Просмотр исходного кода

Add provider ID

- Added new tactics HoldOffDirectServerEntryRegions,
  HoldOffDirectServerEntryProviderRegions, and
  RestrictDirectProviderIDs.
mirokuratczyk 2 лет назад
Родитель
Сommit
cc9bc8430d

+ 37 - 0
psiphon/common/parameters/keyStrings.go

@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package parameters
+
+// KeyStrings represents a set of key/strings pairs.
+type KeyStrings map[string][]string
+
+// Validates that the keys and values are well formed.
+func (keyStrings KeyStrings) Validate() error {
+	// Always succeeds because KeyStrings is generic and does not impose any
+	// restrictions on keys/values. Consider imposing limits like maximum
+	// map/array/string sizes.
+	return nil
+}
+
+func (p ParametersAccessor) KeyStrings(name, key string) []string {
+	value := KeyStrings{}
+	p.snapshot.getValue(name, &value)
+	return value[key]
+}

+ 26 - 0
psiphon/common/parameters/parameters.go

@@ -311,6 +311,14 @@ const (
 	HoldOffTunnelProtocols                           = "HoldOffTunnelProtocols"
 	HoldOffTunnelProtocols                           = "HoldOffTunnelProtocols"
 	HoldOffTunnelFrontingProviderIDs                 = "HoldOffTunnelFrontingProviderIDs"
 	HoldOffTunnelFrontingProviderIDs                 = "HoldOffTunnelFrontingProviderIDs"
 	HoldOffTunnelProbability                         = "HoldOffTunnelProbability"
 	HoldOffTunnelProbability                         = "HoldOffTunnelProbability"
+	HoldOffDirectTunnelMinDuration                   = "HoldOffDirectTunnelMinDuration"
+	HoldOffDirectTunnelMaxDuration                   = "HoldOffDirectTunnelMaxDuration"
+	HoldOffDirectServerEntryRegions                  = "HoldOffDirectServerEntryRegions"
+	HoldOffDirectServerEntryProviderRegions          = "HoldOffDirectServerEntryProviderRegions"
+	HoldOffDirectTunnelProbability                   = "HoldOffDirectTunnelProbability"
+	RestrictDirectProviderIDs                        = "RestrictDirectProviderIDs"
+	RestrictDirectProviderIDsServerProbability       = "RestrictDirectProviderIDsServerProbability"
+	RestrictDirectProviderIDsClientProbability       = "RestrictDirectProviderIDsClientProbability"
 	RestrictFrontingProviderIDs                      = "RestrictFrontingProviderIDs"
 	RestrictFrontingProviderIDs                      = "RestrictFrontingProviderIDs"
 	RestrictFrontingProviderIDsServerProbability     = "RestrictFrontingProviderIDsServerProbability"
 	RestrictFrontingProviderIDsServerProbability     = "RestrictFrontingProviderIDsServerProbability"
 	RestrictFrontingProviderIDsClientProbability     = "RestrictFrontingProviderIDsClientProbability"
 	RestrictFrontingProviderIDsClientProbability     = "RestrictFrontingProviderIDsClientProbability"
@@ -702,6 +710,16 @@ var defaultParameters = map[string]struct {
 	HoldOffTunnelFrontingProviderIDs: {value: []string{}},
 	HoldOffTunnelFrontingProviderIDs: {value: []string{}},
 	HoldOffTunnelProbability:         {value: 0.0, minimum: 0.0},
 	HoldOffTunnelProbability:         {value: 0.0, minimum: 0.0},
 
 
+	HoldOffDirectTunnelMinDuration:          {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffDirectTunnelMaxDuration:          {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffDirectServerEntryRegions:         {value: []string{}},
+	HoldOffDirectServerEntryProviderRegions: {value: KeyStrings{}},
+	HoldOffDirectTunnelProbability:          {value: 0.0, minimum: 0.0},
+
+	RestrictDirectProviderIDs:                  {value: []string{}},
+	RestrictDirectProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
+	RestrictDirectProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
+
 	RestrictFrontingProviderIDs:                  {value: []string{}},
 	RestrictFrontingProviderIDs:                  {value: []string{}},
 	RestrictFrontingProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
 	RestrictFrontingProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
 	RestrictFrontingProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
 	RestrictFrontingProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
@@ -1086,6 +1104,14 @@ func (p *Parameters) Set(
 					}
 					}
 					return nil, errors.Trace(err)
 					return nil, errors.Trace(err)
 				}
 				}
+			case KeyStrings:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
 			case *BPFProgramSpec:
 			case *BPFProgramSpec:
 				if v != nil {
 				if v != nil {
 					err := v.Validate()
 					err := v.Validate()

+ 7 - 0
psiphon/common/parameters/parameters_test.go

@@ -189,6 +189,13 @@ func TestGetDefaultParameters(t *testing.T) {
 			if !reflect.DeepEqual(v, g) {
 			if !reflect.DeepEqual(v, g) {
 				t.Fatalf("ConjureTransports returned %+v expected %+v", g, v)
 				t.Fatalf("ConjureTransports returned %+v expected %+v", g, v)
 			}
 			}
+		case KeyStrings:
+			for key, strings := range v {
+				g := p.Get().KeyStrings(name, key)
+				if !reflect.DeepEqual(strings, g) {
+					t.Fatalf("KeyStrings returned %+v expected %+v", g, strings)
+				}
+			}
 		default:
 		default:
 			t.Fatalf("Unhandled default type: %s (%T)", name, defaults.value)
 			t.Fatalf("Unhandled default type: %s (%T)", name, defaults.value)
 		}
 		}

+ 10 - 0
psiphon/common/protocol/protocol.go

@@ -272,6 +272,16 @@ func TunnelProtocolIsCompatibleWithFragmentor(protocol string) bool {
 		protocol == TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH
 		protocol == TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH
 }
 }
 
 
+func TunnelProtocolIsDirect(protocol string) bool {
+	return protocol == TUNNEL_PROTOCOL_SSH ||
+		protocol == TUNNEL_PROTOCOL_OBFUSCATED_SSH ||
+		protocol == TUNNEL_PROTOCOL_TLS_OBFUSCATED_SSH ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET ||
+		protocol == TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH
+}
+
 func TunnelProtocolRequiresTLS12SessionTickets(protocol string) bool {
 func TunnelProtocolRequiresTLS12SessionTickets(protocol string) bool {
 	return protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET
 	return protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET
 }
 }

+ 5 - 0
psiphon/common/protocol/serverEntry.go

@@ -61,6 +61,7 @@ type ServerEntry struct {
 	SshObfuscatedKey                string   `json:"sshObfuscatedKey"`
 	SshObfuscatedKey                string   `json:"sshObfuscatedKey"`
 	Capabilities                    []string `json:"capabilities"`
 	Capabilities                    []string `json:"capabilities"`
 	Region                          string   `json:"region"`
 	Region                          string   `json:"region"`
+	ProviderID                      string   `json:"providerID"`
 	FrontingProviderID              string   `json:"frontingProviderID"`
 	FrontingProviderID              string   `json:"frontingProviderID"`
 	TlsOSSHPort                     int      `json:"tlsOSSHPort"`
 	TlsOSSHPort                     int      `json:"tlsOSSHPort"`
 	MeekServerPort                  int      `json:"meekServerPort"`
 	MeekServerPort                  int      `json:"meekServerPort"`
@@ -736,6 +737,10 @@ func (serverEntry *ServerEntry) HasSignature() bool {
 	return serverEntry.Signature != ""
 	return serverEntry.Signature != ""
 }
 }
 
 
+func (serverEntry *ServerEntry) HasProviderID() bool {
+	return serverEntry.ProviderID != ""
+}
+
 func (serverEntry *ServerEntry) GetDiagnosticID() string {
 func (serverEntry *ServerEntry) GetDiagnosticID() string {
 	return TagToDiagnosticID(serverEntry.Tag)
 	return TagToDiagnosticID(serverEntry.Tag)
 }
 }

+ 85 - 0
psiphon/config.go

@@ -818,6 +818,19 @@ type Config struct {
 	HoldOffTunnelFrontingProviderIDs     []string
 	HoldOffTunnelFrontingProviderIDs     []string
 	HoldOffTunnelProbability             *float64
 	HoldOffTunnelProbability             *float64
 
 
+	// HoldOffDirectTunnelMinDurationMilliseconds and other HoldOffDirectTunnel
+	// fields are for testing purposes.
+	HoldOffDirectTunnelMinDurationMilliseconds *int
+	HoldOffDirectTunnelMaxDurationMilliseconds *int
+	HoldOffDirectServerEntryRegions            []string
+	HoldOffDirectServerEntryProviderRegions    map[string][]string
+	HoldOffDirectTunnelProbability             *float64
+
+	// RestrictDirectProviderIDs and other RestrictDirectProviderIDs fields
+	// are for testing purposes.
+	RestrictDirectProviderIDs                  []string
+	RestrictDirectProviderIDsClientProbability *float64
+
 	// RestrictFrontingProviderIDs and other RestrictFrontingProviderIDs fields
 	// RestrictFrontingProviderIDs and other RestrictFrontingProviderIDs fields
 	// are for testing purposes.
 	// are for testing purposes.
 	RestrictFrontingProviderIDs                  []string
 	RestrictFrontingProviderIDs                  []string
@@ -1927,6 +1940,34 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.HoldOffTunnelProbability] = *config.HoldOffTunnelProbability
 		applyParameters[parameters.HoldOffTunnelProbability] = *config.HoldOffTunnelProbability
 	}
 	}
 
 
+	if config.HoldOffDirectTunnelMinDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffDirectTunnelMinDuration] = fmt.Sprintf("%dms", *config.HoldOffDirectTunnelMinDurationMilliseconds)
+	}
+
+	if config.HoldOffDirectTunnelMaxDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffDirectTunnelMaxDuration] = fmt.Sprintf("%dms", *config.HoldOffDirectTunnelMaxDurationMilliseconds)
+	}
+
+	if len(config.HoldOffDirectServerEntryRegions) > 0 {
+		applyParameters[parameters.HoldOffDirectServerEntryRegions] = config.HoldOffDirectServerEntryRegions
+	}
+
+	if len(config.HoldOffDirectServerEntryProviderRegions) > 0 {
+		applyParameters[parameters.HoldOffDirectServerEntryProviderRegions] = parameters.KeyStrings(config.HoldOffDirectServerEntryProviderRegions)
+	}
+
+	if config.HoldOffDirectTunnelProbability != nil {
+		applyParameters[parameters.HoldOffDirectTunnelProbability] = *config.HoldOffDirectTunnelProbability
+	}
+
+	if len(config.RestrictDirectProviderIDs) > 0 {
+		applyParameters[parameters.RestrictDirectProviderIDs] = config.RestrictDirectProviderIDs
+	}
+
+	if config.RestrictDirectProviderIDsClientProbability != nil {
+		applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = *config.RestrictDirectProviderIDsClientProbability
+	}
+
 	if len(config.RestrictFrontingProviderIDs) > 0 {
 	if len(config.RestrictFrontingProviderIDs) > 0 {
 		applyParameters[parameters.RestrictFrontingProviderIDs] = config.RestrictFrontingProviderIDs
 		applyParameters[parameters.RestrictFrontingProviderIDs] = config.RestrictFrontingProviderIDs
 	}
 	}
