inproxy_test.go 30 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133
  1. //go:build PSIPHON_ENABLE_INPROXY
  2. /*
  3. * Copyright (c) 2023, Psiphon Inc.
  4. * All rights reserved.
  5. *
  6. * This program is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License as published by
  8. * the Free Software Foundation, either version 3 of the License, or
  9. * (at your option) any later version.
  10. *
  11. * This program is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. * GNU General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU General Public License
  17. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. *
  19. */
  20. package inproxy
  21. import (
  22. "bytes"
  23. "context"
  24. std_tls "crypto/tls"
  25. "encoding/base64"
  26. "fmt"
  27. "io"
  28. "io/ioutil"
  29. "net"
  30. "net/http"
  31. _ "net/http/pprof"
  32. "strconv"
  33. "strings"
  34. "sync"
  35. "sync/atomic"
  36. "testing"
  37. "time"
  38. tls "github.com/Psiphon-Labs/psiphon-tls"
  39. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  40. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  43. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
  44. "golang.org/x/sync/errgroup"
  45. )
  46. func TestInproxy(t *testing.T) {
  47. err := runTestInproxy(false)
  48. if err != nil {
  49. t.Errorf(errors.Trace(err).Error())
  50. }
  51. }
  52. func TestInproxyMustUpgrade(t *testing.T) {
  53. err := runTestInproxy(true)
  54. if err != nil {
  55. t.Errorf(errors.Trace(err).Error())
  56. }
  57. }
  58. func runTestInproxy(doMustUpgrade bool) error {
  59. // Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
  60. numProxies := 5
  61. proxyMaxClients := 3
  62. numClients := 10
  63. bytesToSend := 1 << 20
  64. targetElapsedSeconds := 2
  65. baseAPIParameters := common.APIParameters{
  66. "sponsor_id": strings.ToUpper(prng.HexString(8)),
  67. "client_platform": "test-client-platform",
  68. }
  69. testCompartmentID, _ := MakeID()
  70. testCommonCompartmentIDs := []ID{testCompartmentID}
  71. testNetworkID := "NETWORK-ID-1"
  72. testNetworkType := NetworkTypeUnknown
  73. testNATType := NATTypeUnknown
  74. testSTUNServerAddress := "stun.nextcloud.com:443"
  75. testDisableSTUN := false
  76. testDisablePortMapping := false
  77. testNewTacticsPayload := []byte(prng.HexString(100))
  78. testNewTacticsTag := "new-tactics-tag"
  79. testUnchangedTacticsPayload := []byte(prng.HexString(100))
  80. currentNetworkCtx, currentNetworkCancelFunc := context.WithCancel(context.Background())
  81. defer currentNetworkCancelFunc()
  82. // TODO: test port mapping
  83. stunServerAddressSucceededCount := int32(0)
  84. stunServerAddressSucceeded := func(bool, string) { atomic.AddInt32(&stunServerAddressSucceededCount, 1) }
  85. stunServerAddressFailedCount := int32(0)
  86. stunServerAddressFailed := func(bool, string) { atomic.AddInt32(&stunServerAddressFailedCount, 1) }
  87. roundTripperSucceededCount := int32(0)
  88. roundTripperSucceded := func(RoundTripper) { atomic.AddInt32(&roundTripperSucceededCount, 1) }
  89. roundTripperFailedCount := int32(0)
  90. roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
  91. noMatch := func(RoundTripper) {}
  92. var receivedProxyMustUpgrade chan struct{}
  93. var receivedClientMustUpgrade chan struct{}
  94. if doMustUpgrade {
  95. receivedProxyMustUpgrade = make(chan struct{})
  96. receivedClientMustUpgrade = make(chan struct{})
  97. // trigger MustUpgrade
  98. minimumProxyProtocolVersion = LatestProtocolVersion + 1
  99. minimumClientProtocolVersion = LatestProtocolVersion + 1
  100. // Minimize test parameters for MustUpgrade case
  101. numProxies = 1
  102. proxyMaxClients = 1
  103. numClients = 1
  104. testDisableSTUN = true
  105. testDisablePortMapping = true
  106. }
  107. testCtx, stopTest := context.WithCancel(context.Background())
  108. defer stopTest()
  109. testGroup := new(errgroup.Group)
  110. // Enable test to run without requiring host firewall exceptions
  111. SetAllowBogonWebRTCConnections(true)
  112. defer SetAllowBogonWebRTCConnections(false)
  113. // Init logging and profiling
  114. logger := newTestLogger()
  115. pprofListener, err := net.Listen("tcp", "127.0.0.1:0")
  116. go http.Serve(pprofListener, nil)
  117. defer pprofListener.Close()
  118. logger.WithTrace().Info(fmt.Sprintf("PPROF: http://%s/debug/pprof", pprofListener.Addr()))
  119. // Start echo servers
  120. tcpEchoListener, err := net.Listen("tcp", "127.0.0.1:0")
  121. if err != nil {
  122. return errors.Trace(err)
  123. }
  124. defer tcpEchoListener.Close()
  125. go runTCPEchoServer(tcpEchoListener)
  126. // QUIC tests UDP proxying, and provides reliable delivery of echoed data
  127. quicEchoServer, err := newQuicEchoServer()
  128. if err != nil {
  129. return errors.Trace(err)
  130. }
  131. defer quicEchoServer.Close()
  132. go quicEchoServer.Run()
  133. // Create signed server entry with capability
  134. serverPrivateKey, err := GenerateSessionPrivateKey()
  135. if err != nil {
  136. return errors.Trace(err)
  137. }
  138. serverPublicKey, err := serverPrivateKey.GetPublicKey()
  139. if err != nil {
  140. return errors.Trace(err)
  141. }
  142. serverRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  143. if err != nil {
  144. return errors.Trace(err)
  145. }
  146. serverEntry := make(protocol.ServerEntryFields)
  147. serverEntry["ipAddress"] = "127.0.0.1"
  148. _, tcpPort, _ := net.SplitHostPort(tcpEchoListener.Addr().String())
  149. _, udpPort, _ := net.SplitHostPort(quicEchoServer.Addr().String())
  150. serverEntry["inproxyOSSHPort"], _ = strconv.Atoi(tcpPort)
  151. serverEntry["inproxyQUICPort"], _ = strconv.Atoi(udpPort)
  152. serverEntry["capabilities"] = []string{"INPROXY-WEBRTC-OSSH", "INPROXY-WEBRTC-QUIC-OSSH"}
  153. serverEntry["inproxySessionPublicKey"] = base64.RawStdEncoding.EncodeToString(serverPublicKey[:])
  154. serverEntry["inproxySessionRootObfuscationSecret"] = base64.RawStdEncoding.EncodeToString(serverRootObfuscationSecret[:])
  155. testServerEntryTag := prng.HexString(16)
  156. serverEntry["tag"] = testServerEntryTag
  157. serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey, err :=
  158. protocol.NewServerEntrySignatureKeyPair()
  159. if err != nil {
  160. return errors.Trace(err)
  161. }
  162. err = serverEntry.AddSignature(serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey)
  163. if err != nil {
  164. return errors.Trace(err)
  165. }
  166. packedServerEntryFields, err := protocol.EncodePackedServerEntryFields(serverEntry)
  167. if err != nil {
  168. return errors.Trace(err)
  169. }
  170. packedDestinationServerEntry, err := protocol.CBOREncoding.Marshal(packedServerEntryFields)
  171. if err != nil {
  172. return errors.Trace(err)
  173. }
  174. // API parameter handlers
  175. apiParameterValidator := func(params common.APIParameters) error {
  176. if len(params) != len(baseAPIParameters) {
  177. return errors.TraceNew("unexpected base API parameter count")
  178. }
  179. for name, value := range params {
  180. if value.(string) != baseAPIParameters[name].(string) {
  181. return errors.Tracef(
  182. "unexpected base API parameter: %v: %v != %v",
  183. name,
  184. value.(string),
  185. baseAPIParameters[name].(string))
  186. }
  187. }
  188. return nil
  189. }
  190. apiParameterLogFieldFormatter := func(
  191. _ string, _ common.GeoIPData, params common.APIParameters) common.LogFields {
  192. logFields := common.LogFields{}
  193. logFields.Add(common.LogFields(params))
  194. return logFields
  195. }
  196. // Start broker
  197. logger.WithTrace().Info("START BROKER")
  198. brokerPrivateKey, err := GenerateSessionPrivateKey()
  199. if err != nil {
  200. return errors.Trace(err)
  201. }
  202. brokerPublicKey, err := brokerPrivateKey.GetPublicKey()
  203. if err != nil {
  204. return errors.Trace(err)
  205. }
  206. brokerRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  207. if err != nil {
  208. return errors.Trace(err)
  209. }
  210. brokerListener, err := net.Listen("tcp", "127.0.0.1:0")
  211. if err != nil {
  212. return errors.Trace(err)
  213. }
  214. defer brokerListener.Close()
  215. brokerConfig := &BrokerConfig{
  216. Logger: logger,
  217. CommonCompartmentIDs: testCommonCompartmentIDs,
  218. APIParameterValidator: apiParameterValidator,
  219. APIParameterLogFieldFormatter: apiParameterLogFieldFormatter,
  220. GetTacticsPayload: func(_ common.GeoIPData, _ common.APIParameters) ([]byte, string, error) {
  221. // Exercise both new and unchanged tactics
  222. if prng.FlipCoin() {
  223. return testNewTacticsPayload, testNewTacticsTag, nil
  224. }
  225. return testUnchangedTacticsPayload, "", nil
  226. },
  227. IsValidServerEntryTag: func(serverEntryTag string) bool { return serverEntryTag == testServerEntryTag },
  228. PrivateKey: brokerPrivateKey,
  229. ObfuscationRootSecret: brokerRootObfuscationSecret,
  230. ServerEntrySignaturePublicKey: serverEntrySignaturePublicKey,
  231. AllowProxy: func(common.GeoIPData) bool { return true },
  232. AllowClient: func(common.GeoIPData) bool { return true },
  233. AllowDomainFrontedDestinations: func(common.GeoIPData) bool { return true },
  234. }
  235. broker, err := NewBroker(brokerConfig)
  236. if err != nil {
  237. return errors.Trace(err)
  238. }
  239. err = broker.Start()
  240. if err != nil {
  241. return errors.Trace(err)
  242. }
  243. defer broker.Stop()
  244. testGroup.Go(func() error {
  245. err := runHTTPServer(brokerListener, broker)
  246. if testCtx.Err() != nil {
  247. return nil
  248. }
  249. return errors.Trace(err)
  250. })
  251. // Stub server broker request handler (in Psiphon, this will be the
  252. // destination Psiphon server; here, it's not necessary to build this
  253. // handler into the destination echo server)
  254. serverSessions, err := NewServerBrokerSessions(
  255. serverPrivateKey, serverRootObfuscationSecret, []SessionPublicKey{brokerPublicKey},
  256. apiParameterValidator, apiParameterLogFieldFormatter, "")
  257. if err != nil {
  258. return errors.Trace(err)
  259. }
  260. var pendingBrokerServerReportsMutex sync.Mutex
  261. pendingBrokerServerReports := make(map[ID]bool)
  262. addPendingBrokerServerReport := func(connectionID ID) {
  263. pendingBrokerServerReportsMutex.Lock()
  264. defer pendingBrokerServerReportsMutex.Unlock()
  265. pendingBrokerServerReports[connectionID] = true
  266. }
  267. hasPendingBrokerServerReports := func() bool {
  268. pendingBrokerServerReportsMutex.Lock()
  269. defer pendingBrokerServerReportsMutex.Unlock()
  270. return len(pendingBrokerServerReports) > 0
  271. }
  272. handleBrokerServerReports := func(in []byte, clientConnectionID ID) ([]byte, error) {
  273. handler := func(brokerVerifiedOriginalClientIP string, logFields common.LogFields) {
  274. pendingBrokerServerReportsMutex.Lock()
  275. defer pendingBrokerServerReportsMutex.Unlock()
  276. // Mark the report as no longer outstanding
  277. delete(pendingBrokerServerReports, clientConnectionID)
  278. }
  279. out, err := serverSessions.HandlePacket(logger, in, clientConnectionID, handler)
  280. return out, errors.Trace(err)
  281. }
  282. // Check that the tactics round trip succeeds
  283. var pendingProxyTacticsCallbacksMutex sync.Mutex
  284. pendingProxyTacticsCallbacks := make(map[SessionPrivateKey]bool)
  285. addPendingProxyTacticsCallback := func(proxyPrivateKey SessionPrivateKey) {
  286. pendingProxyTacticsCallbacksMutex.Lock()
  287. defer pendingProxyTacticsCallbacksMutex.Unlock()
  288. pendingProxyTacticsCallbacks[proxyPrivateKey] = true
  289. }
  290. hasPendingProxyTacticsCallbacks := func() bool {
  291. pendingProxyTacticsCallbacksMutex.Lock()
  292. defer pendingProxyTacticsCallbacksMutex.Unlock()
  293. return len(pendingProxyTacticsCallbacks) > 0
  294. }
  295. makeHandleTacticsPayload := func(
  296. proxyPrivateKey SessionPrivateKey,
  297. tacticsNetworkID string) func(_ string, _ []byte) bool {
  298. return func(networkID string, tacticsPayload []byte) bool {
  299. pendingProxyTacticsCallbacksMutex.Lock()
  300. defer pendingProxyTacticsCallbacksMutex.Unlock()
  301. // Check that the correct networkID is passed around; if not,
  302. // skip the delete, which will fail the test
  303. if networkID == tacticsNetworkID {
  304. // Certain state is reset when new tactics are applied -- the
  305. // return true case; exercise both cases
  306. if bytes.Equal(tacticsPayload, testNewTacticsPayload) {
  307. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  308. return true
  309. }
  310. if bytes.Equal(tacticsPayload, testUnchangedTacticsPayload) {
  311. delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
  312. return false
  313. }
  314. }
  315. panic("unexpected tactics payload")
  316. }
  317. }
  318. // Start proxies
  319. logger.WithTrace().Info("START PROXIES")
  320. for i := 0; i < numProxies; i++ {
  321. proxyPrivateKey, err := GenerateSessionPrivateKey()
  322. if err != nil {
  323. return errors.Trace(err)
  324. }
  325. brokerCoordinator := &testBrokerDialCoordinator{
  326. networkID: testNetworkID,
  327. networkType: testNetworkType,
  328. brokerClientPrivateKey: proxyPrivateKey,
  329. brokerPublicKey: brokerPublicKey,
  330. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  331. brokerClientRoundTripper: newHTTPRoundTripper(
  332. brokerListener.Addr().String(), "proxy"),
  333. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  334. brokerClientRoundTripperFailed: roundTripperFailed,
  335. // Minimize the delay before proxies reannounce after dial
  336. // failures, which may occur.
  337. announceDelay: 0,
  338. announceMaxBackoffDelay: 0,
  339. announceDelayJitter: 0.0,
  340. }
  341. webRTCCoordinator := &testWebRTCDialCoordinator{
  342. networkID: testNetworkID,
  343. networkType: testNetworkType,
  344. natType: testNATType,
  345. disableSTUN: testDisableSTUN,
  346. disablePortMapping: testDisablePortMapping,
  347. stunServerAddress: testSTUNServerAddress,
  348. stunServerAddressRFC5780: testSTUNServerAddress,
  349. stunServerAddressSucceeded: stunServerAddressSucceeded,
  350. stunServerAddressFailed: stunServerAddressFailed,
  351. setNATType: func(NATType) {},
  352. setPortMappingTypes: func(PortMappingTypes) {},
  353. bindToDevice: func(int) error { return nil },
  354. // Minimize the delay before proxies reannounce after failed
  355. // connections, which may occur.
  356. webRTCAwaitReadyToProxyTimeout: 5 * time.Second,
  357. proxyRelayInactivityTimeout: 5 * time.Second,
  358. }
  359. // Each proxy has its own broker client
  360. brokerClient, err := NewBrokerClient(brokerCoordinator)
  361. if err != nil {
  362. return errors.Trace(err)
  363. }
  364. tacticsNetworkID := prng.HexString(32)
  365. runCtx, cancelRun := context.WithCancel(testCtx)
  366. // No deferred cancelRun due to testGroup.Go below
  367. name := fmt.Sprintf("proxy-%d", i)
  368. proxy, err := NewProxy(&ProxyConfig{
  369. Logger: newTestLoggerWithComponent(name),
  370. WaitForNetworkConnectivity: func() bool {
  371. return true
  372. },
  373. GetCurrentNetworkContext: func() context.Context {
  374. return currentNetworkCtx
  375. },
  376. GetBrokerClient: func() (*BrokerClient, error) {
  377. return brokerClient, nil
  378. },
  379. GetBaseAPIParameters: func(bool) (common.APIParameters, string, error) {
  380. return baseAPIParameters, tacticsNetworkID, nil
  381. },
  382. MakeWebRTCDialCoordinator: func() (WebRTCDialCoordinator, error) {
  383. return webRTCCoordinator, nil
  384. },
  385. HandleTacticsPayload: makeHandleTacticsPayload(proxyPrivateKey, tacticsNetworkID),
  386. MaxClients: proxyMaxClients,
  387. LimitUpstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  388. LimitDownstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
  389. ActivityUpdater: func(connectingClients int32, connectedClients int32,
  390. bytesUp int64, bytesDown int64, bytesDuration time.Duration) {
  391. fmt.Printf("[%s][%s] ACTIVITY: %d connecting, %d connected, %d up, %d down\n",
  392. time.Now().UTC().Format(time.RFC3339), name,
  393. connectingClients, connectedClients, bytesUp, bytesDown)
  394. },
  395. MustUpgrade: func() {
  396. close(receivedProxyMustUpgrade)
  397. cancelRun()
  398. },
  399. })
  400. if err != nil {
  401. return errors.Trace(err)
  402. }
  403. addPendingProxyTacticsCallback(proxyPrivateKey)
  404. testGroup.Go(func() error {
  405. proxy.Run(runCtx)
  406. return nil
  407. })
  408. }
  409. // Await proxy announcements before starting clients
  410. //
  411. // - Announcements may delay due to proxyAnnounceRetryDelay in Proxy.Run,
  412. // plus NAT discovery
  413. //
  414. // - Don't wait for > numProxies announcements due to
  415. // InitiatorSessions.NewRoundTrip waitToShareSession limitation
  416. if !doMustUpgrade {
  417. for {
  418. time.Sleep(100 * time.Millisecond)
  419. broker.matcher.announcementQueueMutex.Lock()
  420. n := broker.matcher.announcementQueue.getLen()
  421. broker.matcher.announcementQueueMutex.Unlock()
  422. if n >= numProxies {
  423. break
  424. }
  425. }
  426. }
  427. // Start clients
  428. var completedClientCount atomic.Int64
  429. logger.WithTrace().Info("START CLIENTS")
  430. clientsGroup := new(errgroup.Group)
  431. makeClientFunc := func(
  432. clientNum int,
  433. isTCP bool,
  434. brokerClient *BrokerClient,
  435. webRTCCoordinator WebRTCDialCoordinator) func() error {
  436. var networkProtocol NetworkProtocol
  437. var addr string
  438. var wrapWithQUIC bool
  439. if isTCP {
  440. networkProtocol = NetworkProtocolTCP
  441. addr = tcpEchoListener.Addr().String()
  442. } else {
  443. networkProtocol = NetworkProtocolUDP
  444. addr = quicEchoServer.Addr().String()
  445. wrapWithQUIC = true
  446. }
  447. return func() error {
  448. name := fmt.Sprintf("client-%d", clientNum)
  449. dialCtx, cancelDial := context.WithTimeout(testCtx, 60*time.Second)
  450. defer cancelDial()
  451. conn, err := DialClient(
  452. dialCtx,
  453. &ClientConfig{
  454. Logger: newTestLoggerWithComponent(name),
  455. BaseAPIParameters: baseAPIParameters,
  456. BrokerClient: brokerClient,
  457. WebRTCDialCoordinator: webRTCCoordinator,
  458. ReliableTransport: isTCP,
  459. DialNetworkProtocol: networkProtocol,
  460. DialAddress: addr,
  461. PackedDestinationServerEntry: packedDestinationServerEntry,
  462. MustUpgrade: func() {
  463. close(receivedClientMustUpgrade)
  464. cancelDial()
  465. },
  466. })
  467. if err != nil {
  468. return errors.Trace(err)
  469. }
  470. var relayConn net.Conn
  471. relayConn = conn
  472. if wrapWithQUIC {
  473. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  474. if err != nil {
  475. return errors.Trace(err)
  476. }
  477. disablePathMTUDiscovery := true
  478. quicConn, err := quic.Dial(
  479. dialCtx,
  480. conn,
  481. udpAddr,
  482. "test",
  483. "QUICv1",
  484. nil,
  485. quicEchoServer.ObfuscationKey(),
  486. nil,
  487. nil,
  488. disablePathMTUDiscovery,
  489. GetQUICMaxPacketSizeAdjustment(false),
  490. false,
  491. false,
  492. common.WrapClientSessionCache(tls.NewLRUClientSessionCache(0), ""),
  493. )
  494. if err != nil {
  495. return errors.Trace(err)
  496. }
  497. relayConn = quicConn
  498. }
  499. addPendingBrokerServerReport(conn.GetConnectionID())
  500. signalRelayComplete := make(chan struct{})
  501. clientsGroup.Go(func() error {
  502. defer close(signalRelayComplete)
  503. in := conn.InitialRelayPacket()
  504. for in != nil {
  505. out, err := handleBrokerServerReports(in, conn.GetConnectionID())
  506. if err != nil {
  507. if out == nil {
  508. return errors.Trace(err)
  509. } else {
  510. fmt.Printf("HandlePacket returned packet and error: %v\n", err)
  511. // Proceed with reset session token packet
  512. }
  513. }
  514. if out == nil {
  515. // Relay is complete
  516. break
  517. }
  518. in, err = conn.RelayPacket(testCtx, out)
  519. if err != nil {
  520. return errors.Trace(err)
  521. }
  522. }
  523. return nil
  524. })
  525. sendBytes := prng.Bytes(bytesToSend)
  526. clientsGroup.Go(func() error {
  527. for n := 0; n < bytesToSend; {
  528. m := prng.Range(1024, 32768)
  529. if bytesToSend-n < m {
  530. m = bytesToSend - n
  531. }
  532. _, err := relayConn.Write(sendBytes[n : n+m])
  533. if err != nil {
  534. return errors.Trace(err)
  535. }
  536. n += m
  537. }
  538. fmt.Printf("[%s][%s] %d bytes sent\n",
  539. time.Now().UTC().Format(time.RFC3339), name, bytesToSend)
  540. return nil
  541. })
  542. clientsGroup.Go(func() error {
  543. buf := make([]byte, 32768)
  544. n := 0
  545. for n < bytesToSend {
  546. m, err := relayConn.Read(buf)
  547. if err != nil {
  548. return errors.Trace(err)
  549. }
  550. if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
  551. return errors.Tracef(
  552. "unexpected bytes: expected at index %d, received at index %d",
  553. bytes.Index(sendBytes, buf[:m]), n)
  554. }
  555. n += m
  556. }
  557. completed := completedClientCount.Add(1)
  558. fmt.Printf("[%s][%s] %d bytes received; relay complete (%d/%d)\n",
  559. time.Now().UTC().Format(time.RFC3339), name,
  560. bytesToSend, completed, numClients)
  561. select {
  562. case <-signalRelayComplete:
  563. case <-testCtx.Done():
  564. }
  565. fmt.Printf("[%s][%s] closing\n",
  566. time.Now().UTC().Format(time.RFC3339), name)
  567. relayConn.Close()
  568. conn.Close()
  569. return nil
  570. })
  571. return nil
  572. }
  573. }
  574. newClientBrokerClient := func() (*BrokerClient, error) {
  575. clientPrivateKey, err := GenerateSessionPrivateKey()
  576. if err != nil {
  577. return nil, errors.Trace(err)
  578. }
  579. brokerCoordinator := &testBrokerDialCoordinator{
  580. networkID: testNetworkID,
  581. networkType: testNetworkType,
  582. commonCompartmentIDs: testCommonCompartmentIDs,
  583. brokerClientPrivateKey: clientPrivateKey,
  584. brokerPublicKey: brokerPublicKey,
  585. brokerRootObfuscationSecret: brokerRootObfuscationSecret,
  586. brokerClientRoundTripper: newHTTPRoundTripper(
  587. brokerListener.Addr().String(), "client"),
  588. brokerClientRoundTripperSucceeded: roundTripperSucceded,
  589. brokerClientRoundTripperFailed: roundTripperFailed,
  590. brokerClientNoMatch: noMatch,
  591. }
  592. brokerClient, err := NewBrokerClient(brokerCoordinator)
  593. if err != nil {
  594. return nil, errors.Trace(err)
  595. }
  596. return brokerClient, nil
  597. }
  598. newClientWebRTCDialCoordinator := func(
  599. isMobile bool,
  600. useMediaStreams bool) (*testWebRTCDialCoordinator, error) {
  601. clientRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  602. if err != nil {
  603. return nil, errors.Trace(err)
  604. }
  605. var trafficShapingParameters *TrafficShapingParameters
  606. if useMediaStreams {
  607. trafficShapingParameters = &TrafficShapingParameters{
  608. MinPaddedMessages: 0,
  609. MaxPaddedMessages: 10,
  610. MinPaddingSize: 0,
  611. MaxPaddingSize: 254,
  612. MinDecoyMessages: 0,
  613. MaxDecoyMessages: 10,
  614. MinDecoySize: 1,
  615. MaxDecoySize: 1200,
  616. DecoyMessageProbability: 0.5,
  617. }
  618. } else {
  619. trafficShapingParameters = &TrafficShapingParameters{
  620. MinPaddedMessages: 0,
  621. MaxPaddedMessages: 10,
  622. MinPaddingSize: 0,
  623. MaxPaddingSize: 1500,
  624. MinDecoyMessages: 0,
  625. MaxDecoyMessages: 10,
  626. MinDecoySize: 1,
  627. MaxDecoySize: 1500,
  628. DecoyMessageProbability: 0.5,
  629. }
  630. }
  631. webRTCCoordinator := &testWebRTCDialCoordinator{
  632. networkID: testNetworkID,
  633. networkType: testNetworkType,
  634. natType: testNATType,
  635. disableSTUN: testDisableSTUN,
  636. stunServerAddress: testSTUNServerAddress,
  637. stunServerAddressRFC5780: testSTUNServerAddress,
  638. stunServerAddressSucceeded: stunServerAddressSucceeded,
  639. stunServerAddressFailed: stunServerAddressFailed,
  640. clientRootObfuscationSecret: clientRootObfuscationSecret,
  641. doDTLSRandomization: prng.FlipCoin(),
  642. useMediaStreams: useMediaStreams,
  643. trafficShapingParameters: trafficShapingParameters,
  644. setNATType: func(NATType) {},
  645. setPortMappingTypes: func(PortMappingTypes) {},
  646. bindToDevice: func(int) error { return nil },
  647. // With STUN enabled (testDisableSTUN = false), there are cases
  648. // where the WebRTC peer connection is not successfully
  649. // established. With a short enough timeout here, clients will
  650. // redial and eventually succceed.
  651. webRTCAwaitReadyToProxyTimeout: 5 * time.Second,
  652. }
  653. if isMobile {
  654. webRTCCoordinator.networkType = NetworkTypeMobile
  655. webRTCCoordinator.disableInboundForMobileNetworks = true
  656. }
  657. return webRTCCoordinator, nil
  658. }
  659. sharedBrokerClient, err := newClientBrokerClient()
  660. if err != nil {
  661. return errors.Trace(err)
  662. }
  663. for i := 0; i < numClients; i++ {
  664. // Test a mix of TCP and UDP proxying; also test the
  665. // DisableInboundForMobileNetworks code path.
  666. isTCP := i%2 == 0
  667. isMobile := i%4 == 0
  668. useMediaStreams := i%4 < 2
  669. // Exercise BrokerClients shared by multiple clients, but also create
  670. // several broker clients.
  671. brokerClient := sharedBrokerClient
  672. if i%2 == 0 {
  673. brokerClient, err = newClientBrokerClient()
  674. if err != nil {
  675. return errors.Trace(err)
  676. }
  677. }
  678. webRTCCoordinator, err := newClientWebRTCDialCoordinator(
  679. isMobile, useMediaStreams)
  680. if err != nil {
  681. return errors.Trace(err)
  682. }
  683. clientsGroup.Go(
  684. makeClientFunc(
  685. i,
  686. isTCP,
  687. brokerClient,
  688. webRTCCoordinator))
  689. }
  690. if doMustUpgrade {
  691. // Await MustUpgrade callbacks
  692. logger.WithTrace().Info("AWAIT MUST UPGRADE")
  693. <-receivedProxyMustUpgrade
  694. <-receivedClientMustUpgrade
  695. _ = clientsGroup.Wait()
  696. } else {
  697. // Await client transfers complete
  698. logger.WithTrace().Info("AWAIT DATA TRANSFER")
  699. err = clientsGroup.Wait()
  700. if err != nil {
  701. return errors.Trace(err)
  702. }
  703. logger.WithTrace().Info("DONE DATA TRANSFER")
  704. if hasPendingBrokerServerReports() {
  705. return errors.TraceNew("unexpected pending broker server requests")
  706. }
  707. if hasPendingProxyTacticsCallbacks() {
  708. return errors.TraceNew("unexpected pending proxy tactics callback")
  709. }
  710. // TODO: check that elapsed time is consistent with rate limit (+/-)
  711. // Check if STUN server replay callbacks were triggered
  712. if !testDisableSTUN {
  713. if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
  714. return errors.TraceNew("unexpected STUN server succeeded count")
  715. }
  716. }
  717. if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
  718. return errors.TraceNew("unexpected STUN server failed count")
  719. }
  720. // Check if RoundTripper server replay callbacks were triggered
  721. if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
  722. return errors.TraceNew("unexpected round tripper succeeded count")
  723. }
  724. if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
  725. return errors.TraceNew("unexpected round tripper failed count")
  726. }
  727. }
  728. // Await shutdowns
  729. stopTest()
  730. brokerListener.Close()
  731. err = testGroup.Wait()
  732. if err != nil {
  733. return errors.Trace(err)
  734. }
  735. return nil
  736. }
  737. func runHTTPServer(listener net.Listener, broker *Broker) error {
  738. handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  739. // For this test, clients set the path to "/client" and proxies
  740. // set the path to "/proxy" and we use that to create stub GeoIP
  741. // data to pass the not-same-ASN condition.
  742. var geoIPData common.GeoIPData
  743. geoIPData.ASN = r.URL.Path
  744. requestPayload, err := ioutil.ReadAll(
  745. http.MaxBytesReader(w, r.Body, BrokerMaxRequestBodySize))
  746. if err != nil {
  747. fmt.Printf("runHTTPServer ioutil.ReadAll failed: %v\n", err)
  748. http.Error(w, "", http.StatusNotFound)
  749. return
  750. }
  751. clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
  752. extendTimeout := func(timeout time.Duration) {
  753. // TODO: set insufficient initial timeout, so extension is
  754. // required for success
  755. http.NewResponseController(w).SetWriteDeadline(time.Now().Add(timeout))
  756. }
  757. responsePayload, err := broker.HandleSessionPacket(
  758. r.Context(),
  759. extendTimeout,
  760. nil,
  761. clientIP,
  762. geoIPData,
  763. requestPayload)
  764. if err != nil {
  765. fmt.Printf("runHTTPServer HandleSessionPacket failed: %v\n", err)
  766. http.Error(w, "", http.StatusNotFound)
  767. return
  768. }
  769. w.WriteHeader(http.StatusOK)
  770. w.Write(responsePayload)
  771. })
  772. // WriteTimeout will be extended via extendTimeout.
  773. httpServer := &http.Server{
  774. ReadTimeout: 10 * time.Second,
  775. WriteTimeout: 10 * time.Second,
  776. IdleTimeout: 1 * time.Minute,
  777. Handler: handler,
  778. }
  779. certificate, privateKey, _, err := common.GenerateWebServerCertificate("www.example.com")
  780. if err != nil {
  781. return errors.Trace(err)
  782. }
  783. tlsCert, err := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
  784. if err != nil {
  785. return errors.Trace(err)
  786. }
  787. tlsConfig := &tls.Config{
  788. Certificates: []tls.Certificate{tlsCert},
  789. }
  790. err = httpServer.Serve(tls.NewListener(listener, tlsConfig))
  791. return errors.Trace(err)
  792. }
  793. type httpRoundTripper struct {
  794. httpClient *http.Client
  795. endpointAddr string
  796. path string
  797. }
  798. func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
  799. return &httpRoundTripper{
  800. httpClient: &http.Client{
  801. Transport: &http.Transport{
  802. ForceAttemptHTTP2: true,
  803. MaxIdleConns: 2,
  804. IdleConnTimeout: 1 * time.Minute,
  805. TLSHandshakeTimeout: 10 * time.Second,
  806. TLSClientConfig: &std_tls.Config{
  807. InsecureSkipVerify: true,
  808. },
  809. },
  810. },
  811. endpointAddr: endpointAddr,
  812. path: path,
  813. }
  814. }
  815. func (r *httpRoundTripper) RoundTrip(
  816. ctx context.Context,
  817. roundTripDelay time.Duration,
  818. roundTripTimeout time.Duration,
  819. requestPayload []byte) ([]byte, error) {
  820. if roundTripDelay > 0 {
  821. common.SleepWithContext(ctx, roundTripDelay)
  822. }
  823. requestCtx, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
  824. defer requestCancelFunc()
  825. url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
  826. request, err := http.NewRequestWithContext(
  827. requestCtx, "POST", url, bytes.NewReader(requestPayload))
  828. if err != nil {
  829. return nil, errors.Trace(err)
  830. }
  831. response, err := r.httpClient.Do(request)
  832. if err != nil {
  833. return nil, errors.Trace(err)
  834. }
  835. defer response.Body.Close()
  836. if response.StatusCode != http.StatusOK {
  837. return nil, errors.Tracef("unexpected response status code: %d", response.StatusCode)
  838. }
  839. responsePayload, err := io.ReadAll(response.Body)
  840. if err != nil {
  841. return nil, errors.Trace(err)
  842. }
  843. return responsePayload, nil
  844. }
  845. func (r *httpRoundTripper) Close() error {
  846. r.httpClient.CloseIdleConnections()
  847. return nil
  848. }
  849. func runTCPEchoServer(listener net.Listener) {
  850. for {
  851. conn, err := listener.Accept()
  852. if err != nil {
  853. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  854. return
  855. }
  856. go func(conn net.Conn) {
  857. buf := make([]byte, 32768)
  858. for {
  859. n, err := conn.Read(buf)
  860. if n > 0 {
  861. _, err = conn.Write(buf[:n])
  862. }
  863. if err != nil {
  864. fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
  865. return
  866. }
  867. }
  868. }(conn)
  869. }
  870. }
  871. type quicEchoServer struct {
  872. listener net.Listener
  873. obfuscationKey string
  874. }
  875. func newQuicEchoServer() (*quicEchoServer, error) {
  876. obfuscationKey := prng.HexString(32)
  877. listener, err := quic.Listen(
  878. nil,
  879. nil,
  880. "127.0.0.1:0",
  881. GetQUICMaxPacketSizeAdjustment(false),
  882. obfuscationKey,
  883. false)
  884. if err != nil {
  885. return nil, errors.Trace(err)
  886. }
  887. return &quicEchoServer{
  888. listener: listener,
  889. obfuscationKey: obfuscationKey,
  890. }, nil
  891. }
  892. func (q *quicEchoServer) ObfuscationKey() string {
  893. return q.obfuscationKey
  894. }
  895. func (q *quicEchoServer) Close() error {
  896. return q.listener.Close()
  897. }
  898. func (q *quicEchoServer) Addr() net.Addr {
  899. return q.listener.Addr()
  900. }
  901. func (q *quicEchoServer) Run() {
  902. for {
  903. conn, err := q.listener.Accept()
  904. if err != nil {
  905. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  906. return
  907. }
  908. go func(conn net.Conn) {
  909. buf := make([]byte, 32768)
  910. for {
  911. n, err := conn.Read(buf)
  912. if n > 0 {
  913. _, err = conn.Write(buf[:n])
  914. }
  915. if err != nil {
  916. fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
  917. return
  918. }
  919. }
  920. }(conn)
  921. }
  922. }