conn_go_test.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package dtls
  6. import (
  7. "bytes"
  8. "context"
  9. "crypto/tls"
  10. "errors"
  11. "net"
  12. "testing"
  13. "time"
  14. "github.com/pion/dtls/v2/pkg/crypto/selfsign"
  15. "github.com/pion/transport/v2/dpipe"
  16. "github.com/pion/transport/v2/test"
  17. )
  18. func TestContextConfig(t *testing.T) {
  19. // Limit runtime in case of deadlocks
  20. lim := test.TimeOut(time.Second * 20)
  21. defer lim.Stop()
  22. report := test.CheckRoutines(t)
  23. defer report()
  24. addrListen, err := net.ResolveUDPAddr("udp", "localhost:0")
  25. if err != nil {
  26. t.Fatalf("Unexpected error: %v", err)
  27. }
  28. // Dummy listener
  29. listen, err := net.ListenUDP("udp", addrListen)
  30. if err != nil {
  31. t.Fatalf("Unexpected error: %v", err)
  32. }
  33. defer func() {
  34. _ = listen.Close()
  35. }()
  36. addr, ok := listen.LocalAddr().(*net.UDPAddr)
  37. if !ok {
  38. t.Fatal("Failed to cast net.UDPAddr")
  39. }
  40. cert, err := selfsign.GenerateSelfSigned()
  41. if err != nil {
  42. t.Fatalf("Unexpected error: %v", err)
  43. }
  44. config := &Config{
  45. ConnectContextMaker: func() (context.Context, func()) {
  46. return context.WithTimeout(context.Background(), 40*time.Millisecond)
  47. },
  48. Certificates: []tls.Certificate{cert},
  49. }
  50. dials := map[string]struct {
  51. f func() (func() (net.Conn, error), func())
  52. order []byte
  53. }{
  54. "Dial": {
  55. f: func() (func() (net.Conn, error), func()) {
  56. return func() (net.Conn, error) {
  57. return Dial("udp", addr, config)
  58. }, func() {
  59. }
  60. },
  61. order: []byte{0, 1, 2},
  62. },
  63. "DialWithContext": {
  64. f: func() (func() (net.Conn, error), func()) {
  65. ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
  66. return func() (net.Conn, error) {
  67. return DialWithContext(ctx, "udp", addr, config)
  68. }, func() {
  69. cancel()
  70. }
  71. },
  72. order: []byte{0, 2, 1},
  73. },
  74. "Client": {
  75. f: func() (func() (net.Conn, error), func()) {
  76. ca, _ := dpipe.Pipe()
  77. return func() (net.Conn, error) {
  78. return Client(ca, config)
  79. }, func() {
  80. _ = ca.Close()
  81. }
  82. },
  83. order: []byte{0, 1, 2},
  84. },
  85. "ClientWithContext": {
  86. f: func() (func() (net.Conn, error), func()) {
  87. ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
  88. ca, _ := dpipe.Pipe()
  89. return func() (net.Conn, error) {
  90. return ClientWithContext(ctx, ca, config)
  91. }, func() {
  92. cancel()
  93. _ = ca.Close()
  94. }
  95. },
  96. order: []byte{0, 2, 1},
  97. },
  98. "Server": {
  99. f: func() (func() (net.Conn, error), func()) {
  100. ca, _ := dpipe.Pipe()
  101. return func() (net.Conn, error) {
  102. return Server(ca, config)
  103. }, func() {
  104. _ = ca.Close()
  105. }
  106. },
  107. order: []byte{0, 1, 2},
  108. },
  109. "ServerWithContext": {
  110. f: func() (func() (net.Conn, error), func()) {
  111. ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
  112. ca, _ := dpipe.Pipe()
  113. return func() (net.Conn, error) {
  114. return ServerWithContext(ctx, ca, config)
  115. }, func() {
  116. cancel()
  117. _ = ca.Close()
  118. }
  119. },
  120. order: []byte{0, 2, 1},
  121. },
  122. }
  123. for name, dial := range dials {
  124. dial := dial
  125. t.Run(name, func(t *testing.T) {
  126. done := make(chan struct{})
  127. go func() {
  128. d, cancel := dial.f()
  129. conn, err := d()
  130. defer cancel()
  131. var netError net.Error
  132. if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck
  133. t.Errorf("Client error exp(Temporary network error) failed(%v)", err)
  134. close(done)
  135. return
  136. }
  137. done <- struct{}{}
  138. if err == nil {
  139. _ = conn.Close()
  140. }
  141. }()
  142. var order []byte
  143. early := time.After(20 * time.Millisecond)
  144. late := time.After(60 * time.Millisecond)
  145. func() {
  146. for len(order) < 3 {
  147. select {
  148. case <-early:
  149. order = append(order, 0)
  150. case _, ok := <-done:
  151. if !ok {
  152. return
  153. }
  154. order = append(order, 1)
  155. case <-late:
  156. order = append(order, 2)
  157. }
  158. }
  159. }()
  160. if !bytes.Equal(dial.order, order) {
  161. t.Errorf("Invalid cancel timing, expected: %v, got: %v", dial.order, order)
  162. }
  163. })
  164. }
  165. }