inproxy_test.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050
  1. //go:build PSIPHON_ENABLE_INPROXY
  2. /*
  3. * Copyright (c) 2023, Psiphon Inc.
  4. * All rights reserved.
  5. *
  6. * This program is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License as published by
  8. * the Free Software Foundation, either version 3 of the License, or
  9. * (at your option) any later version.
  10. *
  11. * This program is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. * GNU General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU General Public License
  17. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. *
  19. */
  20. package inproxy
  21. import (
  22. "bytes"
  23. "context"
  24. std_tls "crypto/tls"
  25. "encoding/base64"
  26. "fmt"
  27. "io"
  28. "io/ioutil"
  29. "net"
  30. "net/http"
  31. _ "net/http/pprof"
  32. "strconv"
  33. "strings"
  34. "sync"
  35. "sync/atomic"
  36. "testing"
  37. "time"
  38. tls "github.com/Psiphon-Labs/psiphon-tls"
  39. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  40. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  43. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
  44. "golang.org/x/sync/errgroup"
  45. )
  46. func TestInproxy(t *testing.T) {
  47. err := runTestInproxy(false)
  48. if err != nil {
  49. t.Errorf(errors.Trace(err).Error())
  50. }
  51. }
  52. func TestInproxyMustUpgrade(t *testing.T) {
  53. err := runTestInproxy(true)
  54. if err != nil {
  55. t.Errorf(errors.Trace(err).Error())
  56. }
  57. }
  58. func runTestInproxy(doMustUpgrade bool) error {
  59. // Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
  60. numProxies := 5
  61. proxyMaxClients := 3
  62. numClients := 10
  63. bytesToSend := 1 << 20
  64. targetElapsedSeconds := 2
  65. baseAPIParameters := common.APIParameters{
  66. "sponsor_id": strings.ToUpper(prng.HexString(8)),
  67. "client_platform": "test-client-platform",
  68. }
  69. testCompartmentID, _ := MakeID()
  70. testCommonCompartmentIDs := []ID{testCompartmentID}
  71. testNetworkID := "NETWORK-ID-1"
  72. testNetworkType := NetworkTypeUnknown
  73. testNATType := NATTypeUnknown
  74. testSTUNServerAddress := "stun.nextcloud.com:443"
  75. testDisableSTUN := false
  76. testNewTacticsPayload := []byte(prng.HexString(100))
  77. testNewTacticsTag := "new-tactics-tag"
  78. testUnchangedTacticsPayload := []byte(prng.HexString(100))
  79. // TODO: test port mapping
  80. stunServerAddressSucceededCount := int32(0)
  81. stunServerAddressSucceeded := func(bool, string) { atomic.AddInt32(&stunServerAddressSucceededCount, 1) }
  82. stunServerAddressFailedCount := int32(0)
  83. stunServerAddressFailed := func(bool, string) { atomic.AddInt32(&stunServerAddressFailedCount, 1) }
  84. roundTripperSucceededCount := int32(0)
  85. roundTripperSucceded := func(RoundTripper) { atomic.AddInt32(&roundTripperSucceededCount, 1) }
  86. roundTripperFailedCount := int32(0)
  87. roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
  88. noMatch := func(RoundTripper) {}
  89. var receivedProxyMustUpgrade chan struct{}
  90. var receivedClientMustUpgrade chan struct{}
  91. if doMustUpgrade {
  92. receivedProxyMustUpgrade = make(chan struct{})
  93. receivedClientMustUpgrade = make(chan struct{})
  94. // trigger MustUpgrade
  95. proxyProtocolVersion = 0
  96. // Minimize test parameters for MustUpgrade case
  97. numProxies = 1
  98. proxyMaxClients = 1
  99. numClients = 1
  100. testDisableSTUN = true
  101. }
  102. testCtx, stopTest := context.WithCancel(context.Background())
  103. defer stopTest()
  104. testGroup := new(errgroup.Group)
  105. // Enable test to run without requiring host firewall exceptions
  106. SetAllowBogonWebRTCConnections(true)
  107. defer SetAllowBogonWebRTCConnections(false)
  108. // Init logging and profiling
  109. logger := newTestLogger()
  110. pprofListener, err := net.Listen("tcp", "127.0.0.1:0")
  111. go http.Serve(pprofListener, nil)
  112. defer pprofListener.Close()
  113. logger.WithTrace().Info(fmt.Sprintf("PPROF: http://%s/debug/pprof", pprofListener.Addr()))
  114. // Start echo servers
  115. tcpEchoListener, err := net.Listen("tcp", "127.0.0.1:0")
  116. if err != nil {
  117. return errors.Trace(err)
  118. }
  119. defer tcpEchoListener.Close()
  120. go runTCPEchoServer(tcpEchoListener)
  121. // QUIC tests UDP proxying, and provides reliable delivery of echoed data
  122. quicEchoServer, err := newQuicEchoServer()
  123. if err != nil {
  124. return errors.Trace(err)
  125. }
  126. defer quicEchoServer.Close()
  127. go quicEchoServer.Run()
  128. // Create signed server entry with capability
  129. serverPrivateKey, err := GenerateSessionPrivateKey()
  130. if err != nil {
  131. return errors.Trace(err)
  132. }
  133. serverPublicKey, err := serverPrivateKey.GetPublicKey()
  134. if err != nil {
  135. return errors.Trace(err)
  136. }
  137. serverRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  138. if err != nil {
  139. return errors.Trace(err)
  140. }
  141. serverEntry := make(protocol.ServerEntryFields)
  142. serverEntry["ipAddress"] = "127.0.0.1"
  143. _, tcpPort, _ := net.SplitHostPort(tcpEchoListener.Addr().String())
  144. _, udpPort, _ := net.SplitHostPort(quicEchoServer.Addr().String())
  145. serverEntry["inproxyOSSHPort"], _ = strconv.Atoi(tcpPort)
  146. serverEntry["inproxyQUICPort"], _ = strconv.Atoi(udpPort)
  147. serverEntry["capabilities"] = []string{"INPROXY-WEBRTC-OSSH", "INPROXY-WEBRTC-QUIC-OSSH"}
  148. serverEntry["inproxySessionPublicKey"] = base64.RawStdEncoding.EncodeToString(serverPublicKey[:])
  149. serverEntry["inproxySessionRootObfuscationSecret"] = base64.RawStdEncoding.EncodeToString(serverRootObfuscationSecret[:])
  150. testServerEntryTag := prng.HexString(16)
  151. serverEntry["tag"] = testServerEntryTag
  152. serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey, err :=
  153. protocol.NewServerEntrySignatureKeyPair()
  154. if err != nil {
  155. return errors.Trace(err)
  156. }
  157. err = serverEntry.AddSignature(serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey)
  158. if err != nil {
  159. return errors.Trace(err)
  160. }
  161. packedServerEntryFields, err := protocol.EncodePackedServerEntryFields(serverEntry)
  162. if err != nil {
  163. return errors.Trace(err)
  164. }
  165. packedDestinationServerEntry, err := protocol.CBOREncoding.Marshal(packedServerEntryFields)
  166. if err != nil {
  167. return errors.Trace(err)
  168. }
  169. // Start broker
  170. logger.WithTrace().Info("START BROKER")
  171. brokerPrivateKey, err := GenerateSessionPrivateKey()
  172. if err != nil {
  173. return errors.Trace(err)
  174. }
  175. brokerPublicKey, err := brokerPrivateKey.GetPublicKey()
  176. if err != nil {
  177. return errors.Trace(err)
  178. }
  179. brokerRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  180. if err != nil {
  181. return errors.Trace(err)
  182. }
  183. brokerListener, err := net.Listen("tcp", "127.0.0.1:0")
  184. if err != nil {
  185. return errors.Trace(err)
  186. }
  187. defer brokerListener.Close()
  188. brokerConfig := &BrokerConfig{
  189. Logger: logger,
  190. CommonCompartmentIDs: testCommonCompartmentIDs,
  191. APIParameterValidator: func(params common.APIParameters) error {
  192. if len(params) != len(baseAPIParameters) {
  193. return errors.TraceNew("unexpected base API parameter count")
  194. }
  195. for name, value := range params {
  196. if value.(string) != baseAPIParameters[name].(string) {
  197. return errors.Tracef(
  198. "unexpected base API parameter: %v: %v != %v",
  199. name,
  200. value.(string),
  201. baseAPIParameters[name].(string))
  202. }
  203. }
  204. return nil
  205. },
  206. APIParameterLogFieldFormatter: func(
  207. geoIPData common.GeoIPData, params common.APIParameters) common.LogFields {
  208. return common.LogFields(params)
  209. },
  210. GetTactics: func(_ common.GeoIPData, _ common.APIParameters) ([]byte, string, error) {
  211. // Exercise both new and unchanged tactics
  212. if prng.FlipCoin() {
  213. return testNewTacticsPayload, testNewTacticsTag, nil
  214. }
  215. return testUnchangedTacticsPayload, "", nil
  216. },
  217. IsValidServerEntryTag: func(serverEntryTag string) bool { return serverEntryTag == testServerEntryTag },
  218. PrivateKey: brokerPrivateKey,
  219. ObfuscationRootSecret: brokerRootObfuscationSecret,
  220. ServerEntrySignaturePublicKey: serverEntrySignaturePublicKey,
  221. AllowProxy: func(common.GeoIPData) bool { return true },
  222. AllowClient: func(common.GeoIPData) bool { return true },
  223. AllowDomainFrontedDestinations: func(common.GeoIPData) bool { return true },
  224. }
  225. broker, err := NewBroker(brokerConfig)
  226. if err != nil {
  227. return errors.Trace(err)
  228. }
  229. err = broker.Start()
  230. if err != nil {
  231. return errors.Trace(err)
  232. }
  233. defer broker.Stop()
  234. testGroup.Go(func() error {
  235. err := runHTTPServer(brokerListener, broker)
  236. if testCtx.Err() != nil {
  237. return nil
  238. }
  239. return errors.Trace(err)
  240. })
  241. // Stub server broker request handler (in Psiphon, this will be the
  242. // destination Psiphon server; here, it's not necessary to build this
  243. // handler into the destination echo server)
  244. serverSessions, err := NewServerBrokerSessions(
  245. serverPrivateKey, serverRootObfuscationSecret, []SessionPublicKey{brokerPublicKey})
  246. if err != nil {
  247. return errors.Trace(err)
  248. }
  249. var pendingBrokerServerReportsMutex sync.Mutex
  250. pendingBrokerServerReports := make(map[ID]bool)
  251. addPendingBrokerServerReport := func(connectionID ID) {
  252. pendingBrokerServerReportsMutex.Lock()
  253. defer pendingBrokerServerReportsMutex.Unlock()
  254. pendingBrokerServerReports[connectionID] = true
  255. }
  256. hasPendingBrokerServerReports := func() bool {
  257. pendingBrokerServerReportsMutex.Lock()
  258. defer pendingBrokerServerReportsMutex.Unlock()
  259. return len(pendingBrokerServerReports) > 0
  260. }
  261. handleBrokerServerReports := func(in []byte, clientConnectionID ID) ([]byte, error) {
  262. handler := func(brokerVerifiedOriginalClientIP string, logFields common.LogFields) {
  263. pendingBrokerServerReportsMutex.Lock()
  264. defer pendingBrokerServerReportsMutex.Unlock()
  265. // Mark the report as no longer outstanding
  266. delete(pendingBrokerServerReports, clientConnectionID)
  267. }
  268. out, err := serverSessions.HandlePacket(logger, in, clientConnectionID, handler)
  269. return out, errors.Trace(err)
  270. }
  271. // Check that the tactics round trip succeeds
  272. var pendingProxyTacticsCallbacksMutex sync.Mutex
  273. pendingProxyTacticsCallbacks := make(map[SessionPrivateKey]bool)
  274. addPendingProxyTacticsCallback := func(proxyPrivateKey SessionPrivateKey) {
  275. pendingProxyTacticsCallbacksMutex.Lock()
  276. defer pendingProxyTacticsCallbacksMutex.Unlock()
  277. pendingProxyTacticsCallbacks[proxyPrivateKey] = true
  278. }
  279. hasPendingProxyTacticsCallbacks := func() bool {
  280. pendingProxyTacticsCallbacksMutex.Lock()
  281. defer pendingProxyTacticsCallbacksMutex.Unlock()
  282. return len(pendingProxyTacticsCallbacks) > 0
  283. }
  284. makeHandleTacticsPayload := func(
  285. proxyPrivateKey SessionPrivateKey,
  286. tacticsNetworkID string) func(_ string, _ []byte) bool {
  287. return func(networkID string, tacticsPayload []byte) bool {
  288. pendingProxyTacticsCallbacksMutex.Lock()
  289. defer pendingProxyTacticsCallbacksMutex.Unlock()
  290. // Check that the correct networkID is passed around; if not,
  291. // skip the delete, which will fail the test
  292. if networkID == tacticsNetworkID {
  293. // Certain state is reset when new tactics are applied -- the
  294. // return true case; exercise both cases
  295. if bytes.Equal(tacticsPayload, testNewTacticsPayload) {
  296. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  297. return true
  298. }
  299. if bytes.Equal(tacticsPayload, testUnchangedTacticsPayload) {
  300. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  301. return false
  302. }
  303. }
  304. panic("unexpected tactics payload")
  305. }
  306. }
  307. // Start proxies
  308. logger.WithTrace().Info("START PROXIES")
  309. for i := 0; i < numProxies; i++ {
  310. proxyPrivateKey, err := GenerateSessionPrivateKey()
  311. if err != nil {
  312. return errors.Trace(err)
  313. }
  314. brokerCoordinator := &testBrokerDialCoordinator{
  315. networkID: testNetworkID,
  316. networkType: testNetworkType,
  317. brokerClientPrivateKey: proxyPrivateKey,
  318. brokerPublicKey: brokerPublicKey,
  319. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  320. brokerClientRoundTripper: newHTTPRoundTripper(
  321. brokerListener.Addr().String(), "proxy"),
  322. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  323. brokerClientRoundTripperFailed: roundTripperFailed,
  324. }
  325. webRTCCoordinator := &testWebRTCDialCoordinator{
  326. networkID: testNetworkID,
  327. networkType: testNetworkType,
  328. natType: testNATType,
  329. disableSTUN: testDisableSTUN,
  330. stunServerAddress: testSTUNServerAddress,
  331. stunServerAddressRFC5780: testSTUNServerAddress,
  332. stunServerAddressSucceeded: stunServerAddressSucceeded,
  333. stunServerAddressFailed: stunServerAddressFailed,
  334. setNATType: func(NATType) {},
  335. setPortMappingTypes: func(PortMappingTypes) {},
  336. bindToDevice: func(int) error { return nil },
  337. }
  338. // Each proxy has its own broker client
  339. brokerClient, err := NewBrokerClient(brokerCoordinator)
  340. if err != nil {
  341. return errors.Trace(err)
  342. }
  343. tacticsNetworkID := prng.HexString(32)
  344. runCtx, cancelRun := context.WithCancel(testCtx)
  345. // No deferred cancelRun due to testGroup.Go below
  346. proxy, err := NewProxy(&ProxyConfig{
  347. Logger: logger,
  348. WaitForNetworkConnectivity: func() bool {
  349. return true
  350. },
  351. GetBrokerClient: func() (*BrokerClient, error) {
  352. return brokerClient, nil
  353. },
  354. GetBaseAPIParameters: func() (common.APIParameters, string, error) {
  355. return baseAPIParameters, tacticsNetworkID, nil
  356. },
  357. MakeWebRTCDialCoordinator: func() (WebRTCDialCoordinator, error) {
  358. return webRTCCoordinator, nil
  359. },
  360. HandleTacticsPayload: makeHandleTacticsPayload(proxyPrivateKey, tacticsNetworkID),
  361. MaxClients: proxyMaxClients,
  362. LimitUpstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  363. LimitDownstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  364. ActivityUpdater: func(connectingClients int32, connectedClients int32,
  365. bytesUp int64, bytesDown int64, bytesDuration time.Duration) {
  366. fmt.Printf("[%s] ACTIVITY: %d connecting, %d connected, %d up, %d down\n",
  367. time.Now().UTC().Format(time.RFC3339),
  368. connectingClients, connectedClients, bytesUp, bytesDown)
  369. },
  370. MustUpgrade: func() {
  371. close(receivedProxyMustUpgrade)
  372. cancelRun()
  373. },
  374. })
  375. if err != nil {
  376. return errors.Trace(err)
  377. }
  378. addPendingProxyTacticsCallback(proxyPrivateKey)
  379. testGroup.Go(func() error {
  380. proxy.Run(runCtx)
  381. return nil
  382. })
  383. }
  384. // Await proxy announcements before starting clients
  385. //
  386. // - Announcements may delay due to proxyAnnounceRetryDelay in Proxy.Run,
  387. // plus NAT discovery
  388. //
  389. // - Don't wait for > numProxies announcements due to
  390. // InitiatorSessions.NewRoundTrip waitToShareSession limitation
  391. if !doMustUpgrade {
  392. for {
  393. time.Sleep(100 * time.Millisecond)
  394. broker.matcher.announcementQueueMutex.Lock()
  395. n := broker.matcher.announcementQueue.getLen()
  396. broker.matcher.announcementQueueMutex.Unlock()
  397. if n >= numProxies {
  398. break
  399. }
  400. }
  401. }
  402. // Start clients
  403. logger.WithTrace().Info("START CLIENTS")
  404. clientsGroup := new(errgroup.Group)
  405. makeClientFunc := func(
  406. isTCP bool,
  407. isMobile bool,
  408. brokerClient *BrokerClient,
  409. webRTCCoordinator WebRTCDialCoordinator) func() error {
  410. var networkProtocol NetworkProtocol
  411. var addr string
  412. var wrapWithQUIC bool
  413. if isTCP {
  414. networkProtocol = NetworkProtocolTCP
  415. addr = tcpEchoListener.Addr().String()
  416. } else {
  417. networkProtocol = NetworkProtocolUDP
  418. addr = quicEchoServer.Addr().String()
  419. wrapWithQUIC = true
  420. }
  421. return func() error {
  422. dialCtx, cancelDial := context.WithTimeout(testCtx, 60*time.Second)
  423. defer cancelDial()
  424. conn, err := DialClient(
  425. dialCtx,
  426. &ClientConfig{
  427. Logger: logger,
  428. BaseAPIParameters: baseAPIParameters,
  429. BrokerClient: brokerClient,
  430. WebRTCDialCoordinator: webRTCCoordinator,
  431. ReliableTransport: isTCP,
  432. DialNetworkProtocol: networkProtocol,
  433. DialAddress: addr,
  434. PackedDestinationServerEntry: packedDestinationServerEntry,
  435. MustUpgrade: func() {
  436. fmt.Printf("HI!\n")
  437. close(receivedClientMustUpgrade)
  438. cancelDial()
  439. },
  440. })
  441. if err != nil {
  442. return errors.Trace(err)
  443. }
  444. var relayConn net.Conn
  445. relayConn = conn
  446. if wrapWithQUIC {
  447. quicConn, err := quic.Dial(
  448. dialCtx,
  449. conn,
  450. &net.UDPAddr{Port: 1}, // This address is ignored, but the zero value is not allowed
  451. "test", "QUICv1", nil, quicEchoServer.ObfuscationKey(), nil, nil, true,
  452. false, false, common.WrapClientSessionCache(tls.NewLRUClientSessionCache(0), ""),
  453. )
  454. if err != nil {
  455. return errors.Trace(err)
  456. }
  457. relayConn = quicConn
  458. }
  459. addPendingBrokerServerReport(conn.GetConnectionID())
  460. signalRelayComplete := make(chan struct{})
  461. clientsGroup.Go(func() error {
  462. defer close(signalRelayComplete)
  463. in := conn.InitialRelayPacket()
  464. for in != nil {
  465. out, err := handleBrokerServerReports(in, conn.GetConnectionID())
  466. if err != nil {
  467. if out == nil {
  468. return errors.Trace(err)
  469. } else {
  470. fmt.Printf("HandlePacket returned packet and error: %v\n", err)
  471. // Proceed with reset session token packet
  472. }
  473. }
  474. if out == nil {
  475. // Relay is complete
  476. break
  477. }
  478. in, err = conn.RelayPacket(testCtx, out)
  479. if err != nil {
  480. return errors.Trace(err)
  481. }
  482. }
  483. return nil
  484. })
  485. sendBytes := prng.Bytes(bytesToSend)
  486. clientsGroup.Go(func() error {
  487. for n := 0; n < bytesToSend; {
  488. m := prng.Range(1024, 32768)
  489. if bytesToSend-n < m {
  490. m = bytesToSend - n
  491. }
  492. _, err := relayConn.Write(sendBytes[n : n+m])
  493. if err != nil {
  494. return errors.Trace(err)
  495. }
  496. n += m
  497. }
  498. fmt.Printf("%d bytes sent\n", bytesToSend)
  499. return nil
  500. })
  501. clientsGroup.Go(func() error {
  502. buf := make([]byte, 32768)
  503. n := 0
  504. for n < bytesToSend {
  505. m, err := relayConn.Read(buf)
  506. if err != nil {
  507. return errors.Trace(err)
  508. }
  509. if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
  510. return errors.Tracef(
  511. "unexpected bytes: expected at index %d, received at index %d",
  512. bytes.Index(sendBytes, buf[:m]), n)
  513. }
  514. n += m
  515. }
  516. fmt.Printf("%d bytes received\n", bytesToSend)
  517. select {
  518. case <-signalRelayComplete:
  519. case <-testCtx.Done():
  520. }
  521. relayConn.Close()
  522. conn.Close()
  523. return nil
  524. })
  525. return nil
  526. }
  527. }
  528. newClientParams := func(isMobile bool) (*BrokerClient, *testWebRTCDialCoordinator, error) {
  529. clientPrivateKey, err := GenerateSessionPrivateKey()
  530. if err != nil {
  531. return nil, nil, errors.Trace(err)
  532. }
  533. clientRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  534. if err != nil {
  535. return nil, nil, errors.Trace(err)
  536. }
  537. brokerCoordinator := &testBrokerDialCoordinator{
  538. networkID: testNetworkID,
  539. networkType: testNetworkType,
  540. commonCompartmentIDs: testCommonCompartmentIDs,
  541. brokerClientPrivateKey: clientPrivateKey,
  542. brokerPublicKey: brokerPublicKey,
  543. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  544. brokerClientRoundTripper: newHTTPRoundTripper(
  545. brokerListener.Addr().String(), "client"),
  546. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  547. brokerClientRoundTripperFailed: roundTripperFailed,
  548. brokerClientNoMatch: noMatch,
  549. }
  550. webRTCCoordinator := &testWebRTCDialCoordinator{
  551. networkID: testNetworkID,
  552. networkType: testNetworkType,
  553. natType: testNATType,
  554. disableSTUN: testDisableSTUN,
  555. stunServerAddress: testSTUNServerAddress,
  556. stunServerAddressRFC5780: testSTUNServerAddress,
  557. stunServerAddressSucceeded: stunServerAddressSucceeded,
  558. stunServerAddressFailed: stunServerAddressFailed,
  559. clientRootObfuscationSecret: clientRootObfuscationSecret,
  560. doDTLSRandomization: prng.FlipCoin(),
  561. trafficShapingParameters: &DataChannelTrafficShapingParameters{
  562. MinPaddedMessages: 0,
  563. MaxPaddedMessages: 10,
  564. MinPaddingSize: 0,
  565. MaxPaddingSize: 1500,
  566. MinDecoyMessages: 0,
  567. MaxDecoyMessages: 10,
  568. MinDecoySize: 1,
  569. MaxDecoySize: 1500,
  570. DecoyMessageProbability: 0.5,
  571. },
  572. setNATType: func(NATType) {},
  573. setPortMappingTypes: func(PortMappingTypes) {},
  574. bindToDevice: func(int) error { return nil },
  575. // With STUN enabled (testDisableSTUN = false), there are cases
  576. // where the WebRTC Data Channel is not successfully established.
  577. // With a short enough timeout here, clients will redial and
  578. // eventually succceed.
  579. webRTCAwaitDataChannelTimeout: 5 * time.Second,
  580. }
  581. if isMobile {
  582. webRTCCoordinator.networkType = NetworkTypeMobile
  583. webRTCCoordinator.disableInboundForMobileNetworks = true
  584. }
  585. brokerClient, err := NewBrokerClient(brokerCoordinator)
  586. if err != nil {
  587. return nil, nil, errors.Trace(err)
  588. }
  589. return brokerClient, webRTCCoordinator, nil
  590. }
  591. clientBrokerClient, clientWebRTCCoordinator, err := newClientParams(false)
  592. if err != nil {
  593. return errors.Trace(err)
  594. }
  595. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err := newClientParams(true)
  596. if err != nil {
  597. return errors.Trace(err)
  598. }
  599. for i := 0; i < numClients; i++ {
  600. // Test a mix of TCP and UDP proxying; also test the
  601. // DisableInboundForMobileNetworks code path.
  602. isTCP := i%2 == 0
  603. isMobile := i%4 == 0
  604. // Exercise BrokerClients shared by multiple clients, but also create
  605. // several broker clients.
  606. if i%8 == 0 {
  607. clientBrokerClient, clientWebRTCCoordinator, err = newClientParams(false)
  608. if err != nil {
  609. return errors.Trace(err)
  610. }
  611. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err = newClientParams(true)
  612. if err != nil {
  613. return errors.Trace(err)
  614. }
  615. }
  616. brokerClient := clientBrokerClient
  617. webRTCCoordinator := clientWebRTCCoordinator
  618. if isMobile {
  619. brokerClient = clientMobileBrokerClient
  620. webRTCCoordinator = clientMobileWebRTCCoordinator
  621. }
  622. clientsGroup.Go(makeClientFunc(isTCP, isMobile, brokerClient, webRTCCoordinator))
  623. }
  624. if doMustUpgrade {
  625. // Await MustUpgrade callbacks
  626. logger.WithTrace().Info("AWAIT MUST UPGRADE")
  627. <-receivedProxyMustUpgrade
  628. <-receivedClientMustUpgrade
  629. _ = clientsGroup.Wait()
  630. } else {
  631. // Await client transfers complete
  632. logger.WithTrace().Info("AWAIT DATA TRANSFER")
  633. err = clientsGroup.Wait()
  634. if err != nil {
  635. return errors.Trace(err)
  636. }
  637. logger.WithTrace().Info("DONE DATA TRANSFER")
  638. if hasPendingBrokerServerReports() {
  639. return errors.TraceNew("unexpected pending broker server requests")
  640. }
  641. if hasPendingProxyTacticsCallbacks() {
  642. return errors.TraceNew("unexpected pending proxy tactics callback")
  643. }
  644. // TODO: check that elapsed time is consistent with rate limit (+/-)
  645. // Check if STUN server replay callbacks were triggered
  646. if !testDisableSTUN {
  647. if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
  648. return errors.TraceNew("unexpected STUN server succeeded count")
  649. }
  650. }
  651. if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
  652. return errors.TraceNew("unexpected STUN server failed count")
  653. }
  654. // Check if RoundTripper server replay callbacks were triggered
  655. if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
  656. return errors.TraceNew("unexpected round tripper succeeded count")
  657. }
  658. if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
  659. return errors.TraceNew("unexpected round tripper failed count")
  660. }
  661. }
  662. // Await shutdowns
  663. stopTest()
  664. brokerListener.Close()
  665. err = testGroup.Wait()
  666. if err != nil {
  667. return errors.Trace(err)
  668. }
  669. return nil
  670. }
  671. func runHTTPServer(listener net.Listener, broker *Broker) error {
  672. handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  673. // For this test, clients set the path to "/client" and proxies
  674. // set the path to "/proxy" and we use that to create stub GeoIP
  675. // data to pass the not-same-ASN condition.
  676. var geoIPData common.GeoIPData
  677. geoIPData.ASN = r.URL.Path
  678. requestPayload, err := ioutil.ReadAll(
  679. http.MaxBytesReader(w, r.Body, BrokerMaxRequestBodySize))
  680. if err != nil {
  681. fmt.Printf("runHTTPServer ioutil.ReadAll failed: %v\n", err)
  682. http.Error(w, "", http.StatusNotFound)
  683. return
  684. }
  685. clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
  686. extendTimeout := func(timeout time.Duration) {
  687. // TODO: set insufficient initial timeout, so extension is
  688. // required for success
  689. http.NewResponseController(w).SetWriteDeadline(time.Now().Add(timeout))
  690. }
  691. responsePayload, err := broker.HandleSessionPacket(
  692. r.Context(),
  693. extendTimeout,
  694. nil,
  695. clientIP,
  696. geoIPData,
  697. requestPayload)
  698. if err != nil {
  699. fmt.Printf("runHTTPServer HandleSessionPacket failed: %v\n", err)
  700. http.Error(w, "", http.StatusNotFound)
  701. return
  702. }
  703. w.WriteHeader(http.StatusOK)
  704. w.Write(responsePayload)
  705. })
  706. // WriteTimeout will be extended via extendTimeout.
  707. httpServer := &http.Server{
  708. ReadTimeout: 10 * time.Second,
  709. WriteTimeout: 10 * time.Second,
  710. IdleTimeout: 1 * time.Minute,
  711. Handler: handler,
  712. }
  713. certificate, privateKey, _, err := common.GenerateWebServerCertificate("www.example.com")
  714. if err != nil {
  715. return errors.Trace(err)
  716. }
  717. tlsCert, err := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
  718. if err != nil {
  719. return errors.Trace(err)
  720. }
  721. tlsConfig := &tls.Config{
  722. Certificates: []tls.Certificate{tlsCert},
  723. }
  724. err = httpServer.Serve(tls.NewListener(listener, tlsConfig))
  725. return errors.Trace(err)
  726. }
  727. type httpRoundTripper struct {
  728. httpClient *http.Client
  729. endpointAddr string
  730. path string
  731. }
  732. func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
  733. return &httpRoundTripper{
  734. httpClient: &http.Client{
  735. Transport: &http.Transport{
  736. ForceAttemptHTTP2: true,
  737. MaxIdleConns: 2,
  738. IdleConnTimeout: 1 * time.Minute,
  739. TLSHandshakeTimeout: 10 * time.Second,
  740. TLSClientConfig: &std_tls.Config{
  741. InsecureSkipVerify: true,
  742. },
  743. },
  744. },
  745. endpointAddr: endpointAddr,
  746. path: path,
  747. }
  748. }
  749. func (r *httpRoundTripper) RoundTrip(
  750. ctx context.Context,
  751. roundTripDelay time.Duration,
  752. roundTripTimeout time.Duration,
  753. requestPayload []byte) ([]byte, error) {
  754. if roundTripDelay > 0 {
  755. common.SleepWithContext(ctx, roundTripDelay)
  756. }
  757. requestCtx, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
  758. defer requestCancelFunc()
  759. url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
  760. request, err := http.NewRequestWithContext(
  761. requestCtx, "POST", url, bytes.NewReader(requestPayload))
  762. if err != nil {
  763. return nil, errors.Trace(err)
  764. }
  765. response, err := r.httpClient.Do(request)
  766. if err != nil {
  767. return nil, errors.Trace(err)
  768. }
  769. defer response.Body.Close()
  770. if response.StatusCode != http.StatusOK {
  771. return nil, errors.Tracef("unexpected response status code: %d", response.StatusCode)
  772. }
  773. responsePayload, err := io.ReadAll(response.Body)
  774. if err != nil {
  775. return nil, errors.Trace(err)
  776. }
  777. return responsePayload, nil
  778. }
  779. func (r *httpRoundTripper) Close() error {
  780. r.httpClient.CloseIdleConnections()
  781. return nil
  782. }
  783. func runTCPEchoServer(listener net.Listener) {
  784. for {
  785. conn, err := listener.Accept()
  786. if err != nil {
  787. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  788. return
  789. }
  790. go func(conn net.Conn) {
  791. buf := make([]byte, 32768)
  792. for {
  793. n, err := conn.Read(buf)
  794. if n > 0 {
  795. _, err = conn.Write(buf[:n])
  796. }
  797. if err != nil {
  798. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  799. return
  800. }
  801. }
  802. }(conn)
  803. }
  804. }
  805. type quicEchoServer struct {
  806. listener net.Listener
  807. obfuscationKey string
  808. }
  809. func newQuicEchoServer() (*quicEchoServer, error) {
  810. obfuscationKey := prng.HexString(32)
  811. listener, err := quic.Listen(
  812. nil,
  813. nil,
  814. "127.0.0.1:0",
  815. obfuscationKey,
  816. false)
  817. if err != nil {
  818. return nil, errors.Trace(err)
  819. }
  820. return &quicEchoServer{
  821. listener: listener,
  822. obfuscationKey: obfuscationKey,
  823. }, nil
  824. }
  825. func (q *quicEchoServer) ObfuscationKey() string {
  826. return q.obfuscationKey
  827. }
  828. func (q *quicEchoServer) Close() error {
  829. return q.listener.Close()
  830. }
  831. func (q *quicEchoServer) Addr() net.Addr {
  832. return q.listener.Addr()
  833. }
  834. func (q *quicEchoServer) Run() {
  835. for {
  836. conn, err := q.listener.Accept()
  837. if err != nil {
  838. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  839. return
  840. }
  841. go func(conn net.Conn) {
  842. buf := make([]byte, 32768)
  843. for {
  844. n, err := conn.Read(buf)
  845. if n > 0 {
  846. _, err = conn.Write(buf[:n])
  847. }
  848. if err != nil {
  849. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  850. return
  851. }
  852. }
  853. }(conn)
  854. }
  855. }