Просмотр исходного кода

Merge pull request #598 from rod-hynes/master

Add LimitTunnelDialPortNumbers
Rod Hynes 4 лет назад
Родитель
Сommit
80dfe1f85b

+ 49 - 0
psiphon/common/parameters/parameters.go

@@ -99,6 +99,8 @@ const (
 	InitialLimitTunnelProtocolsCandidateCount        = "InitialLimitTunnelProtocolsCandidateCount"
 	LimitTunnelProtocolsProbability                  = "LimitTunnelProtocolsProbability"
 	LimitTunnelProtocols                             = "LimitTunnelProtocols"
+	LimitTunnelDialPortNumbersProbability            = "LimitTunnelDialPortNumbersProbability"
+	LimitTunnelDialPortNumbers                       = "LimitTunnelDialPortNumbers"
 	LimitTLSProfilesProbability                      = "LimitTLSProfilesProbability"
 	LimitTLSProfiles                                 = "LimitTLSProfiles"
 	UseOnlyCustomTLSProfiles                         = "UseOnlyCustomTLSProfiles"
@@ -362,6 +364,9 @@ var defaultParameters = map[string]struct {
 	LimitTunnelProtocolsProbability: {value: 1.0, minimum: 0.0},
 	LimitTunnelProtocols:            {value: protocol.TunnelProtocols{}},
 
+	LimitTunnelDialPortNumbersProbability: {value: 1.0, minimum: 0.0},
+	LimitTunnelDialPortNumbers:            {value: TunnelProtocolPortLists{}},
+
 	LimitTLSProfilesProbability:           {value: 1.0, minimum: 0.0},
 	LimitTLSProfiles:                      {value: protocol.TLSProfiles{}},
 	UseOnlyCustomTLSProfiles:              {value: false},
@@ -931,6 +936,22 @@ func (p *Parameters) Set(
 					}
 					return nil, errors.Trace(err)
 				}
+			case FrontingSpecs:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
+			case TunnelProtocolPortLists:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
 			}
 
 			// Enforce any minimums. Assumes defaultParameters[name]
@@ -1389,3 +1410,31 @@ func (p ParametersAccessor) FrontingSpecs(name string) FrontingSpecs {
 	p.snapshot.getValue(name, &value)
 	return value
 }
+
+// TunnelProtocolPortLists returns a TunnelProtocolPortLists parameter value.
+func (p ParametersAccessor) TunnelProtocolPortLists(name string) TunnelProtocolPortLists {
+
+	probabilityName := name + "Probability"
+	_, ok := p.snapshot.parameters[probabilityName]
+	if ok {
+		probabilityValue := float64(1.0)
+		p.snapshot.getValue(probabilityName, &probabilityValue)
+		if !prng.FlipWeightedCoin(probabilityValue) {
+			defaultParameter, ok := defaultParameters[name]
+			if ok {
+				defaultValue, ok := defaultParameter.value.(TunnelProtocolPortLists)
+				if ok {
+					value := make(TunnelProtocolPortLists)
+					for tunnelProtocol, portLists := range defaultValue {
+						value[tunnelProtocol] = portLists
+					}
+					return value
+				}
+			}
+		}
+	}
+
+	value := make(TunnelProtocolPortLists)
+	p.snapshot.getValue(name, &value)
+	return value
+}

+ 5 - 0
psiphon/common/parameters/parameters_test.go

@@ -154,6 +154,11 @@ func TestGetDefaultParameters(t *testing.T) {
 			if !reflect.DeepEqual(v, g) {
 				t.Fatalf("FrontingSpecs returned %+v expected %+v", g, v)
 			}
+		case TunnelProtocolPortLists:
+			g := p.Get().TunnelProtocolPortLists(name)
+			if !reflect.DeepEqual(v, g) {
+				t.Fatalf("TunnelProtocolPortLists returned %+v expected %+v", g, v)
+			}
 		default:
 			t.Fatalf("Unhandled default type: %s", name)
 		}

+ 41 - 0
psiphon/common/parameters/portlist.go

@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package parameters
+
+import (
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+)
+
+// TunnelProtocolPortLists is a map from tunnel protocol names (or "All") to a
+// list of port number ranges.
+type TunnelProtocolPortLists map[string]*common.PortList
+
+// Validate checks that tunnel protocol names are valid.
+func (lists TunnelProtocolPortLists) Validate() error {
+	for tunnelProtocol, _ := range lists {
+		if tunnelProtocol != protocol.TUNNEL_PROTOCOLS_ALL &&
+			!common.Contains(protocol.SupportedTunnelProtocols, tunnelProtocol) {
+			return errors.TraceNew("invalid tunnel protocol for port list")
+		}
+	}
+	return nil
+}

+ 196 - 0
psiphon/common/portlist.go

