Jelajahi Sumber

Add HoldOffInproxy and RestrictInproxy

Miro 1 tahun lalu
induk
melakukan
5178fc334d

+ 16 - 0
psiphon/common/parameters/parameters.go

@@ -328,6 +328,13 @@ const (
 	RestrictDirectProviderRegions                      = "RestrictDirectProviderRegions"
 	RestrictDirectProviderRegions                      = "RestrictDirectProviderRegions"
 	RestrictDirectProviderIDsServerProbability         = "RestrictDirectProviderIDsServerProbability"
 	RestrictDirectProviderIDsServerProbability         = "RestrictDirectProviderIDsServerProbability"
 	RestrictDirectProviderIDsClientProbability         = "RestrictDirectProviderIDsClientProbability"
 	RestrictDirectProviderIDsClientProbability         = "RestrictDirectProviderIDsClientProbability"
+	HoldOffInproxyTunnelMinDuration                    = "HoldOffInproxyTunnelMinDuration"
+	HoldOffInproxyTunnelMaxDuration                    = "HoldOffInproxyTunnelMaxDuration"
+	HoldOffInproxyTunnelProviderRegions                = "HoldOffInproxyTunnelProviderRegions"
+	HoldOffInproxyTunnelProbability                    = "HoldOffInproxyTunnelProbability"
+	RestrictInproxyProviderRegions                     = "RestrictInproxyProviderRegions"
+	RestrictInproxyProviderIDsServerProbability        = "RestrictInproxyProviderIDsServerProbability"
+	RestrictInproxyProviderIDsClientProbability        = "RestrictInproxyProviderIDsClientProbability"
 	UpstreamProxyAllowAllServerEntrySources            = "UpstreamProxyAllowAllServerEntrySources"
 	UpstreamProxyAllowAllServerEntrySources            = "UpstreamProxyAllowAllServerEntrySources"
 	DestinationBytesMetricsASN                         = "DestinationBytesMetricsASN"
 	DestinationBytesMetricsASN                         = "DestinationBytesMetricsASN"
 	DestinationBytesMetricsASNs                        = "DestinationBytesMetricsASNs"
 	DestinationBytesMetricsASNs                        = "DestinationBytesMetricsASNs"
@@ -832,6 +839,15 @@ var defaultParameters = map[string]struct {
 	RestrictDirectProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
 	RestrictDirectProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
 	RestrictDirectProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
 	RestrictDirectProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
 
 
+	HoldOffInproxyTunnelMinDuration:     {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffInproxyTunnelMaxDuration:     {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffInproxyTunnelProviderRegions: {value: []string{}},
+	HoldOffInproxyTunnelProbability:     {value: 0.0, minimum: 0.0},
+
+	RestrictInproxyProviderRegions:              {value: KeyStrings{}},
+	RestrictInproxyProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
+	RestrictInproxyProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
+
 	UpstreamProxyAllowAllServerEntrySources: {value: false},
 	UpstreamProxyAllowAllServerEntrySources: {value: false},
 
 
 	DestinationBytesMetricsASN:  {value: "", flags: serverSideOnly},
 	DestinationBytesMetricsASN:  {value: "", flags: serverSideOnly},

+ 68 - 0
psiphon/config.go

@@ -903,6 +903,18 @@ type Config struct {
 	RestrictDirectProviderRegions              map[string][]string
 	RestrictDirectProviderRegions              map[string][]string
 	RestrictDirectProviderIDsClientProbability *float64
 	RestrictDirectProviderIDsClientProbability *float64
 
 
+	// HoldOffInproxyTunnelMinDurationMilliseconds and other HoldOffInproxy
+	// fields are for testing purposes.
+	HoldOffInproxyTunnelMinDurationMilliseconds *int
+	HoldOffInproxyTunnelMaxDurationMilliseconds *int
+	HoldOffInproxyTunnelProviderRegions         map[string][]string
+	HoldOffInproxyTunnelProbability             *float64
+
+	// RestrictInproxyProviderRegions and other RestrictInproxy fields are for
+	// testing purposes.
+	RestrictInproxyProviderRegions              map[string][]string
+	RestrictInproxyProviderIDsClientProbability *float64
+
 	// UpstreamProxyAllowAllServerEntrySources is for testing purposes.
 	// UpstreamProxyAllowAllServerEntrySources is for testing purposes.
 	UpstreamProxyAllowAllServerEntrySources *bool
 	UpstreamProxyAllowAllServerEntrySources *bool
 
 
@@ -2272,6 +2284,22 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = *config.RestrictFrontingProviderIDsClientProbability
 		applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = *config.RestrictFrontingProviderIDsClientProbability
 	}
 	}
 
 
+	if config.HoldOffInproxyTunnelMinDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffInproxyTunnelMinDuration] = fmt.Sprintf("%dms", *config.HoldOffInproxyTunnelMinDurationMilliseconds)
+	}
+
+	if config.HoldOffInproxyTunnelMaxDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffInproxyTunnelMaxDuration] = fmt.Sprintf("%dms", *config.HoldOffInproxyTunnelMaxDurationMilliseconds)
+	}
+
+	if len(config.HoldOffInproxyTunnelProviderRegions) > 0 {
+		applyParameters[parameters.HoldOffInproxyTunnelProviderRegions] = parameters.KeyStrings(config.HoldOffInproxyTunnelProviderRegions)
+	}
+
+	if config.HoldOffInproxyTunnelProbability != nil {
+		applyParameters[parameters.HoldOffInproxyTunnelProbability] = *config.HoldOffInproxyTunnelProbability
+	}
+
 	if config.UpstreamProxyAllowAllServerEntrySources != nil {
 	if config.UpstreamProxyAllowAllServerEntrySources != nil {
 		applyParameters[parameters.UpstreamProxyAllowAllServerEntrySources] = *config.UpstreamProxyAllowAllServerEntrySources
 		applyParameters[parameters.UpstreamProxyAllowAllServerEntrySources] = *config.UpstreamProxyAllowAllServerEntrySources
 	}
 	}
