Browse Source

Merge branch 'master' into staging-client

Rod Hynes 1 year ago
parent
commit
e0c0a13e7d

+ 3 - 5
MobileLibrary/Android/PsiphonTunnel/PsiphonTunnel.java

@@ -207,11 +207,9 @@ public class PsiphonTunnel {
         mNetworkMonitor = new NetworkMonitor(new NetworkMonitor.NetworkChangeListener() {
             @Override
             public void onChanged() {
-                try {
-                    reconnectPsiphon();
-                } catch (Exception e) {
-                    mHostService.onDiagnosticMessage("reconnect error: " + e);
-                }
+                // networkChanged initiates a reset of all open network
+                // connections, including a tunnel reconnect.
+                Psi.networkChanged();
             }
         });
     }

+ 7 - 4
MobileLibrary/iOS/PsiphonTunnel/PsiphonTunnel/PsiphonTunnel.m

@@ -1552,11 +1552,14 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
 
         previousNetworkStatus = atomic_exchange(&self->currentNetworkStatus, networkStatus);
 
-        // Restart if the network status or interface has changed, unless the previous status was
-        // NetworkReachabilityNotReachable, because the tunnel should be waiting for connectivity in
-        // that case.
+        // Signal when the network status or interface has changed, unless the
+        // previous status was NetworkReachabilityNotReachable, because the
+        // tunnel should be waiting for connectivity in that case.
+        //
+        // GoPsiNetworkChanged initiates a reset of all open network
+        // connections, including a tunnel reconnect.
         if ((networkStatus != previousNetworkStatus || interfaceChanged) && previousNetworkStatus != NetworkReachabilityNotReachable) {
-            GoPsiReconnectTunnel();
+            GoPsiNetworkChanged();
         }
     }
 }

+ 12 - 0
MobileLibrary/psi/psi.go

@@ -280,6 +280,18 @@ func ReconnectTunnel() {
 	}
 }
 
+// NetworkChanged initiates a reset of all open network connections, including
+// a tunnel reconnect.
+func NetworkChanged() {
+
+	controllerMutex.Lock()
+	defer controllerMutex.Unlock()
+
+	if controller != nil {
+		controller.NetworkChanged()
+	}
+}
+
 // SetDynamicConfig overrides the sponsor ID and authorizations fields set in
 // the config passed to Start. SetDynamicConfig has no effect if no Controller
 // is started.

+ 7 - 0
psiphon/common/inproxy/inproxy_test.go

