Преглед изворни кода

Add fast lookups for traffic rules filters and allowed ports

- Also fix bugs in tactics fast lookups
Rod Hynes пре 6 година
родитељ
комит
a1485be9dc

+ 7 - 7
psiphon/common/tactics/tactics.go

@@ -450,8 +450,6 @@ func NewServer(
 				return common.ContextError(err)
 			}
 
-			newServer.initLookups()
-
 			// Modify actual traffic rules only after validation
 			server.RequestPublicKey = newServer.RequestPublicKey
 			server.RequestPrivateKey = newServer.RequestPrivateKey
@@ -459,6 +457,8 @@ func NewServer(
 			server.DefaultTactics = newServer.DefaultTactics
 			server.FilteredTactics = newServer.FilteredTactics
 
+			server.initLookups()
+
 			server.loaded = true
 
 			return nil
@@ -571,7 +571,7 @@ func (server *Server) Validate() error {
 	return nil
 }
 
-const lookupThreshold = 5
+const stringLookupThreshold = 5
 
 // initLookups creates map lookups for filters where the number
 // of string values to compare against exceeds a threshold where
@@ -581,17 +581,17 @@ func (server *Server) initLookups() {
 
 	for _, filteredTactics := range server.FilteredTactics {
 
-		if len(filteredTactics.Filter.Regions) >= lookupThreshold {
+		if len(filteredTactics.Filter.Regions) >= stringLookupThreshold {
 			filteredTactics.Filter.regionLookup = make(map[string]bool)
 			for _, region := range filteredTactics.Filter.Regions {
 				filteredTactics.Filter.regionLookup[region] = true
 			}
 		}
 
-		if len(filteredTactics.Filter.ISPs) >= lookupThreshold {
-			filteredTactics.Filter.regionLookup = make(map[string]bool)
+		if len(filteredTactics.Filter.ISPs) >= stringLookupThreshold {
+			filteredTactics.Filter.ispLookup = make(map[string]bool)
 			for _, ISP := range filteredTactics.Filter.ISPs {
-				filteredTactics.Filter.regionLookup[ISP] = true
+				filteredTactics.Filter.ispLookup[ISP] = true
 			}
 		}
 

+ 2 - 2
psiphon/common/tactics/tactics_test.go

@@ -116,8 +116,8 @@ func TestTactics(t *testing.T) {
       ]
     }
     `
-	if lookupThreshold != 5 {
-		t.Fatalf("unexpected lookupThreshold")
+	if stringLookupThreshold != 5 {
+		t.Fatalf("unexpected stringLookupThreshold")
 	}
 
 	encodedRequestPublicKey, encodedRequestPrivateKey, encodedObfuscatedKey, err := GenerateKeys()

+ 6 - 1
psiphon/server/server_test.go

@@ -1577,8 +1577,13 @@ func paveTrafficRulesFile(
 	requireAuthorization, deny bool,
 	livenessTestSize int) {
 
+	// Test both default and fast lookups
+	if intLookupThreshold != 10 {
+		t.Fatalf("unexpected intLookupThreshold")
+	}
+
 	allowTCPPorts := fmt.Sprintf("%d", mockWebServerPort)
-	allowUDPPorts := "53, 123"
+	allowUDPPorts := "53, 123, 10001, 10002, 10003, 10004, 10005, 10006, 10007, 10008, 10009, 10010"
 
 	if deny {
 		allowTCPPorts = "0"

+ 134 - 5
psiphon/server/trafficRules.go

@@ -147,6 +147,9 @@ type TrafficRulesFilter struct {
 	// must have been revoked. When true, authorizations must have been
 	// revoked. When omitted or false, this field is ignored.
 	AuthorizationsRevoked bool
+
+	regionLookup map[string]bool
+	ispLookup    map[string]bool
 }
 
 // TrafficRules specify the limits placed on client traffic.
@@ -218,8 +221,11 @@ type TrafficRules struct {
 	// in CIDR notation.
 	// Limitation: currently, AllowSubnets only matches port
 	// forwards where the client sends an IP address. Domain
-	// names aren not resolved before checking AllowSubnets.
+	// names are not resolved before checking AllowSubnets.
 	AllowSubnets []string
+
+	allowTCPPortsLookup map[int]bool
+	allowUDPPortsLookup map[int]bool
 }
 
 // RateLimits is a clone of common.RateLimits with pointers
@@ -280,6 +286,8 @@ func NewTrafficRulesSet(filename string) (*TrafficRulesSet, error) {
 			set.DefaultRules = newSet.DefaultRules
 			set.FilteredRules = newSet.FilteredRules
 
+			set.initLookups()
+
 			return nil
 		})
 
@@ -366,6 +374,58 @@ func (set *TrafficRulesSet) Validate() error {
 	return nil
 }
 
+const stringLookupThreshold = 5
+const intLookupThreshold = 10
+
+// initLookups creates map lookups for filters where the number of string/int
+// values to compare against exceeds a threshold where benchmarks show maps
+// are faster than looping through a string/int slice.
+func (set *TrafficRulesSet) initLookups() {
+
+	initTrafficRulesLookups := func(rules *TrafficRules) {
+
+		if len(rules.AllowTCPPorts) >= intLookupThreshold {
+			rules.allowTCPPortsLookup = make(map[int]bool)
+			for _, port := range rules.AllowTCPPorts {
+				rules.allowTCPPortsLookup[port] = true
+			}
+		}
+
+		if len(rules.AllowUDPPorts) >= intLookupThreshold {
+			rules.allowUDPPortsLookup = make(map[int]bool)
+			for _, port := range rules.AllowUDPPorts {
+				rules.allowUDPPortsLookup[port] = true
+			}
+		}
+	}
+
+	initTrafficRulesFilterLookups := func(filter *TrafficRulesFilter) {
+
+		if len(filter.Regions) >= stringLookupThreshold {
+			filter.regionLookup = make(map[string]bool)
+			for _, region := range filter.Regions {
+				filter.regionLookup[region] = true
+			}
+		}
+
+		if len(filter.ISPs) >= stringLookupThreshold {
+			filter.ispLookup = make(map[string]bool)
+			for _, ISP := range filter.ISPs {
+				filter.ispLookup[ISP] = true
+			}
+		}
+	}
+
+	initTrafficRulesLookups(&set.DefaultRules)
+
+	for i, _ := range set.FilteredRules {
+		initTrafficRulesFilterLookups(&set.FilteredRules[i].Filter)
+		initTrafficRulesLookups(&set.FilteredRules[i].Rules)
+	}
+
+	// TODO: add lookups for MeekRateLimiter?
+}
+
 // GetTrafficRules determines the traffic rules for a client based on its attributes.
 // For the return value TrafficRules, all pointer and slice fields are initialized,
 // so nil checks are not required. The caller must not modify the returned TrafficRules.
@@ -478,14 +538,26 @@ func (set *TrafficRulesSet) GetTrafficRules(
 		}
 
 		if len(filteredRules.Filter.Regions) > 0 {
-			if !common.Contains(filteredRules.Filter.Regions, geoIPData.Country) {
-				continue
+			if filteredRules.Filter.regionLookup != nil {
+				if !filteredRules.Filter.regionLookup[geoIPData.Country] {
+					continue
+				}
+			} else {
+				if !common.Contains(filteredRules.Filter.Regions, geoIPData.Country) {
+					continue
+				}
 			}
 		}
 
 		if len(filteredRules.Filter.ISPs) > 0 {
-			if !common.Contains(filteredRules.Filter.ISPs, geoIPData.ISP) {
-				continue
+			if filteredRules.Filter.ispLookup != nil {
+				if !filteredRules.Filter.ispLookup[geoIPData.ISP] {
+					continue
+				}
+			} else {
+				if !common.Contains(filteredRules.Filter.ISPs, geoIPData.ISP) {
+					continue
+				}
 			}
 		}
 
@@ -593,10 +665,12 @@ func (set *TrafficRulesSet) GetTrafficRules(
 
 		if filteredRules.Rules.AllowTCPPorts != nil {
 			trafficRules.AllowTCPPorts = filteredRules.Rules.AllowTCPPorts
+			trafficRules.allowTCPPortsLookup = filteredRules.Rules.allowTCPPortsLookup
 		}
 
 		if filteredRules.Rules.AllowUDPPorts != nil {
 			trafficRules.AllowUDPPorts = filteredRules.Rules.AllowUDPPorts
+			trafficRules.allowUDPPortsLookup = filteredRules.Rules.allowUDPPortsLookup
 		}
 
 		if filteredRules.Rules.AllowSubnets != nil {
@@ -616,6 +690,61 @@ func (set *TrafficRulesSet) GetTrafficRules(
 	return trafficRules
 }
 
+func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
+
+	if len(rules.AllowTCPPorts) == 0 {
+		return true
+	}
+
+	if rules.allowTCPPortsLookup != nil {
+		if rules.allowTCPPortsLookup[port] == true {
+			return true
+		}
+	} else {
+		for _, allowPort := range rules.AllowTCPPorts {
+			if port == allowPort {
+				return true
+			}
+		}
+	}
+
+	return rules.allowSubnet(remoteIP)
+}
+
+func (rules *TrafficRules) AllowUDPPort(remoteIP net.IP, port int) bool {
+
+	if len(rules.AllowUDPPorts) == 0 {
+		return true
+	}
+
+	if rules.allowUDPPortsLookup != nil {
+		if rules.allowUDPPortsLookup[port] == true {
+			return true
+		}
+	} else {
+		for _, allowPort := range rules.AllowUDPPorts {
+			if port == allowPort {
+				return true
+			}
+		}
+	}
+
+	return rules.allowSubnet(remoteIP)
+}
+
+func (rules *TrafficRules) allowSubnet(remoteIP net.IP) bool {
+
+	for _, subnet := range rules.AllowSubnets {
+		// Note: ignoring error as config has been validated
+		_, network, _ := net.ParseCIDR(subnet)
+		if network.Contains(remoteIP) {
+			return true
+		}
+	}
+
+	return false
+}
+
 // GetMeekRateLimiterConfig gets a snapshot of the meek rate limiter
 // configuration values.
 func (set *TrafficRulesSet) GetMeekRateLimiterConfig() (int, int, []string, []string, int, int) {

+ 6 - 23
psiphon/server/tunnelServer.go

@@ -2632,30 +2632,13 @@ func (sshClient *sshClient) isPortForwardPermitted(
 
 	// Traffic rules checks.
 
-	var allowPorts []int
-	if portForwardType == portForwardTypeTCP {
-		allowPorts = sshClient.trafficRules.AllowTCPPorts
-	} else {
-		allowPorts = sshClient.trafficRules.AllowUDPPorts
-	}
-
-	if len(allowPorts) == 0 {
-		return true
-	}
-
-	// TODO: faster lookup?
-	if len(allowPorts) > 0 {
-		for _, allowPort := range allowPorts {
-			if port == allowPort {
-				return true
-			}
+	switch portForwardType {
+	case portForwardTypeTCP:
+		if sshClient.trafficRules.AllowTCPPort(remoteIP, port) {
+			return true
 		}
-	}
-
-	for _, subnet := range sshClient.trafficRules.AllowSubnets {
-		// Note: ignoring error as config has been validated
-		_, network, _ := net.ParseCIDR(subnet)
-		if network.Contains(remoteIP) {
+	case portForwardTypeUDP:
+		if sshClient.trafficRules.AllowUDPPort(remoteIP, port) {
 			return true
 		}
 	}