@@ -3119,6 +3147,46 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, *config.RestrictFrontingProviderIDsClientProbability)
 		binary.Write(hash, binary.LittleEndian, *config.RestrictFrontingProviderIDsClientProbability)
 	}
 	}
 
 
+	if config.HoldOffInproxyTunnelProbability != nil {
+		hash.Write([]byte("HoldOffInproxyTunnelProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.HoldOffInproxyTunnelProbability)
+	}
+
+	if config.HoldOffInproxyTunnelMinDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffInproxyTunnelMinDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffInproxyTunnelMinDurationMilliseconds))
+	}
+
+	if config.HoldOffInproxyTunnelMaxDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffInproxyTunnelMaxDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffInproxyTunnelMaxDurationMilliseconds))
+	}
+
+	if len(config.HoldOffInproxyTunnelProviderRegions) > 0 {
+		hash.Write([]byte("HoldOffInproxyTunnelProviderRegions"))
+		for providerID, regions := range config.HoldOffInproxyTunnelProviderRegions {
+			hash.Write([]byte(providerID))
+			for _, region := range regions {
+				hash.Write([]byte(region))
+			}
+		}
+	}
+
+	if len(config.RestrictInproxyProviderRegions) > 0 {
+		hash.Write([]byte("RestrictInproxyProviderRegions"))
+		for providerID, regions := range config.RestrictInproxyProviderRegions {
+			hash.Write([]byte(providerID))
+			for _, region := range regions {
+				hash.Write([]byte(region))
+			}
+		}
+	}
+
+	if config.RestrictInproxyProviderIDsClientProbability != nil {
+		hash.Write([]byte("RestrictInproxyProviderIDsClientProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.RestrictInproxyProviderIDsClientProbability)
+	}
+
 	if config.UpstreamProxyAllowAllServerEntrySources != nil {
 	if config.UpstreamProxyAllowAllServerEntrySources != nil {
 		hash.Write([]byte("UpstreamProxyAllowAllServerEntrySources"))
 		hash.Write([]byte("UpstreamProxyAllowAllServerEntrySources"))
 		binary.Write(hash, binary.LittleEndian, *config.UpstreamProxyAllowAllServerEntrySources)
 		binary.Write(hash, binary.LittleEndian, *config.UpstreamProxyAllowAllServerEntrySources)

+ 36 - 1
psiphon/dialParameters.go

@@ -502,6 +502,27 @@ func MakeDialParameters(
 		}
 		}
 	}
 	}
 
 
+	// Skip this candidate when the clients tactics restrict usage of the
+	// provider ID. See the corresponding server-side enforcement comments in
+	// server.sshClient.setHandshakeState.
+	if protocol.TunnelProtocolUsesInproxy(dialParams.TunnelProtocol) &&
+		common.ContainsAny(
+			p.KeyStrings(parameters.RestrictInproxyProviderRegions, dialParams.ServerEntry.ProviderID), []string{"", serverEntry.Region}) {
+		if p.WeightedCoinFlip(
+			parameters.RestrictInproxyProviderIDsClientProbability) {
+
+			// When skipping, return nil/nil as no error should be logged.
+			// NoticeSkipServerEntry emits each skip reason, regardless
+			// of server entry, at most once per session.
+
+			NoticeSkipServerEntry(
+				"restricted provider ID: %s",
+				dialParams.ServerEntry.ProviderID)
+
+			return nil, nil
+		}
+	}
+
 	// Skip this candidate when the clients tactics restrict usage of the
 	// Skip this candidate when the clients tactics restrict usage of the
 	// fronting provider ID. See the corresponding server-side enforcement
 	// fronting provider ID. See the corresponding server-side enforcement
 	// comments in server.MeekServer.getSessionOrEndpoint.
 	// comments in server.MeekServer.getSessionOrEndpoint.
@@ -983,6 +1004,7 @@ func MakeDialParameters(
 		var holdOffTunnelDuration time.Duration
 		var holdOffTunnelDuration time.Duration
 		var holdOffTunnelFrontingDuration time.Duration
 		var holdOffTunnelFrontingDuration time.Duration
 		var holdOffDirectTunnelDuration time.Duration
 		var holdOffDirectTunnelDuration time.Duration
+		var holdOffInproxyTunnelDuration time.Duration
 
 
 		if common.Contains(
 		if common.Contains(
 			p.TunnelProtocols(parameters.HoldOffTunnelProtocols), dialParams.TunnelProtocol) {
 			p.TunnelProtocols(parameters.HoldOffTunnelProtocols), dialParams.TunnelProtocol) {
@@ -1020,11 +1042,24 @@ func MakeDialParameters(
 			}
 			}
 		}
 		}
 
 
+		if protocol.TunnelProtocolUsesInproxy(dialParams.TunnelProtocol) &&
+			common.ContainsAny(
+				p.KeyStrings(parameters.HoldOffInproxyTunnelProviderRegions, dialParams.ServerEntry.ProviderID), []string{"", serverEntry.Region}) {
+
+			if p.WeightedCoinFlip(parameters.HoldOffInproxyTunnelProbability) {
+
+				holdOffInproxyTunnelDuration = prng.Period(
+					p.Duration(parameters.HoldOffInproxyTunnelMinDuration),
+					p.Duration(parameters.HoldOffInproxyTunnelMaxDuration))
+			}
+		}
+
 		// Use the longest hold off duration
 		// Use the longest hold off duration
 		dialParams.HoldOffTunnelDuration = common.MaxDuration(
 		dialParams.HoldOffTunnelDuration = common.MaxDuration(
 			holdOffTunnelDuration,
 			holdOffTunnelDuration,
 			holdOffTunnelFrontingDuration,
 			holdOffTunnelFrontingDuration,
-			holdOffDirectTunnelDuration)
+			holdOffDirectTunnelDuration,
+			holdOffInproxyTunnelDuration)
 	}
 	}
 
 
 	// OSSH prefix and seed transform are applied only to the OSSH tunnel protocol,
 	// OSSH prefix and seed transform are applied only to the OSSH tunnel protocol,

