udp_mux_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package ice
  6. import (
  7. "crypto/rand"
  8. "crypto/sha1" //nolint:gosec
  9. "encoding/binary"
  10. "net"
  11. "sync"
  12. "testing"
  13. "time"
  14. "github.com/pion/stun"
  15. "github.com/pion/transport/v2/test"
  16. "github.com/stretchr/testify/require"
  17. )
  18. func TestUDPMux(t *testing.T) {
  19. report := test.CheckRoutines(t)
  20. defer report()
  21. lim := test.TimeOut(time.Second * 30)
  22. defer lim.Stop()
  23. conn4, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
  24. require.NoError(t, err)
  25. conn6, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6loopback})
  26. if err != nil {
  27. t.Log("IPv6 is not supported on this machine")
  28. }
  29. connUnspecified, err := net.ListenUDP(udp, nil)
  30. require.NoError(t, err)
  31. conn4Unspecified, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4zero})
  32. require.NoError(t, err)
  33. conn6Unspecified, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6unspecified})
  34. if err != nil {
  35. t.Log("IPv6 is not supported on this machine")
  36. }
  37. type testCase struct {
  38. name string
  39. conn net.PacketConn
  40. network string
  41. }
  42. for _, subTest := range []testCase{
  43. {name: "IPv4loopback", conn: conn4, network: udp4},
  44. {name: "IPv6loopback", conn: conn6, network: udp6},
  45. {name: "Unspecified", conn: connUnspecified, network: udp},
  46. {name: "IPv4Unspecified", conn: conn4Unspecified, network: udp4},
  47. {name: "IPv6Unspecified", conn: conn6Unspecified, network: udp6},
  48. } {
  49. network, conn := subTest.network, subTest.conn
  50. if udpConn, ok := conn.(*net.UDPConn); !ok || udpConn == nil {
  51. continue
  52. }
  53. t.Run(subTest.name, func(t *testing.T) {
  54. udpMux := NewUDPMuxDefault(UDPMuxParams{
  55. Logger: nil,
  56. UDPConn: conn,
  57. })
  58. defer func() {
  59. _ = udpMux.Close()
  60. _ = conn.Close()
  61. }()
  62. require.NotNil(t, udpMux.LocalAddr(), "udpMux.LocalAddr() is nil")
  63. wg := sync.WaitGroup{}
  64. wg.Add(1)
  65. go func() {
  66. defer wg.Done()
  67. testMuxConnection(t, udpMux, "ufrag1", udp)
  68. }()
  69. const ptrSize = 32 << (^uintptr(0) >> 63)
  70. if network == udp {
  71. wg.Add(1)
  72. go func() {
  73. defer wg.Done()
  74. testMuxConnection(t, udpMux, "ufrag2", udp4)
  75. }()
  76. // Skip IPv6 test on i386
  77. if ptrSize != 32 {
  78. testMuxConnection(t, udpMux, "ufrag3", udp6)
  79. }
  80. } else if ptrSize != 32 || network != udp6 {
  81. testMuxConnection(t, udpMux, "ufrag2", network)
  82. }
  83. wg.Wait()
  84. require.NoError(t, udpMux.Close())
  85. // Can't create more connections
  86. _, err = udpMux.GetConn("failufrag", udpMux.LocalAddr())
  87. require.Error(t, err)
  88. })
  89. }
  90. }
  91. func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
  92. pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
  93. require.NoError(t, err, "error retrieving muxed connection for ufrag")
  94. defer func() {
  95. _ = pktConn.Close()
  96. }()
  97. addr, ok := pktConn.LocalAddr().(*net.UDPAddr)
  98. require.True(t, ok, "pktConn.LocalAddr() is not a net.UDPAddr")
  99. if addr.IP.IsUnspecified() {
  100. addr = &net.UDPAddr{Port: addr.Port}
  101. }
  102. remoteConn, err := net.DialUDP(network, nil, addr)
  103. require.NoError(t, err, "error dialing test UDP connection")
  104. testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
  105. }
  106. func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net.UDPConn, ufrag string) {
  107. // Initial messages are dropped
  108. _, err := remoteConn.Write([]byte("dropped bytes"))
  109. require.NoError(t, err)
  110. // Wait for packet to be consumed
  111. time.Sleep(time.Millisecond)
  112. // Write out to establish connection
  113. msg := stun.New()
  114. msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
  115. msg.Add(stun.AttrUsername, []byte(ufrag+":otherufrag"))
  116. msg.Encode()
  117. _, err = pktConn.WriteTo(msg.Raw, remoteConn.LocalAddr())
  118. require.NoError(t, err)
  119. // Ensure received
  120. buf := make([]byte, receiveMTU)
  121. n, err := remoteConn.Read(buf)
  122. require.NoError(t, err)
  123. require.Equal(t, msg.Raw, buf[:n])
  124. // Start writing packets through mux
  125. targetSize := 1 * 1024 * 1024
  126. readDone := make(chan struct{}, 1)
  127. remoteReadDone := make(chan struct{}, 1)
  128. // Read packets from the muxed side
  129. go func() {
  130. defer func() {
  131. t.Logf("closing read chan for: %s", ufrag)
  132. close(readDone)
  133. }()
  134. readBuf := make([]byte, receiveMTU)
  135. nextSeq := uint32(0)
  136. for read := 0; read < targetSize; {
  137. n, _, err := pktConn.ReadFrom(readBuf)
  138. require.NoError(t, err)
  139. require.Equal(t, receiveMTU, n)
  140. verifyPacket(t, readBuf[:n], nextSeq)
  141. // Write it back to sender
  142. _, err = pktConn.WriteTo(readBuf[:n], remoteConn.LocalAddr())
  143. require.NoError(t, err)
  144. read += n
  145. nextSeq++
  146. }
  147. }()
  148. go func() {
  149. defer func() {
  150. close(remoteReadDone)
  151. }()
  152. readBuf := make([]byte, receiveMTU)
  153. nextSeq := uint32(0)
  154. for read := 0; read < targetSize; {
  155. n, _, err := remoteConn.ReadFrom(readBuf)
  156. require.NoError(t, err)
  157. require.Equal(t, receiveMTU, n)
  158. verifyPacket(t, readBuf[:n], nextSeq)
  159. read += n
  160. nextSeq++
  161. }
  162. }()
  163. sequence := 0
  164. for written := 0; written < targetSize; {
  165. buf := make([]byte, receiveMTU)
  166. // Byte 0-4: sequence
  167. // Bytes 4-24: sha1 checksum
  168. // Bytes2 4-mtu: random data
  169. _, err := rand.Read(buf[24:])
  170. require.NoError(t, err)
  171. h := sha1.Sum(buf[24:]) //nolint:gosec
  172. copy(buf[4:24], h[:])
  173. binary.LittleEndian.PutUint32(buf[0:4], uint32(sequence))
  174. _, err = remoteConn.Write(buf)
  175. require.NoError(t, err)
  176. written += len(buf)
  177. sequence++
  178. time.Sleep(time.Millisecond)
  179. }
  180. <-readDone
  181. <-remoteReadDone
  182. }
  183. func verifyPacket(t *testing.T, b []byte, nextSeq uint32) {
  184. readSeq := binary.LittleEndian.Uint32(b[0:4])
  185. require.Equal(t, nextSeq, readSeq)
  186. h := sha1.Sum(b[24:]) //nolint:gosec
  187. require.Equal(t, h[:], b[4:24])
  188. }
  189. func TestUDPMux_Agent_Restart(t *testing.T) {
  190. oneSecond := time.Second
  191. connA, connB := pipe(&AgentConfig{
  192. DisconnectedTimeout: &oneSecond,
  193. FailedTimeout: &oneSecond,
  194. })
  195. aNotifier, aConnected := onConnected()
  196. require.NoError(t, connA.agent.OnConnectionStateChange(aNotifier))
  197. bNotifier, bConnected := onConnected()
  198. require.NoError(t, connB.agent.OnConnectionStateChange(bNotifier))
  199. // Maintain Credentials across restarts
  200. ufragA, pwdA, err := connA.agent.GetLocalUserCredentials()
  201. require.NoError(t, err)
  202. ufragB, pwdB, err := connB.agent.GetLocalUserCredentials()
  203. require.NoError(t, err)
  204. require.NoError(t, err)
  205. // Restart and Re-Signal
  206. require.NoError(t, connA.agent.Restart(ufragA, pwdA))
  207. require.NoError(t, connB.agent.Restart(ufragB, pwdB))
  208. require.NoError(t, connA.agent.SetRemoteCredentials(ufragB, pwdB))
  209. require.NoError(t, connB.agent.SetRemoteCredentials(ufragA, pwdA))
  210. gatherAndExchangeCandidates(connA.agent, connB.agent)
  211. // Wait until both have gone back to connected
  212. <-aConnected
  213. <-bConnected
  214. require.NoError(t, connA.agent.Close())
  215. require.NoError(t, connB.agent.Close())
  216. }