mux_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package mux
  4. import (
  5. "io"
  6. "net"
  7. "testing"
  8. "time"
  9. "github.com/pion/logging"
  10. "github.com/pion/transport/v2/packetio"
  11. "github.com/pion/transport/v2/test"
  12. "github.com/stretchr/testify/require"
  13. )
  14. const testPipeBufferSize = 8192
  15. func TestNoEndpoints(t *testing.T) {
  16. // In memory pipe
  17. ca, cb := net.Pipe()
  18. require.NoError(t, cb.Close())
  19. m := NewMux(Config{
  20. Conn: ca,
  21. BufferSize: testPipeBufferSize,
  22. LoggerFactory: logging.NewDefaultLoggerFactory(),
  23. })
  24. require.NoError(t, m.dispatch(make([]byte, 1)))
  25. require.NoError(t, m.Close())
  26. require.NoError(t, ca.Close())
  27. }
  28. type muxErrorConnReadResult struct {
  29. err error
  30. data []byte
  31. }
  32. // muxErrorConn
  33. type muxErrorConn struct {
  34. net.Conn
  35. readResults []muxErrorConnReadResult
  36. }
  37. func (m *muxErrorConn) Read(b []byte) (n int, err error) {
  38. err = m.readResults[0].err
  39. copy(b, m.readResults[0].data)
  40. n = len(m.readResults[0].data)
  41. m.readResults = m.readResults[1:]
  42. return
  43. }
  44. /*
  45. Don't end the mux readLoop for packetio.ErrTimeout or io.ErrShortBuffer, assert the following
  46. - io.ErrShortBuffer and packetio.ErrTimeout don't end the read loop
  47. - io.EOF ends the loop
  48. pion/webrtc#1720
  49. */
  50. func TestNonFatalRead(t *testing.T) {
  51. // Limit runtime in case of deadlocks
  52. lim := test.TimeOut(time.Second * 20)
  53. defer lim.Stop()
  54. expectedData := []byte("expectedData")
  55. // In memory pipe
  56. ca, cb := net.Pipe()
  57. require.NoError(t, cb.Close())
  58. conn := &muxErrorConn{ca, []muxErrorConnReadResult{
  59. // Non-fatal timeout error
  60. {packetio.ErrTimeout, nil},
  61. {nil, expectedData},
  62. {io.ErrShortBuffer, nil},
  63. {nil, expectedData},
  64. {io.EOF, nil},
  65. }}
  66. m := NewMux(Config{
  67. Conn: conn,
  68. BufferSize: testPipeBufferSize,
  69. LoggerFactory: logging.NewDefaultLoggerFactory(),
  70. })
  71. e := m.NewEndpoint(MatchAll)
  72. buff := make([]byte, testPipeBufferSize)
  73. n, err := e.Read(buff)
  74. require.NoError(t, err)
  75. require.Equal(t, buff[:n], expectedData)
  76. n, err = e.Read(buff)
  77. require.NoError(t, err)
  78. require.Equal(t, buff[:n], expectedData)
  79. <-m.closedCh
  80. require.NoError(t, m.Close())
  81. require.NoError(t, ca.Close())
  82. }
  83. // If a endpoint returns packetio.ErrFull it is a non-fatal error and shouldn't cause
  84. // the mux to be destroyed
  85. // pion/webrtc#2180
  86. func TestNonFatalDispatch(t *testing.T) {
  87. in, out := net.Pipe()
  88. m := NewMux(Config{
  89. Conn: out,
  90. LoggerFactory: logging.NewDefaultLoggerFactory(),
  91. BufferSize: 1500,
  92. })
  93. e := m.NewEndpoint(MatchSRTP)
  94. e.buffer.SetLimitSize(1)
  95. for i := 0; i <= 25; i++ {
  96. srtpPacket := []byte{128, 1, 2, 3, 4}
  97. _, err := in.Write(srtpPacket)
  98. require.NoError(t, err)
  99. }
  100. require.NoError(t, m.Close())
  101. require.NoError(t, in.Close())
  102. require.NoError(t, out.Close())
  103. }
  104. func BenchmarkDispatch(b *testing.B) {
  105. m := &Mux{
  106. endpoints: make(map[*Endpoint]MatchFunc),
  107. log: logging.NewDefaultLoggerFactory().NewLogger("mux"),
  108. }
  109. e := m.NewEndpoint(MatchSRTP)
  110. m.NewEndpoint(MatchSRTCP)
  111. buf := []byte{128, 1, 2, 3, 4}
  112. buf2 := make([]byte, 1200)
  113. b.StartTimer()
  114. for i := 0; i < b.N; i++ {
  115. err := m.dispatch(buf)
  116. if err != nil {
  117. b.Errorf("dispatch: %v", err)
  118. }
  119. _, err = e.buffer.Read(buf2)
  120. if err != nil {
  121. b.Errorf("read: %v", err)
  122. }
  123. }
  124. }