quic_test.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. // +build !DISABLE_QUIC
  2. /*
  3. * Copyright (c) 2018, Psiphon Inc.
  4. * All rights reserved.
  5. *
  6. * This program is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License as published by
  8. * the Free Software Foundation, either version 3 of the License, or
  9. * (at your option) any later version.
  10. *
  11. * This program is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. * GNU General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU General Public License
  17. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. *
  19. */
  20. package quic
  21. import (
  22. "context"
  23. "io"
  24. "net"
  25. "runtime"
  26. "strings"
  27. "sync/atomic"
  28. "testing"
  29. "time"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  31. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  32. "golang.org/x/sync/errgroup"
  33. )
  34. func TestQUIC(t *testing.T) {
  35. for negotiateQUICVersion := range supportedVersionNumbers {
  36. t.Run(negotiateQUICVersion, func(t *testing.T) {
  37. runQUIC(t, negotiateQUICVersion)
  38. })
  39. }
  40. }
  41. func runQUIC(t *testing.T, negotiateQUICVersion string) {
  42. initGoroutines := getGoroutines()
  43. clients := 10
  44. bytesToSend := 1 << 20
  45. serverReceivedBytes := int64(0)
  46. clientReceivedBytes := int64(0)
  47. // Intermittently, on some platforms, the client connection termination
  48. // packet is not received even when sent/received locally; set a brief
  49. // idle timeout to ensure the server-side client handler doesn't block too
  50. // long on Read, causing the test to fail.
  51. //
  52. // In realistic network conditions, and especially under adversarial
  53. // network conditions, we should not expect to regularly receive client
  54. // connection termination packets.
  55. serverIdleTimeout = 1 * time.Second
  56. obfuscationKey := prng.HexString(32)
  57. listener, err := Listen(nil, "127.0.0.1:0", obfuscationKey)
  58. if err != nil {
  59. t.Fatalf("Listen failed: %s", err)
  60. }
  61. serverAddress := listener.Addr().String()
  62. testGroup, testCtx := errgroup.WithContext(context.Background())
  63. testGroup.Go(func() error {
  64. var serverGroup errgroup.Group
  65. for i := 0; i < clients; i++ {
  66. conn, err := listener.Accept()
  67. if err != nil {
  68. return errors.Trace(err)
  69. }
  70. serverGroup.Go(func() error {
  71. b := make([]byte, 1024)
  72. for {
  73. n, err := conn.Read(b)
  74. atomic.AddInt64(&serverReceivedBytes, int64(n))
  75. if err == io.EOF {
  76. return nil
  77. } else if err != nil {
  78. return errors.Trace(err)
  79. }
  80. _, err = conn.Write(b[:n])
  81. if err != nil {
  82. return errors.Trace(err)
  83. }
  84. }
  85. })
  86. }
  87. err := serverGroup.Wait()
  88. if err != nil {
  89. return errors.Trace(err)
  90. }
  91. return nil
  92. })
  93. for i := 0; i < clients; i++ {
  94. testGroup.Go(func() error {
  95. ctx, cancelFunc := context.WithTimeout(
  96. context.Background(), 1*time.Second)
  97. defer cancelFunc()
  98. remoteAddr, err := net.ResolveUDPAddr("udp", serverAddress)
  99. if err != nil {
  100. return errors.Trace(err)
  101. }
  102. packetConn, err := net.ListenPacket("udp4", "127.0.0.1:0")
  103. if err != nil {
  104. return errors.Trace(err)
  105. }
  106. obfuscationPaddingSeed, err := prng.NewSeed()
  107. if err != nil {
  108. return errors.Trace(err)
  109. }
  110. conn, err := Dial(
  111. ctx,
  112. packetConn,
  113. remoteAddr,
  114. serverAddress,
  115. negotiateQUICVersion,
  116. obfuscationKey,
  117. obfuscationPaddingSeed)
  118. if err != nil {
  119. return errors.Trace(err)
  120. }
  121. // Cancel should interrupt dialing only
  122. cancelFunc()
  123. var clientGroup errgroup.Group
  124. clientGroup.Go(func() error {
  125. defer conn.Close()
  126. b := make([]byte, 1024)
  127. bytesRead := 0
  128. for bytesRead < bytesToSend {
  129. n, err := conn.Read(b)
  130. bytesRead += n
  131. atomic.AddInt64(&clientReceivedBytes, int64(n))
  132. if err == io.EOF {
  133. break
  134. } else if err != nil {
  135. return errors.Trace(err)
  136. }
  137. }
  138. return nil
  139. })
  140. clientGroup.Go(func() error {
  141. b := make([]byte, bytesToSend)
  142. _, err := conn.Write(b)
  143. if err != nil {
  144. return errors.Trace(err)
  145. }
  146. return nil
  147. })
  148. return clientGroup.Wait()
  149. })
  150. }
  151. go func() {
  152. testGroup.Wait()
  153. }()
  154. <-testCtx.Done()
  155. listener.Close()
  156. err = testGroup.Wait()
  157. if err != nil {
  158. t.Errorf("goroutine failed: %s", err)
  159. }
  160. bytes := atomic.LoadInt64(&serverReceivedBytes)
  161. expectedBytes := int64(clients * bytesToSend)
  162. if bytes != expectedBytes {
  163. t.Errorf("unexpected serverReceivedBytes: %d vs. %d", bytes, expectedBytes)
  164. }
  165. bytes = atomic.LoadInt64(&clientReceivedBytes)
  166. if bytes != expectedBytes {
  167. t.Errorf("unexpected clientReceivedBytes: %d vs. %d", bytes, expectedBytes)
  168. }
  169. _, err = listener.Accept()
  170. if err == nil {
  171. t.Error("unexpected Accept after Close")
  172. }
  173. // Check for unexpected dangling goroutines after shutdown.
  174. //
  175. // quic-go.packetHandlerMap.listen shutdown is async and some quic-go
  176. // goroutines and/or timers dangle so this test makes allowances for these
  177. // known dangling goroutinees.
  178. expectedDanglingGoroutines := []string{
  179. "quic-go.(*packetHandlerMap).Retire.func1",
  180. "quic-go.(*packetHandlerMap).ReplaceWithClosed.func1",
  181. "quic-go.(*packetHandlerMap).RetireResetToken.func1",
  182. "gquic-go.(*packetHandlerMap).removeByConnectionIDAsString.func1",
  183. }
  184. sleepTime := 100 * time.Millisecond
  185. // The longest expected dangling goroutine is in gquic-go and is launched by a timer
  186. // that fires after ClosedSessionDeleteTimeout, which is 1m. Allow one extra second
  187. // to ensure this period elapses and the time.AfterFunc runs.
  188. //
  189. // To avoid taking 1m to run this test every time, the dangling goroutine check exits
  190. // early once no dangling goroutines are found. Note that this doesn't account for
  191. // any timers still pending at the early exit time.
  192. n := int((61 * time.Second) / sleepTime)
  193. for i := 0; i < n; i++ {
  194. // Sleep before making any checks, since quic-go.packetHandlerMap.listen
  195. // shutdown is asynchronous.
  196. time.Sleep(100 * time.Millisecond)
  197. // After the full 61s, no dangling goroutines are expected.
  198. if i == n-1 {
  199. expectedDanglingGoroutines = []string{}
  200. }
  201. hasDangling, onlyExpectedDangling := checkDanglingGoroutines(
  202. t, initGoroutines, expectedDanglingGoroutines)
  203. if !hasDangling {
  204. break
  205. } else if !onlyExpectedDangling {
  206. t.Fatalf("unexpected dangling goroutines")
  207. }
  208. }
  209. }
  210. func getGoroutines() []runtime.StackRecord {
  211. n, _ := runtime.GoroutineProfile(nil)
  212. r := make([]runtime.StackRecord, n)
  213. runtime.GoroutineProfile(r)
  214. return r
  215. }
  216. func checkDanglingGoroutines(
  217. t *testing.T,
  218. initGoroutines []runtime.StackRecord,
  219. expectedDanglingGoroutines []string) (bool, bool) {
  220. hasDangling := false
  221. onlyExpectedDangling := true
  222. current := getGoroutines()
  223. for _, g := range current {
  224. found := false
  225. for _, h := range initGoroutines {
  226. if g == h {
  227. found = true
  228. break
  229. }
  230. }
  231. if !found {
  232. stack := g.Stack()
  233. funcNames := make([]string, len(stack))
  234. skip := false
  235. isExpected := false
  236. for i := 0; i < len(stack); i++ {
  237. funcNames[i] = getFunctionName(stack[i])
  238. // The current goroutine won't have the same stack as in initGoroutines.
  239. if strings.Contains(funcNames[i], "checkDanglingGoroutines") {
  240. skip = true
  241. break
  242. }
  243. // testing.T.Run runs the the test function, f, in another goroutine. f is
  244. // the current goroutine, which captures initGoroutines.
  245. // https://github.com/golang/go/blob/release-branch.go1.13/src/testing/testing.go#L960-L961:
  246. //
  247. // go tRunner(t, f)
  248. // if !<-t.signal {
  249. // ...
  250. //
  251. // f may capture initGoroutines before or after testing.T.Run advances to
  252. // the channel receive, so the stack of the testing.T.Run goroutine may or
  253. // may not match initGoroutines. Skip it.
  254. if strings.Contains(funcNames[i], "testing.(*T).Run") {
  255. skip = true
  256. break
  257. }
  258. for _, expected := range expectedDanglingGoroutines {
  259. if strings.Contains(funcNames[i], expected) {
  260. isExpected = true
  261. break
  262. }
  263. }
  264. if isExpected {
  265. break
  266. }
  267. }
  268. if !skip {
  269. hasDangling = true
  270. if !isExpected {
  271. onlyExpectedDangling = false
  272. s := strings.Join(funcNames, " <- ")
  273. t.Logf("found unexpected dangling goroutine: %s", s)
  274. }
  275. }
  276. }
  277. }
  278. return hasDangling, onlyExpectedDangling
  279. }
  280. func getFunctionName(pc uintptr) string {
  281. funcName := runtime.FuncForPC(pc).Name()
  282. index := strings.LastIndex(funcName, "/")
  283. if index != -1 {
  284. funcName = funcName[index+1:]
  285. }
  286. return funcName
  287. }