inproxy_test.go 28 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049
  1. //go:build PSIPHON_ENABLE_INPROXY
  2. /*
  3. * Copyright (c) 2023, Psiphon Inc.
  4. * All rights reserved.
  5. *
  6. * This program is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License as published by
  8. * the Free Software Foundation, either version 3 of the License, or
  9. * (at your option) any later version.
  10. *
  11. * This program is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. * GNU General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU General Public License
  17. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. *
  19. */
  20. package inproxy
  21. import (
  22. "bytes"
  23. "context"
  24. std_tls "crypto/tls"
  25. "encoding/base64"
  26. "fmt"
  27. "io"
  28. "io/ioutil"
  29. "net"
  30. "net/http"
  31. _ "net/http/pprof"
  32. "strconv"
  33. "strings"
  34. "sync"
  35. "sync/atomic"
  36. "testing"
  37. "time"
  38. tls "github.com/Psiphon-Labs/psiphon-tls"
  39. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  40. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  43. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
  44. "golang.org/x/sync/errgroup"
  45. )
  46. func TestInproxy(t *testing.T) {
  47. err := runTestInproxy(false)
  48. if err != nil {
  49. t.Errorf(errors.Trace(err).Error())
  50. }
  51. }
  52. func TestInproxyMustUpgrade(t *testing.T) {
  53. err := runTestInproxy(true)
  54. if err != nil {
  55. t.Errorf(errors.Trace(err).Error())
  56. }
  57. }
  58. func runTestInproxy(doMustUpgrade bool) error {
  59. // Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
  60. numProxies := 5
  61. proxyMaxClients := 3
  62. numClients := 10
  63. bytesToSend := 1 << 20
  64. targetElapsedSeconds := 2
  65. baseAPIParameters := common.APIParameters{
  66. "sponsor_id": strings.ToUpper(prng.HexString(8)),
  67. "client_platform": "test-client-platform",
  68. }
  69. testCompartmentID, _ := MakeID()
  70. testCommonCompartmentIDs := []ID{testCompartmentID}
  71. testNetworkID := "NETWORK-ID-1"
  72. testNetworkType := NetworkTypeUnknown
  73. testNATType := NATTypeUnknown
  74. testSTUNServerAddress := "stun.nextcloud.com:443"
  75. testDisableSTUN := false
  76. testNewTacticsPayload := []byte(prng.HexString(100))
  77. testNewTacticsTag := "new-tactics-tag"
  78. testUnchangedTacticsPayload := []byte(prng.HexString(100))
  79. // TODO: test port mapping
  80. stunServerAddressSucceededCount := int32(0)
  81. stunServerAddressSucceeded := func(bool, string) { atomic.AddInt32(&stunServerAddressSucceededCount, 1) }
  82. stunServerAddressFailedCount := int32(0)
  83. stunServerAddressFailed := func(bool, string) { atomic.AddInt32(&stunServerAddressFailedCount, 1) }
  84. roundTripperSucceededCount := int32(0)
  85. roundTripperSucceded := func(RoundTripper) { atomic.AddInt32(&roundTripperSucceededCount, 1) }
  86. roundTripperFailedCount := int32(0)
  87. roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
  88. noMatch := func(RoundTripper) {}
  89. var receivedProxyMustUpgrade chan struct{}
  90. var receivedClientMustUpgrade chan struct{}
  91. if doMustUpgrade {
  92. receivedProxyMustUpgrade = make(chan struct{})
  93. receivedClientMustUpgrade = make(chan struct{})
  94. // trigger MustUpgrade
  95. proxyProtocolVersion = 0
  96. // Minimize test parameters for MustUpgrade case
  97. numProxies = 1
  98. proxyMaxClients = 1
  99. numClients = 1
  100. testDisableSTUN = true
  101. }
  102. testCtx, stopTest := context.WithCancel(context.Background())
  103. defer stopTest()
  104. testGroup := new(errgroup.Group)
  105. // Enable test to run without requiring host firewall exceptions
  106. SetAllowBogonWebRTCConnections(true)
  107. defer SetAllowBogonWebRTCConnections(false)
  108. // Init logging and profiling
  109. logger := newTestLogger()
  110. pprofListener, err := net.Listen("tcp", "127.0.0.1:0")
  111. go http.Serve(pprofListener, nil)
  112. defer pprofListener.Close()
  113. logger.WithTrace().Info(fmt.Sprintf("PPROF: http://%s/debug/pprof", pprofListener.Addr()))
  114. // Start echo servers
  115. tcpEchoListener, err := net.Listen("tcp", "127.0.0.1:0")
  116. if err != nil {
  117. return errors.Trace(err)
  118. }
  119. defer tcpEchoListener.Close()
  120. go runTCPEchoServer(tcpEchoListener)
  121. // QUIC tests UDP proxying, and provides reliable delivery of echoed data
  122. quicEchoServer, err := newQuicEchoServer()
  123. if err != nil {
  124. return errors.Trace(err)
  125. }
  126. defer quicEchoServer.Close()
  127. go quicEchoServer.Run()
  128. // Create signed server entry with capability
  129. serverPrivateKey, err := GenerateSessionPrivateKey()
  130. if err != nil {
  131. return errors.Trace(err)
  132. }
  133. serverPublicKey, err := serverPrivateKey.GetPublicKey()
  134. if err != nil {
  135. return errors.Trace(err)
  136. }
  137. serverRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  138. if err != nil {
  139. return errors.Trace(err)
  140. }
  141. serverEntry := make(protocol.ServerEntryFields)
  142. serverEntry["ipAddress"] = "127.0.0.1"
  143. _, tcpPort, _ := net.SplitHostPort(tcpEchoListener.Addr().String())
  144. _, udpPort, _ := net.SplitHostPort(quicEchoServer.Addr().String())
  145. serverEntry["inproxyOSSHPort"], _ = strconv.Atoi(tcpPort)
  146. serverEntry["inproxyQUICPort"], _ = strconv.Atoi(udpPort)
  147. serverEntry["capabilities"] = []string{"INPROXY-WEBRTC-OSSH", "INPROXY-WEBRTC-QUIC-OSSH"}
  148. serverEntry["inproxySessionPublicKey"] = base64.RawStdEncoding.EncodeToString(serverPublicKey[:])
  149. serverEntry["inproxySessionRootObfuscationSecret"] = base64.RawStdEncoding.EncodeToString(serverRootObfuscationSecret[:])
  150. testServerEntryTag := prng.HexString(16)
  151. serverEntry["tag"] = testServerEntryTag
  152. serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey, err :=
  153. protocol.NewServerEntrySignatureKeyPair()
  154. if err != nil {
  155. return errors.Trace(err)
  156. }
  157. err = serverEntry.AddSignature(serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey)
  158. if err != nil {
  159. return errors.Trace(err)
  160. }
  161. packedServerEntryFields, err := protocol.EncodePackedServerEntryFields(serverEntry)
  162. if err != nil {
  163. return errors.Trace(err)
  164. }
  165. packedDestinationServerEntry, err := protocol.CBOREncoding.Marshal(packedServerEntryFields)
  166. if err != nil {
  167. return errors.Trace(err)
  168. }
  169. // Start broker
  170. logger.WithTrace().Info("START BROKER")
  171. brokerPrivateKey, err := GenerateSessionPrivateKey()
  172. if err != nil {
  173. return errors.Trace(err)
  174. }
  175. brokerPublicKey, err := brokerPrivateKey.GetPublicKey()
  176. if err != nil {
  177. return errors.Trace(err)
  178. }
  179. brokerRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  180. if err != nil {
  181. return errors.Trace(err)
  182. }
  183. brokerListener, err := net.Listen("tcp", "127.0.0.1:0")
  184. if err != nil {
  185. return errors.Trace(err)
  186. }
  187. defer brokerListener.Close()
  188. brokerConfig := &BrokerConfig{
  189. Logger: logger,
  190. CommonCompartmentIDs: testCommonCompartmentIDs,
  191. APIParameterValidator: func(params common.APIParameters) error {
  192. if len(params) != len(baseAPIParameters) {
  193. return errors.TraceNew("unexpected base API parameter count")
  194. }
  195. for name, value := range params {
  196. if value.(string) != baseAPIParameters[name].(string) {
  197. return errors.Tracef(
  198. "unexpected base API parameter: %v: %v != %v",
  199. name,
  200. value.(string),
  201. baseAPIParameters[name].(string))
  202. }
  203. }
  204. return nil
  205. },
  206. APIParameterLogFieldFormatter: func(
  207. geoIPData common.GeoIPData, params common.APIParameters) common.LogFields {
  208. return common.LogFields(params)
  209. },
  210. GetTacticsPayload: func(_ common.GeoIPData, _ common.APIParameters) ([]byte, string, error) {
  211. // Exercise both new and unchanged tactics
  212. if prng.FlipCoin() {
  213. return testNewTacticsPayload, testNewTacticsTag, nil
  214. }
  215. return testUnchangedTacticsPayload, "", nil
  216. },
  217. IsValidServerEntryTag: func(serverEntryTag string) bool { return serverEntryTag == testServerEntryTag },
  218. PrivateKey: brokerPrivateKey,
  219. ObfuscationRootSecret: brokerRootObfuscationSecret,
  220. ServerEntrySignaturePublicKey: serverEntrySignaturePublicKey,
  221. AllowProxy: func(common.GeoIPData) bool { return true },
  222. AllowClient: func(common.GeoIPData) bool { return true },
  223. AllowDomainFrontedDestinations: func(common.GeoIPData) bool { return true },
  224. }
  225. broker, err := NewBroker(brokerConfig)
  226. if err != nil {
  227. return errors.Trace(err)
  228. }
  229. err = broker.Start()
  230. if err != nil {
  231. return errors.Trace(err)
  232. }
  233. defer broker.Stop()
  234. testGroup.Go(func() error {
  235. err := runHTTPServer(brokerListener, broker)
  236. if testCtx.Err() != nil {
  237. return nil
  238. }
  239. return errors.Trace(err)
  240. })
  241. // Stub server broker request handler (in Psiphon, this will be the
  242. // destination Psiphon server; here, it's not necessary to build this
  243. // handler into the destination echo server)
  244. serverSessions, err := NewServerBrokerSessions(
  245. serverPrivateKey, serverRootObfuscationSecret, []SessionPublicKey{brokerPublicKey})
  246. if err != nil {
  247. return errors.Trace(err)
  248. }
  249. var pendingBrokerServerReportsMutex sync.Mutex
  250. pendingBrokerServerReports := make(map[ID]bool)
  251. addPendingBrokerServerReport := func(connectionID ID) {
  252. pendingBrokerServerReportsMutex.Lock()
  253. defer pendingBrokerServerReportsMutex.Unlock()
  254. pendingBrokerServerReports[connectionID] = true
  255. }
  256. hasPendingBrokerServerReports := func() bool {
  257. pendingBrokerServerReportsMutex.Lock()
  258. defer pendingBrokerServerReportsMutex.Unlock()
  259. return len(pendingBrokerServerReports) > 0
  260. }
  261. handleBrokerServerReports := func(in []byte, clientConnectionID ID) ([]byte, error) {
  262. handler := func(brokerVerifiedOriginalClientIP string, logFields common.LogFields) {
  263. pendingBrokerServerReportsMutex.Lock()
  264. defer pendingBrokerServerReportsMutex.Unlock()
  265. // Mark the report as no longer outstanding
  266. delete(pendingBrokerServerReports, clientConnectionID)
  267. }
  268. out, err := serverSessions.HandlePacket(logger, in, clientConnectionID, handler)
  269. return out, errors.Trace(err)
  270. }
  271. // Check that the tactics round trip succeeds
  272. var pendingProxyTacticsCallbacksMutex sync.Mutex
  273. pendingProxyTacticsCallbacks := make(map[SessionPrivateKey]bool)
  274. addPendingProxyTacticsCallback := func(proxyPrivateKey SessionPrivateKey) {
  275. pendingProxyTacticsCallbacksMutex.Lock()
  276. defer pendingProxyTacticsCallbacksMutex.Unlock()
  277. pendingProxyTacticsCallbacks[proxyPrivateKey] = true
  278. }
  279. hasPendingProxyTacticsCallbacks := func() bool {
  280. pendingProxyTacticsCallbacksMutex.Lock()
  281. defer pendingProxyTacticsCallbacksMutex.Unlock()
  282. return len(pendingProxyTacticsCallbacks) > 0
  283. }
  284. makeHandleTacticsPayload := func(
  285. proxyPrivateKey SessionPrivateKey,
  286. tacticsNetworkID string) func(_ string, _ []byte) bool {
  287. return func(networkID string, tacticsPayload []byte) bool {
  288. pendingProxyTacticsCallbacksMutex.Lock()
  289. defer pendingProxyTacticsCallbacksMutex.Unlock()
  290. // Check that the correct networkID is passed around; if not,
  291. // skip the delete, which will fail the test
  292. if networkID == tacticsNetworkID {
  293. // Certain state is reset when new tactics are applied -- the
  294. // return true case; exercise both cases
  295. if bytes.Equal(tacticsPayload, testNewTacticsPayload) {
  296. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  297. return true
  298. }
  299. if bytes.Equal(tacticsPayload, testUnchangedTacticsPayload) {
  300. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  301. return false
  302. }
  303. }
  304. panic("unexpected tactics payload")
  305. }
  306. }
  307. // Start proxies
  308. logger.WithTrace().Info("START PROXIES")
  309. for i := 0; i < numProxies; i++ {
  310. proxyPrivateKey, err := GenerateSessionPrivateKey()
  311. if err != nil {
  312. return errors.Trace(err)
  313. }
  314. brokerCoordinator := &testBrokerDialCoordinator{
  315. networkID: testNetworkID,
  316. networkType: testNetworkType,
  317. brokerClientPrivateKey: proxyPrivateKey,
  318. brokerPublicKey: brokerPublicKey,
  319. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  320. brokerClientRoundTripper: newHTTPRoundTripper(
  321. brokerListener.Addr().String(), "proxy"),
  322. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  323. brokerClientRoundTripperFailed: roundTripperFailed,
  324. }
  325. webRTCCoordinator := &testWebRTCDialCoordinator{
  326. networkID: testNetworkID,
  327. networkType: testNetworkType,
  328. natType: testNATType,
  329. disableSTUN: testDisableSTUN,
  330. stunServerAddress: testSTUNServerAddress,
  331. stunServerAddressRFC5780: testSTUNServerAddress,
  332. stunServerAddressSucceeded: stunServerAddressSucceeded,
  333. stunServerAddressFailed: stunServerAddressFailed,
  334. setNATType: func(NATType) {},
  335. setPortMappingTypes: func(PortMappingTypes) {},
  336. bindToDevice: func(int) error { return nil },
  337. }
  338. // Each proxy has its own broker client
  339. brokerClient, err := NewBrokerClient(brokerCoordinator)
  340. if err != nil {
  341. return errors.Trace(err)
  342. }
  343. tacticsNetworkID := prng.HexString(32)
  344. runCtx, cancelRun := context.WithCancel(testCtx)
  345. // No deferred cancelRun due to testGroup.Go below
  346. proxy, err := NewProxy(&ProxyConfig{
  347. Logger: logger,
  348. WaitForNetworkConnectivity: func() bool {
  349. return true
  350. },
  351. GetBrokerClient: func() (*BrokerClient, error) {
  352. return brokerClient, nil
  353. },
  354. GetBaseAPIParameters: func(bool) (common.APIParameters, string, error) {
  355. return baseAPIParameters, tacticsNetworkID, nil
  356. },
  357. MakeWebRTCDialCoordinator: func() (WebRTCDialCoordinator, error) {
  358. return webRTCCoordinator, nil
  359. },
  360. HandleTacticsPayload: makeHandleTacticsPayload(proxyPrivateKey, tacticsNetworkID),
  361. MaxClients: proxyMaxClients,
  362. LimitUpstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  363. LimitDownstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  364. ActivityUpdater: func(connectingClients int32, connectedClients int32,
  365. bytesUp int64, bytesDown int64, bytesDuration time.Duration) {
  366. fmt.Printf("[%s] ACTIVITY: %d connecting, %d connected, %d up, %d down\n",
  367. time.Now().UTC().Format(time.RFC3339),
  368. connectingClients, connectedClients, bytesUp, bytesDown)
  369. },
  370. MustUpgrade: func() {
  371. close(receivedProxyMustUpgrade)
  372. cancelRun()
  373. },
  374. })
  375. if err != nil {
  376. return errors.Trace(err)
  377. }
  378. addPendingProxyTacticsCallback(proxyPrivateKey)
  379. testGroup.Go(func() error {
  380. proxy.Run(runCtx)
  381. return nil
  382. })
  383. }
  384. // Await proxy announcements before starting clients
  385. //
  386. // - Announcements may delay due to proxyAnnounceRetryDelay in Proxy.Run,
  387. // plus NAT discovery
  388. //
  389. // - Don't wait for > numProxies announcements due to
  390. // InitiatorSessions.NewRoundTrip waitToShareSession limitation
  391. if !doMustUpgrade {
  392. for {
  393. time.Sleep(100 * time.Millisecond)
  394. broker.matcher.announcementQueueMutex.Lock()
  395. n := broker.matcher.announcementQueue.getLen()
  396. broker.matcher.announcementQueueMutex.Unlock()
  397. if n >= numProxies {
  398. break
  399. }
  400. }
  401. }
  402. // Start clients
  403. logger.WithTrace().Info("START CLIENTS")
  404. clientsGroup := new(errgroup.Group)
  405. makeClientFunc := func(
  406. isTCP bool,
  407. isMobile bool,
  408. brokerClient *BrokerClient,
  409. webRTCCoordinator WebRTCDialCoordinator) func() error {
  410. var networkProtocol NetworkProtocol
  411. var addr string
  412. var wrapWithQUIC bool
  413. if isTCP {
  414. networkProtocol = NetworkProtocolTCP
  415. addr = tcpEchoListener.Addr().String()
  416. } else {
  417. networkProtocol = NetworkProtocolUDP
  418. addr = quicEchoServer.Addr().String()
  419. wrapWithQUIC = true
  420. }
  421. return func() error {
  422. dialCtx, cancelDial := context.WithTimeout(testCtx, 60*time.Second)
  423. defer cancelDial()
  424. conn, err := DialClient(
  425. dialCtx,
  426. &ClientConfig{
  427. Logger: logger,
  428. BaseAPIParameters: baseAPIParameters,
  429. BrokerClient: brokerClient,
  430. WebRTCDialCoordinator: webRTCCoordinator,
  431. ReliableTransport: isTCP,
  432. DialNetworkProtocol: networkProtocol,
  433. DialAddress: addr,
  434. PackedDestinationServerEntry: packedDestinationServerEntry,
  435. MustUpgrade: func() {
  436. close(receivedClientMustUpgrade)
  437. cancelDial()
  438. },
  439. })
  440. if err != nil {
  441. return errors.Trace(err)
  442. }
  443. var relayConn net.Conn
  444. relayConn = conn
  445. if wrapWithQUIC {
  446. quicConn, err := quic.Dial(
  447. dialCtx,
  448. conn,
  449. &net.UDPAddr{Port: 1}, // This address is ignored, but the zero value is not allowed
  450. "test", "QUICv1", nil, quicEchoServer.ObfuscationKey(), nil, nil, true,
  451. false, false, common.WrapClientSessionCache(tls.NewLRUClientSessionCache(0), ""),
  452. )
  453. if err != nil {
  454. return errors.Trace(err)
  455. }
  456. relayConn = quicConn
  457. }
  458. addPendingBrokerServerReport(conn.GetConnectionID())
  459. signalRelayComplete := make(chan struct{})
  460. clientsGroup.Go(func() error {
  461. defer close(signalRelayComplete)
  462. in := conn.InitialRelayPacket()
  463. for in != nil {
  464. out, err := handleBrokerServerReports(in, conn.GetConnectionID())
  465. if err != nil {
  466. if out == nil {
  467. return errors.Trace(err)
  468. } else {
  469. fmt.Printf("HandlePacket returned packet and error: %v\n", err)
  470. // Proceed with reset session token packet
  471. }
  472. }
  473. if out == nil {
  474. // Relay is complete
  475. break
  476. }
  477. in, err = conn.RelayPacket(testCtx, out)
  478. if err != nil {
  479. return errors.Trace(err)
  480. }
  481. }
  482. return nil
  483. })
  484. sendBytes := prng.Bytes(bytesToSend)
  485. clientsGroup.Go(func() error {
  486. for n := 0; n < bytesToSend; {
  487. m := prng.Range(1024, 32768)
  488. if bytesToSend-n < m {
  489. m = bytesToSend - n
  490. }
  491. _, err := relayConn.Write(sendBytes[n : n+m])
  492. if err != nil {
  493. return errors.Trace(err)
  494. }
  495. n += m
  496. }
  497. fmt.Printf("%d bytes sent\n", bytesToSend)
  498. return nil
  499. })
  500. clientsGroup.Go(func() error {
  501. buf := make([]byte, 32768)
  502. n := 0
  503. for n < bytesToSend {
  504. m, err := relayConn.Read(buf)
  505. if err != nil {
  506. return errors.Trace(err)
  507. }
  508. if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
  509. return errors.Tracef(
  510. "unexpected bytes: expected at index %d, received at index %d",
  511. bytes.Index(sendBytes, buf[:m]), n)
  512. }
  513. n += m
  514. }
  515. fmt.Printf("%d bytes received\n", bytesToSend)
  516. select {
  517. case <-signalRelayComplete:
  518. case <-testCtx.Done():
  519. }
  520. relayConn.Close()
  521. conn.Close()
  522. return nil
  523. })
  524. return nil
  525. }
  526. }
  527. newClientParams := func(isMobile bool) (*BrokerClient, *testWebRTCDialCoordinator, error) {
  528. clientPrivateKey, err := GenerateSessionPrivateKey()
  529. if err != nil {
  530. return nil, nil, errors.Trace(err)
  531. }
  532. clientRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  533. if err != nil {
  534. return nil, nil, errors.Trace(err)
  535. }
  536. brokerCoordinator := &testBrokerDialCoordinator{
  537. networkID: testNetworkID,
  538. networkType: testNetworkType,
  539. commonCompartmentIDs: testCommonCompartmentIDs,
  540. brokerClientPrivateKey: clientPrivateKey,
  541. brokerPublicKey: brokerPublicKey,
  542. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  543. brokerClientRoundTripper: newHTTPRoundTripper(
  544. brokerListener.Addr().String(), "client"),
  545. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  546. brokerClientRoundTripperFailed: roundTripperFailed,
  547. brokerClientNoMatch: noMatch,
  548. }
  549. webRTCCoordinator := &testWebRTCDialCoordinator{
  550. networkID: testNetworkID,
  551. networkType: testNetworkType,
  552. natType: testNATType,
  553. disableSTUN: testDisableSTUN,
  554. stunServerAddress: testSTUNServerAddress,
  555. stunServerAddressRFC5780: testSTUNServerAddress,
  556. stunServerAddressSucceeded: stunServerAddressSucceeded,
  557. stunServerAddressFailed: stunServerAddressFailed,
  558. clientRootObfuscationSecret: clientRootObfuscationSecret,
  559. doDTLSRandomization: prng.FlipCoin(),
  560. trafficShapingParameters: &DataChannelTrafficShapingParameters{
  561. MinPaddedMessages: 0,
  562. MaxPaddedMessages: 10,
  563. MinPaddingSize: 0,
  564. MaxPaddingSize: 1500,
  565. MinDecoyMessages: 0,
  566. MaxDecoyMessages: 10,
  567. MinDecoySize: 1,
  568. MaxDecoySize: 1500,
  569. DecoyMessageProbability: 0.5,
  570. },
  571. setNATType: func(NATType) {},
  572. setPortMappingTypes: func(PortMappingTypes) {},
  573. bindToDevice: func(int) error { return nil },
  574. // With STUN enabled (testDisableSTUN = false), there are cases
  575. // where the WebRTC Data Channel is not successfully established.
  576. // With a short enough timeout here, clients will redial and
  577. // eventually succceed.
  578. webRTCAwaitDataChannelTimeout: 5 * time.Second,
  579. }
  580. if isMobile {
  581. webRTCCoordinator.networkType = NetworkTypeMobile
  582. webRTCCoordinator.disableInboundForMobileNetworks = true
  583. }
  584. brokerClient, err := NewBrokerClient(brokerCoordinator)
  585. if err != nil {
  586. return nil, nil, errors.Trace(err)
  587. }
  588. return brokerClient, webRTCCoordinator, nil
  589. }
  590. clientBrokerClient, clientWebRTCCoordinator, err := newClientParams(false)
  591. if err != nil {
  592. return errors.Trace(err)
  593. }
  594. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err := newClientParams(true)
  595. if err != nil {
  596. return errors.Trace(err)
  597. }
  598. for i := 0; i < numClients; i++ {
  599. // Test a mix of TCP and UDP proxying; also test the
  600. // DisableInboundForMobileNetworks code path.
  601. isTCP := i%2 == 0
  602. isMobile := i%4 == 0
  603. // Exercise BrokerClients shared by multiple clients, but also create
  604. // several broker clients.
  605. if i%8 == 0 {
  606. clientBrokerClient, clientWebRTCCoordinator, err = newClientParams(false)
  607. if err != nil {
  608. return errors.Trace(err)
  609. }
  610. clientMobileBrokerClient, clientMobileWebRTCCoordinator, err = newClientParams(true)
  611. if err != nil {
  612. return errors.Trace(err)
  613. }
  614. }
  615. brokerClient := clientBrokerClient
  616. webRTCCoordinator := clientWebRTCCoordinator
  617. if isMobile {
  618. brokerClient = clientMobileBrokerClient
  619. webRTCCoordinator = clientMobileWebRTCCoordinator
  620. }
  621. clientsGroup.Go(makeClientFunc(isTCP, isMobile, brokerClient, webRTCCoordinator))
  622. }
  623. if doMustUpgrade {
  624. // Await MustUpgrade callbacks
  625. logger.WithTrace().Info("AWAIT MUST UPGRADE")
  626. <-receivedProxyMustUpgrade
  627. <-receivedClientMustUpgrade
  628. _ = clientsGroup.Wait()
  629. } else {
  630. // Await client transfers complete
  631. logger.WithTrace().Info("AWAIT DATA TRANSFER")
  632. err = clientsGroup.Wait()
  633. if err != nil {
  634. return errors.Trace(err)
  635. }
  636. logger.WithTrace().Info("DONE DATA TRANSFER")
  637. if hasPendingBrokerServerReports() {
  638. return errors.TraceNew("unexpected pending broker server requests")
  639. }
  640. if hasPendingProxyTacticsCallbacks() {
  641. return errors.TraceNew("unexpected pending proxy tactics callback")
  642. }
  643. // TODO: check that elapsed time is consistent with rate limit (+/-)
  644. // Check if STUN server replay callbacks were triggered
  645. if !testDisableSTUN {
  646. if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
  647. return errors.TraceNew("unexpected STUN server succeeded count")
  648. }
  649. }
  650. if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
  651. return errors.TraceNew("unexpected STUN server failed count")
  652. }
  653. // Check if RoundTripper server replay callbacks were triggered
  654. if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
  655. return errors.TraceNew("unexpected round tripper succeeded count")
  656. }
  657. if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
  658. return errors.TraceNew("unexpected round tripper failed count")
  659. }
  660. }
  661. // Await shutdowns
  662. stopTest()
  663. brokerListener.Close()
  664. err = testGroup.Wait()
  665. if err != nil {
  666. return errors.Trace(err)
  667. }
  668. return nil
  669. }
  670. func runHTTPServer(listener net.Listener, broker *Broker) error {
  671. handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  672. // For this test, clients set the path to "/client" and proxies
  673. // set the path to "/proxy" and we use that to create stub GeoIP
  674. // data to pass the not-same-ASN condition.
  675. var geoIPData common.GeoIPData
  676. geoIPData.ASN = r.URL.Path
  677. requestPayload, err := ioutil.ReadAll(
  678. http.MaxBytesReader(w, r.Body, BrokerMaxRequestBodySize))
  679. if err != nil {
  680. fmt.Printf("runHTTPServer ioutil.ReadAll failed: %v\n", err)
  681. http.Error(w, "", http.StatusNotFound)
  682. return
  683. }
  684. clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
  685. extendTimeout := func(timeout time.Duration) {
  686. // TODO: set insufficient initial timeout, so extension is
  687. // required for success
  688. http.NewResponseController(w).SetWriteDeadline(time.Now().Add(timeout))
  689. }
  690. responsePayload, err := broker.HandleSessionPacket(
  691. r.Context(),
  692. extendTimeout,
  693. nil,
  694. clientIP,
  695. geoIPData,
  696. requestPayload)
  697. if err != nil {
  698. fmt.Printf("runHTTPServer HandleSessionPacket failed: %v\n", err)
  699. http.Error(w, "", http.StatusNotFound)
  700. return
  701. }
  702. w.WriteHeader(http.StatusOK)
  703. w.Write(responsePayload)
  704. })
  705. // WriteTimeout will be extended via extendTimeout.
  706. httpServer := &http.Server{
  707. ReadTimeout: 10 * time.Second,
  708. WriteTimeout: 10 * time.Second,
  709. IdleTimeout: 1 * time.Minute,
  710. Handler: handler,
  711. }
  712. certificate, privateKey, _, err := common.GenerateWebServerCertificate("www.example.com")
  713. if err != nil {
  714. return errors.Trace(err)
  715. }
  716. tlsCert, err := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
  717. if err != nil {
  718. return errors.Trace(err)
  719. }
  720. tlsConfig := &tls.Config{
  721. Certificates: []tls.Certificate{tlsCert},
  722. }
  723. err = httpServer.Serve(tls.NewListener(listener, tlsConfig))
  724. return errors.Trace(err)
  725. }
  726. type httpRoundTripper struct {
  727. httpClient *http.Client
  728. endpointAddr string
  729. path string
  730. }
  731. func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
  732. return &httpRoundTripper{
  733. httpClient: &http.Client{
  734. Transport: &http.Transport{
  735. ForceAttemptHTTP2: true,
  736. MaxIdleConns: 2,
  737. IdleConnTimeout: 1 * time.Minute,
  738. TLSHandshakeTimeout: 10 * time.Second,
  739. TLSClientConfig: &std_tls.Config{
  740. InsecureSkipVerify: true,
  741. },
  742. },
  743. },
  744. endpointAddr: endpointAddr,
  745. path: path,
  746. }
  747. }
  748. func (r *httpRoundTripper) RoundTrip(
  749. ctx context.Context,
  750. roundTripDelay time.Duration,
  751. roundTripTimeout time.Duration,
  752. requestPayload []byte) ([]byte, error) {
  753. if roundTripDelay > 0 {
  754. common.SleepWithContext(ctx, roundTripDelay)
  755. }
  756. requestCtx, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
  757. defer requestCancelFunc()
  758. url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
  759. request, err := http.NewRequestWithContext(
  760. requestCtx, "POST", url, bytes.NewReader(requestPayload))
  761. if err != nil {
  762. return nil, errors.Trace(err)
  763. }
  764. response, err := r.httpClient.Do(request)
  765. if err != nil {
  766. return nil, errors.Trace(err)
  767. }
  768. defer response.Body.Close()
  769. if response.StatusCode != http.StatusOK {
  770. return nil, errors.Tracef("unexpected response status code: %d", response.StatusCode)
  771. }
  772. responsePayload, err := io.ReadAll(response.Body)
  773. if err != nil {
  774. return nil, errors.Trace(err)
  775. }
  776. return responsePayload, nil
  777. }
  778. func (r *httpRoundTripper) Close() error {
  779. r.httpClient.CloseIdleConnections()
  780. return nil
  781. }
  782. func runTCPEchoServer(listener net.Listener) {
  783. for {
  784. conn, err := listener.Accept()
  785. if err != nil {
  786. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  787. return
  788. }
  789. go func(conn net.Conn) {
  790. buf := make([]byte, 32768)
  791. for {
  792. n, err := conn.Read(buf)
  793. if n > 0 {
  794. _, err = conn.Write(buf[:n])
  795. }
  796. if err != nil {
  797. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  798. return
  799. }
  800. }
  801. }(conn)
  802. }
  803. }
  804. type quicEchoServer struct {
  805. listener net.Listener
  806. obfuscationKey string
  807. }
  808. func newQuicEchoServer() (*quicEchoServer, error) {
  809. obfuscationKey := prng.HexString(32)
  810. listener, err := quic.Listen(
  811. nil,
  812. nil,
  813. "127.0.0.1:0",
  814. obfuscationKey,
  815. false)
  816. if err != nil {
  817. return nil, errors.Trace(err)
  818. }
  819. return &quicEchoServer{
  820. listener: listener,
  821. obfuscationKey: obfuscationKey,
  822. }, nil
  823. }
  824. func (q *quicEchoServer) ObfuscationKey() string {
  825. return q.obfuscationKey
  826. }
  827. func (q *quicEchoServer) Close() error {
  828. return q.listener.Close()
  829. }
  830. func (q *quicEchoServer) Addr() net.Addr {
  831. return q.listener.Addr()
  832. }
  833. func (q *quicEchoServer) Run() {
  834. for {
  835. conn, err := q.listener.Accept()
  836. if err != nil {
  837. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  838. return
  839. }
  840. go func(conn net.Conn) {
  841. buf := make([]byte, 32768)
  842. for {
  843. n, err := conn.Read(buf)
  844. if n > 0 {
  845. _, err = conn.Write(buf[:n])
  846. }
  847. if err != nil {
  848. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  849. return
  850. }
  851. }
  852. }(conn)
  853. }
  854. }