selection_test.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package ice
  6. import (
  7. "bytes"
  8. "context"
  9. "errors"
  10. "io"
  11. "net"
  12. "sync/atomic"
  13. "testing"
  14. "time"
  15. "github.com/pion/stun"
  16. "github.com/pion/transport/v3/test"
  17. "github.com/stretchr/testify/assert"
  18. "github.com/stretchr/testify/require"
  19. )
  20. func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool {
  21. testMessage := []byte("Hello World")
  22. testBuffer := make([]byte, len(testMessage))
  23. readDone, readDoneCancel := context.WithCancel(context.Background())
  24. go func() {
  25. _, err := readingConn.Read(testBuffer)
  26. if errors.Is(err, io.EOF) {
  27. return
  28. }
  29. require.NoError(t, err)
  30. require.True(t, bytes.Equal(testMessage, testBuffer))
  31. readDoneCancel()
  32. }()
  33. attempts := 0
  34. for {
  35. select {
  36. case <-time.After(5 * time.Millisecond):
  37. if attempts > maxAttempts {
  38. return false
  39. }
  40. _, err := writingConn.Write(testMessage)
  41. require.NoError(t, err)
  42. attempts++
  43. case <-readDone.Done():
  44. return true
  45. }
  46. }
  47. }
  48. func TestBindingRequestHandler(t *testing.T) {
  49. defer test.CheckRoutines(t)()
  50. defer test.TimeOut(time.Second * 30).Stop()
  51. var switchToNewCandidatePair, controlledLoggingFired atomic.Value
  52. oneHour := time.Hour
  53. keepaliveInterval := time.Millisecond * 20
  54. aNotifier, aConnected := onConnected()
  55. bNotifier, bConnected := onConnected()
  56. controllingAgent, err := NewAgent(&AgentConfig{
  57. NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
  58. MulticastDNSMode: MulticastDNSModeDisabled,
  59. KeepaliveInterval: &keepaliveInterval,
  60. CheckInterval: &oneHour,
  61. BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
  62. controlledLoggingFired.Store(true)
  63. return false
  64. },
  65. })
  66. require.NoError(t, err)
  67. require.NoError(t, controllingAgent.OnConnectionStateChange(aNotifier))
  68. controlledAgent, err := NewAgent(&AgentConfig{
  69. NetworkTypes: []NetworkType{NetworkTypeUDP4},
  70. MulticastDNSMode: MulticastDNSModeDisabled,
  71. KeepaliveInterval: &keepaliveInterval,
  72. CheckInterval: &oneHour,
  73. BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
  74. // Don't switch candidate pair until we are ready
  75. val, ok := switchToNewCandidatePair.Load().(bool)
  76. return ok && val
  77. },
  78. })
  79. require.NoError(t, err)
  80. require.NoError(t, controlledAgent.OnConnectionStateChange(bNotifier))
  81. controlledConn, controllingConn := connect(controlledAgent, controllingAgent)
  82. <-aConnected
  83. <-bConnected
  84. // Assert we have connected and can send data
  85. require.True(t, sendUntilDone(t, controlledConn, controllingConn, 100))
  86. // Take the lock on the controlling Agent and unset state
  87. assert.NoError(t, controlledAgent.run(controlledAgent.context(), func(_ context.Context, controlledAgent *Agent) {
  88. for net, cs := range controlledAgent.remoteCandidates {
  89. for _, c := range cs {
  90. require.NoError(t, c.close())
  91. }
  92. delete(controlledAgent.remoteCandidates, net)
  93. }
  94. for _, c := range controlledAgent.localCandidates[NetworkTypeUDP4] {
  95. cast, ok := c.(*CandidateHost)
  96. require.True(t, ok)
  97. cast.remoteCandidateCaches = map[AddrPort]Candidate{}
  98. }
  99. controlledAgent.setSelectedPair(nil)
  100. controlledAgent.checklist = make([]*CandidatePair, 0)
  101. }))
  102. // Assert that Selected Candidate pair has only been unset on Controlled side
  103. candidatePair, err := controlledAgent.GetSelectedCandidatePair()
  104. assert.Nil(t, candidatePair)
  105. assert.NoError(t, err)
  106. candidatePair, err = controllingAgent.GetSelectedCandidatePair()
  107. assert.NotNil(t, candidatePair)
  108. assert.NoError(t, err)
  109. // Sending will fail, we no longer have a selected candidate pair
  110. require.False(t, sendUntilDone(t, controlledConn, controllingConn, 20))
  111. // Send STUN Binding requests until a new Selected Candidate Pair has been set by BindingRequestHandler
  112. switchToNewCandidatePair.Store(true)
  113. for {
  114. controllingAgent.requestConnectivityCheck()
  115. candidatePair, err = controlledAgent.GetSelectedCandidatePair()
  116. require.NoError(t, err)
  117. if candidatePair != nil {
  118. break
  119. }
  120. time.Sleep(time.Millisecond * 5)
  121. }
  122. // We have a new selected candidate pair because of BindingRequestHandler, test that it works
  123. require.True(t, sendUntilDone(t, controllingConn, controlledConn, 100))
  124. fired, ok := controlledLoggingFired.Load().(bool)
  125. require.True(t, ok)
  126. require.True(t, fired)
  127. closePipe(t, controllingConn, controlledConn)
  128. }