tcp_mux.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package ice
  4. import (
  5. "encoding/binary"
  6. "errors"
  7. "io"
  8. "net"
  9. "strings"
  10. "sync"
  11. "github.com/pion/logging"
  12. "github.com/pion/stun"
  13. )
  14. // ErrGetTransportAddress can't convert net.Addr to underlying type (UDPAddr or TCPAddr).
  15. var ErrGetTransportAddress = errors.New("failed to get local transport address")
  16. // TCPMux is allows grouping multiple TCP net.Conns and using them like UDP
  17. // net.PacketConns. The main implementation of this is TCPMuxDefault, and this
  18. // interface exists to allow mocking in tests.
  19. type TCPMux interface {
  20. io.Closer
  21. GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error)
  22. RemoveConnByUfrag(ufrag string)
  23. }
  24. type ipAddr string
  25. // TCPMuxDefault muxes TCP net.Conns into net.PacketConns and groups them by
  26. // Ufrag. It is a default implementation of TCPMux interface.
  27. type TCPMuxDefault struct {
  28. params *TCPMuxParams
  29. closed bool
  30. // connsIPv4 and connsIPv6 are maps of all tcpPacketConns indexed by ufrag and local address
  31. connsIPv4, connsIPv6 map[string]map[ipAddr]*tcpPacketConn
  32. mu sync.Mutex
  33. wg sync.WaitGroup
  34. }
  35. // TCPMuxParams are parameters for TCPMux.
  36. type TCPMuxParams struct {
  37. Listener net.Listener
  38. Logger logging.LeveledLogger
  39. ReadBufferSize int
  40. // Maximum buffer size for write op. 0 means no write buffer, the write op will block until the whole packet is written
  41. // if the write buffer is full, the subsequent write packet will be dropped until it has enough space.
  42. // a default 4MB is recommended.
  43. WriteBufferSize int
  44. }
  45. // NewTCPMuxDefault creates a new instance of TCPMuxDefault.
  46. func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
  47. if params.Logger == nil {
  48. params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
  49. }
  50. m := &TCPMuxDefault{
  51. params: &params,
  52. connsIPv4: map[string]map[ipAddr]*tcpPacketConn{},
  53. connsIPv6: map[string]map[ipAddr]*tcpPacketConn{},
  54. }
  55. m.wg.Add(1)
  56. go func() {
  57. defer m.wg.Done()
  58. m.start()
  59. }()
  60. return m
  61. }
  62. func (m *TCPMuxDefault) start() {
  63. m.params.Logger.Infof("Listening TCP on %s", m.params.Listener.Addr())
  64. for {
  65. conn, err := m.params.Listener.Accept()
  66. if err != nil {
  67. m.params.Logger.Infof("Error accepting connection: %s", err)
  68. return
  69. }
  70. m.params.Logger.Debugf("Accepted connection from: %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  71. m.wg.Add(1)
  72. go func() {
  73. defer m.wg.Done()
  74. m.handleConn(conn)
  75. }()
  76. }
  77. }
  78. // LocalAddr returns the listening address of this TCPMuxDefault.
  79. func (m *TCPMuxDefault) LocalAddr() net.Addr {
  80. return m.params.Listener.Addr()
  81. }
  82. // GetConnByUfrag retrieves an existing or creates a new net.PacketConn.
  83. func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) {
  84. m.mu.Lock()
  85. defer m.mu.Unlock()
  86. if m.closed {
  87. return nil, io.ErrClosedPipe
  88. }
  89. if conn, ok := m.getConn(ufrag, isIPv6, local); ok {
  90. return conn, nil
  91. }
  92. return m.createConn(ufrag, isIPv6, local)
  93. }
  94. func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP) (*tcpPacketConn, error) {
  95. addr, ok := m.LocalAddr().(*net.TCPAddr)
  96. if !ok {
  97. return nil, ErrGetTransportAddress
  98. }
  99. localAddr := *addr
  100. localAddr.IP = local
  101. conn := newTCPPacketConn(tcpPacketParams{
  102. ReadBuffer: m.params.ReadBufferSize,
  103. WriteBuffer: m.params.WriteBufferSize,
  104. LocalAddr: &localAddr,
  105. Logger: m.params.Logger,
  106. })
  107. var conns map[ipAddr]*tcpPacketConn
  108. if isIPv6 {
  109. if conns, ok = m.connsIPv6[ufrag]; !ok {
  110. conns = make(map[ipAddr]*tcpPacketConn)
  111. m.connsIPv6[ufrag] = conns
  112. }
  113. } else {
  114. if conns, ok = m.connsIPv4[ufrag]; !ok {
  115. conns = make(map[ipAddr]*tcpPacketConn)
  116. m.connsIPv4[ufrag] = conns
  117. }
  118. }
  119. conns[ipAddr(local.String())] = conn
  120. m.wg.Add(1)
  121. go func() {
  122. defer m.wg.Done()
  123. <-conn.CloseChannel()
  124. m.removeConnByUfragAndLocalHost(ufrag, local)
  125. }()
  126. return conn, nil
  127. }
  128. func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) {
  129. err := closer.Close()
  130. if err != nil {
  131. m.params.Logger.Warnf("Error closing connection: %s", err)
  132. }
  133. }
  134. func (m *TCPMuxDefault) handleConn(conn net.Conn) {
  135. buf := make([]byte, receiveMTU)
  136. n, err := readStreamingPacket(conn, buf)
  137. if err != nil {
  138. m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr().String(), err)
  139. return
  140. }
  141. buf = buf[:n]
  142. msg := &stun.Message{
  143. Raw: make([]byte, len(buf)),
  144. }
  145. // Explicitly copy raw buffer so Message can own the memory.
  146. copy(msg.Raw, buf)
  147. if err = msg.Decode(); err != nil {
  148. m.closeAndLogError(conn)
  149. m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
  150. return
  151. }
  152. if m == nil || msg.Type.Method != stun.MethodBinding { // Not a STUN
  153. m.closeAndLogError(conn)
  154. m.params.Logger.Warnf("Not a STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  155. return
  156. }
  157. for _, attr := range msg.Attributes {
  158. m.params.Logger.Debugf("Message attribute: %s", attr.String())
  159. }
  160. attr, err := msg.Get(stun.AttrUsername)
  161. if err != nil {
  162. m.closeAndLogError(conn)
  163. m.params.Logger.Warnf("No Username attribute in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  164. return
  165. }
  166. ufrag := strings.Split(string(attr), ":")[0]
  167. m.params.Logger.Debugf("Ufrag: %s", ufrag)
  168. m.mu.Lock()
  169. defer m.mu.Unlock()
  170. host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
  171. if err != nil {
  172. m.closeAndLogError(conn)
  173. m.params.Logger.Warnf("Failed to get host in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  174. return
  175. }
  176. isIPv6 := net.ParseIP(host).To4() == nil
  177. localAddr, ok := conn.LocalAddr().(*net.TCPAddr)
  178. if !ok {
  179. m.closeAndLogError(conn)
  180. m.params.Logger.Warnf("Failed to get local tcp address in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  181. return
  182. }
  183. packetConn, ok := m.getConn(ufrag, isIPv6, localAddr.IP)
  184. if !ok {
  185. packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP)
  186. if err != nil {
  187. m.closeAndLogError(conn)
  188. m.params.Logger.Warnf("Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  189. return
  190. }
  191. }
  192. if err := packetConn.AddConn(conn, buf); err != nil {
  193. m.closeAndLogError(conn)
  194. m.params.Logger.Warnf("Error adding conn to tcpPacketConn from %s to %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
  195. return
  196. }
  197. }
  198. // Close closes the listener and waits for all goroutines to exit.
  199. func (m *TCPMuxDefault) Close() error {
  200. m.mu.Lock()
  201. m.closed = true
  202. for _, conns := range m.connsIPv4 {
  203. for _, conn := range conns {
  204. m.closeAndLogError(conn)
  205. }
  206. }
  207. for _, conns := range m.connsIPv6 {
  208. for _, conn := range conns {
  209. m.closeAndLogError(conn)
  210. }
  211. }
  212. m.connsIPv4 = map[string]map[ipAddr]*tcpPacketConn{}
  213. m.connsIPv6 = map[string]map[ipAddr]*tcpPacketConn{}
  214. err := m.params.Listener.Close()
  215. m.mu.Unlock()
  216. m.wg.Wait()
  217. return err
  218. }
  219. // RemoveConnByUfrag closes and removes a net.PacketConn by Ufrag.
  220. func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) {
  221. removedConns := make([]*tcpPacketConn, 0, 4)
  222. // Keep lock section small to avoid deadlock with conn lock
  223. m.mu.Lock()
  224. if conns, ok := m.connsIPv4[ufrag]; ok {
  225. delete(m.connsIPv4, ufrag)
  226. for _, conn := range conns {
  227. removedConns = append(removedConns, conn)
  228. }
  229. }
  230. if conns, ok := m.connsIPv6[ufrag]; ok {
  231. delete(m.connsIPv6, ufrag)
  232. for _, conn := range conns {
  233. removedConns = append(removedConns, conn)
  234. }
  235. }
  236. m.mu.Unlock()
  237. // Close the connections outside the critical section to avoid
  238. // deadlocking TCP mux if (*tcpPacketConn).Close() blocks.
  239. for _, conn := range removedConns {
  240. m.closeAndLogError(conn)
  241. }
  242. }
  243. func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, local net.IP) {
  244. removedConns := make([]*tcpPacketConn, 0, 4)
  245. localIP := ipAddr(local.String())
  246. // Keep lock section small to avoid deadlock with conn lock
  247. m.mu.Lock()
  248. if conns, ok := m.connsIPv4[ufrag]; ok {
  249. if conn, ok := conns[localIP]; ok {
  250. delete(conns, localIP)
  251. if len(conns) == 0 {
  252. delete(m.connsIPv4, ufrag)
  253. }
  254. removedConns = append(removedConns, conn)
  255. }
  256. }
  257. if conns, ok := m.connsIPv6[ufrag]; ok {
  258. if conn, ok := conns[localIP]; ok {
  259. delete(conns, localIP)
  260. if len(conns) == 0 {
  261. delete(m.connsIPv6, ufrag)
  262. }
  263. removedConns = append(removedConns, conn)
  264. }
  265. }
  266. m.mu.Unlock()
  267. // Close the connections outside the critical section to avoid
  268. // deadlocking TCP mux if (*tcpPacketConn).Close() blocks.
  269. for _, conn := range removedConns {
  270. m.closeAndLogError(conn)
  271. }
  272. }
  273. func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool, local net.IP) (val *tcpPacketConn, ok bool) {
  274. var conns map[ipAddr]*tcpPacketConn
  275. if isIPv6 {
  276. conns, ok = m.connsIPv6[ufrag]
  277. } else {
  278. conns, ok = m.connsIPv4[ufrag]
  279. }
  280. if conns != nil {
  281. val, ok = conns[ipAddr(local.String())]
  282. }
  283. return
  284. }
  285. const streamingPacketHeaderLen = 2
  286. // readStreamingPacket reads 1 packet from stream
  287. // read packet bytes https://tools.ietf.org/html/rfc4571#section-2
  288. // 2-byte length header prepends each packet:
  289. //
  290. // 0 1 2 3
  291. // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  292. // -----------------------------------------------------------------
  293. // | LENGTH | RTP or RTCP packet ... |
  294. // -----------------------------------------------------------------
  295. func readStreamingPacket(conn net.Conn, buf []byte) (int, error) {
  296. header := make([]byte, streamingPacketHeaderLen)
  297. var bytesRead, n int
  298. var err error
  299. for bytesRead < streamingPacketHeaderLen {
  300. if n, err = conn.Read(header[bytesRead:streamingPacketHeaderLen]); err != nil {
  301. return 0, err
  302. }
  303. bytesRead += n
  304. }
  305. length := int(binary.BigEndian.Uint16(header))
  306. if length > cap(buf) {
  307. return length, io.ErrShortBuffer
  308. }
  309. bytesRead = 0
  310. for bytesRead < length {
  311. if n, err = conn.Read(buf[bytesRead:length]); err != nil {
  312. return 0, err
  313. }
  314. bytesRead += n
  315. }
  316. return bytesRead, nil
  317. }
  318. func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) {
  319. bufCopy := make([]byte, streamingPacketHeaderLen+len(buf))
  320. binary.BigEndian.PutUint16(bufCopy, uint16(len(buf)))
  321. copy(bufCopy[2:], buf)
  322. n, err := conn.Write(bufCopy)
  323. if err != nil {
  324. return 0, err
  325. }
  326. return n - streamingPacketHeaderLen, nil
  327. }