interceptor_test.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package webrtc
  6. //
  7. import (
  8. "context"
  9. "sync/atomic"
  10. "testing"
  11. "time"
  12. "github.com/pion/interceptor"
  13. mock_interceptor "github.com/pion/interceptor/pkg/mock"
  14. "github.com/pion/rtp"
  15. "github.com/pion/transport/v2/test"
  16. "github.com/pion/webrtc/v3/pkg/media"
  17. "github.com/stretchr/testify/assert"
  18. )
  19. // E2E test of the features of Interceptors
  20. // * Assert an extension can be set on an outbound packet
  21. // * Assert an extension can be read on an outbound packet
  22. // * Assert that attributes set by an interceptor are returned to the Reader
  23. func TestPeerConnection_Interceptor(t *testing.T) {
  24. to := test.TimeOut(time.Second * 20)
  25. defer to.Stop()
  26. report := test.CheckRoutines(t)
  27. defer report()
  28. createPC := func() *PeerConnection {
  29. m := &MediaEngine{}
  30. assert.NoError(t, m.RegisterDefaultCodecs())
  31. ir := &interceptor.Registry{}
  32. ir.Add(&mock_interceptor.Factory{
  33. NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) {
  34. return &mock_interceptor.Interceptor{
  35. BindLocalStreamFn: func(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
  36. return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
  37. // set extension on outgoing packet
  38. header.Extension = true
  39. header.ExtensionProfile = 0xBEDE
  40. assert.NoError(t, header.SetExtension(2, []byte("foo")))
  41. return writer.Write(header, payload, attributes)
  42. })
  43. },
  44. BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
  45. return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
  46. if a == nil {
  47. a = interceptor.Attributes{}
  48. }
  49. a.Set("attribute", "value")
  50. return reader.Read(b, a)
  51. })
  52. },
  53. }, nil
  54. },
  55. })
  56. pc, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{})
  57. assert.NoError(t, err)
  58. return pc
  59. }
  60. offerer := createPC()
  61. answerer := createPC()
  62. track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
  63. assert.NoError(t, err)
  64. _, err = offerer.AddTrack(track)
  65. assert.NoError(t, err)
  66. seenRTP, seenRTPCancel := context.WithCancel(context.Background())
  67. answerer.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) {
  68. p, attributes, readErr := track.ReadRTP()
  69. assert.NoError(t, readErr)
  70. assert.Equal(t, p.Extension, true)
  71. assert.Equal(t, "foo", string(p.GetExtension(2)))
  72. assert.Equal(t, "value", attributes.Get("attribute"))
  73. seenRTPCancel()
  74. })
  75. assert.NoError(t, signalPair(offerer, answerer))
  76. func() {
  77. ticker := time.NewTicker(time.Millisecond * 20)
  78. for {
  79. select {
  80. case <-seenRTP.Done():
  81. return
  82. case <-ticker.C:
  83. assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
  84. }
  85. }
  86. }()
  87. closePairNow(t, offerer, answerer)
  88. }
  89. func Test_Interceptor_BindUnbind(t *testing.T) {
  90. lim := test.TimeOut(time.Second * 10)
  91. defer lim.Stop()
  92. report := test.CheckRoutines(t)
  93. defer report()
  94. m := &MediaEngine{}
  95. assert.NoError(t, m.RegisterDefaultCodecs())
  96. var (
  97. cntBindRTCPReader uint32
  98. cntBindRTCPWriter uint32
  99. cntBindLocalStream uint32
  100. cntUnbindLocalStream uint32
  101. cntBindRemoteStream uint32
  102. cntUnbindRemoteStream uint32
  103. cntClose uint32
  104. )
  105. mockInterceptor := &mock_interceptor.Interceptor{
  106. BindRTCPReaderFn: func(reader interceptor.RTCPReader) interceptor.RTCPReader {
  107. atomic.AddUint32(&cntBindRTCPReader, 1)
  108. return reader
  109. },
  110. BindRTCPWriterFn: func(writer interceptor.RTCPWriter) interceptor.RTCPWriter {
  111. atomic.AddUint32(&cntBindRTCPWriter, 1)
  112. return writer
  113. },
  114. BindLocalStreamFn: func(i *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
  115. atomic.AddUint32(&cntBindLocalStream, 1)
  116. return writer
  117. },
  118. UnbindLocalStreamFn: func(i *interceptor.StreamInfo) {
  119. atomic.AddUint32(&cntUnbindLocalStream, 1)
  120. },
  121. BindRemoteStreamFn: func(i *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
  122. atomic.AddUint32(&cntBindRemoteStream, 1)
  123. return reader
  124. },
  125. UnbindRemoteStreamFn: func(i *interceptor.StreamInfo) {
  126. atomic.AddUint32(&cntUnbindRemoteStream, 1)
  127. },
  128. CloseFn: func() error {
  129. atomic.AddUint32(&cntClose, 1)
  130. return nil
  131. },
  132. }
  133. ir := &interceptor.Registry{}
  134. ir.Add(&mock_interceptor.Factory{
  135. NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { return mockInterceptor, nil },
  136. })
  137. sender, receiver, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).newPair(Configuration{})
  138. assert.NoError(t, err)
  139. track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
  140. assert.NoError(t, err)
  141. _, err = sender.AddTrack(track)
  142. assert.NoError(t, err)
  143. receiverReady, receiverReadyFn := context.WithCancel(context.Background())
  144. receiver.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
  145. _, _, readErr := track.ReadRTP()
  146. assert.NoError(t, readErr)
  147. receiverReadyFn()
  148. })
  149. assert.NoError(t, signalPair(sender, receiver))
  150. ticker := time.NewTicker(time.Millisecond * 20)
  151. defer ticker.Stop()
  152. func() {
  153. for {
  154. select {
  155. case <-receiverReady.Done():
  156. return
  157. case <-ticker.C:
  158. // Send packet to make receiver track actual creates RTPReceiver.
  159. assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
  160. }
  161. }
  162. }()
  163. closePairNow(t, sender, receiver)
  164. // Bind/UnbindLocal/RemoteStream should be called from one side.
  165. if cnt := atomic.LoadUint32(&cntBindLocalStream); cnt != 1 {
  166. t.Errorf("BindLocalStreamFn is expected to be called once, but called %d times", cnt)
  167. }
  168. if cnt := atomic.LoadUint32(&cntUnbindLocalStream); cnt != 1 {
  169. t.Errorf("UnbindLocalStreamFn is expected to be called once, but called %d times", cnt)
  170. }
  171. if cnt := atomic.LoadUint32(&cntBindRemoteStream); cnt != 1 {
  172. t.Errorf("BindRemoteStreamFn is expected to be called once, but called %d times", cnt)
  173. }
  174. if cnt := atomic.LoadUint32(&cntUnbindRemoteStream); cnt != 1 {
  175. t.Errorf("UnbindRemoteStreamFn is expected to be called once, but called %d times", cnt)
  176. }
  177. // BindRTCPWriter/Reader and Close should be called from both side.
  178. if cnt := atomic.LoadUint32(&cntBindRTCPWriter); cnt != 2 {
  179. t.Errorf("BindRTCPWriterFn is expected to be called twice, but called %d times", cnt)
  180. }
  181. if cnt := atomic.LoadUint32(&cntBindRTCPReader); cnt != 2 {
  182. t.Errorf("BindRTCPReaderFn is expected to be called twice, but called %d times", cnt)
  183. }
  184. if cnt := atomic.LoadUint32(&cntClose); cnt != 2 {
  185. t.Errorf("CloseFn is expected to be called twice, but called %d times", cnt)
  186. }
  187. }
  188. func Test_InterceptorRegistry_Build(t *testing.T) {
  189. registryBuildCount := 0
  190. ir := &interceptor.Registry{}
  191. ir.Add(&mock_interceptor.Factory{
  192. NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) {
  193. registryBuildCount++
  194. return &interceptor.NoOp{}, nil
  195. },
  196. })
  197. peerConnectionA, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{})
  198. assert.NoError(t, err)
  199. peerConnectionB, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{})
  200. assert.NoError(t, err)
  201. assert.Equal(t, 2, registryBuildCount)
  202. closePairNow(t, peerConnectionA, peerConnectionB)
  203. }