@@ -91,6 +91,9 @@ func runTestInproxy(doMustUpgrade bool) error {
 	testNewTacticsTag := "new-tactics-tag"
 	testUnchangedTacticsPayload := []byte(prng.HexString(100))
 
+	currentNetworkCtx, currentNetworkCancelFunc := context.WithCancel(context.Background())
+	defer currentNetworkCancelFunc()
+
 	// TODO: test port mapping
 
 	stunServerAddressSucceededCount := int32(0)
@@ -438,6 +441,10 @@ func runTestInproxy(doMustUpgrade bool) error {
 				return true
 			},
 
+			GetCurrentNetworkContext: func() context.Context {
+				return currentNetworkCtx
+			},
+
 			GetBrokerClient: func() (*BrokerClient, error) {
 				return brokerClient, nil
 			},

+ 16 - 0
psiphon/common/inproxy/proxy.go

@@ -87,6 +87,14 @@ type ProxyConfig struct {
 	// there is network connectivity, and false for shutdown.
 	WaitForNetworkConnectivity func() bool
 
+	// GetCurrentNetworkContext is a callback that returns a context tied to
+	// the lifetime of the host's current active network interface. If the
+	// active network changes, the previous context returned by
+	// GetCurrentNetworkContext should cancel. This context is used to
+	// immediately cancel/close individual connections when the active
+	// network changes.
+	GetCurrentNetworkContext func() context.Context
+
 	// GetBrokerClient provides a BrokerClient which the proxy will use for
 	// making broker requests. If GetBrokerClient returns a shared
 	// BrokerClient instance, the BrokerClient must support multiple,
@@ -510,6 +518,14 @@ func (p *Proxy) proxyOneClient(
 	logAnnounce func() bool,
 	signalAnnounceDone func()) (bool, error) {
 
+	// Cancel/close this connection immediately if the network changes.
+	if p.config.GetCurrentNetworkContext != nil {
+		var cancelFunc context.CancelFunc
+		ctx, cancelFunc = common.MergeContextCancel(
+			ctx, p.config.GetCurrentNetworkContext())
+		defer cancelFunc()
+	}
+
 	// Do not trigger back-off unless the proxy successfully announces and
 	// only then performs poorly.
 	//

+ 37 - 1
psiphon/common/parameters/parameters.go

@@ -315,6 +315,14 @@ const (
 	HoldOffTunnelProtocols                             = "HoldOffTunnelProtocols"
 	HoldOffTunnelFrontingProviderIDs                   = "HoldOffTunnelFrontingProviderIDs"
 	HoldOffTunnelProbability                           = "HoldOffTunnelProbability"
+	HoldOffTunnelProtocolMinDuration                   = "HoldOffTunnelProtocolMinDuration"
+	HoldOffTunnelProtocolMaxDuration                   = "HoldOffTunnelProtocolMaxDuration"
+	HoldOffTunnelProtocolNames                         = "HoldOffTunnelProtocolNames"
+	HoldOffTunnelProtocolProbability                   = "HoldOffTunnelProtocolProbability"
+	HoldOffFrontingTunnelMinDuration                   = "HoldOffFrontingTunnelMinDuration"
+	HoldOffFrontingTunnelMaxDuration                   = "HoldOffFrontingTunnelMaxDuration"
+	HoldOffFrontingTunnelProviderIDs                   = "HoldOffFrontingTunnelProviderIDs"
+	HoldOffFrontingTunnelProbability                   = "HoldOffFrontingTunnelProbability"
 	RestrictFrontingProviderIDs                        = "RestrictFrontingProviderIDs"
 	RestrictFrontingProviderIDsServerProbability       = "RestrictFrontingProviderIDsServerProbability"
 	RestrictFrontingProviderIDsClientProbability       = "RestrictFrontingProviderIDsClientProbability"
@@ -325,8 +333,16 @@ const (
 	RestrictDirectProviderRegions                      = "RestrictDirectProviderRegions"
 	RestrictDirectProviderIDsServerProbability         = "RestrictDirectProviderIDsServerProbability"
 	RestrictDirectProviderIDsClientProbability         = "RestrictDirectProviderIDsClientProbability"
+	HoldOffInproxyTunnelMinDuration                    = "HoldOffInproxyTunnelMinDuration"
+	HoldOffInproxyTunnelMaxDuration                    = "HoldOffInproxyTunnelMaxDuration"
+	HoldOffInproxyTunnelProviderRegions                = "HoldOffInproxyTunnelProviderRegions"
+	HoldOffInproxyTunnelProbability                    = "HoldOffInproxyTunnelProbability"
+	RestrictInproxyProviderRegions                     = "RestrictInproxyProviderRegions"
+	RestrictInproxyProviderIDsServerProbability        = "RestrictInproxyProviderIDsServerProbability"
+	RestrictInproxyProviderIDsClientProbability        = "RestrictInproxyProviderIDsClientProbability"
 	UpstreamProxyAllowAllServerEntrySources            = "UpstreamProxyAllowAllServerEntrySources"
 	DestinationBytesMetricsASN                         = "DestinationBytesMetricsASN"
+	DestinationBytesMetricsASNs                        = "DestinationBytesMetricsASNs"
 	DNSResolverAttemptsPerServer                       = "DNSResolverAttemptsPerServer"
 	DNSResolverAttemptsPerPreferredServer              = "DNSResolverAttemptsPerPreferredServer"
 	DNSResolverRequestTimeout                          = "DNSResolverRequestTimeout"
@@ -811,6 +827,16 @@ var defaultParameters = map[string]struct {
 	HoldOffTunnelFrontingProviderIDs: {value: []string{}},
 	HoldOffTunnelProbability:         {value: 0.0, minimum: 0.0},
 
+	HoldOffTunnelProtocolMinDuration: {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffTunnelProtocolMaxDuration: {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffTunnelProtocolNames:       {value: protocol.TunnelProtocols{}},
+	HoldOffTunnelProtocolProbability: {value: 0.0, minimum: 0.0},
+
+	HoldOffFrontingTunnelMinDuration: {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffFrontingTunnelMaxDuration: {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffFrontingTunnelProviderIDs: {value: []string{}},
+	HoldOffFrontingTunnelProbability: {value: 0.0, minimum: 0.0},
+
 	RestrictFrontingProviderIDs:                  {value: []string{}},
 	RestrictFrontingProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
 	RestrictFrontingProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
@@ -824,9 +850,19 @@ var defaultParameters = map[string]struct {
 	RestrictDirectProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
 	RestrictDirectProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
 
+	HoldOffInproxyTunnelMinDuration:     {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffInproxyTunnelMaxDuration:     {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffInproxyTunnelProviderRegions: {value: KeyStrings{}},
+	HoldOffInproxyTunnelProbability:     {value: 0.0, minimum: 0.0},
+
+	RestrictInproxyProviderRegions:              {value: KeyStrings{}},
+	RestrictInproxyProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
+	RestrictInproxyProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
+
 	UpstreamProxyAllowAllServerEntrySources: {value: false},
 
-	DestinationBytesMetricsASN: {value: "", flags: serverSideOnly},
+	DestinationBytesMetricsASN:  {value: "", flags: serverSideOnly},
+	DestinationBytesMetricsASNs: {value: []string{}, flags: serverSideOnly},
 
 	DNSResolverAttemptsPerServer:                {value: 2, minimum: 1},
 	DNSResolverAttemptsPerPreferredServer:       {value: 1, minimum: 1},

+ 23 - 7
psiphon/common/quic/quic.go

@@ -449,6 +449,8 @@ func Dial(
 		// isObfuscated QUIC versions. This mitigates upstream fingerprints;
 		// see ObfuscatedPacketConn.writePacket for the server-side
 		// downstream limitation.
+		//
+		// Update: quic-go now writes ECN bits; see quic-go PR 3999.
 
 		// Ensure blocked packet writes eventually timeout. Note that quic-go
 		// manages read deadlines; we set only the write deadline here.
@@ -940,16 +942,30 @@ func (t *QUICTransporter) dialQUIC() (retConnection quicConnection, retErr error
 		return nil, errors.Trace(err)
 	}
 
-	// Check for a *net.UDPConn, as expected, to support OOB operations.
+	// See `udpConn, ok := packetConn.(*net.UDPConn)` block and comment in
+	// Dial. The same two cases are implemented here, although there is no
+	// obfuscated fronted QUIC.
+	//
+	// Limitation: for FRONTED-MEEK-QUIC-OSSH, OOB operations to support
+	// reading/writing ECN bits will not be enabled due to the
+	// meekUnderlyingPacketConn wrapping in the provided udpDialer.
+
 	udpConn, ok := packetConn.(*net.UDPConn)
+
 	if !ok {
-		return nil, errors.Tracef("unexpected packetConn type: %T", packetConn)
-	}
 
-	// Ensure blocked packet writes eventually timeout. Note that quic-go
-	// manages read deadlines; we set only the write deadline here.
-	packetConn = &common.WriteTimeoutUDPConn{
-		UDPConn: udpConn,
+		// Ensure blocked packet writes eventually timeout. Note that quic-go
+		// manages read deadlines; we set only the write deadline here.
+		packetConn = &common.WriteTimeoutPacketConn{
+			PacketConn: packetConn,
+		}
+
+	} else {
+
+		// Ensure blocked packet writes eventually timeout.
+		packetConn = &common.WriteTimeoutUDPConn{
+			UDPConn: udpConn,
+		}
 	}
 
 	connection, err := dialQUIC(

+ 16 - 0
psiphon/common/utils.go

@@ -276,3 +276,19 @@ func MergeContextCancel(ctx, cancelCtx context.Context) (context.Context, contex
 		cancel(context.Canceled)
 	}
 }
+
+// MaxDuration returns the maximum duration in durations or 0 if durations is
+// empty.
+func MaxDuration(durations ...time.Duration) time.Duration {
+	if len(durations) == 0 {
+		return 0
+	}
+
+	max := durations[0]
+	for _, d := range durations[1:] {
+		if d > max {
+			max = d
+		}
+	}
+	return max
+}

+ 135 - 34
psiphon/config.go

@@ -872,13 +872,19 @@ type Config struct {
 	ConjureSTUNServerAddresses                []string
 	ConjureDTLSEmptyInitialPacketProbability  *float64
 
-	// HoldOffTunnelMinDurationMilliseconds and other HoldOffTunnel fields are
-	// for testing purposes.
-	HoldOffTunnelMinDurationMilliseconds *int
-	HoldOffTunnelMaxDurationMilliseconds *int
-	HoldOffTunnelProtocols               []string
-	HoldOffTunnelFrontingProviderIDs     []string
-	HoldOffTunnelProbability             *float64
+	// HoldOffTunnelProtocolMinDurationMilliseconds and other
+	// HoldOffTunnelProtocol fields are for testing purposes.
+	HoldOffTunnelProtocolMinDurationMilliseconds *int
+	HoldOffTunnelProtocolMaxDurationMilliseconds *int
+	HoldOffTunnelProtocolNames                   []string
+	HoldOffTunnelProtocolProbability             *float64
+
+	// HoldOffFrontingTunnelMinDurationMilliseconds and other
+	// HoldOffFrontingTunnel fields are for testing purposes.
+	HoldOffFrontingTunnelMinDurationMilliseconds *int
+	HoldOffFrontingTunnelMaxDurationMilliseconds *int
+	HoldOffFrontingTunnelProviderIDs             []string
+	HoldOffFrontingTunnelProbability             *float64
 
 	// RestrictFrontingProviderIDs and other RestrictFrontingProviderIDs fields
 	// are for testing purposes.
@@ -897,6 +903,18 @@ type Config struct {
 	RestrictDirectProviderRegions              map[string][]string
 	RestrictDirectProviderIDsClientProbability *float64
 
+	// HoldOffInproxyTunnelMinDurationMilliseconds and other HoldOffInproxy
+	// fields are for testing purposes.
+	HoldOffInproxyTunnelMinDurationMilliseconds *int
+	HoldOffInproxyTunnelMaxDurationMilliseconds *int
+	HoldOffInproxyTunnelProviderRegions         map[string][]string
+	HoldOffInproxyTunnelProbability             *float64
+
+	// RestrictInproxyProviderRegions and other RestrictInproxy fields are for
+	// testing purposes.
+	RestrictInproxyProviderRegions              map[string][]string
+	RestrictInproxyProviderIDsClientProbability *float64
+
 	// UpstreamProxyAllowAllServerEntrySources is for testing purposes.
 	UpstreamProxyAllowAllServerEntrySources *bool
 
@@ -2202,24 +2220,36 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.ConjureDTLSEmptyInitialPacketProbability] = *config.ConjureDTLSEmptyInitialPacketProbability
 	}
 
-	if config.HoldOffTunnelMinDurationMilliseconds != nil {
-		applyParameters[parameters.HoldOffTunnelMinDuration] = fmt.Sprintf("%dms", *config.HoldOffTunnelMinDurationMilliseconds)
+	if config.HoldOffTunnelProtocolMinDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffTunnelProtocolMinDuration] = fmt.Sprintf("%dms", *config.HoldOffTunnelProtocolMinDurationMilliseconds)
+	}
+
+	if config.HoldOffTunnelProtocolMaxDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffTunnelProtocolMaxDuration] = fmt.Sprintf("%dms", *config.HoldOffTunnelProtocolMaxDurationMilliseconds)
 	}
 
-	if config.HoldOffTunnelMaxDurationMilliseconds != nil {
-		applyParameters[parameters.HoldOffTunnelMaxDuration] = fmt.Sprintf("%dms", *config.HoldOffTunnelMaxDurationMilliseconds)
+	if len(config.HoldOffTunnelProtocolNames) > 0 {
+		applyParameters[parameters.HoldOffTunnelProtocolNames] = protocol.TunnelProtocols(config.HoldOffTunnelProtocolNames)
 	}
 
-	if len(config.HoldOffTunnelProtocols) > 0 {
-		applyParameters[parameters.HoldOffTunnelProtocols] = protocol.TunnelProtocols(config.HoldOffTunnelProtocols)
+	if config.HoldOffTunnelProtocolProbability != nil {
+		applyParameters[parameters.HoldOffTunnelProtocolProbability] = *config.HoldOffTunnelProtocolProbability
 	}
 
-	if len(config.HoldOffTunnelFrontingProviderIDs) > 0 {
-		applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = config.HoldOffTunnelFrontingProviderIDs
+	if config.HoldOffFrontingTunnelMinDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffFrontingTunnelMinDuration] = fmt.Sprintf("%dms", *config.HoldOffFrontingTunnelMinDurationMilliseconds)
 	}
 
-	if config.HoldOffTunnelProbability != nil {
-		applyParameters[parameters.HoldOffTunnelProbability] = *config.HoldOffTunnelProbability
+	if config.HoldOffFrontingTunnelMaxDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffFrontingTunnelMaxDuration] = fmt.Sprintf("%dms", *config.HoldOffFrontingTunnelMaxDurationMilliseconds)
+	}
+
+	if len(config.HoldOffFrontingTunnelProviderIDs) > 0 {
+		applyParameters[parameters.HoldOffFrontingTunnelProviderIDs] = config.HoldOffFrontingTunnelProviderIDs
+	}
+
+	if config.HoldOffFrontingTunnelProbability != nil {
+		applyParameters[parameters.HoldOffFrontingTunnelProbability] = *config.HoldOffFrontingTunnelProbability
 	}
 
 	if config.HoldOffDirectTunnelMinDurationMilliseconds != nil {
@@ -2254,6 +2284,22 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = *config.RestrictFrontingProviderIDsClientProbability
 	}
 
+	if config.HoldOffInproxyTunnelMinDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffInproxyTunnelMinDuration] = fmt.Sprintf("%dms", *config.HoldOffInproxyTunnelMinDurationMilliseconds)
+	}
+
+	if config.HoldOffInproxyTunnelMaxDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffInproxyTunnelMaxDuration] = fmt.Sprintf("%dms", *config.HoldOffInproxyTunnelMaxDurationMilliseconds)
+	}
+
+	if len(config.HoldOffInproxyTunnelProviderRegions) > 0 {
+		applyParameters[parameters.HoldOffInproxyTunnelProviderRegions] = parameters.KeyStrings(config.HoldOffInproxyTunnelProviderRegions)
+	}
+
+	if config.HoldOffInproxyTunnelProbability != nil {
+		applyParameters[parameters.HoldOffInproxyTunnelProbability] = *config.HoldOffInproxyTunnelProbability
+	}
+
 	if config.UpstreamProxyAllowAllServerEntrySources != nil {
 		applyParameters[parameters.UpstreamProxyAllowAllServerEntrySources] = *config.UpstreamProxyAllowAllServerEntrySources
 	}
@@ -3005,30 +3051,50 @@ func (config *Config) setDialParametersHash() {
 		}
 	}
 
-	if config.HoldOffTunnelMinDurationMilliseconds != nil {
-		hash.Write([]byte("HoldOffTunnelMinDurationMilliseconds"))
-		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffTunnelMinDurationMilliseconds))
+	if config.HoldOffTunnelProtocolMinDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffTunnelProtocolMinDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffTunnelProtocolMinDurationMilliseconds))
 	}
 
-	if config.HoldOffTunnelMaxDurationMilliseconds != nil {
-		hash.Write([]byte("HoldOffTunnelMaxDurationMilliseconds"))
-		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffTunnelMaxDurationMilliseconds))
+	if config.HoldOffTunnelProtocolMaxDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffTunnelProtocolMaxDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffTunnelProtocolMaxDurationMilliseconds))
 	}
 