+ 48 - 2
psiphon/dialParameters_test.go

@@ -90,6 +90,12 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		holdOffDirectTunnelProviderRegions = map[string][]string{providerID: {""}}
 		holdOffDirectTunnelProviderRegions = map[string][]string{providerID: {""}}
 	}
 	}
 
 
+	var holdOffInproxyTunnelProviderRegions parameters.KeyStrings
+	if protocol.TunnelProtocolUsesInproxy(tunnelProtocol) &&
+		protocol.TunnelProtocolMinusInproxy(tunnelProtocol) == protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
+		holdOffInproxyTunnelProviderRegions = map[string][]string{providerID: {""}}
+	}
+
 	applyParameters := make(map[string]interface{})
 	applyParameters := make(map[string]interface{})
 	applyParameters[parameters.TransformHostNameProbability] = 1.0
 	applyParameters[parameters.TransformHostNameProbability] = 1.0
 	applyParameters[parameters.PickUserAgentProbability] = 1.0
 	applyParameters[parameters.PickUserAgentProbability] = 1.0
@@ -104,6 +110,10 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	applyParameters[parameters.HoldOffDirectTunnelMinDuration] = "1ms"
 	applyParameters[parameters.HoldOffDirectTunnelMinDuration] = "1ms"
 	applyParameters[parameters.HoldOffDirectTunnelMaxDuration] = "10ms"
 	applyParameters[parameters.HoldOffDirectTunnelMaxDuration] = "10ms"
 	applyParameters[parameters.HoldOffDirectTunnelProviderRegions] = holdOffDirectTunnelProviderRegions
 	applyParameters[parameters.HoldOffDirectTunnelProviderRegions] = holdOffDirectTunnelProviderRegions
