tcp_mux.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package ice
  4. import (
  5. "encoding/binary"
  6. "errors"
  7. "io"
  8. "net"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/pion/logging"
  13. "github.com/pion/stun"
  14. )
  15. // ErrGetTransportAddress can't convert net.Addr to underlying type (UDPAddr or TCPAddr).
  16. var ErrGetTransportAddress = errors.New("failed to get local transport address")
  17. // TCPMux is allows grouping multiple TCP net.Conns and using them like UDP
  18. // net.PacketConns. The main implementation of this is TCPMuxDefault, and this
  19. // interface exists to allow mocking in tests.
  20. type TCPMux interface {
  21. io.Closer
  22. GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error)
  23. RemoveConnByUfrag(ufrag string)
  24. }
  25. type ipAddr string
  26. // TCPMuxDefault muxes TCP net.Conns into net.PacketConns and groups them by
  27. // Ufrag. It is a default implementation of TCPMux interface.
  28. type TCPMuxDefault struct {
  29. params *TCPMuxParams
  30. closed bool
  31. // connsIPv4 and connsIPv6 are maps of all tcpPacketConns indexed by ufrag and local address
  32. connsIPv4, connsIPv6 map[string]map[ipAddr]*tcpPacketConn
  33. mu sync.Mutex
  34. wg sync.WaitGroup
  35. }
  36. // TCPMuxParams are parameters for TCPMux.
  37. type TCPMuxParams struct {
  38. Listener net.Listener
  39. Logger logging.LeveledLogger
  40. ReadBufferSize int
  41. // Maximum buffer size for write op. 0 means no write buffer, the write op will block until the whole packet is written
  42. // if the write buffer is full, the subsequent write packet will be dropped until it has enough space.
  43. // a default 4MB is recommended.
  44. WriteBufferSize int
  45. // A new established connection will be removed if the first STUN binding request is not received within this timeout,
  46. // avoiding the client with bad network or attacker to create a lot of empty connections.
  47. // Default 30s timeout will be used if not set.
  48. FirstStunBindTimeout time.Duration
  49. // TCPMux will create connection from STUN binding request with an unknown username, if
  50. // the connection is not used in the timeout, it will be removed to avoid resource leak / attack.
  51. // Default 30s timeout will be used if not set.
  52. AliveDurationForConnFromStun time.Duration
  53. }
  54. // NewTCPMuxDefault creates a new instance of TCPMuxDefault.
  55. func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
  56. if params.Logger == nil {
  57. params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
  58. }
  59. if params.FirstStunBindTimeout == 0 {
  60. params.FirstStunBindTimeout = 30 * time.Second
  61. }
  62. if params.AliveDurationForConnFromStun == 0 {
  63. params.AliveDurationForConnFromStun = 30 * time.Second
  64. }
  65. m := &TCPMuxDefault{
  66. params: &params,
  67. connsIPv4: map[string]map[ipAddr]*tcpPacketConn{},
  68. connsIPv6: map[string]map[ipAddr]*tcpPacketConn{},
  69. }
  70. m.wg.Add(1)
  71. go func() {
  72. defer m.wg.Done()
  73. m.start()
  74. }()
  75. return m
  76. }
  77. func (m *TCPMuxDefault) start() {
  78. m.params.Logger.Infof("Listening TCP on %s", m.params.Listener.Addr())
  79. for {
  80. conn, err := m.params.Listener.Accept()
  81. if err != nil {
  82. m.params.Logger.Infof("Error accepting connection: %s", err)
  83. return
  84. }
  85. m.params.Logger.Debugf("Accepted connection from: %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  86. m.wg.Add(1)
  87. go func() {
  88. defer m.wg.Done()
  89. m.handleConn(conn)
  90. }()
  91. }
  92. }
  93. // LocalAddr returns the listening address of this TCPMuxDefault.
  94. func (m *TCPMuxDefault) LocalAddr() net.Addr {
  95. return m.params.Listener.Addr()
  96. }
  97. // GetConnByUfrag retrieves an existing or creates a new net.PacketConn.
  98. func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) {
  99. m.mu.Lock()
  100. defer m.mu.Unlock()
  101. if m.closed {
  102. return nil, io.ErrClosedPipe
  103. }
  104. if conn, ok := m.getConn(ufrag, isIPv6, local); ok {
  105. conn.ClearAliveTimer()
  106. return conn, nil
  107. }
  108. return m.createConn(ufrag, isIPv6, local, false)
  109. }
  110. func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, fromStun bool) (*tcpPacketConn, error) {
  111. addr, ok := m.LocalAddr().(*net.TCPAddr)
  112. if !ok {
  113. return nil, ErrGetTransportAddress
  114. }
  115. localAddr := *addr
  116. localAddr.IP = local
  117. var alive time.Duration
  118. if fromStun {
  119. alive = m.params.AliveDurationForConnFromStun
  120. }
  121. conn := newTCPPacketConn(tcpPacketParams{
  122. ReadBuffer: m.params.ReadBufferSize,
  123. WriteBuffer: m.params.WriteBufferSize,
  124. LocalAddr: &localAddr,
  125. Logger: m.params.Logger,
  126. AliveDuration: alive,
  127. })
  128. var conns map[ipAddr]*tcpPacketConn
  129. if isIPv6 {
  130. if conns, ok = m.connsIPv6[ufrag]; !ok {
  131. conns = make(map[ipAddr]*tcpPacketConn)
  132. m.connsIPv6[ufrag] = conns
  133. }
  134. } else {
  135. if conns, ok = m.connsIPv4[ufrag]; !ok {
  136. conns = make(map[ipAddr]*tcpPacketConn)
  137. m.connsIPv4[ufrag] = conns
  138. }
  139. }
  140. conns[ipAddr(local.String())] = conn
  141. m.wg.Add(1)
  142. go func() {
  143. defer m.wg.Done()
  144. <-conn.CloseChannel()
  145. m.removeConnByUfragAndLocalHost(ufrag, local)
  146. }()
  147. return conn, nil
  148. }
  149. func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) {
  150. err := closer.Close()
  151. if err != nil {
  152. m.params.Logger.Warnf("Error closing connection: %s", err)
  153. }
  154. }
  155. func (m *TCPMuxDefault) handleConn(conn net.Conn) {
  156. buf := make([]byte, 512)
  157. if m.params.FirstStunBindTimeout > 0 {
  158. if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil {
  159. m.params.Logger.Warnf("Failed to set read deadline for first STUN message: %s to %s , err: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
  160. }
  161. }
  162. n, err := readStreamingPacket(conn, buf)
  163. if err != nil {
  164. if errors.Is(err, io.ErrShortBuffer) {
  165. m.params.Logger.Warnf("Buffer too small for first packet from %s: %s", conn.RemoteAddr(), err)
  166. } else {
  167. m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err)
  168. }
  169. m.closeAndLogError(conn)
  170. return
  171. }
  172. if err = conn.SetReadDeadline(time.Time{}); err != nil {
  173. m.params.Logger.Warnf("Failed to reset read deadline from %s: %s", conn.RemoteAddr(), err)
  174. }
  175. buf = buf[:n]
  176. msg := &stun.Message{
  177. Raw: make([]byte, len(buf)),
  178. }
  179. // Explicitly copy raw buffer so Message can own the memory.
  180. copy(msg.Raw, buf)
  181. if err = msg.Decode(); err != nil {
  182. m.closeAndLogError(conn)
  183. m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
  184. return
  185. }
  186. if m == nil || msg.Type.Method != stun.MethodBinding { // Not a STUN
  187. m.closeAndLogError(conn)
  188. m.params.Logger.Warnf("Not a STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  189. return
  190. }
  191. for _, attr := range msg.Attributes {
  192. m.params.Logger.Debugf("Message attribute: %s", attr.String())
  193. }
  194. attr, err := msg.Get(stun.AttrUsername)
  195. if err != nil {
  196. m.closeAndLogError(conn)
  197. m.params.Logger.Warnf("No Username attribute in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  198. return
  199. }
  200. ufrag := strings.Split(string(attr), ":")[0]
  201. m.params.Logger.Debugf("Ufrag: %s", ufrag)
  202. host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
  203. if err != nil {
  204. m.closeAndLogError(conn)
  205. m.params.Logger.Warnf("Failed to get host in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  206. return
  207. }
  208. isIPv6 := net.ParseIP(host).To4() == nil
  209. localAddr, ok := conn.LocalAddr().(*net.TCPAddr)
  210. if !ok {
  211. m.closeAndLogError(conn)
  212. m.params.Logger.Warnf("Failed to get local tcp address in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  213. return
  214. }
  215. m.mu.Lock()
  216. packetConn, ok := m.getConn(ufrag, isIPv6, localAddr.IP)
  217. if !ok {
  218. packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP, true)
  219. if err != nil {
  220. m.mu.Unlock()
  221. m.closeAndLogError(conn)
  222. m.params.Logger.Warnf("Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  223. return
  224. }
  225. }
  226. m.mu.Unlock()
  227. if err := packetConn.AddConn(conn, buf); err != nil {
  228. m.closeAndLogError(conn)
  229. m.params.Logger.Warnf("Error adding conn to tcpPacketConn from %s to %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
  230. return
  231. }
  232. }
  233. // Close closes the listener and waits for all goroutines to exit.
  234. func (m *TCPMuxDefault) Close() error {
  235. m.mu.Lock()
  236. m.closed = true
  237. for _, conns := range m.connsIPv4 {
  238. for _, conn := range conns {
  239. m.closeAndLogError(conn)
  240. }
  241. }
  242. for _, conns := range m.connsIPv6 {
  243. for _, conn := range conns {
  244. m.closeAndLogError(conn)
  245. }
  246. }
  247. m.connsIPv4 = map[string]map[ipAddr]*tcpPacketConn{}
  248. m.connsIPv6 = map[string]map[ipAddr]*tcpPacketConn{}
  249. err := m.params.Listener.Close()
  250. m.mu.Unlock()
  251. m.wg.Wait()
  252. return err
  253. }
  254. // RemoveConnByUfrag closes and removes a net.PacketConn by Ufrag.
  255. func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) {
  256. removedConns := make([]*tcpPacketConn, 0, 4)
  257. // Keep lock section small to avoid deadlock with conn lock
  258. m.mu.Lock()
  259. if conns, ok := m.connsIPv4[ufrag]; ok {
  260. delete(m.connsIPv4, ufrag)
  261. for _, conn := range conns {
  262. removedConns = append(removedConns, conn)
  263. }
  264. }
  265. if conns, ok := m.connsIPv6[ufrag]; ok {
  266. delete(m.connsIPv6, ufrag)
  267. for _, conn := range conns {
  268. removedConns = append(removedConns, conn)
  269. }
  270. }
  271. m.mu.Unlock()
  272. // Close the connections outside the critical section to avoid
  273. // deadlocking TCP mux if (*tcpPacketConn).Close() blocks.
  274. for _, conn := range removedConns {
  275. m.closeAndLogError(conn)
  276. }
  277. }
  278. func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, local net.IP) {
  279. removedConns := make([]*tcpPacketConn, 0, 4)
  280. localIP := ipAddr(local.String())
  281. // Keep lock section small to avoid deadlock with conn lock
  282. m.mu.Lock()
  283. if conns, ok := m.connsIPv4[ufrag]; ok {
  284. if conn, ok := conns[localIP]; ok {
  285. delete(conns, localIP)
  286. if len(conns) == 0 {
  287. delete(m.connsIPv4, ufrag)
  288. }
  289. removedConns = append(removedConns, conn)
  290. }
  291. }
  292. if conns, ok := m.connsIPv6[ufrag]; ok {
  293. if conn, ok := conns[localIP]; ok {
  294. delete(conns, localIP)
  295. if len(conns) == 0 {
  296. delete(m.connsIPv6, ufrag)
  297. }
  298. removedConns = append(removedConns, conn)
  299. }
  300. }
  301. m.mu.Unlock()
  302. // Close the connections outside the critical section to avoid
  303. // deadlocking TCP mux if (*tcpPacketConn).Close() blocks.
  304. for _, conn := range removedConns {
  305. m.closeAndLogError(conn)
  306. }
  307. }
  308. func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool, local net.IP) (val *tcpPacketConn, ok bool) {
  309. var conns map[ipAddr]*tcpPacketConn
  310. if isIPv6 {
  311. conns, ok = m.connsIPv6[ufrag]
  312. } else {
  313. conns, ok = m.connsIPv4[ufrag]
  314. }
  315. if conns != nil {
  316. val, ok = conns[ipAddr(local.String())]
  317. }
  318. return
  319. }
  320. const streamingPacketHeaderLen = 2
  321. // readStreamingPacket reads 1 packet from stream
  322. // read packet bytes https://tools.ietf.org/html/rfc4571#section-2
  323. // 2-byte length header prepends each packet:
  324. //
  325. // 0 1 2 3
  326. // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  327. // -----------------------------------------------------------------
  328. // | LENGTH | RTP or RTCP packet ... |
  329. // -----------------------------------------------------------------
  330. func readStreamingPacket(conn net.Conn, buf []byte) (int, error) {
  331. header := make([]byte, streamingPacketHeaderLen)
  332. var bytesRead, n int
  333. var err error
  334. for bytesRead < streamingPacketHeaderLen {
  335. if n, err = conn.Read(header[bytesRead:streamingPacketHeaderLen]); err != nil {
  336. return 0, err
  337. }
  338. bytesRead += n
  339. }
  340. length := int(binary.BigEndian.Uint16(header))
  341. if length > cap(buf) {
  342. return length, io.ErrShortBuffer
  343. }
  344. bytesRead = 0
  345. for bytesRead < length {
  346. if n, err = conn.Read(buf[bytesRead:length]); err != nil {
  347. return 0, err
  348. }
  349. bytesRead += n
  350. }
  351. return bytesRead, nil
  352. }
  353. func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) {
  354. bufCopy := make([]byte, streamingPacketHeaderLen+len(buf))
  355. binary.BigEndian.PutUint16(bufCopy, uint16(len(buf)))
  356. copy(bufCopy[2:], buf)
  357. n, err := conn.Write(bufCopy)
  358. if err != nil {
  359. return 0, err
  360. }
  361. return n - streamingPacketHeaderLen, nil
  362. }