inproxy_test.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056
  1. //go:build PSIPHON_ENABLE_INPROXY
  2. /*
  3. * Copyright (c) 2023, Psiphon Inc.
  4. * All rights reserved.
  5. *
  6. * This program is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License as published by
  8. * the Free Software Foundation, either version 3 of the License, or
  9. * (at your option) any later version.
  10. *
  11. * This program is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. * GNU General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU General Public License
  17. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. *
  19. */
  20. package inproxy
  21. import (
  22. "bytes"
  23. "context"
  24. std_tls "crypto/tls"
  25. "encoding/base64"
  26. "fmt"
  27. "io"
  28. "io/ioutil"
  29. "net"
  30. "net/http"
  31. _ "net/http/pprof"
  32. "strconv"
  33. "strings"
  34. "sync"
  35. "sync/atomic"
  36. "testing"
  37. "time"
  38. tls "github.com/Psiphon-Labs/psiphon-tls"
  39. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  40. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  43. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
  44. "golang.org/x/sync/errgroup"
  45. )
  46. func TestInproxy(t *testing.T) {
  47. err := runTestInproxy(false)
  48. if err != nil {
  49. t.Errorf(errors.Trace(err).Error())
  50. }
  51. }
  52. func TestInproxyMustUpgrade(t *testing.T) {
  53. err := runTestInproxy(true)
  54. if err != nil {
  55. t.Errorf(errors.Trace(err).Error())
  56. }
  57. }
  58. func runTestInproxy(doMustUpgrade bool) error {
  59. // Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
  60. numProxies := 5
  61. proxyMaxClients := 3
  62. numClients := 10
  63. bytesToSend := 1 << 20
  64. targetElapsedSeconds := 2
  65. baseAPIParameters := common.APIParameters{
  66. "sponsor_id": strings.ToUpper(prng.HexString(8)),
  67. "client_platform": "test-client-platform",
  68. }
  69. testCompartmentID, _ := MakeID()
  70. testCommonCompartmentIDs := []ID{testCompartmentID}
  71. testNetworkID := "NETWORK-ID-1"
  72. testNetworkType := NetworkTypeUnknown
  73. testNATType := NATTypeUnknown
  74. testSTUNServerAddress := "stun.nextcloud.com:443"
  75. testDisableSTUN := false
  76. testNewTacticsPayload := []byte(prng.HexString(100))
  77. testNewTacticsTag := "new-tactics-tag"
  78. testUnchangedTacticsPayload := []byte(prng.HexString(100))
  79. // TODO: test port mapping
  80. stunServerAddressSucceededCount := int32(0)
  81. stunServerAddressSucceeded := func(bool, string) { atomic.AddInt32(&stunServerAddressSucceededCount, 1) }
  82. stunServerAddressFailedCount := int32(0)
  83. stunServerAddressFailed := func(bool, string) { atomic.AddInt32(&stunServerAddressFailedCount, 1) }
  84. roundTripperSucceededCount := int32(0)
  85. roundTripperSucceded := func(RoundTripper) { atomic.AddInt32(&roundTripperSucceededCount, 1) }
  86. roundTripperFailedCount := int32(0)
  87. roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
  88. noMatch := func(RoundTripper) {}
  89. var receivedProxyMustUpgrade chan struct{}
  90. var receivedClientMustUpgrade chan struct{}
  91. if doMustUpgrade {
  92. receivedProxyMustUpgrade = make(chan struct{})
  93. receivedClientMustUpgrade = make(chan struct{})
  94. // trigger MustUpgrade
  95. proxyProtocolVersion = 0
  96. // Minimize test parameters for MustUpgrade case
  97. numProxies = 1
  98. proxyMaxClients = 1
  99. numClients = 1
  100. testDisableSTUN = true
  101. }
  102. testCtx, stopTest := context.WithCancel(context.Background())
  103. defer stopTest()
  104. testGroup := new(errgroup.Group)
  105. // Enable test to run without requiring host firewall exceptions
  106. SetAllowBogonWebRTCConnections(true)
  107. defer SetAllowBogonWebRTCConnections(false)
  108. // Init logging and profiling
  109. logger := newTestLogger()
  110. pprofListener, err := net.Listen("tcp", "127.0.0.1:0")
  111. go http.Serve(pprofListener, nil)
  112. defer pprofListener.Close()
  113. logger.WithTrace().Info(fmt.Sprintf("PPROF: http://%s/debug/pprof", pprofListener.Addr()))
  114. // Start echo servers
  115. tcpEchoListener, err := net.Listen("tcp", "127.0.0.1:0")
  116. if err != nil {
  117. return errors.Trace(err)
  118. }
  119. defer tcpEchoListener.Close()
  120. go runTCPEchoServer(tcpEchoListener)
  121. // QUIC tests UDP proxying, and provides reliable delivery of echoed data
  122. quicEchoServer, err := newQuicEchoServer()
  123. if err != nil {
  124. return errors.Trace(err)
  125. }
  126. defer quicEchoServer.Close()
  127. go quicEchoServer.Run()
  128. // Create signed server entry with capability
  129. serverPrivateKey, err := GenerateSessionPrivateKey()
  130. if err != nil {
  131. return errors.Trace(err)
  132. }
  133. serverPublicKey, err := serverPrivateKey.GetPublicKey()
  134. if err != nil {
  135. return errors.Trace(err)
  136. }
  137. serverRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  138. if err != nil {
  139. return errors.Trace(err)
  140. }
  141. serverEntry := make(protocol.ServerEntryFields)
  142. serverEntry["ipAddress"] = "127.0.0.1"
  143. _, tcpPort, _ := net.SplitHostPort(tcpEchoListener.Addr().String())
  144. _, udpPort, _ := net.SplitHostPort(quicEchoServer.Addr().String())
  145. serverEntry["inproxyOSSHPort"], _ = strconv.Atoi(tcpPort)
  146. serverEntry["inproxyQUICPort"], _ = strconv.Atoi(udpPort)
  147. serverEntry["capabilities"] = []string{"INPROXY-WEBRTC-OSSH", "INPROXY-WEBRTC-QUIC-OSSH"}
  148. serverEntry["inproxySessionPublicKey"] = base64.RawStdEncoding.EncodeToString(serverPublicKey[:])
  149. serverEntry["inproxySessionRootObfuscationSecret"] = base64.RawStdEncoding.EncodeToString(serverRootObfuscationSecret[:])
  150. testServerEntryTag := prng.HexString(16)
  151. serverEntry["tag"] = testServerEntryTag
  152. serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey, err :=
  153. protocol.NewServerEntrySignatureKeyPair()
  154. if err != nil {
  155. return errors.Trace(err)
  156. }
  157. err = serverEntry.AddSignature(serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey)
  158. if err != nil {
  159. return errors.Trace(err)
  160. }
  161. packedServerEntryFields, err := protocol.EncodePackedServerEntryFields(serverEntry)
  162. if err != nil {
  163. return errors.Trace(err)
  164. }
  165. packedDestinationServerEntry, err := protocol.CBOREncoding.Marshal(packedServerEntryFields)
  166. if err != nil {
  167. return errors.Trace(err)
  168. }
  169. // API parameter handlers
  170. apiParameterValidator := func(params common.APIParameters) error {
  171. if len(params) != len(baseAPIParameters) {
  172. return errors.TraceNew("unexpected base API parameter count")
  173. }
  174. for name, value := range params {
  175. if value.(string) != baseAPIParameters[name].(string) {
  176. return errors.Tracef(
  177. "unexpected base API parameter: %v: %v != %v",
  178. name,
  179. value.(string),
  180. baseAPIParameters[name].(string))
  181. }
  182. }
  183. return nil
  184. }
  185. apiParameterLogFieldFormatter := func(
  186. _ string, _ common.GeoIPData, params common.APIParameters) common.LogFields {
  187. return common.LogFields(params)
  188. }
  189. // Start broker
  190. logger.WithTrace().Info("START BROKER")
  191. brokerPrivateKey, err := GenerateSessionPrivateKey()
  192. if err != nil {
  193. return errors.Trace(err)
  194. }
  195. brokerPublicKey, err := brokerPrivateKey.GetPublicKey()
  196. if err != nil {
  197. return errors.Trace(err)
  198. }
  199. brokerRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  200. if err != nil {
  201. return errors.Trace(err)
  202. }
  203. brokerListener, err := net.Listen("tcp", "127.0.0.1:0")
  204. if err != nil {
  205. return errors.Trace(err)
  206. }
  207. defer brokerListener.Close()
  208. brokerConfig := &BrokerConfig{
  209. Logger: logger,
  210. CommonCompartmentIDs: testCommonCompartmentIDs,
  211. APIParameterValidator: apiParameterValidator,
  212. APIParameterLogFieldFormatter: apiParameterLogFieldFormatter,
  213. GetTacticsPayload: func(_ common.GeoIPData, _ common.APIParameters) ([]byte, string, error) {
  214. // Exercise both new and unchanged tactics
  215. if prng.FlipCoin() {
  216. return testNewTacticsPayload, testNewTacticsTag, nil
  217. }
  218. return testUnchangedTacticsPayload, "", nil
  219. },
  220. IsValidServerEntryTag: func(serverEntryTag string) bool { return serverEntryTag == testServerEntryTag },
  221. PrivateKey: brokerPrivateKey,
  222. ObfuscationRootSecret: brokerRootObfuscationSecret,
  223. ServerEntrySignaturePublicKey: serverEntrySignaturePublicKey,
  224. AllowProxy: func(common.GeoIPData) bool { return true },
  225. AllowClient: func(common.GeoIPData) bool { return true },
  226. AllowDomainFrontedDestinations: func(common.GeoIPData) bool { return true },
  227. }
  228. broker, err := NewBroker(brokerConfig)
  229. if err != nil {
  230. return errors.Trace(err)
  231. }
  232. err = broker.Start()
  233. if err != nil {
  234. return errors.Trace(err)
  235. }
  236. defer broker.Stop()
  237. testGroup.Go(func() error {
  238. err := runHTTPServer(brokerListener, broker)
  239. if testCtx.Err() != nil {
  240. return nil
  241. }
  242. return errors.Trace(err)
  243. })
  244. // Stub server broker request handler (in Psiphon, this will be the
  245. // destination Psiphon server; here, it's not necessary to build this
  246. // handler into the destination echo server)
  247. serverSessions, err := NewServerBrokerSessions(
  248. serverPrivateKey, serverRootObfuscationSecret, []SessionPublicKey{brokerPublicKey},
  249. apiParameterValidator, apiParameterLogFieldFormatter, "")
  250. if err != nil {
  251. return errors.Trace(err)
  252. }
  253. var pendingBrokerServerReportsMutex sync.Mutex
  254. pendingBrokerServerReports := make(map[ID]bool)
  255. addPendingBrokerServerReport := func(connectionID ID) {
  256. pendingBrokerServerReportsMutex.Lock()
  257. defer pendingBrokerServerReportsMutex.Unlock()
  258. pendingBrokerServerReports[connectionID] = true
  259. }
  260. hasPendingBrokerServerReports := func() bool {
  261. pendingBrokerServerReportsMutex.Lock()
  262. defer pendingBrokerServerReportsMutex.Unlock()
  263. return len(pendingBrokerServerReports) > 0
  264. }
  265. handleBrokerServerReports := func(in []byte, clientConnectionID ID) ([]byte, error) {
  266. handler := func(brokerVerifiedOriginalClientIP string, logFields common.LogFields) {
  267. pendingBrokerServerReportsMutex.Lock()
  268. defer pendingBrokerServerReportsMutex.Unlock()
  269. // Mark the report as no longer outstanding
  270. delete(pendingBrokerServerReports, clientConnectionID)
  271. }
  272. out, err := serverSessions.HandlePacket(logger, in, clientConnectionID, handler)
  273. return out, errors.Trace(err)
  274. }
  275. // Check that the tactics round trip succeeds
  276. var pendingProxyTacticsCallbacksMutex sync.Mutex
  277. pendingProxyTacticsCallbacks := make(map[SessionPrivateKey]bool)
  278. addPendingProxyTacticsCallback := func(proxyPrivateKey SessionPrivateKey) {
  279. pendingProxyTacticsCallbacksMutex.Lock()
  280. defer pendingProxyTacticsCallbacksMutex.Unlock()
  281. pendingProxyTacticsCallbacks[proxyPrivateKey] = true
  282. }
  283. hasPendingProxyTacticsCallbacks := func() bool {
  284. pendingProxyTacticsCallbacksMutex.Lock()
  285. defer pendingProxyTacticsCallbacksMutex.Unlock()
  286. return len(pendingProxyTacticsCallbacks) > 0
  287. }
  288. makeHandleTacticsPayload := func(
  289. proxyPrivateKey SessionPrivateKey,
  290. tacticsNetworkID string) func(_ string, _ []byte) bool {
  291. return func(networkID string, tacticsPayload []byte) bool {
  292. pendingProxyTacticsCallbacksMutex.Lock()
  293. defer pendingProxyTacticsCallbacksMutex.Unlock()
  294. // Check that the correct networkID is passed around; if not,
  295. // skip the delete, which will fail the test
  296. if networkID == tacticsNetworkID {
  297. // Certain state is reset when new tactics are applied -- the
  298. // return true case; exercise both cases
  299. if bytes.Equal(tacticsPayload, testNewTacticsPayload) {
  300. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  301. return true
  302. }
  303. if bytes.Equal(tacticsPayload, testUnchangedTacticsPayload) {
  304. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  305. return false
  306. }
  307. }
  308. panic("unexpected tactics payload")
  309. }
  310. }
  311. // Start proxies
  312. logger.WithTrace().Info("START PROXIES")
  313. for i := 0; i < numProxies; i++ {
  314. proxyPrivateKey, err := GenerateSessionPrivateKey()
  315. if err != nil {
  316. return errors.Trace(err)
  317. }
  318. brokerCoordinator := &testBrokerDialCoordinator{
  319. networkID: testNetworkID,
  320. networkType: testNetworkType,
  321. brokerClientPrivateKey: proxyPrivateKey,
  322. brokerPublicKey: brokerPublicKey,
  323. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  324. brokerClientRoundTripper: newHTTPRoundTripper(
  325. brokerListener.Addr().String(), "proxy"),
  326. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  327. brokerClientRoundTripperFailed: roundTripperFailed,
  328. }
  329. webRTCCoordinator := &testWebRTCDialCoordinator{
  330. networkID: testNetworkID,
  331. networkType: testNetworkType,
  332. natType: testNATType,
  333. disableSTUN: testDisableSTUN,
  334. stunServerAddress: testSTUNServerAddress,
  335. stunServerAddressRFC5780: testSTUNServerAddress,
  336. stunServerAddressSucceeded: stunServerAddressSucceeded,
  337. stunServerAddressFailed: stunServerAddressFailed,
  338. setNATType: func(NATType) {},
  339. setPortMappingTypes: func(PortMappingTypes) {},
  340. bindToDevice: func(int) error { return nil },
  341. }
  342. // Each proxy has its own broker client
  343. brokerClient, err := NewBrokerClient(brokerCoordinator)
  344. if err != nil {
  345. return errors.Trace(err)
  346. }
  347. tacticsNetworkID := prng.HexString(32)
  348. runCtx, cancelRun := context.WithCancel(testCtx)
  349. // No deferred cancelRun due to testGroup.Go below
  350. proxy, err := NewProxy(&ProxyConfig{
  351. Logger: logger,
  352. WaitForNetworkConnectivity: func() bool {
  353. return true
  354. },
  355. GetBrokerClient: func() (*BrokerClient, error) {
  356. return brokerClient, nil
  357. },
  358. GetBaseAPIParameters: func(bool) (common.APIParameters, string, error) {
  359. return baseAPIParameters, tacticsNetworkID, nil
  360. },
  361. MakeWebRTCDialCoordinator: func() (WebRTCDialCoordinator, error) {
  362. return webRTCCoordinator, nil
  363. },
  364. HandleTacticsPayload: makeHandleTacticsPayload(proxyPrivateKey, tacticsNetworkID),
  365. MaxClients: proxyMaxClients,
  366. LimitUpstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  367. LimitDownstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  368. ActivityUpdater: func(connectingClients int32, connectedClients int32,
  369. bytesUp int64, bytesDown int64, bytesDuration time.Duration) {
  370. fmt.Printf("[%s] ACTIVITY: %d connecting, %d connected, %d up, %d down\n",
  371. time.Now().UTC().Format(time.RFC3339),
  372. connectingClients, connectedClients, bytesUp, bytesDown)
  373. },
  374. MustUpgrade: func() {
  375. close(receivedProxyMustUpgrade)
  376. cancelRun()
  377. },
  378. })
  379. if err != nil {
  380. return errors.Trace(err)
  381. }
  382. addPendingProxyTacticsCallback(proxyPrivateKey)
  383. testGroup.Go(func() error {
  384. proxy.Run(runCtx)
  385. return nil
  386. })
  387. }
  388. // Await proxy announcements before starting clients
  389. //
  390. // - Announcements may delay due to proxyAnnounceRetryDelay in Proxy.Run,
  391. // plus NAT discovery
  392. //
  393. // - Don't wait for > numProxies announcements due to
  394. // InitiatorSessions.NewRoundTrip waitToShareSession limitation
  395. if !doMustUpgrade {
  396. for {
  397. time.Sleep(100 * time.Millisecond)
  398. broker.matcher.announcementQueueMutex.Lock()
  399. n := broker.matcher.announcementQueue.getLen()
  400. broker.matcher.announcementQueueMutex.Unlock()
  401. if n >= numProxies {
  402. break
  403. }
  404. }
  405. }
  406. // Start clients
  407. logger.WithTrace().Info("START CLIENTS")
  408. clientsGroup := new(errgroup.Group)
  409. makeClientFunc := func(
  410. isTCP bool,
  411. isMobile bool,
  412. brokerClient *BrokerClient,
  413. webRTCCoordinator WebRTCDialCoordinator) func() error {
  414. var networkProtocol NetworkProtocol
  415. var addr string
  416. var wrapWithQUIC bool
  417. if isTCP {
  418. networkProtocol = NetworkProtocolTCP
  419. addr = tcpEchoListener.Addr().String()
  420. } else {
  421. networkProtocol = NetworkProtocolUDP
  422. addr = quicEchoServer.Addr().String()
  423. wrapWithQUIC = true
  424. }
  425. return func() error {
  426. dialCtx, cancelDial := context.WithTimeout(testCtx, 60*time.Second)
  427. defer cancelDial()
  428. conn, err := DialClient(
  429. dialCtx,
  430. &ClientConfig{
  431. Logger: logger,
  432. BaseAPIParameters: baseAPIParameters,
  433. BrokerClient: brokerClient,
  434. WebRTCDialCoordinator: webRTCCoordinator,
  435. ReliableTransport: isTCP,
  436. DialNetworkProtocol: networkProtocol,
  437. DialAddress: addr,
  438. PackedDestinationServerEntry: packedDestinationServerEntry,
  439. MustUpgrade: func() {
  440. close(receivedClientMustUpgrade)
  441. cancelDial()
  442. },
  443. })
  444. if err != nil {
  445. return errors.Trace(err)
  446. }
  447. var relayConn net.Conn
  448. relayConn = conn
  449. if wrapWithQUIC {
  450. quicConn, err := quic.Dial(
  451. dialCtx,
  452. conn,
  453. &net.UDPAddr{Port: 1}, // This address is ignored, but the zero value is not allowed
  454. "test", "QUICv1", nil, quicEchoServer.ObfuscationKey(), nil, nil, true,
  455. false, false, common.WrapClientSessionCache(tls.NewLRUClientSessionCache(0), ""),
  456. )
  457. if err != nil {
  458. return errors.Trace(err)
  459. }
  460. relayConn = quicConn
  461. }
  462. addPendingBrokerServerReport(conn.GetConnectionID())
  463. signalRelayComplete := make(chan struct{})
  464. clientsGroup.Go(func() error {
  465. defer close(signalRelayComplete)
  466. in := conn.InitialRelayPacket()
  467. for in != nil {
  468. out, err := handleBrokerServerReports(in, conn.GetConnectionID())
  469. if err != nil {
  470. if out == nil {
  471. return errors.Trace(err)
  472. } else {
  473. fmt.Printf("HandlePacket returned packet and error: %v\n", err)
  474. // Proceed with reset session token packet
  475. }
  476. }
  477. if out == nil {
  478. // Relay is complete
  479. break
  480. }
  481. in, err = conn.RelayPacket(testCtx, out)
  482. if err != nil {
  483. return errors.Trace(err)
  484. }
  485. }
  486. return nil
  487. })
  488. sendBytes := prng.Bytes(bytesToSend)
  489. clientsGroup.Go(func() error {
  490. for n := 0; n < bytesToSend; {
  491. m := prng.Range(1024, 32768)
  492. if bytesToSend-n < m {
  493. m = bytesToSend - n
  494. }
  495. _, err := relayConn.Write(sendBytes[n : n+m])
  496. if err != nil {
  497. return errors.Trace(err)
  498. }
  499. n += m
  500. }
  501. fmt.Printf("%d bytes sent\n", bytesToSend)
  502. return nil
  503. })
  504. clientsGroup.Go(func() error {
  505. buf := make([]byte, 32768)
  506. n := 0
  507. for n < bytesToSend {
  508. m, err := relayConn.Read(buf)
  509. if err != nil {
  510. return errors.Trace(err)
  511. }
  512. if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
  513. return errors.Tracef(
  514. "unexpected bytes: expected at index %d, received at index %d",
  515. bytes.Index(sendBytes, buf[:m]), n)
  516. }
  517. n += m
  518. }
  519. fmt.Printf("%d bytes received\n", bytesToSend)
  520. select {
  521. case <-signalRelayComplete:
  522. case <-testCtx.Done():
  523. }
  524. relayConn.Close()
  525. conn.Close()
  526. return nil
  527. })
  528. return nil
  529. }
  530. }
  531. newClientParams := func(isMobile bool) (*BrokerClient, *testWebRTCDialCoordinator, error) {
  532. clientPrivateKey, err := GenerateSessionPrivateKey()
  533. if err != nil {
  534. return nil, nil, errors.Trace(err)
  535. }
  536. clientRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  537. if err != nil {
  538. return nil, nil, errors.Trace(err)
  539. }
  540. brokerCoordinator := &testBrokerDialCoordinator{
  541. networkID: testNetworkID,
  542. networkType: testNetworkType,
  543. commonCompartmentIDs: testCommonCompartmentIDs,
  544. brokerClientPrivateKey: clientPrivateKey,
  545. brokerPublicKey: brokerPublicKey,
  546. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  547. brokerClientRoundTripper: newHTTPRoundTripper(
  548. brokerListener.Addr().String(), "client"),
  549. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  550. brokerClientRoundTripperFailed: roundTripperFailed,
  551. brokerClientNoMatch: noMatch,
  552. }
  553. webRTCCoordinator := &testWebRTCDialCoordinator{
  554. networkID: testNetworkID,
  555. networkType: testNetworkType,
  556. natType: testNATType,
  557. disableSTUN: testDisableSTUN,
  558. stunServerAddress: testSTUNServerAddress,
  559. stunServerAddressRFC5780: testSTUNServerAddress,
  560. stunServerAddressSucceeded: stunServerAddressSucceeded,
  561. stunServerAddressFailed: stunServerAddressFailed,
  562. clientRootObfuscationSecret: clientRootObfuscationSecret,
  563. doDTLSRandomization: prng.FlipCoin(),
  564. trafficShapingParameters: &DataChannelTrafficShapingParameters{
  565. MinPaddedMessages: 0,
  566. MaxPaddedMessages: 10,
  567. MinPaddingSize: 0,
  568. MaxPaddingSize: 1500,
  569. MinDecoyMessages: 0,
  570. MaxDecoyMessages: 10,
  571. MinDecoySize: 1,
  572. MaxDecoySize: 1500,
  573. DecoyMessageProbability: 0.5,
  574. },
  575. setNATType: func(NATType) {},
  576. setPortMappingTypes: func(PortMappingTypes) {},
  577. bindToDevice: func(int) error { return nil },
  578. // With STUN enabled (testDisableSTUN = false), there are cases
  579. // where the WebRTC Data Channel is not successfully established.
  580. // With a short enough timeout here, clients will redial and
  581. // eventually succceed.
  582. webRTCAwaitDataChannelTimeout: 5 * time.Second,
  583. }
  584. if isMobile {
  585. webRTCCoordinator.networkType = NetworkTypeMobile
  586. webRTCCoordinator.disableInboundForMobileNetworks = true
  587. }
  588. brokerClient, err := NewBrokerClient(brokerCoordinator)
  589. if err != nil {
  590. return nil, nil, errors.Trace(err)
  591. }
  592. return brokerClient, webRTCCoordinator, nil
  593. }
  594. clientBrokerClient, clientWebRTCCoordinator, err := newClientParams(false)
  595. if err != nil {
  596. return errors.Trace(err)
  597. }
  598. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err := newClientParams(true)
  599. if err != nil {
  600. return errors.Trace(err)
  601. }
  602. for i := 0; i < numClients; i++ {
  603. // Test a mix of TCP and UDP proxying; also test the
  604. // DisableInboundForMobileNetworks code path.
  605. isTCP := i%2 == 0
  606. isMobile := i%4 == 0
  607. // Exercise BrokerClients shared by multiple clients, but also create
  608. // several broker clients.
  609. if i%8 == 0 {
  610. clientBrokerClient, clientWebRTCCoordinator, err = newClientParams(false)
  611. if err != nil {
  612. return errors.Trace(err)
  613. }
  614. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err = newClientParams(true)
  615. if err != nil {
  616. return errors.Trace(err)
  617. }
  618. }
  619. brokerClient := clientBrokerClient
  620. webRTCCoordinator := clientWebRTCCoordinator
  621. if isMobile {
  622. brokerClient = clientMobileBrokerClient
  623. webRTCCoordinator = clientMobileWebRTCCoordinator
  624. }
  625. clientsGroup.Go(makeClientFunc(isTCP, isMobile, brokerClient, webRTCCoordinator))
  626. }
  627. if doMustUpgrade {
  628. // Await MustUpgrade callbacks
  629. logger.WithTrace().Info("AWAIT MUST UPGRADE")
  630. <-receivedProxyMustUpgrade
  631. <-receivedClientMustUpgrade
  632. _ = clientsGroup.Wait()
  633. } else {
  634. // Await client transfers complete
  635. logger.WithTrace().Info("AWAIT DATA TRANSFER")
  636. err = clientsGroup.Wait()
  637. if err != nil {
  638. return errors.Trace(err)
  639. }
  640. logger.WithTrace().Info("DONE DATA TRANSFER")
  641. if hasPendingBrokerServerReports() {
  642. return errors.TraceNew("unexpected pending broker server requests")
  643. }
  644. if hasPendingProxyTacticsCallbacks() {
  645. return errors.TraceNew("unexpected pending proxy tactics callback")
  646. }
  647. // TODO: check that elapsed time is consistent with rate limit (+/-)
  648. // Check if STUN server replay callbacks were triggered
  649. if !testDisableSTUN {
  650. if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
  651. return errors.TraceNew("unexpected STUN server succeeded count")
  652. }
  653. }
  654. if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
  655. return errors.TraceNew("unexpected STUN server failed count")
  656. }
  657. // Check if RoundTripper server replay callbacks were triggered
  658. if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
  659. return errors.TraceNew("unexpected round tripper succeeded count")
  660. }
  661. if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
  662. return errors.TraceNew("unexpected round tripper failed count")
  663. }
  664. }
  665. // Await shutdowns
  666. stopTest()
  667. brokerListener.Close()
  668. err = testGroup.Wait()
  669. if err != nil {
  670. return errors.Trace(err)
  671. }
  672. return nil
  673. }
  674. func runHTTPServer(listener net.Listener, broker *Broker) error {
  675. handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  676. // For this test, clients set the path to "/client" and proxies
  677. // set the path to "/proxy" and we use that to create stub GeoIP
  678. // data to pass the not-same-ASN condition.
  679. var geoIPData common.GeoIPData
  680. geoIPData.ASN = r.URL.Path
  681. requestPayload, err := ioutil.ReadAll(
  682. http.MaxBytesReader(w, r.Body, BrokerMaxRequestBodySize))
  683. if err != nil {
  684. fmt.Printf("runHTTPServer ioutil.ReadAll failed: %v\n", err)
  685. http.Error(w, "", http.StatusNotFound)
  686. return
  687. }
  688. clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
  689. extendTimeout := func(timeout time.Duration) {
  690. // TODO: set insufficient initial timeout, so extension is
  691. // required for success
  692. http.NewResponseController(w).SetWriteDeadline(time.Now().Add(timeout))
  693. }
  694. responsePayload, err := broker.HandleSessionPacket(
  695. r.Context(),
  696. extendTimeout,
  697. nil,
  698. clientIP,
  699. geoIPData,
  700. requestPayload)
  701. if err != nil {
  702. fmt.Printf("runHTTPServer HandleSessionPacket failed: %v\n", err)
  703. http.Error(w, "", http.StatusNotFound)
  704. return
  705. }
  706. w.WriteHeader(http.StatusOK)
  707. w.Write(responsePayload)
  708. })
  709. // WriteTimeout will be extended via extendTimeout.
  710. httpServer := &http.Server{
  711. ReadTimeout: 10 * time.Second,
  712. WriteTimeout: 10 * time.Second,
  713. IdleTimeout: 1 * time.Minute,
  714. Handler: handler,
  715. }
  716. certificate, privateKey, _, err := common.GenerateWebServerCertificate("www.example.com")
  717. if err != nil {
  718. return errors.Trace(err)
  719. }
  720. tlsCert, err := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
  721. if err != nil {
  722. return errors.Trace(err)
  723. }
  724. tlsConfig := &tls.Config{
  725. Certificates: []tls.Certificate{tlsCert},
  726. }
  727. err = httpServer.Serve(tls.NewListener(listener, tlsConfig))
  728. return errors.Trace(err)
  729. }
  730. type httpRoundTripper struct {
  731. httpClient *http.Client
  732. endpointAddr string
  733. path string
  734. }
  735. func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
  736. return &httpRoundTripper{
  737. httpClient: &http.Client{
  738. Transport: &http.Transport{
  739. ForceAttemptHTTP2: true,
  740. MaxIdleConns: 2,
  741. IdleConnTimeout: 1 * time.Minute,
  742. TLSHandshakeTimeout: 10 * time.Second,
  743. TLSClientConfig: &std_tls.Config{
  744. InsecureSkipVerify: true,
  745. },
  746. },
  747. },
  748. endpointAddr: endpointAddr,
  749. path: path,
  750. }
  751. }
  752. func (r *httpRoundTripper) RoundTrip(
  753. ctx context.Context,
  754. roundTripDelay time.Duration,
  755. roundTripTimeout time.Duration,
  756. requestPayload []byte) ([]byte, error) {
  757. if roundTripDelay > 0 {
  758. common.SleepWithContext(ctx, roundTripDelay)
  759. }
  760. requestCtx, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
  761. defer requestCancelFunc()
  762. url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
  763. request, err := http.NewRequestWithContext(
  764. requestCtx, "POST", url, bytes.NewReader(requestPayload))
  765. if err != nil {
  766. return nil, errors.Trace(err)
  767. }
  768. response, err := r.httpClient.Do(request)
  769. if err != nil {
  770. return nil, errors.Trace(err)
  771. }
  772. defer response.Body.Close()
  773. if response.StatusCode != http.StatusOK {
  774. return nil, errors.Tracef("unexpected response status code: %d", response.StatusCode)
  775. }
  776. responsePayload, err := io.ReadAll(response.Body)
  777. if err != nil {
  778. return nil, errors.Trace(err)
  779. }
  780. return responsePayload, nil
  781. }
  782. func (r *httpRoundTripper) Close() error {
  783. r.httpClient.CloseIdleConnections()
  784. return nil
  785. }
  786. func runTCPEchoServer(listener net.Listener) {
  787. for {
  788. conn, err := listener.Accept()
  789. if err != nil {
  790. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  791. return
  792. }
  793. go func(conn net.Conn) {
  794. buf := make([]byte, 32768)
  795. for {
  796. n, err := conn.Read(buf)
  797. if n > 0 {
  798. _, err = conn.Write(buf[:n])
  799. }
  800. if err != nil {
  801. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  802. return
  803. }
  804. }
  805. }(conn)
  806. }
  807. }
  808. type quicEchoServer struct {
  809. listener net.Listener
  810. obfuscationKey string
  811. }
  812. func newQuicEchoServer() (*quicEchoServer, error) {
  813. obfuscationKey := prng.HexString(32)
  814. listener, err := quic.Listen(
  815. nil,
  816. nil,
  817. "127.0.0.1:0",
  818. obfuscationKey,
  819. false)
  820. if err != nil {
  821. return nil, errors.Trace(err)
  822. }
  823. return &quicEchoServer{
  824. listener: listener,
  825. obfuscationKey: obfuscationKey,
  826. }, nil
  827. }
  828. func (q *quicEchoServer) ObfuscationKey() string {
  829. return q.obfuscationKey
  830. }
  831. func (q *quicEchoServer) Close() error {
  832. return q.listener.Close()
  833. }
  834. func (q *quicEchoServer) Addr() net.Addr {
  835. return q.listener.Addr()
  836. }
  837. func (q *quicEchoServer) Run() {
  838. for {
  839. conn, err := q.listener.Accept()
  840. if err != nil {
  841. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  842. return
  843. }
  844. go func(conn net.Conn) {
  845. buf := make([]byte, 32768)
  846. for {
  847. n, err := conn.Read(buf)
  848. if n > 0 {
  849. _, err = conn.Write(buf[:n])
  850. }
  851. if err != nil {
  852. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  853. return
  854. }
  855. }
  856. }(conn)
  857. }
  858. }