packetman_linux_test.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. //go:build PSIPHON_RUN_PACKET_MANIPULATOR_TEST
  2. // +build PSIPHON_RUN_PACKET_MANIPULATOR_TEST
  3. /*
  4. * Copyright (c) 2020, Psiphon Inc.
  5. * All rights reserved.
  6. *
  7. * This program is free software: you can redistribute it and/or modify
  8. * it under the terms of the GNU General Public License as published by
  9. * the Free Software Foundation, either version 3 of the License, or
  10. * (at your option) any later version.
  11. *
  12. * This program is distributed in the hope that it will be useful,
  13. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  15. * GNU General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU General Public License
  18. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  19. *
  20. */
  21. package packetman
  22. import (
  23. "fmt"
  24. "io"
  25. "io/ioutil"
  26. "net"
  27. "net/http"
  28. "strconv"
  29. "testing"
  30. "time"
  31. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  32. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/stacktrace"
  33. )
  34. func TestPacketManipulatorIPv4(t *testing.T) {
  35. testPacketManipulator(false, t)
  36. }
  37. func TestPacketManipulatorIPv6(t *testing.T) {
  38. testPacketManipulator(true, t)
  39. }
  40. func testPacketManipulator(useIPv6 bool, t *testing.T) {
  41. // Test: run a Manipulator in front of a web server; make an HTTP request;
  42. // the expected transformation spec should be executed (as reported by
  43. // GetAppliedSpecName) and the request must succeed.
  44. ipv4, ipv6, err := common.GetRoutableInterfaceIPAddresses()
  45. if err != nil {
  46. t.Fatalf("GetRoutableInterfaceIPAddressesfailed: %v", err)
  47. }
  48. network := "tcp4"
  49. address := net.JoinHostPort(ipv4.String(), "0")
  50. if useIPv6 {
  51. if ipv6 == nil {
  52. t.Skipf("test unsupported: no IP address")
  53. }
  54. network = "tcp6"
  55. address = net.JoinHostPort(ipv6.String(), "0")
  56. }
  57. listener, err := net.Listen(network, address)
  58. if err != nil {
  59. t.Fatalf("net.Listen failed: %v", err)
  60. }
  61. defer listener.Close()
  62. hostStr, portStr, err := net.SplitHostPort(listener.Addr().String())
  63. if err != nil {
  64. t.Fatalf("net.SplitHostPort failed: %s", err.Error())
  65. }
  66. listenerPort, _ := strconv.Atoi(portStr)
  67. // [["TCP-flags S"]] replaces the original SYN-ACK packet with a single
  68. // SYN packet, implementing TCP simultaneous open.
  69. testSpecName := "test-spec"
  70. extraDataValue := "extra-data"
  71. config := &Config{
  72. Logger: newTestLogger(),
  73. ProtocolPorts: []int{listenerPort},
  74. Specs: []*Spec{{Name: testSpecName, PacketSpecs: [][]string{{"TCP-flags S"}}}},
  75. SelectSpecName: func(protocolPort int, _ net.IP) (string, interface{}) {
  76. if protocolPort == listenerPort {
  77. return testSpecName, extraDataValue
  78. }
  79. return "", nil
  80. },
  81. QueueNumber: 1,
  82. }
  83. m, err := NewManipulator(config)
  84. if err != nil {
  85. t.Fatalf("NewManipulator failed: %v", err)
  86. }
  87. err = m.Start()
  88. if err != nil {
  89. t.Fatalf("Manipulator.Start failed: %v", err)
  90. }
  91. defer m.Stop()
  92. go func() {
  93. serveMux := http.NewServeMux()
  94. serveMux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
  95. io.WriteString(w, "test-response\n")
  96. })
  97. server := &http.Server{
  98. Handler: serveMux,
  99. ConnState: func(conn net.Conn, state http.ConnState) {
  100. if state == http.StateNew {
  101. localAddr := conn.LocalAddr().(*net.TCPAddr)
  102. remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
  103. specName, extraData, err := m.GetAppliedSpecName(localAddr, remoteAddr)
  104. if err != nil {
  105. t.Fatalf("GetAppliedSpecName failed: %v", err)
  106. }
  107. if specName != testSpecName {
  108. t.Fatalf("unexpected spec name: %s", specName)
  109. }
  110. extraDataStr, ok := extraData.(string)
  111. if !ok || extraDataStr != extraDataValue {
  112. t.Fatalf("unexpected extra data value: %v", extraData)
  113. }
  114. }
  115. },
  116. }
  117. server.Serve(listener)
  118. }()
  119. httpClient := &http.Client{
  120. Timeout: 30 * time.Second,
  121. }
  122. response, err := httpClient.Get(fmt.Sprintf("http://%s:%s", hostStr, portStr))
  123. if err != nil {
  124. t.Fatalf("http.Get failed: %v", err)
  125. }
  126. defer response.Body.Close()
  127. _, err = ioutil.ReadAll(response.Body)
  128. if err != nil {
  129. t.Fatalf("ioutil.ReadAll failed: %v", err)
  130. }
  131. if response.StatusCode != http.StatusOK {
  132. t.Fatalf("unexpected response code: %d", response.StatusCode)
  133. }
  134. }
  135. func newTestLogger() common.Logger {
  136. return &testLogger{}
  137. }
  138. type testLogger struct {
  139. }
  140. func (logger *testLogger) WithTrace() common.LogTrace {
  141. return &testLogTrace{
  142. trace: stacktrace.GetParentFunctionName(),
  143. }
  144. }
  145. func (logger *testLogger) WithTraceFields(fields common.LogFields) common.LogTrace {
  146. return &testLogTrace{
  147. trace: stacktrace.GetParentFunctionName(),
  148. fields: fields,
  149. }
  150. }
  151. func (logger *testLogger) LogMetric(metric string, fields common.LogFields) {
  152. }
  153. func (logger *testLogger) IsLogLevelDebug() bool {
  154. return true
  155. }
  156. type testLogTrace struct {
  157. trace string
  158. fields common.LogFields
  159. }
  160. func (log *testLogTrace) log(
  161. noticeType string, args ...interface{}) {
  162. fmt.Printf("[%s] %s: %+v: %s\n",
  163. noticeType,
  164. log.trace,
  165. log.fields,
  166. fmt.Sprint(args...))
  167. }
  168. func (log *testLogTrace) Debug(args ...interface{}) {
  169. log.log("DEBUG", args...)
  170. }
  171. func (log *testLogTrace) Info(args ...interface{}) {
  172. log.log("INFO", args...)
  173. }
  174. func (log *testLogTrace) Warning(args ...interface{}) {
  175. log.log("ALERT", args...)
  176. }
  177. func (log *testLogTrace) Error(args ...interface{}) {
  178. log.log("ERROR", args...)
  179. }