@@ -0,0 +1,196 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"bytes"
+	"encoding/json"
+	"strconv"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+)
+
+// PortList provides a lookup for a configured list of IP ports and port
+// ranges. PortList is intended for use with JSON config files and is
+// initialized via UnmarshalJSON.
+//
+// A JSON port list field should look like:
+//
+// "FieldName": [1, 2, 3, [10, 20], [30, 40]]
+//
+// where the ports in the list are 1, 2, 3, 10-20, 30-40. UnmarshalJSON
+// validates that each port is in the range 1-65535 and that ranges have two
+// elements in increasing order. PortList is designed to be backwards
+// compatible with existing JSON config files where port list fields were
+// defined as `[]int`.
+type PortList struct {
+	portRanges [][2]int
+	lookup     map[int]bool
+}
+
+const lookupThreshold = 10
+
+// OptimizeLookups converts the internal port list representation to use a
+// map, which increases the performance of lookups for longer lists with an
+// increased memory footprint tradeoff. OptimizeLookups is not safe to use
+// concurrently with Lookup and should be called immediately after
+// UnmarshalJSON and before performing lookups.
+func (p *PortList) OptimizeLookups() {
+	if p == nil {
+		return
+	}
+	// TODO: does the threshold take long ranges into account?
+	if len(p.portRanges) > lookupThreshold {
+		p.lookup = make(map[int]bool)
+		for _, portRange := range p.portRanges {
+			for i := portRange[0]; i <= portRange[1]; i++ {
+				p.lookup[i] = true
+			}
+		}
+	}
+}
+
+// IsEmpty returns true for a nil PortList or a PortList with no entries.
+func (p *PortList) IsEmpty() bool {
+	if p == nil {
+		return true
+	}
+	return len(p.portRanges) == 0
+}
+
+// Lookup returns true if the specified port is in the port list and false
+// otherwise. Lookups on a nil PortList are allowed and return false.
+func (p *PortList) Lookup(port int) bool {
+	if p == nil {
+		return false
+	}
+	if p.lookup != nil {
+		return p.lookup[port]
+	}
+	for _, portRange := range p.portRanges {
+		if port >= portRange[0] && port <= portRange[1] {
+			return true
+		}
+	}
+	return false
+}
+
+// UnmarshalJSON implements the json.Unmarshaler interface.
+func (p *PortList) UnmarshalJSON(b []byte) error {
+
+	p.portRanges = nil
+	p.lookup = nil
+
+	if bytes.Equal(b, []byte("null")) {
+		return nil
+	}
+
+	decoder := json.NewDecoder(bytes.NewReader(b))
+	decoder.UseNumber()
+
+	var array []interface{}
+
+	err := decoder.Decode(&array)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	p.portRanges = make([][2]int, len(array))
+
+	for i, portRange := range array {
+
+		var startPort, endPort int64
+
+		if portNumber, ok := portRange.(json.Number); ok {
+
+			port, err := portNumber.Int64()
+			if err != nil {
+				return errors.Trace(err)
+			}
+
+			startPort = port
+			endPort = port
+
+		} else if array, ok := portRange.([]interface{}); ok {
+
+			if len(array) != 2 {
+				return errors.TraceNew("invalid range size")
+			}
+
+			portNumber, ok := array[0].(json.Number)
+			if !ok {
+				return errors.TraceNew("invalid type")
+			}
+			port, err := portNumber.Int64()
+			if err != nil {
+				return errors.Trace(err)
+			}
+			startPort = port
+
+			portNumber, ok = array[1].(json.Number)
+			if !ok {
+				return errors.TraceNew("invalid type")
+			}
+			port, err = portNumber.Int64()
+			if err != nil {
+				return errors.Trace(err)
+			}
+			endPort = port
+
+		} else {
+
+			return errors.TraceNew("invalid type")
+		}
+
+		if startPort < 1 || startPort > 65535 {
+			return errors.TraceNew("invalid range start")
+		}
+
+		if endPort < 1 || endPort > 65535 || endPort < startPort {
+			return errors.TraceNew("invalid range end")
+		}
+
+		p.portRanges[i] = [2]int{int(startPort), int(endPort)}
+	}
+
+	return nil
+}
+
+// MarshalJSON implements the json.Marshaler interface.
+func (p *PortList) MarshalJSON() ([]byte, error) {
+	var json bytes.Buffer
+	json.WriteString("[")
+	for i, portRange := range p.portRanges {
+		if i > 0 {
+			json.WriteString(",")
+		}
+		if portRange[0] == portRange[1] {
+			json.WriteString(strconv.Itoa(portRange[0]))
+		} else {
+			json.WriteString("[")
+			json.WriteString(strconv.Itoa(portRange[0]))
+			json.WriteString(",")
+			json.WriteString(strconv.Itoa(portRange[1]))
+			json.WriteString("]")
+		}
+	}
+	json.WriteString("]")
+	return json.Bytes(), nil
+}

+ 222 - 0
psiphon/common/portlist_test.go

@@ -0,0 +1,222 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package common
+
+import (
+	"encoding/json"
+	"strings"
+	"testing"
+	"unicode"
+)
+
+func TestPortList(t *testing.T) {
+
+	var p *PortList
+
+	err := json.Unmarshal([]byte("[1.5]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of float port number")
+	}
+
+	err = json.Unmarshal([]byte("[-1]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of negative port number")
+	}
+
+	err = json.Unmarshal([]byte("[0]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of invalid port number")
+	}
+
+	err = json.Unmarshal([]byte("[65536]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of invalid port number")
+	}
+
+	err = json.Unmarshal([]byte("[[2,1]]"), &p)
+	if err == nil {
+		t.Fatalf("unexpected parse of invalid port range")
+	}
+
+	p = nil
+
+	if p.Lookup(1) != false {
+		t.Fatalf("unexpected nil PortList Lookup result")
+	}
+
+	if !p.IsEmpty() {
+		t.Fatalf("unexpected nil PortList IsEmpty result")
+	}
+
+	err = json.Unmarshal([]byte("[]"), &p)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	if !p.IsEmpty() {
+		t.Fatalf("unexpected IsEmpty result")
+	}
+
+	err = json.Unmarshal([]byte("[1]"), &p)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	if p.IsEmpty() {
+		t.Fatalf("unexpected IsEmpty result")
+	}
+
+	s := struct {
+		List1 *PortList
+		List2 *PortList
+	}{}
+
+	jsonString := `
+    {
+        "List1" : [1,2,[10,20],100,[1000,2000]],
+        "List2" : [3,4,5,[300,400],1000,2000,[3000,3996],3997,3998,3999,4000]
+    }
+    `
+
+	err = json.Unmarshal([]byte(jsonString), &s)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	// Marshal and re-Unmarshal to exercise PortList.MarshalJSON.
+
+	jsonBytes, err := json.Marshal(s)
+	if err != nil {
+		t.Fatalf("Marshal failed: %v", err)
+	}
+
+	strip := func(s string) string {
+		return strings.Map(func(r rune) rune {
+			if unicode.IsSpace(r) {
+				return -1
+			}
+			return r
+		}, s)
+	}
+
+	if strip(jsonString) != strip(string(jsonBytes)) {
+
+		t.Fatalf("unexpected JSON encoding")
+	}
+
+	err = json.Unmarshal(jsonBytes, &s)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %v", err)
+	}
+
+	s.List1.OptimizeLookups()
+	if s.List1.lookup != nil {
+		t.Fatalf("unexpected lookup initialization")
+	}
+
+	s.List2.OptimizeLookups()
+	if s.List2.lookup == nil {
+		t.Fatalf("unexpected lookup initialization")
+	}
+
+	for port := 0; port < 65536; port++ {
+
+		lookup1 := s.List1.Lookup(port)
+		expected1 := port == 1 ||
+			port == 2 ||
+			(port >= 10 && port <= 20) ||
+			port == 100 ||
+			(port >= 1000 && port <= 2000)
+		if lookup1 != expected1 {
+			t.Fatalf("unexpected port lookup: %d %v", port, lookup1)
+		}
+
+		lookup2 := s.List2.Lookup(port)
+		expected2 := port == 3 ||
+			port == 4 ||
+			port == 5 ||
+			(port >= 300 && port <= 400) ||
+			port == 1000 || port == 2000 ||
+			(port >= 3000 && port <= 4000)
+		if lookup2 != expected2 {
+			t.Fatalf("unexpected port lookup: %d %v", port, lookup2)
+		}
+	}
+}
+
+func BenchmarkPortListLinear(b *testing.B) {
+
+	s := struct {
+		List PortList
+	}{}
+
+	jsonStruct := `
+    {
+        "List" : [1,2,3,4,5,6,7,8,9,[10,20]]
+    }
+    `
+
+	err := json.Unmarshal([]byte(jsonStruct), &s)
+	if err != nil {
+		b.Fatalf("Unmarshal failed: %v", err)
+	}
+	s.List.OptimizeLookups()
+	if s.List.lookup != nil {
+		b.Fatalf("unexpected lookup initialization")
+	}
+
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		for port := 0; port < 65536; port++ {
+			s.List.Lookup(port)
+		}
+	}
+}
+
+func BenchmarkPortListMap(b *testing.B) {
+
+	s := struct {
+		List PortList
+	}{}
+
+	jsonStruct := `
+    {
+        "List" : [1,2,3,4,5,6,7,8,9,10,[11,20]]
+    }
+    `
+
+	err := json.Unmarshal([]byte(jsonStruct), &s)
+	if err != nil {
+		b.Fatalf("Unmarshal failed: %v", err)
+	}
+	s.List.OptimizeLookups()
+	if s.List.lookup == nil {
+		b.Fatalf("unexpected lookup initialization")
+	}
+
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		for port := 0; port < 65536; port++ {
+			s.List.Lookup(port)
+		}
+	}
+}

