Procházet zdrojové kódy

Tactics changes

- Add individual probabilities for tunnel protocol and
  TLS profile lists and apply automatically in Get()

- Change client validation of tunnel protocol and TLS
  profile lists to accept lists with unknown (new)
  protocols and profiles; instead of rejecting the
  parameter values, the values are now pruned to
  include only known protocols/profiles.
Rod Hynes před 7 roky
rodič
revize
ad1e7ddde6

+ 61 - 11
psiphon/common/parameters/clientParameters.go

@@ -82,9 +82,12 @@ const (
 	StaggerConnectionWorkersJitter             = "StaggerConnectionWorkersJitter"
 	LimitIntensiveConnectionWorkers            = "LimitIntensiveConnectionWorkers"
 	IgnoreHandshakeStatsRegexps                = "IgnoreHandshakeStatsRegexps"
+	PrioritizeTunnelProtocolsProbability       = "PrioritizeTunnelProtocolsProbability"
 	PrioritizeTunnelProtocols                  = "PrioritizeTunnelProtocols"
 	PrioritizeTunnelProtocolsCandidateCount    = "PrioritizeTunnelProtocolsCandidateCount"
+	LimitTunnelProtocolsProbability            = "LimitTunnelProtocolsProbability"
 	LimitTunnelProtocols                       = "LimitTunnelProtocols"
+	LimitTLSProfilesProbability                = "LimitTLSProfilesProbability"
 	LimitTLSProfiles                           = "LimitTLSProfiles"
 	FragmentorProbability                      = "FragmentorProbability"
 	FragmentorLimitProtocols                   = "FragmentorLimitProtocols"
@@ -213,11 +216,14 @@ var defaultClientParameters = map[string]struct {
 	// the first establishment round. Even then, this will only happen if the
 	// client has sufficient candidates supporting the prioritized protocols.
 
+	PrioritizeTunnelProtocolsProbability:    {value: 1.0, minimum: 0.0},
 	PrioritizeTunnelProtocols:               {value: protocol.TunnelProtocols{}},
 	PrioritizeTunnelProtocolsCandidateCount: {value: 10, minimum: 0},
+	LimitTunnelProtocolsProbability:         {value: 1.0, minimum: 0.0},
 	LimitTunnelProtocols:                    {value: protocol.TunnelProtocols{}},
 
-	LimitTLSProfiles: {value: protocol.TLSProfiles{}},
+	LimitTLSProfilesProbability: {value: 1.0, minimum: 0.0},
+	LimitTLSProfiles:            {value: protocol.TLSProfiles{}},
 
 	FragmentorProbability:    {value: 0.5, minimum: 0.0},
 	FragmentorLimitProtocols: {value: protocol.TunnelProtocols{}},
@@ -400,6 +406,10 @@ func makeDefaultParameters() (map[string]interface{}, error) {
 // When skipOnError is true, unknown or invalid parameters in any
 // applyParameters are skipped instead of aborting with an error.
 //
+// For protocol.TunnelProtocols and protocol.TLSProfiles type values, when
+// skipOnError is true the values are filtered instead of validated, so
+// only known tunnel protocols and TLS profiles are retained.
+//
 // When an error is returned, the previous parameters remain completely
 // unmodified.
 //
@@ -478,20 +488,22 @@ func (p *ClientParameters) Set(
 					return nil, common.ContextError(err)
 				}
 			case protocol.TunnelProtocols:
-				err := v.Validate()
-				if err != nil {
-					if skipOnError {
-						continue
+				if skipOnError {
+					newValue = v.PruneInvalid()
+				} else {
+					err := v.Validate()
+					if err != nil {
+						return nil, common.ContextError(err)
 					}
-					return nil, common.ContextError(err)
 				}
 			case protocol.TLSProfiles:
-				err := v.Validate()
-				if err != nil {
-					if skipOnError {
-						continue
+				if skipOnError {
+					newValue = v.PruneInvalid()
+				} else {
+					err := v.Validate()
+					if err != nil {
+						return nil, common.ContextError(err)
 					}
-					return nil, common.ContextError(err)
 				}
 			}
 
@@ -671,14 +683,52 @@ func (p *ClientParametersSnapshot) Duration(name string) time.Duration {
 }
 
 // TunnelProtocols returns a protocol.TunnelProtocols parameter value.
+// If there is a corresponding Probability value, a weighted coin flip
+// will be performed and, depending on the result, the value or the
+// parameter default will be returned.
 func (p *ClientParametersSnapshot) TunnelProtocols(name string) protocol.TunnelProtocols {
+
+	probabilityName := name + "Probability"
+	probabilityValue := float64(1.0)
+	p.getValue(probabilityName, &probabilityValue)
+	if !common.FlipWeightedCoin(probabilityValue) {
+		defaultParameter, ok := defaultClientParameters[name]
+		if ok {
+			defaultValue, ok := defaultParameter.value.(protocol.TunnelProtocols)
+			if ok {
+				value := make(protocol.TunnelProtocols, len(defaultValue))
+				copy(value, defaultValue)
+				return value
+			}
+		}
+	}
+
 	value := protocol.TunnelProtocols{}
 	p.getValue(name, &value)
 	return value
 }
 
 // TLSProfiles returns a protocol.TLSProfiles parameter value.
+// If there is a corresponding Probability value, a weighted coin flip
+// will be performed and, depending on the result, the value or the
+// parameter default will be returned.
 func (p *ClientParametersSnapshot) TLSProfiles(name string) protocol.TLSProfiles {
+
+	probabilityName := name + "Probability"
+	probabilityValue := float64(1.0)
+	p.getValue(probabilityName, &probabilityValue)
+	if !common.FlipWeightedCoin(probabilityValue) {
+		defaultParameter, ok := defaultClientParameters[name]
+		if ok {
+			defaultValue, ok := defaultParameter.value.(protocol.TLSProfiles)
+			if ok {
+				value := make(protocol.TLSProfiles, len(defaultValue))
+				copy(value, defaultValue)
+				return value
+			}
+		}
+	}
+
 	value := protocol.TLSProfiles{}
 	p.getValue(name, &value)
 	return value

+ 55 - 0
psiphon/common/parameters/clientParameters_test.go

@@ -197,6 +197,61 @@ func TestNetworkLatencyMultiplier(t *testing.T) {
 
 	if 2*timeout1 != timeout2 {
 		t.Fatalf("Unexpected timeouts: 2 * %s != %s", timeout1, timeout2)
+	}
+}
+
+func TestLimitTunnelProtocolProbability(t *testing.T) {
+	p, err := NewClientParameters(nil)
+	if err != nil {
+		t.Fatalf("NewClientParameters failed: %s", err)
+	}
+
+	// Default probability should be 1.0 and always return tunnelProtocols
+
+	tunnelProtocols := protocol.TunnelProtocols{"OSSH", "SSH"}
+
+	applyParameters := map[string]interface{}{
+		"LimitTunnelProtocols": tunnelProtocols,
+	}
+
+	_, err = p.Set("", false, applyParameters)
+	if err != nil {
+		t.Fatalf("Set failed: %s", err)
+	}
+
+	for i := 0; i < 1000; i++ {
+		l := p.Get().TunnelProtocols(LimitTunnelProtocols)
+		if !reflect.DeepEqual(l, tunnelProtocols) {
+			t.Fatalf("unexpected %+v != %+v", l, tunnelProtocols)
+		}
+	}
+
+	// With probability set to 0.5, should return tunnelProtocols ~50%
+
+	defaultLimitTunnelProtocols := protocol.TunnelProtocols{}
+
+	applyParameters = map[string]interface{}{
+		"LimitTunnelProtocolsProbability": 0.5,
+		"LimitTunnelProtocols":            tunnelProtocols,
+	}
+
+	_, err = p.Set("", false, applyParameters)
+	if err != nil {
+		t.Fatalf("Set failed: %s", err)
+	}
+
+	matchCount := 0
+
+	for i := 0; i < 1000; i++ {
+		l := p.Get().TunnelProtocols(LimitTunnelProtocols)
+		if reflect.DeepEqual(l, tunnelProtocols) {
+			matchCount += 1
+		} else if !reflect.DeepEqual(l, defaultLimitTunnelProtocols) {
+			t.Fatalf("unexpected %+v != %+v", l, defaultLimitTunnelProtocols)
+		}
+	}
 
+	if matchCount < 250 || matchCount > 750 {
+		t.Fatalf("Unexpected probability result: %d", matchCount)
 	}
 }

+ 20 - 0
psiphon/common/protocol/protocol.go

@@ -77,6 +77,16 @@ func (t TunnelProtocols) Validate() error {
 	return nil
 }
 
+func (t TunnelProtocols) PruneInvalid() TunnelProtocols {
+	u := make(TunnelProtocols, 0)
+	for _, p := range t {
+		if common.Contains(SupportedTunnelProtocols, p) {
+			u = append(u, p)
+		}
+	}
+	return u
+}
+
 var SupportedTunnelProtocols = TunnelProtocols{
 	TUNNEL_PROTOCOL_SSH,
 	TUNNEL_PROTOCOL_OBFUSCATED_SSH,
@@ -187,6 +197,16 @@ func (profiles TLSProfiles) Validate() error {
 	return nil
 }
 
+func (profiles TLSProfiles) PruneInvalid() TLSProfiles {
+	q := make(TLSProfiles, 0)
+	for _, p := range profiles {
+		if common.Contains(SupportedTLSProfiles, p) {
+			q = append(q, p)
+		}
+	}
+	return q
+}
+
 type HandshakeResponse struct {
 	SSHSessionID           string              `json:"ssh_session_id"`
 	Homepages              []string            `json:"homepages"`

+ 80 - 0
psiphon/common/protocol/protocol_test.go

@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2018, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package protocol
+
+import (
+	"fmt"
+	"reflect"
+	"testing"
+)
+
+func TestTunnelProtocolValidation(t *testing.T) {
+
+	err := SupportedTunnelProtocols.Validate()
+	if err != nil {
+		t.Errorf("unexpected Validate error: %s", err)
+	}
+
+	invalidProtocols := TunnelProtocols{"OSSH", "INVALID-PROTOCOL"}
+	err = invalidProtocols.Validate()
+	if err == nil {
+		t.Errorf("unexpected Validate success")
+	}
+
+	pruneProtocols := make(TunnelProtocols, 0)
+	for i, p := range SupportedTunnelProtocols {
+		pruneProtocols = append(pruneProtocols, fmt.Sprintf("INVALID-PROTOCOL-%d", i))
+		pruneProtocols = append(pruneProtocols, p)
+	}
+	pruneProtocols = append(pruneProtocols, fmt.Sprintf("INVALID-PROTOCOL-%d", len(SupportedTunnelProtocols)))
+
+	prunedProtocols := pruneProtocols.PruneInvalid()
+
+	if !reflect.DeepEqual(prunedProtocols, SupportedTunnelProtocols) {
+		t.Errorf("unexpected %+v != %+v", prunedProtocols, SupportedTunnelProtocols)
+	}
+}
+
+func TestTLSProfileValidation(t *testing.T) {
+
+	err := SupportedTLSProfiles.Validate()
+	if err != nil {
+		t.Errorf("unexpected Validate error: %s", err)
+	}
+
+	invalidProfiles := TLSProfiles{"OSSH", "INVALID-PROTOCOL"}
+	err = invalidProfiles.Validate()
+	if err == nil {
+		t.Errorf("unexpected Validate success")
+	}
+
+	pruneProfiles := make(TLSProfiles, 0)
+	for i, p := range SupportedTLSProfiles {
+		pruneProfiles = append(pruneProfiles, fmt.Sprintf("INVALID-PROFILE-%d", i))
+		pruneProfiles = append(pruneProfiles, p)
+	}
+	pruneProfiles = append(pruneProfiles, fmt.Sprintf("INVALID-PROFILE-%d", len(SupportedTLSProfiles)))
+
+	prunedProfiles := pruneProfiles.PruneInvalid()
+
+	if !reflect.DeepEqual(prunedProfiles, SupportedTLSProfiles) {
+		t.Errorf("unexpected %+v != %+v", prunedProfiles, SupportedTLSProfiles)
+	}
+}