transport_test.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package ice
  6. import (
  7. "context"
  8. "net"
  9. "sync"
  10. "testing"
  11. "time"
  12. "github.com/pion/stun"
  13. "github.com/pion/transport/v2/test"
  14. )
  15. func TestStressDuplex(t *testing.T) {
  16. // Check for leaking routines
  17. report := test.CheckRoutines(t)
  18. defer report()
  19. // Limit runtime in case of deadlocks
  20. lim := test.TimeOut(time.Second * 20)
  21. defer lim.Stop()
  22. // Run the test
  23. stressDuplex(t)
  24. }
  25. func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
  26. const pollRate = 100 * time.Millisecond
  27. const margin = 20 * time.Millisecond // Allow 20msec error in time
  28. ticker := time.NewTicker(pollRate)
  29. defer func() {
  30. ticker.Stop()
  31. err := c.Close()
  32. if err != nil {
  33. t.Error(err)
  34. }
  35. }()
  36. startedAt := time.Now()
  37. for cnt := time.Duration(0); cnt <= timeout+defaultKeepaliveInterval+pollRate; cnt += pollRate {
  38. <-ticker.C
  39. var cs ConnectionState
  40. err := c.agent.run(context.Background(), func(ctx context.Context, agent *Agent) {
  41. cs = agent.connectionState
  42. })
  43. if err != nil {
  44. // We should never get here.
  45. panic(err)
  46. }
  47. if cs != ConnectionStateConnected {
  48. elapsed := time.Since(startedAt)
  49. if elapsed+margin < timeout {
  50. t.Fatalf("Connection timed out %f msec early", elapsed.Seconds()*1000)
  51. } else {
  52. t.Logf("Connection timed out in %f msec", elapsed.Seconds()*1000)
  53. return
  54. }
  55. }
  56. }
  57. t.Fatalf("Connection failed to time out in time. (expected timeout: %v)", timeout)
  58. }
  59. func TestTimeout(t *testing.T) {
  60. if testing.Short() {
  61. t.Skip("skipping test in short mode.")
  62. }
  63. // Check for leaking routines
  64. report := test.CheckRoutines(t)
  65. defer report()
  66. // Limit runtime in case of deadlocks
  67. lim := test.TimeOut(time.Second * 20)
  68. defer lim.Stop()
  69. t.Run("WithoutDisconnectTimeout", func(t *testing.T) {
  70. ca, cb := pipe(nil)
  71. err := cb.Close()
  72. if err != nil {
  73. // We should never get here.
  74. panic(err)
  75. }
  76. testTimeout(t, ca, defaultDisconnectedTimeout)
  77. })
  78. t.Run("WithDisconnectTimeout", func(t *testing.T) {
  79. ca, cb := pipeWithTimeout(5*time.Second, 3*time.Second)
  80. err := cb.Close()
  81. if err != nil {
  82. // We should never get here.
  83. panic(err)
  84. }
  85. testTimeout(t, ca, 5*time.Second)
  86. })
  87. }
  88. func TestReadClosed(t *testing.T) {
  89. // Check for leaking routines
  90. report := test.CheckRoutines(t)
  91. defer report()
  92. // Limit runtime in case of deadlocks
  93. lim := test.TimeOut(time.Second * 20)
  94. defer lim.Stop()
  95. ca, cb := pipe(nil)
  96. err := ca.Close()
  97. if err != nil {
  98. // We should never get here.
  99. panic(err)
  100. }
  101. err = cb.Close()
  102. if err != nil {
  103. // We should never get here.
  104. panic(err)
  105. }
  106. empty := make([]byte, 10)
  107. _, err = ca.Read(empty)
  108. if err == nil {
  109. t.Fatalf("Reading from a closed channel should return an error")
  110. }
  111. }
  112. func stressDuplex(t *testing.T) {
  113. ca, cb := pipe(nil)
  114. defer func() {
  115. err := ca.Close()
  116. if err != nil {
  117. t.Fatal(err)
  118. }
  119. err = cb.Close()
  120. if err != nil {
  121. t.Fatal(err)
  122. }
  123. }()
  124. opt := test.Options{
  125. MsgSize: 10,
  126. MsgCount: 1, // Order not reliable due to UDP & potentially multiple candidate pairs.
  127. }
  128. err := test.StressDuplex(ca, cb, opt)
  129. if err != nil {
  130. t.Fatal(err)
  131. }
  132. }
  133. func check(err error) {
  134. if err != nil {
  135. panic(err)
  136. }
  137. }
  138. func gatherAndExchangeCandidates(aAgent, bAgent *Agent) {
  139. var wg sync.WaitGroup
  140. wg.Add(2)
  141. check(aAgent.OnCandidate(func(candidate Candidate) {
  142. if candidate == nil {
  143. wg.Done()
  144. }
  145. }))
  146. check(aAgent.GatherCandidates())
  147. check(bAgent.OnCandidate(func(candidate Candidate) {
  148. if candidate == nil {
  149. wg.Done()
  150. }
  151. }))
  152. check(bAgent.GatherCandidates())
  153. wg.Wait()
  154. candidates, err := aAgent.GetLocalCandidates()
  155. check(err)
  156. for _, c := range candidates {
  157. candidateCopy, copyErr := c.copy()
  158. check(copyErr)
  159. check(bAgent.AddRemoteCandidate(candidateCopy))
  160. }
  161. candidates, err = bAgent.GetLocalCandidates()
  162. check(err)
  163. for _, c := range candidates {
  164. candidateCopy, copyErr := c.copy()
  165. check(copyErr)
  166. check(aAgent.AddRemoteCandidate(candidateCopy))
  167. }
  168. }
  169. func connect(aAgent, bAgent *Agent) (*Conn, *Conn) {
  170. gatherAndExchangeCandidates(aAgent, bAgent)
  171. accepted := make(chan struct{})
  172. var aConn *Conn
  173. go func() {
  174. var acceptErr error
  175. bUfrag, bPwd, acceptErr := bAgent.GetLocalUserCredentials()
  176. check(acceptErr)
  177. aConn, acceptErr = aAgent.Accept(context.TODO(), bUfrag, bPwd)
  178. check(acceptErr)
  179. close(accepted)
  180. }()
  181. aUfrag, aPwd, err := aAgent.GetLocalUserCredentials()
  182. check(err)
  183. bConn, err := bAgent.Dial(context.TODO(), aUfrag, aPwd)
  184. check(err)
  185. // Ensure accepted
  186. <-accepted
  187. return aConn, bConn
  188. }
  189. func pipe(defaultConfig *AgentConfig) (*Conn, *Conn) {
  190. var urls []*stun.URI
  191. aNotifier, aConnected := onConnected()
  192. bNotifier, bConnected := onConnected()
  193. cfg := &AgentConfig{}
  194. if defaultConfig != nil {
  195. *cfg = *defaultConfig
  196. }
  197. cfg.Urls = urls
  198. cfg.NetworkTypes = supportedNetworkTypes()
  199. aAgent, err := NewAgent(cfg)
  200. check(err)
  201. check(aAgent.OnConnectionStateChange(aNotifier))
  202. bAgent, err := NewAgent(cfg)
  203. check(err)
  204. check(bAgent.OnConnectionStateChange(bNotifier))
  205. aConn, bConn := connect(aAgent, bAgent)
  206. // Ensure pair selected
  207. // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
  208. <-aConnected
  209. <-bConnected
  210. return aConn, bConn
  211. }
  212. func pipeWithTimeout(disconnectTimeout time.Duration, iceKeepalive time.Duration) (*Conn, *Conn) {
  213. var urls []*stun.URI
  214. aNotifier, aConnected := onConnected()
  215. bNotifier, bConnected := onConnected()
  216. cfg := &AgentConfig{
  217. Urls: urls,
  218. DisconnectedTimeout: &disconnectTimeout,
  219. KeepaliveInterval: &iceKeepalive,
  220. NetworkTypes: supportedNetworkTypes(),
  221. }
  222. aAgent, err := NewAgent(cfg)
  223. check(err)
  224. check(aAgent.OnConnectionStateChange(aNotifier))
  225. bAgent, err := NewAgent(cfg)
  226. check(err)
  227. check(bAgent.OnConnectionStateChange(bNotifier))
  228. aConn, bConn := connect(aAgent, bAgent)
  229. // Ensure pair selected
  230. // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
  231. <-aConnected
  232. <-bConnected
  233. return aConn, bConn
  234. }
  235. func onConnected() (func(ConnectionState), chan struct{}) {
  236. done := make(chan struct{})
  237. return func(state ConnectionState) {
  238. if state == ConnectionStateConnected {
  239. close(done)
  240. }
  241. }, done
  242. }
  243. func randomPort(t testing.TB) int {
  244. t.Helper()
  245. conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
  246. if err != nil {
  247. t.Fatalf("failed to pickPort: %v", err)
  248. }
  249. defer func() {
  250. _ = conn.Close()
  251. }()
  252. switch addr := conn.LocalAddr().(type) {
  253. case *net.UDPAddr:
  254. return addr.Port
  255. default:
  256. t.Fatalf("unknown addr type %T", addr)
  257. return 0
  258. }
  259. }
  260. func TestConnStats(t *testing.T) {
  261. // Check for leaking routines
  262. report := test.CheckRoutines(t)
  263. defer report()
  264. // Limit runtime in case of deadlocks
  265. lim := test.TimeOut(time.Second * 20)
  266. defer lim.Stop()
  267. ca, cb := pipe(nil)
  268. if _, err := ca.Write(make([]byte, 10)); err != nil {
  269. t.Fatal("unexpected error trying to write")
  270. }
  271. var wg sync.WaitGroup
  272. wg.Add(1)
  273. go func() {
  274. buf := make([]byte, 10)
  275. if _, err := cb.Read(buf); err != nil {
  276. panic(errRead)
  277. }
  278. wg.Done()
  279. }()
  280. wg.Wait()
  281. if ca.BytesSent() != 10 {
  282. t.Fatal("bytes sent don't match")
  283. }
  284. if cb.BytesReceived() != 10 {
  285. t.Fatal("bytes received don't match")
  286. }
  287. err := ca.Close()
  288. if err != nil {
  289. // We should never get here.
  290. panic(err)
  291. }
  292. err = cb.Close()
  293. if err != nil {
  294. // We should never get here.
  295. panic(err)
  296. }
  297. }