marionette_test.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. /*
  2. * Copyright (c) 2018, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package marionette
  20. import (
  21. "context"
  22. "fmt"
  23. "io"
  24. "net"
  25. "net/http"
  26. _ "net/http/pprof"
  27. "sync/atomic"
  28. "testing"
  29. "time"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  31. "golang.org/x/sync/errgroup"
  32. )
  33. func TestMarionette(t *testing.T) {
  34. go func() {
  35. fmt.Println(http.ListenAndServe("localhost:6060", nil))
  36. }()
  37. // Create a number of concurrent Marionette clients, each of which sends
  38. // data to the server. The server echoes back the data.
  39. clients := 5
  40. bytesToSend := 1 << 15
  41. serverReceivedBytes := int64(0)
  42. clientReceivedBytes := int64(0)
  43. serverAddress := "127.0.0.1"
  44. format := "http_simple_nonblocking"
  45. listener, err := Listen(serverAddress, format)
  46. if err != nil {
  47. t.Fatalf("Listen failed: %s", err)
  48. }
  49. testGroup, testCtx := errgroup.WithContext(context.Background())
  50. testGroup.Go(func() error {
  51. var serverGroup errgroup.Group
  52. for i := 0; i < clients; i++ {
  53. conn, err := listener.Accept()
  54. if err != nil {
  55. return errors.Trace(err)
  56. }
  57. serverGroup.Go(func() error {
  58. defer func() {
  59. fmt.Printf("Start server conn.Close\n")
  60. start := time.Now()
  61. conn.Close()
  62. fmt.Printf("Done server conn.Close: %s\n", time.Since(start))
  63. }()
  64. bytesFromClient := 0
  65. b := make([]byte, 1024)
  66. for bytesFromClient < bytesToSend {
  67. n, err := conn.Read(b)
  68. bytesFromClient += n
  69. atomic.AddInt64(&serverReceivedBytes, int64(n))
  70. if err != nil {
  71. fmt.Printf("Server read error: %s\n", err)
  72. return errors.Trace(err)
  73. }
  74. _, err = conn.Write(b[:n])
  75. if err != nil {
  76. fmt.Printf("Server write error: %s\n", err)
  77. return errors.Trace(err)
  78. }
  79. }
  80. return nil
  81. })
  82. }
  83. err := serverGroup.Wait()
  84. if err != nil {
  85. return errors.Trace(err)
  86. }
  87. return nil
  88. })
  89. for i := 0; i < clients; i++ {
  90. testGroup.Go(func() error {
  91. ctx, cancelFunc := context.WithTimeout(
  92. context.Background(), 1*time.Second)
  93. defer cancelFunc()
  94. conn, err := Dial(ctx, &net.Dialer{}, format, serverAddress)
  95. if err != nil {
  96. return errors.Trace(err)
  97. }
  98. var clientGroup errgroup.Group
  99. clientGroup.Go(func() error {
  100. defer func() {
  101. fmt.Printf("Start client conn.Close\n")
  102. start := time.Now()
  103. conn.Close()
  104. fmt.Printf("Done client conn.Close: %s\n", time.Since(start))
  105. }()
  106. b := make([]byte, 1024)
  107. bytesRead := 0
  108. for bytesRead < bytesToSend {
  109. n, err := conn.Read(b)
  110. bytesRead += n
  111. atomic.AddInt64(&clientReceivedBytes, int64(n))
  112. if err == io.EOF {
  113. break
  114. } else if err != nil {
  115. fmt.Printf("Client read error: %s\n", err)
  116. return errors.Trace(err)
  117. }
  118. }
  119. return nil
  120. })
  121. clientGroup.Go(func() error {
  122. b := make([]byte, bytesToSend)
  123. _, err := conn.Write(b)
  124. if err != nil {
  125. fmt.Printf("Client write error: %s\n", err)
  126. return errors.Trace(err)
  127. }
  128. return nil
  129. })
  130. return clientGroup.Wait()
  131. })
  132. }
  133. go func() {
  134. testGroup.Wait()
  135. }()
  136. <-testCtx.Done()
  137. fmt.Printf("Start listener.Close\n")
  138. start := time.Now()
  139. listener.Close()
  140. fmt.Printf("Done listener.Close: %s\n", time.Since(start))
  141. err = testGroup.Wait()
  142. if err != nil {
  143. t.Errorf("goroutine failed: %s", err)
  144. }
  145. bytes := atomic.LoadInt64(&serverReceivedBytes)
  146. expectedBytes := int64(clients * bytesToSend)
  147. if bytes != expectedBytes {
  148. t.Errorf("unexpected serverReceivedBytes: %d vs. %d", bytes, expectedBytes)
  149. }
  150. bytes = atomic.LoadInt64(&clientReceivedBytes)
  151. if bytes != expectedBytes {
  152. t.Errorf("unexpected clientReceivedBytes: %d vs. %d", bytes, expectedBytes)
  153. }
  154. }