-	if len(config.HoldOffTunnelProtocols) > 0 {
-		hash.Write([]byte("HoldOffTunnelProtocols"))
-		for _, protocol := range config.HoldOffTunnelProtocols {
+	if len(config.HoldOffTunnelProtocolNames) > 0 {
+		hash.Write([]byte("HoldOffTunnelProtocolNames"))
+		for _, protocol := range config.HoldOffTunnelProtocolNames {
 			hash.Write([]byte(protocol))
 		}
 	}
 
-	if len(config.HoldOffTunnelFrontingProviderIDs) > 0 {
-		hash.Write([]byte("HoldOffTunnelFrontingProviderIDs"))
-		for _, providerID := range config.HoldOffTunnelFrontingProviderIDs {
+	if config.HoldOffTunnelProtocolProbability != nil {
+		hash.Write([]byte("HoldOffTunnelProtocolProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.HoldOffTunnelProtocolProbability)
+	}
+
+	if config.HoldOffFrontingTunnelMinDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffFrontingTunnelMinDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffFrontingTunnelMinDurationMilliseconds))
+	}
+
+	if config.HoldOffFrontingTunnelMaxDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffFrontingTunnelMaxDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffFrontingTunnelMaxDurationMilliseconds))
+	}
+
+	if len(config.HoldOffFrontingTunnelProviderIDs) > 0 {
+		hash.Write([]byte("HoldOffFrontingTunnelProviderIDs"))
+		for _, providerID := range config.HoldOffFrontingTunnelProviderIDs {
 			hash.Write([]byte(providerID))
 		}
 	}
 
