Browse Source

Merge pull request #669 from mirokuratczyk/provider-id

Add provider ID
Rod Hynes 2 years ago
parent
commit
448517f4ca

+ 37 - 0
psiphon/common/parameters/keyStrings.go

@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package parameters
+
+// KeyStrings represents a set of key/strings pairs.
+type KeyStrings map[string][]string
+
+// Validates that the keys and values are well formed.
+func (keyStrings KeyStrings) Validate() error {
+	// Always succeeds because KeyStrings is generic and does not impose any
+	// restrictions on keys/values. Consider imposing limits like maximum
+	// map/array/string sizes.
+	return nil
+}
+
+func (p ParametersAccessor) KeyStrings(name, key string) []string {
+	value := KeyStrings{}
+	p.snapshot.getValue(name, &value)
+	return value[key]
+}

+ 28 - 0
psiphon/common/parameters/parameters.go

@@ -311,6 +311,15 @@ const (
 	HoldOffTunnelProtocols                           = "HoldOffTunnelProtocols"
 	HoldOffTunnelFrontingProviderIDs                 = "HoldOffTunnelFrontingProviderIDs"
 	HoldOffTunnelProbability                         = "HoldOffTunnelProbability"
+	HoldOffDirectTunnelMinDuration                   = "HoldOffDirectTunnelMinDuration"
+	HoldOffDirectTunnelMaxDuration                   = "HoldOffDirectTunnelMaxDuration"
+	HoldOffDirectServerEntryRegions                  = "HoldOffDirectServerEntryRegions"
+	HoldOffDirectServerEntryProviderRegions          = "HoldOffDirectServerEntryProviderRegions"
+	HoldOffDirectTunnelProbability                   = "HoldOffDirectTunnelProbability"
+	RestrictDirectProviderIDs                        = "RestrictDirectProviderIDs"
+	RestrictDirectProviderRegions                    = "RestrictDirectProviderRegions"
+	RestrictDirectProviderIDsServerProbability       = "RestrictDirectProviderIDsServerProbability"
+	RestrictDirectProviderIDsClientProbability       = "RestrictDirectProviderIDsClientProbability"
 	RestrictFrontingProviderIDs                      = "RestrictFrontingProviderIDs"
 	RestrictFrontingProviderIDsServerProbability     = "RestrictFrontingProviderIDsServerProbability"
 	RestrictFrontingProviderIDsClientProbability     = "RestrictFrontingProviderIDsClientProbability"
@@ -702,6 +711,17 @@ var defaultParameters = map[string]struct {
 	HoldOffTunnelFrontingProviderIDs: {value: []string{}},
 	HoldOffTunnelProbability:         {value: 0.0, minimum: 0.0},
 
+	HoldOffDirectTunnelMinDuration:          {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffDirectTunnelMaxDuration:          {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffDirectServerEntryRegions:         {value: []string{}},
+	HoldOffDirectServerEntryProviderRegions: {value: KeyStrings{}},
+	HoldOffDirectTunnelProbability:          {value: 0.0, minimum: 0.0},
+
+	RestrictDirectProviderIDs:                  {value: []string{}},
+	RestrictDirectProviderRegions:              {value: KeyStrings{}},
+	RestrictDirectProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
+	RestrictDirectProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
+
 	RestrictFrontingProviderIDs:                  {value: []string{}},
 	RestrictFrontingProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
 	RestrictFrontingProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
@@ -1086,6 +1106,14 @@ func (p *Parameters) Set(
 					}
 					return nil, errors.Trace(err)
 				}
+			case KeyStrings:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
 			case *BPFProgramSpec:
 				if v != nil {
 					err := v.Validate()

+ 7 - 0
psiphon/common/parameters/parameters_test.go

@@ -189,6 +189,13 @@ func TestGetDefaultParameters(t *testing.T) {
 			if !reflect.DeepEqual(v, g) {
 				t.Fatalf("ConjureTransports returned %+v expected %+v", g, v)
 			}
+		case KeyStrings:
+			for key, strings := range v {
+				g := p.Get().KeyStrings(name, key)
+				if !reflect.DeepEqual(strings, g) {
+					t.Fatalf("KeyStrings returned %+v expected %+v", g, strings)
+				}
+			}
 		default:
 			t.Fatalf("Unhandled default type: %s (%T)", name, defaults.value)
 		}

+ 10 - 0
psiphon/common/protocol/protocol.go

@@ -272,6 +272,16 @@ func TunnelProtocolIsCompatibleWithFragmentor(protocol string) bool {
 		protocol == TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH
 }
 
+func TunnelProtocolIsDirect(protocol string) bool {
+	return protocol == TUNNEL_PROTOCOL_SSH ||
+		protocol == TUNNEL_PROTOCOL_OBFUSCATED_SSH ||
+		protocol == TUNNEL_PROTOCOL_TLS_OBFUSCATED_SSH ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET ||
+		protocol == TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH
+}
+
 func TunnelProtocolRequiresTLS12SessionTickets(protocol string) bool {
 	return protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET
 }

+ 5 - 0
psiphon/common/protocol/serverEntry.go

@@ -61,6 +61,7 @@ type ServerEntry struct {
 	SshObfuscatedKey                string   `json:"sshObfuscatedKey"`
 	Capabilities                    []string `json:"capabilities"`
 	Region                          string   `json:"region"`
+	ProviderID                      string   `json:"providerID"`
 	FrontingProviderID              string   `json:"frontingProviderID"`
 	TlsOSSHPort                     int      `json:"tlsOSSHPort"`
 	MeekServerPort                  int      `json:"meekServerPort"`
@@ -736,6 +737,10 @@ func (serverEntry *ServerEntry) HasSignature() bool {
 	return serverEntry.Signature != ""
 }
 
+func (serverEntry *ServerEntry) HasProviderID() bool {
+	return serverEntry.ProviderID != ""
+}
+
 func (serverEntry *ServerEntry) GetDiagnosticID() string {
 	return TagToDiagnosticID(serverEntry.Tag)
 }

+ 2 - 2
psiphon/common/transforms/httpNormalizer.go

@@ -21,7 +21,7 @@ package transforms
 
 import (
 	"bytes"
-	stderrors "errors"
+	std_errors "errors"
 	"io"
 	"net"
 	"net/textproto"
@@ -50,7 +50,7 @@ const (
 	rangeHeader               = "Range"
 )
 
-var ErrPassthroughActive = stderrors.New("passthrough")
+var ErrPassthroughActive = std_errors.New("passthrough")
 
 // HTTPNormalizer wraps a net.Conn, intercepting Read calls, and normalizes any
 // HTTP requests that are read. The HTTP request components preceeding the body

+ 8 - 8
psiphon/common/transforms/httpNormalizer_test.go

@@ -21,7 +21,7 @@ package transforms
 
 import (
 	"bytes"
-	stderrors "errors"
+	std_errors "errors"
 	"io"
 	"net"
 	"strings"
@@ -143,7 +143,7 @@ func runHTTPNormalizerTest(tt *httpNormalizerTest, useNormalizer bool) error {
 		// Subsequent writes should not impact conn or passthroughConn
 
 		_, err = normalizer.Write([]byte("ignored"))
-		if !stderrors.Is(err, ErrPassthroughActive) {
+		if !std_errors.Is(err, ErrPassthroughActive) {
 			return errors.Tracef("expected error io.EOF but got %v", err)
 		}
 
@@ -230,20 +230,20 @@ func TestHTTPNormalizerHTTPRequest(t *testing.T) {
 			headerOrder:  []string{"Host", "Content-Length"},
 			wantOutput:   "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 4\r\n\r\nabcd",
 			chunkSize:    1,
-			connReadErrs: []error{stderrors.New("err1"), stderrors.New("err2")},
+			connReadErrs: []error{std_errors.New("err1"), std_errors.New("err2")},
 		},
 		{
 			name:       "Content-Length missing",
 			input:      "POST / HTTP/1.1\r\n\r\nabcd",
 			wantOutput: "POST / HTTP/1.1\r\n\r\nabcd", // set to ensure all bytes are read
-			wantError:  stderrors.New("Content-Length missing"),
+			wantError:  std_errors.New("Content-Length missing"),
 			chunkSize:  1,
 		},
 		{
 			name:       "invalid Content-Length header value",
 			input:      "POST / HTTP/1.1\r\nContent-Length: X\r\n\r\nabcd",
 			wantOutput: "POST / HTTP/1.1\r\nContent-Length: X\r\nHost: example.com\r\n\r\nabcd", // set to ensure all bytes are read
-			wantError:  stderrors.New("strconv.ParseUint: parsing \"X\": invalid syntax"),
+			wantError:  std_errors.New("strconv.ParseUint: parsing \"X\": invalid syntax"),
 			chunkSize:  1,
 		},
 		{
@@ -330,7 +330,7 @@ func TestHTTPNormalizerHTTPRequest(t *testing.T) {
 			maxHeaderSize: 47, // up to end of Cookie header
 			wantOutput:    "POST / HTTP/1.1\r\nContent-Length: 4\r\nCookie: X\r\nRange: 1234 \r\n\r\nabcd",
 			chunkSize:     1,
-			wantError:     stderrors.New("exceeds maxReqLineAndHeadersSize"),
+			wantError:     std_errors.New("exceeds maxReqLineAndHeadersSize"),
 		},
 	}
 
@@ -424,7 +424,7 @@ func TestHTTPNormalizerHTTPServer(t *testing.T) {
 				if string(cookie) == "valid" {
 					return []byte(validateMeekCookieResult), nil
 				}
-				return nil, stderrors.New("invalid cookie")
+				return nil, std_errors.New("invalid cookie")
 			}
 			normalizer.HeaderWriteOrder = []string{"Host", "Cookie", "Content-Length"}
 
@@ -469,7 +469,7 @@ func TestHTTPNormalizerHTTPServer(t *testing.T) {
 
 				_, err = conn.Write([]byte(listenerType))
 				if err != nil {
-					if stderrors.Is(err, ErrPassthroughActive) {
+					if std_errors.Is(err, ErrPassthroughActive) {
 						return
 					}
 					recv <- &listenerState{

+ 100 - 0
psiphon/config.go

@@ -818,6 +818,20 @@ type Config struct {
 	HoldOffTunnelFrontingProviderIDs     []string
 	HoldOffTunnelProbability             *float64
 
+	// HoldOffDirectTunnelMinDurationMilliseconds and other HoldOffDirect
+	// fields are for testing purposes.
+	HoldOffDirectTunnelMinDurationMilliseconds *int
+	HoldOffDirectTunnelMaxDurationMilliseconds *int
+	HoldOffDirectServerEntryRegions            []string
+	HoldOffDirectServerEntryProviderRegions    map[string][]string
+	HoldOffDirectTunnelProbability             *float64
+
+	// RestrictDirectProviderIDs and other RestrictDirect fields are for
+	// testing purposes.
+	RestrictDirectProviderIDs                  []string
+	RestrictDirectProviderRegions              map[string][]string
+	RestrictDirectProviderIDsClientProbability *float64
+
 	// RestrictFrontingProviderIDs and other RestrictFrontingProviderIDs fields
 	// are for testing purposes.
 	RestrictFrontingProviderIDs                  []string
@@ -1927,6 +1941,38 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.HoldOffTunnelProbability] = *config.HoldOffTunnelProbability
 	}
 
+	if config.HoldOffDirectTunnelMinDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffDirectTunnelMinDuration] = fmt.Sprintf("%dms", *config.HoldOffDirectTunnelMinDurationMilliseconds)
+	}
+
+	if config.HoldOffDirectTunnelMaxDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffDirectTunnelMaxDuration] = fmt.Sprintf("%dms", *config.HoldOffDirectTunnelMaxDurationMilliseconds)
+	}
+
+	if len(config.HoldOffDirectServerEntryRegions) > 0 {
+		applyParameters[parameters.HoldOffDirectServerEntryRegions] = config.HoldOffDirectServerEntryRegions
+	}
+
+	if len(config.HoldOffDirectServerEntryProviderRegions) > 0 {
+		applyParameters[parameters.HoldOffDirectServerEntryProviderRegions] = parameters.KeyStrings(config.HoldOffDirectServerEntryProviderRegions)
+	}
+
+	if config.HoldOffDirectTunnelProbability != nil {
+		applyParameters[parameters.HoldOffDirectTunnelProbability] = *config.HoldOffDirectTunnelProbability
+	}
+
+	if len(config.RestrictDirectProviderIDs) > 0 {
+		applyParameters[parameters.RestrictDirectProviderIDs] = config.RestrictDirectProviderIDs
+	}
+
+	if len(config.RestrictDirectProviderRegions) > 0 {
+		applyParameters[parameters.RestrictDirectProviderRegions] = parameters.KeyStrings(config.RestrictDirectProviderRegions)
+	}
+
+	if config.RestrictDirectProviderIDsClientProbability != nil {
+		applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = *config.RestrictDirectProviderIDsClientProbability
+	}
+
 	if len(config.RestrictFrontingProviderIDs) > 0 {
 		applyParameters[parameters.RestrictFrontingProviderIDs] = config.RestrictFrontingProviderIDs
 	}
@@ -2414,11 +2460,65 @@ func (config *Config) setDialParametersHash() {
 		}
 	}
 
+	if config.HoldOffDirectTunnelProbability != nil {
+		hash.Write([]byte("HoldOffDirectTunnelProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.HoldOffDirectTunnelProbability)
+	}
+
+	if config.HoldOffDirectTunnelMinDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffDirectTunnelMinDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffDirectTunnelMinDurationMilliseconds))
+	}
+
+	if config.HoldOffDirectTunnelMaxDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffDirectTunnelMaxDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffDirectTunnelMaxDurationMilliseconds))
+	}
+
+	if len(config.HoldOffDirectServerEntryRegions) > 0 {
+		hash.Write([]byte("HoldOffDirectServerEntryRegions"))
+		for _, region := range config.HoldOffDirectServerEntryRegions {
+			hash.Write([]byte(region))
+		}
+	}
+
+	if len(config.HoldOffDirectServerEntryProviderRegions) > 0 {
+		hash.Write([]byte("HoldOffDirectServerEntryProviderRegions"))
+		for providerID, regions := range config.HoldOffDirectServerEntryProviderRegions {
+			hash.Write([]byte(providerID))
+			for _, region := range regions {
+				hash.Write([]byte(region))
+			}
+		}
+	}
+
 	if config.HoldOffTunnelProbability != nil {
 		hash.Write([]byte("HoldOffTunnelProbability"))
 		binary.Write(hash, binary.LittleEndian, *config.HoldOffTunnelProbability)
 	}
 
+	if len(config.RestrictDirectProviderIDs) > 0 {
+		hash.Write([]byte("RestrictDirectProviderIDs"))
+		for _, providerID := range config.RestrictDirectProviderIDs {
+			hash.Write([]byte(providerID))
+		}
+	}
+
+	if len(config.RestrictDirectProviderRegions) > 0 {
+		hash.Write([]byte("RestrictDirectProviderRegions"))
+		for providerID, regions := range config.RestrictDirectProviderRegions {
+			hash.Write([]byte(providerID))
+			for _, region := range regions {
+				hash.Write([]byte(region))
+			}
+		}
+	}
+
+	if config.RestrictDirectProviderIDsClientProbability != nil {
+		hash.Write([]byte("RestrictDirectProviderIDsClientProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.RestrictDirectProviderIDsClientProbability)
+	}
+
 	if len(config.RestrictFrontingProviderIDs) > 0 {
 		hash.Write([]byte("RestrictFrontingProviderIDs"))
 		for _, providerID := range config.RestrictFrontingProviderIDs {

+ 49 - 2
psiphon/dialParameters.go

@@ -427,9 +427,33 @@ func MakeDialParameters(
 		dialParams.TunnelProtocol = selectedProtocol
 	}
 
+	// Skip this candidate when the clients tactics restrict usage of the
+	// provider ID. See the corresponding server-side enforcement comments in
+	// server.TacticsListener.accept.
+	if protocol.TunnelProtocolIsDirect(dialParams.TunnelProtocol) &&
+		(common.Contains(
+			p.Strings(parameters.RestrictDirectProviderIDs),
+			dialParams.ServerEntry.ProviderID) ||
+			common.ContainsAny(
+				p.KeyStrings(parameters.RestrictDirectProviderRegions, dialParams.ServerEntry.ProviderID), []string{"", serverEntry.Region})) {
+		if p.WeightedCoinFlip(
+			parameters.RestrictDirectProviderIDsClientProbability) {
+
+			// When skipping, return nil/nil as no error should be logged.
+			// NoticeSkipServerEntry emits each skip reason, regardless
+			// of server entry, at most once per session.
+
+			NoticeSkipServerEntry(
+				"restricted provider ID: %s",
+				dialParams.ServerEntry.ProviderID)
+
+			return nil, nil
+		}
+	}
+
 	// Skip this candidate when the clients tactics restrict usage of the
 	// fronting provider ID. See the corresponding server-side enforcement
-	// comments in server.TacticsListener.accept.
+	// comments in server.MeekServer.getSessionOrEndpoint.
 	if protocol.TunnelProtocolUsesFrontedMeek(dialParams.TunnelProtocol) &&
 		common.Contains(
 			p.Strings(parameters.RestrictFrontingProviderIDs),
@@ -845,6 +869,9 @@ func MakeDialParameters(
 
 	if !isReplay || !replayHoldOffTunnel {
 
+		var holdOffTunnelDuration time.Duration
+		var holdOffDirectTunnelDuration time.Duration
+
 		if common.Contains(
 			p.TunnelProtocols(parameters.HoldOffTunnelProtocols), dialParams.TunnelProtocol) ||
 
@@ -855,12 +882,32 @@ func MakeDialParameters(
 
 			if p.WeightedCoinFlip(parameters.HoldOffTunnelProbability) {
 
-				dialParams.HoldOffTunnelDuration = prng.Period(
+				holdOffTunnelDuration = prng.Period(
 					p.Duration(parameters.HoldOffTunnelMinDuration),
 					p.Duration(parameters.HoldOffTunnelMaxDuration))
 			}
 		}
 
+		if protocol.TunnelProtocolIsDirect(dialParams.TunnelProtocol) &&
+			(common.Contains(
+				p.Strings(parameters.HoldOffDirectServerEntryRegions), serverEntry.Region) ||
+				common.ContainsAny(
+					p.KeyStrings(parameters.HoldOffDirectServerEntryProviderRegions, dialParams.ServerEntry.ProviderID), []string{"", serverEntry.Region})) {
+
+			if p.WeightedCoinFlip(parameters.HoldOffDirectTunnelProbability) {
+
+				holdOffDirectTunnelDuration = prng.Period(
+					p.Duration(parameters.HoldOffDirectTunnelMinDuration),
+					p.Duration(parameters.HoldOffDirectTunnelMaxDuration))
+			}
+		}
+
+		// Use the longest hold off duration
+		if holdOffTunnelDuration >= holdOffDirectTunnelDuration {
+			dialParams.HoldOffTunnelDuration = holdOffTunnelDuration
+		} else {
+			dialParams.HoldOffTunnelDuration = holdOffDirectTunnelDuration
+		}
 	}
 
 	// OSSH prefix and seed transform are applied only to the OSSH tunnel protocol,

+ 92 - 4
psiphon/dialParameters_test.go

@@ -79,8 +79,20 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 
 	holdOffTunnelProtocols := protocol.TunnelProtocols{protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH}
+
+	providerID := prng.HexString(8)
 	frontingProviderID := prng.HexString(8)
 
+	var holdOffDirectServerEntryRegions []string
+	if tunnelProtocol == protocol.TUNNEL_PROTOCOL_TLS_OBFUSCATED_SSH {
+		holdOffDirectServerEntryRegions = []string{"CA"}
+	}
+
+	var holdOffDirectServerEntryProviderRegions parameters.KeyStrings
+	if tunnelProtocol == protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK {
+		holdOffDirectServerEntryProviderRegions = map[string][]string{providerID: {""}}
+	}
+
 	applyParameters := make(map[string]interface{})
 	applyParameters[parameters.TransformHostNameProbability] = 1.0
 	applyParameters[parameters.PickUserAgentProbability] = 1.0
@@ -89,6 +101,11 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	applyParameters[parameters.HoldOffTunnelProtocols] = holdOffTunnelProtocols
 	applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = []string{frontingProviderID}
 	applyParameters[parameters.HoldOffTunnelProbability] = 1.0
+	applyParameters[parameters.HoldOffDirectTunnelMinDuration] = "1ms"
+	applyParameters[parameters.HoldOffDirectTunnelMaxDuration] = "10ms"
+	applyParameters[parameters.HoldOffDirectServerEntryRegions] = holdOffDirectServerEntryRegions
+	applyParameters[parameters.HoldOffDirectServerEntryProviderRegions] = holdOffDirectServerEntryProviderRegions
+	applyParameters[parameters.HoldOffDirectTunnelProbability] = 1.0
 	applyParameters[parameters.DNSResolverAlternateServers] = []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}
 	applyParameters[parameters.DirectHTTPProtocolTransformProbability] = 1.0
 	applyParameters[parameters.DirectHTTPProtocolTransformSpecs] = transforms.Specs{"spec": transforms.Spec{{"", ""}}}
@@ -115,7 +132,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 	defer CloseDataStore()
 
-	serverEntries := makeMockServerEntries(tunnelProtocol, frontingProviderID, 100)
+	serverEntries := makeMockServerEntries(tunnelProtocol, "CA", providerID, frontingProviderID, 100)
 
 	canReplay := func(serverEntry *protocol.ServerEntry, replayProtocol string) bool {
 		return replayProtocol == tunnelProtocol
@@ -230,8 +247,15 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("missing API request fields")
 	}
 
-	if common.Contains(holdOffTunnelProtocols, tunnelProtocol) ||
-		protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
+	expectHoldOffTunnelProtocols := common.Contains(holdOffTunnelProtocols, tunnelProtocol)
+	expectHoldOffTunnelFrontingProviderIDs := protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol)
+	expectHoldOffDirectServerEntryRegions := protocol.TunnelProtocolIsDirect(tunnelProtocol) && common.Contains(holdOffDirectServerEntryRegions, dialParams.ServerEntry.Region)
+	expectHoldOffDirectServerEntryProviderRegion := protocol.TunnelProtocolIsDirect(tunnelProtocol) && common.ContainsAny(holdOffDirectServerEntryProviderRegions[dialParams.ServerEntry.ProviderID], []string{"", dialParams.ServerEntry.Region})
+
+	if expectHoldOffTunnelProtocols ||
+		expectHoldOffTunnelFrontingProviderIDs ||
+		expectHoldOffDirectServerEntryRegions ||
+		expectHoldOffDirectServerEntryProviderRegion {
 		if dialParams.HoldOffTunnelDuration < 1*time.Millisecond ||
 			dialParams.HoldOffTunnelDuration > 10*time.Millisecond {
 			t.Fatalf("unexpected hold-off duration: %v", dialParams.HoldOffTunnelDuration)
@@ -524,6 +548,66 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetParameters failed: %s", err)
 	}
 
+	// Test: client-side restrict provider ID
+
+	applyParameters[parameters.RestrictDirectProviderIDs] = []string{providerID}
+	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 1.0
+	err = clientConfig.SetParameters("tag6", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	dialParams, err = MakeDialParameters(clientConfig, nil, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
+
+	if protocol.TunnelProtocolIsDirect(tunnelProtocol) {
+		if err == nil {
+			if dialParams != nil {
+				t.Fatalf("unexpected MakeDialParameters success")
+			}
+		}
+	} else {
+		if err != nil {
+			t.Fatalf("MakeDialParameters failed: %s", err)
+		}
+	}
+
+	applyParameters[parameters.RestrictDirectProviderIDs] = []string{}
+	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 0.0
+	err = clientConfig.SetParameters("tag7", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	// Test: client-side restrict provider ID by region
+
+	applyParameters[parameters.RestrictDirectProviderRegions] = map[string][]string{providerID: {"CA"}}
+	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 1.0
+	err = clientConfig.SetParameters("tag6", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	dialParams, err = MakeDialParameters(clientConfig, nil, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
+
+	if protocol.TunnelProtocolIsDirect(tunnelProtocol) {
+		if err == nil {
+			if dialParams != nil {
+				t.Fatalf("unexpected MakeDialParameters success")
+			}
+		}
+	} else {
+		if err != nil {
+			t.Fatalf("MakeDialParameters failed: %s", err)
+		}
+	}
+
+	applyParameters[parameters.RestrictDirectProviderRegions] = map[string][]string{}
+	applyParameters[parameters.RestrictDirectProviderIDsClientProbability] = 0.0
+	err = clientConfig.SetParameters("tag7", false, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
 	// Test: iterator shuffles
 
 	for i, serverEntry := range serverEntries {
@@ -677,7 +761,7 @@ func TestLimitTunnelDialPortNumbers(t *testing.T) {
 			continue
 		}
 
-		serverEntries := makeMockServerEntries(tunnelProtocol, "", 100)
+		serverEntries := makeMockServerEntries(tunnelProtocol, "", "", "", 100)
 
 		selected := false
 		skipped := false
@@ -721,6 +805,8 @@ func TestLimitTunnelDialPortNumbers(t *testing.T) {
 
 func makeMockServerEntries(
 	tunnelProtocol string,
+	region string,
+	providerID string,
 	frontingProviderID string,
 	count int) []*protocol.ServerEntry {
 
@@ -737,6 +823,8 @@ func makeMockServerEntries(
 			MeekServerPort:             prng.Range(60, 69),
 			MeekFrontingHosts:          []string{"www1.example.org", "www2.example.org", "www3.example.org"},
 			MeekFrontingAddressesRegex: "[a-z0-9]{1,64}.example.org",
+			Region:                     region,
+			ProviderID:                 providerID,
 			FrontingProviderID:         frontingProviderID,
 			LocalSource:                protocol.SERVER_ENTRY_SOURCE_EMBEDDED,
 			LocalTimestamp:             common.TruncateTimestampToHour(common.GetCurrentTimestamp()),

+ 4 - 0
psiphon/notice.go

@@ -477,6 +477,10 @@ func noticeWithDialParameters(noticeType string, dialParams *DialParameters, pos
 			args = append(args, "upstreamProxyCustomHeaderNames", strings.Join(dialParams.UpstreamProxyCustomHeaderNames, ","))
 		}
 
+		if dialParams.ServerEntry.ProviderID != "" {
+			args = append(args, "providerID", dialParams.ServerEntry.ProviderID)
+		}
+
 		if dialParams.FrontingProviderID != "" {
 			args = append(args, "frontingProviderID", dialParams.FrontingProviderID)
 		}

+ 15 - 4
psiphon/server/api.go

@@ -177,7 +177,8 @@ var handshakeRequestParams = append(
 			[]requestParamSpec{
 				// Legacy clients may not send "session_id" in handshake
 				{"session_id", isHexDigits, requestParamOptional},
-				{"missing_server_entry_signature", isBase64String, requestParamOptional}},
+				{"missing_server_entry_signature", isBase64String, requestParamOptional},
+				{"missing_server_entry_provider_id", isBase64String, requestParamOptional}},
 			baseParams...),
 		baseDialParams...),
 	tacticsParams...)
@@ -351,17 +352,27 @@ func handshakeAPIRequestHandler(
 			calculateDiscoveryValue(support.Config.DiscoveryValueHMACKey, clientIP))
 	}
 
-	// When the client indicates that it used an unsigned server entry for this
-	// connection, return a signed copy of the server entry for the client to
-	// upgrade to. See also: comment in psiphon.doHandshakeRequest.
+	// When the client indicates that it used an out-of-date server entry for
+	// this connection, return a signed copy of the server entry for the client
+	// to upgrade to. Out-of-date server entries are either unsigned or missing
+	// a provider ID. See also: comment in psiphon.doHandshakeRequest.
 	//
 	// The missing_server_entry_signature parameter value is a server entry tag,
 	// which is used to select the correct server entry for servers with multiple
 	// entries. Identifying the server entries tags instead of server IPs prevents
 	// an enumeration attack, where a malicious client can abuse this facilty to
 	// check if an arbitrary IP address is a Psiphon server.
+	//
+	// The missing_server_entry_provider_id parameter value is a server entry
+	// tag.
 	serverEntryTag, ok := getOptionalStringRequestParam(
 		params, "missing_server_entry_signature")
+	if !ok {
+		// Do not need to check this case if we'll already return the server
+		// entry due to a missing signature.
+		serverEntryTag, ok = getOptionalStringRequestParam(
+			params, "missing_server_entry_provider_id")
+	}
 	if ok {
 		ownServerEntry, ok := support.Config.GetOwnEncodedServerEntry(serverEntryTag)
 		if ok {

+ 24 - 0
psiphon/server/config.go

@@ -40,6 +40,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values"
@@ -461,7 +462,9 @@ type Config struct {
 	periodicGarbageCollection                      time.Duration
 	stopEstablishTunnelsEstablishedClientThreshold int
 	dumpProfilesOnStopEstablishTunnelsDone         int32
+	providerID                                     string
 	frontingProviderID                             string
+	region                                         string
 	runningProtocols                               []string
 }
 
@@ -529,12 +532,22 @@ func (config *Config) GetOwnEncodedServerEntry(serverEntryTag string) (string, b
 	return serverEntry, ok
 }
 
+// GetProviderID returns the provider ID associated with the server.
+func (config *Config) GetProviderID() string {
+	return config.providerID
+}
+
 // GetFrontingProviderID returns the fronting provider ID associated with the
 // server's fronted protocol(s).
 func (config *Config) GetFrontingProviderID() string {
 	return config.frontingProviderID
 }
 
+// GetRegion returns the region associated with the server.
+func (config *Config) GetRegion() string {
+	return config.region
+}
+
 // GetRunningProtocols returns the list of protcols this server is running.
 // The caller must not mutate the return value.
 func (config *Config) GetRunningProtocols() []string {
@@ -716,11 +729,21 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 			return nil, errors.Tracef(
 				"protocol.DecodeServerEntry failed: %s", err)
 		}
+		if config.providerID == "" {
+			config.providerID = serverEntry.ProviderID
+		} else if config.providerID != serverEntry.ProviderID {
+			return nil, errors.Tracef("unsupported multiple ProviderID values")
+		}
 		if config.frontingProviderID == "" {
 			config.frontingProviderID = serverEntry.FrontingProviderID
 		} else if config.frontingProviderID != serverEntry.FrontingProviderID {
 			return nil, errors.Tracef("unsupported multiple FrontingProviderID values")
 		}
+		if config.region == "" {
+			config.region = serverEntry.Region
+		} else if config.region != serverEntry.Region {
+			return nil, errors.Tracef("unsupported multiple Region values")
+		}
 	}
 
 	config.runningProtocols = []string{}
@@ -1141,6 +1164,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, []byt
 		SshObfuscatedKey:              obfuscatedSSHKey,
 		Capabilities:                  capabilities,
 		Region:                        "US",
+		ProviderID:                    prng.HexString(8),
 		MeekServerPort:                meekPort,
 		MeekCookieEncryptionPublicKey: meekCookieEncryptionPublicKey,
 		MeekObfuscatedKey:             meekObfuscatedKey,

+ 4 - 4
psiphon/server/demux_test.go

@@ -22,7 +22,7 @@ package server
 import (
 	"bytes"
 	"context"
-	stderrors "errors"
+	std_errors "errors"
 	"fmt"
 	"math/rand"
 	"net"
@@ -135,7 +135,7 @@ func runProtocolDemuxTest(tt *protocolDemuxTest) error {
 		defer close(runErr)
 
 		err := mux.run()
-		if err != nil && !stderrors.Is(err, context.Canceled) {
+		if err != nil && !std_errors.Is(err, context.Canceled) {
 			runErr <- err
 		}
 	}()
@@ -153,7 +153,7 @@ func runProtocolDemuxTest(tt *protocolDemuxTest) error {
 	}
 
 	err = <-runErr
-	if err != nil && !stderrors.Is(err, net.ErrClosed) {
+	if err != nil && !std_errors.Is(err, net.ErrClosed) {
 		return errors.Trace(err)
 	}
 
@@ -398,7 +398,7 @@ func (c *testConn) Read(b []byte) (n int, err error) {
 }
 
 func (c *testConn) Write(b []byte) (n int, err error) {
-	return 0, stderrors.New("not supported")
+	return 0, std_errors.New("not supported")
 }
 
 func (c *testConn) Close() error {

+ 40 - 0
psiphon/server/listener.go

@@ -20,15 +20,19 @@
 package server
 
 import (
+	std_errors "errors"
 	"net"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 )
 
+var errRestrictedProvider = std_errors.New("restricted provider")
+
 // TacticsListener wraps a net.Listener and applies server-side implementation
 // of certain tactics parameters to accepted connections. Tactics filtering is
 // limited to GeoIP attributes as the client has not yet sent API parameters.
@@ -94,6 +98,42 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
 		return conn, nil
 	}
 
+	// Disconnect immediately if the tactics for the client restricts usage of
+	// the provider ID with direct protocols. The probability may be used to
+	// influence usage of a given provider with direct protocols; but when only
+	// that provider works for a given client, and the probability is less than
+	// 1.0, the client can  retry until it gets a successful coin flip.
+	//
+	// Clients will also skip direct protocol candidates with restricted
+	// provider IDs.
+	// The client-side probability, RestrictDirectProviderIDsClientProbability,
+	// is applied independently of the server-side coin flip here.
+	//
+	// The selected tactics are for the immediate peer IP and therefore must
+	// not be applied to clients using indirect protocols, where the immediate
+	// peer IP is not the original client IP. Indirect protocols must determine
+	// the original client IP before applying GeoIP specific tactics; see the
+	// server-side enforcement of RestrictFrontingProviderIDs for fronted meek
+	// in server.MeekServer.getSessionOrEndpoint.
+	//
+	// At this stage, GeoIP tactics filters are active, but handshake API
+	// parameters are not.
+	//
+	// See the comment in server.LoadConfig regarding provider ID limitations.
+	if protocol.TunnelProtocolIsDirect(listener.tunnelProtocol) &&
+		(common.Contains(
+			p.Strings(parameters.RestrictDirectProviderIDs),
+			listener.support.Config.GetProviderID()) ||
+			common.ContainsAny(
+				p.KeyStrings(parameters.RestrictDirectProviderRegions, listener.support.Config.GetProviderID()), []string{"", listener.support.Config.GetRegion()})) {
+
+		if p.WeightedCoinFlip(
+			parameters.RestrictDirectProviderIDsServerProbability) {
+			conn.Close()
+			return nil, errRestrictedProvider
+		}
+	}
+
 	// Server-side fragmentation may be synchronized with client-side in two ways.
 	//
 	// In the OSSH case, replay is always activated and it is seeded using the

+ 4 - 0
psiphon/server/tunnelServer.go

@@ -705,6 +705,10 @@ func runListener(listener net.Listener, shutdownBroadcast <-chan struct{}, liste
 				log.WithTraceFields(LogFields{"error": err}).Error("accept failed")
 				// Temporary error, keep running
 				continue
+			} else if std_errors.Is(err, errRestrictedProvider) {
+				log.WithTraceFields(LogFields{"error": err}).Error("accept rejected client")
+				// Restricted provider, keep running
+				continue
 			}
 
 			select {

+ 24 - 5
psiphon/serverApi.go

@@ -134,6 +134,24 @@ func (serverContext *ServerContext) doHandshakeRequest(
 			serverContext.tunnel.dialParams.ServerEntry.Tag
 	}
 
+	// The server will return a signed copy of its own server entry when the
+	// client specifies this 'missing_server_entry_provider_id' parameter.
+	//
+	// The purpose of this mechanism is to rapidly add provider IDs to the
+	// server entries in client local storage, and to ensure that the client has
+	// a provider ID for its currently connected server as required for the
+	// RestrictDirectProviderIDs, RestrictDirectProviderRegions, and
+	// HoldOffDirectServerEntryProviderRegions tactics.
+	//
+	// The server entry will be included in handshakeResponse.EncodedServerList,
+	// along side discovery servers.
+	requestedMissingProviderID := false
+	if !serverContext.tunnel.dialParams.ServerEntry.HasProviderID() {
+		requestedMissingProviderID = true
+		params["missing_server_entry_provider_id"] =
+			serverContext.tunnel.dialParams.ServerEntry.Tag
+	}
+
 	doTactics := !serverContext.tunnel.config.DisableTactics
 
 	networkID := ""
@@ -272,13 +290,14 @@ func (serverContext *ServerContext) doHandshakeRequest(
 			return errors.Trace(err)
 		}
 
-		// Retain the original timestamp and source in the requestedMissingSignature
-		// case, as this server entry was not discovered here.
+		// Retain the original timestamp and source in the
+		// requestedMissingSignature and requestedMissingProviderID
+		// cases, as this server entry was not discovered here.
 		//
 		// Limitation: there is a transient edge case where
-		// requestedMissingSignature will be set for a discovery server entry that
-		// _is_ also discovered here.
-		if requestedMissingSignature &&
+		// requestedMissingSignature and/or requestedMissingProviderID will be
+		// set for a discovery server entry that _is_ also discovered here.
+		if requestedMissingSignature || requestedMissingProviderID &&
 			serverEntryFields.GetIPAddress() == serverContext.tunnel.dialParams.ServerEntry.IpAddress {
 
 			serverEntryFields.SetLocalTimestamp(serverContext.tunnel.dialParams.ServerEntry.LocalTimestamp)