tcp_mux_test.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package ice
  4. import (
  5. "io"
  6. "net"
  7. "os"
  8. "testing"
  9. "time"
  10. "github.com/pion/logging"
  11. "github.com/pion/stun"
  12. "github.com/pion/transport/v2/test"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. )
  16. var _ TCPMux = &TCPMuxDefault{}
  17. func TestTCPMux_Recv(t *testing.T) {
  18. for name, bufSize := range map[string]int{
  19. "no buffer": 0,
  20. "buffered 4MB": 4 * 1024 * 1024,
  21. } {
  22. bufSize := bufSize
  23. t.Run(name, func(t *testing.T) {
  24. report := test.CheckRoutines(t)
  25. defer report()
  26. loggerFactory := logging.NewDefaultLoggerFactory()
  27. listener, err := net.ListenTCP("tcp", &net.TCPAddr{
  28. IP: net.IP{127, 0, 0, 1},
  29. Port: 0,
  30. })
  31. require.NoError(t, err, "error starting listener")
  32. defer func() {
  33. _ = listener.Close()
  34. }()
  35. tcpMux := NewTCPMuxDefault(TCPMuxParams{
  36. Listener: listener,
  37. Logger: loggerFactory.NewLogger("ice"),
  38. ReadBufferSize: 20,
  39. WriteBufferSize: bufSize,
  40. })
  41. defer func() {
  42. _ = tcpMux.Close()
  43. }()
  44. require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
  45. conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
  46. require.NoError(t, err, "error dialing test TCP connection")
  47. msg := stun.New()
  48. msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
  49. msg.Add(stun.AttrUsername, []byte("myufrag:otherufrag"))
  50. msg.Encode()
  51. n, err := writeStreamingPacket(conn, msg.Raw)
  52. require.NoError(t, err, "error writing TCP STUN packet")
  53. pktConn, err := tcpMux.GetConnByUfrag("myufrag", false, listener.Addr().(*net.TCPAddr).IP)
  54. require.NoError(t, err, "error retrieving muxed connection for ufrag")
  55. defer func() {
  56. _ = pktConn.Close()
  57. }()
  58. recv := make([]byte, n)
  59. n2, rAddr, err := pktConn.ReadFrom(recv)
  60. require.NoError(t, err, "error receiving data")
  61. assert.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch")
  62. assert.Equal(t, n, n2, "received byte size mismatch")
  63. assert.Equal(t, msg.Raw, recv, "received bytes mismatch")
  64. // Check echo response
  65. n, err = pktConn.WriteTo(recv, conn.LocalAddr())
  66. require.NoError(t, err, "error writing echo STUN packet")
  67. recvEcho := make([]byte, n)
  68. n3, err := readStreamingPacket(conn, recvEcho)
  69. require.NoError(t, err, "error receiving echo data")
  70. assert.Equal(t, n2, n3, "received byte size mismatch")
  71. assert.Equal(t, msg.Raw, recvEcho, "received bytes mismatch")
  72. })
  73. }
  74. }
  75. func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
  76. report := test.CheckRoutines(t)
  77. defer report()
  78. loggerFactory := logging.NewDefaultLoggerFactory()
  79. listener, err := net.ListenTCP("tcp", &net.TCPAddr{
  80. IP: net.IP{127, 0, 0, 1},
  81. Port: 0,
  82. })
  83. require.NoError(t, err, "error starting listener")
  84. defer func() {
  85. _ = listener.Close()
  86. }()
  87. tcpMux := NewTCPMuxDefault(TCPMuxParams{
  88. Listener: listener,
  89. Logger: loggerFactory.NewLogger("ice"),
  90. ReadBufferSize: 20,
  91. })
  92. defer func() {
  93. _ = tcpMux.Close()
  94. }()
  95. _, err = tcpMux.GetConnByUfrag("test", false, listener.Addr().(*net.TCPAddr).IP)
  96. require.NoError(t, err, "error getting conn by ufrag")
  97. require.NoError(t, tcpMux.Close(), "error closing tcpMux")
  98. conn, err := tcpMux.GetConnByUfrag("test", false, listener.Addr().(*net.TCPAddr).IP)
  99. assert.Nil(t, conn, "should receive nil because mux is closed")
  100. assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
  101. }
  102. func TestTCPMux_FirstPacketTimeout(t *testing.T) {
  103. report := test.CheckRoutines(t)
  104. defer report()
  105. loggerFactory := logging.NewDefaultLoggerFactory()
  106. listener, err := net.ListenTCP("tcp", &net.TCPAddr{
  107. IP: net.IP{127, 0, 0, 1},
  108. Port: 0,
  109. })
  110. require.NoError(t, err, "error starting listener")
  111. defer func() {
  112. _ = listener.Close()
  113. }()
  114. tcpMux := NewTCPMuxDefault(TCPMuxParams{
  115. Listener: listener,
  116. Logger: loggerFactory.NewLogger("ice"),
  117. ReadBufferSize: 20,
  118. FirstStunBindTimeout: time.Second,
  119. })
  120. require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
  121. conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
  122. require.NoError(t, err, "error dialing test TCP connection")
  123. defer func() {
  124. _ = conn.Close()
  125. }()
  126. // Don't send any data, the mux should close the connection after the timeout
  127. time.Sleep(1500 * time.Millisecond)
  128. require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
  129. buf := make([]byte, 1)
  130. _, err = conn.Read(buf)
  131. require.ErrorIs(t, err, io.EOF)
  132. }
  133. func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) {
  134. report := test.CheckRoutines(t)
  135. defer report()
  136. loggerFactory := logging.NewDefaultLoggerFactory()
  137. listener, err := net.ListenTCP("tcp", &net.TCPAddr{
  138. IP: net.IP{127, 0, 0, 1},
  139. Port: 0,
  140. })
  141. require.NoError(t, err, "error starting listener")
  142. defer func() {
  143. _ = listener.Close()
  144. }()
  145. tcpMux := NewTCPMuxDefault(TCPMuxParams{
  146. Listener: listener,
  147. Logger: loggerFactory.NewLogger("ice"),
  148. ReadBufferSize: 20,
  149. AliveDurationForConnFromStun: time.Second,
  150. })
  151. defer func() {
  152. _ = tcpMux.Close()
  153. }()
  154. require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
  155. t.Run("close connection from stun msg after timeout", func(t *testing.T) {
  156. conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
  157. require.NoError(t, err, "error dialing test TCP connection")
  158. defer func() {
  159. _ = conn.Close()
  160. }()
  161. msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
  162. stun.NewUsername("myufrag:otherufrag"),
  163. stun.NewShortTermIntegrity("myufrag"),
  164. stun.Fingerprint,
  165. )
  166. require.NoError(t, err, "error building STUN packet")
  167. msg.Encode()
  168. _, err = writeStreamingPacket(conn, msg.Raw)
  169. require.NoError(t, err, "error writing TCP STUN packet")
  170. time.Sleep(1500 * time.Millisecond)
  171. require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
  172. buf := make([]byte, 1)
  173. _, err = conn.Read(buf)
  174. require.ErrorIs(t, err, io.EOF)
  175. })
  176. t.Run("connection keep alive if access by user", func(t *testing.T) {
  177. conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
  178. require.NoError(t, err, "error dialing test TCP connection")
  179. defer func() {
  180. _ = conn.Close()
  181. }()
  182. msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
  183. stun.NewUsername("myufrag2:otherufrag2"),
  184. stun.NewShortTermIntegrity("myufrag2"),
  185. stun.Fingerprint,
  186. )
  187. require.NoError(t, err, "error building STUN packet")
  188. msg.Encode()
  189. n, err := writeStreamingPacket(conn, msg.Raw)
  190. require.NoError(t, err, "error writing TCP STUN packet")
  191. // wait for the connection to be created
  192. time.Sleep(100 * time.Millisecond)
  193. pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, listener.Addr().(*net.TCPAddr).IP)
  194. require.NoError(t, err, "error retrieving muxed connection for ufrag")
  195. defer func() {
  196. _ = pktConn.Close()
  197. }()
  198. time.Sleep(1500 * time.Millisecond)
  199. // timeout, not closed
  200. buf := make([]byte, 1024)
  201. require.NoError(t, conn.SetReadDeadline(time.Now().Add(100*time.Millisecond)))
  202. _, err = conn.Read(buf)
  203. require.ErrorIs(t, err, os.ErrDeadlineExceeded)
  204. recv := make([]byte, n)
  205. n2, rAddr, err := pktConn.ReadFrom(recv)
  206. require.NoError(t, err, "error receiving data")
  207. assert.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch")
  208. assert.Equal(t, n, n2, "received byte size mismatch")
  209. assert.Equal(t, msg.Raw, recv, "received bytes mismatch")
  210. })
  211. }