+ 82 - 12
psiphon/common/protocol/serverEntry.go

@@ -477,51 +477,121 @@ type ConditionallyEnabledComponents interface {
 	RefractionNetworkingEnabled() bool
 }
 
-// GetSupportedProtocols returns a list of tunnel protocols supported
-// by the ServerEntry's capabilities.
+// TunnelProtocolPortLists is a map from tunnel protocol names (or "All") to a
+// list of port number ranges.
+type TunnelProtocolPortLists map[string]*common.PortList
+
+// GetSupportedProtocols returns a list of tunnel protocols supported by the
+// ServerEntry's capabilities and allowed by various constraints.
 func (serverEntry *ServerEntry) GetSupportedProtocols(
 	conditionallyEnabled ConditionallyEnabledComponents,
 	useUpstreamProxy bool,
 	limitTunnelProtocols []string,
+	limitTunnelDialPortNumbers TunnelProtocolPortLists,
 	excludeIntensive bool) []string {
 
 	supportedProtocols := make([]string, 0)
 
-	for _, protocol := range SupportedTunnelProtocols {
+	for _, tunnelProtocol := range SupportedTunnelProtocols {
 
-		if useUpstreamProxy && !TunnelProtocolSupportsUpstreamProxy(protocol) {
+		if useUpstreamProxy && !TunnelProtocolSupportsUpstreamProxy(tunnelProtocol) {
 			continue
 		}
 
 		if len(limitTunnelProtocols) > 0 {
-			if !common.Contains(limitTunnelProtocols, protocol) {
+			if !common.Contains(limitTunnelProtocols, tunnelProtocol) {
 				continue
 			}
 		} else {
-			if common.Contains(DefaultDisabledTunnelProtocols, protocol) {
+			if common.Contains(DefaultDisabledTunnelProtocols, tunnelProtocol) {
 				continue
 			}
 		}
 
-		if excludeIntensive && TunnelProtocolIsResourceIntensive(protocol) {
+		if excludeIntensive && TunnelProtocolIsResourceIntensive(tunnelProtocol) {
 			continue
 		}
 
-		if (TunnelProtocolUsesQUIC(protocol) && !conditionallyEnabled.QUICEnabled()) ||
-			(TunnelProtocolUsesMarionette(protocol) && !conditionallyEnabled.MarionetteEnabled()) ||
-			(TunnelProtocolUsesRefractionNetworking(protocol) &&
+		if (TunnelProtocolUsesQUIC(tunnelProtocol) && !conditionallyEnabled.QUICEnabled()) ||
+			(TunnelProtocolUsesMarionette(tunnelProtocol) && !conditionallyEnabled.MarionetteEnabled()) ||
+			(TunnelProtocolUsesRefractionNetworking(tunnelProtocol) &&
 				!conditionallyEnabled.RefractionNetworkingEnabled()) {
 			continue
 		}
 
-		if serverEntry.SupportsProtocol(protocol) {
-			supportedProtocols = append(supportedProtocols, protocol)
+		if !serverEntry.SupportsProtocol(tunnelProtocol) {
+			continue
+		}
+
+		dialPortNumber, err := serverEntry.GetDialPortNumber(tunnelProtocol)
+		if err != nil {
+			continue
+		}
+
+		if len(limitTunnelDialPortNumbers) > 0 {
+			if portList, ok := limitTunnelDialPortNumbers[tunnelProtocol]; ok {
+				if !portList.Lookup(dialPortNumber) {
+					continue
+				}
+			} else if portList, ok := limitTunnelDialPortNumbers[TUNNEL_PROTOCOLS_ALL]; ok {
+				if !portList.Lookup(dialPortNumber) {
+					continue
+				}
+			}
 		}
 
+		supportedProtocols = append(supportedProtocols, tunnelProtocol)
+
 	}
 	return supportedProtocols
 }
 
+func (serverEntry *ServerEntry) GetDialPortNumber(tunnelProtocol string) (int, error) {
+
+	if !serverEntry.SupportsProtocol(tunnelProtocol) {
+		return 0, errors.TraceNew("protocol not supported")
+	}
+
+	switch tunnelProtocol {
+
+	case TUNNEL_PROTOCOL_SSH:
+		return serverEntry.SshPort, nil
+
+	case TUNNEL_PROTOCOL_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedPort, nil
+
+	case TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedTapDancePort, nil
+
+	case TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedConjurePort, nil
+
+	case TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
+		return serverEntry.SshObfuscatedQUICPort, nil
+
+	case TUNNEL_PROTOCOL_FRONTED_MEEK,
+		TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
+		return 443, nil
+
+	case TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
+		return 80, nil
+
+	case TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
+		TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET,
+		TUNNEL_PROTOCOL_UNFRONTED_MEEK:
+		return serverEntry.MeekServerPort, nil
+
+	case TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
+		// The port is encoded in the marionnete "format"
+		// Limitations:
+		// - not compatible with LimitDialPortNumbers
+		// - accurate port is not reported via dial_port_number
+		return -1, nil
+	}
+
+	return 0, errors.TraceNew("unknown protocol")
+}
+
 // GetSupportedTacticsProtocols returns a list of tunnel protocols,
 // supported by the ServerEntry's capabilities, that may be used
 // for tactics requests.

+ 18 - 4
psiphon/config.go

@@ -751,6 +751,9 @@ type Config struct {
 	// UpstreamProxyAllowAllServerEntrySources is for testing purposes.
 	UpstreamProxyAllowAllServerEntrySources *bool
 
+	// LimitTunnelDialPortNumbers is for testing purposes.
+	LimitTunnelDialPortNumbers parameters.TunnelProtocolPortLists
+
 	// params is the active parameters.Parameters with defaults, config values,
 	// and, optionally, tactics applied.
 	//
@@ -1604,7 +1607,7 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.UseOnlyCustomTLSProfiles] = *config.UseOnlyCustomTLSProfiles
 	}
 
-	if config.CustomTLSProfiles != nil {
+	if len(config.CustomTLSProfiles) > 0 {
 		applyParameters[parameters.CustomTLSProfiles] = config.CustomTLSProfiles
 	}
 
@@ -1616,7 +1619,7 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.NoDefaultTLSSessionIDProbability] = *config.NoDefaultTLSSessionIDProbability
 	}
 
-	if config.DisableFrontingProviderTLSProfiles != nil {
+	if len(config.DisableFrontingProviderTLSProfiles) > 0 {
 		applyParameters[parameters.DisableFrontingProviderTLSProfiles] = config.DisableFrontingProviderTLSProfiles
 	}
 
@@ -1660,7 +1663,7 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.ConjureAPIRegistrarURL] = config.ConjureAPIRegistrarURL
 	}
 
-	if config.ConjureAPIRegistrarFrontingSpecs != nil {
+	if len(config.ConjureAPIRegistrarFrontingSpecs) > 0 {
 		applyParameters[parameters.ConjureAPIRegistrarFrontingSpecs] = config.ConjureAPIRegistrarFrontingSpecs
 	}
 
@@ -1720,6 +1723,10 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.UpstreamProxyAllowAllServerEntrySources] = *config.UpstreamProxyAllowAllServerEntrySources
 	}
 
+	if len(config.LimitTunnelDialPortNumbers) > 0 {
+		applyParameters[parameters.LimitTunnelDialPortNumbers] = config.LimitTunnelDialPortNumbers
+	}
+
 	// When adding new config dial parameters that may override tactics, also
 	// update setDialParametersHash.
 
@@ -1929,7 +1936,7 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, *config.NoDefaultTLSSessionIDProbability)
 	}
 
-	if config.DisableFrontingProviderTLSProfiles != nil {
+	if len(config.DisableFrontingProviderTLSProfiles) > 0 {
 		hash.Write([]byte("DisableFrontingProviderTLSProfiles"))
 		encodedDisableFrontingProviderTLSProfiles, _ :=
 			json.Marshal(config.DisableFrontingProviderTLSProfiles)
@@ -2044,6 +2051,13 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, *config.UpstreamProxyAllowAllServerEntrySources)
 	}
 
+	if len(config.LimitTunnelDialPortNumbers) > 0 {
+		hash.Write([]byte("LimitTunnelDialPortNumbers"))
+		encodedLimitTunnelDialPortNumbers, _ :=
+			json.Marshal(config.LimitTunnelDialPortNumbers)
+		hash.Write(encodedLimitTunnelDialPortNumbers)
+	}
+
 	config.dialParametersHash = hash.Sum(nil)
 }
 

