udp_mux_test.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package ice
  6. import (
  7. "crypto/rand"
  8. "crypto/sha1" //nolint:gosec
  9. "encoding/binary"
  10. "net"
  11. "sync"
  12. "testing"
  13. "time"
  14. "github.com/pion/stun"
  15. "github.com/pion/transport/v2/test"
  16. "github.com/stretchr/testify/require"
  17. )
  18. func TestUDPMux(t *testing.T) {
  19. report := test.CheckRoutines(t)
  20. defer report()
  21. lim := test.TimeOut(time.Second * 30)
  22. defer lim.Stop()
  23. conn4, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
  24. require.NoError(t, err)
  25. conn6, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6loopback})
  26. if err != nil {
  27. t.Log("IPv6 is not supported on this machine")
  28. }
  29. connUnspecified, err := net.ListenUDP(udp, nil)
  30. require.NoError(t, err)
  31. conn4Unspecified, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4zero})
  32. require.NoError(t, err)
  33. conn6Unspecified, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6unspecified})
  34. if err != nil {
  35. t.Log("IPv6 is not supported on this machine")
  36. }
  37. type testCase struct {
  38. name string
  39. conn net.PacketConn
  40. network string
  41. }
  42. for _, subTest := range []testCase{
  43. {name: "IPv4loopback", conn: conn4, network: udp4},
  44. {name: "IPv6loopback", conn: conn6, network: udp6},
  45. {name: "Unspecified", conn: connUnspecified, network: udp},
  46. {name: "IPv4Unspecified", conn: conn4Unspecified, network: udp4},
  47. {name: "IPv6Unspecified", conn: conn6Unspecified, network: udp6},
  48. } {
  49. network, conn := subTest.network, subTest.conn
  50. if udpConn, ok := conn.(*net.UDPConn); !ok || udpConn == nil {
  51. continue
  52. }
  53. t.Run(subTest.name, func(t *testing.T) {
  54. udpMux := NewUDPMuxDefault(UDPMuxParams{
  55. Logger: nil,
  56. UDPConn: conn,
  57. })
  58. defer func() {
  59. _ = udpMux.Close()
  60. _ = conn.Close()
  61. }()
  62. require.NotNil(t, udpMux.LocalAddr(), "udpMux.LocalAddr() is nil")
  63. wg := sync.WaitGroup{}
  64. wg.Add(1)
  65. go func() {
  66. defer wg.Done()
  67. testMuxConnection(t, udpMux, "ufrag1", udp)
  68. }()
  69. const ptrSize = 32 << (^uintptr(0) >> 63)
  70. if network == udp {
  71. wg.Add(1)
  72. go func() {
  73. defer wg.Done()
  74. testMuxConnection(t, udpMux, "ufrag2", udp4)
  75. }()
  76. // Skip IPv6 test on i386
  77. if ptrSize != 32 {
  78. testMuxConnection(t, udpMux, "ufrag3", udp6)
  79. }
  80. } else if ptrSize != 32 || network != udp6 {
  81. testMuxConnection(t, udpMux, "ufrag2", network)
  82. }
  83. wg.Wait()
  84. require.NoError(t, udpMux.Close())
  85. // Can't create more connections
  86. _, err = udpMux.GetConn("failufrag", udpMux.LocalAddr())
  87. require.Error(t, err)
  88. })
  89. }
  90. }
  91. func TestAddressEncoding(t *testing.T) {
  92. cases := []struct {
  93. name string
  94. addr net.UDPAddr
  95. }{
  96. {
  97. name: "empty address",
  98. },
  99. {
  100. name: "ipv4",
  101. addr: net.UDPAddr{
  102. IP: net.IPv4(244, 120, 0, 5),
  103. Port: 6000,
  104. Zone: "",
  105. },
  106. },
  107. {
  108. name: "ipv6",
  109. addr: net.UDPAddr{
  110. IP: net.IPv6loopback,
  111. Port: 2500,
  112. Zone: "zone",
  113. },
  114. },
  115. }
  116. for _, c := range cases {
  117. addr := c.addr
  118. t.Run(c.name, func(t *testing.T) {
  119. buf := make([]byte, maxAddrSize)
  120. n, err := encodeUDPAddr(&addr, buf)
  121. require.NoError(t, err)
  122. parsedAddr, err := decodeUDPAddr(buf[:n])
  123. require.NoError(t, err)
  124. require.EqualValues(t, &addr, parsedAddr)
  125. })
  126. }
  127. }
  128. func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
  129. pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
  130. require.NoError(t, err, "error retrieving muxed connection for ufrag")
  131. defer func() {
  132. _ = pktConn.Close()
  133. }()
  134. addr, ok := pktConn.LocalAddr().(*net.UDPAddr)
  135. require.True(t, ok, "pktConn.LocalAddr() is not a net.UDPAddr")
  136. if addr.IP.IsUnspecified() {
  137. addr = &net.UDPAddr{Port: addr.Port}
  138. }
  139. remoteConn, err := net.DialUDP(network, nil, addr)
  140. require.NoError(t, err, "error dialing test UDP connection")
  141. testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
  142. }
  143. func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net.UDPConn, ufrag string) {
  144. // Initial messages are dropped
  145. _, err := remoteConn.Write([]byte("dropped bytes"))
  146. require.NoError(t, err)
  147. // Wait for packet to be consumed
  148. time.Sleep(time.Millisecond)
  149. // Write out to establish connection
  150. msg := stun.New()
  151. msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
  152. msg.Add(stun.AttrUsername, []byte(ufrag+":otherufrag"))
  153. msg.Encode()
  154. _, err = pktConn.WriteTo(msg.Raw, remoteConn.LocalAddr())
  155. require.NoError(t, err)
  156. // Ensure received
  157. buf := make([]byte, receiveMTU)
  158. n, err := remoteConn.Read(buf)
  159. require.NoError(t, err)
  160. require.Equal(t, msg.Raw, buf[:n])
  161. // Start writing packets through mux
  162. targetSize := 1 * 1024 * 1024
  163. readDone := make(chan struct{}, 1)
  164. remoteReadDone := make(chan struct{}, 1)
  165. // Read packets from the muxed side
  166. go func() {
  167. defer func() {
  168. t.Logf("closing read chan for: %s", ufrag)
  169. close(readDone)
  170. }()
  171. readBuf := make([]byte, receiveMTU)
  172. nextSeq := uint32(0)
  173. for read := 0; read < targetSize; {
  174. n, _, err := pktConn.ReadFrom(readBuf)
  175. require.NoError(t, err)
  176. require.Equal(t, receiveMTU, n)
  177. verifyPacket(t, readBuf[:n], nextSeq)
  178. // Write it back to sender
  179. _, err = pktConn.WriteTo(readBuf[:n], remoteConn.LocalAddr())
  180. require.NoError(t, err)
  181. read += n
  182. nextSeq++
  183. }
  184. }()
  185. go func() {
  186. defer func() {
  187. close(remoteReadDone)
  188. }()
  189. readBuf := make([]byte, receiveMTU)
  190. nextSeq := uint32(0)
  191. for read := 0; read < targetSize; {
  192. n, _, err := remoteConn.ReadFrom(readBuf)
  193. require.NoError(t, err)
  194. require.Equal(t, receiveMTU, n)
  195. verifyPacket(t, readBuf[:n], nextSeq)
  196. read += n
  197. nextSeq++
  198. }
  199. }()
  200. sequence := 0
  201. for written := 0; written < targetSize; {
  202. buf := make([]byte, receiveMTU)
  203. // Byte 0-4: sequence
  204. // Bytes 4-24: sha1 checksum
  205. // Bytes2 4-mtu: random data
  206. _, err := rand.Read(buf[24:])
  207. require.NoError(t, err)
  208. h := sha1.Sum(buf[24:]) //nolint:gosec
  209. copy(buf[4:24], h[:])
  210. binary.LittleEndian.PutUint32(buf[0:4], uint32(sequence))
  211. _, err = remoteConn.Write(buf)
  212. require.NoError(t, err)
  213. written += len(buf)
  214. sequence++
  215. time.Sleep(time.Millisecond)
  216. }
  217. <-readDone
  218. <-remoteReadDone
  219. }
  220. func verifyPacket(t *testing.T, b []byte, nextSeq uint32) {
  221. readSeq := binary.LittleEndian.Uint32(b[0:4])
  222. require.Equal(t, nextSeq, readSeq)
  223. h := sha1.Sum(b[24:]) //nolint:gosec
  224. require.Equal(t, h[:], b[4:24])
  225. }
  226. func TestUDPMux_Agent_Restart(t *testing.T) {
  227. oneSecond := time.Second
  228. connA, connB := pipe(&AgentConfig{
  229. DisconnectedTimeout: &oneSecond,
  230. FailedTimeout: &oneSecond,
  231. })
  232. aNotifier, aConnected := onConnected()
  233. require.NoError(t, connA.agent.OnConnectionStateChange(aNotifier))
  234. bNotifier, bConnected := onConnected()
  235. require.NoError(t, connB.agent.OnConnectionStateChange(bNotifier))
  236. // Maintain Credentials across restarts
  237. ufragA, pwdA, err := connA.agent.GetLocalUserCredentials()
  238. require.NoError(t, err)
  239. ufragB, pwdB, err := connB.agent.GetLocalUserCredentials()
  240. require.NoError(t, err)
  241. require.NoError(t, err)
  242. // Restart and Re-Signal
  243. require.NoError(t, connA.agent.Restart(ufragA, pwdA))
  244. require.NoError(t, connB.agent.Restart(ufragB, pwdB))
  245. require.NoError(t, connA.agent.SetRemoteCredentials(ufragB, pwdB))
  246. require.NoError(t, connB.agent.SetRemoteCredentials(ufragA, pwdA))
  247. gatherAndExchangeCandidates(connA.agent, connB.agent)
  248. // Wait until both have gone back to connected
  249. <-aConnected
  250. <-bConnected
  251. require.NoError(t, connA.agent.Close())
  252. require.NoError(t, connB.agent.Close())
  253. }