+	if config.HoldOffFrontingTunnelProbability != nil {
+		hash.Write([]byte("HoldOffFrontingTunnelProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.HoldOffFrontingTunnelProbability)
+	}
+
 	if config.HoldOffDirectTunnelProbability != nil {
 		hash.Write([]byte("HoldOffDirectTunnelProbability"))
 		binary.Write(hash, binary.LittleEndian, *config.HoldOffDirectTunnelProbability)
@@ -3054,11 +3120,6 @@ func (config *Config) setDialParametersHash() {
 		}
 	}
 
-	if config.HoldOffTunnelProbability != nil {
-		hash.Write([]byte("HoldOffTunnelProbability"))
-		binary.Write(hash, binary.LittleEndian, *config.HoldOffTunnelProbability)
-	}
-
 	if len(config.RestrictDirectProviderRegions) > 0 {
 		hash.Write([]byte("RestrictDirectProviderRegions"))
 		for providerID, regions := range config.RestrictDirectProviderRegions {
@@ -3086,6 +3147,46 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, *config.RestrictFrontingProviderIDsClientProbability)
 	}
 
+	if config.HoldOffInproxyTunnelProbability != nil {
+		hash.Write([]byte("HoldOffInproxyTunnelProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.HoldOffInproxyTunnelProbability)
+	}
+
+	if config.HoldOffInproxyTunnelMinDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffInproxyTunnelMinDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffInproxyTunnelMinDurationMilliseconds))
+	}
+
+	if config.HoldOffInproxyTunnelMaxDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffInproxyTunnelMaxDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffInproxyTunnelMaxDurationMilliseconds))
+	}
+
+	if len(config.HoldOffInproxyTunnelProviderRegions) > 0 {
+		hash.Write([]byte("HoldOffInproxyTunnelProviderRegions"))
+		for providerID, regions := range config.HoldOffInproxyTunnelProviderRegions {
+			hash.Write([]byte(providerID))
+			for _, region := range regions {
+				hash.Write([]byte(region))
+			}
+		}
+	}
+
+	if len(config.RestrictInproxyProviderRegions) > 0 {
+		hash.Write([]byte("RestrictInproxyProviderRegions"))
+		for providerID, regions := range config.RestrictInproxyProviderRegions {
+			hash.Write([]byte(providerID))
+			for _, region := range regions {
+				hash.Write([]byte(region))
+			}
+		}
+	}
+
+	if config.RestrictInproxyProviderIDsClientProbability != nil {
+		hash.Write([]byte("RestrictInproxyProviderIDsClientProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.RestrictInproxyProviderIDsClientProbability)
+	}
+
 	if config.UpstreamProxyAllowAllServerEntrySources != nil {
 		hash.Write([]byte("UpstreamProxyAllowAllServerEntrySources"))
 		binary.Write(hash, binary.LittleEndian, *config.UpstreamProxyAllowAllServerEntrySources)

+ 51 - 0
psiphon/controller.go

@@ -101,6 +101,10 @@ type Controller struct {
 	inproxyLastStoredTactics                time.Time
 	establishSignalForceTacticsFetch        chan struct{}
 	inproxyClientDialRateLimiter            *rate.Limiter
+
+	currentNetworkMutex      sync.Mutex
+	currentNetworkCtx        context.Context
+	currentNetworkCancelFunc context.CancelFunc
 }
 
 // NewController initializes a new controller.
@@ -177,6 +181,18 @@ func NewController(config *Config) (controller *Controller, err error) {
 		quicTLSClientSessionCache: tls.NewLRUClientSessionCache(0),
 	}
 
+	// Initialize the current network context. This context represents the
+	// lifetime of the host's current active network interface. When
+	// Controller.NetworkChanged is called (by the Android and iOS platform
+	// code), the previous current network interface is considered to be no
+	// longer active and the corresponding current network context is canceled.
+	// Components may use currentNetworkCtx to cancel and close old network
+	// connections and quickly initiate new connections when the active
+	// interface changes.
+
+	controller.currentNetworkCtx, controller.currentNetworkCancelFunc =
+		context.WithCancel(context.Background())
+
 	// Initialize untunneledDialConfig, used by untunneled dials including
 	// remote server list and upgrade downloads.
 	controller.untunneledDialConfig = &DialConfig{
@@ -411,6 +427,9 @@ func (controller *Controller) Run(ctx context.Context) {
 		controller.packetTunnelClient.Stop()
 	}
 
+	// Cleanup current network context
+	controller.currentNetworkCancelFunc()
+
 	// All workers -- runTunnels, establishment workers, and auxilliary
 	// workers such as fetch remote server list and untunneled uprade
 	// download -- operate with the controller run context and will all
@@ -437,6 +456,37 @@ func (controller *Controller) SetDynamicConfig(sponsorID string, authorizations
 	controller.config.SetDynamicConfig(sponsorID, authorizations)
 }
 
+// NetworkChanged initiates a reset of all open network connections, including
+// a tunnel reconnect, if one is running, as well as terminating any in-proxy
+// proxy connections.
+func (controller *Controller) NetworkChanged() {
+
+	// Explicitly reset components that don't use the current network context.
+	controller.TerminateNextActiveTunnel()
+	if controller.inproxyProxyBrokerClientManager != nil {
+		controller.inproxyProxyBrokerClientManager.NetworkChanged()
+	}
+	controller.inproxyClientBrokerClientManager.NetworkChanged()
+
+	controller.currentNetworkMutex.Lock()
+	defer controller.currentNetworkMutex.Unlock()
+
+	// Cancel the previous current network context, which will interrupt any
+	// operations using this context.
+	controller.currentNetworkCancelFunc()
+
+	// Create a new context for the new current network.
+	controller.currentNetworkCtx, controller.currentNetworkCancelFunc =
+		context.WithCancel(context.Background())
+}
+
+func (controller *Controller) getCurrentNetworkContext() context.Context {
+	controller.currentNetworkMutex.Lock()
+	defer controller.currentNetworkMutex.Unlock()
+
+	return controller.currentNetworkCtx
+}
+
 // TerminateNextActiveTunnel terminates the active tunnel, which will initiate
 // establishment of a new tunnel.
 func (controller *Controller) TerminateNextActiveTunnel() {
@@ -2936,6 +2986,7 @@ func (controller *Controller) runInproxyProxy() {
 		Logger:                        NoticeCommonLogger(debugLogging),
 		EnableWebRTCDebugLogging:      debugLogging,
 		WaitForNetworkConnectivity:    controller.inproxyWaitForNetworkConnectivity,
+		GetCurrentNetworkContext:      controller.getCurrentNetworkContext,
 		GetBrokerClient:               controller.inproxyGetProxyBrokerClient,
 		GetBaseAPIParameters:          controller.inproxyGetProxyAPIParameters,
 		MakeWebRTCDialCoordinator:     controller.inproxyMakeProxyWebRTCDialCoordinator,

+ 1 - 7
psiphon/controller_test.go

@@ -147,10 +147,7 @@ func TestObfuscatedSSH(t *testing.T) {
 		})
 }
 
-func TestTLS(t *testing.T) {
-
-	t.Skipf("temporarily disabled")
-
+func TestTLSOSSH(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
 			protocol:                 protocol.TUNNEL_PROTOCOL_TLS_OBFUSCATED_SSH,
@@ -286,9 +283,6 @@ func TestQUIC(t *testing.T) {
 }
 
 func TestFrontedQUIC(t *testing.T) {
-
-	t.Skipf("temporarily disabled")
-
 	if !quic.Enabled() {
 		t.Skip("QUIC is not enabled")
 	}

+ 58 - 15
psiphon/dialParameters.go

@@ -502,6 +502,27 @@ func MakeDialParameters(
 		}
 	}
 
+	// Skip this candidate when the clients tactics restrict usage of the
+	// provider ID. See the corresponding server-side enforcement comments in
+	// server.sshClient.setHandshakeState.
+	if protocol.TunnelProtocolUsesInproxy(dialParams.TunnelProtocol) &&
+		common.ContainsAny(
+			p.KeyStrings(parameters.RestrictInproxyProviderRegions, dialParams.ServerEntry.ProviderID), []string{"", serverEntry.Region}) {
+		if p.WeightedCoinFlip(
+			parameters.RestrictInproxyProviderIDsClientProbability) {
+
+			// When skipping, return nil/nil as no error should be logged.
+			// NoticeSkipServerEntry emits each skip reason, regardless
+			// of server entry, at most once per session.
+
+			NoticeSkipServerEntry(
+				"restricted provider ID: %s",
+				dialParams.ServerEntry.ProviderID)
+
+			return nil, nil
+		}
+	}
+
 	// Skip this candidate when the clients tactics restrict usage of the
 	// fronting provider ID. See the corresponding server-side enforcement
 	// comments in server.MeekServer.getSessionOrEndpoint.
@@ -980,22 +1001,32 @@ func MakeDialParameters(
 
 	if !isReplay || !replayHoldOffTunnel {
 
-		var holdOffTunnelDuration time.Duration
+		var HoldOffTunnelProtocolDuration time.Duration
+		var HoldOffFrontingTunnelDuration time.Duration
 		var holdOffDirectTunnelDuration time.Duration
+		var holdOffInproxyTunnelDuration time.Duration
 
 		if common.Contains(
-			p.TunnelProtocols(parameters.HoldOffTunnelProtocols), dialParams.TunnelProtocol) ||
+			p.TunnelProtocols(parameters.HoldOffTunnelProtocolNames), dialParams.TunnelProtocol) {
 
-			(protocol.TunnelProtocolUsesFrontedMeek(dialParams.TunnelProtocol) &&
-				common.Contains(
-					p.Strings(parameters.HoldOffTunnelFrontingProviderIDs),
-					dialParams.FrontingProviderID)) {
+			if p.WeightedCoinFlip(parameters.HoldOffTunnelProtocolProbability) {
 
-			if p.WeightedCoinFlip(parameters.HoldOffTunnelProbability) {
+				HoldOffTunnelProtocolDuration = prng.Period(
+					p.Duration(parameters.HoldOffTunnelProtocolMinDuration),
+					p.Duration(parameters.HoldOffTunnelProtocolMaxDuration))
+			}
+		}
 
-				holdOffTunnelDuration = prng.Period(
-					p.Duration(parameters.HoldOffTunnelMinDuration),
-					p.Duration(parameters.HoldOffTunnelMaxDuration))
+		if protocol.TunnelProtocolUsesFrontedMeek(dialParams.TunnelProtocol) &&
+			common.Contains(
+				p.Strings(parameters.HoldOffFrontingTunnelProviderIDs),
+				dialParams.FrontingProviderID) {
+
+			if p.WeightedCoinFlip(parameters.HoldOffFrontingTunnelProbability) {
+
+				HoldOffFrontingTunnelDuration = prng.Period(
+					p.Duration(parameters.HoldOffFrontingTunnelMinDuration),
+					p.Duration(parameters.HoldOffFrontingTunnelMaxDuration))
 			}
 		}
 
@@ -1011,12 +1042,24 @@ func MakeDialParameters(
 			}
 		}
 
-		// Use the longest hold off duration
-		if holdOffTunnelDuration >= holdOffDirectTunnelDuration {
-			dialParams.HoldOffTunnelDuration = holdOffTunnelDuration
-		} else {
-			dialParams.HoldOffTunnelDuration = holdOffDirectTunnelDuration
+		if protocol.TunnelProtocolUsesInproxy(dialParams.TunnelProtocol) &&
+			common.ContainsAny(
+				p.KeyStrings(parameters.HoldOffInproxyTunnelProviderRegions, dialParams.ServerEntry.ProviderID), []string{"", serverEntry.Region}) {
+
+			if p.WeightedCoinFlip(parameters.HoldOffInproxyTunnelProbability) {
+
+				holdOffInproxyTunnelDuration = prng.Period(
+					p.Duration(parameters.HoldOffInproxyTunnelMinDuration),
+					p.Duration(parameters.HoldOffInproxyTunnelMaxDuration))
+			}
 		}
+
+		// Use the longest hold off duration
+		dialParams.HoldOffTunnelDuration = common.MaxDuration(
+			HoldOffTunnelProtocolDuration,
+			HoldOffFrontingTunnelDuration,
+			holdOffDirectTunnelDuration,
+			holdOffInproxyTunnelDuration)
 	}
 
 	// OSSH prefix and seed transform are applied only to the OSSH tunnel protocol,

+ 61 - 12
psiphon/dialParameters_test.go

@@ -80,7 +80,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("error committing configuration file: %s", err)
 	}
 
-	holdOffTunnelProtocols := protocol.TunnelProtocols{protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH}
+	holdOffTunnelProtocolNames := protocol.TunnelProtocols{protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH}
 
 	providerID := prng.HexString(8)
 	frontingProviderID := prng.HexString(8)
@@ -90,17 +90,30 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		holdOffDirectTunnelProviderRegions = map[string][]string{providerID: {""}}
 	}
 
+	var holdOffInproxyTunnelProviderRegions parameters.KeyStrings
+	if protocol.TunnelProtocolUsesInproxy(tunnelProtocol) &&
+		protocol.TunnelProtocolMinusInproxy(tunnelProtocol) == protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
+		holdOffInproxyTunnelProviderRegions = map[string][]string{providerID: {""}}
+	}
+
 	applyParameters := make(map[string]interface{})
 	applyParameters[parameters.TransformHostNameProbability] = 1.0
 	applyParameters[parameters.PickUserAgentProbability] = 1.0
-	applyParameters[parameters.HoldOffTunnelMinDuration] = "1ms"
-	applyParameters[parameters.HoldOffTunnelMaxDuration] = "10ms"
-	applyParameters[parameters.HoldOffTunnelProtocols] = holdOffTunnelProtocols
-	applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = []string{frontingProviderID}
-	applyParameters[parameters.HoldOffTunnelProbability] = 1.0
+	applyParameters[parameters.HoldOffTunnelProtocolMinDuration] = "1ms"
+	applyParameters[parameters.HoldOffTunnelProtocolMaxDuration] = "10ms"
+	applyParameters[parameters.HoldOffTunnelProtocolNames] = holdOffTunnelProtocolNames
+	applyParameters[parameters.HoldOffTunnelProtocolProbability] = 1.0
+	applyParameters[parameters.HoldOffFrontingTunnelMinDuration] = "1ms"
+	applyParameters[parameters.HoldOffFrontingTunnelMaxDuration] = "10ms"
+	applyParameters[parameters.HoldOffFrontingTunnelProviderIDs] = []string{frontingProviderID}
+	applyParameters[parameters.HoldOffFrontingTunnelProbability] = 1.0
 	applyParameters[parameters.HoldOffDirectTunnelMinDuration] = "1ms"
 	applyParameters[parameters.HoldOffDirectTunnelMaxDuration] = "10ms"
 	applyParameters[parameters.HoldOffDirectTunnelProviderRegions] = holdOffDirectTunnelProviderRegions
+	applyParameters[parameters.HoldOffInproxyTunnelProbability] = 1.0
+	applyParameters[parameters.HoldOffInproxyTunnelMinDuration] = "1ms"
+	applyParameters[parameters.HoldOffInproxyTunnelMaxDuration] = "10ms"
+	applyParameters[parameters.HoldOffInproxyTunnelProviderRegions] = holdOffInproxyTunnelProviderRegions
 	applyParameters[parameters.HoldOffDirectTunnelProbability] = 1.0
 	applyParameters[parameters.DNSResolverAlternateServers] = []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}
 	applyParameters[parameters.DirectHTTPProtocolTransformProbability] = 1.0
@@ -245,16 +258,21 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("missing API request fields")
 	}
 
-	expectHoldOffTunnelProtocols := common.Contains(holdOffTunnelProtocols, tunnelProtocol)
-	expectHoldOffTunnelFrontingProviderIDs := protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol)
+	expectHoldOffTunnelProtocolNames := common.Contains(holdOffTunnelProtocolNames, tunnelProtocol)
+	expectHoldOffFrontingTunnelProviderIDs := protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol)
 	expectHoldOffDirectTunnelProviderRegion := protocol.TunnelProtocolIsDirect(tunnelProtocol) &&
 		common.ContainsAny(
 			holdOffDirectTunnelProviderRegions[dialParams.ServerEntry.ProviderID],
 			[]string{"", dialParams.ServerEntry.Region})
+	expectHoldOffInproxyTunnelProviderRegion := protocol.TunnelProtocolUsesInproxy(tunnelProtocol) &&
+		common.ContainsAny(
+			holdOffInproxyTunnelProviderRegions[dialParams.ServerEntry.ProviderID],
+			[]string{"", dialParams.ServerEntry.Region})
 
-	if expectHoldOffTunnelProtocols ||
-		expectHoldOffTunnelFrontingProviderIDs ||
-		expectHoldOffDirectTunnelProviderRegion {
+	if expectHoldOffTunnelProtocolNames ||
+		expectHoldOffFrontingTunnelProviderIDs ||
+		expectHoldOffDirectTunnelProviderRegion ||
+		expectHoldOffInproxyTunnelProviderRegion {
 		if dialParams.HoldOffTunnelDuration < 1*time.Millisecond ||
 			dialParams.HoldOffTunnelDuration > 10*time.Millisecond {
 			t.Fatalf("unexpected hold-off duration: %v", dialParams.HoldOffTunnelDuration)
@@ -557,7 +575,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 
-	// Test: client-side restrict provider ID by region
+	// Test: client-side restrict provider ID by region for direct protocols
 
 	applyParameters[parameters.RestrictDirectProviderRegions] = map[string][]string{providerID: {"CA"}}
 	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 1.0
@@ -588,6 +606,37 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 
+	// Test: client-side restrict provider ID by region for inproxy protocols
+
+	applyParameters[parameters.RestrictInproxyProviderRegions] = map[string][]string{providerID: {"CA"}}
+	applyParameters[parameters.RestrictInproxyProviderIDsClientProbability] = 1.0
+	err = clientConfig.SetParameters("tag8", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	dialParams, err = MakeDialParameters(
+		clientConfig, steeringIPCache, nil, nil, nil, canReplay, selectProtocol, serverEntries[0], nil, nil, false, 0, 0)
+
+	if protocol.TunnelProtocolUsesInproxy(tunnelProtocol) {
+		if err == nil {
+			if dialParams != nil {
+				t.Fatalf("unexpected MakeDialParameters success")
+			}
+		}
+	} else {
+		if err != nil {
+			t.Fatalf("MakeDialParameters failed: %s", err)
+		}
+	}
+
+	applyParameters[parameters.RestrictInproxyProviderRegions] = map[string][]string{}
+	applyParameters[parameters.RestrictInproxyProviderIDsClientProbability] = 0.0
+	err = clientConfig.SetParameters("tag9", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
 	if protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
 
 		steeringIPCache.Flush()

+ 19 - 1
psiphon/inproxy.go

@@ -114,6 +114,22 @@ func (b *InproxyBrokerClientManager) TacticsApplied() error {
 	return errors.Trace(b.reset(resetBrokerClientReasonTacticsApplied))
 }
 
+// NetworkChanged is called when the active network changes, to trigger a
+// broker client reset.
+func (b *InproxyBrokerClientManager) NetworkChanged() error {
+
+	b.mutex.Lock()
+	defer b.mutex.Unlock()
+
+	// Don't reset when not yet initialized; b.brokerClientInstance is
+	// initialized only on demand.
+	if b.brokerClientInstance == nil {
+		return nil
+	}
+
+	return errors.Trace(b.reset(resetBrokerClientReasonNetworkChanged))
+}
+
 // GetBrokerClient returns the current, shared broker client and its
 // corresponding dial parametrers (for metrics logging). If there is no
 // current broker client, if the network ID differs from the network ID
@@ -195,6 +211,7 @@ type resetBrokerClientReason int
 const (
 	resetBrokerClientReasonInit resetBrokerClientReason = iota + 1
 	resetBrokerClientReasonTacticsApplied
+	resetBrokerClientReasonNetworkChanged
 	resetBrokerClientReasonRoundTripperFailed
 	resetBrokerClientReasonRoundNoMatch
 )
@@ -220,7 +237,8 @@ func (b *InproxyBrokerClientManager) reset(reason resetBrokerClientReason) error
 
 	switch reason {
 	case resetBrokerClientReasonInit,
-		resetBrokerClientReasonTacticsApplied:
+		resetBrokerClientReasonTacticsApplied,
+		resetBrokerClientReasonNetworkChanged:
 		b.brokerSelectCount = 0
 
 	case resetBrokerClientReasonRoundTripperFailed,

+ 7 - 1
psiphon/server/api.go

@@ -119,8 +119,14 @@ func sshAPIRequestHandler(
 	switch name {
 
 	case protocol.PSIPHON_API_HANDSHAKE_REQUEST_NAME:
-		return handshakeAPIRequestHandler(
+		responsePayload, err := handshakeAPIRequestHandler(
 			support, protocol.PSIPHON_API_PROTOCOL_SSH, sshClient, params)
+		if err != nil {
+			// Handshake failed, disconnect the client.
+			sshClient.stop()
+			return nil, errors.Trace(err)
+		}
+		return responsePayload, nil
 
 	case protocol.PSIPHON_API_CONNECTED_REQUEST_NAME:
 		return connectedAPIRequestHandler(

+ 3 - 1
psiphon/server/listener.go

@@ -117,7 +117,9 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
 	// peer IP is not the original client IP. Indirect protocols must determine
 	// the original client IP before applying GeoIP specific tactics; see the
 	// server-side enforcement of RestrictFrontingProviderIDs for fronted meek
-	// in server.MeekServer.getSessionOrEndpoint.
+	// in server.MeekServer.getSessionOrEndpoint or of
+	// RestrictInproxyProviderRegions for inproxy in
+	// server.sshClient.setHandshakeState.
 	//
 	// At this stage, GeoIP tactics filters are active, but handshake API
 	// parameters are not.

+ 161 - 50
psiphon/server/server_test.go

@@ -369,6 +369,17 @@ func TestInproxyOSSH(t *testing.T) {
 		})
 }
 
+func TestRestrictInproxy(t *testing.T) {
+	if !inproxy.Enabled() {
+		t.Skip("inproxy is not enabled")
+	}
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:    "INPROXY-WEBRTC-OSSH",
+			doRestrictInproxy: true,
+		})
+}
+
 func TestInproxyQUICOSSH(t *testing.T) {
 	if !quic.Enabled() {
 		t.Skip("QUIC is not enabled")
@@ -578,17 +589,32 @@ func TestBurstMonitorAndDestinationBytes(t *testing.T) {
 		})
 }
 
+func TestBurstMonitorAndLegacyDestinationBytes(t *testing.T) {
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:           "OSSH",
+			requireAuthorization:     true,
+			doTunneledWebRequest:     true,
+			doTunneledNTPRequest:     true,
+			doDanglingTCPConn:        true,
+			doBurstMonitor:           true,
+			doLegacyDestinationBytes: true,
+			doLogHostProvider:        true,
+		})
+}
+
 func TestChangeBytesConfig(t *testing.T) {
 	runServer(t,
 		&runServerConfig{
-			tunnelProtocol:       "OSSH",
-			requireAuthorization: true,
-			doTunneledWebRequest: true,
-			doTunneledNTPRequest: true,
-			doDanglingTCPConn:    true,
-			doDestinationBytes:   true,
-			doChangeBytesConfig:  true,
-			doLogHostProvider:    true,
+			tunnelProtocol:           "OSSH",
+			requireAuthorization:     true,
+			doTunneledWebRequest:     true,
+			doTunneledNTPRequest:     true,
+			doDanglingTCPConn:        true,
+			doDestinationBytes:       true,
+			doLegacyDestinationBytes: true,
+			doChangeBytesConfig:      true,
+			doLogHostProvider:        true,
 		})
 }
 
@@ -646,34 +672,36 @@ func TestLegacyAPIEncoding(t *testing.T) {
 }
 
 type runServerConfig struct {
-	tunnelProtocol       string
-	clientTunnelProtocol string
-	passthrough          bool
-	tlsProfile           string
-	doHotReload          bool
-	doDefaultSponsorID   bool
-	denyTrafficRules     bool
-	requireAuthorization bool
-	omitAuthorization    bool
-	doTunneledWebRequest bool
-	doTunneledNTPRequest bool
-	applyPrefix          bool
-	forceFragmenting     bool
-	forceLivenessTest    bool
-	doPruneServerEntries bool
-	doDanglingTCPConn    bool
-	doPacketManipulation bool
-	doBurstMonitor       bool
-	doSplitTunnel        bool
-	limitQUICVersions    bool
-	doDestinationBytes   bool
-	doChangeBytesConfig  bool
-	doLogHostProvider    bool
-	inspectFlows         bool
-	doSteeringIP         bool
-	doTargetBrokerSpecs  bool
-	useLegacyAPIEncoding bool
-	doPersonalPairing    bool
+	tunnelProtocol           string
+	clientTunnelProtocol     string
+	passthrough              bool
+	tlsProfile               string
+	doHotReload              bool
+	doDefaultSponsorID       bool
+	denyTrafficRules         bool
+	requireAuthorization     bool
+	omitAuthorization        bool
+	doTunneledWebRequest     bool
+	doTunneledNTPRequest     bool
+	applyPrefix              bool
+	forceFragmenting         bool
+	forceLivenessTest        bool
+	doPruneServerEntries     bool
+	doDanglingTCPConn        bool
+	doPacketManipulation     bool
+	doBurstMonitor           bool
+	doSplitTunnel            bool
+	limitQUICVersions        bool
+	doDestinationBytes       bool
+	doLegacyDestinationBytes bool
+	doChangeBytesConfig      bool
+	doLogHostProvider        bool
+	inspectFlows             bool
+	doSteeringIP             bool
+	doTargetBrokerSpecs      bool
+	useLegacyAPIEncoding     bool
+	doPersonalPairing        bool
+	doRestrictInproxy        bool
 }
 
 var (
@@ -776,7 +804,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		runConfig.applyPrefix ||
 		runConfig.forceFragmenting ||
 		runConfig.doBurstMonitor ||
-		runConfig.doDestinationBytes
+		runConfig.doDestinationBytes ||
+		runConfig.doLegacyDestinationBytes
 
 	// All servers require a tactics config with valid keys.
 	tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey, err :=
@@ -912,10 +941,12 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			livenessTestSize,
 			runConfig.doBurstMonitor,
 			runConfig.doDestinationBytes,
+			runConfig.doLegacyDestinationBytes,
 			runConfig.applyPrefix,
 			runConfig.forceFragmenting,
 			"classic",
-			inproxyTacticsParametersJSON)
+			inproxyTacticsParametersJSON,
+			runConfig.doRestrictInproxy)
 	}
 
 	blocklistFilename := filepath.Join(testDataDirName, "blocklist.csv")
@@ -1181,10 +1212,12 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 				livenessTestSize,
 				runConfig.doBurstMonitor,
 				runConfig.doDestinationBytes,
+				runConfig.doLegacyDestinationBytes,
 				runConfig.applyPrefix,
 				runConfig.forceFragmenting,
 				"consistent",
-				inproxyTacticsParametersJSON)
+				inproxyTacticsParametersJSON,
+				runConfig.doRestrictInproxy)
 		}
 
 		p, _ := os.FindProcess(os.Getpid())
