inproxy_test.go 28 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063
  1. //go:build PSIPHON_ENABLE_INPROXY
  2. /*
  3. * Copyright (c) 2023, Psiphon Inc.
  4. * All rights reserved.
  5. *
  6. * This program is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License as published by
  8. * the Free Software Foundation, either version 3 of the License, or
  9. * (at your option) any later version.
  10. *
  11. * This program is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. * GNU General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU General Public License
  17. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. *
  19. */
  20. package inproxy
  21. import (
  22. "bytes"
  23. "context"
  24. std_tls "crypto/tls"
  25. "encoding/base64"
  26. "fmt"
  27. "io"
  28. "io/ioutil"
  29. "net"
  30. "net/http"
  31. _ "net/http/pprof"
  32. "strconv"
  33. "strings"
  34. "sync"
  35. "sync/atomic"
  36. "testing"
  37. "time"
  38. tls "github.com/Psiphon-Labs/psiphon-tls"
  39. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  40. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  43. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
  44. "golang.org/x/sync/errgroup"
  45. )
  46. func TestInproxy(t *testing.T) {
  47. err := runTestInproxy(false)
  48. if err != nil {
  49. t.Errorf(errors.Trace(err).Error())
  50. }
  51. }
  52. func TestInproxyMustUpgrade(t *testing.T) {
  53. err := runTestInproxy(true)
  54. if err != nil {
  55. t.Errorf(errors.Trace(err).Error())
  56. }
  57. }
  58. func runTestInproxy(doMustUpgrade bool) error {
  59. // Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
  60. numProxies := 5
  61. proxyMaxClients := 3
  62. numClients := 10
  63. bytesToSend := 1 << 20
  64. targetElapsedSeconds := 2
  65. baseAPIParameters := common.APIParameters{
  66. "sponsor_id": strings.ToUpper(prng.HexString(8)),
  67. "client_platform": "test-client-platform",
  68. }
  69. testCompartmentID, _ := MakeID()
  70. testCommonCompartmentIDs := []ID{testCompartmentID}
  71. testNetworkID := "NETWORK-ID-1"
  72. testNetworkType := NetworkTypeUnknown
  73. testNATType := NATTypeUnknown
  74. testSTUNServerAddress := "stun.nextcloud.com:443"
  75. testDisableSTUN := false
  76. testNewTacticsPayload := []byte(prng.HexString(100))
  77. testNewTacticsTag := "new-tactics-tag"
  78. testUnchangedTacticsPayload := []byte(prng.HexString(100))
  79. currentNetworkCtx, currentNetworkCancelFunc := context.WithCancel(context.Background())
  80. defer currentNetworkCancelFunc()
  81. // TODO: test port mapping
  82. stunServerAddressSucceededCount := int32(0)
  83. stunServerAddressSucceeded := func(bool, string) { atomic.AddInt32(&stunServerAddressSucceededCount, 1) }
  84. stunServerAddressFailedCount := int32(0)
  85. stunServerAddressFailed := func(bool, string) { atomic.AddInt32(&stunServerAddressFailedCount, 1) }
  86. roundTripperSucceededCount := int32(0)
  87. roundTripperSucceded := func(RoundTripper) { atomic.AddInt32(&roundTripperSucceededCount, 1) }
  88. roundTripperFailedCount := int32(0)
  89. roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
  90. noMatch := func(RoundTripper) {}
  91. var receivedProxyMustUpgrade chan struct{}
  92. var receivedClientMustUpgrade chan struct{}
  93. if doMustUpgrade {
  94. receivedProxyMustUpgrade = make(chan struct{})
  95. receivedClientMustUpgrade = make(chan struct{})
  96. // trigger MustUpgrade
  97. proxyProtocolVersion = 0
  98. // Minimize test parameters for MustUpgrade case
  99. numProxies = 1
  100. proxyMaxClients = 1
  101. numClients = 1
  102. testDisableSTUN = true
  103. }
  104. testCtx, stopTest := context.WithCancel(context.Background())
  105. defer stopTest()
  106. testGroup := new(errgroup.Group)
  107. // Enable test to run without requiring host firewall exceptions
  108. SetAllowBogonWebRTCConnections(true)
  109. defer SetAllowBogonWebRTCConnections(false)
  110. // Init logging and profiling
  111. logger := newTestLogger()
  112. pprofListener, err := net.Listen("tcp", "127.0.0.1:0")
  113. go http.Serve(pprofListener, nil)
  114. defer pprofListener.Close()
  115. logger.WithTrace().Info(fmt.Sprintf("PPROF: http://%s/debug/pprof", pprofListener.Addr()))
  116. // Start echo servers
  117. tcpEchoListener, err := net.Listen("tcp", "127.0.0.1:0")
  118. if err != nil {
  119. return errors.Trace(err)
  120. }
  121. defer tcpEchoListener.Close()
  122. go runTCPEchoServer(tcpEchoListener)
  123. // QUIC tests UDP proxying, and provides reliable delivery of echoed data
  124. quicEchoServer, err := newQuicEchoServer()
  125. if err != nil {
  126. return errors.Trace(err)
  127. }
  128. defer quicEchoServer.Close()
  129. go quicEchoServer.Run()
  130. // Create signed server entry with capability
  131. serverPrivateKey, err := GenerateSessionPrivateKey()
  132. if err != nil {
  133. return errors.Trace(err)
  134. }
  135. serverPublicKey, err := serverPrivateKey.GetPublicKey()
  136. if err != nil {
  137. return errors.Trace(err)
  138. }
  139. serverRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  140. if err != nil {
  141. return errors.Trace(err)
  142. }
  143. serverEntry := make(protocol.ServerEntryFields)
  144. serverEntry["ipAddress"] = "127.0.0.1"
  145. _, tcpPort, _ := net.SplitHostPort(tcpEchoListener.Addr().String())
  146. _, udpPort, _ := net.SplitHostPort(quicEchoServer.Addr().String())
  147. serverEntry["inproxyOSSHPort"], _ = strconv.Atoi(tcpPort)
  148. serverEntry["inproxyQUICPort"], _ = strconv.Atoi(udpPort)
  149. serverEntry["capabilities"] = []string{"INPROXY-WEBRTC-OSSH", "INPROXY-WEBRTC-QUIC-OSSH"}
  150. serverEntry["inproxySessionPublicKey"] = base64.RawStdEncoding.EncodeToString(serverPublicKey[:])
  151. serverEntry["inproxySessionRootObfuscationSecret"] = base64.RawStdEncoding.EncodeToString(serverRootObfuscationSecret[:])
  152. testServerEntryTag := prng.HexString(16)
  153. serverEntry["tag"] = testServerEntryTag
  154. serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey, err :=
  155. protocol.NewServerEntrySignatureKeyPair()
  156. if err != nil {
  157. return errors.Trace(err)
  158. }
  159. err = serverEntry.AddSignature(serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey)
  160. if err != nil {
  161. return errors.Trace(err)
  162. }
  163. packedServerEntryFields, err := protocol.EncodePackedServerEntryFields(serverEntry)
  164. if err != nil {
  165. return errors.Trace(err)
  166. }
  167. packedDestinationServerEntry, err := protocol.CBOREncoding.Marshal(packedServerEntryFields)
  168. if err != nil {
  169. return errors.Trace(err)
  170. }
  171. // API parameter handlers
  172. apiParameterValidator := func(params common.APIParameters) error {
  173. if len(params) != len(baseAPIParameters) {
  174. return errors.TraceNew("unexpected base API parameter count")
  175. }
  176. for name, value := range params {
  177. if value.(string) != baseAPIParameters[name].(string) {
  178. return errors.Tracef(
  179. "unexpected base API parameter: %v: %v != %v",
  180. name,
  181. value.(string),
  182. baseAPIParameters[name].(string))
  183. }
  184. }
  185. return nil
  186. }
  187. apiParameterLogFieldFormatter := func(
  188. _ string, _ common.GeoIPData, params common.APIParameters) common.LogFields {
  189. return common.LogFields(params)
  190. }
  191. // Start broker
  192. logger.WithTrace().Info("START BROKER")
  193. brokerPrivateKey, err := GenerateSessionPrivateKey()
  194. if err != nil {
  195. return errors.Trace(err)
  196. }
  197. brokerPublicKey, err := brokerPrivateKey.GetPublicKey()
  198. if err != nil {
  199. return errors.Trace(err)
  200. }
  201. brokerRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  202. if err != nil {
  203. return errors.Trace(err)
  204. }
  205. brokerListener, err := net.Listen("tcp", "127.0.0.1:0")
  206. if err != nil {
  207. return errors.Trace(err)
  208. }
  209. defer brokerListener.Close()
  210. brokerConfig := &BrokerConfig{
  211. Logger: logger,
  212. CommonCompartmentIDs: testCommonCompartmentIDs,
  213. APIParameterValidator: apiParameterValidator,
  214. APIParameterLogFieldFormatter: apiParameterLogFieldFormatter,
  215. GetTacticsPayload: func(_ common.GeoIPData, _ common.APIParameters) ([]byte, string, error) {
  216. // Exercise both new and unchanged tactics
  217. if prng.FlipCoin() {
  218. return testNewTacticsPayload, testNewTacticsTag, nil
  219. }
  220. return testUnchangedTacticsPayload, "", nil
  221. },
  222. IsValidServerEntryTag: func(serverEntryTag string) bool { return serverEntryTag == testServerEntryTag },
  223. PrivateKey: brokerPrivateKey,
  224. ObfuscationRootSecret: brokerRootObfuscationSecret,
  225. ServerEntrySignaturePublicKey: serverEntrySignaturePublicKey,
  226. AllowProxy: func(common.GeoIPData) bool { return true },
  227. AllowClient: func(common.GeoIPData) bool { return true },
  228. AllowDomainFrontedDestinations: func(common.GeoIPData) bool { return true },
  229. }
  230. broker, err := NewBroker(brokerConfig)
  231. if err != nil {
  232. return errors.Trace(err)
  233. }
  234. err = broker.Start()
  235. if err != nil {
  236. return errors.Trace(err)
  237. }
  238. defer broker.Stop()
  239. testGroup.Go(func() error {
  240. err := runHTTPServer(brokerListener, broker)
  241. if testCtx.Err() != nil {
  242. return nil
  243. }
  244. return errors.Trace(err)
  245. })
  246. // Stub server broker request handler (in Psiphon, this will be the
  247. // destination Psiphon server; here, it's not necessary to build this
  248. // handler into the destination echo server)
  249. serverSessions, err := NewServerBrokerSessions(
  250. serverPrivateKey, serverRootObfuscationSecret, []SessionPublicKey{brokerPublicKey},
  251. apiParameterValidator, apiParameterLogFieldFormatter, "")
  252. if err != nil {
  253. return errors.Trace(err)
  254. }
  255. var pendingBrokerServerReportsMutex sync.Mutex
  256. pendingBrokerServerReports := make(map[ID]bool)
  257. addPendingBrokerServerReport := func(connectionID ID) {
  258. pendingBrokerServerReportsMutex.Lock()
  259. defer pendingBrokerServerReportsMutex.Unlock()
  260. pendingBrokerServerReports[connectionID] = true
  261. }
  262. hasPendingBrokerServerReports := func() bool {
  263. pendingBrokerServerReportsMutex.Lock()
  264. defer pendingBrokerServerReportsMutex.Unlock()
  265. return len(pendingBrokerServerReports) > 0
  266. }
  267. handleBrokerServerReports := func(in []byte, clientConnectionID ID) ([]byte, error) {
  268. handler := func(brokerVerifiedOriginalClientIP string, logFields common.LogFields) {
  269. pendingBrokerServerReportsMutex.Lock()
  270. defer pendingBrokerServerReportsMutex.Unlock()
  271. // Mark the report as no longer outstanding
  272. delete(pendingBrokerServerReports, clientConnectionID)
  273. }
  274. out, err := serverSessions.HandlePacket(logger, in, clientConnectionID, handler)
  275. return out, errors.Trace(err)
  276. }
  277. // Check that the tactics round trip succeeds
  278. var pendingProxyTacticsCallbacksMutex sync.Mutex
  279. pendingProxyTacticsCallbacks := make(map[SessionPrivateKey]bool)
  280. addPendingProxyTacticsCallback := func(proxyPrivateKey SessionPrivateKey) {
  281. pendingProxyTacticsCallbacksMutex.Lock()
  282. defer pendingProxyTacticsCallbacksMutex.Unlock()
  283. pendingProxyTacticsCallbacks[proxyPrivateKey] = true
  284. }
  285. hasPendingProxyTacticsCallbacks := func() bool {
  286. pendingProxyTacticsCallbacksMutex.Lock()
  287. defer pendingProxyTacticsCallbacksMutex.Unlock()
  288. return len(pendingProxyTacticsCallbacks) > 0
  289. }
  290. makeHandleTacticsPayload := func(
  291. proxyPrivateKey SessionPrivateKey,
  292. tacticsNetworkID string) func(_ string, _ []byte) bool {
  293. return func(networkID string, tacticsPayload []byte) bool {
  294. pendingProxyTacticsCallbacksMutex.Lock()
  295. defer pendingProxyTacticsCallbacksMutex.Unlock()
  296. // Check that the correct networkID is passed around; if not,
  297. // skip the delete, which will fail the test
  298. if networkID == tacticsNetworkID {
  299. // Certain state is reset when new tactics are applied -- the
  300. // return true case; exercise both cases
  301. if bytes.Equal(tacticsPayload, testNewTacticsPayload) {
  302. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  303. return true
  304. }
  305. if bytes.Equal(tacticsPayload, testUnchangedTacticsPayload) {
  306. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  307. return false
  308. }
  309. }
  310. panic("unexpected tactics payload")
  311. }
  312. }
  313. // Start proxies
  314. logger.WithTrace().Info("START PROXIES")
  315. for i := 0; i < numProxies; i++ {
  316. proxyPrivateKey, err := GenerateSessionPrivateKey()
  317. if err != nil {
  318. return errors.Trace(err)
  319. }
  320. brokerCoordinator := &testBrokerDialCoordinator{
  321. networkID: testNetworkID,
  322. networkType: testNetworkType,
  323. brokerClientPrivateKey: proxyPrivateKey,
  324. brokerPublicKey: brokerPublicKey,
  325. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  326. brokerClientRoundTripper: newHTTPRoundTripper(
  327. brokerListener.Addr().String(), "proxy"),
  328. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  329. brokerClientRoundTripperFailed: roundTripperFailed,
  330. }
  331. webRTCCoordinator := &testWebRTCDialCoordinator{
  332. networkID: testNetworkID,
  333. networkType: testNetworkType,
  334. natType: testNATType,
  335. disableSTUN: testDisableSTUN,
  336. stunServerAddress: testSTUNServerAddress,
  337. stunServerAddressRFC5780: testSTUNServerAddress,
  338. stunServerAddressSucceeded: stunServerAddressSucceeded,
  339. stunServerAddressFailed: stunServerAddressFailed,
  340. setNATType: func(NATType) {},
  341. setPortMappingTypes: func(PortMappingTypes) {},
  342. bindToDevice: func(int) error { return nil },
  343. }
  344. // Each proxy has its own broker client
  345. brokerClient, err := NewBrokerClient(brokerCoordinator)
  346. if err != nil {
  347. return errors.Trace(err)
  348. }
  349. tacticsNetworkID := prng.HexString(32)
  350. runCtx, cancelRun := context.WithCancel(testCtx)
  351. // No deferred cancelRun due to testGroup.Go below
  352. proxy, err := NewProxy(&ProxyConfig{
  353. Logger: logger,
  354. WaitForNetworkConnectivity: func() bool {
  355. return true
  356. },
  357. GetCurrentNetworkContext: func() context.Context {
  358. return currentNetworkCtx
  359. },
  360. GetBrokerClient: func() (*BrokerClient, error) {
  361. return brokerClient, nil
  362. },
  363. GetBaseAPIParameters: func(bool) (common.APIParameters, string, error) {
  364. return baseAPIParameters, tacticsNetworkID, nil
  365. },
  366. MakeWebRTCDialCoordinator: func() (WebRTCDialCoordinator, error) {
  367. return webRTCCoordinator, nil
  368. },
  369. HandleTacticsPayload: makeHandleTacticsPayload(proxyPrivateKey, tacticsNetworkID),
  370. MaxClients: proxyMaxClients,
  371. LimitUpstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  372. LimitDownstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  373. ActivityUpdater: func(connectingClients int32, connectedClients int32,
  374. bytesUp int64, bytesDown int64, bytesDuration time.Duration) {
  375. fmt.Printf("[%s] ACTIVITY: %d connecting, %d connected, %d up, %d down\n",
  376. time.Now().UTC().Format(time.RFC3339),
  377. connectingClients, connectedClients, bytesUp, bytesDown)
  378. },
  379. MustUpgrade: func() {
  380. close(receivedProxyMustUpgrade)
  381. cancelRun()
  382. },
  383. })
  384. if err != nil {
  385. return errors.Trace(err)
  386. }
  387. addPendingProxyTacticsCallback(proxyPrivateKey)
  388. testGroup.Go(func() error {
  389. proxy.Run(runCtx)
  390. return nil
  391. })
  392. }
  393. // Await proxy announcements before starting clients
  394. //
  395. // - Announcements may delay due to proxyAnnounceRetryDelay in Proxy.Run,
  396. // plus NAT discovery
  397. //
  398. // - Don't wait for > numProxies announcements due to
  399. // InitiatorSessions.NewRoundTrip waitToShareSession limitation
  400. if !doMustUpgrade {
  401. for {
  402. time.Sleep(100 * time.Millisecond)
  403. broker.matcher.announcementQueueMutex.Lock()
  404. n := broker.matcher.announcementQueue.getLen()
  405. broker.matcher.announcementQueueMutex.Unlock()
  406. if n >= numProxies {
  407. break
  408. }
  409. }
  410. }
  411. // Start clients
  412. logger.WithTrace().Info("START CLIENTS")
  413. clientsGroup := new(errgroup.Group)
  414. makeClientFunc := func(
  415. isTCP bool,
  416. isMobile bool,
  417. brokerClient *BrokerClient,
  418. webRTCCoordinator WebRTCDialCoordinator) func() error {
  419. var networkProtocol NetworkProtocol
  420. var addr string
  421. var wrapWithQUIC bool
  422. if isTCP {
  423. networkProtocol = NetworkProtocolTCP
  424. addr = tcpEchoListener.Addr().String()
  425. } else {
  426. networkProtocol = NetworkProtocolUDP
  427. addr = quicEchoServer.Addr().String()
  428. wrapWithQUIC = true
  429. }
  430. return func() error {
  431. dialCtx, cancelDial := context.WithTimeout(testCtx, 60*time.Second)
  432. defer cancelDial()
  433. conn, err := DialClient(
  434. dialCtx,
  435. &ClientConfig{
  436. Logger: logger,
  437. BaseAPIParameters: baseAPIParameters,
  438. BrokerClient: brokerClient,
  439. WebRTCDialCoordinator: webRTCCoordinator,
  440. ReliableTransport: isTCP,
  441. DialNetworkProtocol: networkProtocol,
  442. DialAddress: addr,
  443. PackedDestinationServerEntry: packedDestinationServerEntry,
  444. MustUpgrade: func() {
  445. close(receivedClientMustUpgrade)
  446. cancelDial()
  447. },
  448. })
  449. if err != nil {
  450. return errors.Trace(err)
  451. }
  452. var relayConn net.Conn
  453. relayConn = conn
  454. if wrapWithQUIC {
  455. quicConn, err := quic.Dial(
  456. dialCtx,
  457. conn,
  458. &net.UDPAddr{Port: 1}, // This address is ignored, but the zero value is not allowed
  459. "test", "QUICv1", nil, quicEchoServer.ObfuscationKey(), nil, nil, true,
  460. false, false, common.WrapClientSessionCache(tls.NewLRUClientSessionCache(0), ""),
  461. )
  462. if err != nil {
  463. return errors.Trace(err)
  464. }
  465. relayConn = quicConn
  466. }
  467. addPendingBrokerServerReport(conn.GetConnectionID())
  468. signalRelayComplete := make(chan struct{})
  469. clientsGroup.Go(func() error {
  470. defer close(signalRelayComplete)
  471. in := conn.InitialRelayPacket()
  472. for in != nil {
  473. out, err := handleBrokerServerReports(in, conn.GetConnectionID())
  474. if err != nil {
  475. if out == nil {
  476. return errors.Trace(err)
  477. } else {
  478. fmt.Printf("HandlePacket returned packet and error: %v\n", err)
  479. // Proceed with reset session token packet
  480. }
  481. }
  482. if out == nil {
  483. // Relay is complete
  484. break
  485. }
  486. in, err = conn.RelayPacket(testCtx, out)
  487. if err != nil {
  488. return errors.Trace(err)
  489. }
  490. }
  491. return nil
  492. })
  493. sendBytes := prng.Bytes(bytesToSend)
  494. clientsGroup.Go(func() error {
  495. for n := 0; n < bytesToSend; {
  496. m := prng.Range(1024, 32768)
  497. if bytesToSend-n < m {
  498. m = bytesToSend - n
  499. }
  500. _, err := relayConn.Write(sendBytes[n : n+m])
  501. if err != nil {
  502. return errors.Trace(err)
  503. }
  504. n += m
  505. }
  506. fmt.Printf("%d bytes sent\n", bytesToSend)
  507. return nil
  508. })
  509. clientsGroup.Go(func() error {
  510. buf := make([]byte, 32768)
  511. n := 0
  512. for n < bytesToSend {
  513. m, err := relayConn.Read(buf)
  514. if err != nil {
  515. return errors.Trace(err)
  516. }
  517. if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
  518. return errors.Tracef(
  519. "unexpected bytes: expected at index %d, received at index %d",
  520. bytes.Index(sendBytes, buf[:m]), n)
  521. }
  522. n += m
  523. }
  524. fmt.Printf("%d bytes received\n", bytesToSend)
  525. select {
  526. case <-signalRelayComplete:
  527. case <-testCtx.Done():
  528. }
  529. relayConn.Close()
  530. conn.Close()
  531. return nil
  532. })
  533. return nil
  534. }
  535. }
  536. newClientParams := func(isMobile bool) (*BrokerClient, *testWebRTCDialCoordinator, error) {
  537. clientPrivateKey, err := GenerateSessionPrivateKey()
  538. if err != nil {
  539. return nil, nil, errors.Trace(err)
  540. }
  541. clientRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  542. if err != nil {
  543. return nil, nil, errors.Trace(err)
  544. }
  545. brokerCoordinator := &testBrokerDialCoordinator{
  546. networkID: testNetworkID,
  547. networkType: testNetworkType,
  548. commonCompartmentIDs: testCommonCompartmentIDs,
  549. brokerClientPrivateKey: clientPrivateKey,
  550. brokerPublicKey: brokerPublicKey,
  551. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  552. brokerClientRoundTripper: newHTTPRoundTripper(
  553. brokerListener.Addr().String(), "client"),
  554. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  555. brokerClientRoundTripperFailed: roundTripperFailed,
  556. brokerClientNoMatch: noMatch,
  557. }
  558. webRTCCoordinator := &testWebRTCDialCoordinator{
  559. networkID: testNetworkID,
  560. networkType: testNetworkType,
  561. natType: testNATType,
  562. disableSTUN: testDisableSTUN,
  563. stunServerAddress: testSTUNServerAddress,
  564. stunServerAddressRFC5780: testSTUNServerAddress,
  565. stunServerAddressSucceeded: stunServerAddressSucceeded,
  566. stunServerAddressFailed: stunServerAddressFailed,
  567. clientRootObfuscationSecret: clientRootObfuscationSecret,
  568. doDTLSRandomization: prng.FlipCoin(),
  569. trafficShapingParameters: &DataChannelTrafficShapingParameters{
  570. MinPaddedMessages: 0,
  571. MaxPaddedMessages: 10,
  572. MinPaddingSize: 0,
  573. MaxPaddingSize: 1500,
  574. MinDecoyMessages: 0,
  575. MaxDecoyMessages: 10,
  576. MinDecoySize: 1,
  577. MaxDecoySize: 1500,
  578. DecoyMessageProbability: 0.5,
  579. },
  580. setNATType: func(NATType) {},
  581. setPortMappingTypes: func(PortMappingTypes) {},
  582. bindToDevice: func(int) error { return nil },
  583. // With STUN enabled (testDisableSTUN = false), there are cases
  584. // where the WebRTC Data Channel is not successfully established.
  585. // With a short enough timeout here, clients will redial and
  586. // eventually succceed.
  587. webRTCAwaitDataChannelTimeout: 5 * time.Second,
  588. }
  589. if isMobile {
  590. webRTCCoordinator.networkType = NetworkTypeMobile
  591. webRTCCoordinator.disableInboundForMobileNetworks = true
  592. }
  593. brokerClient, err := NewBrokerClient(brokerCoordinator)
  594. if err != nil {
  595. return nil, nil, errors.Trace(err)
  596. }
  597. return brokerClient, webRTCCoordinator, nil
  598. }
  599. clientBrokerClient, clientWebRTCCoordinator, err := newClientParams(false)
  600. if err != nil {
  601. return errors.Trace(err)
  602. }
  603. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err := newClientParams(true)
  604. if err != nil {
  605. return errors.Trace(err)
  606. }
  607. for i := 0; i < numClients; i++ {
  608. // Test a mix of TCP and UDP proxying; also test the
  609. // DisableInboundForMobileNetworks code path.
  610. isTCP := i%2 == 0
  611. isMobile := i%4 == 0
  612. // Exercise BrokerClients shared by multiple clients, but also create
  613. // several broker clients.
  614. if i%8 == 0 {
  615. clientBrokerClient, clientWebRTCCoordinator, err = newClientParams(false)
  616. if err != nil {
  617. return errors.Trace(err)
  618. }
  619. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err = newClientParams(true)
  620. if err != nil {
  621. return errors.Trace(err)
  622. }
  623. }
  624. brokerClient := clientBrokerClient
  625. webRTCCoordinator := clientWebRTCCoordinator
  626. if isMobile {
  627. brokerClient = clientMobileBrokerClient
  628. webRTCCoordinator = clientMobileWebRTCCoordinator
  629. }
  630. clientsGroup.Go(makeClientFunc(isTCP, isMobile, brokerClient, webRTCCoordinator))
  631. }
  632. if doMustUpgrade {
  633. // Await MustUpgrade callbacks
  634. logger.WithTrace().Info("AWAIT MUST UPGRADE")
  635. <-receivedProxyMustUpgrade
  636. <-receivedClientMustUpgrade
  637. _ = clientsGroup.Wait()
  638. } else {
  639. // Await client transfers complete
  640. logger.WithTrace().Info("AWAIT DATA TRANSFER")
  641. err = clientsGroup.Wait()
  642. if err != nil {
  643. return errors.Trace(err)
  644. }
  645. logger.WithTrace().Info("DONE DATA TRANSFER")
  646. if hasPendingBrokerServerReports() {
  647. return errors.TraceNew("unexpected pending broker server requests")
  648. }
  649. if hasPendingProxyTacticsCallbacks() {
  650. return errors.TraceNew("unexpected pending proxy tactics callback")
  651. }
  652. // TODO: check that elapsed time is consistent with rate limit (+/-)
  653. // Check if STUN server replay callbacks were triggered
  654. if !testDisableSTUN {
  655. if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
  656. return errors.TraceNew("unexpected STUN server succeeded count")
  657. }
  658. }
  659. if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
  660. return errors.TraceNew("unexpected STUN server failed count")
  661. }
  662. // Check if RoundTripper server replay callbacks were triggered
  663. if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
  664. return errors.TraceNew("unexpected round tripper succeeded count")
  665. }
  666. if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
  667. return errors.TraceNew("unexpected round tripper failed count")
  668. }
  669. }
  670. // Await shutdowns
  671. stopTest()
  672. brokerListener.Close()
  673. err = testGroup.Wait()
  674. if err != nil {
  675. return errors.Trace(err)
  676. }
  677. return nil
  678. }
  679. func runHTTPServer(listener net.Listener, broker *Broker) error {
  680. handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  681. // For this test, clients set the path to "/client" and proxies
  682. // set the path to "/proxy" and we use that to create stub GeoIP
  683. // data to pass the not-same-ASN condition.
  684. var geoIPData common.GeoIPData
  685. geoIPData.ASN = r.URL.Path
  686. requestPayload, err := ioutil.ReadAll(
  687. http.MaxBytesReader(w, r.Body, BrokerMaxRequestBodySize))
  688. if err != nil {
  689. fmt.Printf("runHTTPServer ioutil.ReadAll failed: %v\n", err)
  690. http.Error(w, "", http.StatusNotFound)
  691. return
  692. }
  693. clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
  694. extendTimeout := func(timeout time.Duration) {
  695. // TODO: set insufficient initial timeout, so extension is
  696. // required for success
  697. http.NewResponseController(w).SetWriteDeadline(time.Now().Add(timeout))
  698. }
  699. responsePayload, err := broker.HandleSessionPacket(
  700. r.Context(),
  701. extendTimeout,
  702. nil,
  703. clientIP,
  704. geoIPData,
  705. requestPayload)
  706. if err != nil {
  707. fmt.Printf("runHTTPServer HandleSessionPacket failed: %v\n", err)
  708. http.Error(w, "", http.StatusNotFound)
  709. return
  710. }
  711. w.WriteHeader(http.StatusOK)
  712. w.Write(responsePayload)
  713. })
  714. // WriteTimeout will be extended via extendTimeout.
  715. httpServer := &http.Server{
  716. ReadTimeout: 10 * time.Second,
  717. WriteTimeout: 10 * time.Second,
  718. IdleTimeout: 1 * time.Minute,
  719. Handler: handler,
  720. }
  721. certificate, privateKey, _, err := common.GenerateWebServerCertificate("www.example.com")
  722. if err != nil {
  723. return errors.Trace(err)
  724. }
  725. tlsCert, err := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
  726. if err != nil {
  727. return errors.Trace(err)
  728. }
  729. tlsConfig := &tls.Config{
  730. Certificates: []tls.Certificate{tlsCert},
  731. }
  732. err = httpServer.Serve(tls.NewListener(listener, tlsConfig))
  733. return errors.Trace(err)
  734. }
  735. type httpRoundTripper struct {
  736. httpClient *http.Client
  737. endpointAddr string
  738. path string
  739. }
  740. func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
  741. return &httpRoundTripper{
  742. httpClient: &http.Client{
  743. Transport: &http.Transport{
  744. ForceAttemptHTTP2: true,
  745. MaxIdleConns: 2,
  746. IdleConnTimeout: 1 * time.Minute,
  747. TLSHandshakeTimeout: 10 * time.Second,
  748. TLSClientConfig: &std_tls.Config{
  749. InsecureSkipVerify: true,
  750. },
  751. },
  752. },
  753. endpointAddr: endpointAddr,
  754. path: path,
  755. }
  756. }
  757. func (r *httpRoundTripper) RoundTrip(
  758. ctx context.Context,
  759. roundTripDelay time.Duration,
  760. roundTripTimeout time.Duration,
  761. requestPayload []byte) ([]byte, error) {
  762. if roundTripDelay > 0 {
  763. common.SleepWithContext(ctx, roundTripDelay)
  764. }
  765. requestCtx, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
  766. defer requestCancelFunc()
  767. url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
  768. request, err := http.NewRequestWithContext(
  769. requestCtx, "POST", url, bytes.NewReader(requestPayload))
  770. if err != nil {
  771. return nil, errors.Trace(err)
  772. }
  773. response, err := r.httpClient.Do(request)
  774. if err != nil {
  775. return nil, errors.Trace(err)
  776. }
  777. defer response.Body.Close()
  778. if response.StatusCode != http.StatusOK {
  779. return nil, errors.Tracef("unexpected response status code: %d", response.StatusCode)
  780. }
  781. responsePayload, err := io.ReadAll(response.Body)
  782. if err != nil {
  783. return nil, errors.Trace(err)
  784. }
  785. return responsePayload, nil
  786. }
  787. func (r *httpRoundTripper) Close() error {
  788. r.httpClient.CloseIdleConnections()
  789. return nil
  790. }
  791. func runTCPEchoServer(listener net.Listener) {
  792. for {
  793. conn, err := listener.Accept()
  794. if err != nil {
  795. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  796. return
  797. }
  798. go func(conn net.Conn) {
  799. buf := make([]byte, 32768)
  800. for {
  801. n, err := conn.Read(buf)
  802. if n > 0 {
  803. _, err = conn.Write(buf[:n])
  804. }
  805. if err != nil {
  806. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  807. return
  808. }
  809. }
  810. }(conn)
  811. }
  812. }
  813. type quicEchoServer struct {
  814. listener net.Listener
  815. obfuscationKey string
  816. }
  817. func newQuicEchoServer() (*quicEchoServer, error) {
  818. obfuscationKey := prng.HexString(32)
  819. listener, err := quic.Listen(
  820. nil,
  821. nil,
  822. "127.0.0.1:0",
  823. obfuscationKey,
  824. false)
  825. if err != nil {
  826. return nil, errors.Trace(err)
  827. }
  828. return &quicEchoServer{
  829. listener: listener,
  830. obfuscationKey: obfuscationKey,
  831. }, nil
  832. }
  833. func (q *quicEchoServer) ObfuscationKey() string {
  834. return q.obfuscationKey
  835. }
  836. func (q *quicEchoServer) Close() error {
  837. return q.listener.Close()
  838. }
  839. func (q *quicEchoServer) Addr() net.Addr {
  840. return q.listener.Addr()
  841. }
  842. func (q *quicEchoServer) Run() {
  843. for {
  844. conn, err := q.listener.Accept()
  845. if err != nil {
  846. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  847. return
  848. }
  849. go func(conn net.Conn) {
  850. buf := make([]byte, 32768)
  851. for {
  852. n, err := conn.Read(buf)
  853. if n > 0 {
  854. _, err = conn.Write(buf[:n])
  855. }
  856. if err != nil {
  857. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  858. return
  859. }
  860. }
  861. }(conn)
  862. }
  863. }