@@ -2414,11 +2455,55 @@ func (config *Config) setDialParametersHash() {
 		}
 		}
 	}
 	}
 
 
+	if config.HoldOffDirectTunnelProbability != nil {
+		hash.Write([]byte("HoldOffDirectTunnelProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.HoldOffDirectTunnelProbability)
+	}
+
+	if config.HoldOffDirectTunnelMinDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffDirectTunnelMinDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffDirectTunnelMinDurationMilliseconds))
+	}
+
+	if config.HoldOffDirectTunnelMaxDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffDirectTunnelMaxDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffDirectTunnelMaxDurationMilliseconds))
+	}
+
+	if len(config.HoldOffDirectServerEntryRegions) > 0 {
+		hash.Write([]byte("HoldOffDirectServerEntryRegions"))
+		for _, region := range config.HoldOffDirectServerEntryRegions {
+			hash.Write([]byte(region))
+		}
+	}
+
+	if len(config.HoldOffDirectServerEntryProviderRegions) > 0 {
+		hash.Write([]byte("HoldOffDirectServerEntryProviderRegions"))
+		for providerID, regions := range config.HoldOffDirectServerEntryProviderRegions {
+			hash.Write([]byte(providerID))
+			for _, region := range regions {
+				hash.Write([]byte(region))
+			}
+		}
+	}
+
 	if config.HoldOffTunnelProbability != nil {
 	if config.HoldOffTunnelProbability != nil {
 		hash.Write([]byte("HoldOffTunnelProbability"))
 		hash.Write([]byte("HoldOffTunnelProbability"))
 		binary.Write(hash, binary.LittleEndian, *config.HoldOffTunnelProbability)
 		binary.Write(hash, binary.LittleEndian, *config.HoldOffTunnelProbability)
 	}
 	}
 
 