@@ -1491,6 +1524,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	pruneServerEntriesNoticesEmitted := make(chan struct{}, 1)
 	serverAlertDisallowedNoticesEmitted := make(chan struct{}, 1)
 	untunneledPortForward := make(chan struct{}, 1)
+	discardTunnel := make(chan struct{}, 1)
 
 	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
@@ -1562,6 +1596,11 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 				if connectedClients == 1 && bytesUp > 0 && bytesDown > 0 {
 					sendNotificationReceived(inproxyActivity)
 				}
+
+			case "Info":
+				if strings.Contains(payload["message"].(string), "discard tunnel") {
+					sendNotificationReceived(discardTunnel)
+				}
 			}
 
 			if printNotice {
@@ -1614,16 +1653,23 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		close(timeoutSignal)
 	}()
 
-	waitOnNotification(t, connectedServer, timeoutSignal, "connected server timeout exceeded")
-	if doInproxy {
-		waitOnNotification(t, inproxyActivity, timeoutSignal, "inproxy activity timeout exceeded")
+	expectDiscardTunnel := runConfig.doRestrictInproxy
+
+	if expectDiscardTunnel {
+		waitOnNotification(t, discardTunnel, timeoutSignal, "discard tunnel timeout exceeded")
+		return
+	} else {
+		waitOnNotification(t, connectedServer, timeoutSignal, "connected server timeout exceeded")
+		if doInproxy {
+			waitOnNotification(t, inproxyActivity, timeoutSignal, "inproxy activity timeout exceeded")
+		}
+		waitOnNotification(t, tunnelsEstablished, timeoutSignal, "tunnel established timeout exceeded")
+		waitOnNotification(t, homepageReceived, timeoutSignal, "homepage received timeout exceeded")
 	}
-	waitOnNotification(t, tunnelsEstablished, timeoutSignal, "tunnel established timeout exceeded")
-	waitOnNotification(t, homepageReceived, timeoutSignal, "homepage received timeout exceeded")
 
 	if runConfig.doChangeBytesConfig {
 
-		if !runConfig.doDestinationBytes {
+		if !runConfig.doDestinationBytes || !runConfig.doLegacyDestinationBytes {
 			t.Fatalf("invalid test configuration")
 		}
 
@@ -1649,10 +1695,12 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			livenessTestSize,
 			runConfig.doBurstMonitor,
 			false,
+			false,
 			runConfig.applyPrefix,
 			runConfig.forceFragmenting,
 			"consistent",
-			inproxyTacticsParametersJSON)
+			inproxyTacticsParametersJSON,
+			runConfig.doRestrictInproxy)
 
 		p, _ := os.FindProcess(os.Getpid())
 		p.Signal(syscall.SIGUSR1)
@@ -1796,6 +1844,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		expectQUICVersion = limitQUICVersions[0]
 	}
 	expectDestinationBytesFields := runConfig.doDestinationBytes && !runConfig.doChangeBytesConfig