+	applyParameters[parameters.HoldOffInproxyTunnelProbability] = 1.0
+	applyParameters[parameters.HoldOffInproxyTunnelMinDuration] = "1ms"
+	applyParameters[parameters.HoldOffInproxyTunnelMaxDuration] = "10ms"
+	applyParameters[parameters.HoldOffInproxyTunnelProviderRegions] = holdOffInproxyTunnelProviderRegions
 	applyParameters[parameters.HoldOffDirectTunnelProbability] = 1.0
 	applyParameters[parameters.HoldOffDirectTunnelProbability] = 1.0
 	applyParameters[parameters.DNSResolverAlternateServers] = []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}
 	applyParameters[parameters.DNSResolverAlternateServers] = []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}
 	applyParameters[parameters.DirectHTTPProtocolTransformProbability] = 1.0
 	applyParameters[parameters.DirectHTTPProtocolTransformProbability] = 1.0
@@ -254,10 +264,15 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		common.ContainsAny(
 		common.ContainsAny(
 			holdOffDirectTunnelProviderRegions[dialParams.ServerEntry.ProviderID],
 			holdOffDirectTunnelProviderRegions[dialParams.ServerEntry.ProviderID],
 			[]string{"", dialParams.ServerEntry.Region})
 			[]string{"", dialParams.ServerEntry.Region})
+	expectHoldOffInproxyTunnelProviderRegion := protocol.TunnelProtocolUsesInproxy(tunnelProtocol) &&
+		common.ContainsAny(
+			holdOffInproxyTunnelProviderRegions[dialParams.ServerEntry.ProviderID],
+			[]string{"", dialParams.ServerEntry.Region})
 
 
 	if expectHoldOffTunnelProtocols ||
 	if expectHoldOffTunnelProtocols ||
 		expectHoldOffTunnelFrontingProviderIDs ||
 		expectHoldOffTunnelFrontingProviderIDs ||
