Просмотр исходного кода

Add ProviderIDs to traffic rules filters

Rod Hynes 10 месяцев назад
Родитель
Сommit
b835480b82

+ 27 - 0
psiphon/server/trafficRules.go

@@ -203,11 +203,18 @@ type TrafficRulesFilter struct {
 	// revoked. When omitted or false, this field is ignored.
 	AuthorizationsRevoked bool
 
+	// ProviderIDs specifies a list of server host providers which match this
+	// filter. When ProviderIDs is not empty, the current server will apply
+	// the filter only if its provider ID, from Config.GetProviderID, is in
+	// ProviderIDs.
+	ProviderIDs []string
+
 	regionLookup                map[string]bool
 	ispLookup                   map[string]bool
 	asnLookup                   map[string]bool
 	cityLookup                  map[string]bool
 	activeAuthorizationIDLookup map[string]bool
+	providerIDLookup            map[string]bool
 }
 
 // TrafficRules specify the limits placed on client traffic.
@@ -555,6 +562,13 @@ func (set *TrafficRulesSet) initLookups() {
 				filter.activeAuthorizationIDLookup[ID] = true
 			}
 		}
+
+		if len(filter.ProviderIDs) >= stringLookupThreshold {
+			filter.providerIDLookup = make(map[string]bool)
+			for _, ID := range filter.ProviderIDs {
+				filter.providerIDLookup[ID] = true
+			}
+		}
 	}
 
 	initTrafficRulesLookups(&set.DefaultRules)
@@ -574,6 +588,7 @@ func (set *TrafficRulesSet) initLookups() {
 // For the return value TrafficRules, all pointer and slice fields are initialized,
 // so nil checks are not required. The caller must not modify the returned TrafficRules.
 func (set *TrafficRulesSet) GetTrafficRules(
+	serverProviderID string,
 	isFirstTunnelInSession bool,
 	tunnelProtocol string,
 	geoIPData GeoIPData,
@@ -818,6 +833,18 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			}
 		}
 
+		if len(filter.ProviderIDs) > 0 {
+			if filter.providerIDLookup != nil {
+				if !filter.providerIDLookup[serverProviderID] {
+					return false
+				}
+			} else {
+				if !common.Contains(filter.ProviderIDs, serverProviderID) {
+					return false
+				}
+			}
+		}
+
 		return true
 	}
 

+ 34 - 7
psiphon/server/trafficRules_test.go

@@ -49,6 +49,19 @@ func TestTrafficRulesFilters(t *testing.T) {
   
         {
           "Filter" : {
+            "ProviderIDs" : ["H2"]
+          },
+          "Rules" : {
+            "RateLimits" : {
+              "WriteBytesPerSecond": 99,
+              "ReadBytesPerSecond": 99
+            }
+          }
+        },
+
+        {
+          "Filter" : {
+            "ProviderIDs" : ["H1"],
             "Regions" : ["R2"],
             "HandshakeParameters" : {
                 "client_version" : ["1"]
@@ -142,8 +155,12 @@ func TestTrafficRulesFilters(t *testing.T) {
 		return p
 	}
 
+	// should never get 1st filtered rule with different provider ID
+	providerID := "H1"
+
 	testCases := []struct {
 		description                   string
+		providerID                    string
 		isFirstTunnelInSession        bool
 		tunnelProtocol                string
 		geoIPData                     GeoIPData
@@ -157,6 +174,7 @@ func TestTrafficRulesFilters(t *testing.T) {
 	}{
 		{
 			"get defaults",
+			providerID,
 			true,
 			"P1",
 			GeoIPData{Country: "R1", ISP: "I1"},
@@ -166,6 +184,7 @@ func TestTrafficRulesFilters(t *testing.T) {
 
 		{
 			"get defaults for not first tunnel in session",
+			providerID,
 			false,
 			"P1",
 			GeoIPData{Country: "R1", ISP: "I1"},
@@ -174,7 +193,8 @@ func TestTrafficRulesFilters(t *testing.T) {
 		},
 
 		{
-			"get first filtered rule",
+			"get 2nd filtered rule (including provider ID)",
+			providerID,
 			true,
 			"P1",
 			GeoIPData{Country: "R2", ISP: "I1"},
@@ -183,7 +203,8 @@ func TestTrafficRulesFilters(t *testing.T) {
 		},
 
 		{
-			"don't get first filtered rule with incomplete match",
+			"don't get 2nd filtered rule with incomplete match",
+			providerID,
 			true,
 			"P1",
 			GeoIPData{Country: "R2", ISP: "I1"},
@@ -192,7 +213,8 @@ func TestTrafficRulesFilters(t *testing.T) {
 		},
 
 		{
-			"get second filtered rule",
+			"get 3rd filtered rule",
+			providerID,
 			true,
 			"P2",
 			GeoIPData{Country: "R3", ISP: "I1"},
@@ -201,7 +223,8 @@ func TestTrafficRulesFilters(t *testing.T) {
 		},
 
 		{
-			"get second filtered rule with incomplete exception",
+			"get 3rd filtered rule with incomplete exception",
+			providerID,
 			true,
 			"P2",
 			GeoIPData{Country: "R3", ISP: "I2"},
@@ -210,7 +233,8 @@ func TestTrafficRulesFilters(t *testing.T) {
 		},
 
 		{
-			"don't get second filtered rule due to exception",
+			"don't get 3rd filtered rule due to exception",
+			providerID,
 			true,
 			"P2",
 			GeoIPData{Country: "R3", ISP: "I2"},
@@ -219,7 +243,8 @@ func TestTrafficRulesFilters(t *testing.T) {
 		},
 
 		{
-			"get third filtered rule",
+			"get 4th filtered rule",
+			providerID,
 			true,
 			"P1",
 			GeoIPData{Country: "R3", ISP: "I1"},
@@ -228,7 +253,8 @@ func TestTrafficRulesFilters(t *testing.T) {
 		},
 
 		{
-			"don't get third filtered rule due to exception",
+			"don't get 4th filtered rule due to exception",
+			providerID,
 			true,
 			"P1",
 			GeoIPData{Country: "R3", ISP: "I2"},
@@ -240,6 +266,7 @@ func TestTrafficRulesFilters(t *testing.T) {
 		t.Run(testCase.description, func(t *testing.T) {
 
 			rules := trafficRules.GetTrafficRules(
+				testCase.providerID,
 				testCase.isFirstTunnelInSession,
 				testCase.tunnelProtocol,
 				testCase.geoIPData,

+ 1 - 0
psiphon/server/tunnelServer.go

@@ -4629,6 +4629,7 @@ func (sshClient *sshClient) setTrafficRules() (int64, int64) {
 	// broker.
 
 	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
+		sshClient.sshServer.support.Config.GetProviderID(),
 		isFirstTunnelInSession,
 		sshClient.tunnelProtocol,
 		sshClient.clientGeoIPData,