inproxy_test.go 28 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. //go:build PSIPHON_ENABLE_INPROXY
  2. /*
  3. * Copyright (c) 2023, Psiphon Inc.
  4. * All rights reserved.
  5. *
  6. * This program is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License as published by
  8. * the Free Software Foundation, either version 3 of the License, or
  9. * (at your option) any later version.
  10. *
  11. * This program is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. * GNU General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU General Public License
  17. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. *
  19. */
  20. package inproxy
  21. import (
  22. "bytes"
  23. "context"
  24. std_tls "crypto/tls"
  25. "encoding/base64"
  26. "fmt"
  27. "io"
  28. "io/ioutil"
  29. "net"
  30. "net/http"
  31. _ "net/http/pprof"
  32. "strconv"
  33. "strings"
  34. "sync"
  35. "sync/atomic"
  36. "testing"
  37. "time"
  38. tls "github.com/Psiphon-Labs/psiphon-tls"
  39. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  40. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  43. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
  44. "golang.org/x/sync/errgroup"
  45. )
  46. func TestInproxy(t *testing.T) {
  47. err := runTestInproxy(false)
  48. if err != nil {
  49. t.Errorf(errors.Trace(err).Error())
  50. }
  51. }
  52. func TestInproxyMustUpgrade(t *testing.T) {
  53. err := runTestInproxy(true)
  54. if err != nil {
  55. t.Errorf(errors.Trace(err).Error())
  56. }
  57. }
  58. func runTestInproxy(doMustUpgrade bool) error {
  59. // Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
  60. numProxies := 5
  61. proxyMaxClients := 3
  62. numClients := 10
  63. bytesToSend := 1 << 20
  64. targetElapsedSeconds := 2
  65. baseAPIParameters := common.APIParameters{
  66. "sponsor_id": strings.ToUpper(prng.HexString(8)),
  67. "client_platform": "test-client-platform",
  68. }
  69. testCompartmentID, _ := MakeID()
  70. testCommonCompartmentIDs := []ID{testCompartmentID}
  71. testNetworkID := "NETWORK-ID-1"
  72. testNetworkType := NetworkTypeUnknown
  73. testNATType := NATTypeUnknown
  74. testSTUNServerAddress := "stun.nextcloud.com:443"
  75. testDisableSTUN := false
  76. testNewTacticsPayload := []byte(prng.HexString(100))
  77. testNewTacticsTag := "new-tactics-tag"
  78. testUnchangedTacticsPayload := []byte(prng.HexString(100))
  79. // TODO: test port mapping
  80. stunServerAddressSucceededCount := int32(0)
  81. stunServerAddressSucceeded := func(bool, string) { atomic.AddInt32(&stunServerAddressSucceededCount, 1) }
  82. stunServerAddressFailedCount := int32(0)
  83. stunServerAddressFailed := func(bool, string) { atomic.AddInt32(&stunServerAddressFailedCount, 1) }
  84. roundTripperSucceededCount := int32(0)
  85. roundTripperSucceded := func(RoundTripper) { atomic.AddInt32(&roundTripperSucceededCount, 1) }
  86. roundTripperFailedCount := int32(0)
  87. roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
  88. var receivedProxyMustUpgrade chan struct{}
  89. var receivedClientMustUpgrade chan struct{}
  90. if doMustUpgrade {
  91. receivedProxyMustUpgrade = make(chan struct{})
  92. receivedClientMustUpgrade = make(chan struct{})
  93. // trigger MustUpgrade
  94. proxyProtocolVersion = 0
  95. // Minimize test parameters for MustUpgrade case
  96. numProxies = 1
  97. proxyMaxClients = 1
  98. numClients = 1
  99. testDisableSTUN = true
  100. }
  101. testCtx, stopTest := context.WithCancel(context.Background())
  102. defer stopTest()
  103. testGroup := new(errgroup.Group)
  104. // Enable test to run without requiring host firewall exceptions
  105. SetAllowBogonWebRTCConnections(true)
  106. defer SetAllowBogonWebRTCConnections(false)
  107. // Init logging and profiling
  108. logger := newTestLogger()
  109. pprofListener, err := net.Listen("tcp", "127.0.0.1:0")
  110. go http.Serve(pprofListener, nil)
  111. defer pprofListener.Close()
  112. logger.WithTrace().Info(fmt.Sprintf("PPROF: http://%s/debug/pprof", pprofListener.Addr()))
  113. // Start echo servers
  114. tcpEchoListener, err := net.Listen("tcp", "127.0.0.1:0")
  115. if err != nil {
  116. return errors.Trace(err)
  117. }
  118. defer tcpEchoListener.Close()
  119. go runTCPEchoServer(tcpEchoListener)
  120. // QUIC tests UDP proxying, and provides reliable delivery of echoed data
  121. quicEchoServer, err := newQuicEchoServer()
  122. if err != nil {
  123. return errors.Trace(err)
  124. }
  125. defer quicEchoServer.Close()
  126. go quicEchoServer.Run()
  127. // Create signed server entry with capability
  128. serverPrivateKey, err := GenerateSessionPrivateKey()
  129. if err != nil {
  130. return errors.Trace(err)
  131. }
  132. serverPublicKey, err := serverPrivateKey.GetPublicKey()
  133. if err != nil {
  134. return errors.Trace(err)
  135. }
  136. serverRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  137. if err != nil {
  138. return errors.Trace(err)
  139. }
  140. serverEntry := make(protocol.ServerEntryFields)
  141. serverEntry["ipAddress"] = "127.0.0.1"
  142. _, tcpPort, _ := net.SplitHostPort(tcpEchoListener.Addr().String())
  143. _, udpPort, _ := net.SplitHostPort(quicEchoServer.Addr().String())
  144. serverEntry["inproxyOSSHPort"], _ = strconv.Atoi(tcpPort)
  145. serverEntry["inproxyQUICPort"], _ = strconv.Atoi(udpPort)
  146. serverEntry["capabilities"] = []string{"INPROXY-WEBRTC-OSSH", "INPROXY-WEBRTC-QUIC-OSSH"}
  147. serverEntry["inproxySessionPublicKey"] = base64.RawStdEncoding.EncodeToString(serverPublicKey[:])
  148. serverEntry["inproxySessionRootObfuscationSecret"] = base64.RawStdEncoding.EncodeToString(serverRootObfuscationSecret[:])
  149. testServerEntryTag := prng.HexString(16)
  150. serverEntry["tag"] = testServerEntryTag
  151. serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey, err :=
  152. protocol.NewServerEntrySignatureKeyPair()
  153. if err != nil {
  154. return errors.Trace(err)
  155. }
  156. err = serverEntry.AddSignature(serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey)
  157. if err != nil {
  158. return errors.Trace(err)
  159. }
  160. packedServerEntryFields, err := protocol.EncodePackedServerEntryFields(serverEntry)
  161. if err != nil {
  162. return errors.Trace(err)
  163. }
  164. packedDestinationServerEntry, err := protocol.CBOREncoding.Marshal(packedServerEntryFields)
  165. if err != nil {
  166. return errors.Trace(err)
  167. }
  168. // Start broker
  169. logger.WithTrace().Info("START BROKER")
  170. brokerPrivateKey, err := GenerateSessionPrivateKey()
  171. if err != nil {
  172. return errors.Trace(err)
  173. }
  174. brokerPublicKey, err := brokerPrivateKey.GetPublicKey()
  175. if err != nil {
  176. return errors.Trace(err)
  177. }
  178. brokerRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  179. if err != nil {
  180. return errors.Trace(err)
  181. }
  182. brokerListener, err := net.Listen("tcp", "127.0.0.1:0")
  183. if err != nil {
  184. return errors.Trace(err)
  185. }
  186. defer brokerListener.Close()
  187. brokerConfig := &BrokerConfig{
  188. Logger: logger,
  189. CommonCompartmentIDs: testCommonCompartmentIDs,
  190. APIParameterValidator: func(params common.APIParameters) error {
  191. if len(params) != len(baseAPIParameters) {
  192. return errors.TraceNew("unexpected base API parameter count")
  193. }
  194. for name, value := range params {
  195. if value.(string) != baseAPIParameters[name].(string) {
  196. return errors.Tracef(
  197. "unexpected base API parameter: %v: %v != %v",
  198. name,
  199. value.(string),
  200. baseAPIParameters[name].(string))
  201. }
  202. }
  203. return nil
  204. },
  205. APIParameterLogFieldFormatter: func(
  206. geoIPData common.GeoIPData, params common.APIParameters) common.LogFields {
  207. return common.LogFields(params)
  208. },
  209. GetTactics: func(_ common.GeoIPData, _ common.APIParameters) ([]byte, string, error) {
  210. // Exercise both new and unchanged tactics
  211. if prng.FlipCoin() {
  212. return testNewTacticsPayload, testNewTacticsTag, nil
  213. }
  214. return testUnchangedTacticsPayload, "", nil
  215. },
  216. IsValidServerEntryTag: func(serverEntryTag string) bool { return serverEntryTag == testServerEntryTag },
  217. PrivateKey: brokerPrivateKey,
  218. ObfuscationRootSecret: brokerRootObfuscationSecret,
  219. ServerEntrySignaturePublicKey: serverEntrySignaturePublicKey,
  220. AllowProxy: func(common.GeoIPData) bool { return true },
  221. AllowClient: func(common.GeoIPData) bool { return true },
  222. AllowDomainFrontedDestinations: func(common.GeoIPData) bool { return true },
  223. }
  224. broker, err := NewBroker(brokerConfig)
  225. if err != nil {
  226. return errors.Trace(err)
  227. }
  228. err = broker.Start()
  229. if err != nil {
  230. return errors.Trace(err)
  231. }
  232. defer broker.Stop()
  233. testGroup.Go(func() error {
  234. err := runHTTPServer(brokerListener, broker)
  235. if testCtx.Err() != nil {
  236. return nil
  237. }
  238. return errors.Trace(err)
  239. })
  240. // Stub server broker request handler (in Psiphon, this will be the
  241. // destination Psiphon server; here, it's not necessary to build this
  242. // handler into the destination echo server)
  243. serverSessions, err := NewServerBrokerSessions(
  244. serverPrivateKey, serverRootObfuscationSecret, []SessionPublicKey{brokerPublicKey})
  245. if err != nil {
  246. return errors.Trace(err)
  247. }
  248. var pendingBrokerServerReportsMutex sync.Mutex
  249. pendingBrokerServerReports := make(map[ID]bool)
  250. addPendingBrokerServerReport := func(connectionID ID) {
  251. pendingBrokerServerReportsMutex.Lock()
  252. defer pendingBrokerServerReportsMutex.Unlock()
  253. pendingBrokerServerReports[connectionID] = true
  254. }
  255. hasPendingBrokerServerReports := func() bool {
  256. pendingBrokerServerReportsMutex.Lock()
  257. defer pendingBrokerServerReportsMutex.Unlock()
  258. return len(pendingBrokerServerReports) > 0
  259. }
  260. handleBrokerServerReports := func(in []byte, clientConnectionID ID) ([]byte, error) {
  261. handler := func(brokerVerifiedOriginalClientIP string, logFields common.LogFields) {
  262. pendingBrokerServerReportsMutex.Lock()
  263. defer pendingBrokerServerReportsMutex.Unlock()
  264. // Mark the report as no longer outstanding
  265. delete(pendingBrokerServerReports, clientConnectionID)
  266. }
  267. out, err := serverSessions.HandlePacket(logger, in, clientConnectionID, handler)
  268. return out, errors.Trace(err)
  269. }
  270. // Check that the tactics round trip succeeds
  271. var pendingProxyTacticsCallbacksMutex sync.Mutex
  272. pendingProxyTacticsCallbacks := make(map[SessionPrivateKey]bool)
  273. addPendingProxyTacticsCallback := func(proxyPrivateKey SessionPrivateKey) {
  274. pendingProxyTacticsCallbacksMutex.Lock()
  275. defer pendingProxyTacticsCallbacksMutex.Unlock()
  276. pendingProxyTacticsCallbacks[proxyPrivateKey] = true
  277. }
  278. hasPendingProxyTacticsCallbacks := func() bool {
  279. pendingProxyTacticsCallbacksMutex.Lock()
  280. defer pendingProxyTacticsCallbacksMutex.Unlock()
  281. return len(pendingProxyTacticsCallbacks) > 0
  282. }
  283. makeHandleTacticsPayload := func(
  284. proxyPrivateKey SessionPrivateKey,
  285. tacticsNetworkID string) func(_ string, _ []byte) bool {
  286. return func(networkID string, tacticsPayload []byte) bool {
  287. pendingProxyTacticsCallbacksMutex.Lock()
  288. defer pendingProxyTacticsCallbacksMutex.Unlock()
  289. // Check that the correct networkID is passed around; if not,
  290. // skip the delete, which will fail the test
  291. if networkID == tacticsNetworkID {
  292. // Certain state is reset when new tactics are applied -- the
  293. // return true case; exercise both cases
  294. if bytes.Equal(tacticsPayload, testNewTacticsPayload) {
  295. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  296. return true
  297. }
  298. if bytes.Equal(tacticsPayload, testUnchangedTacticsPayload) {
  299. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  300. return false
  301. }
  302. }
  303. panic("unexpected tactics payload")
  304. }
  305. }
  306. // Start proxies
  307. logger.WithTrace().Info("START PROXIES")
  308. for i := 0; i < numProxies; i++ {
  309. proxyPrivateKey, err := GenerateSessionPrivateKey()
  310. if err != nil {
  311. return errors.Trace(err)
  312. }
  313. brokerCoordinator := &testBrokerDialCoordinator{
  314. networkID: testNetworkID,
  315. networkType: testNetworkType,
  316. brokerClientPrivateKey: proxyPrivateKey,
  317. brokerPublicKey: brokerPublicKey,
  318. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  319. brokerClientRoundTripper: newHTTPRoundTripper(
  320. brokerListener.Addr().String(), "proxy"),
  321. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  322. brokerClientRoundTripperFailed: roundTripperFailed,
  323. }
  324. webRTCCoordinator := &testWebRTCDialCoordinator{
  325. networkID: testNetworkID,
  326. networkType: testNetworkType,
  327. natType: testNATType,
  328. disableSTUN: testDisableSTUN,
  329. stunServerAddress: testSTUNServerAddress,
  330. stunServerAddressRFC5780: testSTUNServerAddress,
  331. stunServerAddressSucceeded: stunServerAddressSucceeded,
  332. stunServerAddressFailed: stunServerAddressFailed,
  333. setNATType: func(NATType) {},
  334. setPortMappingTypes: func(PortMappingTypes) {},
  335. bindToDevice: func(int) error { return nil },
  336. }
  337. // Each proxy has its own broker client
  338. brokerClient, err := NewBrokerClient(brokerCoordinator)
  339. if err != nil {
  340. return errors.Trace(err)
  341. }
  342. tacticsNetworkID := prng.HexString(32)
  343. runCtx, cancelRun := context.WithCancel(testCtx)
  344. // No deferred cancelRun due to testGroup.Go below
  345. proxy, err := NewProxy(&ProxyConfig{
  346. Logger: logger,
  347. WaitForNetworkConnectivity: func() bool {
  348. return true
  349. },
  350. GetBrokerClient: func() (*BrokerClient, error) {
  351. return brokerClient, nil
  352. },
  353. GetBaseAPIParameters: func() (common.APIParameters, string, error) {
  354. return baseAPIParameters, tacticsNetworkID, nil
  355. },
  356. MakeWebRTCDialCoordinator: func() (WebRTCDialCoordinator, error) {
  357. return webRTCCoordinator, nil
  358. },
  359. HandleTacticsPayload: makeHandleTacticsPayload(proxyPrivateKey, tacticsNetworkID),
  360. MaxClients: proxyMaxClients,
  361. LimitUpstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  362. LimitDownstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  363. ActivityUpdater: func(connectingClients int32, connectedClients int32,
  364. bytesUp int64, bytesDown int64, bytesDuration time.Duration) {
  365. fmt.Printf("[%s] ACTIVITY: %d connecting, %d connected, %d up, %d down\n",
  366. time.Now().UTC().Format(time.RFC3339),
  367. connectingClients, connectedClients, bytesUp, bytesDown)
  368. },
  369. MustUpgrade: func() {
  370. close(receivedProxyMustUpgrade)
  371. cancelRun()
  372. },
  373. })
  374. if err != nil {
  375. return errors.Trace(err)
  376. }
  377. addPendingProxyTacticsCallback(proxyPrivateKey)
  378. testGroup.Go(func() error {
  379. proxy.Run(runCtx)
  380. return nil
  381. })
  382. }
  383. // Await proxy announcements before starting clients
  384. //
  385. // - Announcements may delay due to proxyAnnounceRetryDelay in Proxy.Run,
  386. // plus NAT discovery
  387. //
  388. // - Don't wait for > numProxies announcements due to
  389. // InitiatorSessions.NewRoundTrip waitToShareSession limitation
  390. if !doMustUpgrade {
  391. for {
  392. time.Sleep(100 * time.Millisecond)
  393. broker.matcher.announcementQueueMutex.Lock()
  394. n := broker.matcher.announcementQueue.getLen()
  395. broker.matcher.announcementQueueMutex.Unlock()
  396. if n >= numProxies {
  397. break
  398. }
  399. }
  400. }
  401. // Start clients
  402. logger.WithTrace().Info("START CLIENTS")
  403. clientsGroup := new(errgroup.Group)
  404. makeClientFunc := func(
  405. isTCP bool,
  406. isMobile bool,
  407. brokerClient *BrokerClient,
  408. webRTCCoordinator WebRTCDialCoordinator) func() error {
  409. var networkProtocol NetworkProtocol
  410. var addr string
  411. var wrapWithQUIC bool
  412. if isTCP {
  413. networkProtocol = NetworkProtocolTCP
  414. addr = tcpEchoListener.Addr().String()
  415. } else {
  416. networkProtocol = NetworkProtocolUDP
  417. addr = quicEchoServer.Addr().String()
  418. wrapWithQUIC = true
  419. }
  420. return func() error {
  421. dialCtx, cancelDial := context.WithTimeout(testCtx, 60*time.Second)
  422. defer cancelDial()
  423. conn, err := DialClient(
  424. dialCtx,
  425. &ClientConfig{
  426. Logger: logger,
  427. BaseAPIParameters: baseAPIParameters,
  428. BrokerClient: brokerClient,
  429. WebRTCDialCoordinator: webRTCCoordinator,
  430. ReliableTransport: isTCP,
  431. DialNetworkProtocol: networkProtocol,
  432. DialAddress: addr,
  433. PackedDestinationServerEntry: packedDestinationServerEntry,
  434. MustUpgrade: func() {
  435. fmt.Printf("HI!\n")
  436. close(receivedClientMustUpgrade)
  437. cancelDial()
  438. },
  439. })
  440. if err != nil {
  441. return errors.Trace(err)
  442. }
  443. var relayConn net.Conn
  444. relayConn = conn
  445. if wrapWithQUIC {
  446. quicConn, err := quic.Dial(
  447. dialCtx,
  448. conn,
  449. &net.UDPAddr{Port: 1}, // This address is ignored, but the zero value is not allowed
  450. "test", "QUICv1", nil, quicEchoServer.ObfuscationKey(), nil, nil, true,
  451. false, false, common.WrapClientSessionCache(tls.NewLRUClientSessionCache(0), ""),
  452. )
  453. if err != nil {
  454. return errors.Trace(err)
  455. }
  456. relayConn = quicConn
  457. }
  458. addPendingBrokerServerReport(conn.GetConnectionID())
  459. signalRelayComplete := make(chan struct{})
  460. clientsGroup.Go(func() error {
  461. defer close(signalRelayComplete)
  462. in := conn.InitialRelayPacket()
  463. for in != nil {
  464. out, err := handleBrokerServerReports(in, conn.GetConnectionID())
  465. if err != nil {
  466. if out == nil {
  467. return errors.Trace(err)
  468. } else {
  469. fmt.Printf("HandlePacket returned packet and error: %v\n", err)
  470. // Proceed with reset session token packet
  471. }
  472. }
  473. if out == nil {
  474. // Relay is complete
  475. break
  476. }
  477. in, err = conn.RelayPacket(testCtx, out)
  478. if err != nil {
  479. return errors.Trace(err)
  480. }
  481. }
  482. return nil
  483. })
  484. sendBytes := prng.Bytes(bytesToSend)
  485. clientsGroup.Go(func() error {
  486. for n := 0; n < bytesToSend; {
  487. m := prng.Range(1024, 32768)
  488. if bytesToSend-n < m {
  489. m = bytesToSend - n
  490. }
  491. _, err := relayConn.Write(sendBytes[n : n+m])
  492. if err != nil {
  493. return errors.Trace(err)
  494. }
  495. n += m
  496. }
  497. fmt.Printf("%d bytes sent\n", bytesToSend)
  498. return nil
  499. })
  500. clientsGroup.Go(func() error {
  501. buf := make([]byte, 32768)
  502. n := 0
  503. for n < bytesToSend {
  504. m, err := relayConn.Read(buf)
  505. if err != nil {
  506. return errors.Trace(err)
  507. }
  508. if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
  509. return errors.Tracef(
  510. "unexpected bytes: expected at index %d, received at index %d",
  511. bytes.Index(sendBytes, buf[:m]), n)
  512. }
  513. n += m
  514. }
  515. fmt.Printf("%d bytes received\n", bytesToSend)
  516. select {
  517. case <-signalRelayComplete:
  518. case <-testCtx.Done():
  519. }
  520. relayConn.Close()
  521. conn.Close()
  522. return nil
  523. })
  524. return nil
  525. }
  526. }
  527. newClientParams := func(isMobile bool) (*BrokerClient, *testWebRTCDialCoordinator, error) {
  528. clientPrivateKey, err := GenerateSessionPrivateKey()
  529. if err != nil {
  530. return nil, nil, errors.Trace(err)
  531. }
  532. clientRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  533. if err != nil {
  534. return nil, nil, errors.Trace(err)
  535. }
  536. brokerCoordinator := &testBrokerDialCoordinator{
  537. networkID: testNetworkID,
  538. networkType: testNetworkType,
  539. commonCompartmentIDs: testCommonCompartmentIDs,
  540. brokerClientPrivateKey: clientPrivateKey,
  541. brokerPublicKey: brokerPublicKey,
  542. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  543. brokerClientRoundTripper: newHTTPRoundTripper(
  544. brokerListener.Addr().String(), "client"),
  545. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  546. brokerClientRoundTripperFailed: roundTripperFailed,
  547. }
  548. webRTCCoordinator := &testWebRTCDialCoordinator{
  549. networkID: testNetworkID,
  550. networkType: testNetworkType,
  551. natType: testNATType,
  552. disableSTUN: testDisableSTUN,
  553. stunServerAddress: testSTUNServerAddress,
  554. stunServerAddressRFC5780: testSTUNServerAddress,
  555. stunServerAddressSucceeded: stunServerAddressSucceeded,
  556. stunServerAddressFailed: stunServerAddressFailed,
  557. clientRootObfuscationSecret: clientRootObfuscationSecret,
  558. doDTLSRandomization: prng.FlipCoin(),
  559. trafficShapingParameters: &DataChannelTrafficShapingParameters{
  560. MinPaddedMessages: 0,
  561. MaxPaddedMessages: 10,
  562. MinPaddingSize: 0,
  563. MaxPaddingSize: 1500,
  564. MinDecoyMessages: 0,
  565. MaxDecoyMessages: 10,
  566. MinDecoySize: 1,
  567. MaxDecoySize: 1500,
  568. DecoyMessageProbability: 0.5,
  569. },
  570. setNATType: func(NATType) {},
  571. setPortMappingTypes: func(PortMappingTypes) {},
  572. bindToDevice: func(int) error { return nil },
  573. // With STUN enabled (testDisableSTUN = false), there are cases
  574. // where the WebRTC Data Channel is not successfully established.
  575. // With a short enough timeout here, clients will redial and
  576. // eventually succceed.
  577. webRTCAwaitDataChannelTimeout: 5 * time.Second,
  578. }
  579. if isMobile {
  580. webRTCCoordinator.networkType = NetworkTypeMobile
  581. webRTCCoordinator.disableInboundForMobileNetworks = true
  582. }
  583. brokerClient, err := NewBrokerClient(brokerCoordinator)
  584. if err != nil {
  585. return nil, nil, errors.Trace(err)
  586. }
  587. return brokerClient, webRTCCoordinator, nil
  588. }
  589. clientBrokerClient, clientWebRTCCoordinator, err := newClientParams(false)
  590. if err != nil {
  591. return errors.Trace(err)
  592. }
  593. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err := newClientParams(true)
  594. if err != nil {
  595. return errors.Trace(err)
  596. }
  597. for i := 0; i < numClients; i++ {
  598. // Test a mix of TCP and UDP proxying; also test the
  599. // DisableInboundForMobileNetworks code path.
  600. isTCP := i%2 == 0
  601. isMobile := i%4 == 0
  602. // Exercise BrokerClients shared by multiple clients, but also create
  603. // several broker clients.
  604. if i%8 == 0 {
  605. clientBrokerClient, clientWebRTCCoordinator, err = newClientParams(false)
  606. if err != nil {
  607. return errors.Trace(err)
  608. }
  609. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err = newClientParams(true)
  610. if err != nil {
  611. return errors.Trace(err)
  612. }
  613. }
  614. brokerClient := clientBrokerClient
  615. webRTCCoordinator := clientWebRTCCoordinator
  616. if isMobile {
  617. brokerClient = clientMobileBrokerClient
  618. webRTCCoordinator = clientMobileWebRTCCoordinator
  619. }
  620. clientsGroup.Go(makeClientFunc(isTCP, isMobile, brokerClient, webRTCCoordinator))
  621. }
  622. if doMustUpgrade {
  623. // Await MustUpgrade callbacks
  624. logger.WithTrace().Info("AWAIT MUST UPGRADE")
  625. <-receivedProxyMustUpgrade
  626. <-receivedClientMustUpgrade
  627. _ = clientsGroup.Wait()
  628. } else {
  629. // Await client transfers complete
  630. logger.WithTrace().Info("AWAIT DATA TRANSFER")
  631. err = clientsGroup.Wait()
  632. if err != nil {
  633. return errors.Trace(err)
  634. }
  635. logger.WithTrace().Info("DONE DATA TRANSFER")
  636. if hasPendingBrokerServerReports() {
  637. return errors.TraceNew("unexpected pending broker server requests")
  638. }
  639. if hasPendingProxyTacticsCallbacks() {
  640. return errors.TraceNew("unexpected pending proxy tactics callback")
  641. }
  642. // TODO: check that elapsed time is consistent with rate limit (+/-)
  643. // Check if STUN server replay callbacks were triggered
  644. if !testDisableSTUN {
  645. if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
  646. return errors.TraceNew("unexpected STUN server succeeded count")
  647. }
  648. }
  649. if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
  650. return errors.TraceNew("unexpected STUN server failed count")
  651. }
  652. // Check if RoundTripper server replay callbacks were triggered
  653. if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
  654. return errors.TraceNew("unexpected round tripper succeeded count")
  655. }
  656. if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
  657. return errors.TraceNew("unexpected round tripper failed count")
  658. }
  659. }
  660. // Await shutdowns
  661. stopTest()
  662. brokerListener.Close()
  663. err = testGroup.Wait()
  664. if err != nil {
  665. return errors.Trace(err)
  666. }
  667. return nil
  668. }
  669. func runHTTPServer(listener net.Listener, broker *Broker) error {
  670. handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  671. // For this test, clients set the path to "/client" and proxies
  672. // set the path to "/proxy" and we use that to create stub GeoIP
  673. // data to pass the not-same-ASN condition.
  674. var geoIPData common.GeoIPData
  675. geoIPData.ASN = r.URL.Path
  676. requestPayload, err := ioutil.ReadAll(
  677. http.MaxBytesReader(w, r.Body, BrokerMaxRequestBodySize))
  678. if err != nil {
  679. fmt.Printf("runHTTPServer ioutil.ReadAll failed: %v\n", err)
  680. http.Error(w, "", http.StatusNotFound)
  681. return
  682. }
  683. clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
  684. extendTimeout := func(timeout time.Duration) {
  685. // TODO: set insufficient initial timeout, so extension is
  686. // required for success
  687. http.NewResponseController(w).SetWriteDeadline(time.Now().Add(timeout))
  688. }
  689. responsePayload, err := broker.HandleSessionPacket(
  690. r.Context(),
  691. extendTimeout,
  692. nil,
  693. clientIP,
  694. geoIPData,
  695. requestPayload)
  696. if err != nil {
  697. fmt.Printf("runHTTPServer HandleSessionPacket failed: %v\n", err)
  698. http.Error(w, "", http.StatusNotFound)
  699. return
  700. }
  701. w.WriteHeader(http.StatusOK)
  702. w.Write(responsePayload)
  703. })
  704. // WriteTimeout will be extended via extendTimeout.
  705. httpServer := &http.Server{
  706. ReadTimeout: 10 * time.Second,
  707. WriteTimeout: 10 * time.Second,
  708. IdleTimeout: 1 * time.Minute,
  709. Handler: handler,
  710. }
  711. certificate, privateKey, _, err := common.GenerateWebServerCertificate("www.example.com")
  712. if err != nil {
  713. return errors.Trace(err)
  714. }
  715. tlsCert, err := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
  716. if err != nil {
  717. return errors.Trace(err)
  718. }
  719. tlsConfig := &tls.Config{
  720. Certificates: []tls.Certificate{tlsCert},
  721. }
  722. err = httpServer.Serve(tls.NewListener(listener, tlsConfig))
  723. return errors.Trace(err)
  724. }
  725. type httpRoundTripper struct {
  726. httpClient *http.Client
  727. endpointAddr string
  728. path string
  729. }
  730. func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
  731. return &httpRoundTripper{
  732. httpClient: &http.Client{
  733. Transport: &http.Transport{
  734. ForceAttemptHTTP2: true,
  735. MaxIdleConns: 2,
  736. IdleConnTimeout: 1 * time.Minute,
  737. TLSHandshakeTimeout: 10 * time.Second,
  738. TLSClientConfig: &std_tls.Config{
  739. InsecureSkipVerify: true,
  740. },
  741. },
  742. },
  743. endpointAddr: endpointAddr,
  744. path: path,
  745. }
  746. }
  747. func (r *httpRoundTripper) RoundTrip(
  748. ctx context.Context,
  749. roundTripDelay time.Duration,
  750. roundTripTimeout time.Duration,
  751. requestPayload []byte) ([]byte, error) {
  752. if roundTripDelay > 0 {
  753. common.SleepWithContext(ctx, roundTripDelay)
  754. }
  755. requestCtx, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
  756. defer requestCancelFunc()
  757. url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
  758. request, err := http.NewRequestWithContext(
  759. requestCtx, "POST", url, bytes.NewReader(requestPayload))
  760. if err != nil {
  761. return nil, errors.Trace(err)
  762. }
  763. response, err := r.httpClient.Do(request)
  764. if err != nil {
  765. return nil, errors.Trace(err)
  766. }
  767. defer response.Body.Close()
  768. if response.StatusCode != http.StatusOK {
  769. return nil, errors.Tracef("unexpected response status code: %d", response.StatusCode)
  770. }
  771. responsePayload, err := io.ReadAll(response.Body)
  772. if err != nil {
  773. return nil, errors.Trace(err)
  774. }
  775. return responsePayload, nil
  776. }
  777. func (r *httpRoundTripper) Close() error {
  778. r.httpClient.CloseIdleConnections()
  779. return nil
  780. }
  781. func runTCPEchoServer(listener net.Listener) {
  782. for {
  783. conn, err := listener.Accept()
  784. if err != nil {
  785. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  786. return
  787. }
  788. go func(conn net.Conn) {
  789. buf := make([]byte, 32768)
  790. for {
  791. n, err := conn.Read(buf)
  792. if n > 0 {
  793. _, err = conn.Write(buf[:n])
  794. }
  795. if err != nil {
  796. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  797. return
  798. }
  799. }
  800. }(conn)
  801. }
  802. }
  803. type quicEchoServer struct {
  804. listener net.Listener
  805. obfuscationKey string
  806. }
  807. func newQuicEchoServer() (*quicEchoServer, error) {
  808. obfuscationKey := prng.HexString(32)
  809. listener, err := quic.Listen(
  810. nil,
  811. nil,
  812. "127.0.0.1:0",
  813. obfuscationKey,
  814. false)
  815. if err != nil {
  816. return nil, errors.Trace(err)
  817. }
  818. return &quicEchoServer{
  819. listener: listener,
  820. obfuscationKey: obfuscationKey,
  821. }, nil
  822. }
  823. func (q *quicEchoServer) ObfuscationKey() string {
  824. return q.obfuscationKey
  825. }
  826. func (q *quicEchoServer) Close() error {
  827. return q.listener.Close()
  828. }
  829. func (q *quicEchoServer) Addr() net.Addr {
  830. return q.listener.Addr()
  831. }
  832. func (q *quicEchoServer) Run() {
  833. for {
  834. conn, err := q.listener.Accept()
  835. if err != nil {
  836. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  837. return
  838. }
  839. go func(conn net.Conn) {
  840. buf := make([]byte, 32768)
  841. for {
  842. n, err := conn.Read(buf)
  843. if n > 0 {
  844. _, err = conn.Write(buf[:n])
  845. }
  846. if err != nil {
  847. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  848. return
  849. }
  850. }
  851. }(conn)
  852. }
  853. }