Просмотр исходного кода

Add RestrictDirectProviderRegions

mirokuratczyk 2 лет назад
Родитель
Сommit
6bfc32dd2d

+ 2 - 0
psiphon/common/parameters/parameters.go

@@ -317,6 +317,7 @@ const (
 	HoldOffDirectServerEntryProviderRegions          = "HoldOffDirectServerEntryProviderRegions"
 	HoldOffDirectTunnelProbability                   = "HoldOffDirectTunnelProbability"
 	RestrictDirectProviderIDs                        = "RestrictDirectProviderIDs"
+	RestrictDirectProviderRegions                    = "RestrictDirectProviderRegions"
 	RestrictDirectProviderIDsServerProbability       = "RestrictDirectProviderIDsServerProbability"
 	RestrictDirectProviderIDsClientProbability       = "RestrictDirectProviderIDsClientProbability"
 	RestrictFrontingProviderIDs                      = "RestrictFrontingProviderIDs"
@@ -717,6 +718,7 @@ var defaultParameters = map[string]struct {
 	HoldOffDirectTunnelProbability:          {value: 0.0, minimum: 0.0},
 
 	RestrictDirectProviderIDs:                  {value: []string{}},
+	RestrictDirectProviderRegions:              {value: KeyStrings{}},
 	RestrictDirectProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
 	RestrictDirectProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
 

+ 18 - 3
psiphon/config.go

@@ -818,7 +818,7 @@ type Config struct {
 	HoldOffTunnelFrontingProviderIDs     []string
 	HoldOffTunnelProbability             *float64
 
-	// HoldOffDirectTunnelMinDurationMilliseconds and other HoldOffDirectTunnel
+	// HoldOffDirectTunnelMinDurationMilliseconds and other HoldOffDirect
 	// fields are for testing purposes.
 	HoldOffDirectTunnelMinDurationMilliseconds *int
 	HoldOffDirectTunnelMaxDurationMilliseconds *int
@@ -826,9 +826,10 @@ type Config struct {
 	HoldOffDirectServerEntryProviderRegions    map[string][]string
 	HoldOffDirectTunnelProbability             *float64
 
-	// RestrictDirectProviderIDs and other RestrictDirectProviderIDs fields
-	// are for testing purposes.
+	// RestrictDirectProviderIDs and other RestrictDirect fields are for
+	// testing purposes.
 	RestrictDirectProviderIDs                  []string
+	RestrictDirectProviderRegions              map[string][]string
 	RestrictDirectProviderIDsClientProbability *float64
 
 	// RestrictFrontingProviderIDs and other RestrictFrontingProviderIDs fields
@@ -1964,6 +1965,10 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.RestrictDirectProviderIDs] = config.RestrictDirectProviderIDs
 	}
 
+	if len(config.RestrictDirectProviderRegions) > 0 {
+		applyParameters[parameters.RestrictDirectProviderRegions] = parameters.KeyStrings(config.RestrictDirectProviderRegions)
+	}
+
 	if config.RestrictDirectProviderIDsClientProbability != nil {
 		applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = *config.RestrictDirectProviderIDsClientProbability
 	}
@@ -2499,6 +2504,16 @@ func (config *Config) setDialParametersHash() {
 		}
 	}
 
+	if len(config.RestrictDirectProviderRegions) > 0 {
+		hash.Write([]byte("RestrictDirectProviderRegions"))
+		for providerID, regions := range config.RestrictDirectProviderRegions {
+			hash.Write([]byte(providerID))
+			for _, region := range regions {
+				hash.Write([]byte(region))
+			}
+		}
+	}
+
 	if config.RestrictDirectProviderIDsClientProbability != nil {
 		hash.Write([]byte("RestrictDirectProviderIDsClientProbability"))
 		binary.Write(hash, binary.LittleEndian, *config.RestrictDirectProviderIDsClientProbability)

+ 4 - 2
psiphon/dialParameters.go

@@ -431,9 +431,11 @@ func MakeDialParameters(
 	// provider ID. See the corresponding server-side enforcement comments in
 	// server.TacticsListener.accept.
 	if protocol.TunnelProtocolIsDirect(dialParams.TunnelProtocol) &&
-		common.Contains(
+		(common.Contains(
 			p.Strings(parameters.RestrictDirectProviderIDs),
-			dialParams.ServerEntry.ProviderID) {
+			dialParams.ServerEntry.ProviderID) ||
+			common.ContainsAny(
+				p.KeyStrings(parameters.RestrictDirectProviderRegions, dialParams.ServerEntry.ProviderID), []string{"", serverEntry.Region})) {
 		if p.WeightedCoinFlip(
 			parameters.RestrictDirectProviderIDsClientProbability) {
 

+ 31 - 0
psiphon/dialParameters_test.go

@@ -571,6 +571,37 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		}
 	}
 
+	applyParameters[parameters.RestrictDirectProviderIDs] = []string{}
+	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 0.0
+	err = clientConfig.SetParameters("tag7", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	// Test: client-side restrict provider ID by region
+
+	applyParameters[parameters.RestrictDirectProviderRegions] = map[string][]string{providerID: {"CA"}}
+	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 1.0
+	err = clientConfig.SetParameters("tag6", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	dialParams, err = MakeDialParameters(clientConfig, nil, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
+
+	if protocol.TunnelProtocolIsDirect(tunnelProtocol) {
+		if err == nil {
+			if dialParams != nil {
+				t.Fatalf("unexpected MakeDialParameters success")
+			}
+		}
+	} else {
+		if err != nil {
+			t.Fatalf("MakeDialParameters failed: %s", err)
+		}
+	}
+
+	applyParameters[parameters.RestrictDirectProviderRegions] = map[string][]string{}
 	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 0.0
 	err = clientConfig.SetParameters("tag7", false, applyParameters)
 	if err != nil {

+ 11 - 0
psiphon/server/config.go

@@ -464,6 +464,7 @@ type Config struct {
 	dumpProfilesOnStopEstablishTunnelsDone         int32
 	providerID                                     string
 	frontingProviderID                             string
+	region                                         string
 	runningProtocols                               []string
 }
 
@@ -542,6 +543,11 @@ func (config *Config) GetFrontingProviderID() string {
 	return config.frontingProviderID
 }
 
+// GetRegion returns the region associated with the server.
+func (config *Config) GetRegion() string {
+	return config.region
+}
+
 // GetRunningProtocols returns the list of protcols this server is running.
 // The caller must not mutate the return value.
 func (config *Config) GetRunningProtocols() []string {
@@ -733,6 +739,11 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 		} else if config.frontingProviderID != serverEntry.FrontingProviderID {
 			return nil, errors.Tracef("unsupported multiple FrontingProviderID values")
 		}
+		if config.region == "" {
+			config.region = serverEntry.Region
+		} else if config.region != serverEntry.Region {
+			return nil, errors.Tracef("unsupported multiple Region values")
+		}
 	}
 
 	config.runningProtocols = []string{}

+ 11 - 8
psiphon/server/listener.go

@@ -20,6 +20,7 @@
 package server
 
 import (
+	std_errors "errors"
 	"net"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
@@ -30,9 +31,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 )
 
-type restrictedProviderError struct{}
-
-func (restrictedProviderError) Error() string { return "restricted provider" }
+var errRestrictedProvider = std_errors.New("restricted provider")
 
 // TacticsListener wraps a net.Listener and applies server-side implementation
 // of certain tactics parameters to accepted connections. Tactics filtering is
@@ -114,20 +113,24 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
 	// not be applied to clients using indirect protocols, where the immediate
 	// peer IP is not the original client IP. Indirect protocols must determine
 	// the original client IP before applying GeoIP specific tactics; see the
-	// server-side enforcement of RestrictDirectProviderIDs for fronted meek in
-	// server.MeekServer.getSessionOrEndpoint.
+	// server-side enforcement of RestrictFrontingProviderIDs for fronted meek
+	// in server.MeekServer.getSessionOrEndpoint.
 	//
 	// At this stage, GeoIP tactics filters are active, but handshake API
 	// parameters are not.
 	//
 	// See the comment in server.LoadConfig regarding provider ID limitations.
 	if protocol.TunnelProtocolIsDirect(listener.tunnelProtocol) &&
-		common.Contains(
+		(common.Contains(
 			p.Strings(parameters.RestrictDirectProviderIDs),
-			listener.support.Config.GetProviderID()) {
+			listener.support.Config.GetProviderID()) ||
+			common.ContainsAny(
+				p.KeyStrings(parameters.RestrictDirectProviderRegions, listener.support.Config.GetProviderID()), []string{"", listener.support.Config.GetRegion()})) {
+
 		if p.WeightedCoinFlip(
 			parameters.RestrictDirectProviderIDsServerProbability) {
-			return nil, restrictedProviderError{}
+			conn.Close()
+			return nil, errRestrictedProvider
 		}
 	}
 

+ 1 - 1
psiphon/server/tunnelServer.go

@@ -705,7 +705,7 @@ func runListener(listener net.Listener, shutdownBroadcast <-chan struct{}, liste
 				log.WithTraceFields(LogFields{"error": err}).Error("accept failed")
 				// Temporary error, keep running
 				continue
-			} else if _, ok := err.(restrictedProviderError); ok {
+			} else if std_errors.Is(err, errRestrictedProvider) {
 				log.WithTraceFields(LogFields{"error": err}).Error("accept rejected client")
 				// Restricted provider, keep running
 				continue

+ 2 - 1
psiphon/serverApi.go

@@ -140,7 +140,8 @@ func (serverContext *ServerContext) doHandshakeRequest(
 	// The purpose of this mechanism is to rapidly add provider IDs to the
 	// server entries in client local storage, and to ensure that the client has
 	// a provider ID for its currently connected server as required for the
-	// RestrictDirectProviderIDs and HoldOffDirectServerEntryProviderRegions tactics.
+	// RestrictDirectProviderIDs, RestrictDirectProviderRegions, and
+	// HoldOffDirectServerEntryProviderRegions tactics.
 	//
 	// The server entry will be included in handshakeResponse.EncodedServerList,
 	// along side discovery servers.