inproxy_test.go 28 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045
  1. //go:build PSIPHON_ENABLE_INPROXY
  2. /*
  3. * Copyright (c) 2023, Psiphon Inc.
  4. * All rights reserved.
  5. *
  6. * This program is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License as published by
  8. * the Free Software Foundation, either version 3 of the License, or
  9. * (at your option) any later version.
  10. *
  11. * This program is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. * GNU General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU General Public License
  17. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. *
  19. */
  20. package inproxy
  21. import (
  22. "bytes"
  23. "context"
  24. "crypto/tls"
  25. "encoding/base64"
  26. "fmt"
  27. "io"
  28. "io/ioutil"
  29. "net"
  30. "net/http"
  31. _ "net/http/pprof"
  32. "strconv"
  33. "strings"
  34. "sync"
  35. "sync/atomic"
  36. "testing"
  37. "time"
  38. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  39. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  40. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
  43. "golang.org/x/sync/errgroup"
  44. )
  45. func TestInproxy(t *testing.T) {
  46. err := runTestInproxy(false)
  47. if err != nil {
  48. t.Errorf(errors.Trace(err).Error())
  49. }
  50. }
  51. func TestInproxyMustUpgrade(t *testing.T) {
  52. err := runTestInproxy(true)
  53. if err != nil {
  54. t.Errorf(errors.Trace(err).Error())
  55. }
  56. }
  57. func runTestInproxy(doMustUpgrade bool) error {
  58. // Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
  59. numProxies := 5
  60. proxyMaxClients := 3
  61. numClients := 10
  62. bytesToSend := 1 << 20
  63. targetElapsedSeconds := 2
  64. baseAPIParameters := common.APIParameters{
  65. "sponsor_id": strings.ToUpper(prng.HexString(8)),
  66. "client_platform": "test-client-platform",
  67. }
  68. testCompartmentID, _ := MakeID()
  69. testCommonCompartmentIDs := []ID{testCompartmentID}
  70. testNetworkID := "NETWORK-ID-1"
  71. testNetworkType := NetworkTypeUnknown
  72. testNATType := NATTypeUnknown
  73. testSTUNServerAddress := "stun.nextcloud.com:443"
  74. testDisableSTUN := false
  75. testNewTacticsPayload := []byte(prng.HexString(100))
  76. testNewTacticsTag := "new-tactics-tag"
  77. testUnchangedTacticsPayload := []byte(prng.HexString(100))
  78. // TODO: test port mapping
  79. stunServerAddressSucceededCount := int32(0)
  80. stunServerAddressSucceeded := func(bool, string) { atomic.AddInt32(&stunServerAddressSucceededCount, 1) }
  81. stunServerAddressFailedCount := int32(0)
  82. stunServerAddressFailed := func(bool, string) { atomic.AddInt32(&stunServerAddressFailedCount, 1) }
  83. roundTripperSucceededCount := int32(0)
  84. roundTripperSucceded := func(RoundTripper) { atomic.AddInt32(&roundTripperSucceededCount, 1) }
  85. roundTripperFailedCount := int32(0)
  86. roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
  87. var receivedProxyMustUpgrade chan struct{}
  88. var receivedClientMustUpgrade chan struct{}
  89. if doMustUpgrade {
  90. receivedProxyMustUpgrade = make(chan struct{})
  91. receivedClientMustUpgrade = make(chan struct{})
  92. // trigger MustUpgrade
  93. proxyProtocolVersion = 0
  94. // Minimize test parameters for MustUpgrade case
  95. numProxies = 1
  96. proxyMaxClients = 1
  97. numClients = 1
  98. testDisableSTUN = true
  99. }
  100. testCtx, stopTest := context.WithCancel(context.Background())
  101. defer stopTest()
  102. testGroup := new(errgroup.Group)
  103. // Enable test to run without requiring host firewall exceptions
  104. SetAllowBogonWebRTCConnections(true)
  105. defer SetAllowBogonWebRTCConnections(false)
  106. // Init logging and profiling
  107. logger := newTestLogger()
  108. pprofListener, err := net.Listen("tcp", "127.0.0.1:0")
  109. go http.Serve(pprofListener, nil)
  110. defer pprofListener.Close()
  111. logger.WithTrace().Info(fmt.Sprintf("PPROF: http://%s/debug/pprof", pprofListener.Addr()))
  112. // Start echo servers
  113. tcpEchoListener, err := net.Listen("tcp", "127.0.0.1:0")
  114. if err != nil {
  115. return errors.Trace(err)
  116. }
  117. defer tcpEchoListener.Close()
  118. go runTCPEchoServer(tcpEchoListener)
  119. // QUIC tests UDP proxying, and provides reliable delivery of echoed data
  120. quicEchoServer, err := newQuicEchoServer()
  121. if err != nil {
  122. return errors.Trace(err)
  123. }
  124. defer quicEchoServer.Close()
  125. go quicEchoServer.Run()
  126. // Create signed server entry with capability
  127. serverPrivateKey, err := GenerateSessionPrivateKey()
  128. if err != nil {
  129. return errors.Trace(err)
  130. }
  131. serverPublicKey, err := serverPrivateKey.GetPublicKey()
  132. if err != nil {
  133. return errors.Trace(err)
  134. }
  135. serverRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  136. if err != nil {
  137. return errors.Trace(err)
  138. }
  139. serverEntry := make(protocol.ServerEntryFields)
  140. serverEntry["ipAddress"] = "127.0.0.1"
  141. _, tcpPort, _ := net.SplitHostPort(tcpEchoListener.Addr().String())
  142. _, udpPort, _ := net.SplitHostPort(quicEchoServer.Addr().String())
  143. serverEntry["inproxyOSSHPort"], _ = strconv.Atoi(tcpPort)
  144. serverEntry["inproxyQUICPort"], _ = strconv.Atoi(udpPort)
  145. serverEntry["capabilities"] = []string{"INPROXY-WEBRTC-OSSH", "INPROXY-WEBRTC-QUIC-OSSH"}
  146. serverEntry["inproxySessionPublicKey"] = base64.RawStdEncoding.EncodeToString(serverPublicKey[:])
  147. serverEntry["inproxySessionRootObfuscationSecret"] = base64.RawStdEncoding.EncodeToString(serverRootObfuscationSecret[:])
  148. testServerEntryTag := prng.HexString(16)
  149. serverEntry["tag"] = testServerEntryTag
  150. serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey, err :=
  151. protocol.NewServerEntrySignatureKeyPair()
  152. if err != nil {
  153. return errors.Trace(err)
  154. }
  155. err = serverEntry.AddSignature(serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey)
  156. if err != nil {
  157. return errors.Trace(err)
  158. }
  159. packedServerEntryFields, err := protocol.EncodePackedServerEntryFields(serverEntry)
  160. if err != nil {
  161. return errors.Trace(err)
  162. }
  163. packedDestinationServerEntry, err := protocol.CBOREncoding.Marshal(packedServerEntryFields)
  164. if err != nil {
  165. return errors.Trace(err)
  166. }
  167. // Start broker
  168. logger.WithTrace().Info("START BROKER")
  169. brokerPrivateKey, err := GenerateSessionPrivateKey()
  170. if err != nil {
  171. return errors.Trace(err)
  172. }
  173. brokerPublicKey, err := brokerPrivateKey.GetPublicKey()
  174. if err != nil {
  175. return errors.Trace(err)
  176. }
  177. brokerRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  178. if err != nil {
  179. return errors.Trace(err)
  180. }
  181. brokerListener, err := net.Listen("tcp", "127.0.0.1:0")
  182. if err != nil {
  183. return errors.Trace(err)
  184. }
  185. defer brokerListener.Close()
  186. brokerConfig := &BrokerConfig{
  187. Logger: logger,
  188. CommonCompartmentIDs: testCommonCompartmentIDs,
  189. APIParameterValidator: func(params common.APIParameters) error {
  190. if len(params) != len(baseAPIParameters) {
  191. return errors.TraceNew("unexpected base API parameter count")
  192. }
  193. for name, value := range params {
  194. if value.(string) != baseAPIParameters[name].(string) {
  195. return errors.Tracef(
  196. "unexpected base API parameter: %v: %v != %v",
  197. name,
  198. value.(string),
  199. baseAPIParameters[name].(string))
  200. }
  201. }
  202. return nil
  203. },
  204. APIParameterLogFieldFormatter: func(
  205. geoIPData common.GeoIPData, params common.APIParameters) common.LogFields {
  206. return common.LogFields(params)
  207. },
  208. GetTactics: func(_ common.GeoIPData, _ common.APIParameters) ([]byte, string, error) {
  209. // Exercise both new and unchanged tactics
  210. if prng.FlipCoin() {
  211. return testNewTacticsPayload, testNewTacticsTag, nil
  212. }
  213. return testUnchangedTacticsPayload, "", nil
  214. },
  215. IsValidServerEntryTag: func(serverEntryTag string) bool { return serverEntryTag == testServerEntryTag },
  216. PrivateKey: brokerPrivateKey,
  217. ObfuscationRootSecret: brokerRootObfuscationSecret,
  218. ServerEntrySignaturePublicKey: serverEntrySignaturePublicKey,
  219. AllowProxy: func(common.GeoIPData) bool { return true },
  220. AllowClient: func(common.GeoIPData) bool { return true },
  221. AllowDomainFrontedDestinations: func(common.GeoIPData) bool { return true },
  222. }
  223. broker, err := NewBroker(brokerConfig)
  224. if err != nil {
  225. return errors.Trace(err)
  226. }
  227. err = broker.Start()
  228. if err != nil {
  229. return errors.Trace(err)
  230. }
  231. defer broker.Stop()
  232. testGroup.Go(func() error {
  233. err := runHTTPServer(brokerListener, broker)
  234. if testCtx.Err() != nil {
  235. return nil
  236. }
  237. return errors.Trace(err)
  238. })
  239. // Stub server broker request handler (in Psiphon, this will be the
  240. // destination Psiphon server; here, it's not necessary to build this
  241. // handler into the destination echo server)
  242. serverSessions, err := NewServerBrokerSessions(
  243. serverPrivateKey, serverRootObfuscationSecret, []SessionPublicKey{brokerPublicKey})
  244. if err != nil {
  245. return errors.Trace(err)
  246. }
  247. var pendingBrokerServerReportsMutex sync.Mutex
  248. pendingBrokerServerReports := make(map[ID]bool)
  249. addPendingBrokerServerReport := func(connectionID ID) {
  250. pendingBrokerServerReportsMutex.Lock()
  251. defer pendingBrokerServerReportsMutex.Unlock()
  252. pendingBrokerServerReports[connectionID] = true
  253. }
  254. hasPendingBrokerServerReports := func() bool {
  255. pendingBrokerServerReportsMutex.Lock()
  256. defer pendingBrokerServerReportsMutex.Unlock()
  257. return len(pendingBrokerServerReports) > 0
  258. }
  259. handleBrokerServerReports := func(in []byte, clientConnectionID ID) ([]byte, error) {
  260. handler := func(brokerVerifiedOriginalClientIP string, logFields common.LogFields) {
  261. pendingBrokerServerReportsMutex.Lock()
  262. defer pendingBrokerServerReportsMutex.Unlock()
  263. // Mark the report as no longer outstanding
  264. delete(pendingBrokerServerReports, clientConnectionID)
  265. }
  266. out, err := serverSessions.HandlePacket(logger, in, clientConnectionID, handler)
  267. return out, errors.Trace(err)
  268. }
  269. // Check that the tactics round trip succeeds
  270. var pendingProxyTacticsCallbacksMutex sync.Mutex
  271. pendingProxyTacticsCallbacks := make(map[SessionPrivateKey]bool)
  272. addPendingProxyTacticsCallback := func(proxyPrivateKey SessionPrivateKey) {
  273. pendingProxyTacticsCallbacksMutex.Lock()
  274. defer pendingProxyTacticsCallbacksMutex.Unlock()
  275. pendingProxyTacticsCallbacks[proxyPrivateKey] = true
  276. }
  277. hasPendingProxyTacticsCallbacks := func() bool {
  278. pendingProxyTacticsCallbacksMutex.Lock()
  279. defer pendingProxyTacticsCallbacksMutex.Unlock()
  280. return len(pendingProxyTacticsCallbacks) > 0
  281. }
  282. makeHandleTacticsPayload := func(
  283. proxyPrivateKey SessionPrivateKey,
  284. tacticsNetworkID string) func(_ string, _ []byte) bool {
  285. return func(networkID string, tacticsPayload []byte) bool {
  286. pendingProxyTacticsCallbacksMutex.Lock()
  287. defer pendingProxyTacticsCallbacksMutex.Unlock()
  288. // Check that the correct networkID is passed around; if not,
  289. // skip the delete, which will fail the test
  290. if networkID == tacticsNetworkID {
  291. // Certain state is reset when new tactics are applied -- the
  292. // return true case; exercise both cases
  293. if bytes.Equal(tacticsPayload, testNewTacticsPayload) {
  294. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  295. return true
  296. }
  297. if bytes.Equal(tacticsPayload, testUnchangedTacticsPayload) {
  298. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  299. return false
  300. }
  301. }
  302. panic("unexpected tactics payload")
  303. }
  304. }
  305. // Start proxies
  306. logger.WithTrace().Info("START PROXIES")
  307. for i := 0; i < numProxies; i++ {
  308. proxyPrivateKey, err := GenerateSessionPrivateKey()
  309. if err != nil {
  310. return errors.Trace(err)
  311. }
  312. brokerCoordinator := &testBrokerDialCoordinator{
  313. networkID: testNetworkID,
  314. networkType: testNetworkType,
  315. brokerClientPrivateKey: proxyPrivateKey,
  316. brokerPublicKey: brokerPublicKey,
  317. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  318. brokerClientRoundTripper: newHTTPRoundTripper(
  319. brokerListener.Addr().String(), "proxy"),
  320. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  321. brokerClientRoundTripperFailed: roundTripperFailed,
  322. }
  323. webRTCCoordinator := &testWebRTCDialCoordinator{
  324. networkID: testNetworkID,
  325. networkType: testNetworkType,
  326. natType: testNATType,
  327. disableSTUN: testDisableSTUN,
  328. stunServerAddress: testSTUNServerAddress,
  329. stunServerAddressRFC5780: testSTUNServerAddress,
  330. stunServerAddressSucceeded: stunServerAddressSucceeded,
  331. stunServerAddressFailed: stunServerAddressFailed,
  332. setNATType: func(NATType) {},
  333. setPortMappingTypes: func(PortMappingTypes) {},
  334. bindToDevice: func(int) error { return nil },
  335. }
  336. // Each proxy has its own broker client
  337. brokerClient, err := NewBrokerClient(brokerCoordinator)
  338. if err != nil {
  339. return errors.Trace(err)
  340. }
  341. tacticsNetworkID := prng.HexString(32)
  342. runCtx, cancelRun := context.WithCancel(testCtx)
  343. // No deferred cancelRun due to testGroup.Go below
  344. proxy, err := NewProxy(&ProxyConfig{
  345. Logger: logger,
  346. WaitForNetworkConnectivity: func() bool {
  347. return true
  348. },
  349. GetBrokerClient: func() (*BrokerClient, error) {
  350. return brokerClient, nil
  351. },
  352. GetBaseAPIParameters: func() (common.APIParameters, string, error) {
  353. return baseAPIParameters, tacticsNetworkID, nil
  354. },
  355. MakeWebRTCDialCoordinator: func() (WebRTCDialCoordinator, error) {
  356. return webRTCCoordinator, nil
  357. },
  358. HandleTacticsPayload: makeHandleTacticsPayload(proxyPrivateKey, tacticsNetworkID),
  359. MaxClients: proxyMaxClients,
  360. LimitUpstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  361. LimitDownstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  362. ActivityUpdater: func(connectingClients int32, connectedClients int32,
  363. bytesUp int64, bytesDown int64, bytesDuration time.Duration) {
  364. fmt.Printf("[%s] ACTIVITY: %d connecting, %d connected, %d up, %d down\n",
  365. time.Now().UTC().Format(time.RFC3339),
  366. connectingClients, connectedClients, bytesUp, bytesDown)
  367. },
  368. MustUpgrade: func() {
  369. close(receivedProxyMustUpgrade)
  370. cancelRun()
  371. },
  372. })
  373. if err != nil {
  374. return errors.Trace(err)
  375. }
  376. addPendingProxyTacticsCallback(proxyPrivateKey)
  377. testGroup.Go(func() error {
  378. proxy.Run(runCtx)
  379. return nil
  380. })
  381. }
  382. // Await proxy announcements before starting clients
  383. //
  384. // - Announcements may delay due to proxyAnnounceRetryDelay in Proxy.Run,
  385. // plus NAT discovery
  386. //
  387. // - Don't wait for > numProxies announcements due to
  388. // InitiatorSessions.NewRoundTrip waitToShareSession limitation
  389. if !doMustUpgrade {
  390. for {
  391. time.Sleep(100 * time.Millisecond)
  392. broker.matcher.announcementQueueMutex.Lock()
  393. n := broker.matcher.announcementQueue.getLen()
  394. broker.matcher.announcementQueueMutex.Unlock()
  395. if n >= numProxies {
  396. break
  397. }
  398. }
  399. }
  400. // Start clients
  401. logger.WithTrace().Info("START CLIENTS")
  402. clientsGroup := new(errgroup.Group)
  403. makeClientFunc := func(
  404. isTCP bool,
  405. isMobile bool,
  406. brokerClient *BrokerClient,
  407. webRTCCoordinator WebRTCDialCoordinator) func() error {
  408. var networkProtocol NetworkProtocol
  409. var addr string
  410. var wrapWithQUIC bool
  411. if isTCP {
  412. networkProtocol = NetworkProtocolTCP
  413. addr = tcpEchoListener.Addr().String()
  414. } else {
  415. networkProtocol = NetworkProtocolUDP
  416. addr = quicEchoServer.Addr().String()
  417. wrapWithQUIC = true
  418. }
  419. return func() error {
  420. dialCtx, cancelDial := context.WithTimeout(testCtx, 60*time.Second)
  421. defer cancelDial()
  422. conn, err := DialClient(
  423. dialCtx,
  424. &ClientConfig{
  425. Logger: logger,
  426. BaseAPIParameters: baseAPIParameters,
  427. BrokerClient: brokerClient,
  428. WebRTCDialCoordinator: webRTCCoordinator,
  429. ReliableTransport: isTCP,
  430. DialNetworkProtocol: networkProtocol,
  431. DialAddress: addr,
  432. PackedDestinationServerEntry: packedDestinationServerEntry,
  433. MustUpgrade: func() {
  434. fmt.Printf("HI!\n")
  435. close(receivedClientMustUpgrade)
  436. cancelDial()
  437. },
  438. })
  439. if err != nil {
  440. return errors.Trace(err)
  441. }
  442. var relayConn net.Conn
  443. relayConn = conn
  444. if wrapWithQUIC {
  445. quicConn, err := quic.Dial(
  446. dialCtx,
  447. conn,
  448. &net.UDPAddr{Port: 1}, // This address is ignored, but the zero value is not allowed
  449. "test", "QUICv1", nil, quicEchoServer.ObfuscationKey(), nil, nil, true)
  450. if err != nil {
  451. return errors.Trace(err)
  452. }
  453. relayConn = quicConn
  454. }
  455. addPendingBrokerServerReport(conn.GetConnectionID())
  456. signalRelayComplete := make(chan struct{})
  457. clientsGroup.Go(func() error {
  458. defer close(signalRelayComplete)
  459. in := conn.InitialRelayPacket()
  460. for in != nil {
  461. out, err := handleBrokerServerReports(in, conn.GetConnectionID())
  462. if err != nil {
  463. if out == nil {
  464. return errors.Trace(err)
  465. } else {
  466. fmt.Printf("HandlePacket returned packet and error: %v\n", err)
  467. // Proceed with reset session token packet
  468. }
  469. }
  470. if out == nil {
  471. // Relay is complete
  472. break
  473. }
  474. in, err = conn.RelayPacket(testCtx, out)
  475. if err != nil {
  476. return errors.Trace(err)
  477. }
  478. }
  479. return nil
  480. })
  481. sendBytes := prng.Bytes(bytesToSend)
  482. clientsGroup.Go(func() error {
  483. for n := 0; n < bytesToSend; {
  484. m := prng.Range(1024, 32768)
  485. if bytesToSend-n < m {
  486. m = bytesToSend - n
  487. }
  488. _, err := relayConn.Write(sendBytes[n : n+m])
  489. if err != nil {
  490. return errors.Trace(err)
  491. }
  492. n += m
  493. }
  494. fmt.Printf("%d bytes sent\n", bytesToSend)
  495. return nil
  496. })
  497. clientsGroup.Go(func() error {
  498. buf := make([]byte, 32768)
  499. n := 0
  500. for n < bytesToSend {
  501. m, err := relayConn.Read(buf)
  502. if err != nil {
  503. return errors.Trace(err)
  504. }
  505. if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
  506. return errors.Tracef(
  507. "unexpected bytes: expected at index %d, received at index %d",
  508. bytes.Index(sendBytes, buf[:m]), n)
  509. }
  510. n += m
  511. }
  512. fmt.Printf("%d bytes received\n", bytesToSend)
  513. select {
  514. case <-signalRelayComplete:
  515. case <-testCtx.Done():
  516. }
  517. relayConn.Close()
  518. conn.Close()
  519. return nil
  520. })
  521. return nil
  522. }
  523. }
  524. newClientParams := func(isMobile bool) (*BrokerClient, *testWebRTCDialCoordinator, error) {
  525. clientPrivateKey, err := GenerateSessionPrivateKey()
  526. if err != nil {
  527. return nil, nil, errors.Trace(err)
  528. }
  529. clientRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  530. if err != nil {
  531. return nil, nil, errors.Trace(err)
  532. }
  533. brokerCoordinator := &testBrokerDialCoordinator{
  534. networkID: testNetworkID,
  535. networkType: testNetworkType,
  536. commonCompartmentIDs: testCommonCompartmentIDs,
  537. brokerClientPrivateKey: clientPrivateKey,
  538. brokerPublicKey: brokerPublicKey,
  539. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  540. brokerClientRoundTripper: newHTTPRoundTripper(
  541. brokerListener.Addr().String(), "client"),
  542. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  543. brokerClientRoundTripperFailed: roundTripperFailed,
  544. }
  545. webRTCCoordinator := &testWebRTCDialCoordinator{
  546. networkID: testNetworkID,
  547. networkType: testNetworkType,
  548. natType: testNATType,
  549. disableSTUN: testDisableSTUN,
  550. stunServerAddress: testSTUNServerAddress,
  551. stunServerAddressRFC5780: testSTUNServerAddress,
  552. stunServerAddressSucceeded: stunServerAddressSucceeded,
  553. stunServerAddressFailed: stunServerAddressFailed,
  554. clientRootObfuscationSecret: clientRootObfuscationSecret,
  555. doDTLSRandomization: prng.FlipCoin(),
  556. trafficShapingParameters: &DataChannelTrafficShapingParameters{
  557. MinPaddedMessages: 0,
  558. MaxPaddedMessages: 10,
  559. MinPaddingSize: 0,
  560. MaxPaddingSize: 1500,
  561. MinDecoyMessages: 0,
  562. MaxDecoyMessages: 10,
  563. MinDecoySize: 1,
  564. MaxDecoySize: 1500,
  565. DecoyMessageProbability: 0.5,
  566. },
  567. setNATType: func(NATType) {},
  568. setPortMappingTypes: func(PortMappingTypes) {},
  569. bindToDevice: func(int) error { return nil },
  570. // With STUN enabled (testDisableSTUN = false), there are cases
  571. // where the WebRTC Data Channel is not successfully established.
  572. // With a short enough timeout here, clients will redial and
  573. // eventually succceed.
  574. webRTCAwaitDataChannelTimeout: 5 * time.Second,
  575. }
  576. if isMobile {
  577. webRTCCoordinator.networkType = NetworkTypeMobile
  578. webRTCCoordinator.disableInboundForMobileNetworks = true
  579. }
  580. brokerClient, err := NewBrokerClient(brokerCoordinator)
  581. if err != nil {
  582. return nil, nil, errors.Trace(err)
  583. }
  584. return brokerClient, webRTCCoordinator, nil
  585. }
  586. clientBrokerClient, clientWebRTCCoordinator, err := newClientParams(false)
  587. if err != nil {
  588. return errors.Trace(err)
  589. }
  590. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err := newClientParams(true)
  591. if err != nil {
  592. return errors.Trace(err)
  593. }
  594. for i := 0; i < numClients; i++ {
  595. // Test a mix of TCP and UDP proxying; also test the
  596. // DisableInboundForMobileNetworks code path.
  597. isTCP := i%2 == 0
  598. isMobile := i%4 == 0
  599. // Exercise BrokerClients shared by multiple clients, but also create
  600. // several broker clients.
  601. if i%8 == 0 {
  602. clientBrokerClient, clientWebRTCCoordinator, err = newClientParams(false)
  603. if err != nil {
  604. return errors.Trace(err)
  605. }
  606. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err = newClientParams(true)
  607. if err != nil {
  608. return errors.Trace(err)
  609. }
  610. }
  611. brokerClient := clientBrokerClient
  612. webRTCCoordinator := clientWebRTCCoordinator
  613. if isMobile {
  614. brokerClient = clientMobileBrokerClient
  615. webRTCCoordinator = clientMobileWebRTCCoordinator
  616. }
  617. clientsGroup.Go(makeClientFunc(isTCP, isMobile, brokerClient, webRTCCoordinator))
  618. }
  619. if doMustUpgrade {
  620. // Await MustUpgrade callbacks
  621. logger.WithTrace().Info("AWAIT MUST UPGRADE")
  622. <-receivedProxyMustUpgrade
  623. <-receivedClientMustUpgrade
  624. _ = clientsGroup.Wait()
  625. } else {
  626. // Await client transfers complete
  627. logger.WithTrace().Info("AWAIT DATA TRANSFER")
  628. err = clientsGroup.Wait()
  629. if err != nil {
  630. return errors.Trace(err)
  631. }
  632. logger.WithTrace().Info("DONE DATA TRANSFER")
  633. if hasPendingBrokerServerReports() {
  634. return errors.TraceNew("unexpected pending broker server requests")
  635. }
  636. if hasPendingProxyTacticsCallbacks() {
  637. return errors.TraceNew("unexpected pending proxy tactics callback")
  638. }
  639. // TODO: check that elapsed time is consistent with rate limit (+/-)
  640. // Check if STUN server replay callbacks were triggered
  641. if !testDisableSTUN {
  642. if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
  643. return errors.TraceNew("unexpected STUN server succeeded count")
  644. }
  645. }
  646. if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
  647. return errors.TraceNew("unexpected STUN server failed count")
  648. }
  649. // Check if RoundTripper server replay callbacks were triggered
  650. if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
  651. return errors.TraceNew("unexpected round tripper succeeded count")
  652. }
  653. if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
  654. return errors.TraceNew("unexpected round tripper failed count")
  655. }
  656. }
  657. // Await shutdowns
  658. stopTest()
  659. brokerListener.Close()
  660. err = testGroup.Wait()
  661. if err != nil {
  662. return errors.Trace(err)
  663. }
  664. return nil
  665. }
  666. func runHTTPServer(listener net.Listener, broker *Broker) error {
  667. handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  668. // For this test, clients set the path to "/client" and proxies
  669. // set the path to "/proxy" and we use that to create stub GeoIP
  670. // data to pass the not-same-ASN condition.
  671. var geoIPData common.GeoIPData
  672. geoIPData.ASN = r.URL.Path
  673. requestPayload, err := ioutil.ReadAll(
  674. http.MaxBytesReader(w, r.Body, BrokerMaxRequestBodySize))
  675. if err != nil {
  676. fmt.Printf("runHTTPServer ioutil.ReadAll failed: %v\n", err)
  677. http.Error(w, "", http.StatusNotFound)
  678. return
  679. }
  680. clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
  681. extendTimeout := func(timeout time.Duration) {
  682. // TODO: set insufficient initial timeout, so extension is
  683. // required for success
  684. http.NewResponseController(w).SetWriteDeadline(time.Now().Add(timeout))
  685. }
  686. responsePayload, err := broker.HandleSessionPacket(
  687. r.Context(),
  688. extendTimeout,
  689. nil,
  690. clientIP,
  691. geoIPData,
  692. requestPayload)
  693. if err != nil {
  694. fmt.Printf("runHTTPServer HandleSessionPacket failed: %v\n", err)
  695. http.Error(w, "", http.StatusNotFound)
  696. return
  697. }
  698. w.WriteHeader(http.StatusOK)
  699. w.Write(responsePayload)
  700. })
  701. // WriteTimeout will be extended via extendTimeout.
  702. httpServer := &http.Server{
  703. ReadTimeout: 10 * time.Second,
  704. WriteTimeout: 10 * time.Second,
  705. IdleTimeout: 1 * time.Minute,
  706. Handler: handler,
  707. }
  708. certificate, privateKey, _, err := common.GenerateWebServerCertificate("www.example.com")
  709. if err != nil {
  710. return errors.Trace(err)
  711. }
  712. tlsCert, err := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
  713. if err != nil {
  714. return errors.Trace(err)
  715. }
  716. tlsConfig := &tls.Config{
  717. Certificates: []tls.Certificate{tlsCert},
  718. }
  719. err = httpServer.Serve(tls.NewListener(listener, tlsConfig))
  720. return errors.Trace(err)
  721. }
  722. type httpRoundTripper struct {
  723. httpClient *http.Client
  724. endpointAddr string
  725. path string
  726. }
  727. func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
  728. return &httpRoundTripper{
  729. httpClient: &http.Client{
  730. Transport: &http.Transport{
  731. ForceAttemptHTTP2: true,
  732. MaxIdleConns: 2,
  733. IdleConnTimeout: 1 * time.Minute,
  734. TLSHandshakeTimeout: 10 * time.Second,
  735. TLSClientConfig: &tls.Config{
  736. InsecureSkipVerify: true,
  737. },
  738. },
  739. },
  740. endpointAddr: endpointAddr,
  741. path: path,
  742. }
  743. }
  744. func (r *httpRoundTripper) RoundTrip(
  745. ctx context.Context,
  746. roundTripDelay time.Duration,
  747. roundTripTimeout time.Duration,
  748. requestPayload []byte) ([]byte, error) {
  749. if roundTripDelay > 0 {
  750. common.SleepWithContext(ctx, roundTripDelay)
  751. }
  752. requestCtx, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
  753. defer requestCancelFunc()
  754. url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
  755. request, err := http.NewRequestWithContext(
  756. requestCtx, "POST", url, bytes.NewReader(requestPayload))
  757. if err != nil {
  758. return nil, errors.Trace(err)
  759. }
  760. response, err := r.httpClient.Do(request)
  761. if err != nil {
  762. return nil, errors.Trace(err)
  763. }
  764. defer response.Body.Close()
  765. if response.StatusCode != http.StatusOK {
  766. return nil, errors.Tracef("unexpected response status code: %d", response.StatusCode)
  767. }
  768. responsePayload, err := io.ReadAll(response.Body)
  769. if err != nil {
  770. return nil, errors.Trace(err)
  771. }
  772. return responsePayload, nil
  773. }
  774. func (r *httpRoundTripper) Close() error {
  775. r.httpClient.CloseIdleConnections()
  776. return nil
  777. }
  778. func runTCPEchoServer(listener net.Listener) {
  779. for {
  780. conn, err := listener.Accept()
  781. if err != nil {
  782. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  783. return
  784. }
  785. go func(conn net.Conn) {
  786. buf := make([]byte, 32768)
  787. for {
  788. n, err := conn.Read(buf)
  789. if n > 0 {
  790. _, err = conn.Write(buf[:n])
  791. }
  792. if err != nil {
  793. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  794. return
  795. }
  796. }
  797. }(conn)
  798. }
  799. }
  800. type quicEchoServer struct {
  801. listener net.Listener
  802. obfuscationKey string
  803. }
  804. func newQuicEchoServer() (*quicEchoServer, error) {
  805. obfuscationKey := prng.HexString(32)
  806. listener, err := quic.Listen(
  807. nil,
  808. nil,
  809. "127.0.0.1:0",
  810. obfuscationKey,
  811. false)
  812. if err != nil {
  813. return nil, errors.Trace(err)
  814. }
  815. return &quicEchoServer{
  816. listener: listener,
  817. obfuscationKey: obfuscationKey,
  818. }, nil
  819. }
  820. func (q *quicEchoServer) ObfuscationKey() string {
  821. return q.obfuscationKey
  822. }
  823. func (q *quicEchoServer) Close() error {
  824. return q.listener.Close()
  825. }
  826. func (q *quicEchoServer) Addr() net.Addr {
  827. return q.listener.Addr()
  828. }
  829. func (q *quicEchoServer) Run() {
  830. for {
  831. conn, err := q.listener.Accept()
  832. if err != nil {
  833. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  834. return
  835. }
  836. go func(conn net.Conn) {
  837. buf := make([]byte, 32768)
  838. for {
  839. n, err := conn.Read(buf)
  840. if n > 0 {
  841. _, err = conn.Write(buf[:n])
  842. }
  843. if err != nil {
  844. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  845. return
  846. }
  847. }
  848. }(conn)
  849. }
  850. }