-		expectHoldOffDirectTunnelProviderRegion {
+		expectHoldOffDirectTunnelProviderRegion ||
+		expectHoldOffInproxyTunnelProviderRegion {
 		if dialParams.HoldOffTunnelDuration < 1*time.Millisecond ||
 		if dialParams.HoldOffTunnelDuration < 1*time.Millisecond ||
 			dialParams.HoldOffTunnelDuration > 10*time.Millisecond {
 			dialParams.HoldOffTunnelDuration > 10*time.Millisecond {
 			t.Fatalf("unexpected hold-off duration: %v", dialParams.HoldOffTunnelDuration)
 			t.Fatalf("unexpected hold-off duration: %v", dialParams.HoldOffTunnelDuration)
@@ -560,7 +575,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetParameters failed: %s", err)
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 	}
 
 
-	// Test: client-side restrict provider ID by region
+	// Test: client-side restrict provider ID by region for direct protocols
 
 
 	applyParameters[parameters.RestrictDirectProviderRegions] = map[string][]string{providerID: {"CA"}}
 	applyParameters[parameters.RestrictDirectProviderRegions] = map[string][]string{providerID: {"CA"}}
 	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 1.0
 	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 1.0
@@ -591,6 +606,37 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetParameters failed: %s", err)
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 	}
 
 
+	// Test: client-side restrict provider ID by region for inproxy protocols
+
+	applyParameters[parameters.RestrictInproxyProviderRegions] = map[string][]string{providerID: {"CA"}}
+	applyParameters[parameters.RestrictInproxyProviderIDsClientProbability] = 1.0
+	err = clientConfig.SetParameters("tag8", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	dialParams, err = MakeDialParameters(
+		clientConfig, steeringIPCache, nil, nil, nil, canReplay, selectProtocol, serverEntries[0], nil, nil, false, 0, 0)
+
+	if protocol.TunnelProtocolUsesInproxy(tunnelProtocol) {
+		if err == nil {
+			if dialParams != nil {
+				t.Fatalf("unexpected MakeDialParameters success")
+			}
+		}
+	} else {
+		if err != nil {
+			t.Fatalf("MakeDialParameters failed: %s", err)
+		}
+	}
+
+	applyParameters[parameters.RestrictInproxyProviderRegions] = map[string][]string{}
+	applyParameters[parameters.RestrictInproxyProviderIDsClientProbability] = 0.0
+	err = clientConfig.SetParameters("tag9", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
 	if protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
 	if protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
 
 
 		steeringIPCache.Flush()
 		steeringIPCache.Flush()

+ 7 - 1
psiphon/server/api.go

@@ -119,8 +119,14 @@ func sshAPIRequestHandler(
 	switch name {
 	switch name {
 
 
 	case protocol.PSIPHON_API_HANDSHAKE_REQUEST_NAME:
 	case protocol.PSIPHON_API_HANDSHAKE_REQUEST_NAME:
-		return handshakeAPIRequestHandler(
+		responsePayload, err := handshakeAPIRequestHandler(
 			support, protocol.PSIPHON_API_PROTOCOL_SSH, sshClient, params)
 			support, protocol.PSIPHON_API_PROTOCOL_SSH, sshClient, params)
+		if err != nil {
+			// Handshake failed, disconnect the client.
+			sshClient.stop()
+			return nil, errors.Trace(err)
+		}
+		return responsePayload, nil
 
 
 	case protocol.PSIPHON_API_CONNECTED_REQUEST_NAME:
 	case protocol.PSIPHON_API_CONNECTED_REQUEST_NAME:
 		return connectedAPIRequestHandler(
 		return connectedAPIRequestHandler(

+ 3 - 1
psiphon/server/listener.go

@@ -117,7 +117,9 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
 	// peer IP is not the original client IP. Indirect protocols must determine
 	// peer IP is not the original client IP. Indirect protocols must determine
 	// the original client IP before applying GeoIP specific tactics; see the
 	// the original client IP before applying GeoIP specific tactics; see the
 	// server-side enforcement of RestrictFrontingProviderIDs for fronted meek
 	// server-side enforcement of RestrictFrontingProviderIDs for fronted meek
-	// in server.MeekServer.getSessionOrEndpoint.
+	// in server.MeekServer.getSessionOrEndpoint or of
+	// RestrictInproxyProviderRegions for inproxy in
+	// server.sshClient.setHandshakeState.
 	//
 	//
 	// At this stage, GeoIP tactics filters are active, but handshake API
 	// At this stage, GeoIP tactics filters are active, but handshake API
 	// parameters are not.
 	// parameters are not.

+ 48 - 9
psiphon/server/server_test.go

@@ -369,6 +369,17 @@ func TestInproxyOSSH(t *testing.T) {
 		})
 		})
 }
 }
 
 
+func TestRestrictInproxy(t *testing.T) {
+	if !inproxy.Enabled() {
+		t.Skip("inproxy is not enabled")
+	}
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:    "INPROXY-WEBRTC-OSSH",
+			doRestrictInproxy: true,
+		})
+}
+
 func TestInproxyQUICOSSH(t *testing.T) {
 func TestInproxyQUICOSSH(t *testing.T) {
 	if !quic.Enabled() {
 	if !quic.Enabled() {
 		t.Skip("QUIC is not enabled")
 		t.Skip("QUIC is not enabled")
@@ -690,6 +701,7 @@ type runServerConfig struct {
 	doTargetBrokerSpecs      bool
 	doTargetBrokerSpecs      bool
 	useLegacyAPIEncoding     bool
 	useLegacyAPIEncoding     bool
 	doPersonalPairing        bool
 	doPersonalPairing        bool
+	doRestrictInproxy        bool
 }
 }
 
 
 var (
 var (
@@ -933,7 +945,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			runConfig.applyPrefix,
 			runConfig.applyPrefix,
 			runConfig.forceFragmenting,
 			runConfig.forceFragmenting,
 			"classic",
 			"classic",
-			inproxyTacticsParametersJSON)
+			inproxyTacticsParametersJSON,
+			runConfig.doRestrictInproxy)
 	}
 	}
 
 
 	blocklistFilename := filepath.Join(testDataDirName, "blocklist.csv")
 	blocklistFilename := filepath.Join(testDataDirName, "blocklist.csv")
@@ -1203,7 +1216,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 				runConfig.applyPrefix,
 				runConfig.applyPrefix,
 				runConfig.forceFragmenting,
 				runConfig.forceFragmenting,
 				"consistent",
 				"consistent",
-				inproxyTacticsParametersJSON)
+				inproxyTacticsParametersJSON,
+				runConfig.doRestrictInproxy)
 		}
 		}
 
 
 		p, _ := os.FindProcess(os.Getpid())
 		p, _ := os.FindProcess(os.Getpid())
