| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243 |
- //go:build PSIPHON_ENABLE_INPROXY
- /*
- * Copyright (c) 2023, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see <http://www.gnu.org/licenses/>.
- *
- */
- package inproxy
- import (
- "bytes"
- "context"
- std_tls "crypto/tls"
- "encoding/base64"
- "fmt"
- "io"
- "io/ioutil"
- "net"
- "net/http"
- _ "net/http/pprof"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "testing"
- "time"
- tls "github.com/Psiphon-Labs/psiphon-tls"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
- "golang.org/x/sync/errgroup"
- )
- func TestInproxy(t *testing.T) {
- err := runTestInproxy(false)
- if err != nil {
- t.Error(errors.Trace(err).Error())
- }
- }
- func TestInproxyMustUpgrade(t *testing.T) {
- err := runTestInproxy(true)
- if err != nil {
- t.Error(errors.Trace(err).Error())
- }
- }
- func runTestInproxy(doMustUpgrade bool) error {
- // Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
- numProxies := 5
- proxyMaxClients := 3
- numClients := 10
- bytesToSend := 1 << 20
- targetElapsedSeconds := 2
- baseAPIParameters := common.APIParameters{
- "sponsor_id": strings.ToUpper(prng.HexString(8)),
- "client_platform": "test-client-platform",
- }
- testCompartmentID, _ := MakeID()
- testCommonCompartmentIDs := []ID{testCompartmentID}
- testNetworkID := "NETWORK-ID-1"
- testNetworkType := NetworkTypeUnknown
- testNATType := NATTypeUnknown
- testSTUNServerAddress := "stun.nextcloud.com:443"
- testDisableSTUN := false
- testDisablePortMapping := false
- testNewTacticsPayload := []byte(prng.HexString(100))
- testNewTacticsTag := "new-tactics-tag"
- testUnchangedTacticsPayload := []byte(prng.HexString(100))
- currentNetworkCtx, currentNetworkCancelFunc := context.WithCancel(context.Background())
- defer currentNetworkCancelFunc()
- // TODO: test port mapping
- stunServerAddressSucceededCount := int32(0)
- stunServerAddressSucceeded := func(bool, string) { atomic.AddInt32(&stunServerAddressSucceededCount, 1) }
- stunServerAddressFailedCount := int32(0)
- stunServerAddressFailed := func(bool, string) { atomic.AddInt32(&stunServerAddressFailedCount, 1) }
- roundTripperSucceededCount := int32(0)
- roundTripperSucceded := func(RoundTripper) { atomic.AddInt32(&roundTripperSucceededCount, 1) }
- roundTripperFailedCount := int32(0)
- roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
- noMatch := func(RoundTripper) {}
- var receivedProxyMustUpgrade chan struct{}
- var receivedClientMustUpgrade chan struct{}
- if doMustUpgrade {
- receivedProxyMustUpgrade = make(chan struct{})
- receivedClientMustUpgrade = make(chan struct{})
- // trigger MustUpgrade
- minimumProxyProtocolVersion = LatestProtocolVersion + 1
- minimumClientProtocolVersion = LatestProtocolVersion + 1
- // Minimize test parameters for MustUpgrade case
- numProxies = 1
- proxyMaxClients = 1
- numClients = 1
- testDisableSTUN = true
- testDisablePortMapping = true
- }
- testCtx, stopTest := context.WithCancel(context.Background())
- defer stopTest()
- testGroup := new(errgroup.Group)
- // Enable test to run without requiring host firewall exceptions
- SetAllowBogonWebRTCConnections(true)
- defer SetAllowBogonWebRTCConnections(false)
- // Init logging and profiling
- logger := newTestLogger()
- pprofListener, err := net.Listen("tcp", "127.0.0.1:0")
- go http.Serve(pprofListener, nil)
- defer pprofListener.Close()
- logger.WithTrace().Info(fmt.Sprintf("PPROF: http://%s/debug/pprof", pprofListener.Addr()))
- // Start echo servers
- tcpEchoListener, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- return errors.Trace(err)
- }
- defer tcpEchoListener.Close()
- go runTCPEchoServer(tcpEchoListener)
- // QUIC tests UDP proxying, and provides reliable delivery of echoed data
- quicEchoServer, err := newQuicEchoServer()
- if err != nil {
- return errors.Trace(err)
- }
- defer quicEchoServer.Close()
- go quicEchoServer.Run()
- // Create signed server entry with capability
- serverPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return errors.Trace(err)
- }
- serverPublicKey, err := serverPrivateKey.GetPublicKey()
- if err != nil {
- return errors.Trace(err)
- }
- serverRootObfuscationSecret, err := GenerateRootObfuscationSecret()
- if err != nil {
- return errors.Trace(err)
- }
- serverEntry := make(protocol.ServerEntryFields)
- serverEntry["ipAddress"] = "127.0.0.1"
- _, tcpPort, _ := net.SplitHostPort(tcpEchoListener.Addr().String())
- _, udpPort, _ := net.SplitHostPort(quicEchoServer.Addr().String())
- serverEntry["inproxyOSSHPort"], _ = strconv.Atoi(tcpPort)
- serverEntry["inproxyQUICPort"], _ = strconv.Atoi(udpPort)
- serverEntry["capabilities"] = []string{"INPROXY-WEBRTC-OSSH", "INPROXY-WEBRTC-QUIC-OSSH"}
- serverEntry["inproxySessionPublicKey"] = base64.RawStdEncoding.EncodeToString(serverPublicKey[:])
- serverEntry["inproxySessionRootObfuscationSecret"] = base64.RawStdEncoding.EncodeToString(serverRootObfuscationSecret[:])
- testServerEntryTag := prng.HexString(16)
- serverEntry["tag"] = testServerEntryTag
- serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey, err :=
- protocol.NewServerEntrySignatureKeyPair()
- if err != nil {
- return errors.Trace(err)
- }
- err = serverEntry.AddSignature(serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey)
- if err != nil {
- return errors.Trace(err)
- }
- packedServerEntryFields, err := protocol.EncodePackedServerEntryFields(serverEntry)
- if err != nil {
- return errors.Trace(err)
- }
- packedDestinationServerEntry, err := protocol.CBOREncoding.Marshal(packedServerEntryFields)
- if err != nil {
- return errors.Trace(err)
- }
- // API parameter handlers
- apiParameterValidator := func(params common.APIParameters) error {
- if len(params) != len(baseAPIParameters) {
- return errors.TraceNew("unexpected base API parameter count")
- }
- for name, value := range params {
- if value.(string) != baseAPIParameters[name].(string) {
- return errors.Tracef(
- "unexpected base API parameter: %v: %v != %v",
- name,
- value.(string),
- baseAPIParameters[name].(string))
- }
- }
- return nil
- }
- apiParameterLogFieldFormatter := func(
- _ string, _ common.GeoIPData, params common.APIParameters) common.LogFields {
- logFields := common.LogFields{}
- logFields.Add(common.LogFields(params))
- return logFields
- }
- // Start broker
- logger.WithTrace().Info("START BROKER")
- brokerPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return errors.Trace(err)
- }
- brokerPublicKey, err := brokerPrivateKey.GetPublicKey()
- if err != nil {
- return errors.Trace(err)
- }
- brokerRootObfuscationSecret, err := GenerateRootObfuscationSecret()
- if err != nil {
- return errors.Trace(err)
- }
- brokerListener, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- return errors.Trace(err)
- }
- defer brokerListener.Close()
- brokerConfig := &BrokerConfig{
- Logger: logger,
- CommonCompartmentIDs: testCommonCompartmentIDs,
- APIParameterValidator: apiParameterValidator,
- APIParameterLogFieldFormatter: apiParameterLogFieldFormatter,
- GetTacticsPayload: func(_ common.GeoIPData, _ common.APIParameters) ([]byte, string, error) {
- // Exercise both new and unchanged tactics
- if prng.FlipCoin() {
- return testNewTacticsPayload, testNewTacticsTag, nil
- }
- return testUnchangedTacticsPayload, "", nil
- },
- IsValidServerEntryTag: func(serverEntryTag string) bool { return serverEntryTag == testServerEntryTag },
- PrivateKey: brokerPrivateKey,
- ObfuscationRootSecret: brokerRootObfuscationSecret,
- ServerEntrySignaturePublicKey: serverEntrySignaturePublicKey,
- AllowProxy: func(common.GeoIPData) bool { return true },
- AllowClient: func(common.GeoIPData) bool { return true },
- AllowDomainFrontedDestinations: func(common.GeoIPData) bool { return true },
- }
- broker, err := NewBroker(brokerConfig)
- if err != nil {
- return errors.Trace(err)
- }
- // Enable proxy quality (and otherwise use the default quality parameters)
- enableProxyQuality := true
- broker.SetProxyQualityParameters(
- enableProxyQuality,
- proxyQualityTTL,
- proxyQualityPendingFailedMatchDeadline,
- proxyQualityFailedMatchThreshold)
- err = broker.Start()
- if err != nil {
- return errors.Trace(err)
- }
- defer broker.Stop()
- testGroup.Go(func() error {
- err := runHTTPServer(brokerListener, broker)
- if testCtx.Err() != nil {
- return nil
- }
- return errors.Trace(err)
- })
- // Stub server broker request handler (in Psiphon, this will be the
- // destination Psiphon server; here, it's not necessary to build this
- // handler into the destination echo server)
- //
- // The stub server broker request handler also triggers a server proxy
- // quality request in the other direction.
- makeServerBrokerClientRoundTripper := func(_ SessionPublicKey) (
- RoundTripper, common.APIParameters, error) {
- return newHTTPRoundTripper(brokerListener.Addr().String(), "server"), nil, nil
- }
- serverSessionsConfig := &ServerBrokerSessionsConfig{
- Logger: logger,
- ServerPrivateKey: serverPrivateKey,
- ServerRootObfuscationSecret: serverRootObfuscationSecret,
- BrokerPublicKeys: []SessionPublicKey{brokerPublicKey},
- BrokerRootObfuscationSecrets: []ObfuscationSecret{brokerRootObfuscationSecret},
- BrokerRoundTripperMaker: makeServerBrokerClientRoundTripper,
- ProxyMetricsValidator: apiParameterValidator,
- ProxyMetricsFormatter: apiParameterLogFieldFormatter,
- ProxyMetricsPrefix: "",
- }
- serverSessions, err := NewServerBrokerSessions(serverSessionsConfig)
- if err != nil {
- return errors.Trace(err)
- }
- err = serverSessions.Start()
- if err != nil {
- return errors.Trace(err)
- }
- defer serverSessions.Stop()
- // Don't delay reporting quality.
- serverSessions.SetProxyQualityRequestParameters(
- proxyQualityReporterMaxRequestEntries,
- 0,
- proxyQualityReporterRequestTimeout,
- proxyQualityReporterRequestRetries)
- var pendingBrokerServerReportsMutex sync.Mutex
- pendingBrokerServerReports := make(map[ID]bool)
- addPendingBrokerServerReport := func(connectionID ID) {
- pendingBrokerServerReportsMutex.Lock()
- defer pendingBrokerServerReportsMutex.Unlock()
- pendingBrokerServerReports[connectionID] = true
- }
- removePendingBrokerServerReport := func(connectionID ID) {
- pendingBrokerServerReportsMutex.Lock()
- defer pendingBrokerServerReportsMutex.Unlock()
- delete(pendingBrokerServerReports, connectionID)
- }
- hasPendingBrokerServerReports := func() bool {
- pendingBrokerServerReportsMutex.Lock()
- defer pendingBrokerServerReportsMutex.Unlock()
- return len(pendingBrokerServerReports) > 0
- }
- serverQualityGroup := new(errgroup.Group)
- var serverQualityProxyIDsMutex sync.Mutex
- serverQualityProxyIDs := make(map[ID]struct{})
- testProxyASN := "65537"
- testClientASN := "65538"
- handleBrokerServerReports := func(in []byte, clientConnectionID ID) ([]byte, error) {
- handler := func(
- brokerVerifiedOriginalClientIP string,
- brokerReportedProxyID ID,
- brokerMatchedPersonalCompartments bool,
- logFields common.LogFields) {
- // Mark the report as no longer outstanding
- removePendingBrokerServerReport(clientConnectionID)
- // Trigger an asynchronous proxy quality request to the broker.
- // This roughly follows the Psiphon server functionality, where a
- // quality request is made sometime after the Psiphon handshake
- // completes, once tunnel quality thresholds are achieved.
- serverQualityGroup.Go(func() error {
- serverSessions.ReportQuality(
- brokerReportedProxyID, testProxyASN, testClientASN)
- serverQualityProxyIDsMutex.Lock()
- serverQualityProxyIDs[brokerReportedProxyID] = struct{}{}
- serverQualityProxyIDsMutex.Unlock()
- return nil
- })
- }
- out, err := serverSessions.HandlePacket(logger, in, clientConnectionID, handler)
- return out, errors.Trace(err)
- }
- // Check that the tactics round trip succeeds
- var pendingProxyTacticsCallbacksMutex sync.Mutex
- pendingProxyTacticsCallbacks := make(map[SessionPrivateKey]bool)
- addPendingProxyTacticsCallback := func(proxyPrivateKey SessionPrivateKey) {
- pendingProxyTacticsCallbacksMutex.Lock()
- defer pendingProxyTacticsCallbacksMutex.Unlock()
- pendingProxyTacticsCallbacks[proxyPrivateKey] = true
- }
- hasPendingProxyTacticsCallbacks := func() bool {
- pendingProxyTacticsCallbacksMutex.Lock()
- defer pendingProxyTacticsCallbacksMutex.Unlock()
- return len(pendingProxyTacticsCallbacks) > 0
- }
- makeHandleTacticsPayload := func(
- proxyPrivateKey SessionPrivateKey,
- tacticsNetworkID string) func(_ string, _ []byte) bool {
- return func(networkID string, tacticsPayload []byte) bool {
- pendingProxyTacticsCallbacksMutex.Lock()
- defer pendingProxyTacticsCallbacksMutex.Unlock()
- // Check that the correct networkID is passed around; if not,
- // skip the delete, which will fail the test
- if networkID == tacticsNetworkID {
- // Certain state is reset when new tactics are applied -- the
- // return true case; exercise both cases
- if bytes.Equal(tacticsPayload, testNewTacticsPayload) {
- delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
- return true
- }
- if bytes.Equal(tacticsPayload, testUnchangedTacticsPayload) {
- delete(pendingProxyTacticsCallbacks, proxyPrivateKey)
- return false
- }
- }
- panic("unexpected tactics payload")
- }
- }
- // Start proxies
- logger.WithTrace().Info("START PROXIES")
- for i := 0; i < numProxies; i++ {
- proxyPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return errors.Trace(err)
- }
- brokerCoordinator := &testBrokerDialCoordinator{
- networkID: testNetworkID,
- networkType: testNetworkType,
- brokerClientPrivateKey: proxyPrivateKey,
- brokerPublicKey: brokerPublicKey,
- brokerRootObfuscationSecret: brokerRootObfuscationSecret,
- brokerClientRoundTripper: newHTTPRoundTripper(
- brokerListener.Addr().String(), "proxy"),
- brokerClientRoundTripperSucceeded: roundTripperSucceded,
- brokerClientRoundTripperFailed: roundTripperFailed,
- // Minimize the delay before proxies reannounce after dial
- // failures, which may occur.
- announceDelay: 0,
- announceMaxBackoffDelay: 0,
- announceDelayJitter: 0.0,
- }
- webRTCCoordinator := &testWebRTCDialCoordinator{
- networkID: testNetworkID,
- networkType: testNetworkType,
- natType: testNATType,
- disableSTUN: testDisableSTUN,
- disablePortMapping: testDisablePortMapping,
- stunServerAddress: testSTUNServerAddress,
- stunServerAddressRFC5780: testSTUNServerAddress,
- stunServerAddressSucceeded: stunServerAddressSucceeded,
- stunServerAddressFailed: stunServerAddressFailed,
- setNATType: func(NATType) {},
- setPortMappingTypes: func(PortMappingTypes) {},
- bindToDevice: func(int) error { return nil },
- // Minimize the delay before proxies reannounce after failed
- // connections, which may occur.
- webRTCAwaitReadyToProxyTimeout: 5 * time.Second,
- proxyRelayInactivityTimeout: 5 * time.Second,
- }
- // Each proxy has its own broker client
- brokerClient, err := NewBrokerClient(brokerCoordinator)
- if err != nil {
- return errors.Trace(err)
- }
- tacticsNetworkID := prng.HexString(32)
- runCtx, cancelRun := context.WithCancel(testCtx)
- // No deferred cancelRun due to testGroup.Go below
- name := fmt.Sprintf("proxy-%d", i)
- proxy, err := NewProxy(&ProxyConfig{
- Logger: newTestLoggerWithComponent(name),
- WaitForNetworkConnectivity: func() bool {
- return true
- },
- GetCurrentNetworkContext: func() context.Context {
- return currentNetworkCtx
- },
- GetBrokerClient: func() (*BrokerClient, error) {
- return brokerClient, nil
- },
- GetBaseAPIParameters: func(bool) (common.APIParameters, string, error) {
- return baseAPIParameters, tacticsNetworkID, nil
- },
- MakeWebRTCDialCoordinator: func() (WebRTCDialCoordinator, error) {
- return webRTCCoordinator, nil
- },
- HandleTacticsPayload: makeHandleTacticsPayload(proxyPrivateKey, tacticsNetworkID),
- MaxClients: proxyMaxClients,
- LimitUpstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
- LimitDownstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
- ActivityUpdater: func(connectingClients int32, connectedClients int32,
- bytesUp int64, bytesDown int64, bytesDuration time.Duration) {
- fmt.Printf("[%s][%s] ACTIVITY: %d connecting, %d connected, %d up, %d down\n",
- time.Now().UTC().Format(time.RFC3339), name,
- connectingClients, connectedClients, bytesUp, bytesDown)
- },
- MustUpgrade: func() {
- close(receivedProxyMustUpgrade)
- cancelRun()
- },
- })
- if err != nil {
- return errors.Trace(err)
- }
- addPendingProxyTacticsCallback(proxyPrivateKey)
- testGroup.Go(func() error {
- proxy.Run(runCtx)
- return nil
- })
- }
- // Await proxy announcements before starting clients
- //
- // - Announcements may delay due to proxyAnnounceRetryDelay in Proxy.Run,
- // plus NAT discovery
- //
- // - Don't wait for > numProxies announcements due to
- // InitiatorSessions.NewRoundTrip waitToShareSession limitation
- if !doMustUpgrade {
- for {
- time.Sleep(100 * time.Millisecond)
- broker.matcher.announcementQueueMutex.Lock()
- n := broker.matcher.announcementQueue.getLen()
- broker.matcher.announcementQueueMutex.Unlock()
- if n >= numProxies {
- break
- }
- }
- }
- // Start clients
- var completedClientCount atomic.Int64
- logger.WithTrace().Info("START CLIENTS")
- clientsGroup := new(errgroup.Group)
- makeClientFunc := func(
- clientNum int,
- isTCP bool,
- brokerClient *BrokerClient,
- webRTCCoordinator WebRTCDialCoordinator) func() error {
- var networkProtocol NetworkProtocol
- var addr string
- var wrapWithQUIC bool
- if isTCP {
- networkProtocol = NetworkProtocolTCP
- addr = tcpEchoListener.Addr().String()
- } else {
- networkProtocol = NetworkProtocolUDP
- addr = quicEchoServer.Addr().String()
- wrapWithQUIC = true
- }
- return func() error {
- name := fmt.Sprintf("client-%d", clientNum)
- dialCtx, cancelDial := context.WithTimeout(testCtx, 60*time.Second)
- defer cancelDial()
- conn, err := DialClient(
- dialCtx,
- &ClientConfig{
- Logger: newTestLoggerWithComponent(name),
- BaseAPIParameters: baseAPIParameters,
- BrokerClient: brokerClient,
- WebRTCDialCoordinator: webRTCCoordinator,
- ReliableTransport: isTCP,
- DialNetworkProtocol: networkProtocol,
- DialAddress: addr,
- PackedDestinationServerEntry: packedDestinationServerEntry,
- MustUpgrade: func() {
- close(receivedClientMustUpgrade)
- cancelDial()
- },
- })
- if err != nil {
- return errors.Trace(err)
- }
- var relayConn net.Conn
- relayConn = conn
- if wrapWithQUIC {
- udpAddr, err := net.ResolveUDPAddr("udp", addr)
- if err != nil {
- return errors.Trace(err)
- }
- disablePathMTUDiscovery := true
- quicConn, err := quic.Dial(
- dialCtx,
- conn,
- udpAddr,
- "test",
- "QUICv1",
- nil,
- quicEchoServer.ObfuscationKey(),
- nil,
- nil,
- disablePathMTUDiscovery,
- GetQUICMaxPacketSizeAdjustment(),
- false,
- false,
- common.WrapClientSessionCache(tls.NewLRUClientSessionCache(0), ""),
- )
- if err != nil {
- return errors.Trace(err)
- }
- relayConn = quicConn
- }
- addPendingBrokerServerReport(conn.GetConnectionID())
- signalRelayComplete := make(chan struct{})
- clientsGroup.Go(func() error {
- defer close(signalRelayComplete)
- in := conn.InitialRelayPacket()
- for in != nil {
- out, err := handleBrokerServerReports(in, conn.GetConnectionID())
- if err != nil {
- if out == nil {
- return errors.Trace(err)
- } else {
- fmt.Printf("HandlePacket returned packet and error: %v\n", err)
- // Proceed with reset session token packet
- }
- }
- if out == nil {
- // Relay is complete
- break
- }
- in, err = conn.RelayPacket(testCtx, out)
- if err != nil {
- return errors.Trace(err)
- }
- }
- return nil
- })
- sendBytes := prng.Bytes(bytesToSend)
- clientsGroup.Go(func() error {
- for n := 0; n < bytesToSend; {
- m := prng.Range(1024, 32768)
- if bytesToSend-n < m {
- m = bytesToSend - n
- }
- _, err := relayConn.Write(sendBytes[n : n+m])
- if err != nil {
- return errors.Trace(err)
- }
- n += m
- }
- fmt.Printf("[%s][%s] %d bytes sent\n",
- time.Now().UTC().Format(time.RFC3339), name, bytesToSend)
- return nil
- })
- clientsGroup.Go(func() error {
- buf := make([]byte, 32768)
- n := 0
- for n < bytesToSend {
- m, err := relayConn.Read(buf)
- if err != nil {
- return errors.Trace(err)
- }
- if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
- return errors.Tracef(
- "unexpected bytes: expected at index %d, received at index %d",
- bytes.Index(sendBytes, buf[:m]), n)
- }
- n += m
- }
- completed := completedClientCount.Add(1)
- fmt.Printf("[%s][%s] %d bytes received; relay complete (%d/%d)\n",
- time.Now().UTC().Format(time.RFC3339), name,
- bytesToSend, completed, numClients)
- select {
- case <-signalRelayComplete:
- case <-testCtx.Done():
- }
- fmt.Printf("[%s][%s] closing\n",
- time.Now().UTC().Format(time.RFC3339), name)
- relayConn.Close()
- conn.Close()
- return nil
- })
- return nil
- }
- }
- newClientBrokerClient := func(
- disableWaitToShareSession bool) (*BrokerClient, error) {
- clientPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return nil, errors.Trace(err)
- }
- brokerCoordinator := &testBrokerDialCoordinator{
- networkID: testNetworkID,
- networkType: testNetworkType,
- commonCompartmentIDs: testCommonCompartmentIDs,
- disableWaitToShareSession: disableWaitToShareSession,
- brokerClientPrivateKey: clientPrivateKey,
- brokerPublicKey: brokerPublicKey,
- brokerRootObfuscationSecret: brokerRootObfuscationSecret,
- brokerClientRoundTripper: newHTTPRoundTripper(
- brokerListener.Addr().String(), "client"),
- brokerClientRoundTripperSucceeded: roundTripperSucceded,
- brokerClientRoundTripperFailed: roundTripperFailed,
- brokerClientNoMatch: noMatch,
- }
- brokerClient, err := NewBrokerClient(brokerCoordinator)
- if err != nil {
- return nil, errors.Trace(err)
- }
- return brokerClient, nil
- }
- newClientWebRTCDialCoordinator := func(
- isMobile bool,
- useMediaStreams bool) (*testWebRTCDialCoordinator, error) {
- clientRootObfuscationSecret, err := GenerateRootObfuscationSecret()
- if err != nil {
- return nil, errors.Trace(err)
- }
- var trafficShapingParameters *TrafficShapingParameters
- if useMediaStreams {
- trafficShapingParameters = &TrafficShapingParameters{
- MinPaddedMessages: 0,
- MaxPaddedMessages: 10,
- MinPaddingSize: 0,
- MaxPaddingSize: 254,
- MinDecoyMessages: 0,
- MaxDecoyMessages: 10,
- MinDecoySize: 1,
- MaxDecoySize: 1200,
- DecoyMessageProbability: 0.5,
- }
- } else {
- trafficShapingParameters = &TrafficShapingParameters{
- MinPaddedMessages: 0,
- MaxPaddedMessages: 10,
- MinPaddingSize: 0,
- MaxPaddingSize: 1500,
- MinDecoyMessages: 0,
- MaxDecoyMessages: 10,
- MinDecoySize: 1,
- MaxDecoySize: 1500,
- DecoyMessageProbability: 0.5,
- }
- }
- webRTCCoordinator := &testWebRTCDialCoordinator{
- networkID: testNetworkID,
- networkType: testNetworkType,
- natType: testNATType,
- disableSTUN: testDisableSTUN,
- stunServerAddress: testSTUNServerAddress,
- stunServerAddressRFC5780: testSTUNServerAddress,
- stunServerAddressSucceeded: stunServerAddressSucceeded,
- stunServerAddressFailed: stunServerAddressFailed,
- clientRootObfuscationSecret: clientRootObfuscationSecret,
- doDTLSRandomization: prng.FlipCoin(),
- useMediaStreams: useMediaStreams,
- trafficShapingParameters: trafficShapingParameters,
- setNATType: func(NATType) {},
- setPortMappingTypes: func(PortMappingTypes) {},
- bindToDevice: func(int) error { return nil },
- // With STUN enabled (testDisableSTUN = false), there are cases
- // where the WebRTC peer connection is not successfully
- // established. With a short enough timeout here, clients will
- // redial and eventually succceed.
- webRTCAwaitReadyToProxyTimeout: 5 * time.Second,
- }
- if isMobile {
- webRTCCoordinator.networkType = NetworkTypeMobile
- webRTCCoordinator.disableInboundForMobileNetworks = true
- }
- return webRTCCoordinator, nil
- }
- sharedBrokerClient, err := newClientBrokerClient(false)
- if err != nil {
- return errors.Trace(err)
- }
- sharedBrokerClientDisableWait, err := newClientBrokerClient(true)
- if err != nil {
- return errors.Trace(err)
- }
- for i := 0; i < numClients; i++ {
- // Test a mix of TCP and UDP proxying; also test the
- // DisableInboundForMobileNetworks code path.
- isTCP := i%2 == 0
- isMobile := i%4 == 0
- useMediaStreams := i%4 < 2
- // Exercise BrokerClients shared by multiple clients, but also create
- // several broker clients.
- var brokerClient *BrokerClient
- switch i % 3 {
- case 0:
- brokerClient = sharedBrokerClient
- case 1:
- brokerClient = sharedBrokerClientDisableWait
- case 2:
- brokerClient, err = newClientBrokerClient(true)
- if err != nil {
- return errors.Trace(err)
- }
- }
- webRTCCoordinator, err := newClientWebRTCDialCoordinator(
- isMobile, useMediaStreams)
- if err != nil {
- return errors.Trace(err)
- }
- clientsGroup.Go(
- makeClientFunc(
- i,
- isTCP,
- brokerClient,
- webRTCCoordinator))
- }
- if doMustUpgrade {
- // Await MustUpgrade callbacks
- logger.WithTrace().Info("AWAIT MUST UPGRADE")
- <-receivedProxyMustUpgrade
- <-receivedClientMustUpgrade
- _ = clientsGroup.Wait()
- } else {
- // Await client transfers complete
- logger.WithTrace().Info("AWAIT DATA TRANSFER")
- err = clientsGroup.Wait()
- if err != nil {
- return errors.Trace(err)
- }
- logger.WithTrace().Info("DONE DATA TRANSFER")
- if hasPendingBrokerServerReports() {
- return errors.TraceNew("unexpected pending broker server requests")
- }
- if hasPendingProxyTacticsCallbacks() {
- return errors.TraceNew("unexpected pending proxy tactics callback")
- }
- err = serverQualityGroup.Wait()
- if err != nil {
- return errors.Trace(err)
- }
- // Inspect the broker's proxy quality state, to verify that the proxy
- // quality request was processed.
- //
- // Limitation: currently we don't check the priority
- // announcement _queue_, as announcements may have arrived before the
- // quality request, and announcements are promoted between queues.
- serverQualityProxyIDsMutex.Lock()
- defer serverQualityProxyIDsMutex.Unlock()
- for proxyID := range serverQualityProxyIDs {
- if !broker.proxyQualityState.HasQuality(proxyID, testProxyASN, "") {
- return errors.TraceNew("unexpected missing HasQuality (no client ASN)")
- }
- if !broker.proxyQualityState.HasQuality(proxyID, testProxyASN, testClientASN) {
- return errors.TraceNew("unexpected missing HasQuality (with client ASN)")
- }
- }
- // TODO: check that elapsed time is consistent with rate limit (+/-)
- // Check if STUN server replay callbacks were triggered
- if !testDisableSTUN {
- if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
- return errors.TraceNew("unexpected STUN server succeeded count")
- }
- // Allow for some STUN server failures
- if atomic.LoadInt32(&stunServerAddressFailedCount) >= int32(numProxies/2) {
- return errors.TraceNew("unexpected STUN server failed count")
- }
- }
- // Check if RoundTripper server replay callbacks were triggered
- if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
- return errors.TraceNew("unexpected round tripper succeeded count")
- }
- if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
- return errors.TraceNew("unexpected round tripper failed count")
- }
- }
- // Await shutdowns
- stopTest()
- brokerListener.Close()
- err = testGroup.Wait()
- if err != nil {
- return errors.Trace(err)
- }
- return nil
- }
- func runHTTPServer(listener net.Listener, broker *Broker) error {
- handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // For this test, clients set the path to "/client" and proxies
- // set the path to "/proxy" and we use that to create stub GeoIP
- // data to pass the not-same-ASN condition.
- var geoIPData common.GeoIPData
- geoIPData.ASN = r.URL.Path
- requestPayload, err := ioutil.ReadAll(
- http.MaxBytesReader(w, r.Body, BrokerMaxRequestBodySize))
- if err != nil {
- fmt.Printf("runHTTPServer ioutil.ReadAll failed: %v\n", err)
- http.Error(w, "", http.StatusNotFound)
- return
- }
- clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
- extendTimeout := func(timeout time.Duration) {
- // TODO: set insufficient initial timeout, so extension is
- // required for success
- http.NewResponseController(w).SetWriteDeadline(time.Now().Add(timeout))
- }
- responsePayload, err := broker.HandleSessionPacket(
- r.Context(),
- extendTimeout,
- nil,
- clientIP,
- geoIPData,
- requestPayload)
- if err != nil {
- fmt.Printf("runHTTPServer HandleSessionPacket failed: %v\n", err)
- http.Error(w, "", http.StatusNotFound)
- return
- }
- w.WriteHeader(http.StatusOK)
- w.Write(responsePayload)
- })
- // WriteTimeout will be extended via extendTimeout.
- httpServer := &http.Server{
- ReadTimeout: 10 * time.Second,
- WriteTimeout: 10 * time.Second,
- IdleTimeout: 1 * time.Minute,
- Handler: handler,
- }
- certificate, privateKey, _, err := common.GenerateWebServerCertificate("www.example.com")
- if err != nil {
- return errors.Trace(err)
- }
- tlsCert, err := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
- if err != nil {
- return errors.Trace(err)
- }
- tlsConfig := &tls.Config{
- Certificates: []tls.Certificate{tlsCert},
- }
- err = httpServer.Serve(tls.NewListener(listener, tlsConfig))
- return errors.Trace(err)
- }
- type httpRoundTripper struct {
- httpClient *http.Client
- endpointAddr string
- path string
- }
- func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
- return &httpRoundTripper{
- httpClient: &http.Client{
- Transport: &http.Transport{
- ForceAttemptHTTP2: true,
- MaxIdleConns: 2,
- IdleConnTimeout: 1 * time.Minute,
- TLSHandshakeTimeout: 10 * time.Second,
- TLSClientConfig: &std_tls.Config{
- InsecureSkipVerify: true,
- },
- },
- },
- endpointAddr: endpointAddr,
- path: path,
- }
- }
- func (r *httpRoundTripper) RoundTrip(
- ctx context.Context,
- roundTripDelay time.Duration,
- roundTripTimeout time.Duration,
- requestPayload []byte) ([]byte, error) {
- if roundTripDelay > 0 {
- common.SleepWithContext(ctx, roundTripDelay)
- }
- requestCtx, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
- defer requestCancelFunc()
- url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
- request, err := http.NewRequestWithContext(
- requestCtx, "POST", url, bytes.NewReader(requestPayload))
- if err != nil {
- return nil, errors.Trace(err)
- }
- response, err := r.httpClient.Do(request)
- if err != nil {
- return nil, errors.Trace(err)
- }
- defer response.Body.Close()
- if response.StatusCode != http.StatusOK {
- return nil, errors.Tracef("unexpected response status code: %d", response.StatusCode)
- }
- responsePayload, err := io.ReadAll(response.Body)
- if err != nil {
- return nil, errors.Trace(err)
- }
- return responsePayload, nil
- }
- func (r *httpRoundTripper) Close() error {
- r.httpClient.CloseIdleConnections()
- return nil
- }
- func runTCPEchoServer(listener net.Listener) {
- for {
- conn, err := listener.Accept()
- if err != nil {
- fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
- return
- }
- go func(conn net.Conn) {
- buf := make([]byte, 32768)
- for {
- n, err := conn.Read(buf)
- if n > 0 {
- _, err = conn.Write(buf[:n])
- }
- if err != nil {
- fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
- return
- }
- }
- }(conn)
- }
- }
- type quicEchoServer struct {
- listener net.Listener
- obfuscationKey string
- }
- func newQuicEchoServer() (*quicEchoServer, error) {
- obfuscationKey := prng.HexString(32)
- listener, err := quic.Listen(
- nil,
- nil,
- "127.0.0.1:0",
- true,
- GetQUICMaxPacketSizeAdjustment(),
- obfuscationKey,
- false)
- if err != nil {
- return nil, errors.Trace(err)
- }
- return &quicEchoServer{
- listener: listener,
- obfuscationKey: obfuscationKey,
- }, nil
- }
- func (q *quicEchoServer) ObfuscationKey() string {
- return q.obfuscationKey
- }
- func (q *quicEchoServer) Close() error {
- return q.listener.Close()
- }
- func (q *quicEchoServer) Addr() net.Addr {
- return q.listener.Addr()
- }
- func (q *quicEchoServer) Run() {
- for {
- conn, err := q.listener.Accept()
- if err != nil {
- fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
- return
- }
- go func(conn net.Conn) {
- buf := make([]byte, 32768)
- for {
- n, err := conn.Read(buf)
- if n > 0 {
- _, err = conn.Write(buf[:n])
- }
- if err != nil {
- fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
- return
- }
- }
- }(conn)
- }
- }
|