demux_test.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. /*
  2. * Copyright (c) 2023, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "bytes"
  22. "context"
  23. std_errors "errors"
  24. "fmt"
  25. "math/rand"
  26. "net"
  27. "testing"
  28. "time"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  30. )
  31. type protocolDemuxTest struct {
  32. name string
  33. classifiers []protocolClassifier
  34. classifierType []string
  35. // conns made on demand so the same test instance can be reused across
  36. // tests.
  37. conns []func() net.Conn
  38. // NOTE: duplicate expected key and value not supported. E.g.
  39. // {"1": {"A", "A"}} will result in a test failure, but
  40. // {"1": {"A"}, "2": {"A"}} will not.
  41. // Expected stream of bytes to read from each conn type. Test will halt
  42. // if any of the values are not observed.
  43. expected map[string][]string
  44. }
  45. func runProtocolDemuxTest(tt *protocolDemuxTest) error {
  46. conns := make(chan net.Conn)
  47. l := testListener{conns: conns}
  48. go func() {
  49. // send conns downstream in random order
  50. randOrd := rand.Perm(len(tt.conns))
  51. for i := range randOrd {
  52. conns <- tt.conns[i]()
  53. }
  54. }()
  55. mux, protoListeners := newProtocolDemux(context.Background(), l, tt.classifiers, 0)
  56. errs := make([]chan error, len(protoListeners))
  57. for i := range errs {
  58. errs[i] = make(chan error)
  59. }
  60. for i, protoListener := range protoListeners {
  61. ind := i
  62. l := protoListener
  63. go func() {
  64. defer close(errs[ind])
  65. protoListenerType := tt.classifierType[ind]
  66. expectedValues, ok := tt.expected[protoListenerType]
  67. if !ok {
  68. errs[ind] <- fmt.Errorf("conn type %s not found", protoListenerType)
  69. return
  70. }
  71. expectedValuesNotSeen := make(map[string]struct{})
  72. for _, v := range expectedValues {
  73. expectedValuesNotSeen[v] = struct{}{}
  74. }
  75. // Keep accepting conns until all conns of
  76. // protoListenerType are retrieved from the mux.
  77. for len(expectedValuesNotSeen) > 0 {
  78. conn, err := l.Accept()
  79. if err != nil {
  80. errs[ind] <- err
  81. return
  82. }
  83. connType := conn.(*bufferedConn).Conn.(*testConn).connType
  84. if connType != protoListenerType {
  85. errs[ind] <- fmt.Errorf("expected conn type %s but got %s for %s", protoListenerType, connType, conn.(*bufferedConn).buffer.String())
  86. return
  87. }
  88. var acc []byte
  89. b := make([]byte, 1) // TODO: randomize read buffer size
  90. for {
  91. n, err := conn.Read(b)
  92. if err != nil {
  93. errs[ind] <- err
  94. return
  95. }
  96. if n == 0 {
  97. break
  98. }
  99. acc = append(acc, b[:n]...)
  100. }
  101. if _, ok := expectedValuesNotSeen[string(acc)]; !ok {
  102. errs[ind] <- fmt.Errorf("unexpected value %s", string(acc))
  103. return
  104. }
  105. delete(expectedValuesNotSeen, string(acc))
  106. }
  107. }()
  108. }
  109. runErr := make(chan error)
  110. go func() {
  111. defer close(runErr)
  112. err := mux.run()
  113. if err != nil && !std_errors.Is(err, context.Canceled) {
  114. runErr <- err
  115. }
  116. }()
  117. for i := range errs {
  118. err := <-errs[i]
  119. if err != nil {
  120. return errors.Trace(err)
  121. }
  122. }
  123. err := mux.Close()
  124. if err != nil {
  125. return errors.Trace(err)
  126. }
  127. err = <-runErr
  128. if err != nil && !std_errors.Is(err, net.ErrClosed) {
  129. return errors.Trace(err)
  130. }
  131. return nil
  132. }
  133. func TestProtocolDemux(t *testing.T) {
  134. aClassifier := protocolClassifier{
  135. match: func(b []byte) bool {
  136. return bytes.HasPrefix(b, []byte("AAA"))
  137. },
  138. }
  139. bClassifier := protocolClassifier{
  140. match: func(b []byte) bool {
  141. return bytes.HasPrefix(b, []byte("BBBB"))
  142. },
  143. }
  144. // TODO: could add delay between each testConn returning bytes to simulate
  145. // network delay.
  146. tests := []protocolDemuxTest{
  147. {
  148. name: "single conn",
  149. classifiers: []protocolClassifier{
  150. aClassifier,
  151. },
  152. classifierType: []string{"A"},
  153. conns: []func() net.Conn{
  154. func() net.Conn {
  155. return &testConn{connType: "A", b: []byte("AAA")}
  156. },
  157. },
  158. expected: map[string][]string{
  159. "A": {"AAA"},
  160. },
  161. },
  162. {
  163. name: "multiple conns one of each type",
  164. classifiers: []protocolClassifier{
  165. aClassifier,
  166. bClassifier,
  167. },
  168. classifierType: []string{"A", "B"},
  169. conns: []func() net.Conn{
  170. func() net.Conn {
  171. return &testConn{connType: "A", b: []byte("AAAzzzzz")}
  172. },
  173. func() net.Conn {
  174. return &testConn{connType: "B", b: []byte("BBBBzzzzz")}
  175. },
  176. },
  177. expected: map[string][]string{
  178. "A": {"AAAzzzzz"},
  179. "B": {"BBBBzzzzz"},
  180. },
  181. },
  182. {
  183. name: "multiple conns multiple of each type",
  184. classifiers: []protocolClassifier{
  185. aClassifier,
  186. bClassifier,
  187. },
  188. classifierType: []string{"A", "B"},
  189. conns: []func() net.Conn{
  190. func() net.Conn {
  191. return &testConn{connType: "A", b: []byte("AAA1zzzzz")}
  192. },
  193. func() net.Conn {
  194. return &testConn{connType: "B", b: []byte("BBBB1zzzzz")}
  195. },
  196. func() net.Conn {
  197. return &testConn{connType: "A", b: []byte("AAA2zzzzz")}
  198. },
  199. func() net.Conn {
  200. return &testConn{connType: "B", b: []byte("BBBB2zzzzz")}
  201. },
  202. },
  203. expected: map[string][]string{
  204. "A": {"AAA1zzzzz", "AAA2zzzzz"},
  205. "B": {"BBBB1zzzzz", "BBBB2zzzzz"},
  206. },
  207. },
  208. }
  209. for _, tt := range tests {
  210. t.Run(tt.name, func(t *testing.T) {
  211. err := runProtocolDemuxTest(&tt)
  212. if err != nil {
  213. t.Fatalf("runProtocolDemuxTest failed: %v", err)
  214. }
  215. })
  216. }
  217. }
  218. func BenchmarkProtocolDemux(b *testing.B) {
  219. rand.Seed(time.Now().UnixNano())
  220. aClassifier := protocolClassifier{
  221. match: func(b []byte) bool {
  222. return bytes.HasPrefix(b, []byte("AAA"))
  223. },
  224. minBytesToMatch: 3,
  225. maxBytesToMatch: 3,
  226. }
  227. bClassifier := protocolClassifier{
  228. match: func(b []byte) bool {
  229. return bytes.HasPrefix(b, []byte("BBBB"))
  230. },
  231. minBytesToMatch: 4,
  232. maxBytesToMatch: 4,
  233. }
  234. cClassifier := protocolClassifier{
  235. match: func(b []byte) bool {
  236. return bytes.HasPrefix(b, []byte("C"))
  237. },
  238. minBytesToMatch: 1,
  239. maxBytesToMatch: 1,
  240. }
  241. connTypeToPrefix := map[string]string{
  242. "A": "AAA",
  243. "B": "BBBB",
  244. "C": "C",
  245. }
  246. var conns []func() net.Conn
  247. connsPerConnType := 100
  248. expected := make(map[string][]string)
  249. for connType, connTypePrefix := range connTypeToPrefix {
  250. for i := 0; i < connsPerConnType; i++ {
  251. s := fmt.Sprintf("%s%s%d", connTypePrefix, getRandAlphanumericString(9999), i) // include index to prevent collision even though improbable
  252. connTypeCopy := connType // avoid capturing loop variable
  253. conns = append(conns, func() net.Conn {
  254. conn := testConn{
  255. connType: connTypeCopy,
  256. b: []byte(s),
  257. }
  258. return &conn
  259. })
  260. expected[connType] = append(expected[connType], s)
  261. }
  262. }
  263. test := &protocolDemuxTest{
  264. name: "multiple conns multiple of each type",
  265. classifiers: []protocolClassifier{
  266. aClassifier,
  267. bClassifier,
  268. cClassifier,
  269. },
  270. classifierType: []string{"A", "B", "C"},
  271. conns: conns,
  272. expected: expected,
  273. }
  274. for n := 0; n < b.N; n++ {
  275. err := runProtocolDemuxTest(test)
  276. if err != nil {
  277. b.Fatalf("runProtocolDemuxTest failed: %v", err)
  278. }
  279. }
  280. }
  281. func getRandAlphanumericString(n int) string {
  282. var alphanumericals = []rune("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
  283. b := make([]rune, n)
  284. for i := range b {
  285. b[i] = alphanumericals[rand.Intn(len(alphanumericals))]
  286. }
  287. return string(b)
  288. }
  289. type testListener struct {
  290. conns chan net.Conn
  291. }
  292. func (l testListener) Accept() (net.Conn, error) {
  293. conn := <-l.conns
  294. if conn == nil {
  295. // no more conns
  296. return nil, net.ErrClosed
  297. }
  298. return conn, nil
  299. }
  300. func (l testListener) Close() error {
  301. close(l.conns)
  302. return nil
  303. }
  304. func (l testListener) Addr() net.Addr {
  305. return nil
  306. }
  307. type testConn struct {
  308. // connType is the type of the underlying connection.
  309. connType string
  310. // b is the bytes to return over Read() calls.
  311. b []byte
  312. // maxReadLen is the maximum number of bytes to return from b in a single
  313. // Read() call if > 0; otherwise no limit is imposed.
  314. maxReadLen int
  315. // readErrs are returned from Read() calls in order. If empty, then a nil
  316. // error is returned.
  317. readErrs []error
  318. }
  319. func (c *testConn) Read(b []byte) (n int, err error) {
  320. if len(c.readErrs) > 0 {
  321. err := c.readErrs[0]
  322. c.readErrs = c.readErrs[1:]
  323. return 0, err
  324. }
  325. numBytes := len(b)
  326. if numBytes > c.maxReadLen && c.maxReadLen != 0 {
  327. numBytes = c.maxReadLen
  328. }
  329. if numBytes > len(c.b) {
  330. numBytes = len(c.b)
  331. }
  332. n = copy(b, c.b[:numBytes])
  333. c.b = c.b[n:]
  334. return n, nil
  335. }
  336. func (c *testConn) Write(b []byte) (n int, err error) {
  337. return 0, std_errors.New("not supported")
  338. }
  339. func (c *testConn) Close() error {
  340. return nil
  341. }
  342. func (c *testConn) LocalAddr() net.Addr {
  343. return nil
  344. }
  345. func (c *testConn) RemoteAddr() net.Addr {
  346. return nil
  347. }
  348. func (c *testConn) SetDeadline(t time.Time) error {
  349. return nil
  350. }
  351. func (c *testConn) SetReadDeadline(t time.Time) error {
  352. return nil
  353. }
  354. func (c *testConn) SetWriteDeadline(t time.Time) error {
  355. return nil
  356. }