@@ -1510,6 +1524,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	pruneServerEntriesNoticesEmitted := make(chan struct{}, 1)
 	pruneServerEntriesNoticesEmitted := make(chan struct{}, 1)
 	serverAlertDisallowedNoticesEmitted := make(chan struct{}, 1)
 	serverAlertDisallowedNoticesEmitted := make(chan struct{}, 1)
 	untunneledPortForward := make(chan struct{}, 1)
 	untunneledPortForward := make(chan struct{}, 1)
+	discardTunnel := make(chan struct{}, 1)
 
 
 	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
 		func(notice []byte) {
@@ -1581,6 +1596,11 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 				if connectedClients == 1 && bytesUp > 0 && bytesDown > 0 {
 				if connectedClients == 1 && bytesUp > 0 && bytesDown > 0 {
 					sendNotificationReceived(inproxyActivity)
 					sendNotificationReceived(inproxyActivity)
 				}
 				}
+
+			case "Info":
+				if strings.Contains(payload["message"].(string), "discard tunnel") {
+					sendNotificationReceived(discardTunnel)
+				}
 			}
 			}
 
 
 			if printNotice {
 			if printNotice {
@@ -1633,12 +1653,19 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		close(timeoutSignal)
 		close(timeoutSignal)
 	}()
 	}()
 
 
