Rod Hynes 4 лет назад
Родитель
Сommit
ebe7b5b3ac

+ 15 - 1
psiphon/common/portlist.go

@@ -53,6 +53,9 @@ const lookupThreshold = 10
 // concurrently with Lookup and should be called immediately after
 // UnmarshalJSON and before performing lookups.
 func (p *PortList) OptimizeLookups() {
+	if p == nil {
+		return
+	}
 	// TODO: does the threshold take long ranges into account?
 	if len(p.portRanges) > lookupThreshold {
 		p.lookup = make(map[int]bool)
@@ -64,9 +67,20 @@ func (p *PortList) OptimizeLookups() {
 	}
 }
 
+// IsEmpty returns true for a nil PortList or a PortList with no entries.
+func (p *PortList) IsEmpty() bool {
+	if p == nil {
+		return true
+	}
+	return len(p.portRanges) == 0
+}
+
 // Lookup returns true if the specified port is in the port list and false
-// otherwise.
+// otherwise. Lookups on a nil PortList are allowed and return false.
 func (p *PortList) Lookup(port int) bool {
+	if p == nil {
+		return false
+	}
 	if p.lookup != nil {
 		return p.lookup[port]
 	}

+ 62 - 5
psiphon/common/portlist_test.go

@@ -21,12 +21,15 @@ package common
 
 import (
 	"encoding/json"
+	"strings"
 	"testing"
+	"unicode"
 )
 
 func TestPortList(t *testing.T) {
 
-	var p PortList
+	var p *PortList
+
 	err := json.Unmarshal([]byte("[1.5]"), &p)
 	if err == nil {
 		t.Fatalf("unexpected parse of float port number")
@@ -52,19 +55,73 @@ func TestPortList(t *testing.T) {
 		t.Fatalf("unexpected parse of invalid port range")
 	}
 
+	p = nil
+
+	if p.Lookup(1) != false {
+		t.Fatalf("unexpected nil PortList Lookup result")
+	}
+
+	if !p.IsEmpty() {
+		t.Fatalf("unexpected nil PortList IsEmpty result")
+	}
+
+	err = json.Unmarshal([]byte("[]"), &p)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	if !p.IsEmpty() {
+		t.Fatalf("unexpected IsEmpty result")
+	}
+
+	err = json.Unmarshal([]byte("[1]"), &p)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	if p.IsEmpty() {
+		t.Fatalf("unexpected IsEmpty result")
+	}
+
 	s := struct {
-		List1 PortList
-		List2 PortList
+		List1 *PortList
+		List2 *PortList
 	}{}
 
-	jsonStruct := `
+	jsonString := `
     {
         "List1" : [1,2,[10,20],100,[1000,2000]],
         "List2" : [3,4,5,[300,400],1000,2000,[3000,3996],3997,3998,3999,4000]
     }
     `
 
-	err = json.Unmarshal([]byte(jsonStruct), &s)
+	err = json.Unmarshal([]byte(jsonString), &s)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	// Marshal and re-Unmarshal to exercise PortList.MarshalJSON.
+
+	jsonBytes, err := json.Marshal(s)
+	if err != nil {
+		t.Fatalf("Marshal failed: %v", err)
+	}
+
+	strip := func(s string) string {
+		return strings.Map(func(r rune) rune {
+			if unicode.IsSpace(r) {
+				return -1
+			}
+			return r
+		}, s)
+	}
+
+	if strip(jsonString) != strip(string(jsonBytes)) {
+
+		t.Fatalf("unexpected JSON encoding")
+	}
+
+	err = json.Unmarshal(jsonBytes, &s)
 	if err != nil {
 		t.Fatalf("Unmarshal failed: %v", err)
 	}

+ 6 - 6
psiphon/server/server_test.go

@@ -1998,12 +1998,12 @@ func paveTrafficRulesFile(
 
 	allowTCPPorts := TCPPorts
 	allowUDPPorts := UDPPorts
-	disallowTCPPorts := "0"
-	disallowUDPPorts := "0"
+	disallowTCPPorts := "1"
+	disallowUDPPorts := "1"
 
 	if deny {
-		allowTCPPorts = "0"
-		allowUDPPorts = "0"
+		allowTCPPorts = "1"
+		allowUDPPorts = "1"
 		disallowTCPPorts = TCPPorts
 		disallowUDPPorts = UDPPorts
 	}
@@ -2033,8 +2033,8 @@ func paveTrafficRulesFile(
                 "ReadUnthrottledBytes": %d,
                 "WriteUnthrottledBytes": %d
             },
-            "AllowTCPPorts" : [0],
-            "AllowUDPPorts" : [0],
+            "AllowTCPPorts" : [1],
+            "AllowUDPPorts" : [1],
             "MeekRateLimiterHistorySize" : 10,
             "MeekRateLimiterThresholdSeconds" : 1,
             "MeekRateLimiterGarbageCollectionTriggerCount" : 1,

+ 18 - 93
psiphon/server/trafficRules.go

@@ -236,21 +236,21 @@ type TrafficRules struct {
 
 	// AllowTCPPorts specifies a list of TCP ports that are permitted for port
 	// forwarding. When set, only ports in the list are accessible to clients.
-	AllowTCPPorts []int
+	AllowTCPPorts *common.PortList
 
 	// AllowUDPPorts specifies a list of UDP ports that are permitted for port
 	// forwarding. When set, only ports in the list are accessible to clients.
-	AllowUDPPorts []int
+	AllowUDPPorts *common.PortList
 
 	// DisallowTCPPorts specifies a list of TCP ports that are not permitted for
 	// port forwarding. DisallowTCPPorts takes priority over AllowTCPPorts and
 	// AllowSubnets.
-	DisallowTCPPorts []int
+	DisallowTCPPorts *common.PortList
 
 	// DisallowUDPPorts specifies a list of UDP ports that are not permitted for
 	// port forwarding. DisallowUDPPorts takes priority over AllowUDPPorts and
 	// AllowSubnets.
-	DisallowUDPPorts []int
+	DisallowUDPPorts *common.PortList
 
 	// AllowSubnets specifies a list of IP address subnets for which all TCP and
 	// UDP ports are allowed. This list is consulted if a port is disallowed by
@@ -261,11 +261,6 @@ type TrafficRules struct {
 	// client sends an IP address. Domain names are not resolved before checking
 	// AllowSubnets.
 	AllowSubnets []string
-
-	allowTCPPortsLookup    map[int]bool
-	allowUDPPortsLookup    map[int]bool
-	disallowTCPPortsLookup map[int]bool
-	disallowUDPPortsLookup map[int]bool
 }
 
 // RateLimits is a clone of common.RateLimits with pointers
@@ -434,33 +429,11 @@ func (set *TrafficRulesSet) initLookups() {
 
 	initTrafficRulesLookups := func(rules *TrafficRules) {
 
-		if len(rules.AllowTCPPorts) >= intLookupThreshold {
-			rules.allowTCPPortsLookup = make(map[int]bool)
-			for _, port := range rules.AllowTCPPorts {
-				rules.allowTCPPortsLookup[port] = true
-			}
-		}
-
-		if len(rules.AllowUDPPorts) >= intLookupThreshold {
-			rules.allowUDPPortsLookup = make(map[int]bool)
-			for _, port := range rules.AllowUDPPorts {
-				rules.allowUDPPortsLookup[port] = true
-			}
-		}
-
-		if len(rules.DisallowTCPPorts) >= intLookupThreshold {
-			rules.disallowTCPPortsLookup = make(map[int]bool)
-			for _, port := range rules.DisallowTCPPorts {
-				rules.disallowTCPPortsLookup[port] = true
-			}
-		}
+		rules.AllowTCPPorts.OptimizeLookups()
+		rules.AllowUDPPorts.OptimizeLookups()
+		rules.DisallowTCPPorts.OptimizeLookups()
+		rules.DisallowUDPPorts.OptimizeLookups()
 
-		if len(rules.DisallowUDPPorts) >= intLookupThreshold {
-			rules.disallowUDPPortsLookup = make(map[int]bool)
-			for _, port := range rules.DisallowUDPPorts {
-				rules.disallowUDPPortsLookup[port] = true
-			}
-		}
 	}
 
 	initTrafficRulesFilterLookups := func(filter *TrafficRulesFilter) {
@@ -600,14 +573,6 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			intPtr(DEFAULT_MAX_UDP_PORT_FORWARD_COUNT)
 	}
 
-	if trafficRules.AllowTCPPorts == nil {
-		trafficRules.AllowTCPPorts = make([]int, 0)
-	}
-
-	if trafficRules.AllowUDPPorts == nil {
-		trafficRules.AllowUDPPorts = make([]int, 0)
-	}
-
 	if trafficRules.AllowSubnets == nil {
 		trafficRules.AllowSubnets = make([]string, 0)
 	}
@@ -800,22 +765,18 @@ func (set *TrafficRulesSet) GetTrafficRules(
 
 		if filteredRules.Rules.AllowTCPPorts != nil {
 			trafficRules.AllowTCPPorts = filteredRules.Rules.AllowTCPPorts
-			trafficRules.allowTCPPortsLookup = filteredRules.Rules.allowTCPPortsLookup
 		}
 
 		if filteredRules.Rules.AllowUDPPorts != nil {
 			trafficRules.AllowUDPPorts = filteredRules.Rules.AllowUDPPorts
-			trafficRules.allowUDPPortsLookup = filteredRules.Rules.allowUDPPortsLookup
 		}
 
 		if filteredRules.Rules.DisallowTCPPorts != nil {
 			trafficRules.DisallowTCPPorts = filteredRules.Rules.DisallowTCPPorts
-			trafficRules.disallowTCPPortsLookup = filteredRules.Rules.disallowTCPPortsLookup
 		}
 
 		if filteredRules.Rules.DisallowUDPPorts != nil {
 			trafficRules.DisallowUDPPorts = filteredRules.Rules.DisallowUDPPorts
-			trafficRules.disallowUDPPortsLookup = filteredRules.Rules.disallowUDPPortsLookup
 		}
 
 		if filteredRules.Rules.AllowSubnets != nil {
@@ -837,34 +798,16 @@ func (set *TrafficRulesSet) GetTrafficRules(
 
 func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
 
-	if len(rules.DisallowTCPPorts) > 0 {
-		if rules.disallowTCPPortsLookup != nil {
-			if rules.disallowTCPPortsLookup[port] {
-				return false
-			}
-		} else {
-			for _, disallowPort := range rules.DisallowTCPPorts {
-				if port == disallowPort {
-					return false
-				}
-			}
-		}
+	if rules.DisallowTCPPorts.Lookup(port) {
+		return false
 	}
 
-	if len(rules.AllowTCPPorts) == 0 {
+	if rules.AllowTCPPorts.IsEmpty() {
 		return true
 	}
 
-	if rules.allowTCPPortsLookup != nil {
-		if rules.allowTCPPortsLookup[port] {
-			return true
-		}
-	} else {
-		for _, allowPort := range rules.AllowTCPPorts {
-			if port == allowPort {
-				return true
-			}
-		}
+	if rules.AllowTCPPorts.Lookup(port) {
+		return true
 	}
 
 	return rules.allowSubnet(remoteIP)
@@ -872,34 +815,16 @@ func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
 
 func (rules *TrafficRules) AllowUDPPort(remoteIP net.IP, port int) bool {
 
-	if len(rules.DisallowUDPPorts) > 0 {
-		if rules.disallowUDPPortsLookup != nil {
-			if rules.disallowUDPPortsLookup[port] {
-				return false
-			}
-		} else {
-			for _, disallowPort := range rules.DisallowUDPPorts {
-				if port == disallowPort {
-					return false
-				}
-			}
-		}
+	if rules.DisallowUDPPorts.Lookup(port) {
+		return false
 	}
 
-	if len(rules.AllowUDPPorts) == 0 {
+	if rules.AllowUDPPorts.IsEmpty() {
 		return true
 	}
 
-	if rules.allowUDPPortsLookup != nil {
-		if rules.allowUDPPortsLookup[port] {
-			return true
-		}
-	} else {
-		for _, allowPort := range rules.AllowUDPPorts {
-			if port == allowPort {
-				return true
-			}
-		}
+	if rules.AllowUDPPorts.Lookup(port) {
+		return true
 	}
 
 	return rules.allowSubnet(remoteIP)