瀏覽代碼

Send client-tunneled status to DSL backend

Rod Hynes 5 月之前
父節點
當前提交
aeaab09800

+ 1 - 0
psiphon/common/dsl/api.go

@@ -182,6 +182,7 @@ const MaxRelayPayloadSize = 65536
 const (
 	PsiphonClientIPHeader        = "X-Psiphon-Client-Ip"
 	PsiphonClientGeoIPDataHeader = "X-Psiphon-Client-Geoipdata"
+	PsiphonClientTunneledHeader  = "X-Psiphon-Client-Tunneled"
 	PsiphonHostIDHeader          = "X-Psiphon-Host-Id"
 
 	requestVersion                   = 1

+ 22 - 16
psiphon/common/dsl/dsl_test.go

@@ -59,7 +59,7 @@ type testConfig struct {
 	interruptDownloads bool
 	enableRetries      bool
 	repeatBeforeTTL    bool
-	isConnected        bool
+	isTunneled         bool
 	expectFailure      bool
 	cacheServerEntries bool
 }
@@ -105,9 +105,9 @@ func TestDSLs(t *testing.T) {
 			alreadyDiscovered: true,
 		},
 		{
-			name: "first request is-connected",
+			name: "first request is-tunneled",
 
-			isConnected: true,
+			isTunneled: true,
 		},
 		{
 			name: "cache server entries",
@@ -227,8 +227,8 @@ func testDSLs(testConfig *testConfig) error {
 	if testConfig.enableRetries {
 		retryCount = 20
 	}
-	isConnected := testConfig.isConnected
-	if isConnected {
+	isTunneled := testConfig.isTunneled
+	if isTunneled {
 		discoverCount = 1
 	}
 
@@ -247,6 +247,7 @@ func testDSLs(testConfig *testConfig) error {
 			nil,
 			testClientIP,
 			testClientGeoIPData,
+			false,
 			requestPayload)
 		if err != nil {
 			return GetRelayGenericErrorResponse(), errors.Trace(err)
@@ -306,7 +307,7 @@ func testDSLs(testConfig *testConfig) error {
 	ctx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
 	defer cancelFunc()
 
-	err = fetcher.Run(ctx, isConnected)
+	err = fetcher.Run(ctx)
 	if testConfig.expectFailure && err == nil {
 		err = errors.TraceNew("unexpected success")
 	}
@@ -326,13 +327,13 @@ func testDSLs(testConfig *testConfig) error {
 			return nil, errors.TraceNew("round trip not permitted")
 		}
 
-		err = fetcher.Run(ctx, isConnected)
+		err = fetcher.Run(ctx)
 		if err != nil {
 			return errors.Trace(err)
 		}
 	}
 
-	if testConfig.alreadyDiscovered && testConfig.isConnected {
+	if testConfig.alreadyDiscovered && testConfig.isTunneled {
 		return errors.TraceNew("invalid test configuration")
 	}
 
@@ -347,27 +348,32 @@ func testDSLs(testConfig *testConfig) error {
 		dslClient.lastDiscoverTime = time.Time{}
 		dslClient.lastActiveOSLsTime = time.Time{}
 
-		err = fetcher.Run(ctx, isConnected)
+		err = fetcher.Run(ctx)
 		if err != nil {
 			return errors.Trace(err)
 		}
 	}
 
-	if testConfig.isConnected {
+	if testConfig.isTunneled {
 
-		// If the first request was isConnected, only one server entry will
-		// have been fetched and the last discover time TTL should not be
-		// set. Do another full fetch, and the
+		if dslClient.serverEntryStoreCount != 1 {
+			return errors.Tracef(
+				"unexpected server entry store count: %d", dslClient.serverEntryStoreCount)
+		}
+
+		// If the first request was isTunneled, only one server entry will
+		// have been fetched. Do another full fetch, and the following
 		// dslClient.serverEntryStoreCount check will demonstrate that all
 		// remaining server entries were downloaded and stored.
 
+		dslClient.lastDiscoverTime = time.Time{}
+
 		discoverCount = 128
-		isConnected = false
 
 		fetcherConfig.DiscoverServerEntriesMinCount = discoverCount
 		fetcherConfig.DiscoverServerEntriesMaxCount = discoverCount
 
-		err = fetcher.Run(ctx, isConnected)
+		err = fetcher.Run(ctx)
 		if err != nil {
 			return errors.Trace(err)
 		}
@@ -396,7 +402,7 @@ func testDSLs(testConfig *testConfig) error {
 
 		backend.oslPaveData = backendOSLPaveData2
 
-		err = fetcher.Run(ctx, isConnected)
+		err = fetcher.Run(ctx)
 		if err != nil {
 			return errors.Trace(err)
 		}

+ 30 - 32
psiphon/common/dsl/fetcher.go

@@ -81,6 +81,11 @@ type FetcherConfig struct {
 	GetOSLFileSpecsMinCount       int
 	GetOSLFileSpecsMaxCount       int
 
+	// WaitForNetworkConnectivity is an optional  callback that should block
+	// until there is network connectivity or shutdown. The return value is
+	// true when there is network connectivity, and false for shutdown.
+	WaitForNetworkConnectivity func() bool
+
 	DoGarbageCollection func()
 }
 
@@ -128,15 +133,6 @@ func NewFetcher(config *FetcherConfig) (*Fetcher, error) {
 // Run performs a server entry discovery/download and OSL synchronization
 // sequence.
 //
-// Run supports two modes:
-//   - Frequent, intended for fetching via an established SSH tunnel, and
-//     discovering only a small number of servers. Frequent fetches can be
-//     repeated often.
-//   - Non-frequent, intended for fetching via an untunneled relay, and invoked
-//     after a client is unable to connect with its known servers. This fetch
-//     mode is intended for discovering a larger number of servers, and is
-//     subject to the DiscoverServerEntriesTTL, which skips repeated runs.
-//
 // Each Run may make incremental progress. New OSL state or new server entries
 // may be downloaded and persisted even when Run ultimately fails and returns
 // an error.
@@ -172,18 +168,15 @@ func NewFetcher(config *FetcherConfig) (*Fetcher, error) {
 //     This requirement means that if there's an ongoing untunneled fetcher run
 //     and a tunnel is established, any post-connected, frequent fetcher run
 //     must be skipped or postponed.
-func (f *Fetcher) Run(ctx context.Context, isFrequent bool) error {
-
-	if !isFrequent {
+func (f *Fetcher) Run(ctx context.Context) error {
 
-		lastTime, err := f.config.DatastoreGetLastDiscoverTime()
-		if err != nil {
-			return errors.Trace(err)
-		}
+	lastTime, err := f.config.DatastoreGetLastDiscoverTime()
+	if err != nil {
+		return errors.Trace(err)
+	}
 
-		if time.Now().Before(lastTime.Add(f.config.DiscoverServerEntriesTTL)) {
-			return nil
-		}
+	if time.Now().Before(lastTime.Add(f.config.DiscoverServerEntriesTTL)) {
+		return nil
 	}
 
 	// processOSLs will:
@@ -288,21 +281,19 @@ func (f *Fetcher) Run(ctx context.Context, isFrequent bool) error {
 		f.config.DoGarbageCollection()
 	}
 
-	if !isFrequent {
-		err = f.config.DatastoreSetLastDiscoverTime(time.Now())
-		if err != nil {
-			err = errors.Trace(err)
+	err = f.config.DatastoreSetLastDiscoverTime(time.Now())
+	if err != nil {
+		err = errors.Trace(err)
 
-			// Signal a fatal datastore error. The caller should not run any
-			// Fetcher again, for the duration of its process, since the
-			// LastDiscoverTime mechanism won't prevent excess repeats.
+		// Signal a fatal datastore error. The caller should not run any
+		// Fetcher again, for the duration of its process, since the
+		// LastDiscoverTime mechanism won't prevent excess repeats.
 
-			f.config.DatastoreFatalError(err)
-			f.config.Logger.WithTraceFields(common.LogFields{
-				"error": err.Error(),
-			}).Warning("DSL: datastore failed")
-			// Proceed with this one run
-		}
+		f.config.DatastoreFatalError(err)
+		f.config.Logger.WithTraceFields(common.LogFields{
+			"error": err.Error(),
+		}).Warning("DSL: datastore failed")
+		// Proceed with this one run
 	}
 
 	if oslErr != nil {
@@ -709,6 +700,13 @@ func (f *Fetcher) doRelayedRequest(
 	request any,
 	response any) (retRetry bool, retErr error) {
 
+	// Delay attempt to fetch while there is no network connectivity.
+
+	if f.config.WaitForNetworkConnectivity != nil &&
+		!f.config.WaitForNetworkConnectivity() {
+		return false, errors.TraceNew("shutdown")
+	}
+
 	// Add the relay wrapping.
 
 	cborRequest, err := protocol.CBOREncoding.Marshal(request)

+ 8 - 0
psiphon/common/dsl/relay.go

@@ -212,6 +212,8 @@ func (r *Relay) SetCacheParameters(
 // expected maximum request timeout, including retries; this callback may be
 // used to customize the response timeout for a transport handler.
 //
+// Set isClientTunneled when the relay uses a connected Psiphon tunnel.
+//
 // In the case of an error, the caller must log the error and send
 // dsl.GenericErrorResponse to the client. This generic error response
 // ensures that the client receives a DSL response and doesn't consider the
@@ -221,6 +223,7 @@ func (r *Relay) HandleRequest(
 	extendTimeout func(time.Duration),
 	clientIP string,
 	clientGeoIPData common.GeoIPData,
+	isClientTunneled bool,
 	cborRelayedRequest []byte) ([]byte, error) {
 
 	r.mutex.Lock()
@@ -339,6 +342,11 @@ func (r *Relay) HandleRequest(
 		}
 		httpRequest.Header.Set(PsiphonClientIPHeader, clientIP)
 		httpRequest.Header.Set(PsiphonClientGeoIPDataHeader, string(jsonGeoIPData))
+		if isClientTunneled {
+			httpRequest.Header.Set(PsiphonClientTunneledHeader, "true")
+		} else {
+			httpRequest.Header.Set(PsiphonClientTunneledHeader, "false")
+		}
 		httpRequest.Header.Set(PsiphonHostIDHeader, r.config.HostID)
 
 		startTime := time.Now()

+ 1 - 0
psiphon/server/api.go

@@ -1024,6 +1024,7 @@ func dslAPIRequestHandler(
 		nil, // no extendTimeout
 		sshClient.getClientIP(),
 		common.GeoIPData(sshClient.getClientGeoIPData()),
+		true, // client request is tunneled
 		requestPayload)
 	return responsePayload, errors.Trace(err)
 }

+ 2 - 0
psiphon/server/dsl.go

@@ -68,6 +68,7 @@ func dslHandleRequest(
 	extendTimeout func(time.Duration),
 	clientIP string,
 	clientGeoIPData common.GeoIPData,
+	isClientTunneled bool,
 	requestPayload []byte) ([]byte, error) {
 
 	relay := support.dslRelay
@@ -82,6 +83,7 @@ func dslHandleRequest(
 		extendTimeout,
 		clientIP,
 		clientGeoIPData,
+		isClientTunneled,
 		requestPayload)
 	if err != nil {
 		return dsl.GetRelayGenericErrorResponse(),

+ 1 - 0
psiphon/server/meek.go

@@ -2081,6 +2081,7 @@ func (server *MeekServer) inproxyBrokerRelayDSLRequest(
 		extendTimeout,
 		clientIP,
 		clientGeoIPData,
+		false, // client request is untunneled
 		requestPayload)
 	return responsePayload, errors.Trace(err)
 }