-	waitOnNotification(t, connectedServer, timeoutSignal, "connected server timeout exceeded")
-	if doInproxy {
-		waitOnNotification(t, inproxyActivity, timeoutSignal, "inproxy activity timeout exceeded")
+	expectDiscardTunnel := runConfig.doRestrictInproxy
+
+	if expectDiscardTunnel {
+		waitOnNotification(t, discardTunnel, timeoutSignal, "discard tunnel timeout exceeded")
+		return
+	} else {
+		waitOnNotification(t, connectedServer, timeoutSignal, "connected server timeout exceeded")
+		if doInproxy {
+			waitOnNotification(t, inproxyActivity, timeoutSignal, "inproxy activity timeout exceeded")
+		}
+		waitOnNotification(t, tunnelsEstablished, timeoutSignal, "tunnel established timeout exceeded")
+		waitOnNotification(t, homepageReceived, timeoutSignal, "homepage received timeout exceeded")
 	}
 	}
-	waitOnNotification(t, tunnelsEstablished, timeoutSignal, "tunnel established timeout exceeded")
-	waitOnNotification(t, homepageReceived, timeoutSignal, "homepage received timeout exceeded")
 
 
 	if runConfig.doChangeBytesConfig {
 	if runConfig.doChangeBytesConfig {
 
 
@@ -1672,7 +1699,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			runConfig.applyPrefix,
 			runConfig.applyPrefix,
 			runConfig.forceFragmenting,
 			runConfig.forceFragmenting,
 			"consistent",
 			"consistent",
-			inproxyTacticsParametersJSON)
+			inproxyTacticsParametersJSON,
+			runConfig.doRestrictInproxy)
 
 
 		p, _ := os.FindProcess(os.Getpid())
 		p, _ := os.FindProcess(os.Getpid())
 		p.Signal(syscall.SIGUSR1)
 		p.Signal(syscall.SIGUSR1)
@@ -3451,7 +3479,8 @@ func paveTacticsConfigFile(
 	applyOsshPrefix bool,
 	applyOsshPrefix bool,
 	enableOsshPrefixFragmenting bool,
 	enableOsshPrefixFragmenting bool,
 	discoveryStategy string,
 	discoveryStategy string,
-	inproxyParametersJSON string) {
+	inproxyParametersJSON string,
+	doRestrictAllInproxyProviderRegions bool) {
 
 
 	// Setting LimitTunnelProtocols passively exercises the
 	// Setting LimitTunnelProtocols passively exercises the
 	// server-side LimitTunnelProtocols enforcement.
 	// server-side LimitTunnelProtocols enforcement.
@@ -3470,6 +3499,7 @@ func paveTacticsConfigFile(
           %s
           %s
           %s
           %s
           %s
           %s
+          %s
           "LimitTunnelProtocols" : ["%s"],
           "LimitTunnelProtocols" : ["%s"],
           "FragmentorLimitProtocols" : ["%s"],
           "FragmentorLimitProtocols" : ["%s"],
           "FragmentorProbability" : 1.0,
           "FragmentorProbability" : 1.0,
@@ -3577,6 +3607,14 @@ func paveTacticsConfigFile(
 	`, strconv.FormatBool(enableOsshPrefixFragmenting))
 	`, strconv.FormatBool(enableOsshPrefixFragmenting))
 	}
 	}
 
 
