Browse Source

Merge pull request #594 from rod-hynes/domain-fronted-registration

Domain fronted Conjure registration
Rod Hynes 5 years ago
parent
commit
9fbbe2f816

+ 7 - 0
psiphon/TCPConn.go

@@ -47,6 +47,13 @@ type TCPConn struct {
 // as a custom dialer for NewProxyAuthTransport (or http.Transport with a
 // as a custom dialer for NewProxyAuthTransport (or http.Transport with a
 // ProxyUrl), as that would result in double proxy chaining.
 // ProxyUrl), as that would result in double proxy chaining.
 func NewTCPDialer(config *DialConfig) common.Dialer {
 func NewTCPDialer(config *DialConfig) common.Dialer {
+
+	// Use config.CustomDialer when set. This ignores all other parameters in
+	// DialConfig.
+	if config.CustomDialer != nil {
+		return config.CustomDialer
+	}
+
 	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 		if network != "tcp" {
 		if network != "tcp" {
 			return nil, errors.Tracef("%s unsupported", network)
 			return nil, errors.Tracef("%s unsupported", network)

+ 1 - 1
psiphon/common/burst.go

@@ -203,7 +203,7 @@ func (conn *BurstMonitoredConn) IsClosed() bool {
 
 
 // GetMetrics returns log fields with burst metrics for the first, last, min
 // GetMetrics returns log fields with burst metrics for the first, last, min
 // (by rate), and max bursts for this conn. Time/duration values are reported
 // (by rate), and max bursts for this conn. Time/duration values are reported
-// in milliseconds.
+// in milliseconds. Rate is reported in bytes per second.
 func (conn *BurstMonitoredConn) GetMetrics(baseTime time.Time) LogFields {
 func (conn *BurstMonitoredConn) GetMetrics(baseTime time.Time) LogFields {
 	logFields := make(LogFields)
 	logFields := make(LogFields)
 
 

+ 18 - 0
psiphon/common/net.go

@@ -79,6 +79,24 @@ type FragmentorReplayAccessor interface {
 	GetReplay() (*prng.Seed, bool)
 	GetReplay() (*prng.Seed, bool)
 }
 }
 
 
+// HTTPRoundTripper is an adapter that allows using a function as a
+// http.RoundTripper.
+type HTTPRoundTripper struct {
+	roundTrip func(*http.Request) (*http.Response, error)
+}
+
+// NewHTTPRoundTripper creates a new HTTPRoundTripper, using the specified
+// roundTrip function for HTTP round trips.
+func NewHTTPRoundTripper(
+	roundTrip func(*http.Request) (*http.Response, error)) *HTTPRoundTripper {
+	return &HTTPRoundTripper{roundTrip: roundTrip}
+}
+
+// RoundTrip implements http.RoundTripper RoundTrip.
+func (h HTTPRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
+	return h.roundTrip(request)
+}
+
 // TerminateHTTPConnection sends a 404 response to a client and also closes
 // TerminateHTTPConnection sends a 404 response to a client and also closes
 // the persistent connection.
 // the persistent connection.
 func TerminateHTTPConnection(
 func TerminateHTTPConnection(

+ 2 - 2
psiphon/common/obfuscator/history.go

@@ -126,9 +126,9 @@ func (h *SeedHistory) AddNew(
 	// an unlikely possibility that this Add and the following Get don't see the
 	// an unlikely possibility that this Add and the following Get don't see the
 	// same existing key/value state.
 	// same existing key/value state.
 
 
-	if h.seedToTime.Add(key, time.Now(), 0) == nil {
+	if h.seedToTime.Add(key, time.Now(), lrucache.DefaultExpiration) == nil {
 		// Seed was not already in cache
 		// Seed was not already in cache
-		h.seedToClientIP.Set(key, clientIP, 0)
+		h.seedToClientIP.Set(key, clientIP, lrucache.DefaultExpiration)
 		return true, nil
 		return true, nil
 	}
 	}
 
 

+ 8 - 0
psiphon/common/obfuscator/obfuscator.go

@@ -83,6 +83,10 @@ type ObfuscatorConfig struct {
 func NewClientObfuscator(
 func NewClientObfuscator(
 	config *ObfuscatorConfig) (obfuscator *Obfuscator, err error) {
 	config *ObfuscatorConfig) (obfuscator *Obfuscator, err error) {
 
 
+	if config.Keyword == "" {
+		return nil, errors.TraceNew("missing keyword")
+	}
+
 	if config.PaddingPRNGSeed == nil {
 	if config.PaddingPRNGSeed == nil {
 		return nil, errors.TraceNew("missing padding seed")
 		return nil, errors.TraceNew("missing padding seed")
 	}
 	}
@@ -148,6 +152,10 @@ func NewClientObfuscator(
 func NewServerObfuscator(
 func NewServerObfuscator(
 	config *ObfuscatorConfig, clientIP string, clientReader io.Reader) (obfuscator *Obfuscator, err error) {
 	config *ObfuscatorConfig, clientIP string, clientReader io.Reader) (obfuscator *Obfuscator, err error) {
 
 
+	if config.Keyword == "" {
+		return nil, errors.TraceNew("missing keyword")
+	}
+
 	clientToServerCipher, serverToClientCipher, paddingPRNGSeed, err := readSeedMessage(
 	clientToServerCipher, serverToClientCipher, paddingPRNGSeed, err := readSeedMessage(
 		config, clientIP, clientReader)
 		config, clientIP, clientReader)
 	if err != nil {
 	if err != nil {

+ 130 - 0
psiphon/common/parameters/frontingSpec.go

@@ -0,0 +1,130 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package parameters
+
+import (
+	"net"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	regen "github.com/zach-klippenstein/goregen"
+)
+
+// FrontingSpecs is a list of domain fronting specs.
+type FrontingSpecs []*FrontingSpec
+
+// FrontingSpec specifies a domain fronting configuration, to be used with
+// MeekConn and MeekModePlaintextRoundTrip. In MeekModePlaintextRoundTrip, the
+// fronted origin is an arbitrary web server, not a Psiphon server. This
+// MeekConn mode requires HTTPS and server certificate validation:
+// VerifyServerName is required; VerifyPins is recommended. See also
+// psiphon.MeekConfig and psiphon.MeekConn.
+//
+// FrontingSpec.Addresses supports the functionality of both
+// ServerEntry.MeekFrontingAddressesRegex and
+// ServerEntry.MeekFrontingAddresses: multiple candidates are supported, and
+// each candidate may be a regex, or a static value (with regex syntax).
+type FrontingSpec struct {
+	FrontingProviderID string
+	Addresses          []string
+	DisableSNI         bool
+	VerifyServerName   string
+	VerifyPins         []string
+	Host               string
+}
+
+// SelectParameters selects fronting parameters from the given FrontingSpecs,
+// first selecting a spec at random. SelectParameters is similar to
+// psiphon.selectFrontingParameters, which operates on server entries.
+//
+// The return values are:
+// - Dial Address (domain or IP address)
+// - SNI (which may be transformed; unless it is "", which indicates omit SNI)
+// - VerifyServerName (see psiphon.CustomTLSConfig)
+// - VerifyPins (see psiphon.CustomTLSConfig)
+// - Host (Host header value)
+func (specs FrontingSpecs) SelectParameters() (
+	string, string, string, string, []string, string, error) {
+
+	if len(specs) == 0 {
+		return "", "", "", "", nil, "", errors.TraceNew("missing fronting spec")
+	}
+
+	spec := specs[prng.Intn(len(specs))]
+
+	if len(spec.Addresses) == 0 {
+		return "", "", "", "", nil, "", errors.TraceNew("missing fronting address")
+	}
+
+	frontingDialAddr, err := regen.Generate(
+		spec.Addresses[prng.Intn(len(spec.Addresses))])
+	if err != nil {
+		return "", "", "", "", nil, "", errors.Trace(err)
+	}
+
+	SNIServerName := frontingDialAddr
+	if spec.DisableSNI || net.ParseIP(frontingDialAddr) != nil {
+		SNIServerName = ""
+	}
+
+	return spec.FrontingProviderID,
+		frontingDialAddr,
+		SNIServerName,
+		spec.VerifyServerName,
+		spec.VerifyPins,
+		spec.Host,
+		nil
+}
+
+// Validate checks that the JSON values are well-formed.
+func (specs FrontingSpecs) Validate() error {
+
+	// An empty FrontingSpecs is allowed as a tactics setting, but
+	// SelectParameters will fail at runtime: code that uses FrontingSpecs must
+	// provide some mechanism -- or check for an empty FrontingSpecs -- to
+	// enable/disable features that use FrontingSpecs.
+
+	for _, spec := range specs {
+		if len(spec.FrontingProviderID) == 0 {
+			return errors.TraceNew("empty fronting provider ID")
+		}
+		if len(spec.Addresses) == 0 {
+			return errors.TraceNew("missing fronting addresses")
+		}
+		for _, addr := range spec.Addresses {
+			if len(addr) == 0 {
+				return errors.TraceNew("empty fronting address")
+			}
+		}
+		if len(spec.VerifyServerName) == 0 {
+			return errors.TraceNew("empty verify server name")
+		}
+		// An empty VerifyPins is allowed.
+		for _, pin := range spec.VerifyPins {
+			if len(pin) == 0 {
+				return errors.TraceNew("empty verify pin")
+			}
+		}
+		if len(spec.Host) == 0 {
+			return errors.TraceNew("empty fronting host")
+		}
+	}
+	return nil
+}

+ 25 - 1
psiphon/common/parameters/parameters.go

@@ -273,7 +273,15 @@ const (
 	ClientBurstUpstreamTargetBytes                   = "ClientBurstUpstreamTargetBytes"
 	ClientBurstUpstreamTargetBytes                   = "ClientBurstUpstreamTargetBytes"
 	ClientBurstDownstreamDeadline                    = "ClientBurstDownstreamDeadline"
 	ClientBurstDownstreamDeadline                    = "ClientBurstDownstreamDeadline"
 	ClientBurstDownstreamTargetBytes                 = "ClientBurstDownstreamTargetBytes"
 	ClientBurstDownstreamTargetBytes                 = "ClientBurstDownstreamTargetBytes"
+	ConjureCachedRegistrationTTL                     = "ConjureCachedRegistrationTTL"
+	ConjureAPIRegistrarURL                           = "ConjureAPIRegistrarURL"
+	ConjureAPIRegistrarFrontingSpecs                 = "ConjureAPIRegistrarFrontingSpecs"
+	ConjureAPIRegistrarMinDelay                      = "ConjureAPIRegistrarMinDelay"
+	ConjureAPIRegistrarMaxDelay                      = "ConjureAPIRegistrarMaxDelay"
+	ConjureDecoyRegistrarProbability                 = "ConjureDecoyRegistrarProbability"
 	ConjureDecoyRegistrarWidth                       = "ConjureDecoyRegistrarWidth"
 	ConjureDecoyRegistrarWidth                       = "ConjureDecoyRegistrarWidth"
+	ConjureDecoyRegistrarMinDelay                    = "ConjureDecoyRegistrarMinDelay"
+	ConjureDecoyRegistrarMaxDelay                    = "ConjureDecoyRegistrarMaxDelay"
 	ConjureTransportObfs4Probability                 = "ConjureTransportObfs4Probability"
 	ConjureTransportObfs4Probability                 = "ConjureTransportObfs4Probability"
 	CustomHostNameRegexes                            = "CustomHostNameRegexes"
 	CustomHostNameRegexes                            = "CustomHostNameRegexes"
 	CustomHostNameProbability                        = "CustomHostNameProbability"
 	CustomHostNameProbability                        = "CustomHostNameProbability"
@@ -577,7 +585,16 @@ var defaultParameters = map[string]struct {
 	ClientBurstDownstreamTargetBytes: {value: 0, minimum: 0},
 	ClientBurstDownstreamTargetBytes: {value: 0, minimum: 0},
 	ClientBurstDownstreamDeadline:    {value: time.Duration(0), minimum: time.Duration(0)},
 	ClientBurstDownstreamDeadline:    {value: time.Duration(0), minimum: time.Duration(0)},
 
 
-	ConjureDecoyRegistrarWidth:       {value: 5, minimum: 1},
+	ConjureCachedRegistrationTTL:     {value: time.Duration(0), minimum: time.Duration(0)},
+	ConjureAPIRegistrarURL:           {value: ""},
+	ConjureAPIRegistrarFrontingSpecs: {value: FrontingSpecs{}},
+	ConjureAPIRegistrarMinDelay:      {value: time.Duration(0), minimum: time.Duration(0)},
+	ConjureAPIRegistrarMaxDelay:      {value: time.Duration(0), minimum: time.Duration(0)},
+	ConjureDecoyRegistrarProbability: {value: 0.0, minimum: 0.0},
+	ConjureDecoyRegistrarWidth:       {value: 5, minimum: 0},
+	ConjureDecoyRegistrarMinDelay:    {value: time.Duration(0), minimum: time.Duration(0)},
+	ConjureDecoyRegistrarMaxDelay:    {value: time.Duration(0), minimum: time.Duration(0)},
+
 	ConjureTransportObfs4Probability: {value: 0.0, minimum: 0.0},
 	ConjureTransportObfs4Probability: {value: 0.0, minimum: 0.0},
 
 
 	CustomHostNameRegexes:        {value: RegexStrings{}},
 	CustomHostNameRegexes:        {value: RegexStrings{}},
@@ -1343,3 +1360,10 @@ func (p ParametersAccessor) RegexStrings(name string) RegexStrings {
 	p.snapshot.getValue(name, &value)
 	p.snapshot.getValue(name, &value)
 	return value
 	return value
 }
 }
+
+// FrontingSpecs returns a FrontingSpecs parameter value.
+func (p ParametersAccessor) FrontingSpecs(name string) FrontingSpecs {
+	value := FrontingSpecs{}
+	p.snapshot.getValue(name, &value)
+	return value
+}

+ 88 - 0
psiphon/common/refraction/config.go

@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package refraction
+
+import (
+	"net/http"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+)
+
+// ConjureConfig specifies the additional configuration for a Conjure dial.
+type ConjureConfig struct {
+
+	// RegistrationCacheTTL specifies how long to retain a successful Conjure
+	// registration for reuse in a subsequent dial. This value should be
+	// synchronized with the Conjure station configuration. When
+	// RegistrationCacheTTL is 0, registrations are not cached.
+	RegistrationCacheTTL time.Duration
+
+	// RegistrationCacheKey defines a scope or affinity for cached Conjure
+	// registrations. For example, the key can reflect the target Psiphon server
+	// as well as the current network ID. This ensures that any replay will
+	// always use the same cached registration, including its phantom IP(s). And
+	// ensures that the cache scope is restricted to the current network: when
+	// the network changes, the client's public IP changes, and previous
+	// registrations will become invalid. When the client returns to the original
+	// network, the previous registrations may be valid once again (assuming
+	// the client reverts back to its original public IP).
+	RegistrationCacheKey string
+
+	// APIRegistrarURL specifies the API registration endpoint. Setting
+	// APIRegistrarURL enables API registration. The domain fronting
+	// configuration provided by APIRegistrarHTTPClient may ignore the host
+	// portion of this URL, implicitly providing another value; the path portion
+	// is always used in the request. Only one of API registration or decoy
+	// registration can be enabled for a single dial.
+	APIRegistrarURL string
+
+	// APIRegistrarHTTPClient specifies a custom HTTP client (and underlying
+	// dialers) to be used for Conjure API registration. The
+	// APIRegistrarHTTPClient enables domain fronting of API registration web
+	// requests. This parameter is required when API registration is enabled.
+	APIRegistrarHTTPClient *http.Client
+
+	// APIRegistrarDelay specifies how long to wait after a successful API
+	// registration before initiating the phantom dial(s), as required by the
+	// Conjure protocol. This value depends on Conjure station operations and
+	// should be synchronized with the Conjure station configuration.
+	APIRegistrarDelay time.Duration
+
+	// DecoyRegistrarDialer specifies a custom dialer to be used for decoy
+	// registration. Only one of API registration or decoy registration can be
+	// enabled for a single dial.
+	DecoyRegistrarDialer common.NetDialer
+
+	// DecoyRegistrarWidth specifies how many decoys to use per registration.
+	DecoyRegistrarWidth int
+
+	// DecoyRegistrarDelay specifies how long to wait after a successful API
+	// registration before initiating the phantom dial(s), as required by the
+	// Conjure protocol.
+	//
+	// Limitation: this value is not exposed by gotapdance and is currently
+	// ignored.
+	DecoyRegistrarDelay time.Duration
+
+	// Transport may be protocol.CONJURE_TRANSPORT_MIN_OSSH or
+	// protocol.CONJURE_TRANSPORT_OBFS4_OSSH.
+	Transport string
+}

+ 287 - 26
psiphon/common/refraction/refraction.go

@@ -30,6 +30,7 @@ package refraction
 import (
 import (
 	"context"
 	"context"
 	"crypto/sha256"
 	"crypto/sha256"
+	"fmt"
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
 	"os"
 	"os"
@@ -42,12 +43,14 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/armon/go-proxyproto"
 	"github.com/armon/go-proxyproto"
+	lrucache "github.com/cognusion/go-cache-lru"
 	refraction_networking_proto "github.com/refraction-networking/gotapdance/protobuf"
 	refraction_networking_proto "github.com/refraction-networking/gotapdance/protobuf"
 	refraction_networking_client "github.com/refraction-networking/gotapdance/tapdance"
 	refraction_networking_client "github.com/refraction-networking/gotapdance/tapdance"
 )
 )
 
 
 const (
 const (
 	READ_PROXY_PROTOCOL_HEADER_TIMEOUT = 5 * time.Second
 	READ_PROXY_PROTOCOL_HEADER_TIMEOUT = 5 * time.Second
+	REGISTRATION_CACHE_MAX_ENTRIES     = 256
 )
 )
 
 
 // Enabled indicates if Refraction Networking functionality is enabled.
 // Enabled indicates if Refraction Networking functionality is enabled.
@@ -178,6 +181,8 @@ func (c *stationConn) GetMetrics() common.LogFields {
 // assets) are read from dataDirectory/"refraction-networking". When no config
 // assets) are read from dataDirectory/"refraction-networking". When no config
 // is found, default assets are paved.
 // is found, default assets are paved.
 //
 //
+// dialer specifies the custom dialer for underlying TCP dials.
+//
 // The input ctx is expected to have a timeout for the dial.
 // The input ctx is expected to have a timeout for the dial.
 //
 //
 // Limitation: the parameters emitLogs and dataDirectory are used for one-time
 // Limitation: the parameters emitLogs and dataDirectory are used for one-time
@@ -194,36 +199,31 @@ func DialTapDance(
 		emitLogs,
 		emitLogs,
 		dataDirectory,
 		dataDirectory,
 		dialer,
 		dialer,
-		false,
-		nil,
-		0,
-		"",
-		address)
+		address,
+		nil)
 }
 }
 
 
 // DialConjure establishes a new Conjure connection to a Conjure station.
 // DialConjure establishes a new Conjure connection to a Conjure station.
 //
 //
+// dialer specifies the custom dialer to use for phantom dials. Additional
+// Conjure-specific parameters are specified in conjureConfig.
+//
 // See DialTapdance comment.
 // See DialTapdance comment.
 func DialConjure(
 func DialConjure(
 	ctx context.Context,
 	ctx context.Context,
 	emitLogs bool,
 	emitLogs bool,
 	dataDirectory string,
 	dataDirectory string,
 	dialer common.NetDialer,
 	dialer common.NetDialer,
-	conjureDecoyRegistrarDialer common.NetDialer,
-	conjureDecoyRegistrarWidth int,
-	conjureTransport string,
-	address string) (net.Conn, error) {
+	address string,
+	conjureConfig *ConjureConfig) (net.Conn, error) {
 
 
 	return dial(
 	return dial(
 		ctx,
 		ctx,
 		emitLogs,
 		emitLogs,
 		dataDirectory,
 		dataDirectory,
 		dialer,
 		dialer,
-		true,
-		conjureDecoyRegistrarDialer,
-		conjureDecoyRegistrarWidth,
-		conjureTransport,
-		address)
+		address,
+		conjureConfig)
 }
 }
 
 
 func dial(
 func dial(
@@ -231,11 +231,8 @@ func dial(
 	emitLogs bool,
 	emitLogs bool,
 	dataDirectory string,
 	dataDirectory string,
 	dialer common.NetDialer,
 	dialer common.NetDialer,
-	useConjure bool,
-	conjureDecoyRegistrarDialer common.NetDialer,
-	conjureDecoyRegistrarWidth int,
-	conjureTransport string,
-	address string) (net.Conn, error) {
+	address string,
+	conjureConfig *ConjureConfig) (net.Conn, error) {
 
 
 	err := initRefractionNetworking(emitLogs, dataDirectory)
 	err := initRefractionNetworking(emitLogs, dataDirectory)
 	if err != nil {
 	if err != nil {
@@ -246,6 +243,8 @@ func dial(
 		return nil, errors.TraceNew("dial context has no timeout")
 		return nil, errors.TraceNew("dial context has no timeout")
 	}
 	}
 
 
+	useConjure := conjureConfig != nil
+
 	manager := newDialManager()
 	manager := newDialManager()
 
 
 	refractionDialer := &refraction_networking_client.Dialer{
 	refractionDialer := &refraction_networking_client.Dialer{
@@ -253,23 +252,125 @@ func dial(
 		UseProxyHeader: true,
 		UseProxyHeader: true,
 	}
 	}
 
 
+	conjureCached := false
+	conjureDelay := time.Duration(0)
+
+	var conjureCachedRegistration *refraction_networking_client.ConjureReg
+	var conjureRecordRegistrar *recordRegistrar
+
 	if useConjure {
 	if useConjure {
 
 
+		// Our strategy is to try one registration per dial attempt: a cached
+		// registration, if it exists, or API or decoy registration, as configured.
+		// This assumes Psiphon establishment will try/retry many candidates as
+		// required, and that the desired mix of API/decoy registrations will be
+		// configured and generated. In good network conditions, internal gotapdance
+		// retries (via APIRegistrar.MaxRetries or APIRegistrar.SecondaryRegistrar)
+		// are unlikely to start before the Conjure dial is canceled.
+
+		// Caching registrations reduces average Conjure dial time by often
+		// eliminating the registration phase. This is especially impactful for
+		// short duration tunnels, such as on mobile. Caching also reduces domain
+		// fronted traffic and load on the API registrar and decoys.
+		//
+		// We implement a simple in-memory registration cache with the following
+		// behavior:
+		//
+		// - If a new registration succeeds, but the overall Conjure dial is
+		//   _canceled_, the registration is optimistically cached.
+		// - If the Conjure phantom dial fails, any associated cached registration
+		//   is discarded.
+		// - A cached registration's TTL is extended upon phantom dial success.
+		// - If the configured TTL changes, the cache is cleared.
+		//
+		// Limitations:
+		// - The cache is not persistent.
+		// - There is no TTL extension during a long connection.
+		// - Caching a successful registration when the phantom dial is canceled may
+		//   skip the necessary "delay" step (however, an immediate re-establishment
+		//   to the same candidate is unlikely in this case).
+		//
+		// TODO:
+		// - Revisit when gotapdance adds its own caching.
+		// - Consider "pre-registering" Conjure when already connected with a
+		//   different protocol, so a Conjure registration is available on the next
+		//   establishment; in this scenario, a tunneled API registration would not
+		//   require domain fronting.
+
 		refractionDialer.DarkDecoy = true
 		refractionDialer.DarkDecoy = true
 
 
-		refractionDialer.DarkDecoyRegistrar = refraction_networking_client.DecoyRegistrar{
-			TcpDialer: manager.makeManagedDialer(conjureDecoyRegistrarDialer.DialContext),
+		// The pop operation removes the registration from the cache. This
+		// eliminates the possibility of concurrent candidates (with the same cache
+		// key) using and modifying the same registration, a potential race
+		// condition. The popped cached registration must be reinserted in the cache
+		// after canceling or success, but not on phantom dial failure.
+
+		conjureCachedRegistration = conjureRegistrationCache.pop(
+			conjureConfig.RegistrationCacheTTL,
+			conjureConfig.RegistrationCacheKey)
+
+		if conjureCachedRegistration != nil {
+
+			refractionDialer.DarkDecoyRegistrar = &cachedRegistrar{
+				registration: conjureCachedRegistration,
+			}
+
+			conjureCached = true
+			conjureDelay = 0 // report no delay
+
+		} else if conjureConfig.APIRegistrarURL != "" {
+
+			if conjureConfig.APIRegistrarHTTPClient == nil {
+				// While not a guaranteed check, if the APIRegistrarHTTPClient isn't set
+				// then the API registration would certainly be unfronted, resulting in a
+				// fingerprintable connection leak.
+				return nil, errors.TraceNew("missing APIRegistrarHTTPClient")
+			}
+
+			refractionDialer.DarkDecoyRegistrar = &refraction_networking_client.APIRegistrar{
+				Endpoint:        conjureConfig.APIRegistrarURL,
+				ConnectionDelay: conjureConfig.APIRegistrarDelay,
+				MaxRetries:      0,
+				Client:          conjureConfig.APIRegistrarHTTPClient,
+			}
+
+			conjureDelay = conjureConfig.APIRegistrarDelay
+
+		} else if conjureConfig.DecoyRegistrarDialer != nil {
+
+			refractionDialer.DarkDecoyRegistrar = &refraction_networking_client.DecoyRegistrar{
+				TcpDialer: manager.makeManagedDialer(
+					conjureConfig.DecoyRegistrarDialer.DialContext),
+			}
+
+			refractionDialer.Width = conjureConfig.DecoyRegistrarWidth
+
+			// Limitation: the decoy regsitration delay is not currently exposed in the
+			// gotapdance API.
+			conjureDelay = -1 // don't report delay
+
+		} else {
+
+			return nil, errors.TraceNew("no conjure registrar specified")
+		}
+
+		if conjureCachedRegistration == nil && conjureConfig.RegistrationCacheTTL != 0 {
+
+			// Record the registration result in order to cache it.
+			conjureRecordRegistrar = &recordRegistrar{
+				registrar: refractionDialer.DarkDecoyRegistrar,
+			}
+			refractionDialer.DarkDecoyRegistrar = conjureRecordRegistrar
 		}
 		}
-		refractionDialer.Width = conjureDecoyRegistrarWidth
 
 
-		switch conjureTransport {
+		switch conjureConfig.Transport {
 		case protocol.CONJURE_TRANSPORT_MIN_OSSH:
 		case protocol.CONJURE_TRANSPORT_MIN_OSSH:
 			refractionDialer.Transport = refraction_networking_proto.TransportType_Min
 			refractionDialer.Transport = refraction_networking_proto.TransportType_Min
 			refractionDialer.TcpDialer = newMinTransportDialer(refractionDialer.TcpDialer)
 			refractionDialer.TcpDialer = newMinTransportDialer(refractionDialer.TcpDialer)
 		case protocol.CONJURE_TRANSPORT_OBFS4_OSSH:
 		case protocol.CONJURE_TRANSPORT_OBFS4_OSSH:
 			refractionDialer.Transport = refraction_networking_proto.TransportType_Obfs4
 			refractionDialer.Transport = refraction_networking_proto.TransportType_Obfs4
 		default:
 		default:
-			return nil, errors.Tracef("invalid Conjure transport: %s", conjureTransport)
+			return nil, errors.Tracef("invalid Conjure transport: %s", conjureConfig.Transport)
 		}
 		}
 	}
 	}
 
 
@@ -294,17 +395,152 @@ func dial(
 
 
 	conn, err := refractionDialer.DialContext(ctx, "tcp", address)
 	conn, err := refractionDialer.DialContext(ctx, "tcp", address)
 	close(dialComplete)
 	close(dialComplete)
+
 	if err != nil {
 	if err != nil {
+		// Call manager.close before updating cache, to synchronously shutdown dials
+		// and ensure there are no further concurrent reads/writes to the recorded
+		// registration before referencing it.
 		manager.close()
 		manager.close()
+	}
+
+	// Cache (or put back) a successful registration. Also put back in the
+	// specific error case where the phantom dial was canceled, as the
+	// registration may still be valid. This operation implicitly extends the TTL
+	// of a reused cached registration; we assume the Conjure station is also
+	// extending the TTL by the same amount.
+	//
+	// Limitation: the cancel case shouldn't extend the TTL.
+
+	if useConjure &&
+		(err == nil || ctx.Err() == context.Canceled) &&
+		(conjureCachedRegistration != nil || conjureRecordRegistrar != nil) {
+
+		registration := conjureCachedRegistration
+		if registration == nil {
+			// We assume gotapdance is no longer accessing the Registrar.
+			registration = conjureRecordRegistrar.registration
+		}
+
+		// conjureRecordRegistrar.registration will be nil there was no cached
+		// registration _and_ registration didn't succeed before a cancel.
+		if registration != nil {
+			conjureRegistrationCache.put(
+				conjureConfig.RegistrationCacheTTL,
+				conjureConfig.RegistrationCacheKey,
+				registration)
+		}
+	}
+
+	if err != nil {
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
 
 
 	manager.startUsingRunCtx()
 	manager.startUsingRunCtx()
 
 
-	return &refractionConn{
+	refractionConn := &refractionConn{
 		Conn:    conn,
 		Conn:    conn,
 		manager: manager,
 		manager: manager,
-	}, nil
+	}
+
+	if useConjure {
+		// Retain these values for logging metrics.
+		refractionConn.isConjure = true
+		refractionConn.conjureCached = conjureCached
+		refractionConn.conjureDelay = conjureDelay
+		refractionConn.conjureTransport = conjureConfig.Transport
+	}
+
+	return refractionConn, nil
+}
+
+type registrationCache struct {
+	mutex sync.Mutex
+	TTL   time.Duration
+	cache *lrucache.Cache
+}
+
+func newRegistrationCache() *registrationCache {
+	return &registrationCache{
+		cache: lrucache.NewWithLRU(
+			lrucache.NoExpiration,
+			1*time.Minute,
+			REGISTRATION_CACHE_MAX_ENTRIES),
+	}
+}
+
+func (c *registrationCache) put(
+	TTL time.Duration,
+	key string,
+	registration *refraction_networking_client.ConjureReg) {
+
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
+
+	// Clear the entire cache if the configured TTL changes to avoid retaining
+	// items for too long. This is expected to be an infrequent event. The
+	// go-cache-lru API does not offer a mechanism to inspect and adjust the TTL
+	// of all existing items.
+	if c.TTL != TTL {
+		c.cache.Flush()
+		c.TTL = TTL
+	}
+
+	c.cache.Set(
+		key,
+		registration,
+		c.TTL)
+}
+
+func (c *registrationCache) pop(
+	TTL time.Duration,
+	key string) *refraction_networking_client.ConjureReg {
+
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
+
+	// See TTL/Flush comment in put.
+	if c.TTL != TTL {
+		c.cache.Flush()
+		c.TTL = TTL
+	}
+
+	entry, found := c.cache.Get(key)
+	if found {
+		c.cache.Delete(key)
+		return entry.(*refraction_networking_client.ConjureReg)
+	}
+
+	return nil
+}
+
+var conjureRegistrationCache = newRegistrationCache()
+
+type cachedRegistrar struct {
+	registration *refraction_networking_client.ConjureReg
+}
+
+func (r *cachedRegistrar) Register(
+	_ *refraction_networking_client.ConjureSession,
+	_ context.Context) (*refraction_networking_client.ConjureReg, error) {
+
+	return r.registration, nil
+}
+
+type recordRegistrar struct {
+	registrar    refraction_networking_client.Registrar
+	registration *refraction_networking_client.ConjureReg
+}
+
+func (r *recordRegistrar) Register(
+	session *refraction_networking_client.ConjureSession,
+	ctx context.Context) (*refraction_networking_client.ConjureReg, error) {
+
+	registration, err := r.registrar.Register(session, ctx)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	r.registration = registration
+	return registration, nil
 }
 }
 
 
 // minTransportConn buffers the first 32-byte random HMAC write performed by
 // minTransportConn buffers the first 32-byte random HMAC write performed by
@@ -516,6 +752,11 @@ type refractionConn struct {
 	net.Conn
 	net.Conn
 	manager  *dialManager
 	manager  *dialManager
 	isClosed int32
 	isClosed int32
+
+	isConjure        bool
+	conjureCached    bool
+	conjureDelay     time.Duration
+	conjureTransport string
 }
 }
 
 
 func (conn *refractionConn) Close() error {
 func (conn *refractionConn) Close() error {
@@ -529,6 +770,26 @@ func (conn *refractionConn) IsClosed() bool {
 	return atomic.LoadInt32(&conn.isClosed) == 1
 	return atomic.LoadInt32(&conn.isClosed) == 1
 }
 }
 
 
+// GetMetrics implements the common.MetricsSource interface.
+func (conn *refractionConn) GetMetrics() common.LogFields {
+	logFields := make(common.LogFields)
+	if conn.isConjure {
+
+		cached := "0"
+		if conn.conjureCached {
+			cached = "1"
+		}
+		logFields["conjure_cached"] = cached
+
+		if conn.conjureDelay != -1 {
+			logFields["conjure_delay"] = fmt.Sprintf("%d", conn.conjureDelay/time.Millisecond)
+		}
+
+		logFields["conjure_transport"] = conn.conjureTransport
+	}
+	return logFields
+}
+
 var initRefractionNetworkingOnce sync.Once
 var initRefractionNetworkingOnce sync.Once
 
 
 func initRefractionNetworking(emitLogs bool, dataDirectory string) error {
 func initRefractionNetworking(emitLogs bool, dataDirectory string) error {

+ 1 - 1
psiphon/common/refraction/refraction_disabled.go

@@ -50,6 +50,6 @@ func DialTapDance(_ context.Context, _ bool, _ string, _ common.NetDialer, _ str
 }
 }
 
 
 // DialConjure establishes a new Conjure connection to a Conjure station.
 // DialConjure establishes a new Conjure connection to a Conjure station.
-func DialConjure(_ context.Context, _ bool, _ string, _, _ common.NetDialer, _ int, _, _ string) (net.Conn, error) {
+func DialConjure(_ context.Context, _ bool, _ string, _ common.NetDialer, _ string, _ *ConjureConfig) (net.Conn, error) {
 	return nil, errors.TraceNew("operation is not enabled")
 	return nil, errors.TraceNew("operation is not enabled")
 }
 }

+ 15 - 14
psiphon/common/tactics/tactics.go

@@ -1225,14 +1225,15 @@ func (server *Server) handleTacticsRequest(
 	server.logger.LogMetric(TACTICS_METRIC_EVENT_NAME, logFields)
 	server.logger.LogMetric(TACTICS_METRIC_EVENT_NAME, logFields)
 }
 }
 
 
-// RoundTripper performs a round trip to the specified endpoint, sending the
-// request body and returning the response body. The context may be used to
-// set a timeout or cancel the rount trip.
+// ObfuscatedRoundTripper performs a round trip to the specified endpoint,
+// sending the request body and returning the response body, with an
+// obfuscation layer applied to the endpoint value. The context may be used
+// to set a timeout or cancel the round trip.
 //
 //
-// The Psiphon client provides a RoundTripper using meek. The client will
-// handle connection details including server selection, dialing details
-// including device binding and upstream proxy, etc.
-type RoundTripper func(
+// The Psiphon client provides a ObfuscatedRoundTripper using MeekConn. The
+// client will handle connection details including server selection, dialing
+// details including device binding and upstream proxy, etc.
+type ObfuscatedRoundTripper func(
 	ctx context.Context,
 	ctx context.Context,
 	endPoint string,
 	endPoint string,
 	requestBody []byte) ([]byte, error)
 	requestBody []byte) ([]byte, error)
@@ -1343,11 +1344,11 @@ func UseStoredTactics(
 // FetchTactics performs a tactics request. When there are no stored
 // FetchTactics performs a tactics request. When there are no stored
 // speed test samples for the network ID, a speed test request is
 // speed test samples for the network ID, a speed test request is
 // performed immediately before the tactics request, using the same
 // performed immediately before the tactics request, using the same
-// RoundTripper.
+// ObfuscatedRoundTripper.
 //
 //
-// The RoundTripper transport should be established in advance, so that
-// calls to RoundTripper don't take additional time in TCP, TLS, etc.
-// handshakes.
+// The ObfuscatedRoundTripper transport should be established in advance, so
+// that calls to ObfuscatedRoundTripper don't take additional time in TCP,
+// TLS, etc. handshakes.
 //
 //
 // The caller should first call UseStoredTactics and skip FetchTactics
 // The caller should first call UseStoredTactics and skip FetchTactics
 // when there is an unexpired stored tactics record available. The
 // when there is an unexpired stored tactics record available. The
@@ -1371,7 +1372,7 @@ func FetchTactics(
 	endPointProtocol string,
 	endPointProtocol string,
 	encodedRequestPublicKey string,
 	encodedRequestPublicKey string,
 	encodedRequestObfuscatedKey string,
 	encodedRequestObfuscatedKey string,
-	roundTripper RoundTripper) (*Record, error) {
+	obfuscatedRoundTripper ObfuscatedRoundTripper) (*Record, error) {
 
 
 	networkID := getNetworkID()
 	networkID := getNetworkID()
 
 
@@ -1396,7 +1397,7 @@ func FetchTactics(
 
 
 		startTime := time.Now()
 		startTime := time.Now()
 
 
-		response, err := roundTripper(ctx, SPEED_TEST_END_POINT, request)
+		response, err := obfuscatedRoundTripper(ctx, SPEED_TEST_END_POINT, request)
 
 
 		elapsedTime := time.Since(startTime)
 		elapsedTime := time.Since(startTime)
 
 
@@ -1458,7 +1459,7 @@ func FetchTactics(
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
 
 
-	boxedResponse, err := roundTripper(ctx, TACTICS_END_POINT, boxedRequest)
+	boxedResponse, err := obfuscatedRoundTripper(ctx, TACTICS_END_POINT, boxedRequest)
 	if err != nil {
 	if err != nil {
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}

+ 9 - 6
psiphon/common/tactics/tactics_test.go

@@ -248,11 +248,14 @@ func TestTactics(t *testing.T) {
 	endPointProtocol := "OSSH"
 	endPointProtocol := "OSSH"
 	differentEndPointProtocol := "SSH"
 	differentEndPointProtocol := "SSH"
 
 
-	roundTripper := func(
+	obfuscatedRoundTripper := func(
 		ctx context.Context,
 		ctx context.Context,
 		endPoint string,
 		endPoint string,
 		requestBody []byte) ([]byte, error) {
 		requestBody []byte) ([]byte, error) {
 
 
+		// This mock ObfuscatedRoundTripper does not actually obfuscate the endpoint
+		// value.
+
 		request, err := http.NewRequest(
 		request, err := http.NewRequest(
 			"POST",
 			"POST",
 			fmt.Sprintf("http://%s/%s", serverAddress, endPoint),
 			fmt.Sprintf("http://%s/%s", serverAddress, endPoint),
@@ -341,7 +344,7 @@ func TestTactics(t *testing.T) {
 		endPointRegion,
 		endPointRegion,
 		encodedRequestPublicKey,
 		encodedRequestPublicKey,
 		encodedObfuscatedKey,
 		encodedObfuscatedKey,
-		roundTripper)
+		obfuscatedRoundTripper)
 
 
 	cancelFunc()
 	cancelFunc()
 
 
@@ -413,7 +416,7 @@ func TestTactics(t *testing.T) {
 		endPointRegion,
 		endPointRegion,
 		encodedRequestPublicKey,
 		encodedRequestPublicKey,
 		encodedObfuscatedKey,
 		encodedObfuscatedKey,
-		roundTripper)
+		obfuscatedRoundTripper)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("FetchTactics failed: %s", err)
 		t.Fatalf("FetchTactics failed: %s", err)
 	}
 	}
@@ -490,7 +493,7 @@ func TestTactics(t *testing.T) {
 		endPointRegion,
 		endPointRegion,
 		encodedRequestPublicKey,
 		encodedRequestPublicKey,
 		encodedObfuscatedKey,
 		encodedObfuscatedKey,
-		roundTripper)
+		obfuscatedRoundTripper)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("FetchTactics failed: %s", err)
 		t.Fatalf("FetchTactics failed: %s", err)
 	}
 	}
@@ -646,7 +649,7 @@ func TestTactics(t *testing.T) {
 		endPointRegion,
 		endPointRegion,
 		encodedIncorrectRequestPublicKey,
 		encodedIncorrectRequestPublicKey,
 		encodedObfuscatedKey,
 		encodedObfuscatedKey,
-		roundTripper)
+		obfuscatedRoundTripper)
 	if err == nil {
 	if err == nil {
 		t.Fatalf("FetchTactics succeeded unexpectedly with incorrect request key")
 		t.Fatalf("FetchTactics succeeded unexpectedly with incorrect request key")
 	}
 	}
@@ -661,7 +664,7 @@ func TestTactics(t *testing.T) {
 		endPointRegion,
 		endPointRegion,
 		encodedRequestPublicKey,
 		encodedRequestPublicKey,
 		encodedIncorrectObfuscatedKey,
 		encodedIncorrectObfuscatedKey,
-		roundTripper)
+		obfuscatedRoundTripper)
 	if err == nil {
 	if err == nil {
 		t.Fatalf("FetchTactics succeeded unexpectedly with incorrect obfuscated key")
 		t.Fatalf("FetchTactics succeeded unexpectedly with incorrect obfuscated key")
 	}
 	}

+ 48 - 0
psiphon/config.go

@@ -718,6 +718,18 @@ type Config struct {
 	CustomHostNameProbability    *float64
 	CustomHostNameProbability    *float64
 	CustomHostNameLimitProtocols []string
 	CustomHostNameLimitProtocols []string
 
 
+	// ConjureCachedRegistrationTTLSeconds and other Conjure fields are for
+	// testing purposes.
+	ConjureCachedRegistrationTTLSeconds       *int
+	ConjureAPIRegistrarURL                    string
+	ConjureAPIRegistrarFrontingSpecs          parameters.FrontingSpecs
+	ConjureAPIRegistrarMinDelayMilliseconds   *int
+	ConjureAPIRegistrarMaxDelayMilliseconds   *int
+	ConjureDecoyRegistrarProbability          *float64
+	ConjureDecoyRegistrarWidth                *int
+	ConjureDecoyRegistrarMinDelayMilliseconds *int
+	ConjureDecoyRegistrarMaxDelayMilliseconds *int
+
 	// params is the active parameters.Parameters with defaults, config values,
 	// params is the active parameters.Parameters with defaults, config values,
 	// and, optionally, tactics applied.
 	// and, optionally, tactics applied.
 	//
 	//
@@ -1615,6 +1627,42 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.CustomHostNameLimitProtocols] = protocol.TunnelProtocols(config.CustomHostNameLimitProtocols)
 		applyParameters[parameters.CustomHostNameLimitProtocols] = protocol.TunnelProtocols(config.CustomHostNameLimitProtocols)
 	}
 	}
 
 
+	if config.ConjureCachedRegistrationTTLSeconds != nil {
+		applyParameters[parameters.ConjureCachedRegistrationTTL] = fmt.Sprintf("%dms", *config.ConjureCachedRegistrationTTLSeconds)
+	}
+
+	if config.ConjureAPIRegistrarURL != "" {
+		applyParameters[parameters.ConjureAPIRegistrarURL] = config.ConjureAPIRegistrarURL
+	}
+
+	if config.ConjureAPIRegistrarFrontingSpecs != nil {
+		applyParameters[parameters.ConjureAPIRegistrarFrontingSpecs] = config.ConjureAPIRegistrarFrontingSpecs
+	}
+
+	if config.ConjureAPIRegistrarMinDelayMilliseconds != nil {
+		applyParameters[parameters.ConjureAPIRegistrarMinDelay] = fmt.Sprintf("%dms", *config.ConjureAPIRegistrarMinDelayMilliseconds)
+	}
+
+	if config.ConjureAPIRegistrarMaxDelayMilliseconds != nil {
+		applyParameters[parameters.ConjureAPIRegistrarMaxDelay] = fmt.Sprintf("%dms", *config.ConjureAPIRegistrarMaxDelayMilliseconds)
+	}
+
+	if config.ConjureDecoyRegistrarProbability != nil {
+		applyParameters[parameters.ConjureDecoyRegistrarProbability] = *config.ConjureDecoyRegistrarProbability
+	}
+
+	if config.ConjureDecoyRegistrarWidth != nil {
+		applyParameters[parameters.ConjureDecoyRegistrarWidth] = *config.ConjureDecoyRegistrarWidth
+	}
+
+	if config.ConjureDecoyRegistrarMinDelayMilliseconds != nil {
+		applyParameters[parameters.ConjureDecoyRegistrarMinDelay] = fmt.Sprintf("%dms", *config.ConjureDecoyRegistrarMinDelayMilliseconds)
+	}
+
+	if config.ConjureDecoyRegistrarMaxDelayMilliseconds != nil {
+		applyParameters[parameters.ConjureDecoyRegistrarMaxDelay] = fmt.Sprintf("%dms", *config.ConjureDecoyRegistrarMaxDelayMilliseconds)
+	}
+
 	return applyParameters
 	return applyParameters
 }
 }
 
 

+ 30 - 26
psiphon/controller.go

@@ -1232,6 +1232,22 @@ func (controller *Controller) Dial(
 		return nil, errors.TraceNew("no active tunnels")
 		return nil, errors.TraceNew("no active tunnels")
 	}
 	}
 
 
+	if !controller.config.EnableSplitTunnel {
+
+		tunneledConn, splitTunnel, err := tunnel.DialTCPChannel(
+			remoteAddr, false, downstreamConn)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		if splitTunnel {
+			return nil, errors.TraceNew(
+				"unexpected split tunnel classification")
+		}
+
+		return tunneledConn, nil
+	}
+
 	// In split tunnel mode, TCP port forwards to destinations in the same
 	// In split tunnel mode, TCP port forwards to destinations in the same
 	// country as the client are untunneled.
 	// country as the client are untunneled.
 	//
 	//
@@ -1255,22 +1271,17 @@ func (controller *Controller) Dial(
 	// it does for all port forwards in non-split tunnel mode. There is no
 	// it does for all port forwards in non-split tunnel mode. There is no
 	// additional round trip for tunneled port forwards.
 	// additional round trip for tunneled port forwards.
 
 
-	untunneledCache := controller.untunneledSplitTunnelClassifications
-	var splitTunnelHost string
-	cachedUntunneled := false
+	splitTunnelHost, _, err := net.SplitHostPort(remoteAddr)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
 
 
-	if controller.config.EnableSplitTunnel {
-		var err error
-		splitTunnelHost, _, err = net.SplitHostPort(remoteAddr)
-		if err != nil {
-			return nil, errors.Trace(err)
-		}
+	untunneledCache := controller.untunneledSplitTunnelClassifications
 
 
-		// If the destination hostname is in the untunneled split tunnel
-		// classifications cache, skip the round trip to the server and do the
-		// direct, untunneled dial immediately.
-		_, cachedUntunneled = untunneledCache.Get(splitTunnelHost)
-	}
+	// If the destination hostname is in the untunneled split tunnel
+	// classifications cache, skip the round trip to the server and do the
+	// direct, untunneled dial immediately.
+	_, cachedUntunneled := untunneledCache.Get(splitTunnelHost)
 
 
 	if !cachedUntunneled {
 	if !cachedUntunneled {
 
 
@@ -1282,25 +1293,17 @@ func (controller *Controller) Dial(
 
 
 		if !splitTunnel {
 		if !splitTunnel {
 
 
-			if controller.config.EnableSplitTunnel {
-
-				// Clear any cached untunneled classification entry for this destination
-				// hostname, as the server is now classifying it as tunneled.
-				untunneledCache.Delete(splitTunnelHost)
-			}
+			// Clear any cached untunneled classification entry for this destination
+			// hostname, as the server is now classifying it as tunneled.
+			untunneledCache.Delete(splitTunnelHost)
 
 
 			return tunneledConn, nil
 			return tunneledConn, nil
 		}
 		}
 
 
-		if !controller.config.EnableSplitTunnel {
-			return nil, errors.TraceNew(
-				"unexpected split tunnel classification")
-		}
-
 		// The server has indicated that the client should make a direct,
 		// The server has indicated that the client should make a direct,
 		// untunneled dial. Cache the classification to avoid this round trip in
 		// untunneled dial. Cache the classification to avoid this round trip in
 		// the immediate future.
 		// the immediate future.
-		untunneledCache.Add(splitTunnelHost, true, 0)
+		untunneledCache.Add(splitTunnelHost, true, lrucache.DefaultExpiration)
 	}
 	}
 
 
 	NoticeUntunneled(splitTunnelHost)
 	NoticeUntunneled(splitTunnelHost)
@@ -1309,6 +1312,7 @@ func (controller *Controller) Dial(
 	if err != nil {
 	if err != nil {
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
+
 	return untunneledConn, nil
 	return untunneledConn, nil
 }
 }
 
 

+ 131 - 30
psiphon/dialParameters.go

@@ -95,6 +95,8 @@ type DialParameters struct {
 	MeekDialAddress           string
 	MeekDialAddress           string
 	MeekTransformedHostName   bool
 	MeekTransformedHostName   bool
 	MeekSNIServerName         string
 	MeekSNIServerName         string
+	MeekVerifyServerName      string
+	MeekVerifyPins            []string
 	MeekHostHeader            string
 	MeekHostHeader            string
 	MeekObfuscatorPaddingSeed *prng.Seed
 	MeekObfuscatorPaddingSeed *prng.Seed
 	MeekTLSPaddingSize        int
 	MeekTLSPaddingSize        int
@@ -113,8 +115,14 @@ type DialParameters struct {
 	QUICDialSNIAddress        string
 	QUICDialSNIAddress        string
 	ObfuscatedQUICPaddingSeed *prng.Seed
 	ObfuscatedQUICPaddingSeed *prng.Seed
 
 
-	ConjureDecoyRegistrarWidth int
-	ConjureTransport           string
+	ConjureCachedRegistrationTTL time.Duration
+	ConjureAPIRegistration       bool
+	ConjureAPIRegistrarURL       string
+	ConjureAPIRegistrarDelay     time.Duration
+	ConjureDecoyRegistration     bool
+	ConjureDecoyRegistrarDelay   time.Duration
+	ConjureDecoyRegistrarWidth   int
+	ConjureTransport             string
 
 
 	LivenessTestSeed *prng.Seed
 	LivenessTestSeed *prng.Seed
 
 
@@ -392,14 +400,115 @@ func MakeDialParameters(
 		}
 		}
 	}
 	}
 
 
-	if (!isReplay || !replayTLSProfile) &&
-		protocol.TunnelProtocolUsesMeekHTTPS(dialParams.TunnelProtocol) {
+	if (!isReplay || !replayConjureRegistration) &&
+		protocol.TunnelProtocolUsesConjure(dialParams.TunnelProtocol) {
+
+		dialParams.ConjureCachedRegistrationTTL = p.Duration(parameters.ConjureCachedRegistrationTTL)
+
+		apiURL := p.String(parameters.ConjureAPIRegistrarURL)
+		decoyWidth := p.Int(parameters.ConjureDecoyRegistrarWidth)
+
+		dialParams.ConjureAPIRegistration = apiURL != ""
+		dialParams.ConjureDecoyRegistration = decoyWidth != 0
+
+		// We select only one of API or decoy registration. When both are enabled,
+		// ConjureDecoyRegistrarProbability determines the probability of using
+		// decoy registration.
+		//
+		// In general, we disable retries in gotapdance and rely on Psiphon
+		// establishment to try/retry different registration schemes. This allows us
+		// to control the proportion of registration types attempted. And, in good
+		// network conditions, individual candidates are most likely to be cancelled
+		// before they exhaust their retry options.
+
+		if dialParams.ConjureAPIRegistration && dialParams.ConjureDecoyRegistration {
+			if p.WeightedCoinFlip(parameters.ConjureDecoyRegistrarProbability) {
+				dialParams.ConjureAPIRegistration = false
+			}
+		}
+
+		if dialParams.ConjureAPIRegistration {
+
+			// While Conjure API registration uses MeekConn and specifies common meek
+			// parameters, the meek address and SNI configuration is implemented in this
+			// code block and not in common code blocks below. The exception is TLS
+			// configuration.
+			//
+			// Accordingly, replayFronting/replayHostname have no effect on Conjure API
+			// registration replay.
+
+			dialParams.ConjureAPIRegistrarURL = apiURL
+
+			frontingSpecs := p.FrontingSpecs(parameters.ConjureAPIRegistrarFrontingSpecs)
+			dialParams.FrontingProviderID,
+				dialParams.MeekFrontingDialAddress,
+				dialParams.MeekSNIServerName,
+				dialParams.MeekVerifyServerName,
+				dialParams.MeekVerifyPins,
+				dialParams.MeekFrontingHost,
+				err = frontingSpecs.SelectParameters()
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+
+			dialParams.MeekDialAddress = fmt.Sprintf("%s:443", dialParams.MeekFrontingDialAddress)
+			dialParams.MeekHostHeader = dialParams.MeekFrontingHost
+
+			// For a FrontingSpec, an SNI value of "" indicates to disable/omit SNI, so
+			// never transform in that case.
+			if dialParams.MeekSNIServerName != "" {
+				if p.WeightedCoinFlip(parameters.TransformHostNameProbability) {
+					dialParams.MeekSNIServerName = selectHostName(dialParams.TunnelProtocol, p)
+					dialParams.MeekTransformedHostName = true
+				}
+			}
+
+			// The minimum delay value is determined by the Conjure station, which
+			// performs an asynchronous "liveness test" against the selected phantom
+			// IPs. The min/max range allows us to introduce some jitter so that we
+			// don't present a trivial inter-flow fingerprint: CDN connection, fixed
+			// delay, phantom dial.
+
+			minDelay := p.Duration(parameters.ConjureAPIRegistrarMinDelay)
+			maxDelay := p.Duration(parameters.ConjureAPIRegistrarMaxDelay)
+			dialParams.ConjureAPIRegistrarDelay = prng.Period(minDelay, maxDelay)
+
+		} else if dialParams.ConjureDecoyRegistration {
+
+			dialParams.ConjureDecoyRegistrarWidth = decoyWidth
+			minDelay := p.Duration(parameters.ConjureDecoyRegistrarMinDelay)
+			maxDelay := p.Duration(parameters.ConjureDecoyRegistrarMaxDelay)
+			dialParams.ConjureAPIRegistrarDelay = prng.Period(minDelay, maxDelay)
+
+		} else {
+
+			return nil, errors.TraceNew("no Conjure registrar configured")
+		}
+	}
+
+	if (!isReplay || !replayConjureTransport) &&
+		protocol.TunnelProtocolUsesConjure(dialParams.TunnelProtocol) {
+
+		dialParams.ConjureTransport = protocol.CONJURE_TRANSPORT_MIN_OSSH
+		if p.WeightedCoinFlip(
+			parameters.ConjureTransportObfs4Probability) {
+			dialParams.ConjureTransport = protocol.CONJURE_TRANSPORT_OBFS4_OSSH
+		}
+	}
+
+	usingTLS := protocol.TunnelProtocolUsesMeekHTTPS(dialParams.TunnelProtocol) ||
+		dialParams.ConjureAPIRegistration
+
+	if (!isReplay || !replayTLSProfile) && usingTLS {
 
 
 		dialParams.SelectedTLSProfile = true
 		dialParams.SelectedTLSProfile = true
 
 
 		requireTLS12SessionTickets := protocol.TunnelProtocolRequiresTLS12SessionTickets(
 		requireTLS12SessionTickets := protocol.TunnelProtocolRequiresTLS12SessionTickets(
 			dialParams.TunnelProtocol)
 			dialParams.TunnelProtocol)
-		isFronted := protocol.TunnelProtocolUsesFrontedMeek(dialParams.TunnelProtocol)
+
+		isFronted := protocol.TunnelProtocolUsesFrontedMeek(dialParams.TunnelProtocol) ||
+			dialParams.ConjureAPIRegistration
+
 		dialParams.TLSProfile = SelectTLSProfile(
 		dialParams.TLSProfile = SelectTLSProfile(
 			requireTLS12SessionTickets, isFronted, serverEntry.FrontingProviderID, p)
 			requireTLS12SessionTickets, isFronted, serverEntry.FrontingProviderID, p)
 
 
@@ -407,8 +516,7 @@ func MakeDialParameters(
 			parameters.NoDefaultTLSSessionIDProbability)
 			parameters.NoDefaultTLSSessionIDProbability)
 	}
 	}
 
 
-	if (!isReplay || !replayRandomizedTLSProfile) &&
-		protocol.TunnelProtocolUsesMeekHTTPS(dialParams.TunnelProtocol) &&
+	if (!isReplay || !replayRandomizedTLSProfile) && usingTLS &&
 		protocol.TLSProfileIsRandomized(dialParams.TLSProfile) {
 		protocol.TLSProfileIsRandomized(dialParams.TLSProfile) {
 
 
 		dialParams.RandomizedTLSProfileSeed, err = prng.NewSeed()
 		dialParams.RandomizedTLSProfileSeed, err = prng.NewSeed()
@@ -417,8 +525,7 @@ func MakeDialParameters(
 		}
 		}
 	}
 	}
 
 
-	if (!isReplay || !replayTLSProfile) &&
-		protocol.TunnelProtocolUsesMeekHTTPS(dialParams.TunnelProtocol) {
+	if (!isReplay || !replayTLSProfile) && usingTLS {
 
 
 		// Since "Randomized-v2"/CustomTLSProfiles may be TLS 1.2 or TLS 1.3,
 		// Since "Randomized-v2"/CustomTLSProfiles may be TLS 1.2 or TLS 1.3,
 		// construct the ClientHello to determine if it's TLS 1.3. This test also
 		// construct the ClientHello to determine if it's TLS 1.3. This test also
@@ -505,22 +612,6 @@ func MakeDialParameters(
 		}
 		}
 	}
 	}
 
 
-	if (!isReplay || !replayConjureRegistration) &&
-		protocol.TunnelProtocolUsesConjure(dialParams.TunnelProtocol) {
-
-		dialParams.ConjureDecoyRegistrarWidth = p.Int(parameters.ConjureDecoyRegistrarWidth)
-	}
-
-	if (!isReplay || !replayConjureTransport) &&
-		protocol.TunnelProtocolUsesConjure(dialParams.TunnelProtocol) {
-
-		dialParams.ConjureTransport = protocol.CONJURE_TRANSPORT_MIN_OSSH
-		if p.WeightedCoinFlip(
-			parameters.ConjureTransportObfs4Probability) {
-			dialParams.ConjureTransport = protocol.CONJURE_TRANSPORT_OBFS4_OSSH
-		}
-	}
-
 	if !isReplay || !replayLivenessTest {
 	if !isReplay || !replayLivenessTest {
 
 
 		// TODO: initialize only when LivenessTestMaxUp/DownstreamBytes > 0?
 		// TODO: initialize only when LivenessTestMaxUp/DownstreamBytes > 0?
@@ -655,7 +746,9 @@ func MakeDialParameters(
 
 
 	dialCustomHeaders := makeDialCustomHeaders(config, p)
 	dialCustomHeaders := makeDialCustomHeaders(config, p)
 
 
-	if protocol.TunnelProtocolUsesMeek(dialParams.TunnelProtocol) || dialParams.UpstreamProxyType == "http" {
+	if protocol.TunnelProtocolUsesMeek(dialParams.TunnelProtocol) ||
+		dialParams.UpstreamProxyType == "http" ||
+		dialParams.ConjureAPIRegistration {
 
 
 		if !isReplay || !replayUserAgent {
 		if !isReplay || !replayUserAgent {
 			dialParams.SelectedUserAgent, dialParams.UserAgent = selectUserAgentIfUnset(p, dialCustomHeaders)
 			dialParams.SelectedUserAgent, dialParams.UserAgent = selectUserAgentIfUnset(p, dialCustomHeaders)
@@ -699,7 +792,8 @@ func MakeDialParameters(
 	// always be read.
 	// always be read.
 	dialParams.MeekResolvedIPAddress.Store("")
 	dialParams.MeekResolvedIPAddress.Store("")
 
 
-	if protocol.TunnelProtocolUsesMeek(dialParams.TunnelProtocol) {
+	if protocol.TunnelProtocolUsesMeek(dialParams.TunnelProtocol) ||
+		dialParams.ConjureAPIRegistration {
 
 
 		dialParams.meekConfig = &MeekConfig{
 		dialParams.meekConfig = &MeekConfig{
 			DiagnosticID:                  serverEntry.GetDiagnosticID(),
 			DiagnosticID:                  serverEntry.GetDiagnosticID(),
@@ -707,12 +801,14 @@ func MakeDialParameters(
 			DialAddress:                   dialParams.MeekDialAddress,
 			DialAddress:                   dialParams.MeekDialAddress,
 			UseQUIC:                       protocol.TunnelProtocolUsesFrontedMeekQUIC(dialParams.TunnelProtocol),
 			UseQUIC:                       protocol.TunnelProtocolUsesFrontedMeekQUIC(dialParams.TunnelProtocol),
 			QUICVersion:                   dialParams.QUICVersion,
 			QUICVersion:                   dialParams.QUICVersion,
-			UseHTTPS:                      protocol.TunnelProtocolUsesMeekHTTPS(dialParams.TunnelProtocol),
+			UseHTTPS:                      usingTLS,
 			TLSProfile:                    dialParams.TLSProfile,
 			TLSProfile:                    dialParams.TLSProfile,
 			NoDefaultTLSSessionID:         dialParams.NoDefaultTLSSessionID,
 			NoDefaultTLSSessionID:         dialParams.NoDefaultTLSSessionID,
 			RandomizedTLSProfileSeed:      dialParams.RandomizedTLSProfileSeed,
 			RandomizedTLSProfileSeed:      dialParams.RandomizedTLSProfileSeed,
 			UseObfuscatedSessionTickets:   dialParams.TunnelProtocol == protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET,
 			UseObfuscatedSessionTickets:   dialParams.TunnelProtocol == protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET,
 			SNIServerName:                 dialParams.MeekSNIServerName,
 			SNIServerName:                 dialParams.MeekSNIServerName,
+			VerifyServerName:              dialParams.MeekVerifyServerName,
+			VerifyPins:                    dialParams.MeekVerifyPins,
 			HostHeader:                    dialParams.MeekHostHeader,
 			HostHeader:                    dialParams.MeekHostHeader,
 			TransformedHostName:           dialParams.MeekTransformedHostName,
 			TransformedHostName:           dialParams.MeekTransformedHostName,
 			ClientTunnelProtocol:          dialParams.TunnelProtocol,
 			ClientTunnelProtocol:          dialParams.TunnelProtocol,
@@ -732,7 +828,11 @@ func MakeDialParameters(
 		}
 		}
 
 
 		if isTactics {
 		if isTactics {
-			dialParams.meekConfig.RoundTripperOnly = true
+			dialParams.meekConfig.Mode = MeekModeObfuscatedRoundTrip
+		} else if dialParams.ConjureAPIRegistration {
+			dialParams.meekConfig.Mode = MeekModePlaintextRoundTrip
+		} else {
+			dialParams.meekConfig.Mode = MeekModeRelay
 		}
 		}
 	}
 	}
 
 
@@ -931,7 +1031,8 @@ func getConfigStateHash(
 	return hash.Sum(nil)
 	return hash.Sum(nil)
 }
 }
 
 
-func selectFrontingParameters(serverEntry *protocol.ServerEntry) (string, string, error) {
+func selectFrontingParameters(
+	serverEntry *protocol.ServerEntry) (string, string, error) {
 
 
 	frontingDialHost := ""
 	frontingDialHost := ""
 	frontingHost := ""
 	frontingHost := ""

+ 0 - 1
psiphon/feedback.go

@@ -163,7 +163,6 @@ func SendFeedback(ctx context.Context, config *Config, diagnostics, uploadPath s
 			feedbackUploadCtx,
 			feedbackUploadCtx,
 			config,
 			config,
 			untunneledDialConfig,
 			untunneledDialConfig,
-			nil,
 			uploadURL.SkipVerify)
 			uploadURL.SkipVerify)
 		if err != nil {
 		if err != nil {
 			return errors.Trace(err)
 			return errors.Trace(err)

+ 236 - 128
psiphon/meekConn.go

@@ -49,19 +49,25 @@ import (
 	"golang.org/x/crypto/nacl/box"
 	"golang.org/x/crypto/nacl/box"
 )
 )
 
 
-// MeekConn is based on meek-client.go from Tor and Psiphon:
+// MeekConn is based on meek-client.go from Tor:
 //
 //
 // https://gitweb.torproject.org/pluggable-transports/meek.git/blob/HEAD:/meek-client/meek-client.go
 // https://gitweb.torproject.org/pluggable-transports/meek.git/blob/HEAD:/meek-client/meek-client.go
 // CC0 1.0 Universal
 // CC0 1.0 Universal
-//
-// https://bitbucket.org/psiphon/psiphon-circumvention-system/src/default/go/meek-client/meek-client.go
 
 
 const (
 const (
 	MEEK_PROTOCOL_VERSION           = 3
 	MEEK_PROTOCOL_VERSION           = 3
 	MEEK_MAX_REQUEST_PAYLOAD_LENGTH = 65536
 	MEEK_MAX_REQUEST_PAYLOAD_LENGTH = 65536
 )
 )
 
 
-// MeekConfig specifies the behavior of a MeekConn
+type MeekMode int
+
+const (
+	MeekModeRelay = iota
+	MeekModeObfuscatedRoundTrip
+	MeekModePlaintextRoundTrip
+)
+
+// MeekConfig specifies the behavior of a MeekConn.
 type MeekConfig struct {
 type MeekConfig struct {
 
 
 	// DiagnosticID is the server ID to record in any diagnostics notices.
 	// DiagnosticID is the server ID to record in any diagnostics notices.
@@ -71,6 +77,32 @@ type MeekConfig struct {
 	// for the meek dial.
 	// for the meek dial.
 	Parameters *parameters.Parameters
 	Parameters *parameters.Parameters
 
 
+	// Mode selects the mode of operation:
+	//
+	// MeekModeRelay: encapsulates net.Conn flows in HTTP requests and responses;
+	// secures and obfuscates metadata in an encrypted HTTP cookie, making it
+	// suitable for non-TLS HTTP and HTTPS with unverifed server certificates;
+	// the caller is responsible for securing and obfuscating the net.Conn flows;
+	// the origin server should be a meek server; used for the meek tunnel
+	// protocols.
+	//
+	// MeekModeObfuscatedRoundTrip: enables ObfuscatedRoundTrip, which performs
+	// HTTP round trips; secures and obfuscates metadata, including the end point
+	// (or path), in an encrypted HTTP cookie, making it suitable for non-TLS
+	// HTTP and HTTPS with unverifed server certificates; the caller is
+	// responsible for securing and obfuscating request/response payloads; the
+	// origin server should be a meek server; used for tactics requests.
+	//
+	// MeekModePlaintextRoundTrip: enables RoundTrip; the MeekConn is an
+	// http.RoundTripper; there are no security or obfuscation measures at the
+	// HTTP level; TLS and server certificate verification is required; the
+	// origin server may be any HTTP(S) server.
+	//
+	// As with the other modes, MeekModePlaintextRoundTrip supports HTTP/2 with
+	// utls, and integration with DialParameters for replay -- which are not
+	// otherwise implemented if using just CustomTLSDialer and net.http.
+	Mode MeekMode
+
 	// DialAddress is the actual network address to dial to establish a
 	// DialAddress is the actual network address to dial to establish a
 	// connection to the meek server. This may be either a fronted or
 	// connection to the meek server. This may be either a fronted or
 	// direct address. The address must be in the form "host:port",
 	// direct address. The address must be in the form "host:port",
@@ -84,7 +116,6 @@ type MeekConfig struct {
 	QUICVersion string
 	QUICVersion string
 
 
 	// UseHTTPS indicates whether to use HTTPS (true) or HTTP (false).
 	// UseHTTPS indicates whether to use HTTPS (true) or HTTP (false).
-	// Ignored when UseQUIC is true.
 	UseHTTPS bool
 	UseHTTPS bool
 
 
 	// TLSProfile specifies the value for CustomTLSConfig.TLSProfile for all
 	// TLSProfile specifies the value for CustomTLSConfig.TLSProfile for all
@@ -101,12 +132,13 @@ type MeekConfig struct {
 	// connections created by this meek connection.
 	// connections created by this meek connection.
 	RandomizedTLSProfileSeed *prng.Seed
 	RandomizedTLSProfileSeed *prng.Seed
 
 
-	// UseObfuscatedSessionTickets indicates whether to use obfuscated
-	// session tickets. Assumes UseHTTPS is true.
+	// UseObfuscatedSessionTickets indicates whether to use obfuscated session
+	// tickets. Assumes UseHTTPS is true. Ignored for MeekModePlaintextRoundTrip.
+	//
 	UseObfuscatedSessionTickets bool
 	UseObfuscatedSessionTickets bool
 
 
-	// SNIServerName is the value to place in the TLS/QUIC SNI server_name
-	// field when HTTPS or QUIC is used.
+	// SNIServerName is the value to place in the TLS/QUIC SNI server_name field
+	// when HTTPS or QUIC is used.
 	SNIServerName string
 	SNIServerName string
 
 
 	// HostHeader is the value to place in the HTTP request Host header.
 	// HostHeader is the value to place in the HTTP request Host header.
@@ -116,46 +148,54 @@ type MeekConfig struct {
 	// in effect. This value is used for stats reporting.
 	// in effect. This value is used for stats reporting.
 	TransformedHostName bool
 	TransformedHostName bool
 
 
-	// ClientTunnelProtocol is the protocol the client is using. It's
-	// included in the meek cookie for optional use by the server, in
-	// cases where the server cannot unambiguously determine the
-	// tunnel protocol.
-	// ClientTunnelProtocol is used when selecting tactics targeted at
-	// specific protocols.
+	// VerifyServerName specifies a domain name that must appear in the server
+	// certificate. When blank, server certificate verification is disabled.
+	VerifyServerName string
+
+	// VerifyPins specifies one or more certificate pin values, one of which must
+	// appear in the verified server certificate chain. A pin value is the
+	// base64-encoded SHA2 digest of a certificate's public key. When specified,
+	// at least one pin must match at least one certificate in the chain, at any
+	// position; e.g., the root CA may be pinned, or the server certificate,
+	// etc.
+	VerifyPins []string
+
+	// ClientTunnelProtocol is the protocol the client is using. It's included in
+	// the meek cookie for optional use by the server, in cases where the server
+	// cannot unambiguously determine the tunnel protocol. ClientTunnelProtocol
+	// is used when selecting tactics targeted at specific protocols. Ignored for
+	// MeekModePlaintextRoundTrip.
 	ClientTunnelProtocol string
 	ClientTunnelProtocol string
 
 
-	// RoundTripperOnly sets the MeekConn to operate in round tripper
-	// mode, which is used for untunneled tactics requests. In this
-	// mode, a connection is established to the meek server as usual,
-	// but instead of relaying tunnel traffic, the RoundTrip function
-	// may be used to make requests. In this mode, no relay resources
-	// incuding buffers are allocated.
-	RoundTripperOnly bool
-
 	// NetworkLatencyMultiplier specifies a custom network latency multiplier to
 	// NetworkLatencyMultiplier specifies a custom network latency multiplier to
 	// apply to client parameters used by this meek connection.
 	// apply to client parameters used by this meek connection.
 	NetworkLatencyMultiplier float64
 	NetworkLatencyMultiplier float64
 
 
 	// The following values are used to create the obfuscated meek cookie.
 	// The following values are used to create the obfuscated meek cookie.
+	// Ignored for MeekModePlaintextRoundTrip.
 
 
 	MeekCookieEncryptionPublicKey string
 	MeekCookieEncryptionPublicKey string
 	MeekObfuscatedKey             string
 	MeekObfuscatedKey             string
 	MeekObfuscatorPaddingSeed     *prng.Seed
 	MeekObfuscatorPaddingSeed     *prng.Seed
 }
 }
 
 
-// MeekConn is a network connection that tunnels TCP over HTTP and supports "fronting". Meek sends
-// client->server flow in HTTP request bodies and receives server->client flow in HTTP response bodies.
-// Polling is used to achieve full duplex TCP.
+// MeekConn is a network connection that tunnels net.Conn flows over HTTP and supports
+// "domain fronting". Meek sends client->server flow in HTTP request bodies and
+// receives server->client flow in HTTP response bodies. Polling is used to
+// approximate full duplex TCP. MeekConn also offers HTTP round trip modes.
 //
 //
-// Fronting is an obfuscation technique in which the connection
-// to a web server, typically a CDN, is indistinguishable from any other HTTPS connection to the generic
-// "fronting domain" -- the HTTP Host header is used to route the requests to the actual destination.
-// See https://trac.torproject.org/projects/tor/wiki/doc/meek for more details.
+// Domain fronting is a network obfuscation technique in which the connection to a web
+// server, typically a CDN, is indistinguishable from any other HTTPS
+// connection to the generic "fronting domain" -- the HTTP Host header is used
+// to route the requests to the actual destination. See
+// https://trac.torproject.org/projects/tor/wiki/doc/meek for more details.
 //
 //
-// MeekConn also operates in unfronted mode, in which plain HTTP connections are made without routing
-// through a CDN.
+// MeekConn also support unfronted operation, in which connections are made
+// without routing through a CDN; and plain HTTP operation, without TLS or
+// QUIC, with connection metadata obfuscated in HTTP cookies.
 type MeekConn struct {
 type MeekConn struct {
 	params                    *parameters.Parameters
 	params                    *parameters.Parameters
+	mode                      MeekMode
 	networkLatencyMultiplier  float64
 	networkLatencyMultiplier  float64
 	isQUIC                    bool
 	isQUIC                    bool
 	url                       *url.URL
 	url                       *url.URL
@@ -173,14 +213,13 @@ type MeekConn struct {
 	stopRunning               context.CancelFunc
 	stopRunning               context.CancelFunc
 	relayWaitGroup            *sync.WaitGroup
 	relayWaitGroup            *sync.WaitGroup
 
 
-	// For round tripper mode
-	roundTripperOnly              bool
+	// For MeekModeObfuscatedRoundTrip
 	meekCookieEncryptionPublicKey string
 	meekCookieEncryptionPublicKey string
 	meekObfuscatedKey             string
 	meekObfuscatedKey             string
 	meekObfuscatorPaddingSeed     *prng.Seed
 	meekObfuscatorPaddingSeed     *prng.Seed
 	clientTunnelProtocol          string
 	clientTunnelProtocol          string
 
 
-	// For relay mode
+	// For MeekModeRelay
 	fullReceiveBufferLength int
 	fullReceiveBufferLength int
 	readPayloadChunkLength  int
 	readPayloadChunkLength  int
 	emptyReceiveBuffer      chan *bytes.Buffer
 	emptyReceiveBuffer      chan *bytes.Buffer
@@ -203,15 +242,38 @@ type transporter interface {
 
 
 // DialMeek returns an initialized meek connection. A meek connection is
 // DialMeek returns an initialized meek connection. A meek connection is
 // an HTTP session which does not depend on an underlying socket connection (although
 // an HTTP session which does not depend on an underlying socket connection (although
-// persistent HTTP connections are used for performance). This function does not
-// wait for the connection to be "established" before returning. A goroutine
-// is spawned which will eventually start HTTP polling.
-// When frontingAddress is not "", fronting is used. This option assumes caller has
-// already checked server entry capabilities.
+// persistent HTTP connections are used for performance). This function may not
+// wait for the connection to be established before returning.
 func DialMeek(
 func DialMeek(
 	ctx context.Context,
 	ctx context.Context,
 	meekConfig *MeekConfig,
 	meekConfig *MeekConfig,
-	dialConfig *DialConfig) (meek *MeekConn, err error) {
+	dialConfig *DialConfig) (*MeekConn, error) {
+
+	if meekConfig.UseQUIC && meekConfig.UseHTTPS {
+		return nil, errors.TraceNew(
+			"invalid config: only one of UseQUIC or UseHTTPS may be set")
+	}
+
+	if meekConfig.UseQUIC &&
+		(meekConfig.VerifyServerName != "" || len(meekConfig.VerifyPins) > 0) {
+
+		// TODO: UseQUIC VerifyServerName and VerifyPins support (required for MeekModePlaintextRoundTrip).
+
+		return nil, errors.TraceNew(
+			"invalid config: VerifyServerName and VerifyPins not supported for UseQUIC")
+	}
+
+	skipVerify := meekConfig.VerifyServerName == ""
+	if len(meekConfig.VerifyPins) > 0 && skipVerify {
+		return nil, errors.TraceNew(
+			"invalid config: VerifyServerName must be set when VerifyPins is set")
+	}
+
+	if meekConfig.Mode == MeekModePlaintextRoundTrip &&
+		(!meekConfig.UseHTTPS || skipVerify) {
+		return nil, errors.TraceNew(
+			"invalid config: MeekModePlaintextRoundTrip requires UseHTTPS and VerifyServerName")
+	}
 
 
 	runCtx, stopRunning := context.WithCancel(context.Background())
 	runCtx, stopRunning := context.WithCancel(context.Background())
 
 
@@ -229,18 +291,18 @@ func DialMeek(
 		}
 		}
 	}()
 	}()
 
 
-	meek = &MeekConn{
+	meek := &MeekConn{
 		params:                   meekConfig.Parameters,
 		params:                   meekConfig.Parameters,
+		mode:                     meekConfig.Mode,
 		networkLatencyMultiplier: meekConfig.NetworkLatencyMultiplier,
 		networkLatencyMultiplier: meekConfig.NetworkLatencyMultiplier,
 		isClosed:                 false,
 		isClosed:                 false,
 		runCtx:                   runCtx,
 		runCtx:                   runCtx,
 		stopRunning:              stopRunning,
 		stopRunning:              stopRunning,
 		relayWaitGroup:           new(sync.WaitGroup),
 		relayWaitGroup:           new(sync.WaitGroup),
-		roundTripperOnly:         meekConfig.RoundTripperOnly,
 	}
 	}
 
 
-	if !meek.roundTripperOnly {
-
+	if meek.mode == MeekModeRelay {
+		var err error
 		meek.cookie,
 		meek.cookie,
 			meek.tlsPadding,
 			meek.tlsPadding,
 			meek.limitRequestPayloadLength,
 			meek.limitRequestPayloadLength,
@@ -256,6 +318,9 @@ func DialMeek(
 		if err != nil {
 		if err != nil {
 			return nil, errors.Trace(err)
 			return nil, errors.Trace(err)
 		}
 		}
+
+		// For stats, record the size of the initial obfuscated cookie.
+		meek.cookieSize = len(meek.cookie.Name) + len(meek.cookie.Value)
 	}
 	}
 
 
 	// Configure transport: QUIC or HTTPS or HTTP
 	// Configure transport: QUIC or HTTPS or HTTP
@@ -306,7 +371,7 @@ func DialMeek(
 		//
 		//
 		//  1. ignores the HTTP request address and uses the fronting domain
 		//  1. ignores the HTTP request address and uses the fronting domain
 		//  2. optionally disables SNI -- SNI breaks fronting when used with certain CDNs.
 		//  2. optionally disables SNI -- SNI breaks fronting when used with certain CDNs.
-		//  3. skips verifying the server cert.
+		//  3. may skip verifying the server cert.
 		//
 		//
 		// Reasoning for #3:
 		// Reasoning for #3:
 		//
 		//
@@ -342,7 +407,9 @@ func DialMeek(
 			DialAddr:                      meekConfig.DialAddress,
 			DialAddr:                      meekConfig.DialAddress,
 			Dial:                          NewTCPDialer(dialConfig),
 			Dial:                          NewTCPDialer(dialConfig),
 			SNIServerName:                 meekConfig.SNIServerName,
 			SNIServerName:                 meekConfig.SNIServerName,
-			SkipVerify:                    true,
+			SkipVerify:                    skipVerify,
+			VerifyServerName:              meekConfig.VerifyServerName,
+			VerifyPins:                    meekConfig.VerifyPins,
 			TLSProfile:                    meekConfig.TLSProfile,
 			TLSProfile:                    meekConfig.TLSProfile,
 			NoDefaultTLSSessionID:         &meekConfig.NoDefaultTLSSessionID,
 			NoDefaultTLSSessionID:         &meekConfig.NoDefaultTLSSessionID,
 			RandomizedTLSProfileSeed:      meekConfig.RandomizedTLSProfileSeed,
 			RandomizedTLSProfileSeed:      meekConfig.RandomizedTLSProfileSeed,
@@ -355,17 +422,21 @@ func DialMeek(
 			tlsConfig.ObfuscatedSessionTicketKey = meekConfig.MeekObfuscatedKey
 			tlsConfig.ObfuscatedSessionTicketKey = meekConfig.MeekObfuscatedKey
 		}
 		}
 
 
-		// As the passthrough message is unique and indistinguisbale from a normal
-		// TLS client random value, we set it unconditionally and not just for
-		// protocols which may support passthrough (even for those protocols,
-		// clients don't know which servers are configured to use it).
+		if meekConfig.Mode != MeekModePlaintextRoundTrip &&
+			meekConfig.MeekObfuscatedKey != "" {
 
 
-		passthroughMessage, err := obfuscator.MakeTLSPassthroughMessage(
-			meekConfig.MeekObfuscatedKey)
-		if err != nil {
-			return nil, errors.Trace(err)
+			// As the passthrough message is unique and indistinguishable from a normal
+			// TLS client random value, we set it unconditionally and not just for
+			// protocols which may support passthrough (even for those protocols,
+			// clients don't know which servers are configured to use it).
+
+			passthroughMessage, err := obfuscator.MakeTLSPassthroughMessage(
+				meekConfig.MeekObfuscatedKey)
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+			tlsConfig.PassthroughMessage = passthroughMessage
 		}
 		}
-		tlsConfig.PassthroughMessage = passthroughMessage
 
 
 		tlsDialer := NewCustomTLSDialer(tlsConfig)
 		tlsDialer := NewCustomTLSDialer(tlsConfig)
 
 
@@ -478,6 +549,7 @@ func DialMeek(
 
 
 		if proxyUrl != nil {
 		if proxyUrl != nil {
 			// Wrap transport with a transport that can perform HTTP proxy auth negotiation
 			// Wrap transport with a transport that can perform HTTP proxy auth negotiation
+			var err error
 			transport, err = upstreamproxy.NewProxyAuthTransport(httpTransport, dialConfig.CustomHeaders)
 			transport, err = upstreamproxy.NewProxyAuthTransport(httpTransport, dialConfig.CustomHeaders)
 			if err != nil {
 			if err != nil {
 				return nil, errors.Trace(err)
 				return nil, errors.Trace(err)
@@ -518,7 +590,7 @@ func DialMeek(
 
 
 	// Allocate relay resources, including buffers and running the relay
 	// Allocate relay resources, including buffers and running the relay
 	// go routine, only when running in relay mode.
 	// go routine, only when running in relay mode.
-	if !meek.roundTripperOnly {
+	if meek.mode == MeekModeRelay {
 
 
 		// The main loop of a MeekConn is run in the relay() goroutine.
 		// The main loop of a MeekConn is run in the relay() goroutine.
 		// A MeekConn implements net.Conn concurrency semantics:
 		// A MeekConn implements net.Conn concurrency semantics:
@@ -559,7 +631,7 @@ func DialMeek(
 		meek.relayWaitGroup.Add(1)
 		meek.relayWaitGroup.Add(1)
 		go meek.relay()
 		go meek.relay()
 
 
-	} else {
+	} else if meek.mode == MeekModeObfuscatedRoundTrip {
 
 
 		meek.meekCookieEncryptionPublicKey = meekConfig.MeekCookieEncryptionPublicKey
 		meek.meekCookieEncryptionPublicKey = meekConfig.MeekCookieEncryptionPublicKey
 		meek.meekObfuscatedKey = meekConfig.MeekObfuscatedKey
 		meek.meekObfuscatedKey = meekConfig.MeekObfuscatedKey
@@ -617,11 +689,12 @@ func (c *cachedTLSDialer) close() {
 	}
 	}
 }
 }
 
 
-// Close terminates the meek connection. Close waits for the relay goroutine
-// to stop (in relay mode) and releases HTTP transport resources.
-// A mutex is required to support net.Conn concurrency semantics.
+// Close terminates the meek connection and releases its resources. In in
+// MeekModeRelay, Close waits for the relay goroutine to stop.
 func (meek *MeekConn) Close() (err error) {
 func (meek *MeekConn) Close() (err error) {
 
 
+	// A mutex is required to support net.Conn concurrency semantics.
+
 	meek.mutex.Lock()
 	meek.mutex.Lock()
 	isClosed := meek.isClosed
 	isClosed := meek.isClosed
 	meek.isClosed = true
 	meek.isClosed = true
@@ -671,26 +744,28 @@ func (meek *MeekConn) IsClosed() bool {
 // GetMetrics implements the common.MetricsSource interface.
 // GetMetrics implements the common.MetricsSource interface.
 func (meek *MeekConn) GetMetrics() common.LogFields {
 func (meek *MeekConn) GetMetrics() common.LogFields {
 	logFields := make(common.LogFields)
 	logFields := make(common.LogFields)
-	logFields["meek_cookie_size"] = meek.cookieSize
-	logFields["meek_tls_padding"] = meek.tlsPadding
-	logFields["meek_limit_request"] = meek.limitRequestPayloadLength
+	if meek.mode == MeekModeRelay {
+		logFields["meek_cookie_size"] = meek.cookieSize
+		logFields["meek_tls_padding"] = meek.tlsPadding
+		logFields["meek_limit_request"] = meek.limitRequestPayloadLength
+	}
 	return logFields
 	return logFields
 }
 }
 
 
-// RoundTrip makes a request to the meek server and returns the response.
-// A new, obfuscated meek cookie is created for every request. The specified
-// end point is recorded in the cookie and is not exposed as plaintext in the
-// meek traffic. The caller is responsible for obfuscating the request body.
+// ObfuscatedRoundTrip makes a request to the meek server and returns the
+// response. A new, obfuscated meek cookie is created for every request. The
+// specified end point is recorded in the cookie and is not exposed as
+// plaintext in the meek traffic. The caller is responsible for securing and
+// obfuscating the request body.
 //
 //
-// RoundTrip is not safe for concurrent use, and Close must not be called
-// concurrently. The caller must ensure onlt one RoundTrip call is active
-// at once and that it completes before calling Close.
-//
-// RoundTrip is only available in round tripper mode.
-func (meek *MeekConn) RoundTrip(
-	ctx context.Context, endPoint string, requestBody []byte) ([]byte, error) {
-
-	if !meek.roundTripperOnly {
+// ObfuscatedRoundTrip is not safe for concurrent use, and Close must not be
+// called concurrently. The caller must ensure only one ObfuscatedRoundTrip
+// call is active at once and that it completes or is cancelled before calling
+// Close.
+func (meek *MeekConn) ObfuscatedRoundTrip(
+	requestCtx context.Context, endPoint string, requestBody []byte) ([]byte, error) {
+
+	if meek.mode != MeekModeObfuscatedRoundTrip {
 		return nil, errors.TraceNew("operation unsupported")
 		return nil, errors.TraceNew("operation unsupported")
 	}
 	}
 
 
@@ -707,32 +782,23 @@ func (meek *MeekConn) RoundTrip(
 
 
 	// Note:
 	// Note:
 	//
 	//
-	// - multiple, concurrent RoundTrip calls are unsafe due to the
-	//   setRequestContext calls in newRequest.
+	// - multiple, concurrent ObfuscatedRoundTrip calls are unsafe due to the
+	//   setDialerRequestContext calls in newRequest.
 	//
 	//
-	// - concurrent Close and RoundTrip calls are unsafe as Close
-	//   does not synchronize with RoundTrip before calling
-	//   meek.transport.CloseIdleConnections(), so resources could
-	//   be left open.
+	// - concurrent Close and ObfuscatedRoundTrip calls are unsafe as Close does
+	//   not synchronize with ObfuscatedRoundTrip before calling
+	//   meek.transport.CloseIdleConnections(), so resources could be left open.
 	//
 	//
-	// At this time, RoundTrip is used for tactics in Controller and
+	// At this time, ObfuscatedRoundTrip is used for tactics in Controller and
 	// the concurrency constraints are satisfied.
 	// the concurrency constraints are satisfied.
 
 
-	request, cancelFunc, err := meek.newRequest(
-		ctx, cookie, bytes.NewReader(requestBody), 0)
+	request, err := meek.newRequest(
+		requestCtx, cookie, bytes.NewReader(requestBody), 0)
 	if err != nil {
 	if err != nil {
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
-	defer cancelFunc()
 
 
-	// Workaround for h2quic.RoundTripper context issue. See comment in
-	// MeekConn.Close.
-	if meek.isQUIC {
-		go func() {
-			<-request.Context().Done()
-			meek.transport.CloseIdleConnections()
-		}()
-	}
+	meek.scheduleQUICCloseIdle(request)
 
 
 	response, err := meek.transport.RoundTrip(request)
 	response, err := meek.transport.RoundTrip(request)
 	if err == nil {
 	if err == nil {
@@ -753,10 +819,36 @@ func (meek *MeekConn) RoundTrip(
 	return responseBody, nil
 	return responseBody, nil
 }
 }
 
 
+// RoundTrip implements the http.RoundTripper interface. RoundTrip may only be
+// used when TLS and server certificate verification are configured. RoundTrip
+// does not implement any security or obfuscation at the HTTP layer.
+//
+// RoundTrip is not safe for concurrent use, and Close must not be called
+// concurrently. The caller must ensure only one RoundTrip call is active at
+// once and that it completes or is cancelled before calling Close.
+func (meek *MeekConn) RoundTrip(request *http.Request) (*http.Response, error) {
+
+	if meek.mode != MeekModePlaintextRoundTrip {
+		return nil, errors.TraceNew("operation unsupported")
+	}
+
+	requestCtx := request.Context()
+
+	// The setDialerRequestContext/CloseIdleConnections concurrency note in
+	// ObfuscatedRoundTrip applies to RoundTrip as well.
+
+	// Ensure dials are made within the request context.
+	meek.setDialerRequestContext(requestCtx)
+
+	meek.scheduleQUICCloseIdle(request)
+
+	return meek.transport.RoundTrip(request)
+}
+
 // Read reads data from the connection.
 // Read reads data from the connection.
 // net.Conn Deadlines are ignored. net.Conn concurrency semantics are supported.
 // net.Conn Deadlines are ignored. net.Conn concurrency semantics are supported.
 func (meek *MeekConn) Read(buffer []byte) (n int, err error) {
 func (meek *MeekConn) Read(buffer []byte) (n int, err error) {
-	if meek.roundTripperOnly {
+	if meek.mode != MeekModeRelay {
 		return 0, errors.TraceNew("operation unsupported")
 		return 0, errors.TraceNew("operation unsupported")
 	}
 	}
 	if meek.IsClosed() {
 	if meek.IsClosed() {
@@ -778,7 +870,7 @@ func (meek *MeekConn) Read(buffer []byte) (n int, err error) {
 // Write writes data to the connection.
 // Write writes data to the connection.
 // net.Conn Deadlines are ignored. net.Conn concurrency semantics are supported.
 // net.Conn Deadlines are ignored. net.Conn concurrency semantics are supported.
 func (meek *MeekConn) Write(buffer []byte) (n int, err error) {
 func (meek *MeekConn) Write(buffer []byte) (n int, err error) {
-	if meek.roundTripperOnly {
+	if meek.mode != MeekModeRelay {
 		return 0, errors.TraceNew("operation unsupported")
 		return 0, errors.TraceNew("operation unsupported")
 	}
 	}
 	if meek.IsClosed() {
 	if meek.IsClosed() {
@@ -1016,43 +1108,25 @@ func (r *readCloseSignaller) AwaitClosed() bool {
 	return false
 	return false
 }
 }
 
 
-// newRequest performs common request setup for both relay and round
-// tripper modes.
+// newRequest performs common request setup for both MeekModeRelay and
+// MeekModeObfuscatedRoundTrip.
 //
 //
 // newRequest is not safe for concurrent calls due to its use of
 // newRequest is not safe for concurrent calls due to its use of
 // setRequestContext.
 // setRequestContext.
 //
 //
 // The caller must call the returned cancelFunc.
 // The caller must call the returned cancelFunc.
 func (meek *MeekConn) newRequest(
 func (meek *MeekConn) newRequest(
-	ctx context.Context,
+	requestCtx context.Context,
 	cookie *http.Cookie,
 	cookie *http.Cookie,
 	body io.Reader,
 	body io.Reader,
-	contentLength int) (*http.Request, context.CancelFunc, error) {
+	contentLength int) (*http.Request, error) {
 
 
-	var requestCtx context.Context
-	var cancelFunc context.CancelFunc
-
-	if ctx != nil {
-		requestCtx, cancelFunc = context.WithCancel(ctx)
-	} else {
-		// - meek.stopRunning() will abort a round trip in flight
-		// - round trip will abort if it exceeds timeout
-		requestCtx, cancelFunc = context.WithTimeout(
-			meek.runCtx,
-			meek.getCustomParameters().Duration(parameters.MeekRoundTripTimeout))
-	}
-
-	// Ensure dials are made within the current request context.
-	if meek.isQUIC {
-		meek.transport.(*quic.QUICTransporter).SetRequestContext(requestCtx)
-	} else if meek.cachedTLSDialer != nil {
-		meek.cachedTLSDialer.setRequestContext(requestCtx)
-	}
+	// Ensure dials are made within the request context.
+	meek.setDialerRequestContext(requestCtx)
 
 
 	request, err := http.NewRequest("POST", meek.url.String(), body)
 	request, err := http.NewRequest("POST", meek.url.String(), body)
 	if err != nil {
 	if err != nil {
-		cancelFunc()
-		return nil, nil, errors.Trace(err)
+		return nil, errors.Trace(err)
 	}
 	}
 
 
 	request = request.WithContext(requestCtx)
 	request = request.WithContext(requestCtx)
@@ -1072,7 +1146,30 @@ func (meek *MeekConn) newRequest(
 	}
 	}
 	request.AddCookie(cookie)
 	request.AddCookie(cookie)
 
 
-	return request, cancelFunc, nil
+	return request, nil
+}
+
+// setDialerRequestContext ensures that underlying TLS/QUIC dials operate
+// within the context of the request context. setDialerRequestContext must not
+// be called while another request is already in flight.
+func (meek *MeekConn) setDialerRequestContext(requestCtx context.Context) {
+	if meek.isQUIC {
+		meek.transport.(*quic.QUICTransporter).SetRequestContext(requestCtx)
+	} else if meek.cachedTLSDialer != nil {
+		meek.cachedTLSDialer.setRequestContext(requestCtx)
+	}
+}
+
+// Workaround for h2quic.RoundTripper context issue. See comment in
+// MeekConn.Close.
+func (meek *MeekConn) scheduleQUICCloseIdle(request *http.Request) {
+	requestCtx := request.Context()
+	if meek.isQUIC && requestCtx != context.Background() {
+		go func() {
+			<-requestCtx.Done()
+			meek.transport.CloseIdleConnections()
+		}()
+	}
 }
 }
 
 
 // relayRoundTrip configures and makes the actual HTTP POST request
 // relayRoundTrip configures and makes the actual HTTP POST request
@@ -1153,9 +1250,15 @@ func (meek *MeekConn) relayRoundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 			contentLength = sendBuffer.Len()
 			contentLength = sendBuffer.Len()
 		}
 		}
 
 
-		request, cancelFunc, err := meek.newRequest(
-			//lint:ignore SA1012 meek.newRequest expects/handles nil context
-			nil,
+		// - meek.stopRunning() will abort a round trip in flight
+		// - round trip will abort if it exceeds timeout
+		requestCtx, cancelFunc := context.WithTimeout(
+			meek.runCtx,
+			meek.getCustomParameters().Duration(parameters.MeekRoundTripTimeout))
+		defer cancelFunc()
+
+		request, err := meek.newRequest(
+			requestCtx,
 			nil,
 			nil,
 			requestBody,
 			requestBody,
 			contentLength)
 			contentLength)
@@ -1256,7 +1359,7 @@ func (meek *MeekConn) relayRoundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 			}
 			}
 		}
 		}
 
 
-		// Release context resources now.
+		// Release context resources immediately.
 		cancelFunc()
 		cancelFunc()
 
 
 		// Either the request failed entirely, or there was a failure
 		// Either the request failed entirely, or there was a failure
@@ -1383,6 +1486,10 @@ func makeMeekObfuscationValues(
 	redialTLSProbability float64,
 	redialTLSProbability float64,
 	err error) {
 	err error) {
 
 
+	if meekCookieEncryptionPublicKey == "" {
+		return nil, 0, 0, 0.0, errors.TraceNew("missing public key")
+	}
+
 	cookieData := &protocol.MeekCookieData{
 	cookieData := &protocol.MeekCookieData{
 		MeekProtocolVersion:  MEEK_PROTOCOL_VERSION,
 		MeekProtocolVersion:  MEEK_PROTOCOL_VERSION,
 		ClientTunnelProtocol: clientTunnelProtocol,
 		ClientTunnelProtocol: clientTunnelProtocol,
@@ -1418,7 +1525,8 @@ func makeMeekObfuscationValues(
 
 
 	maxPadding := p.Int(parameters.MeekCookieMaxPadding)
 	maxPadding := p.Int(parameters.MeekCookieMaxPadding)
 
 
-	// Obfuscate the encrypted data
+	// Obfuscate the encrypted data. NewClientObfuscator checks that
+	// meekObfuscatedKey isn't missing.
 	obfuscator, err := obfuscator.NewClientObfuscator(
 	obfuscator, err := obfuscator.NewClientObfuscator(
 		&obfuscator.ObfuscatorConfig{
 		&obfuscator.ObfuscatorConfig{
 			Keyword:         meekObfuscatedKey,
 			Keyword:         meekObfuscatedKey,

+ 101 - 0
psiphon/meekConn_test.go

@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2021, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import (
+	"context"
+	"io/ioutil"
+	"net/http"
+	"os"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
+)
+
+// MeekModeRelay and MeekModeObfuscatedRoundTrip are tested via meek protocol
+// and tactics test cases.
+
+func TestMeekModePlaintextRoundTrip(t *testing.T) {
+
+	testDataDirName, err := ioutil.TempDir("", "psiphon-meek-mode-plaintext-round-trip-test")
+	if err != nil {
+		t.Fatalf("TempDir failed: %v", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	serverName := "example.org"
+
+	rootCAsFileName,
+		rootCACertificatePin,
+		serverCertificatePin,
+		shutdown,
+		serverAddr,
+		dialer := initTestCertificatesAndWebServer(
+		t, testDataDirName, serverName)
+	defer shutdown()
+
+	params, err := parameters.NewParameters(nil)
+	if err != nil {
+		t.Fatalf("parameters.NewParameters failed: %v", err)
+	}
+
+	meekConfig := &MeekConfig{
+		Parameters:       params,
+		Mode:             MeekModePlaintextRoundTrip,
+		DialAddress:      serverAddr,
+		UseHTTPS:         true,
+		SNIServerName:    "not-" + serverName,
+		VerifyServerName: serverName,
+		VerifyPins:       []string{rootCACertificatePin, serverCertificatePin},
+	}
+
+	dialConfig := &DialConfig{
+		TrustedCACertificatesFilename: rootCAsFileName,
+		CustomDialer:                  dialer,
+	}
+
+	ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
+	defer cancelFunc()
+
+	meekConn, err := DialMeek(ctx, meekConfig, dialConfig)
+	if err != nil {
+		t.Fatalf("DialMeek failed: %v", err)
+	}
+
+	client := &http.Client{
+		Transport: meekConn,
+	}
+
+	response, err := client.Get("https://" + serverAddr + "/")
+	if err != nil {
+		t.Fatalf("http.Client.Get failed: %v", err)
+	}
+	response.Body.Close()
+
+	if response.StatusCode != http.StatusOK {
+		t.Fatalf("unexpected response code: %v", response.StatusCode)
+	}
+
+	err = meekConn.Close()
+	if err != nil {
+		t.Fatalf("MeekConn.Close failed: %v", err)
+	}
+}

+ 10 - 7
psiphon/net.go

@@ -104,6 +104,13 @@ type DialConfig struct {
 	// proxy error. As the upstream proxy is user configured, the error message
 	// proxy error. As the upstream proxy is user configured, the error message
 	// may need to be relayed to the user.
 	// may need to be relayed to the user.
 	UpstreamProxyErrorCallback func(error)
 	UpstreamProxyErrorCallback func(error)
+
+	// CustomDialer overrides the dialer created by NewNetDialer/NewTCPDialer.
+	// When CustomDialer is set, all other DialConfig parameters are ignored by
+	// NewNetDialer/NewTCPDialer. Other DialConfig consumers may still reference
+	// other DialConfig parameters; for example MeekConfig still uses
+	// TrustedCACertificatesFilename.
+	CustomDialer common.Dialer
 }
 }
 
 
 // WithoutFragmentor returns a copy of the DialConfig with any fragmentor
 // WithoutFragmentor returns a copy of the DialConfig with any fragmentor
@@ -318,25 +325,21 @@ func ResolveIP(host string, conn net.Conn) (addrs []net.IP, ttls []time.Duration
 }
 }
 
 
 // MakeUntunneledHTTPClient returns a net/http.Client which is configured to
 // MakeUntunneledHTTPClient returns a net/http.Client which is configured to
-// use custom dialing features -- including BindToDevice, etc. If
-// verifyLegacyCertificate is not nil, it's used for certificate verification.
+// use custom dialing features -- including BindToDevice, etc.
+//
 // The context is applied to underlying TCP dials. The caller is responsible
 // The context is applied to underlying TCP dials. The caller is responsible
 // for applying the context to requests made with the returned http.Client.
 // for applying the context to requests made with the returned http.Client.
 func MakeUntunneledHTTPClient(
 func MakeUntunneledHTTPClient(
 	ctx context.Context,
 	ctx context.Context,
 	config *Config,
 	config *Config,
 	untunneledDialConfig *DialConfig,
 	untunneledDialConfig *DialConfig,
-	verifyLegacyCertificate *x509.Certificate,
 	skipVerify bool) (*http.Client, error) {
 	skipVerify bool) (*http.Client, error) {
 
 
 	dialer := NewTCPDialer(untunneledDialConfig)
 	dialer := NewTCPDialer(untunneledDialConfig)
 
 
-	// Note: when verifyLegacyCertificate is not nil, some
-	// of the other CustomTLSConfig is overridden.
 	tlsConfig := &CustomTLSConfig{
 	tlsConfig := &CustomTLSConfig{
 		Parameters:                    config.GetParameters(),
 		Parameters:                    config.GetParameters(),
 		Dial:                          dialer,
 		Dial:                          dialer,
-		VerifyLegacyCertificate:       verifyLegacyCertificate,
 		UseDialAddrSNI:                true,
 		UseDialAddrSNI:                true,
 		SNIServerName:                 "",
 		SNIServerName:                 "",
 		SkipVerify:                    skipVerify,
 		SkipVerify:                    skipVerify,
@@ -430,7 +433,7 @@ func MakeDownloadHTTPClient(
 	} else {
 	} else {
 
 
 		httpClient, err = MakeUntunneledHTTPClient(
 		httpClient, err = MakeUntunneledHTTPClient(
-			ctx, config, untunneledDialConfig, nil, skipVerify)
+			ctx, config, untunneledDialConfig, skipVerify)
 		if err != nil {
 		if err != nil {
 			return nil, false, errors.Trace(err)
 			return nil, false, errors.Trace(err)
 		}
 		}

+ 32 - 12
psiphon/server/meek.go

@@ -340,7 +340,12 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 	// 3. A request to an endpoint. This meek connection is not for relaying
 	// 3. A request to an endpoint. This meek connection is not for relaying
 	// tunnel traffic. Instead, the request is handed off to a custom handler.
 	// tunnel traffic. Instead, the request is handed off to a custom handler.
 
 
-	sessionID, session, endPoint, clientIP, err := server.getSessionOrEndpoint(request, meekCookie)
+	sessionID,
+		session,
+		underlyingConn,
+		endPoint,
+		clientIP,
+		err := server.getSessionOrEndpoint(request, meekCookie)
 	if err != nil {
 	if err != nil {
 		// Debug since session cookie errors commonly occur during
 		// Debug since session cookie errors commonly occur during
 		// normal operation.
 		// normal operation.
@@ -390,6 +395,18 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 	session.lock.Lock()
 	session.lock.Lock()
 	defer session.lock.Unlock()
 	defer session.lock.Unlock()
 
 
+	// Count this metric once the lock is acquired, to avoid concurrent and
+	// potentially incorrect session.underlyingConn updates.
+	//
+	// It should never be the case that a new underlyingConn has the same
+	// value as the previous session.underlyingConn, as each is a net.Conn
+	// interface which includes a pointer, and the previous value cannot
+	// be garbage collected until session.underlyingConn is updated.
+	if session.underlyingConn != underlyingConn {
+		atomic.AddInt64(&session.metricUnderlyingConnCount, 1)
+		session.underlyingConn = underlyingConn
+	}
+
 	// If a newer request has arrived while waiting, discard this one.
 	// If a newer request has arrived while waiting, discard this one.
 	// Do not delay processing the newest request.
 	// Do not delay processing the newest request.
 	//
 	//
@@ -570,7 +587,9 @@ func checkRangeHeader(request *http.Request) (int, bool) {
 // mode; or the endpoint is returned when the meek cookie indicates endpoint
 // mode; or the endpoint is returned when the meek cookie indicates endpoint
 // mode.
 // mode.
 func (server *MeekServer) getSessionOrEndpoint(
 func (server *MeekServer) getSessionOrEndpoint(
-	request *http.Request, meekCookie *http.Cookie) (string, *meekSession, string, string, error) {
+	request *http.Request, meekCookie *http.Cookie) (string, *meekSession, net.Conn, string, string, error) {
+
+	underlyingConn := request.Context().Value(meekNetConnContextKey).(net.Conn)
 
 
 	// Check for an existing session.
 	// Check for an existing session.
 
 
@@ -582,7 +601,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 		// TODO: can multiple http client connections using same session cookie
 		// TODO: can multiple http client connections using same session cookie
 		// cause race conditions on session struct?
 		// cause race conditions on session struct?
 		session.touch()
 		session.touch()
-		return existingSessionID, session, "", "", nil
+		return existingSessionID, session, underlyingConn, "", "", nil
 	}
 	}
 
 
 	// Determine the client remote address, which is used for geolocation
 	// Determine the client remote address, which is used for geolocation
@@ -612,7 +631,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 	}
 	}
 
 
 	if server.rateLimit(clientIP) {
 	if server.rateLimit(clientIP) {
-		return "", nil, "", "", errors.TraceNew("rate limit exceeded")
+		return "", nil, nil, "", "", errors.TraceNew("rate limit exceeded")
 	}
 	}
 
 
 	// The session is new (or expired). Treat the cookie value as a new meek
 	// The session is new (or expired). Treat the cookie value as a new meek
@@ -620,7 +639,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 
 
 	payloadJSON, err := server.getMeekCookiePayload(clientIP, meekCookie.Value)
 	payloadJSON, err := server.getMeekCookiePayload(clientIP, meekCookie.Value)
 	if err != nil {
 	if err != nil {
-		return "", nil, "", "", errors.Trace(err)
+		return "", nil, nil, "", "", errors.Trace(err)
 	}
 	}
 
 
 	// Note: this meek server ignores legacy values PsiphonClientSessionId
 	// Note: this meek server ignores legacy values PsiphonClientSessionId
@@ -629,7 +648,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 
 
 	err = json.Unmarshal(payloadJSON, &clientSessionData)
 	err = json.Unmarshal(payloadJSON, &clientSessionData)
 	if err != nil {
 	if err != nil {
-		return "", nil, "", "", errors.Trace(err)
+		return "", nil, nil, "", "", errors.Trace(err)
 	}
 	}
 
 
 	// Handle endpoints before enforcing CheckEstablishTunnels.
 	// Handle endpoints before enforcing CheckEstablishTunnels.
@@ -637,7 +656,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// handled by servers which would otherwise reject new tunnels.
 	// handled by servers which would otherwise reject new tunnels.
 
 
 	if clientSessionData.EndPoint != "" {
 	if clientSessionData.EndPoint != "" {
-		return "", nil, clientSessionData.EndPoint, clientIP, nil
+		return "", nil, nil, clientSessionData.EndPoint, clientIP, nil
 	}
 	}
 
 
 	// Don't create new sessions when not establishing. A subsequent SSH handshake
 	// Don't create new sessions when not establishing. A subsequent SSH handshake
@@ -645,7 +664,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 
 
 	if server.support.TunnelServer != nil &&
 	if server.support.TunnelServer != nil &&
 		!server.support.TunnelServer.CheckEstablishTunnels() {
 		!server.support.TunnelServer.CheckEstablishTunnels() {
-		return "", nil, "", "", errors.TraceNew("not establishing tunnels")
+		return "", nil, nil, "", "", errors.TraceNew("not establishing tunnels")
 	}
 	}
 
 
 	// Create a new session
 	// Create a new session
@@ -664,8 +683,6 @@ func (server *MeekServer) getSessionOrEndpoint(
 
 
 	session.touch()
 	session.touch()
 
 
-	underlyingConn := request.Context().Value(meekNetConnContextKey).(net.Conn)
-
 	// Create a new meek conn that will relay the payload
 	// Create a new meek conn that will relay the payload
 	// between meek request/responses and the tunnel server client
 	// between meek request/responses and the tunnel server client
 	// handler. The client IP is also used to initialize the
 	// handler. The client IP is also used to initialize the
@@ -697,7 +714,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 	if clientSessionData.MeekProtocolVersion >= MEEK_PROTOCOL_VERSION_2 {
 	if clientSessionData.MeekProtocolVersion >= MEEK_PROTOCOL_VERSION_2 {
 		sessionID, err = makeMeekSessionID()
 		sessionID, err = makeMeekSessionID()
 		if err != nil {
 		if err != nil {
-			return "", nil, "", "", errors.Trace(err)
+			return "", nil, nil, "", "", errors.Trace(err)
 		}
 		}
 	}
 	}
 
 
@@ -709,7 +726,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// will close when session.delete calls Close() on the meekConn.
 	// will close when session.delete calls Close() on the meekConn.
 	server.clientHandler(clientSessionData.ClientTunnelProtocol, session.clientConn)
 	server.clientHandler(clientSessionData.ClientTunnelProtocol, session.clientConn)
 
 
-	return sessionID, session, "", "", nil
+	return sessionID, session, underlyingConn, "", "", nil
 }
 }
 
 
 func (server *MeekServer) rateLimit(clientIP string) bool {
 func (server *MeekServer) rateLimit(clientIP string) bool {
@@ -1149,8 +1166,10 @@ type meekSession struct {
 	metricPeakCachedResponseSize     int64
 	metricPeakCachedResponseSize     int64
 	metricPeakCachedResponseHitSize  int64
 	metricPeakCachedResponseHitSize  int64
 	metricCachedResponseMissPosition int64
 	metricCachedResponseMissPosition int64
+	metricUnderlyingConnCount        int64
 	lock                             sync.Mutex
 	lock                             sync.Mutex
 	deleted                          bool
 	deleted                          bool
+	underlyingConn                   net.Conn
 	clientConn                       *meekConn
 	clientConn                       *meekConn
 	meekProtocolVersion              int
 	meekProtocolVersion              int
 	sessionIDSent                    bool
 	sessionIDSent                    bool
@@ -1222,6 +1241,7 @@ func (session *meekSession) GetMetrics() common.LogFields {
 	logFields["meek_peak_cached_response_size"] = atomic.LoadInt64(&session.metricPeakCachedResponseSize)
 	logFields["meek_peak_cached_response_size"] = atomic.LoadInt64(&session.metricPeakCachedResponseSize)
 	logFields["meek_peak_cached_response_hit_size"] = atomic.LoadInt64(&session.metricPeakCachedResponseHitSize)
 	logFields["meek_peak_cached_response_hit_size"] = atomic.LoadInt64(&session.metricPeakCachedResponseHitSize)
 	logFields["meek_cached_response_miss_position"] = atomic.LoadInt64(&session.metricCachedResponseMissPosition)
 	logFields["meek_cached_response_miss_position"] = atomic.LoadInt64(&session.metricCachedResponseMissPosition)
+	logFields["meek_underlying_connection_count"] = atomic.LoadInt64(&session.metricUnderlyingConnCount)
 	return logFields
 	return logFields
 }
 }
 
 

+ 13 - 1
psiphon/server/meek_test.go

@@ -28,11 +28,13 @@ import (
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
 	"sync"
 	"sync"
+	"sync/atomic"
 	"syscall"
 	"syscall"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"golang.org/x/crypto/nacl/box"
 	"golang.org/x/crypto/nacl/box"
@@ -253,7 +255,10 @@ func TestMeekResiliency(t *testing.T) {
 
 
 	relayWaitGroup := new(sync.WaitGroup)
 	relayWaitGroup := new(sync.WaitGroup)
 
 
+	var serverClientConn atomic.Value
+
 	clientHandler := func(_ string, conn net.Conn) {
 	clientHandler := func(_ string, conn net.Conn) {
+		serverClientConn.Store(conn)
 		name := "server"
 		name := "server"
 		relayWaitGroup.Add(1)
 		relayWaitGroup.Add(1)
 		go func() {
 		go func() {
@@ -342,7 +347,6 @@ func TestMeekResiliency(t *testing.T) {
 	// Relay data through meek while interrupting underlying TCP connections
 	// Relay data through meek while interrupting underlying TCP connections
 
 
 	name := "client"
 	name := "client"
-
 	relayWaitGroup.Add(1)
 	relayWaitGroup.Add(1)
 	go func() {
 	go func() {
 		defer relayWaitGroup.Done()
 		defer relayWaitGroup.Done()
@@ -357,6 +361,14 @@ func TestMeekResiliency(t *testing.T) {
 
 
 	relayWaitGroup.Wait()
 	relayWaitGroup.Wait()
 
 
+	// Check for multiple underlying connections
+
+	metrics := serverClientConn.Load().(common.MetricsSource).GetMetrics()
+	count := metrics["meek_underlying_connection_count"].(int64)
+	if count <= 1 {
+		t.Fatalf("unexpected meek_underlying_connection_count: %d", count)
+	}
+
 	// Graceful shutdown
 	// Graceful shutdown
 
 
 	clientConn.Close()
 	clientConn.Close()

+ 5 - 1
psiphon/server/replay.go

@@ -74,10 +74,14 @@ type replayParameters struct {
 
 
 // NewReplayCache creates a new ReplayCache.
 // NewReplayCache creates a new ReplayCache.
 func NewReplayCache(support *SupportServices) *ReplayCache {
 func NewReplayCache(support *SupportServices) *ReplayCache {
+	// Cache TTL may vary based on tactics filtering, so each cache.Add must set
+	// the entry TTL.
 	return &ReplayCache{
 	return &ReplayCache{
 		support: support,
 		support: support,
 		cache: lrucache.NewWithLRU(
 		cache: lrucache.NewWithLRU(
-			0, REPLAY_CACHE_CLEANUP_INTERVAL, REPLAY_CACHE_MAX_ENTRIES),
+			lrucache.NoExpiration,
+			REPLAY_CACHE_CLEANUP_INTERVAL,
+			REPLAY_CACHE_MAX_ENTRIES),
 		metrics: &replayCacheMetrics{},
 		metrics: &replayCacheMetrics{},
 	}
 	}
 }
 }

+ 1 - 0
psiphon/server/server_test.go

@@ -1479,6 +1479,7 @@ func checkExpectedServerTunnelLogFields(
 			"meek_transformed_host_name",
 			"meek_transformed_host_name",
 			"meek_cookie_size",
 			"meek_cookie_size",
 			"meek_limit_request",
 			"meek_limit_request",
+			"meek_underlying_connection_count",
 			tactics.APPLIED_TACTICS_TAG_PARAMETER_NAME,
 			tactics.APPLIED_TACTICS_TAG_PARAMETER_NAME,
 		} {
 		} {
 			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
 			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {

+ 1 - 1
psiphon/tactics.go

@@ -259,7 +259,7 @@ func fetchTactics(
 		dialParams.TunnelProtocol,
 		dialParams.TunnelProtocol,
 		serverEntry.TacticsRequestPublicKey,
 		serverEntry.TacticsRequestPublicKey,
 		serverEntry.TacticsRequestObfuscatedKey,
 		serverEntry.TacticsRequestObfuscatedKey,
-		meekConn.RoundTrip)
+		meekConn.ObfuscatedRoundTrip)
 	if err != nil {
 	if err != nil {
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}

+ 375 - 268
psiphon/tlsDialer.go

@@ -47,20 +47,20 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
 */
 
 
-// Based on https://github.com/getlantern/tlsdialer (http://gopkg.in/getlantern/tlsdialer.v1)
-// which itself is a "Fork of crypto/tls.Dial and DialWithDialer"
+// Originally based on https://gopkg.in/getlantern/tlsdialer.v1.
 
 
 package psiphon
 package psiphon
 
 
 import (
 import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
+	"crypto/sha256"
 	"crypto/x509"
 	"crypto/x509"
+	"encoding/base64"
 	"encoding/hex"
 	"encoding/hex"
 	std_errors "errors"
 	std_errors "errors"
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
-	"time"
 
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -71,16 +71,16 @@ import (
 	utls "github.com/refraction-networking/utls"
 	utls "github.com/refraction-networking/utls"
 )
 )
 
 
-// CustomTLSConfig contains parameters to determine the behavior
-// of CustomTLSDial.
+// CustomTLSConfig specifies the parameters for a CustomTLSDial, supporting
+// many TLS-related network obfuscation mechanisms.
 type CustomTLSConfig struct {
 type CustomTLSConfig struct {
 
 
 	// Parameters is the active set of parameters.Parameters to use for the TLS
 	// Parameters is the active set of parameters.Parameters to use for the TLS
-	// dial.
+	// dial. Must not be nil.
 	Parameters *parameters.Parameters
 	Parameters *parameters.Parameters
 
 
-	// Dial is the network connection dialer. TLS is layered on
-	// top of a new network connection created with dialer.
+	// Dial is the network connection dialer. TLS is layered on top of a new
+	// network connection created with dialer. Must not be nil.
 	Dial common.Dialer
 	Dial common.Dialer
 
 
 	// DialAddr overrides the "addr" input to Dial when specified
 	// DialAddr overrides the "addr" input to Dial when specified
@@ -98,15 +98,35 @@ type CustomTLSConfig struct {
 	// SNIServerName is ignored when UseDialAddrSNI is true.
 	// SNIServerName is ignored when UseDialAddrSNI is true.
 	SNIServerName string
 	SNIServerName string
 
 
-	// SkipVerify completely disables server certificate verification.
-	SkipVerify bool
+	// VerifyServerName specifies a domain name that must appear in the server
+	// certificate. When specified, certificate verification checks for
+	// VerifyServerName in the server certificate, in place of the dial or SNI
+	// hostname.
+	VerifyServerName string
+
+	// VerifyPins specifies one or more certificate pin values, one of which must
+	// appear in the verified server certificate chain. A pin value is the
+	// base64-encoded SHA2 digest of a certificate's public key. When specified,
+	// at least one pin must match at least one certificate in the chain, at any
+	// position; e.g., the root CA may be pinned, or the server certificate,
+	// etc.
+	VerifyPins []string
 
 
 	// VerifyLegacyCertificate is a special case self-signed server
 	// VerifyLegacyCertificate is a special case self-signed server
 	// certificate case. Ignores IP SANs and basic constraints. No
 	// certificate case. Ignores IP SANs and basic constraints. No
 	// certificate chain. Just checks that the server presented the
 	// certificate chain. Just checks that the server presented the
-	// specified certificate. SNI is disbled when this is set.
+	// specified certificate.
+	//
+	// When VerifyLegacyCertificate is set, none of VerifyServerName, VerifyPins,
+	// SkipVerify may be set.
 	VerifyLegacyCertificate *x509.Certificate
 	VerifyLegacyCertificate *x509.Certificate
 
 
+	// SkipVerify completely disables server certificate verification.
+	//
+	// When SkipVerify is set, none of VerifyServerName, VerifyPins,
+	// VerifyLegacyCertificate may be set.
+	SkipVerify bool
+
 	// TLSProfile specifies a particular indistinguishable TLS profile to use for
 	// TLSProfile specifies a particular indistinguishable TLS profile to use for
 	// the TLS dial. Setting TLSProfile allows the caller to pin the selection so
 	// the TLS dial. Setting TLSProfile allows the caller to pin the selection so
 	// all TLS connections in a certain context (e.g. a single meek connection)
 	// all TLS connections in a certain context (e.g. a single meek connection)
@@ -159,209 +179,6 @@ func (config *CustomTLSConfig) EnableClientSessionCache() {
 	}
 	}
 }
 }
 
 
-// SelectTLSProfile picks a TLS profile at random from the available candidates.
-func SelectTLSProfile(
-	requireTLS12SessionTickets bool,
-	isFronted bool,
-	frontingProviderID string,
-	p parameters.ParametersAccessor) string {
-
-	// Two TLS profile lists are constructed, subject to limit constraints:
-	// stock, fixed parrots (non-randomized SupportedTLSProfiles) and custom
-	// parrots (CustomTLSProfileNames); and randomized. If one list is empty, the
-	// non-empty list is used. Otherwise SelectRandomizedTLSProfileProbability
-	// determines which list is used.
-	//
-	// Note that LimitTLSProfiles is not applied to CustomTLSProfiles; the
-	// presence of a candidate in CustomTLSProfiles is treated as explicit
-	// enabling.
-	//
-	// UseOnlyCustomTLSProfiles may be used to disable all stock TLS profiles and
-	// use only CustomTLSProfiles; UseOnlyCustomTLSProfiles is ignored if
-	// CustomTLSProfiles is empty.
-	//
-	// For fronted servers, DisableFrontingProviderTLSProfiles may be used
-	// to disable TLS profiles which are incompatible with the TLS stack used
-	// by the front. For example, if a utls parrot doesn't fully support all
-	// of the capabilities in the ClientHello. Unlike the LimitTLSProfiles case,
-	// DisableFrontingProviderTLSProfiles may disable CustomTLSProfiles.
-
-	limitTLSProfiles := p.TLSProfiles(parameters.LimitTLSProfiles)
-	var disableTLSProfiles protocol.TLSProfiles
-
-	if isFronted && frontingProviderID != "" {
-		disableTLSProfiles = p.LabeledTLSProfiles(
-			parameters.DisableFrontingProviderTLSProfiles, frontingProviderID)
-	}
-
-	randomizedTLSProfiles := make([]string, 0)
-	parrotTLSProfiles := make([]string, 0)
-
-	for _, tlsProfile := range p.CustomTLSProfileNames() {
-		if !common.Contains(disableTLSProfiles, tlsProfile) {
-			parrotTLSProfiles = append(parrotTLSProfiles, tlsProfile)
-		}
-	}
-
-	useOnlyCustomTLSProfiles := p.Bool(parameters.UseOnlyCustomTLSProfiles)
-	if useOnlyCustomTLSProfiles && len(parrotTLSProfiles) == 0 {
-		useOnlyCustomTLSProfiles = false
-	}
-
-	if !useOnlyCustomTLSProfiles {
-		for _, tlsProfile := range protocol.SupportedTLSProfiles {
-
-			if len(limitTLSProfiles) > 0 &&
-				!common.Contains(limitTLSProfiles, tlsProfile) {
-				continue
-			}
-
-			if common.Contains(disableTLSProfiles, tlsProfile) {
-				continue
-			}
-
-			// requireTLS12SessionTickets is specified for
-			// UNFRONTED-MEEK-SESSION-TICKET-OSSH, a protocol which depends on using
-			// obfuscated session tickets to ensure that the server doesn't send its
-			// certificate in the TLS handshake. TLS 1.2 profiles which omit session
-			// tickets should not be selected. As TLS 1.3 encrypts the server
-			// certificate message, there's no exclusion for TLS 1.3.
-
-			if requireTLS12SessionTickets &&
-				protocol.TLS12ProfileOmitsSessionTickets(tlsProfile) {
-				continue
-			}
-
-			if protocol.TLSProfileIsRandomized(tlsProfile) {
-				randomizedTLSProfiles = append(randomizedTLSProfiles, tlsProfile)
-			} else {
-				parrotTLSProfiles = append(parrotTLSProfiles, tlsProfile)
-			}
-		}
-	}
-
-	if len(randomizedTLSProfiles) > 0 &&
-		(len(parrotTLSProfiles) == 0 ||
-			p.WeightedCoinFlip(parameters.SelectRandomizedTLSProfileProbability)) {
-
-		return randomizedTLSProfiles[prng.Intn(len(randomizedTLSProfiles))]
-	}
-
-	if len(parrotTLSProfiles) == 0 {
-		return ""
-	}
-
-	return parrotTLSProfiles[prng.Intn(len(parrotTLSProfiles))]
-}
-
-func getUTLSClientHelloID(
-	p parameters.ParametersAccessor,
-	tlsProfile string) (utls.ClientHelloID, *utls.ClientHelloSpec, error) {
-
-	switch tlsProfile {
-	case protocol.TLS_PROFILE_IOS_111:
-		return utls.HelloIOS_11_1, nil, nil
-	case protocol.TLS_PROFILE_IOS_121:
-		return utls.HelloIOS_12_1, nil, nil
-	case protocol.TLS_PROFILE_CHROME_58:
-		return utls.HelloChrome_58, nil, nil
-	case protocol.TLS_PROFILE_CHROME_62:
-		return utls.HelloChrome_62, nil, nil
-	case protocol.TLS_PROFILE_CHROME_70:
-		return utls.HelloChrome_70, nil, nil
-	case protocol.TLS_PROFILE_CHROME_72:
-		return utls.HelloChrome_72, nil, nil
-	case protocol.TLS_PROFILE_CHROME_83:
-		return utls.HelloChrome_83, nil, nil
-	case protocol.TLS_PROFILE_FIREFOX_55:
-		return utls.HelloFirefox_55, nil, nil
-	case protocol.TLS_PROFILE_FIREFOX_56:
-		return utls.HelloFirefox_56, nil, nil
-	case protocol.TLS_PROFILE_FIREFOX_65:
-		return utls.HelloFirefox_65, nil, nil
-	case protocol.TLS_PROFILE_RANDOMIZED:
-		return utls.HelloRandomized, nil, nil
-	}
-
-	// utls.HelloCustom with a utls.ClientHelloSpec is used for
-	// CustomTLSProfiles.
-
-	customTLSProfile := p.CustomTLSProfile(tlsProfile)
-	if customTLSProfile == nil {
-		return utls.HelloCustom,
-			nil,
-			errors.Tracef("unknown TLS profile: %s", tlsProfile)
-	}
-
-	utlsClientHelloSpec, err := customTLSProfile.GetClientHelloSpec()
-	if err != nil {
-		return utls.ClientHelloID{}, nil, errors.Trace(err)
-	}
-
-	return utls.HelloCustom, utlsClientHelloSpec, nil
-}
-
-func getClientHelloVersion(
-	utlsClientHelloID utls.ClientHelloID,
-	utlsClientHelloSpec *utls.ClientHelloSpec) (string, error) {
-
-	switch utlsClientHelloID {
-
-	case utls.HelloIOS_11_1, utls.HelloIOS_12_1, utls.HelloChrome_58,
-		utls.HelloChrome_62, utls.HelloFirefox_55, utls.HelloFirefox_56:
-		return protocol.TLS_VERSION_12, nil
-
-	case utls.HelloChrome_70, utls.HelloChrome_72, utls.HelloChrome_83,
-		utls.HelloFirefox_65, utls.HelloGolang:
-		return protocol.TLS_VERSION_13, nil
-	}
-
-	// As utls.HelloRandomized/Custom may be either TLS 1.2 or TLS 1.3, we cannot
-	// perform a simple ClientHello ID check. BuildHandshakeState is run, which
-	// constructs the entire ClientHello.
-	//
-	// Assumes utlsClientHelloID.Seed has been set; otherwise the result is
-	// ephemeral.
-	//
-	// BenchmarkRandomizedGetClientHelloVersion indicates that this operation
-	// takes on the order of 0.05ms and allocates ~8KB for randomized client
-	// hellos.
-
-	conn := utls.UClient(
-		nil,
-		&utls.Config{InsecureSkipVerify: true},
-		utlsClientHelloID)
-
-	if utlsClientHelloSpec != nil {
-		err := conn.ApplyPreset(utlsClientHelloSpec)
-		if err != nil {
-			return "", errors.Trace(err)
-		}
-	}
-
-	err := conn.BuildHandshakeState()
-	if err != nil {
-		return "", errors.Trace(err)
-	}
-
-	for _, v := range conn.HandshakeState.Hello.SupportedVersions {
-		if v == utls.VersionTLS13 {
-			return protocol.TLS_VERSION_13, nil
-		}
-	}
-
-	return protocol.TLS_VERSION_12, nil
-}
-
-func IsTLSConnUsingHTTP2(conn net.Conn) bool {
-	if c, ok := conn.(*utls.UConn); ok {
-		state := c.ConnectionState()
-		return state.NegotiatedProtocolIsMutual &&
-			state.NegotiatedProtocol == "h2"
-	}
-	return false
-}
-
 // NewCustomTLSDialer creates a new dialer based on CustomTLSDial.
 // NewCustomTLSDialer creates a new dialer based on CustomTLSDial.
 func NewCustomTLSDialer(config *CustomTLSConfig) common.Dialer {
 func NewCustomTLSDialer(config *CustomTLSConfig) common.Dialer {
 	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 	return func(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -369,21 +186,28 @@ func NewCustomTLSDialer(config *CustomTLSConfig) common.Dialer {
 	}
 	}
 }
 }
 
 
-// CustomTLSDial is a customized replacement for tls.Dial.
-// Based on tlsdialer.DialWithDialer which is based on crypto/tls.DialWithDialer.
-//
-// To ensure optimal TLS profile selection when using CustomTLSDial for tunnel
-// protocols, call SelectTLSProfile first and set its result into
-// config.TLSProfile.
+// CustomTLSDial dials a new TLS connection using the parameters set in
+// CustomTLSConfig.
 //
 //
-// tlsdialer comment:
-//   Note - if sendServerName is false, the VerifiedChains field on the
-//   connection's ConnectionState will never get populated.
+// The dial aborts if ctx becomes Done before the dial completes.
 func CustomTLSDial(
 func CustomTLSDial(
 	ctx context.Context,
 	ctx context.Context,
 	network, addr string,
 	network, addr string,
 	config *CustomTLSConfig) (net.Conn, error) {
 	config *CustomTLSConfig) (net.Conn, error) {
 
 
+	if (config.SkipVerify &&
+		(config.VerifyLegacyCertificate != nil ||
+			len(config.VerifyServerName) > 0 ||
+			len(config.VerifyPins) > 0)) ||
+
+		(config.VerifyLegacyCertificate != nil &&
+			(config.SkipVerify ||
+				len(config.VerifyServerName) > 0 ||
+				len(config.VerifyPins) > 0)) {
+
+		return nil, errors.TraceNew("incompatible certification verification parameters")
+	}
+
 	p := config.Parameters.Get()
 	p := config.Parameters.Get()
 
 
 	dialAddr := addr
 	dialAddr := addr
@@ -402,51 +226,116 @@ func CustomTLSDial(
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
 
 
-	selectedTLSProfile := config.TLSProfile
+	var tlsConfigRootCAs *x509.CertPool
+	if !config.SkipVerify &&
+		config.VerifyLegacyCertificate == nil &&
+		config.TrustedCACertificatesFilename != "" {
 
 
-	if selectedTLSProfile == "" {
-		selectedTLSProfile = SelectTLSProfile(false, false, "", p)
+		tlsConfigRootCAs = x509.NewCertPool()
+		certData, err := ioutil.ReadFile(config.TrustedCACertificatesFilename)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		tlsConfigRootCAs.AppendCertsFromPEM(certData)
 	}
 	}
 
 
+	// In some cases, config.SkipVerify is false, but
+	// utls.Config.InsecureSkipVerify will be set to true to disable verification
+	// in utls that will otherwise fail: when SNI is omitted, and when
+	// VerifyServerName differs from SNI. In these cases, the certificate chain
+	// is verified in VerifyPeerCertificate.
+
 	tlsConfigInsecureSkipVerify := false
 	tlsConfigInsecureSkipVerify := false
 	tlsConfigServerName := ""
 	tlsConfigServerName := ""
+	verifyServerName := hostname
 
 
 	if config.SkipVerify {
 	if config.SkipVerify {
 		tlsConfigInsecureSkipVerify = true
 		tlsConfigInsecureSkipVerify = true
 	}
 	}
 
 
 	if config.UseDialAddrSNI {
 	if config.UseDialAddrSNI {
+
+		// Set SNI to match the dial hostname. This is the standard case.
 		tlsConfigServerName = hostname
 		tlsConfigServerName = hostname
-	} else if config.SNIServerName != "" && config.VerifyLegacyCertificate == nil {
-		// Set the ServerName and rely on the usual logic in
-		// tls.Conn.Handshake() to do its verification.
-		// Note: Go TLS will automatically omit this ServerName when it's an IP address
+
+	} else if config.SNIServerName != "" {
+
+		// Set a custom SNI value. If this value doesn't match the server
+		// certificate, SkipVerify and/or VerifyServerName may need to be
+		// configured; but by itself this case doesn't necessarily require
+		// custom certificate verification.
 		tlsConfigServerName = config.SNIServerName
 		tlsConfigServerName = config.SNIServerName
+
 	} else {
 	} else {
-		// No SNI.
-		// Disable verification in tls.Conn.Handshake().  We'll verify manually
-		// after handshaking
+
+		// Omit SNI. If SkipVerify is not set, this case requires custom certificate
+		// verification, which will check that the server certificate matches either
+		// the dial hostname or VerifyServerName, as if the SNI were set to one of
+		// those values.
 		tlsConfigInsecureSkipVerify = true
 		tlsConfigInsecureSkipVerify = true
 	}
 	}
 
 
-	var tlsRootCAs *x509.CertPool
+	// When VerifyServerName does not match the SNI, custom certificate
+	// verification is necessary.
+	if config.VerifyServerName != "" && config.VerifyServerName != tlsConfigServerName {
+		verifyServerName = config.VerifyServerName
+		tlsConfigInsecureSkipVerify = true
+	}
 
 
-	if !config.SkipVerify &&
-		config.VerifyLegacyCertificate == nil &&
-		config.TrustedCACertificatesFilename != "" {
+	// With the VerifyPeerCertificate callback, we perform any custom certificate
+	// verification at the same point in the TLS handshake as standard utls
+	// verification; and abort the handshake at the same point, if custom
+	// verification fails.
+	var tlsConfigVerifyPeerCertificate func([][]byte, [][]*x509.Certificate) error
+	if !config.SkipVerify {
+		tlsConfigVerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
 
 
-		tlsRootCAs = x509.NewCertPool()
-		certData, err := ioutil.ReadFile(config.TrustedCACertificatesFilename)
-		if err != nil {
-			return nil, errors.Trace(err)
+			if config.VerifyLegacyCertificate != nil {
+				return verifyLegacyCertificate(
+					rawCerts, config.VerifyLegacyCertificate)
+			}
+
+			if tlsConfigInsecureSkipVerify {
+
+				// Limitation: this verification path does not set the utls.Conn's
+				// ConnectionState certificate information.
+
+				if len(verifiedChains) > 0 {
+					return errors.TraceNew("unexpected verified chains")
+				}
+				var err error
+				verifiedChains, err = verifyServerCertificate(
+					tlsConfigRootCAs, rawCerts, verifyServerName)
+				if err != nil {
+					return errors.Trace(err)
+				}
+			}
+
+			if len(config.VerifyPins) > 0 {
+				err := verifyCertificatePins(
+					config.VerifyPins, verifiedChains)
+				if err != nil {
+					return errors.Trace(err)
+				}
+			}
+
+			return nil
 		}
 		}
-		tlsRootCAs.AppendCertsFromPEM(certData)
 	}
 	}
 
 
+	// Note: utls will automatically omit SNI when ServerName is an IP address.
+
 	tlsConfig := &utls.Config{
 	tlsConfig := &utls.Config{
-		RootCAs:            tlsRootCAs,
-		InsecureSkipVerify: tlsConfigInsecureSkipVerify,
-		ServerName:         tlsConfigServerName,
+		RootCAs:               tlsConfigRootCAs,
+		InsecureSkipVerify:    tlsConfigInsecureSkipVerify,
+		ServerName:            tlsConfigServerName,
+		VerifyPeerCertificate: tlsConfigVerifyPeerCertificate,
+	}
+
+	selectedTLSProfile := config.TLSProfile
+
+	if selectedTLSProfile == "" {
+		selectedTLSProfile = SelectTLSProfile(false, false, "", p)
 	}
 	}
 
 
 	utlsClientHelloID, utlsClientHelloSpec, err := getUTLSClientHelloID(
 	utlsClientHelloID, utlsClientHelloSpec, err := getUTLSClientHelloID(
@@ -697,16 +586,6 @@ func CustomTLSDial(
 		<-resultChannel
 		<-resultChannel
 	}
 	}
 
 
-	if err == nil && !config.SkipVerify && tlsConfigInsecureSkipVerify {
-
-		if config.VerifyLegacyCertificate != nil {
-			err = verifyLegacyCertificate(conn, config.VerifyLegacyCertificate)
-		} else {
-			// Manually verify certificates
-			err = verifyServerCerts(conn, hostname)
-		}
-	}
-
 	if err != nil {
 	if err != nil {
 		rawConn.Close()
 		rawConn.Close()
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
@@ -715,24 +594,33 @@ func CustomTLSDial(
 	return conn, nil
 	return conn, nil
 }
 }
 
 
-func verifyLegacyCertificate(conn *utls.UConn, expectedCertificate *x509.Certificate) error {
-	certs := conn.ConnectionState().PeerCertificates
-	if len(certs) < 1 {
-		return errors.TraceNew("no certificate to verify")
+func verifyLegacyCertificate(rawCerts [][]byte, expectedCertificate *x509.Certificate) error {
+	if len(rawCerts) < 1 {
+		return errors.TraceNew("missing certificate")
 	}
 	}
-	if !bytes.Equal(certs[0].Raw, expectedCertificate.Raw) {
+	if !bytes.Equal(rawCerts[0], expectedCertificate.Raw) {
 		return errors.TraceNew("unexpected certificate")
 		return errors.TraceNew("unexpected certificate")
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
-func verifyServerCerts(conn *utls.UConn, hostname string) error {
-	certs := conn.ConnectionState().PeerCertificates
+func verifyServerCertificate(
+	rootCAs *x509.CertPool, rawCerts [][]byte, verifyServerName string) ([][]*x509.Certificate, error) {
+
+	// This duplicates the verification logic in utls (and standard crypto/tls).
+
+	certs := make([]*x509.Certificate, len(rawCerts))
+	for i, rawCert := range rawCerts {
+		cert, err := x509.ParseCertificate(rawCert)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		certs[i] = cert
+	}
 
 
 	opts := x509.VerifyOptions{
 	opts := x509.VerifyOptions{
-		Roots:         nil, // Use host's root CAs
-		CurrentTime:   time.Now(),
-		DNSName:       hostname,
+		Roots:         rootCAs,
+		DNSName:       verifyServerName,
 		Intermediates: x509.NewCertPool(),
 		Intermediates: x509.NewCertPool(),
 	}
 	}
 
 
@@ -743,11 +631,230 @@ func verifyServerCerts(conn *utls.UConn, hostname string) error {
 		opts.Intermediates.AddCert(cert)
 		opts.Intermediates.AddCert(cert)
 	}
 	}
 
 
-	_, err := certs[0].Verify(opts)
+	verifiedChains, err := certs[0].Verify(opts)
 	if err != nil {
 	if err != nil {
-		return errors.Trace(err)
+		return nil, errors.Trace(err)
 	}
 	}
-	return nil
+
+	return verifiedChains, nil
+}
+
+func verifyCertificatePins(pins []string, verifiedChains [][]*x509.Certificate) error {
+	for _, chain := range verifiedChains {
+		for _, cert := range chain {
+			publicKeyDigest := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
+			expectedPin := base64.StdEncoding.EncodeToString(publicKeyDigest[:])
+			if common.Contains(pins, expectedPin) {
+				// Return success on the first match of any certificate public key to any
+				// pin.
+				return nil
+			}
+		}
+	}
+	return errors.TraceNew("no pin found")
+}
+
+func IsTLSConnUsingHTTP2(conn net.Conn) bool {
+	if c, ok := conn.(*utls.UConn); ok {
+		state := c.ConnectionState()
+		return state.NegotiatedProtocolIsMutual &&
+			state.NegotiatedProtocol == "h2"
+	}
+	return false
+}
+
+// SelectTLSProfile picks a TLS profile at random from the available candidates.
+func SelectTLSProfile(
+	requireTLS12SessionTickets bool,
+	isFronted bool,
+	frontingProviderID string,
+	p parameters.ParametersAccessor) string {
+
+	// Two TLS profile lists are constructed, subject to limit constraints:
+	// stock, fixed parrots (non-randomized SupportedTLSProfiles) and custom
+	// parrots (CustomTLSProfileNames); and randomized. If one list is empty, the
+	// non-empty list is used. Otherwise SelectRandomizedTLSProfileProbability
+	// determines which list is used.
+	//
+	// Note that LimitTLSProfiles is not applied to CustomTLSProfiles; the
+	// presence of a candidate in CustomTLSProfiles is treated as explicit
+	// enabling.
+	//
+	// UseOnlyCustomTLSProfiles may be used to disable all stock TLS profiles and
+	// use only CustomTLSProfiles; UseOnlyCustomTLSProfiles is ignored if
+	// CustomTLSProfiles is empty.
+	//
+	// For fronted servers, DisableFrontingProviderTLSProfiles may be used
+	// to disable TLS profiles which are incompatible with the TLS stack used
+	// by the front. For example, if a utls parrot doesn't fully support all
+	// of the capabilities in the ClientHello. Unlike the LimitTLSProfiles case,
+	// DisableFrontingProviderTLSProfiles may disable CustomTLSProfiles.
+
+	limitTLSProfiles := p.TLSProfiles(parameters.LimitTLSProfiles)
+	var disableTLSProfiles protocol.TLSProfiles
+
+	if isFronted && frontingProviderID != "" {
+		disableTLSProfiles = p.LabeledTLSProfiles(
+			parameters.DisableFrontingProviderTLSProfiles, frontingProviderID)
+	}
+
+	randomizedTLSProfiles := make([]string, 0)
+	parrotTLSProfiles := make([]string, 0)
+
+	for _, tlsProfile := range p.CustomTLSProfileNames() {
+		if !common.Contains(disableTLSProfiles, tlsProfile) {
+			parrotTLSProfiles = append(parrotTLSProfiles, tlsProfile)
+		}
+	}
+
+	useOnlyCustomTLSProfiles := p.Bool(parameters.UseOnlyCustomTLSProfiles)
+	if useOnlyCustomTLSProfiles && len(parrotTLSProfiles) == 0 {
+		useOnlyCustomTLSProfiles = false
+	}
+
+	if !useOnlyCustomTLSProfiles {
+		for _, tlsProfile := range protocol.SupportedTLSProfiles {
+
+			if len(limitTLSProfiles) > 0 &&
+				!common.Contains(limitTLSProfiles, tlsProfile) {
+				continue
+			}
+
+			if common.Contains(disableTLSProfiles, tlsProfile) {
+				continue
+			}
+
+			// requireTLS12SessionTickets is specified for
+			// UNFRONTED-MEEK-SESSION-TICKET-OSSH, a protocol which depends on using
+			// obfuscated session tickets to ensure that the server doesn't send its
+			// certificate in the TLS handshake. TLS 1.2 profiles which omit session
+			// tickets should not be selected. As TLS 1.3 encrypts the server
+			// certificate message, there's no exclusion for TLS 1.3.
+
+			if requireTLS12SessionTickets &&
+				protocol.TLS12ProfileOmitsSessionTickets(tlsProfile) {
+				continue
+			}
+
+			if protocol.TLSProfileIsRandomized(tlsProfile) {
+				randomizedTLSProfiles = append(randomizedTLSProfiles, tlsProfile)
+			} else {
+				parrotTLSProfiles = append(parrotTLSProfiles, tlsProfile)
+			}
+		}
+	}
+
+	if len(randomizedTLSProfiles) > 0 &&
+		(len(parrotTLSProfiles) == 0 ||
+			p.WeightedCoinFlip(parameters.SelectRandomizedTLSProfileProbability)) {
+
+		return randomizedTLSProfiles[prng.Intn(len(randomizedTLSProfiles))]
+	}
+
+	if len(parrotTLSProfiles) == 0 {
+		return ""
+	}
+
+	return parrotTLSProfiles[prng.Intn(len(parrotTLSProfiles))]
+}
+
+func getUTLSClientHelloID(
+	p parameters.ParametersAccessor,
+	tlsProfile string) (utls.ClientHelloID, *utls.ClientHelloSpec, error) {
+
+	switch tlsProfile {
+	case protocol.TLS_PROFILE_IOS_111:
+		return utls.HelloIOS_11_1, nil, nil
+	case protocol.TLS_PROFILE_IOS_121:
+		return utls.HelloIOS_12_1, nil, nil
+	case protocol.TLS_PROFILE_CHROME_58:
+		return utls.HelloChrome_58, nil, nil
+	case protocol.TLS_PROFILE_CHROME_62:
+		return utls.HelloChrome_62, nil, nil
+	case protocol.TLS_PROFILE_CHROME_70:
+		return utls.HelloChrome_70, nil, nil
+	case protocol.TLS_PROFILE_CHROME_72:
+		return utls.HelloChrome_72, nil, nil
+	case protocol.TLS_PROFILE_CHROME_83:
+		return utls.HelloChrome_83, nil, nil
+	case protocol.TLS_PROFILE_FIREFOX_55:
+		return utls.HelloFirefox_55, nil, nil
+	case protocol.TLS_PROFILE_FIREFOX_56:
+		return utls.HelloFirefox_56, nil, nil
+	case protocol.TLS_PROFILE_FIREFOX_65:
+		return utls.HelloFirefox_65, nil, nil
+	case protocol.TLS_PROFILE_RANDOMIZED:
+		return utls.HelloRandomized, nil, nil
+	}
+
+	// utls.HelloCustom with a utls.ClientHelloSpec is used for
+	// CustomTLSProfiles.
+
+	customTLSProfile := p.CustomTLSProfile(tlsProfile)
+	if customTLSProfile == nil {
+		return utls.HelloCustom,
+			nil,
+			errors.Tracef("unknown TLS profile: %s", tlsProfile)
+	}
+
+	utlsClientHelloSpec, err := customTLSProfile.GetClientHelloSpec()
+	if err != nil {
+		return utls.ClientHelloID{}, nil, errors.Trace(err)
+	}
+
+	return utls.HelloCustom, utlsClientHelloSpec, nil
+}
+
+func getClientHelloVersion(
+	utlsClientHelloID utls.ClientHelloID,
+	utlsClientHelloSpec *utls.ClientHelloSpec) (string, error) {
+
+	switch utlsClientHelloID {
+
+	case utls.HelloIOS_11_1, utls.HelloIOS_12_1, utls.HelloChrome_58,
+		utls.HelloChrome_62, utls.HelloFirefox_55, utls.HelloFirefox_56:
+		return protocol.TLS_VERSION_12, nil
+
+	case utls.HelloChrome_70, utls.HelloChrome_72, utls.HelloChrome_83,
+		utls.HelloFirefox_65, utls.HelloGolang:
+		return protocol.TLS_VERSION_13, nil
+	}
+
+	// As utls.HelloRandomized/Custom may be either TLS 1.2 or TLS 1.3, we cannot
+	// perform a simple ClientHello ID check. BuildHandshakeState is run, which
+	// constructs the entire ClientHello.
+	//
+	// Assumes utlsClientHelloID.Seed has been set; otherwise the result is
+	// ephemeral.
+	//
+	// BenchmarkRandomizedGetClientHelloVersion indicates that this operation
+	// takes on the order of 0.05ms and allocates ~8KB for randomized client
+	// hellos.
+
+	conn := utls.UClient(
+		nil,
+		&utls.Config{InsecureSkipVerify: true},
+		utlsClientHelloID)
+
+	if utlsClientHelloSpec != nil {
+		err := conn.ApplyPreset(utlsClientHelloSpec)
+		if err != nil {
+			return "", errors.Trace(err)
+		}
+	}
+
+	err := conn.BuildHandshakeState()
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+
+	for _, v := range conn.HandshakeState.Hello.SupportedVersions {
+		if v == utls.VersionTLS13 {
+			return protocol.TLS_VERSION_13, nil
+		}
+	}
+
+	return protocol.TLS_VERSION_12, nil
 }
 }
 
 
 func init() {
 func init() {

+ 360 - 10
psiphon/tlsDialer_test.go

@@ -21,11 +21,24 @@ package psiphon
 
 
 import (
 import (
 	"context"
 	"context"
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/sha256"
+	"crypto/tls"
+	"crypto/x509"
+	"crypto/x509/pkix"
+	"encoding/base64"
 	"encoding/json"
 	"encoding/json"
+	"encoding/pem"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
+	"math/big"
 	"net"
 	"net"
+	"net/http"
+	"os"
+	"path/filepath"
 	"strings"
 	"strings"
+	"sync"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -37,6 +50,343 @@ import (
 	utls "github.com/refraction-networking/utls"
 	utls "github.com/refraction-networking/utls"
 )
 )
 
 
+func TestTLSCertificateVerification(t *testing.T) {
+
+	testDataDirName, err := ioutil.TempDir("", "psiphon-tls-certificate-verification-test")
+	if err != nil {
+		t.Fatalf("TempDir failed: %v", err)
+	}
+	defer os.RemoveAll(testDataDirName)
+
+	serverName := "example.org"
+
+	rootCAsFileName,
+		rootCACertificatePin,
+		serverCertificatePin,
+		shutdown,
+		serverAddr,
+		dialer := initTestCertificatesAndWebServer(
+		t, testDataDirName, serverName)
+	defer shutdown()
+
+	// Test: without custom RootCAs, the TLS dial fails.
+
+	params, err := parameters.NewParameters(nil)
+	if err != nil {
+		t.Fatalf("parameters.NewParameters failed: %v", err)
+	}
+
+	conn, err := CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters: params,
+			Dial:       dialer,
+		})
+
+	if err == nil {
+		conn.Close()
+		t.Errorf("unexpected success without custom RootCAs")
+	}
+
+	// Test: without custom RootCAs and with SkipVerify, the TLS dial succeeds.
+
+	conn, err = CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters: params,
+			Dial:       dialer,
+			SkipVerify: true,
+		})
+
+	if err != nil {
+		t.Errorf("CustomTLSDial failed: %v", err)
+	} else {
+		conn.Close()
+	}
+
+	// Test: with custom RootCAs, the TLS dial succeeds.
+
+	conn, err = CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters:                    params,
+			Dial:                          dialer,
+			TrustedCACertificatesFilename: rootCAsFileName,
+		})
+
+	if err != nil {
+		t.Errorf("CustomTLSDial failed: %v", err)
+	} else {
+		conn.Close()
+	}
+
+	// Test: with SNI changed and VerifyServerName set, the TLS dial succeeds.
+
+	conn, err = CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters:                    params,
+			Dial:                          dialer,
+			SNIServerName:                 "not-" + serverName,
+			VerifyServerName:              serverName,
+			TrustedCACertificatesFilename: rootCAsFileName,
+		})
+
+	if err != nil {
+		t.Errorf("CustomTLSDial failed: %v", err)
+	} else {
+		conn.Close()
+	}
+
+	// Test: with an invalid pin, the TLS dial fails.
+
+	invalidPin := base64.StdEncoding.EncodeToString(make([]byte, 32))
+
+	conn, err = CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters:                    params,
+			Dial:                          dialer,
+			VerifyPins:                    []string{invalidPin},
+			TrustedCACertificatesFilename: rootCAsFileName,
+		})
+
+	if err == nil {
+		conn.Close()
+		t.Errorf("unexpected success without invalid pin")
+	}
+
+	// Test: with the root CA certirficate pinned, the TLS dial succeeds.
+
+	conn, err = CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters:                    params,
+			Dial:                          dialer,
+			VerifyPins:                    []string{rootCACertificatePin},
+			TrustedCACertificatesFilename: rootCAsFileName,
+		})
+
+	if err != nil {
+		t.Errorf("CustomTLSDial failed: %v", err)
+	} else {
+		conn.Close()
+	}
+
+	// Test: with the server certificate pinned, the TLS dial succeeds.
+
+	conn, err = CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters:                    params,
+			Dial:                          dialer,
+			VerifyPins:                    []string{serverCertificatePin},
+			TrustedCACertificatesFilename: rootCAsFileName,
+		})
+
+	if err != nil {
+		t.Errorf("CustomTLSDial failed: %v", err)
+	} else {
+		conn.Close()
+	}
+
+	// Test: with SNI changed, VerifyServerName set, and pinning the TLS dial
+	// succeeds.
+
+	conn, err = CustomTLSDial(
+		context.Background(), "tcp", serverAddr,
+		&CustomTLSConfig{
+			Parameters:                    params,
+			Dial:                          dialer,
+			SNIServerName:                 "not-" + serverName,
+			VerifyServerName:              serverName,
+			VerifyPins:                    []string{rootCACertificatePin},
+			TrustedCACertificatesFilename: rootCAsFileName,
+		})
+
+	if err != nil {
+		t.Errorf("CustomTLSDial failed: %v", err)
+	} else {
+		conn.Close()
+	}
+}
+
+// initTestCertificatesAndWebServer creates a Root CA, a web server
+// certificate, for serverName, signed by that Root CA, and runs a web server
+// that uses that server certificate. initRootCAandWebServer returns:
+//
+// - the file name containing the Root CA, to be used with
+//   CustomTLSConfig.TrustedCACertificatesFilename
+//
+// - pin values for the Root CA and server certificare, to be used with
+//   CustomTLSConfig.VerifyPins
+//
+// - a shutdown function which the caller must invoked to terminate the web
+//   server
+//
+// - the web server dial address: serverName and port
+//
+// - and a dialer function, which bypasses DNS resolution of serverName, to be
+//   used with CustomTLSConfig.Dial
+func initTestCertificatesAndWebServer(
+	t *testing.T,
+	testDataDirName string,
+	serverName string) (string, string, string, func(), string, common.Dialer) {
+
+	// Generate a root CA certificate.
+
+	rootCACertificate := &x509.Certificate{
+		SerialNumber: big.NewInt(1),
+		Subject: pkix.Name{
+			Organization: []string{"test"},
+		},
+		NotBefore:             time.Now(),
+		NotAfter:              time.Now().AddDate(1, 0, 0),
+		IsCA:                  true,
+		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+		BasicConstraintsValid: true,
+	}
+
+	rootCAPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096)
+	if err != nil {
+		t.Fatalf("rsa.GenerateKey failed: %v", err)
+	}
+
+	rootCACertificateBytes, err := x509.CreateCertificate(
+		rand.Reader,
+		rootCACertificate,
+		rootCACertificate,
+		&rootCAPrivateKey.PublicKey,
+		rootCAPrivateKey)
+	if err != nil {
+		t.Fatalf("x509.CreateCertificate failed: %v", err)
+	}
+
+	pemRootCACertificate := pem.EncodeToMemory(
+		&pem.Block{
+			Type:  "CERTIFICATE",
+			Bytes: rootCACertificateBytes,
+		})
+
+	// Generate a server certificate.
+
+	serverCertificate := &x509.Certificate{
+		SerialNumber: big.NewInt(2),
+		Subject: pkix.Name{
+			Organization: []string{"test"},
+		},
+		DNSNames:    []string{serverName},
+		NotBefore:   time.Now(),
+		NotAfter:    time.Now().AddDate(1, 0, 0),
+		ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+		KeyUsage:    x509.KeyUsageDigitalSignature,
+	}
+
+	serverPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096)
+	if err != nil {
+		t.Fatalf("rsa.GenerateKey failed: %v", err)
+	}
+
+	serverCertificateBytes, err := x509.CreateCertificate(
+		rand.Reader,
+		serverCertificate,
+		rootCACertificate,
+		&serverPrivateKey.PublicKey,
+		rootCAPrivateKey)
+	if err != nil {
+		t.Fatalf("x509.CreateCertificate failed: %v", err)
+	}
+
+	pemServerCertificate := pem.EncodeToMemory(
+		&pem.Block{
+			Type:  "CERTIFICATE",
+			Bytes: serverCertificateBytes,
+		})
+
+	pemServerPrivateKey := pem.EncodeToMemory(
+		&pem.Block{
+			Type:  "RSA PRIVATE KEY",
+			Bytes: x509.MarshalPKCS1PrivateKey(serverPrivateKey),
+		})
+
+	// Pave Root CA file.
+
+	rootCAsFileName := filepath.Join(testDataDirName, "RootCAs.pem")
+	err = ioutil.WriteFile(rootCAsFileName, pemRootCACertificate, 0600)
+	if err != nil {
+		t.Fatalf("WriteFile failed: %v", err)
+	}
+
+	// Calculate certificate pins.
+
+	parsedCertificate, err := x509.ParseCertificate(rootCACertificateBytes)
+	if err != nil {
+		t.Fatalf("x509.ParseCertificate failed: %v", err)
+	}
+	publicKeyDigest := sha256.Sum256(parsedCertificate.RawSubjectPublicKeyInfo)
+	rootCACertificatePin := base64.StdEncoding.EncodeToString(publicKeyDigest[:])
+
+	parsedCertificate, err = x509.ParseCertificate(serverCertificateBytes)
+	if err != nil {
+		t.Fatalf("x509.ParseCertificate failed: %v", err)
+	}
+	publicKeyDigest = sha256.Sum256(parsedCertificate.RawSubjectPublicKeyInfo)
+	serverCertificatePin := base64.StdEncoding.EncodeToString(publicKeyDigest[:])
+
+	// Run an HTTPS server with the server certificate.
+
+	dialAddr := "127.0.0.1:8000"
+	serverAddr := fmt.Sprintf("%s:8000", serverName)
+
+	serverKeyPair, err := tls.X509KeyPair(
+		pemServerCertificate, pemServerPrivateKey)
+	if err != nil {
+		t.Fatalf("tls.X509KeyPair failed: %v", err)
+	}
+
+	mux := http.NewServeMux()
+	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte("test"))
+	})
+
+	server := &http.Server{
+		Addr: dialAddr,
+		TLSConfig: &tls.Config{
+			Certificates: []tls.Certificate{serverKeyPair},
+		},
+		Handler: mux,
+	}
+
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		wg.Done()
+		server.ListenAndServeTLS("", "")
+	}()
+
+	shutdown := func() {
+		server.Shutdown(context.Background())
+		wg.Wait()
+	}
+
+	// Initialize a custom dialer for the client which bypasses DNS resolution.
+
+	dialer := func(ctx context.Context, network, address string) (net.Conn, error) {
+		d := &net.Dialer{}
+		// Ignore the address input, which will be serverAddr, and dial dialAddr, as
+		// if the serverName in serverAddr had been resolved to "127.0.0.1".
+		return d.DialContext(ctx, network, dialAddr)
+	}
+
+	return rootCAsFileName,
+		rootCACertificatePin,
+		serverCertificatePin,
+		shutdown,
+		serverAddr,
+		dialer
+}
+
 func TestTLSDialerCompatibility(t *testing.T) {
 func TestTLSDialerCompatibility(t *testing.T) {
 
 
 	// This test checks that each TLS profile can successfully complete a TLS
 	// This test checks that each TLS profile can successfully complete a TLS
@@ -76,12 +426,12 @@ func testTLSDialerCompatibility(t *testing.T, address string) {
 
 
 		certificate, privateKey, err := common.GenerateWebServerCertificate(values.GetHostName())
 		certificate, privateKey, err := common.GenerateWebServerCertificate(values.GetHostName())
 		if err != nil {
 		if err != nil {
-			t.Fatalf("%s\n", err)
+			t.Fatalf("common.GenerateWebServerCertificate failed: %v", err)
 		}
 		}
 
 
 		tlsCertificate, err := tris.X509KeyPair([]byte(certificate), []byte(privateKey))
 		tlsCertificate, err := tris.X509KeyPair([]byte(certificate), []byte(privateKey))
 		if err != nil {
 		if err != nil {
-			t.Fatalf("%s\n", err)
+			t.Fatalf("tris.X509KeyPair failed: %v", err)
 		}
 		}
 
 
 		config := &tris.Config{
 		config := &tris.Config{
@@ -93,7 +443,7 @@ func testTLSDialerCompatibility(t *testing.T, address string) {
 
 
 		tcpListener, err := net.Listen("tcp", "127.0.0.1:0")
 		tcpListener, err := net.Listen("tcp", "127.0.0.1:0")
 		if err != nil {
 		if err != nil {
-			t.Fatalf("%s\n", err)
+			t.Fatalf("net.Listen failed: %v", err)
 		}
 		}
 
 
 		tlsListener := tris.NewListener(tcpListener, config)
 		tlsListener := tris.NewListener(tcpListener, config)
@@ -109,7 +459,7 @@ func testTLSDialerCompatibility(t *testing.T, address string) {
 				}
 				}
 				err = conn.(*tris.Conn).Handshake()
 				err = conn.(*tris.Conn).Handshake()
 				if err != nil {
 				if err != nil {
-					t.Logf("server handshake: %s", err)
+					t.Logf("tris.Conn.Handshake failed: %v", err)
 				}
 				}
 				conn.Close()
 				conn.Close()
 			}
 			}
@@ -157,7 +507,7 @@ func testTLSDialerCompatibility(t *testing.T, address string) {
 			conn, err := CustomTLSDial(ctx, "tcp", address, tlsConfig)
 			conn, err := CustomTLSDial(ctx, "tcp", address, tlsConfig)
 
 
 			if err != nil {
 			if err != nil {
-				t.Logf("%s (transformHostname: %v): %s\n",
+				t.Logf("CustomTLSDial failed: %s (transformHostname: %v): %v",
 					tlsProfile, transformHostname, err)
 					tlsProfile, transformHostname, err)
 			} else {
 			} else {
 
 
@@ -184,7 +534,7 @@ func testTLSDialerCompatibility(t *testing.T, address string) {
 		}
 		}
 
 
 		result := fmt.Sprintf(
 		result := fmt.Sprintf(
-			"%s: %d/%d successful; negotiated TLS versions: %v\n",
+			"%s: %d/%d successful; negotiated TLS versions: %v",
 			tlsProfile, success, repeats, tlsVersions)
 			tlsProfile, success, repeats, tlsVersions)
 
 
 		if success == repeats {
 		if success == repeats {
@@ -252,7 +602,7 @@ func TestSelectTLSProfile(t *testing.T) {
 		utlsClientHelloID, utlsClientHelloSpec, err :=
 		utlsClientHelloID, utlsClientHelloSpec, err :=
 			getUTLSClientHelloID(params.Get(), profile)
 			getUTLSClientHelloID(params.Get(), profile)
 		if err != nil {
 		if err != nil {
-			t.Fatalf("getUTLSClientHelloID failed: %s\n", err)
+			t.Fatalf("getUTLSClientHelloID failed: %v", err)
 		}
 		}
 
 
 		var unexpectedClientHelloID, unexpectedClientHelloSpec bool
 		var unexpectedClientHelloID, unexpectedClientHelloSpec bool
@@ -334,7 +684,7 @@ func makeCustomTLSProfilesParameters(
 
 
 	params, err := parameters.NewParameters(nil)
 	params, err := parameters.NewParameters(nil)
 	if err != nil {
 	if err != nil {
-		t.Fatalf("NewParameters failed: %s\n", err)
+		t.Fatalf("NewParameters failed: %v", err)
 	}
 	}
 
 
 	// Equivilent to utls.HelloChrome_62
 	// Equivilent to utls.HelloChrome_62
@@ -370,7 +720,7 @@ func makeCustomTLSProfilesParameters(
 
 
 	err = json.Unmarshal(customTLSProfilesJSON, &customTLSProfiles)
 	err = json.Unmarshal(customTLSProfilesJSON, &customTLSProfiles)
 	if err != nil {
 	if err != nil {
-		t.Fatalf("Unmarshal failed: %s", err)
+		t.Fatalf("Unmarshal failed: %v", err)
 	}
 	}
 
 
 	applyParameters := make(map[string]interface{})
 	applyParameters := make(map[string]interface{})
@@ -394,7 +744,7 @@ func makeCustomTLSProfilesParameters(
 
 
 	_, err = params.Set("", false, applyParameters)
 	_, err = params.Set("", false, applyParameters)
 	if err != nil {
 	if err != nil {
-		t.Fatalf("Set failed: %s", err)
+		t.Fatalf("Set failed: %v", err)
 	}
 	}
 
 
 	customTLSProfileNames := params.Get().CustomTLSProfileNames()
 	customTLSProfileNames := params.Get().CustomTLSProfileNames()

+ 86 - 10
psiphon/tunnel.go

@@ -30,6 +30,7 @@ import (
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
+	"net/http"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
@@ -790,22 +791,97 @@ func dialTunnel(
 
 
 	} else if protocol.TunnelProtocolUsesConjure(dialParams.TunnelProtocol) {
 	} else if protocol.TunnelProtocolUsesConjure(dialParams.TunnelProtocol) {
 
 
-		// The Conjure "phantom" connection is compatible with fragmentation, but
-		// the decoy registrar connection, like Tapdance, is not, so force it off.
-		// Any tunnel fragmentation metrics will refer to the "phantom" connection
-		// only.
-		decoyRegistrarDialer := NewNetDialer(
-			dialParams.GetDialConfig().WithoutFragmentor())
+		// Specify a cache key with a scope that ensures that:
+		//
+		// (a) cached registrations aren't used across different networks, as a
+		// registration requires the client's public IP to match the value at time
+		// of registration;
+		//
+		// (b) cached registrations are associated with specific Psiphon server
+		// candidates, to ensure that replay will use the same phantom IP(s).
+		//
+		// This scheme allows for reuse of cached registrations on network A when a
+		// client roams from network A to network B and back to network A.
+		//
+		// Using the network ID as a proxy for client public IP address is a
+		// heurisitic: it's possible that a clients public IP address changes
+		// without the network ID changing, and it's not guaranteed that the client
+		// will be assigned the original public IP on network A; so there's some
+		// chance the registration cannot be reused.
+
+		cacheKey := dialParams.NetworkID + dialParams.ServerEntry.IpAddress
+
+		conjureConfig := &refraction.ConjureConfig{
+			RegistrationCacheTTL: dialParams.ConjureCachedRegistrationTTL,
+			RegistrationCacheKey: cacheKey,
+			Transport:            dialParams.ConjureTransport,
+		}
+
+		if dialParams.ConjureAPIRegistration {
+
+			// Use MeekConn to domain front Conjure API registration.
+			//
+			// ConjureAPIRegistrarFrontingSpecs are applied via
+			// dialParams.GetMeekConfig, and will be subject to replay.
+			//
+			// Since DialMeek will create a TLS connection immediately, and a cached
+			// registration may be used, we will delay initializing the MeekConn-based
+			// RoundTripper until we know it's needed. This is implemented by passing
+			// in a RoundTripper that establishes a MeekConn when RoundTrip is called.
+			//
+			// In refraction.dial we configure 0 retries for API registration requests,
+			// assuming it's better to let another Psiphon candidate retry, with new
+			// domaing fronting parameters. As such, we expect only one round trip call
+			// per NewHTTPRoundTripper, so, in practise, there's no performance penalty
+			// from establishing a new MeekConn per round trip.
+			//
+			// Performing the full DialMeek/RoundTrip operation here allows us to call
+			// MeekConn.Close and ensure all resources are immediately cleaned up.
+			roundTrip := func(request *http.Request) (*http.Response, error) {
+				conn, err := DialMeek(
+					ctx, dialParams.GetMeekConfig(), dialParams.GetDialConfig())
+				if err != nil {
+					return nil, errors.Trace(err)
+				}
+				defer conn.Close()
+				response, err := conn.RoundTrip(request)
+				if err != nil {
+					return nil, errors.Trace(err)
+				}
+				// Currently, gotapdance does not read the response body. When that
+				// changes, we will need to ensure MeekConn.Close does not make the
+				// response body unavailable, perhaps by reading into a buffer and
+				// replacing reponse.Body. For now, we can immediately close it.
+				response.Body.Close()
+				return response, nil
+			}
+
+			conjureConfig.APIRegistrarHTTPClient = &http.Client{
+				Transport: common.NewHTTPRoundTripper(roundTrip),
+			}
+
+			conjureConfig.APIRegistrarURL = dialParams.ConjureAPIRegistrarURL
+			conjureConfig.APIRegistrarDelay = dialParams.ConjureAPIRegistrarDelay
+
+		} else if dialParams.ConjureDecoyRegistration {
+
+			// The Conjure "phantom" connection is compatible with fragmentation, but
+			// the decoy registrar connection, like Tapdance, is not, so force it off.
+			// Any tunnel fragmentation metrics will refer to the "phantom" connection
+			// only.
+			conjureConfig.DecoyRegistrarDialer = NewNetDialer(
+				dialParams.GetDialConfig().WithoutFragmentor())
+			conjureConfig.DecoyRegistrarWidth = dialParams.ConjureDecoyRegistrarWidth
+			conjureConfig.DecoyRegistrarDelay = dialParams.ConjureDecoyRegistrarDelay
+		}
 
 
 		dialConn, err = refraction.DialConjure(
 		dialConn, err = refraction.DialConjure(
 			ctx,
 			ctx,
 			config.EmitRefractionNetworkingLogs,
 			config.EmitRefractionNetworkingLogs,
 			config.GetPsiphonDataDirectory(),
 			config.GetPsiphonDataDirectory(),
 			NewNetDialer(dialParams.GetDialConfig()),
 			NewNetDialer(dialParams.GetDialConfig()),
-			decoyRegistrarDialer,
-			dialParams.ConjureDecoyRegistrarWidth,
-			dialParams.ConjureTransport,
-			dialParams.DirectDialAddress)
+			dialParams.DirectDialAddress,
+			conjureConfig)
 		if err != nil {
 		if err != nil {
 			return nil, errors.Trace(err)
 			return nil, errors.Trace(err)
 		}
 		}