tcp_mux_multi_test.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package ice
  6. import (
  7. "io"
  8. "net"
  9. "testing"
  10. "github.com/pion/logging"
  11. "github.com/pion/stun"
  12. "github.com/pion/transport/v2/test"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. )
  16. func TestMultiTCPMux_Recv(t *testing.T) {
  17. for name, bufSize := range map[string]int{
  18. "no buffer": 0,
  19. "buffered 4MB": 4 * 1024 * 1024,
  20. } {
  21. bufSize := bufSize
  22. t.Run(name, func(t *testing.T) {
  23. report := test.CheckRoutines(t)
  24. defer report()
  25. loggerFactory := logging.NewDefaultLoggerFactory()
  26. var muxInstances []TCPMux
  27. for i := 0; i < 3; i++ {
  28. listener, err := net.ListenTCP("tcp", &net.TCPAddr{
  29. IP: net.IP{127, 0, 0, 1},
  30. Port: 0,
  31. })
  32. require.NoError(t, err, "error starting listener")
  33. defer func() {
  34. _ = listener.Close()
  35. }()
  36. tcpMux := NewTCPMuxDefault(TCPMuxParams{
  37. Listener: listener,
  38. Logger: loggerFactory.NewLogger("ice"),
  39. ReadBufferSize: 20,
  40. WriteBufferSize: bufSize,
  41. })
  42. muxInstances = append(muxInstances, tcpMux)
  43. require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
  44. }
  45. multiMux := NewMultiTCPMuxDefault(muxInstances...)
  46. defer func() {
  47. _ = multiMux.Close()
  48. }()
  49. pktConns, err := multiMux.GetAllConns("myufrag", false, net.IP{127, 0, 0, 1})
  50. require.NoError(t, err, "error retrieving muxed connection for ufrag")
  51. for _, pktConn := range pktConns {
  52. defer func() {
  53. _ = pktConn.Close()
  54. }()
  55. conn, err := net.DialTCP("tcp", nil, pktConn.LocalAddr().(*net.TCPAddr))
  56. require.NoError(t, err, "error dialing test TCP connection")
  57. msg := stun.New()
  58. msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
  59. msg.Add(stun.AttrUsername, []byte("myufrag:otherufrag"))
  60. msg.Encode()
  61. n, err := writeStreamingPacket(conn, msg.Raw)
  62. require.NoError(t, err, "error writing TCP STUN packet")
  63. recv := make([]byte, n)
  64. n2, rAddr, err := pktConn.ReadFrom(recv)
  65. require.NoError(t, err, "error receiving data")
  66. assert.Equal(t, conn.LocalAddr(), rAddr, "remote TCP address mismatch")
  67. assert.Equal(t, n, n2, "received byte size mismatch")
  68. assert.Equal(t, msg.Raw, recv, "received bytes mismatch")
  69. // Check echo response
  70. n, err = pktConn.WriteTo(recv, conn.LocalAddr())
  71. require.NoError(t, err, "error writing echo STUN packet")
  72. recvEcho := make([]byte, n)
  73. n3, err := readStreamingPacket(conn, recvEcho)
  74. require.NoError(t, err, "error receiving echo data")
  75. assert.Equal(t, n2, n3, "received byte size mismatch")
  76. assert.Equal(t, msg.Raw, recvEcho, "received bytes mismatch")
  77. }
  78. })
  79. }
  80. }
  81. func TestMultiTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
  82. report := test.CheckRoutines(t)
  83. defer report()
  84. loggerFactory := logging.NewDefaultLoggerFactory()
  85. var tcpMuxInstances []TCPMux
  86. for i := 0; i < 3; i++ {
  87. listener, err := net.ListenTCP("tcp", &net.TCPAddr{
  88. IP: net.IP{127, 0, 0, 1},
  89. Port: 0,
  90. })
  91. require.NoError(t, err, "error starting listener")
  92. defer func() {
  93. _ = listener.Close()
  94. }()
  95. tcpMux := NewTCPMuxDefault(TCPMuxParams{
  96. Listener: listener,
  97. Logger: loggerFactory.NewLogger("ice"),
  98. ReadBufferSize: 20,
  99. })
  100. tcpMuxInstances = append(tcpMuxInstances, tcpMux)
  101. }
  102. muxMulti := NewMultiTCPMuxDefault(tcpMuxInstances...)
  103. _, err := muxMulti.GetAllConns("test", false, net.IP{127, 0, 0, 1})
  104. require.NoError(t, err, "error getting conn by ufrag")
  105. require.NoError(t, muxMulti.Close(), "error closing tcpMux")
  106. conn, err := muxMulti.GetAllConns("test", false, net.IP{127, 0, 0, 1})
  107. assert.Nil(t, conn, "should receive nil because mux is closed")
  108. assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
  109. }