+	restrictInproxyParameters := ""
+	if doRestrictAllInproxyProviderRegions {
+		restrictInproxyParameters = `
+		"RestrictInproxyProviderRegions": {"" : [""]},
+		"RestrictInproxyProviderIDsServerProbability": 1.0,
+	`
+	}
+
 	tacticsConfigJSON := fmt.Sprintf(
 	tacticsConfigJSON := fmt.Sprintf(
 		tacticsConfigJSONFormat,
 		tacticsConfigJSONFormat,
 		tacticsRequestPublicKey,
 		tacticsRequestPublicKey,
@@ -3587,6 +3625,7 @@ func paveTacticsConfigFile(
 		legacyDestinationBytesParameters,
 		legacyDestinationBytesParameters,
 		osshPrefix,
 		osshPrefix,
 		inproxyParametersJSON,
 		inproxyParametersJSON,
+		restrictInproxyParameters,
 		tunnelProtocol,
 		tunnelProtocol,
 		tunnelProtocol,
 		tunnelProtocol,
 		tunnelProtocol,
 		tunnelProtocol,

+ 40 - 0
psiphon/server/tunnelServer.go

@@ -3784,6 +3784,46 @@ func (sshClient *sshClient) setHandshakeState(
 		return nil, errors.TraceNew("handshake already completed")
 		return nil, errors.TraceNew("handshake already completed")
 	}
 	}
 
 
+	if sshClient.isInproxyTunnelProtocol {
+
+		p, err := sshClient.sshServer.support.ServerTacticsParametersCache.Get(sshClient.clientGeoIPData)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		// Skip check if no tactics are configured.
+		//
+		// Disconnect immediately if the tactics for the client restricts usage
+		// of the provider ID with inproxy protocols. The probability may be
+		// used to influence usage of a given provider with inproxy protocols;
+		// but when only that provider works for a given client, and the
+		// probability is less than 1.0, the client can retry until it gets a
+		// successful coin flip.
+		//
+		// Clients will also skip inproxy protocol candidates with restricted
+		// provider IDs.
+		// The client-side probability,
+		// RestrictInproxyProviderIDsClientProbability, is applied
+		// independently of the server-side coin flip here.
+		//
+		// At this stage, GeoIP tactics filters are active, but handshake API
+		// parameters are not.
+		//
+		// See the comment in server.LoadConfig regarding provider ID
+		// limitations.
+		if !p.IsNil() &&
+			common.ContainsAny(
+				p.KeyStrings(parameters.RestrictInproxyProviderRegions,
+					sshClient.sshServer.support.Config.GetProviderID()),
+				[]string{"", sshClient.sshServer.support.Config.GetRegion()}) {
+
+			if p.WeightedCoinFlip(
+				parameters.RestrictInproxyProviderIDsServerProbability) {
+				return nil, errRestrictedProvider
+			}
+		}
+	}
+
 	// Verify the authorizations submitted by the client. Verified, active
 	// Verify the authorizations submitted by the client. Verified, active
 	// (non-expired) access types will be available for traffic rules
 	// (non-expired) access types will be available for traffic rules
 	// filtering.
 	// filtering.

+ 2 - 1
psiphon/serverApi.go

@@ -142,7 +142,8 @@ func (serverContext *ServerContext) doHandshakeRequest(ignoreStatsRegexps bool)
 	// The purpose of this mechanism is to rapidly add provider IDs to the
 	// The purpose of this mechanism is to rapidly add provider IDs to the
 	// server entries in client local storage, and to ensure that the client has
 	// server entries in client local storage, and to ensure that the client has
 	// a provider ID for its currently connected server as required for the
 	// a provider ID for its currently connected server as required for the
-	// RestrictDirectProviderRegions, and HoldOffDirectTunnelProviderRegions
+	// RestrictDirectProviderRegions, HoldOffDirectTunnelProviderRegions,
+	// RestrictInproxyProviderRegions, and HoldOffInproxyTunnelProviderRegions
 	// tactics.
 	// tactics.
 	//
 	//
 	// The server entry will be included in handshakeResponse.EncodedServerList,
 	// The server entry will be included in handshakeResponse.EncodedServerList,