+	expectLegacyDestinationBytesFields := runConfig.doLegacyDestinationBytes && !runConfig.doChangeBytesConfig
 	expectMeekHTTPVersion := ""
 	if protocol.TunnelProtocolUsesMeek(runConfig.tunnelProtocol) {
 		if protocol.TunnelProtocolUsesFrontedMeek(runConfig.tunnelProtocol) {
@@ -1829,6 +1878,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			expectUDPDataTransfer,
 			expectQUICVersion,
 			expectDestinationBytesFields,
+			expectLegacyDestinationBytesFields,
 			passthroughAddress,
 			expectMeekHTTPVersion,
 			inproxyTestConfig,
@@ -2086,6 +2136,7 @@ func checkExpectedServerTunnelLogFields(
 	expectUDPDataTransfer bool,
 	expectQUICVersion string,
 	expectDestinationBytesFields bool,
+	expectLegacyDestinationBytesFields bool,
 	expectPassthroughAddress *string,
 	expectMeekHTTPVersion string,
 	inproxyTestConfig *inproxyTestConfig,
@@ -2691,6 +2742,45 @@ func checkExpectedServerTunnelLogFields(
 		}
 	}
 
+	for _, name := range []string{
+		"asn_dest_bytes",
+		"asn_dest_bytes_up_tcp",
+		"asn_dest_bytes_down_tcp",
+		"asn_dest_bytes_up_udp",
+		"asn_dest_bytes_down_udp",
+	} {
+		if expectDestinationBytesFields && fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+
+		} else if !expectDestinationBytesFields && fields[name] != nil {
+			return fmt.Errorf("unexpected field '%s'", name)
+		}
+	}
+
+	if expectDestinationBytesFields {
+		for _, pair := range [][]string{
+			{"asn_dest_bytes", "bytes"},
+			{"asn_dest_bytes_up_tcp", "bytes_up_tcp"},
+			{"asn_dest_bytes_down_tcp", "bytes_down_tcp"},
+			{"asn_dest_bytes_up_udp", "bytes_up_udp"},
+			{"asn_dest_bytes_down_udp", "bytes_down_udp"},
+		} {
+			if _, ok := fields[pair[0]].(map[string]any)[testGeoIPASN].(float64); !ok {
+				return fmt.Errorf("missing field entry %s: '%v'", pair[0], testGeoIPASN)
+			}
+			value0 := int64(fields[pair[0]].(map[string]any)[testGeoIPASN].(float64))
+			value1 := int64(fields[pair[1]].(float64))
+			ok := value0 == value1
+			if pair[0] == "asn_dest_bytes_up_udp" || pair[0] == "asn_dest_bytes_down_udp" || pair[0] == "asn_dest_bytes" {
+				// DNS requests are excluded from destination bytes counting
+				ok = value0 > 0 && value0 < value1
+			}
+			if !ok {
+				return fmt.Errorf("unexpected field value %s: %v != %v", pair[0], fields[pair[0]], fields[pair[1]])
+			}
+		}
+	}
+
 	for _, name := range []string{
 		"dest_bytes_asn",
 		"dest_bytes_up_tcp",
@@ -2699,15 +2789,15 @@ func checkExpectedServerTunnelLogFields(
 		"dest_bytes_down_udp",
 		"dest_bytes",
 	} {
-		if expectDestinationBytesFields && fields[name] == nil {
+		if expectLegacyDestinationBytesFields && fields[name] == nil {
 			return fmt.Errorf("missing expected field '%s'", name)
 
-		} else if !expectDestinationBytesFields && fields[name] != nil {
+		} else if !expectLegacyDestinationBytesFields && fields[name] != nil {
 			return fmt.Errorf("unexpected field '%s'", name)
 		}
 	}
 
-	if expectDestinationBytesFields {
+	if expectLegacyDestinationBytesFields {
 		name := "dest_bytes_asn"
 		if fields[name].(string) != testGeoIPASN {
 			return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
@@ -3385,10 +3475,12 @@ func paveTacticsConfigFile(
 	livenessTestSize int,
 	doBurstMonitor bool,
 	doDestinationBytes bool,
+	doLegacyDestinationBytes bool,
 	applyOsshPrefix bool,
 	enableOsshPrefixFragmenting bool,
 	discoveryStategy string,
-	inproxyParametersJSON string) {
+	inproxyParametersJSON string,
+	doRestrictAllInproxyProviderRegions bool) {
 
 	// Setting LimitTunnelProtocols passively exercises the
 	// server-side LimitTunnelProtocols enforcement.
@@ -3406,6 +3498,8 @@ func paveTacticsConfigFile(
           %s
           %s
           %s
+          %s
+          %s
           "LimitTunnelProtocols" : ["%s"],
           "FragmentorLimitProtocols" : ["%s"],
           "FragmentorProbability" : 1.0,
@@ -3490,6 +3584,13 @@ func paveTacticsConfigFile(
 	destinationBytesParameters := ""
 	if doDestinationBytes {
 		destinationBytesParameters = fmt.Sprintf(`
+          "DestinationBytesMetricsASNs" : ["%s"],
+	`, testGeoIPASN)
+	}
+
+	legacyDestinationBytesParameters := ""
+	if doLegacyDestinationBytes {
+		legacyDestinationBytesParameters = fmt.Sprintf(`
           "DestinationBytesMetricsASN" : "%s",
 	`, testGeoIPASN)
 	}
@@ -3506,6 +3607,14 @@ func paveTacticsConfigFile(
 	`, strconv.FormatBool(enableOsshPrefixFragmenting))
 	}
 
+	restrictInproxyParameters := ""
+	if doRestrictAllInproxyProviderRegions {
+		restrictInproxyParameters = `
+		"RestrictInproxyProviderRegions": {"" : [""]},
+		"RestrictInproxyProviderIDsServerProbability": 1.0,
+	`
+	}
+
 	tacticsConfigJSON := fmt.Sprintf(
 		tacticsConfigJSONFormat,
 		tacticsRequestPublicKey,
@@ -3513,8 +3622,10 @@ func paveTacticsConfigFile(
 		tacticsRequestObfuscatedKey,
 		burstParameters,
 		destinationBytesParameters,
+		legacyDestinationBytesParameters,
 		osshPrefix,
 		inproxyParametersJSON,
+		restrictInproxyParameters,
 		tunnelProtocol,
 		tunnelProtocol,
 		tunnelProtocol,

+ 149 - 35
psiphon/server/tunnelServer.go

@@ -1809,9 +1809,7 @@ type sshClient struct {
 	sendAlertRequests                    chan protocol.AlertRequest
 	sentAlertRequests                    map[string]bool
 	peakMetrics                          peakMetrics
-	destinationBytesMetricsASN           string
-	tcpDestinationBytesMetrics           destinationBytesMetrics
-	udpDestinationBytesMetrics           destinationBytesMetrics
+	destinationBytesMetrics              map[string]*protocolDestinationBytesMetrics
 }
 
 type trafficState struct {
@@ -1941,6 +1939,11 @@ type handshakeState struct {
 	inproxyRelayLogFields   common.LogFields
 }
 
+type protocolDestinationBytesMetrics struct {
+	tcpMetrics destinationBytesMetrics
+	udpMetrics destinationBytesMetrics
+}
+
 type destinationBytesMetrics struct {
 	bytesUp   int64
 	bytesDown int64
@@ -3365,20 +3368,18 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 	logFields["random_stream_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.downstreamBytes
 	logFields["random_stream_sent_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.sentDownstreamBytes
 
-	if sshClient.destinationBytesMetricsASN != "" {
+	if sshClient.destinationBytesMetrics != nil {
 
-		// Check if the configured DestinationBytesMetricsASN has changed
-		// (or been cleared). If so, don't log and discard the accumulated
-		// bytes to ensure we don't continue to record stats as previously
-		// configured.
+		// Only log destination bytes for ASNs that remain enabled in tactics.
 		//
-		// Any counts accumulated before the DestinationBytesMetricsASN change
-		// are lost. At this time we can't change
-		// sshClient.destinationBytesMetricsASN dynamically, after a tactics
-		// hot reload, as there may be destination bytes port forwards that
-		// were in place before the change, which will continue to count.
-
-		logDestBytes := true
+		// Any counts accumulated before DestinationBytesMetricsASN[s] changes
+		// are lost. At this time we can't change destination byte counting
+		// dynamically, after a tactics hot reload, as there may be
+		// destination bytes port forwards that were in place before the
+		// change, which will continue to count.
+
+		destinationBytesMetricsASNs := []string{}
+		destinationBytesMetricsASN := ""
 		if sshClient.sshServer.support.ServerTacticsParametersCache != nil {
 
 			// Target this using the client, not peer, GeoIP. In the case of
@@ -3387,24 +3388,69 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 			// have transferred.
 
 			p, err := sshClient.sshServer.support.ServerTacticsParametersCache.Get(sshClient.clientGeoIPData)
-			if err != nil || p.IsNil() ||
-				sshClient.destinationBytesMetricsASN != p.String(parameters.DestinationBytesMetricsASN) {
-				logDestBytes = false
+			if err == nil && !p.IsNil() {
+				destinationBytesMetricsASNs = p.Strings(parameters.DestinationBytesMetricsASNs)
+				destinationBytesMetricsASN = p.String(parameters.DestinationBytesMetricsASN)
 			}
+			p.Close()
 		}
 
-		if logDestBytes {
-			bytesUpTCP := sshClient.tcpDestinationBytesMetrics.getBytesUp()
-			bytesDownTCP := sshClient.tcpDestinationBytesMetrics.getBytesDown()
-			bytesUpUDP := sshClient.udpDestinationBytesMetrics.getBytesUp()
-			bytesDownUDP := sshClient.udpDestinationBytesMetrics.getBytesDown()
-
-			logFields["dest_bytes_asn"] = sshClient.destinationBytesMetricsASN
-			logFields["dest_bytes_up_tcp"] = bytesUpTCP
-			logFields["dest_bytes_down_tcp"] = bytesDownTCP
-			logFields["dest_bytes_up_udp"] = bytesUpUDP
-			logFields["dest_bytes_down_udp"] = bytesDownUDP
-			logFields["dest_bytes"] = bytesUpTCP + bytesDownTCP + bytesUpUDP + bytesDownUDP
+		if destinationBytesMetricsASN != "" {
+
+			// Log any parameters.DestinationBytesMetricsASN data in the
+			// legacy log field format.
+
+			destinationBytesMetrics, ok :=
+				sshClient.destinationBytesMetrics[destinationBytesMetricsASN]
+
+			if ok {
+				bytesUpTCP := destinationBytesMetrics.tcpMetrics.getBytesUp()
+				bytesDownTCP := destinationBytesMetrics.tcpMetrics.getBytesDown()
+				bytesUpUDP := destinationBytesMetrics.udpMetrics.getBytesUp()
+				bytesDownUDP := destinationBytesMetrics.udpMetrics.getBytesDown()
+
+				logFields["dest_bytes_asn"] = destinationBytesMetricsASN
+				logFields["dest_bytes"] = bytesUpTCP + bytesDownTCP + bytesUpUDP + bytesDownUDP
+				logFields["dest_bytes_up_tcp"] = bytesUpTCP
+				logFields["dest_bytes_down_tcp"] = bytesDownTCP
+				logFields["dest_bytes_up_udp"] = bytesUpUDP
+				logFields["dest_bytes_down_udp"] = bytesDownUDP
+			}
+		}
+
+		if len(destinationBytesMetricsASNs) > 0 {
+
+			destBytes := make(map[string]int64)
+			destBytesUpTCP := make(map[string]int64)
+			destBytesDownTCP := make(map[string]int64)
+			destBytesUpUDP := make(map[string]int64)
+			destBytesDownUDP := make(map[string]int64)
+
+			for _, ASN := range destinationBytesMetricsASNs {
+
+				destinationBytesMetrics, ok :=
+					sshClient.destinationBytesMetrics[ASN]
+				if !ok {
+					continue
+				}
+
+				bytesUpTCP := destinationBytesMetrics.tcpMetrics.getBytesUp()
+				bytesDownTCP := destinationBytesMetrics.tcpMetrics.getBytesDown()
+				bytesUpUDP := destinationBytesMetrics.udpMetrics.getBytesUp()
+				bytesDownUDP := destinationBytesMetrics.udpMetrics.getBytesDown()
+
+				destBytes[ASN] = bytesUpTCP + bytesDownTCP + bytesUpUDP + bytesDownUDP
+				destBytesUpTCP[ASN] = bytesUpTCP
+				destBytesDownTCP[ASN] = bytesDownTCP
+				destBytesUpUDP[ASN] = bytesUpUDP
+				destBytesDownUDP[ASN] = bytesDownUDP
+			}
+
+			logFields["asn_dest_bytes"] = destBytes
+			logFields["asn_dest_bytes_up_tcp"] = destBytesUpTCP
+			logFields["asn_dest_bytes_down_tcp"] = destBytesDownTCP
+			logFields["asn_dest_bytes_up_udp"] = destBytesUpUDP
+			logFields["asn_dest_bytes_down_udp"] = destBytesDownUDP
 		}
 	}
 
@@ -3738,6 +3784,46 @@ func (sshClient *sshClient) setHandshakeState(
 		return nil, errors.TraceNew("handshake already completed")
 	}
 
+	if sshClient.isInproxyTunnelProtocol {
+
+		p, err := sshClient.sshServer.support.ServerTacticsParametersCache.Get(sshClient.clientGeoIPData)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		// Skip check if no tactics are configured.
+		//
+		// Disconnect immediately if the tactics for the client restricts usage
+		// of the provider ID with inproxy protocols. The probability may be
+		// used to influence usage of a given provider with inproxy protocols;
+		// but when only that provider works for a given client, and the
+		// probability is less than 1.0, the client can retry until it gets a
+		// successful coin flip.
+		//
+		// Clients will also skip inproxy protocol candidates with restricted
+		// provider IDs.
+		// The client-side probability,
+		// RestrictInproxyProviderIDsClientProbability, is applied
+		// independently of the server-side coin flip here.
+		//
+		// At this stage, GeoIP tactics filters are active, but handshake API
+		// parameters are not.
+		//
+		// See the comment in server.LoadConfig regarding provider ID
+		// limitations.
+		if !p.IsNil() &&
+			common.ContainsAny(
+				p.KeyStrings(parameters.RestrictInproxyProviderRegions,
+					sshClient.sshServer.support.Config.GetProviderID()),
+				[]string{"", sshClient.sshServer.support.Config.GetRegion()}) {
+
+			if p.WeightedCoinFlip(
+				parameters.RestrictInproxyProviderIDsServerProbability) {
+				return nil, errRestrictedProvider
+			}
+		}
+	}
+
 	// Verify the authorizations submitted by the client. Verified, active
 	// (non-expired) access types will be available for traffic rules
 	// filtering.
@@ -4112,26 +4198,54 @@ func (sshClient *sshClient) setDestinationBytesMetrics() {
 		return
 	}
 
-	sshClient.destinationBytesMetricsASN = p.String(parameters.DestinationBytesMetricsASN)
+	ASNs := p.Strings(parameters.DestinationBytesMetricsASNs)
+
+	// Merge in any legacy parameters.DestinationBytesMetricsASN
+	// configuration. Data for this target will be logged using the legacy
+	// log field format; see logTunnel. If an ASN is in _both_ configuration
+	// parameters, its data will be logged in both log field formats.
+	ASN := p.String(parameters.DestinationBytesMetricsASN)
+
+	if len(ASNs) == 0 && ASN == "" {
+		return
+	}
+
+	sshClient.destinationBytesMetrics = make(map[string]*protocolDestinationBytesMetrics)
+
+	for _, ASN := range ASNs {
+		if ASN != "" {
+			sshClient.destinationBytesMetrics[ASN] = &protocolDestinationBytesMetrics{}
+		}
+	}
+
+	if ASN != "" {
+		sshClient.destinationBytesMetrics[ASN] = &protocolDestinationBytesMetrics{}
+	}
 }
 
 func (sshClient *sshClient) newDestinationBytesMetricsUpdater(portForwardType int, IPAddress net.IP) *destinationBytesMetrics {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
-	if sshClient.destinationBytesMetricsASN == "" {
+	if sshClient.destinationBytesMetrics == nil {
 		return nil
 	}
 
-	if sshClient.sshServer.support.GeoIPService.LookupISPForIP(IPAddress).ASN != sshClient.destinationBytesMetricsASN {
+	destinationASN := sshClient.sshServer.support.GeoIPService.LookupISPForIP(IPAddress).ASN
+
+	// Future enhancement: for 5 or fewer ASNs, iterate over a slice instead
+	// of using a map? See, for example, stringLookupThreshold in
+	// common/tactics.
+	metrics, ok := sshClient.destinationBytesMetrics[destinationASN]
+	if !ok {
 		return nil
 	}
 
 	if portForwardType == portForwardTypeTCP {
-		return &sshClient.tcpDestinationBytesMetrics
+		return &metrics.tcpMetrics
 	}
 
-	return &sshClient.udpDestinationBytesMetrics
+	return &metrics.udpMetrics
 }
 
 func (sshClient *sshClient) getActivityUpdaters(portForwardType int, IPAddress net.IP) []common.ActivityUpdater {

+ 2 - 1
psiphon/serverApi.go

@@ -142,7 +142,8 @@ func (serverContext *ServerContext) doHandshakeRequest(ignoreStatsRegexps bool)
 	// The purpose of this mechanism is to rapidly add provider IDs to the
 	// server entries in client local storage, and to ensure that the client has
 	// a provider ID for its currently connected server as required for the
-	// RestrictDirectProviderRegions, and HoldOffDirectTunnelProviderRegions
+	// RestrictDirectProviderRegions, HoldOffDirectTunnelProviderRegions,
+	// RestrictInproxyProviderRegions, and HoldOffInproxyTunnelProviderRegions
 	// tactics.
 	//
 	// The server entry will be included in handshakeResponse.EncodedServerList,