+ 31 - 21
psiphon/controller.go

@@ -1358,15 +1358,16 @@ func (controller *Controller) triggerFetches() {
 }
 
 type protocolSelectionConstraints struct {
-	useUpstreamProxy                    bool
-	initialLimitProtocols               protocol.TunnelProtocols
-	initialLimitProtocolsCandidateCount int
-	limitProtocols                      protocol.TunnelProtocols
-	replayCandidateCount                int
+	useUpstreamProxy                          bool
+	initialLimitTunnelProtocols               protocol.TunnelProtocols
+	initialLimitTunnelProtocolsCandidateCount int
+	limitTunnelProtocols                      protocol.TunnelProtocols
+	limitTunnelDialPortNumbers                protocol.TunnelProtocolPortLists
+	replayCandidateCount                      int
 }
 
 func (p *protocolSelectionConstraints) hasInitialProtocols() bool {
-	return len(p.initialLimitProtocols) > 0 && p.initialLimitProtocolsCandidateCount > 0
+	return len(p.initialLimitTunnelProtocols) > 0 && p.initialLimitTunnelProtocolsCandidateCount > 0
 }
 
 func (p *protocolSelectionConstraints) isInitialCandidate(
@@ -1377,7 +1378,8 @@ func (p *protocolSelectionConstraints) isInitialCandidate(
 		len(serverEntry.GetSupportedProtocols(
 			conditionallyEnabledComponents{},
 			p.useUpstreamProxy,
-			p.initialLimitProtocols,
+			p.initialLimitTunnelProtocols,
+			p.limitTunnelDialPortNumbers,
 			excludeIntensive)) > 0
 }
 
@@ -1385,11 +1387,12 @@ func (p *protocolSelectionConstraints) isCandidate(
 	excludeIntensive bool,
 	serverEntry *protocol.ServerEntry) bool {
 
-	return len(p.limitProtocols) == 0 ||
+	return len(p.limitTunnelProtocols) == 0 ||
 		len(serverEntry.GetSupportedProtocols(
 			conditionallyEnabledComponents{},
 			p.useUpstreamProxy,
-			p.limitProtocols,
+			p.limitTunnelProtocols,
+			p.limitTunnelDialPortNumbers,
 			excludeIntensive)) > 0
 }
 
@@ -1413,16 +1416,19 @@ func (p *protocolSelectionConstraints) supportedProtocols(
 	excludeIntensive bool,
 	serverEntry *protocol.ServerEntry) []string {
 
-	limitProtocols := p.limitProtocols
+	limitTunnelProtocols := p.limitTunnelProtocols
 
-	if len(p.initialLimitProtocols) > 0 && p.initialLimitProtocolsCandidateCount > connectTunnelCount {
-		limitProtocols = p.initialLimitProtocols
+	if len(p.initialLimitTunnelProtocols) > 0 &&
+		p.initialLimitTunnelProtocolsCandidateCount > connectTunnelCount {
+
+		limitTunnelProtocols = p.initialLimitTunnelProtocols
 	}
 
 	return serverEntry.GetSupportedProtocols(
 		conditionallyEnabledComponents{},
 		p.useUpstreamProxy,
-		limitProtocols,
+		limitTunnelProtocols,
+		p.limitTunnelDialPortNumbers,
 		excludeIntensive)
 }
 
@@ -1578,11 +1584,15 @@ func (controller *Controller) launchEstablishing() {
 	p := controller.config.GetParameters().Get()
 
 	controller.protocolSelectionConstraints = &protocolSelectionConstraints{
-		useUpstreamProxy:                    controller.config.UseUpstreamProxy(),
-		initialLimitProtocols:               p.TunnelProtocols(parameters.InitialLimitTunnelProtocols),
-		initialLimitProtocolsCandidateCount: p.Int(parameters.InitialLimitTunnelProtocolsCandidateCount),
-		limitProtocols:                      p.TunnelProtocols(parameters.LimitTunnelProtocols),
-		replayCandidateCount:                p.Int(parameters.ReplayCandidateCount),
+		useUpstreamProxy:                          controller.config.UseUpstreamProxy(),
+		initialLimitTunnelProtocols:               p.TunnelProtocols(parameters.InitialLimitTunnelProtocols),
+		initialLimitTunnelProtocolsCandidateCount: p.Int(parameters.InitialLimitTunnelProtocolsCandidateCount),
+		limitTunnelProtocols:                      p.TunnelProtocols(parameters.LimitTunnelProtocols),
+
+		limitTunnelDialPortNumbers: protocol.TunnelProtocolPortLists(
+			p.TunnelProtocolPortLists(parameters.LimitTunnelDialPortNumbers)),
+
+		replayCandidateCount: p.Int(parameters.ReplayCandidateCount),
 	}
 
 	// ConnectionWorkerPoolSize may be set by tactics.
@@ -1626,7 +1636,7 @@ func (controller *Controller) launchEstablishing() {
 	// proceeding.
 
 	awaitResponse := tunnelPoolSize > 1 ||
-		controller.protocolSelectionConstraints.initialLimitProtocolsCandidateCount > 0
+		controller.protocolSelectionConstraints.initialLimitTunnelProtocolsCandidateCount > 0
 
 	// AvailableEgressRegions: after a fresh install, the outer client may not
 	// have a list of regions to display; and LimitTunnelProtocols may reduce the
@@ -1720,11 +1730,11 @@ func (controller *Controller) launchEstablishing() {
 		// protocols may have some bad effect, such as a firewall blocking all
 		// traffic from a host.
 
-		if controller.protocolSelectionConstraints.initialLimitProtocolsCandidateCount > 0 {
+		if controller.protocolSelectionConstraints.initialLimitTunnelProtocolsCandidateCount > 0 {
 
 			if reportResponse.initialCandidatesAnyEgressRegion == 0 {
 				NoticeWarning("skipping initial limit tunnel protocols")
-				controller.protocolSelectionConstraints.initialLimitProtocolsCandidateCount = 0
+				controller.protocolSelectionConstraints.initialLimitTunnelProtocolsCandidateCount = 0
 
 				// Since we were unable to satisfy the InitialLimitTunnelProtocols
 				// tactic, trigger RSL, OSL, and upgrade fetches to potentially

+ 6 - 1
psiphon/dataStore.go

@@ -671,7 +671,11 @@ func newTargetServerEntryIterator(config *Config, isTactics bool) (bool, *Server
 			return false, nil, errors.TraceNew("TargetServerEntry does not support EgressRegion")
 		}
 
-		limitTunnelProtocols := config.GetParameters().Get().TunnelProtocols(parameters.LimitTunnelProtocols)
+		p := config.GetParameters().Get()
+		limitTunnelProtocols := p.TunnelProtocols(parameters.LimitTunnelProtocols)
+		limitTunnelDialPortNumbers := protocol.TunnelProtocolPortLists(
+			p.TunnelProtocolPortLists(parameters.LimitTunnelDialPortNumbers))
+
 		if len(limitTunnelProtocols) > 0 {
 			// At the ServerEntryIterator level, only limitTunnelProtocols is applied;
 			// excludeIntensive is handled higher up.
@@ -679,6 +683,7 @@ func newTargetServerEntryIterator(config *Config, isTactics bool) (bool, *Server
 				conditionallyEnabledComponents{},
 				config.UseUpstreamProxy(),
 				limitTunnelProtocols,
+				limitTunnelDialPortNumbers,
 				false)) == 0 {
 				return false, nil, errors.Tracef(
 					"TargetServerEntry does not support LimitTunnelProtocols: %v", limitTunnelProtocols)

+ 30 - 41
psiphon/dialParameters.go

@@ -26,6 +26,7 @@ import (
 	"fmt"
 	"net"
 	"net/http"
+	"strconv"
 	"strings"
 	"sync/atomic"
 	"time"
@@ -708,40 +709,27 @@ func MakeDialParameters(
 	// Set dial address fields. This portion of configuration is
 	// deterministic, given the parameters established or replayed so far.
 
-	switch dialParams.TunnelProtocol {
+	dialPortNumber, err := serverEntry.GetDialPortNumber(dialParams.TunnelProtocol)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
 
-	case protocol.TUNNEL_PROTOCOL_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshPort)
+	dialParams.DialPortNumber = strconv.Itoa(dialPortNumber)
 
-	case protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort)
+	switch dialParams.TunnelProtocol {
 
-	case protocol.TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedTapDancePort)
+	case protocol.TUNNEL_PROTOCOL_SSH,
+		protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
+		protocol.TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH,
+		protocol.TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH,
+		protocol.TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
 
-	case protocol.TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedConjurePort)
+		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
 
-	case protocol.TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedQUICPort)
+	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK,
+		protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
 
-	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:443", dialParams.MeekFrontingDialAddress)
-		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
-		if serverEntry.MeekFrontingDisableSNI {
-			dialParams.MeekSNIServerName = ""
-			// When SNI is omitted, the transformed host name is not used.
-			dialParams.MeekTransformedHostName = false
-		} else if !dialParams.MeekTransformedHostName {
-			dialParams.MeekSNIServerName = dialParams.MeekFrontingDialAddress
-		}
-
-	case protocol.TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
-		// Note: port comes from marionnete "format"
-		dialParams.DirectDialAddress = serverEntry.IpAddress
-
-	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:443", dialParams.MeekFrontingDialAddress)
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", dialParams.MeekFrontingDialAddress, dialPortNumber)
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		if serverEntry.MeekFrontingDisableSNI {
 			dialParams.MeekSNIServerName = ""
@@ -752,15 +740,17 @@ func MakeDialParameters(
 		}
 
 	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:80", dialParams.MeekFrontingDialAddress)
+
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", dialParams.MeekFrontingDialAddress, dialPortNumber)
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		// For FRONTED HTTP, the Host header cannot be transformed.
 		dialParams.MeekTransformedHostName = false
 
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK:
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
+
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
 		if !dialParams.MeekTransformedHostName {
-			if serverEntry.MeekServerPort == 80 {
+			if dialPortNumber == 80 {
 				dialParams.MeekHostHeader = serverEntry.IpAddress
 			} else {
 				dialParams.MeekHostHeader = dialParams.MeekDialAddress
@@ -768,19 +758,24 @@ func MakeDialParameters(
 		}
 
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
-		protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET:
 
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
+		protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET:
+		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
 		if !dialParams.MeekTransformedHostName {
 			// Note: IP address in SNI field will be omitted.
 			dialParams.MeekSNIServerName = serverEntry.IpAddress
 		}
-		if serverEntry.MeekServerPort == 443 {
+		if dialPortNumber == 443 {
 			dialParams.MeekHostHeader = serverEntry.IpAddress
 		} else {
 			dialParams.MeekHostHeader = dialParams.MeekDialAddress
 		}
 
+	case protocol.TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
+
+		// Note: port comes from marionnete "format"
+		dialParams.DirectDialAddress = serverEntry.IpAddress
+
 	default:
 		return nil, errors.Tracef(
 			"unknown tunnel protocol: %s", dialParams.TunnelProtocol)
@@ -789,7 +784,7 @@ func MakeDialParameters(
 
 	if protocol.TunnelProtocolUsesMeek(dialParams.TunnelProtocol) {
 
-		host, port, _ := net.SplitHostPort(dialParams.MeekDialAddress)
+		host, _, _ := net.SplitHostPort(dialParams.MeekDialAddress)
 
 		if p.Bool(parameters.MeekDialDomainsOnly) {
 			if net.ParseIP(host) != nil {
@@ -798,17 +793,11 @@ func MakeDialParameters(
 			}
 		}
 
-		dialParams.DialPortNumber = port
-
 		// The underlying TLS will automatically disable SNI for IP address server name
 		// values; we have this explicit check here so we record the correct value for stats.
 		if net.ParseIP(dialParams.MeekSNIServerName) != nil {
 			dialParams.MeekSNIServerName = ""
 		}
-
-	} else {
-
-		_, dialParams.DialPortNumber, _ = net.SplitHostPort(dialParams.DirectDialAddress)
 	}
 
 	// Initialize/replay User-Agent header for HTTP upstream proxy and meek protocols.

+ 116 - 11
psiphon/dialParameters_test.go

@@ -87,7 +87,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	applyParameters[parameters.HoldOffTunnelProtocols] = holdOffTunnelProtocols
 	applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.HoldOffTunnelProbability] = 1.0
-	err = clientConfig.SetParameters("tag1", true, applyParameters)
+	err = clientConfig.SetParameters("tag1", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -346,7 +346,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	// Test: no replay after change tactics
 
 	applyParameters[parameters.ReplayDialParametersTTL] = "1s"
-	err = clientConfig.SetParameters("tag2", true, applyParameters)
+	err = clientConfig.SetParameters("tag2", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -400,7 +400,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	applyParameters[parameters.ReplayObfuscatedQUIC] = false
 	applyParameters[parameters.ReplayLivenessTest] = false
 	applyParameters[parameters.ReplayAPIRequestPadding] = false
-	err = clientConfig.SetParameters("tag3", true, applyParameters)
+	err = clientConfig.SetParameters("tag3", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -442,7 +442,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	applyParameters[parameters.RestrictFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 1.0
-	err = clientConfig.SetParameters("tag4", true, applyParameters)
+	err = clientConfig.SetParameters("tag4", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -462,7 +462,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 
 	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 0.0
-	err = clientConfig.SetParameters("tag5", true, applyParameters)
+	err = clientConfig.SetParameters("tag5", false, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
@@ -558,6 +558,110 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 }
 
+func TestLimitTunnelDialPortNumbers(t *testing.T) {
+
+	testDataDirName, err := ioutil.TempDir("", "psiphon-limit-tunnel-dial-port-numbers-test")
+	if err != nil {
+		t.Fatalf("TempDir failed: %s", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	SetNoticeWriter(ioutil.Discard)
+
+	clientConfig := &Config{
+		PropagationChannelId: "0",
+		SponsorId:            "0",
+		DataRootDirectory:    testDataDirName,
+		NetworkIDGetter:      new(testNetworkGetter),
+	}
+
+	err = clientConfig.Commit(false)
+	if err != nil {
+		t.Fatalf("error committing configuration file: %s", err)
+	}
+
+	jsonLimitDialPortNumbers := `
+    {
+        "SSH" : [[10,11]],
+        "OSSH" : [[20,21]],
+        "QUIC-OSSH" : [[30,31]],
+        "TAPDANCE-OSSH" : [[40,41]],
+        "CONJURE-OSSH" : [[50,51]],
+        "All" : [[60,61],80,443]
+    }
+    `
+
+	var limitTunnelDialPortNumbers parameters.TunnelProtocolPortLists
+	err = json.Unmarshal([]byte(jsonLimitDialPortNumbers), &limitTunnelDialPortNumbers)
+	if err != nil {
+		t.Fatalf("Unmarshal failed: %s", err)
+	}
+
+	applyParameters := make(map[string]interface{})
+	applyParameters[parameters.LimitTunnelDialPortNumbers] = limitTunnelDialPortNumbers
+	applyParameters[parameters.LimitTunnelDialPortNumbersProbability] = 1.0
+	err = clientConfig.SetParameters("tag1", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	constraints := &protocolSelectionConstraints{
+		limitTunnelDialPortNumbers: protocol.TunnelProtocolPortLists(
+			clientConfig.GetParameters().Get().TunnelProtocolPortLists(parameters.LimitTunnelDialPortNumbers)),
+	}
+
+	selectProtocol := func(serverEntry *protocol.ServerEntry) (string, bool) {
+		return constraints.selectProtocol(0, false, serverEntry)
+	}
+
+	for _, tunnelProtocol := range protocol.SupportedTunnelProtocols {
+
+		if common.Contains(protocol.DefaultDisabledTunnelProtocols, tunnelProtocol) {
+			continue
+		}
+
+		serverEntries := makeMockServerEntries(tunnelProtocol, "", 100)
+
+		selected := false
+		skipped := false
+
+		for _, serverEntry := range serverEntries {
+
+			selectedProtocol, ok := selectProtocol(serverEntry)
+
+			if ok {
+
+				if selectedProtocol != tunnelProtocol {
+					t.Fatalf("unexpected selected protocol: %s", selectedProtocol)
+				}
+
+				port, err := serverEntry.GetDialPortNumber(selectedProtocol)
+				if err != nil {
+					t.Fatalf("GetDialPortNumber failed: %s", err)
+				}
+
+				if port%10 != 0 && port%10 != 1 && !protocol.TunnelProtocolUsesFrontedMeek(selectedProtocol) {
+					t.Fatalf("unexpected dial port number: %d", port)
+				}
+
+				selected = true
+
+			} else {
+
+				skipped = true
+			}
+		}
+
+		if !selected {
+			t.Fatalf("expected at least one selected server entry: %s", tunnelProtocol)
+		}
+
+		if !skipped && !protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
+			t.Fatalf("expected at least one skipped server entry: %s", tunnelProtocol)
+		}
+	}
+}
+
 func makeMockServerEntries(
 	tunnelProtocol string,
 	frontingProviderID string,
@@ -568,17 +672,18 @@ func makeMockServerEntries(
 	for i := 0; i < count; i++ {
 		serverEntries[i] = &protocol.ServerEntry{
 			IpAddress:                  fmt.Sprintf("192.168.0.%d", i),
-			SshPort:                    1,
-			SshObfuscatedPort:          2,
-			SshObfuscatedQUICPort:      3,
-			SshObfuscatedTapDancePort:  4,
-			SshObfuscatedConjurePort:   5,
-			MeekServerPort:             6,
+			SshPort:                    prng.Range(10, 19),
+			SshObfuscatedPort:          prng.Range(20, 29),
+			SshObfuscatedQUICPort:      prng.Range(30, 39),
+			SshObfuscatedTapDancePort:  prng.Range(40, 49),
+			SshObfuscatedConjurePort:   prng.Range(50, 59),
+			MeekServerPort:             prng.Range(60, 69),
 			MeekFrontingHosts:          []string{"www1.example.org", "www2.example.org", "www3.example.org"},
 			MeekFrontingAddressesRegex: "[a-z0-9]{1,64}.example.org",
 			FrontingProviderID:         frontingProviderID,
 			LocalSource:                protocol.SERVER_ENTRY_SOURCE_EMBEDDED,
 			LocalTimestamp:             common.TruncateTimestampToHour(common.GetCurrentTimestamp()),
+			Capabilities:               []string{protocol.GetCapability(tunnelProtocol)},
 		}
 	}
 

+ 4 - 3
psiphon/notice.go

@@ -425,9 +425,10 @@ func NoticeCandidateServers(
 	singletonNoticeLogger.outputNotice(
 		"CandidateServers", noticeIsDiagnostic,
 		"region", region,
-		"initialLimitTunnelProtocols", constraints.initialLimitProtocols,
-		"initialLimitTunnelProtocolsCandidateCount", constraints.initialLimitProtocolsCandidateCount,
-		"limitTunnelProtocols", constraints.limitProtocols,
+		"initialLimitTunnelProtocols", constraints.initialLimitTunnelProtocols,
+		"initialLimitTunnelProtocolsCandidateCount", constraints.initialLimitTunnelProtocolsCandidateCount,
+		"limitTunnelProtocols", constraints.limitTunnelProtocols,
+		"limitTunnelDialPortNumbers", constraints.limitTunnelDialPortNumbers,
 		"replayCandidateCount", constraints.replayCandidateCount,
 		"initialCount", initialCount,
 		"count", count,

+ 6 - 6
psiphon/server/server_test.go

@@ -1998,12 +1998,12 @@ func paveTrafficRulesFile(
 
 	allowTCPPorts := TCPPorts
 	allowUDPPorts := UDPPorts
-	disallowTCPPorts := "0"
-	disallowUDPPorts := "0"
+	disallowTCPPorts := "1"
+	disallowUDPPorts := "1"
 
 	if deny {
-		allowTCPPorts = "0"
-		allowUDPPorts = "0"
+		allowTCPPorts = "1"
+		allowUDPPorts = "1"
 		disallowTCPPorts = TCPPorts
 		disallowUDPPorts = UDPPorts
 	}
@@ -2033,8 +2033,8 @@ func paveTrafficRulesFile(
                 "ReadUnthrottledBytes": %d,
                 "WriteUnthrottledBytes": %d
             },
-            "AllowTCPPorts" : [0],
-            "AllowUDPPorts" : [0],
+            "AllowTCPPorts" : [1],
+            "AllowUDPPorts" : [1],
             "MeekRateLimiterHistorySize" : 10,
             "MeekRateLimiterThresholdSeconds" : 1,
             "MeekRateLimiterGarbageCollectionTriggerCount" : 1,

+ 18 - 93
psiphon/server/trafficRules.go

@@ -236,21 +236,21 @@ type TrafficRules struct {
 
 	// AllowTCPPorts specifies a list of TCP ports that are permitted for port
 	// forwarding. When set, only ports in the list are accessible to clients.
-	AllowTCPPorts []int
+	AllowTCPPorts *common.PortList
 
 	// AllowUDPPorts specifies a list of UDP ports that are permitted for port
 	// forwarding. When set, only ports in the list are accessible to clients.
-	AllowUDPPorts []int
+	AllowUDPPorts *common.PortList
 
 	// DisallowTCPPorts specifies a list of TCP ports that are not permitted for
 	// port forwarding. DisallowTCPPorts takes priority over AllowTCPPorts and
 	// AllowSubnets.
-	DisallowTCPPorts []int
+	DisallowTCPPorts *common.PortList
 
 	// DisallowUDPPorts specifies a list of UDP ports that are not permitted for
 	// port forwarding. DisallowUDPPorts takes priority over AllowUDPPorts and
 	// AllowSubnets.
-	DisallowUDPPorts []int
+	DisallowUDPPorts *common.PortList
 
 	// AllowSubnets specifies a list of IP address subnets for which all TCP and
 	// UDP ports are allowed. This list is consulted if a port is disallowed by
@@ -261,11 +261,6 @@ type TrafficRules struct {
 	// client sends an IP address. Domain names are not resolved before checking
 	// AllowSubnets.
 	AllowSubnets []string
-
-	allowTCPPortsLookup    map[int]bool
-	allowUDPPortsLookup    map[int]bool
-	disallowTCPPortsLookup map[int]bool
-	disallowUDPPortsLookup map[int]bool
 }
 
 // RateLimits is a clone of common.RateLimits with pointers
@@ -434,33 +429,11 @@ func (set *TrafficRulesSet) initLookups() {
 
 	initTrafficRulesLookups := func(rules *TrafficRules) {
 
-		if len(rules.AllowTCPPorts) >= intLookupThreshold {
-			rules.allowTCPPortsLookup = make(map[int]bool)
-			for _, port := range rules.AllowTCPPorts {
-				rules.allowTCPPortsLookup[port] = true
-			}
-		}
-
-		if len(rules.AllowUDPPorts) >= intLookupThreshold {
-			rules.allowUDPPortsLookup = make(map[int]bool)
-			for _, port := range rules.AllowUDPPorts {
-				rules.allowUDPPortsLookup[port] = true
-			}
-		}
-
-		if len(rules.DisallowTCPPorts) >= intLookupThreshold {
-			rules.disallowTCPPortsLookup = make(map[int]bool)
-			for _, port := range rules.DisallowTCPPorts {
-				rules.disallowTCPPortsLookup[port] = true
-			}
-		}
+		rules.AllowTCPPorts.OptimizeLookups()
+		rules.AllowUDPPorts.OptimizeLookups()
+		rules.DisallowTCPPorts.OptimizeLookups()
+		rules.DisallowUDPPorts.OptimizeLookups()
 
-		if len(rules.DisallowUDPPorts) >= intLookupThreshold {
-			rules.disallowUDPPortsLookup = make(map[int]bool)
-			for _, port := range rules.DisallowUDPPorts {
-				rules.disallowUDPPortsLookup[port] = true
-			}
-		}
 	}
 
 	initTrafficRulesFilterLookups := func(filter *TrafficRulesFilter) {
@@ -600,14 +573,6 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			intPtr(DEFAULT_MAX_UDP_PORT_FORWARD_COUNT)
 	}
 
-	if trafficRules.AllowTCPPorts == nil {
-		trafficRules.AllowTCPPorts = make([]int, 0)
-	}
-
-	if trafficRules.AllowUDPPorts == nil {
-		trafficRules.AllowUDPPorts = make([]int, 0)
-	}
-
 	if trafficRules.AllowSubnets == nil {
 		trafficRules.AllowSubnets = make([]string, 0)
 	}
@@ -800,22 +765,18 @@ func (set *TrafficRulesSet) GetTrafficRules(
 
 		if filteredRules.Rules.AllowTCPPorts != nil {
 			trafficRules.AllowTCPPorts = filteredRules.Rules.AllowTCPPorts
-			trafficRules.allowTCPPortsLookup = filteredRules.Rules.allowTCPPortsLookup
 		}
 
 		if filteredRules.Rules.AllowUDPPorts != nil {
 			trafficRules.AllowUDPPorts = filteredRules.Rules.AllowUDPPorts
-			trafficRules.allowUDPPortsLookup = filteredRules.Rules.allowUDPPortsLookup
 		}
 
 		if filteredRules.Rules.DisallowTCPPorts != nil {
 			trafficRules.DisallowTCPPorts = filteredRules.Rules.DisallowTCPPorts
-			trafficRules.disallowTCPPortsLookup = filteredRules.Rules.disallowTCPPortsLookup
 		}
 
 		if filteredRules.Rules.DisallowUDPPorts != nil {
 			trafficRules.DisallowUDPPorts = filteredRules.Rules.DisallowUDPPorts
-			trafficRules.disallowUDPPortsLookup = filteredRules.Rules.disallowUDPPortsLookup
 		}
 
 		if filteredRules.Rules.AllowSubnets != nil {
@@ -837,34 +798,16 @@ func (set *TrafficRulesSet) GetTrafficRules(
 
 func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
 
-	if len(rules.DisallowTCPPorts) > 0 {
-		if rules.disallowTCPPortsLookup != nil {
-			if rules.disallowTCPPortsLookup[port] {
-				return false
-			}
-		} else {
-			for _, disallowPort := range rules.DisallowTCPPorts {
-				if port == disallowPort {
-					return false
-				}
-			}
-		}
+	if rules.DisallowTCPPorts.Lookup(port) {
+		return false
 	}
 
-	if len(rules.AllowTCPPorts) == 0 {
+	if rules.AllowTCPPorts.IsEmpty() {
 		return true
 	}
 
-	if rules.allowTCPPortsLookup != nil {
-		if rules.allowTCPPortsLookup[port] {
-			return true
-		}
-	} else {
-		for _, allowPort := range rules.AllowTCPPorts {
-			if port == allowPort {
-				return true
-			}
-		}
+	if rules.AllowTCPPorts.Lookup(port) {
+		return true
 	}
 
 	return rules.allowSubnet(remoteIP)
@@ -872,34 +815,16 @@ func (rules *TrafficRules) AllowTCPPort(remoteIP net.IP, port int) bool {
 
 func (rules *TrafficRules) AllowUDPPort(remoteIP net.IP, port int) bool {
 
-	if len(rules.DisallowUDPPorts) > 0 {
-		if rules.disallowUDPPortsLookup != nil {
-			if rules.disallowUDPPortsLookup[port] {
-				return false
-			}
-		} else {
-			for _, disallowPort := range rules.DisallowUDPPorts {
-				if port == disallowPort {
-					return false
-				}
-			}
-		}
+	if rules.DisallowUDPPorts.Lookup(port) {
+		return false
 	}
 
-	if len(rules.AllowUDPPorts) == 0 {
+	if rules.AllowUDPPorts.IsEmpty() {
 		return true
 	}
 
-	if rules.allowUDPPortsLookup != nil {
-		if rules.allowUDPPortsLookup[port] {
-			return true
-		}
-	} else {
-		for _, allowPort := range rules.AllowUDPPorts {
-			if port == allowPort {
-				return true
-			}
-		}
+	if rules.AllowUDPPorts.Lookup(port) {
+		return true
 	}
 
 	return rules.allowSubnet(remoteIP)