+	if len(config.RestrictDirectProviderIDs) > 0 {
+		hash.Write([]byte("RestrictDirectProviderIDs"))
+		for _, providerID := range config.RestrictDirectProviderIDs {
+			hash.Write([]byte(providerID))
+		}
+	}
+
+	if config.RestrictDirectProviderIDsClientProbability != nil {
+		hash.Write([]byte("RestrictDirectProviderIDsClientProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.RestrictDirectProviderIDsClientProbability)
+	}
+
 	if len(config.RestrictFrontingProviderIDs) > 0 {
 	if len(config.RestrictFrontingProviderIDs) > 0 {
 		hash.Write([]byte("RestrictFrontingProviderIDs"))
 		hash.Write([]byte("RestrictFrontingProviderIDs"))
 		for _, providerID := range config.RestrictFrontingProviderIDs {
 		for _, providerID := range config.RestrictFrontingProviderIDs {

+ 47 - 2
psiphon/dialParameters.go

@@ -427,9 +427,31 @@ func MakeDialParameters(
 		dialParams.TunnelProtocol = selectedProtocol
 		dialParams.TunnelProtocol = selectedProtocol
 	}
 	}
 
 
+	// Skip this candidate when the clients tactics restrict usage of the
+	// provider ID. See the corresponding server-side enforcement comments in
+	// server.TacticsListener.accept.
+	if protocol.TunnelProtocolIsDirect(dialParams.TunnelProtocol) &&
+		common.Contains(
+			p.Strings(parameters.RestrictDirectProviderIDs),
+			dialParams.ServerEntry.ProviderID) {
+		if p.WeightedCoinFlip(
+			parameters.RestrictDirectProviderIDsClientProbability) {
+
+			// When skipping, return nil/nil as no error should be logged.
+			// NoticeSkipServerEntry emits each skip reason, regardless
+			// of server entry, at most once per session.
+
+			NoticeSkipServerEntry(
+				"restricted provider ID: %s",
+				dialParams.ServerEntry.ProviderID)
+
+			return nil, nil
+		}
+	}
+
 	// Skip this candidate when the clients tactics restrict usage of the
 	// Skip this candidate when the clients tactics restrict usage of the
 	// fronting provider ID. See the corresponding server-side enforcement
 	// fronting provider ID. See the corresponding server-side enforcement
-	// comments in server.TacticsListener.accept.
+	// comments in server.MeekServer.getSessionOrEndpoint.
 	if protocol.TunnelProtocolUsesFrontedMeek(dialParams.TunnelProtocol) &&
 	if protocol.TunnelProtocolUsesFrontedMeek(dialParams.TunnelProtocol) &&
 		common.Contains(
 		common.Contains(
 			p.Strings(parameters.RestrictFrontingProviderIDs),
 			p.Strings(parameters.RestrictFrontingProviderIDs),
@@ -845,6 +867,9 @@ func MakeDialParameters(
 
 
 	if !isReplay || !replayHoldOffTunnel {
 	if !isReplay || !replayHoldOffTunnel {
 
 
+		var holdOffTunnelDuration time.Duration
+		var holdOffDirectTunnelDuration time.Duration
+
 		if common.Contains(
 		if common.Contains(
 			p.TunnelProtocols(parameters.HoldOffTunnelProtocols), dialParams.TunnelProtocol) ||
 			p.TunnelProtocols(parameters.HoldOffTunnelProtocols), dialParams.TunnelProtocol) ||
 
 
@@ -855,12 +880,32 @@ func MakeDialParameters(
 
 
 			if p.WeightedCoinFlip(parameters.HoldOffTunnelProbability) {
 			if p.WeightedCoinFlip(parameters.HoldOffTunnelProbability) {
 
 
-				dialParams.HoldOffTunnelDuration = prng.Period(
+				holdOffTunnelDuration = prng.Period(
 					p.Duration(parameters.HoldOffTunnelMinDuration),
 					p.Duration(parameters.HoldOffTunnelMinDuration),
 					p.Duration(parameters.HoldOffTunnelMaxDuration))
 					p.Duration(parameters.HoldOffTunnelMaxDuration))
 			}
 			}
 		}
 		}
 
 
+		if protocol.TunnelProtocolIsDirect(dialParams.TunnelProtocol) &&
+			(common.Contains(
+				p.Strings(parameters.HoldOffDirectServerEntryRegions), serverEntry.Region) ||
+				common.ContainsAny(
+					p.KeyStrings(parameters.HoldOffDirectServerEntryProviderRegions, dialParams.ServerEntry.ProviderID), []string{"", serverEntry.Region})) {
+
+			if p.WeightedCoinFlip(parameters.HoldOffDirectTunnelProbability) {
+
+				holdOffDirectTunnelDuration = prng.Period(
+					p.Duration(parameters.HoldOffDirectTunnelMinDuration),
+					p.Duration(parameters.HoldOffDirectTunnelMaxDuration))
+			}
+		}
+
+		// Use the longest hold off duration
+		if holdOffTunnelDuration >= holdOffDirectTunnelDuration {
+			dialParams.HoldOffTunnelDuration = holdOffTunnelDuration
+		} else {
+			dialParams.HoldOffTunnelDuration = holdOffDirectTunnelDuration
+		}
 	}
 	}
 
 
 	// OSSH prefix and seed transform are applied only to the OSSH tunnel protocol,
 	// OSSH prefix and seed transform are applied only to the OSSH tunnel protocol,

+ 61 - 4
psiphon/dialParameters_test.go

@@ -79,8 +79,20 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 	}
 
 
 	holdOffTunnelProtocols := protocol.TunnelProtocols{protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH}
 	holdOffTunnelProtocols := protocol.TunnelProtocols{protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH}
+
+	providerID := prng.HexString(8)
 	frontingProviderID := prng.HexString(8)
 	frontingProviderID := prng.HexString(8)
 
 
+	var holdOffDirectServerEntryRegions []string
+	if tunnelProtocol == protocol.TUNNEL_PROTOCOL_TLS_OBFUSCATED_SSH {
+		holdOffDirectServerEntryRegions = []string{"CA"}
+	}
+
+	var holdOffDirectServerEntryProviderRegions parameters.KeyStrings
+	if tunnelProtocol == protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK {
+		holdOffDirectServerEntryProviderRegions = map[string][]string{providerID: {""}}
+	}
+
 	applyParameters := make(map[string]interface{})
 	applyParameters := make(map[string]interface{})
 	applyParameters[parameters.TransformHostNameProbability] = 1.0
 	applyParameters[parameters.TransformHostNameProbability] = 1.0
 	applyParameters[parameters.PickUserAgentProbability] = 1.0
 	applyParameters[parameters.PickUserAgentProbability] = 1.0
@@ -89,6 +101,11 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	applyParameters[parameters.HoldOffTunnelProtocols] = holdOffTunnelProtocols
 	applyParameters[parameters.HoldOffTunnelProtocols] = holdOffTunnelProtocols
 	applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.HoldOffTunnelProbability] = 1.0
 	applyParameters[parameters.HoldOffTunnelProbability] = 1.0
+	applyParameters[parameters.HoldOffDirectTunnelMinDuration] = "1ms"
+	applyParameters[parameters.HoldOffDirectTunnelMaxDuration] = "10ms"
+	applyParameters[parameters.HoldOffDirectServerEntryRegions] = holdOffDirectServerEntryRegions
+	applyParameters[parameters.HoldOffDirectServerEntryProviderRegions] = holdOffDirectServerEntryProviderRegions
+	applyParameters[parameters.HoldOffDirectTunnelProbability] = 1.0
 	applyParameters[parameters.DNSResolverAlternateServers] = []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}
 	applyParameters[parameters.DNSResolverAlternateServers] = []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}
 	applyParameters[parameters.DirectHTTPProtocolTransformProbability] = 1.0
 	applyParameters[parameters.DirectHTTPProtocolTransformProbability] = 1.0
 	applyParameters[parameters.DirectHTTPProtocolTransformSpecs] = transforms.Specs{"spec": transforms.Spec{{"", ""}}}
 	applyParameters[parameters.DirectHTTPProtocolTransformSpecs] = transforms.Specs{"spec": transforms.Spec{{"", ""}}}
@@ -115,7 +132,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 	}
 	defer CloseDataStore()
 	defer CloseDataStore()
 
 
-	serverEntries := makeMockServerEntries(tunnelProtocol, frontingProviderID, 100)
+	serverEntries := makeMockServerEntries(tunnelProtocol, "CA", providerID, frontingProviderID, 100)
 
 
 	canReplay := func(serverEntry *protocol.ServerEntry, replayProtocol string) bool {
 	canReplay := func(serverEntry *protocol.ServerEntry, replayProtocol string) bool {
 		return replayProtocol == tunnelProtocol
 		return replayProtocol == tunnelProtocol
@@ -230,8 +247,15 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("missing API request fields")
 		t.Fatalf("missing API request fields")
 	}
 	}
 
 
-	if common.Contains(holdOffTunnelProtocols, tunnelProtocol) ||
-		protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
+	expectHoldOffTunnelProtocols := common.Contains(holdOffTunnelProtocols, tunnelProtocol)
+	expectHoldOffTunnelFrontingProviderIDs := protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol)
+	expectHoldOffDirectServerEntryRegions := protocol.TunnelProtocolIsDirect(tunnelProtocol) && common.Contains(holdOffDirectServerEntryRegions, dialParams.ServerEntry.Region)
+	expectHoldOffDirectServerEntryProviderRegion := protocol.TunnelProtocolIsDirect(tunnelProtocol) && common.ContainsAny(holdOffDirectServerEntryProviderRegions[dialParams.ServerEntry.ProviderID], []string{"", dialParams.ServerEntry.Region})
+
+	if expectHoldOffTunnelProtocols ||
+		expectHoldOffTunnelFrontingProviderIDs ||
+		expectHoldOffDirectServerEntryRegions ||
+		expectHoldOffDirectServerEntryProviderRegion {
 		if dialParams.HoldOffTunnelDuration < 1*time.Millisecond ||
 		if dialParams.HoldOffTunnelDuration < 1*time.Millisecond ||
 			dialParams.HoldOffTunnelDuration > 10*time.Millisecond {
 			dialParams.HoldOffTunnelDuration > 10*time.Millisecond {
 			t.Fatalf("unexpected hold-off duration: %v", dialParams.HoldOffTunnelDuration)
 			t.Fatalf("unexpected hold-off duration: %v", dialParams.HoldOffTunnelDuration)
@@ -524,6 +548,35 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetParameters failed: %s", err)
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 	}
 
 
+	// Test: client-side restrict provider ID
+
+	applyParameters[parameters.RestrictDirectProviderIDs] = []string{providerID}
+	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 1.0
+	err = clientConfig.SetParameters("tag6", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	dialParams, err = MakeDialParameters(clientConfig, nil, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
+
+	if protocol.TunnelProtocolIsDirect(tunnelProtocol) {
+		if err == nil {
+			if dialParams != nil {
+				t.Fatalf("unexpected MakeDialParameters success")
+			}
+		}
+	} else {
+		if err != nil {
+			t.Fatalf("MakeDialParameters failed: %s", err)
+		}
+	}
+
+	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 0.0
+	err = clientConfig.SetParameters("tag7", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
 	// Test: iterator shuffles
 	// Test: iterator shuffles
 
 
 	for i, serverEntry := range serverEntries {
 	for i, serverEntry := range serverEntries {
@@ -677,7 +730,7 @@ func TestLimitTunnelDialPortNumbers(t *testing.T) {
 			continue
 			continue
 		}
 		}
 
 
-		serverEntries := makeMockServerEntries(tunnelProtocol, "", 100)
+		serverEntries := makeMockServerEntries(tunnelProtocol, "", "", "", 100)
 
 
 		selected := false
 		selected := false
 		skipped := false
 		skipped := false
@@ -721,6 +774,8 @@ func TestLimitTunnelDialPortNumbers(t *testing.T) {
 
 
 func makeMockServerEntries(
 func makeMockServerEntries(
 	tunnelProtocol string,
 	tunnelProtocol string,
+	region string,
+	providerID string,
 	frontingProviderID string,
 	frontingProviderID string,
 	count int) []*protocol.ServerEntry {
 	count int) []*protocol.ServerEntry {
 
 
@@ -737,6 +792,8 @@ func makeMockServerEntries(
 			MeekServerPort:             prng.Range(60, 69),
 			MeekServerPort:             prng.Range(60, 69),
 			MeekFrontingHosts:          []string{"www1.example.org", "www2.example.org", "www3.example.org"},
 			MeekFrontingHosts:          []string{"www1.example.org", "www2.example.org", "www3.example.org"},
 			MeekFrontingAddressesRegex: "[a-z0-9]{1,64}.example.org",
 			MeekFrontingAddressesRegex: "[a-z0-9]{1,64}.example.org",
+			Region:                     region,
+			ProviderID:                 providerID,
 			FrontingProviderID:         frontingProviderID,
 			FrontingProviderID:         frontingProviderID,
 			LocalSource:                protocol.SERVER_ENTRY_SOURCE_EMBEDDED,
 			LocalSource:                protocol.SERVER_ENTRY_SOURCE_EMBEDDED,
 			LocalTimestamp:             common.TruncateTimestampToHour(common.GetCurrentTimestamp()),
 			LocalTimestamp:             common.TruncateTimestampToHour(common.GetCurrentTimestamp()),

+ 4 - 0
psiphon/notice.go

@@ -477,6 +477,10 @@ func noticeWithDialParameters(noticeType string, dialParams *DialParameters, pos
 			args = append(args, "upstreamProxyCustomHeaderNames", strings.Join(dialParams.UpstreamProxyCustomHeaderNames, ","))
 			args = append(args, "upstreamProxyCustomHeaderNames", strings.Join(dialParams.UpstreamProxyCustomHeaderNames, ","))
 		}
 		}
 
 
+		if dialParams.ServerEntry.ProviderID != "" {
+			args = append(args, "providerID", dialParams.ServerEntry.ProviderID)
+		}
+
 		if dialParams.FrontingProviderID != "" {
 		if dialParams.FrontingProviderID != "" {
 			args = append(args, "frontingProviderID", dialParams.FrontingProviderID)
 			args = append(args, "frontingProviderID", dialParams.FrontingProviderID)
 		}
 		}

+ 15 - 4
psiphon/server/api.go

@@ -177,7 +177,8 @@ var handshakeRequestParams = append(
 			[]requestParamSpec{
 			[]requestParamSpec{
 				// Legacy clients may not send "session_id" in handshake
 				// Legacy clients may not send "session_id" in handshake
 				{"session_id", isHexDigits, requestParamOptional},
 				{"session_id", isHexDigits, requestParamOptional},
-				{"missing_server_entry_signature", isBase64String, requestParamOptional}},
+				{"missing_server_entry_signature", isBase64String, requestParamOptional},
+				{"missing_server_entry_provider_id", isBase64String, requestParamOptional}},
 			baseParams...),
 			baseParams...),
 		baseDialParams...),
 		baseDialParams...),
 	tacticsParams...)
 	tacticsParams...)
@@ -351,17 +352,27 @@ func handshakeAPIRequestHandler(
 			calculateDiscoveryValue(support.Config.DiscoveryValueHMACKey, clientIP))
 			calculateDiscoveryValue(support.Config.DiscoveryValueHMACKey, clientIP))
 	}
 	}
 
 
-	// When the client indicates that it used an unsigned server entry for this
-	// connection, return a signed copy of the server entry for the client to
-	// upgrade to. See also: comment in psiphon.doHandshakeRequest.
+	// When the client indicates that it used an out-of-date server entry for
+	// this connection, return a signed copy of the server entry for the client
+	// to upgrade to. Out-of-date server entries are either unsigned or missing
+	// a provider ID. See also: comment in psiphon.doHandshakeRequest.
 	//
 	//
 	// The missing_server_entry_signature parameter value is a server entry tag,
 	// The missing_server_entry_signature parameter value is a server entry tag,
 	// which is used to select the correct server entry for servers with multiple
 	// which is used to select the correct server entry for servers with multiple
 	// entries. Identifying the server entries tags instead of server IPs prevents
 	// entries. Identifying the server entries tags instead of server IPs prevents
 	// an enumeration attack, where a malicious client can abuse this facilty to
 	// an enumeration attack, where a malicious client can abuse this facilty to
 	// check if an arbitrary IP address is a Psiphon server.
 	// check if an arbitrary IP address is a Psiphon server.
+	//
+	// The missing_server_entry_provider_id parameter value is a server entry
+	// tag.
 	serverEntryTag, ok := getOptionalStringRequestParam(
 	serverEntryTag, ok := getOptionalStringRequestParam(
 		params, "missing_server_entry_signature")
 		params, "missing_server_entry_signature")
+	if !ok {
+		// Do not need to check this case if we'll already return the server
+		// entry due to a missing signature.
+		serverEntryTag, ok = getOptionalStringRequestParam(
+			params, "missing_server_entry_provider_id")
+	}
 	if ok {
 	if ok {
 		ownServerEntry, ok := support.Config.GetOwnEncodedServerEntry(serverEntryTag)
 		ownServerEntry, ok := support.Config.GetOwnEncodedServerEntry(serverEntryTag)
 		if ok {
 		if ok {

+ 13 - 0
psiphon/server/config.go

@@ -40,6 +40,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values"
@@ -461,6 +462,7 @@ type Config struct {
 	periodicGarbageCollection                      time.Duration
 	periodicGarbageCollection                      time.Duration
 	stopEstablishTunnelsEstablishedClientThreshold int
 	stopEstablishTunnelsEstablishedClientThreshold int
 	dumpProfilesOnStopEstablishTunnelsDone         int32
 	dumpProfilesOnStopEstablishTunnelsDone         int32
+	providerID                                     string
 	frontingProviderID                             string
 	frontingProviderID                             string
 	runningProtocols                               []string
 	runningProtocols                               []string
 }
 }
@@ -529,6 +531,11 @@ func (config *Config) GetOwnEncodedServerEntry(serverEntryTag string) (string, b
 	return serverEntry, ok
 	return serverEntry, ok
 }
 }
 
 
+// GetProviderID returns the provider ID associated with the server.
+func (config *Config) GetProviderID() string {
+	return config.providerID
+}
+
 // GetFrontingProviderID returns the fronting provider ID associated with the
 // GetFrontingProviderID returns the fronting provider ID associated with the
 // server's fronted protocol(s).
 // server's fronted protocol(s).
 func (config *Config) GetFrontingProviderID() string {
 func (config *Config) GetFrontingProviderID() string {
@@ -716,6 +723,11 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 			return nil, errors.Tracef(
 			return nil, errors.Tracef(
 				"protocol.DecodeServerEntry failed: %s", err)
 				"protocol.DecodeServerEntry failed: %s", err)
 		}
 		}
+		if config.providerID == "" {
+			config.providerID = serverEntry.ProviderID
+		} else if config.providerID != serverEntry.ProviderID {
+			return nil, errors.Tracef("unsupported multiple ProviderID values")
+		}
 		if config.frontingProviderID == "" {
 		if config.frontingProviderID == "" {
 			config.frontingProviderID = serverEntry.FrontingProviderID
 			config.frontingProviderID = serverEntry.FrontingProviderID
 		} else if config.frontingProviderID != serverEntry.FrontingProviderID {
 		} else if config.frontingProviderID != serverEntry.FrontingProviderID {
@@ -1141,6 +1153,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, []byt
 		SshObfuscatedKey:              obfuscatedSSHKey,
 		SshObfuscatedKey:              obfuscatedSSHKey,
 		Capabilities:                  capabilities,
 		Capabilities:                  capabilities,
 		Region:                        "US",
 		Region:                        "US",
+		ProviderID:                    prng.HexString(8),
 		MeekServerPort:                meekPort,
 		MeekServerPort:                meekPort,
 		MeekCookieEncryptionPublicKey: meekCookieEncryptionPublicKey,
 		MeekCookieEncryptionPublicKey: meekCookieEncryptionPublicKey,
 		MeekObfuscatedKey:             meekObfuscatedKey,
 		MeekObfuscatedKey:             meekObfuscatedKey,

+ 37 - 0
psiphon/server/listener.go

@@ -25,10 +25,15 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 )
 )
 
 
+type restrictedProviderError struct{}
+
+func (restrictedProviderError) Error() string { return "restricted provider" }
+
 // TacticsListener wraps a net.Listener and applies server-side implementation
 // TacticsListener wraps a net.Listener and applies server-side implementation
 // of certain tactics parameters to accepted connections. Tactics filtering is
 // of certain tactics parameters to accepted connections. Tactics filtering is
 // limited to GeoIP attributes as the client has not yet sent API parameters.
 // limited to GeoIP attributes as the client has not yet sent API parameters.
@@ -94,6 +99,38 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
 		return conn, nil
 		return conn, nil
 	}
 	}
 
 
+	// Disconnect immediately if the tactics for the client restricts usage of
+	// the provider ID with direct protocols. The probability may be used to
+	// influence usage of a given provider with direct protocols; but when only
+	// that provider works for a given client, and the probability is less than
+	// 1.0, the client can  retry until it gets a successful coin flip.
+	//
+	// Clients will also skip direct protocol candidates with restricted
+	// provider IDs.
+	// The client-side probability, RestrictDirectProviderIDsClientProbability,
+	// is applied independently of the server-side coin flip here.
+	//
+	// The selected tactics are for the immediate peer IP and therefore must
+	// not be applied to clients using indirect protocols, where the immediate
+	// peer IP is not the original client IP. Indirect protocols must determine
+	// the original client IP before applying GeoIP specific tactics; see the
+	// server-side enforcement of RestrictDirectProviderIDs for fronted meek in
+	// server.MeekServer.getSessionOrEndpoint.
+	//
+	// At this stage, GeoIP tactics filters are active, but handshake API
+	// parameters are not.
+	//
+	// See the comment in server.LoadConfig regarding provider ID limitations.
+	if protocol.TunnelProtocolIsDirect(listener.tunnelProtocol) &&
+		common.Contains(
+			p.Strings(parameters.RestrictDirectProviderIDs),
+			listener.support.Config.GetProviderID()) {
+		if p.WeightedCoinFlip(
+			parameters.RestrictDirectProviderIDsServerProbability) {
+			return nil, restrictedProviderError{}
+		}
+	}
+
 	// Server-side fragmentation may be synchronized with client-side in two ways.
 	// Server-side fragmentation may be synchronized with client-side in two ways.
 	//
 	//
 	// In the OSSH case, replay is always activated and it is seeded using the
 	// In the OSSH case, replay is always activated and it is seeded using the

+ 4 - 0
psiphon/server/tunnelServer.go

@@ -705,6 +705,10 @@ func runListener(listener net.Listener, shutdownBroadcast <-chan struct{}, liste
 				log.WithTraceFields(LogFields{"error": err}).Error("accept failed")
 				log.WithTraceFields(LogFields{"error": err}).Error("accept failed")
 				// Temporary error, keep running
 				// Temporary error, keep running
 				continue
 				continue
+			} else if _, ok := err.(restrictedProviderError); ok {
+				log.WithTraceFields(LogFields{"error": err}).Error("accept rejected client")
+				// Restricted provider, keep running
+				continue
 			}
 			}
 
 
 			select {
 			select {

+ 23 - 5
psiphon/serverApi.go

@@ -134,6 +134,23 @@ func (serverContext *ServerContext) doHandshakeRequest(
 			serverContext.tunnel.dialParams.ServerEntry.Tag
 			serverContext.tunnel.dialParams.ServerEntry.Tag
 	}
 	}
 
 
+	// The server will return a signed copy of its own server entry when the
+	// client specifies this 'missing_server_entry_provider_id' parameter.
+	//
+	// The purpose of this mechanism is to rapidly add provider IDs to the
+	// server entries in client local storage, and to ensure that the client has
+	// a provider ID for its currently connected server as required for the
+	// RestrictDirectProviderIDs and HoldOffDirectServerEntryProviderRegions tactics.
+	//
+	// The server entry will be included in handshakeResponse.EncodedServerList,
+	// along side discovery servers.
+	requestedMissingProviderID := false
+	if !serverContext.tunnel.dialParams.ServerEntry.HasProviderID() {
+		requestedMissingProviderID = true
+		params["missing_server_entry_provider_id"] =
+			serverContext.tunnel.dialParams.ServerEntry.Tag
+	}
+
 	doTactics := !serverContext.tunnel.config.DisableTactics
 	doTactics := !serverContext.tunnel.config.DisableTactics
 
 
 	networkID := ""
 	networkID := ""
@@ -272,13 +289,14 @@ func (serverContext *ServerContext) doHandshakeRequest(
 			return errors.Trace(err)
 			return errors.Trace(err)
 		}
 		}
 
 
-		// Retain the original timestamp and source in the requestedMissingSignature
-		// case, as this server entry was not discovered here.
+		// Retain the original timestamp and source in the
+		// requestedMissingSignature and requestedMissingProviderID
+		// cases, as this server entry was not discovered here.
 		//
 		//
 		// Limitation: there is a transient edge case where
 		// Limitation: there is a transient edge case where
-		// requestedMissingSignature will be set for a discovery server entry that
-		// _is_ also discovered here.
-		if requestedMissingSignature &&
+		// requestedMissingSignature and/or requestedMissingProviderID will be
+		// set for a discovery server entry that _is_ also discovered here.
+		if requestedMissingSignature || requestedMissingProviderID &&
 			serverEntryFields.GetIPAddress() == serverContext.tunnel.dialParams.ServerEntry.IpAddress {
 			serverEntryFields.GetIPAddress() == serverContext.tunnel.dialParams.ServerEntry.IpAddress {
 
 
 			serverEntryFields.SetLocalTimestamp(serverContext.tunnel.dialParams.ServerEntry.LocalTimestamp)
 			serverEntryFields.SetLocalTimestamp(serverContext.tunnel.dialParams.ServerEntry.LocalTimestamp)