Просмотр исходного кода

Merge branch 'master' into staging-client

Rod Hynes 1 год назад
Родитель
Сommit
515be1482c

+ 6 - 0
psiphon/common/inproxy/api.go

@@ -137,6 +137,8 @@ const (
 	NetworkTypeUnknown NetworkType = iota
 	NetworkTypeWiFi
 	NetworkTypeMobile
+	NetworkTypeWired
+	NetworkTypeVPN
 )
 
 // NetworkProtocol is an Internet protocol, such as TCP or UDP. This enum is
@@ -453,6 +455,10 @@ func GetNetworkType(packedBaseParams protocol.PackedAPIParameters) NetworkType {
 		return NetworkTypeWiFi
 	case "MOBILE":
 		return NetworkTypeMobile
+	case "WIRED":
+		return NetworkTypeWired
+	case "VPN":
+		return NetworkTypeVPN
 	}
 	return NetworkTypeUnknown
 }

+ 54 - 3
psiphon/common/inproxy/client.go

@@ -21,6 +21,7 @@ package inproxy
 
 import (
 	"context"
+	"fmt"
 	"net"
 	"net/netip"
 	"sync"
@@ -47,6 +48,7 @@ type ClientConn struct {
 	webRTCConn   *webRTCConn
 	connectionID ID
 	remoteAddr   net.Addr
+	metrics      common.LogFields
 
 	relayMutex         sync.Mutex
 	initialRelayPacket []byte
@@ -126,6 +128,9 @@ func DialClient(
 	ctx context.Context,
 	config *ClientConfig) (retConn *ClientConn, retErr error) {
 
+	startTime := time.Now()
+	metrics := common.LogFields{}
+
 	// Configure the value returned by ClientConn.RemoteAddr. If no
 	// config.RemoteAddrOverride is specified, RemoteAddr will return a
 	// zero-value, non-nil net.Addr. The underlying webRTCConn.RemoteAddr
@@ -193,10 +198,18 @@ func DialClient(
 				Logger:                config.Logger,
 				WebRTCDialCoordinator: config.WebRTCDialCoordinator,
 			})
+
+		duration := time.Since(startTime)
+		metrics["inproxy_dial_nat_discovery_duration"] = fmt.Sprintf("%d", duration/time.Millisecond)
+		config.Logger.WithTraceFields(
+			common.LogFields{"duration": duration.String()}).Info("NAT discovery complete")
+		startTime = time.Now()
 	}
 
 	var result *clientWebRTCDialResult
-	for {
+	for attempt := 0; ; attempt += 1 {
+
+		previousAttemptsDuration := time.Since(startTime)
 
 		// Repeatedly try to establish in-proxy/WebRTC connection until the
 		// dial context is canceled or times out.
@@ -219,6 +232,16 @@ func DialClient(
 		var retry bool
 		result, retry, err = dialClientWebRTCConn(ctx, config)
 		if err == nil {
+
+			if attempt > 0 {
+				// Record the time elapsed in previous attempts.
+				metrics["inproxy_dial_failed_attempts_duration"] =
+					fmt.Sprintf("%d", previousAttemptsDuration/time.Millisecond)
+				config.Logger.WithTraceFields(
+					common.LogFields{
+						"duration": previousAttemptsDuration.String()}).Info("previous failed attempts")
+			}
+
 			break
 		}
 
@@ -241,12 +264,15 @@ func DialClient(
 		return nil, errors.Trace(err)
 	}
 
+	metrics.Add(result.metrics)
+
 	return &ClientConn{
 		config:             config,
 		webRTCConn:         result.conn,
 		connectionID:       result.connectionID,
-		initialRelayPacket: result.relayPacket,
 		remoteAddr:         remoteAddr,
+		metrics:            metrics,
+		initialRelayPacket: result.relayPacket,
 	}, nil
 }
 
@@ -313,12 +339,16 @@ type clientWebRTCDialResult struct {
 	conn         *webRTCConn
 	connectionID ID
 	relayPacket  []byte
+	metrics      common.LogFields
 }
 
 func dialClientWebRTCConn(
 	ctx context.Context,
 	config *ClientConfig) (retResult *clientWebRTCDialResult, retRetry bool, retErr error) {
 
+	startTime := time.Now()
+	metrics := common.LogFields{}
+
 	brokerCoordinator := config.BrokerClient.GetBrokerDialCoordinator()
 	personalCompartmentIDs := brokerCoordinator.PersonalCompartmentIDs()
 
@@ -353,6 +383,12 @@ func dialClientWebRTCConn(
 		}
 	}()
 
+	duration := time.Since(startTime)
+	metrics["inproxy_dial_webrtc_ice_gathering_duration"] = fmt.Sprintf("%d", duration/time.Millisecond)
+	config.Logger.WithTraceFields(
+		common.LogFields{"duration": duration.String()}).Info("ICE gathering complete")
+	startTime = time.Now()
+
 	// Send the ClientOffer request to the broker
 
 	apiParams := common.APIParameters{}
@@ -396,6 +432,12 @@ func dialClientWebRTCConn(
 		return nil, false, errors.Trace(err)
 	}
 
+	duration = time.Since(startTime)
+	metrics["inproxy_dial_broker_offer_duration"] = fmt.Sprintf("%d", duration/time.Millisecond)
+	config.Logger.WithTraceFields(
+		common.LogFields{"duration": duration.String()}).Info("Broker offer complete")
+	startTime = time.Now()
+
 	// MustUpgrade has precedence over other cases to ensure the callback is
 	// invoked. No retry when rate/entry limited or must upgrade; do retry on
 	// no-match, as a match may soon appear.
@@ -442,16 +484,25 @@ func dialClientWebRTCConn(
 		return nil, true, errors.Trace(err)
 	}
 
+	duration = time.Since(startTime)
+	metrics["inproxy_dial_webrtc_connection_duration"] = fmt.Sprintf("%d", duration/time.Millisecond)
+	config.Logger.WithTraceFields(
+		common.LogFields{"duration": duration.String()}).Info("WebRTC connection complete")
+
 	return &clientWebRTCDialResult{
 		conn:         webRTCConn,
 		connectionID: offerResponse.ConnectionID,
 		relayPacket:  offerResponse.RelayPacketToServer,
+		metrics:      metrics,
 	}, false, nil
 }
 
 // GetMetrics implements the common.MetricsSource interface.
 func (conn *ClientConn) GetMetrics() common.LogFields {
-	return conn.webRTCConn.GetMetrics()
+	metrics := common.LogFields{}
+	metrics.Add(conn.metrics)
+	metrics.Add(conn.webRTCConn.GetMetrics())
+	return metrics
 }
 
 func (conn *ClientConn) Close() error {

+ 43 - 4
psiphon/common/inproxy/session.go

@@ -37,6 +37,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	lrucache "github.com/cognusion/go-cache-lru"
 	"github.com/flynn/noise"
+	"github.com/marusama/semaphore"
 	"golang.org/x/crypto/curve25519"
 	"golang.zx2c4.com/wireguard/replay"
 )
@@ -50,6 +51,8 @@ const (
 
 	resetSessionTokenName      = "psiphon-inproxy-session-reset-session-token"
 	resetSessionTokenNonceSize = 32
+
+	maxResponderConcurrentNewSessions = 32768
 )
 
 const (
@@ -450,6 +453,12 @@ func (s *InitiatorSessions) getSession(
 	s.mutex.Lock()
 	defer s.mutex.Unlock()
 
+	// Note: unlike in ResponderSessions.getSession, there is no indication,
+	// in profiling, of high lock contention and blocking here when holding
+	// the mutex lock while calling newSession. The lock is left in place to
+	// preserve the semantics of only one concurrent newSession call,
+	// particularly for brokers initiating new sessions with servers.
+
 	session, ok := s.sessions[publicKey]
 	if ok {
 		return session, false, session.isReadyToShare(nil), nil
@@ -860,8 +869,10 @@ type ResponderSessions struct {
 	obfuscationReplayHistory    *obfuscationReplayHistory
 	expectedInitiatorPublicKeys *sessionPublicKeyLookup
 
-	mutex    sync.Mutex
+	mutex    sync.RWMutex
 	sessions *lrucache.Cache
+
+	concurrentNewSessions semaphore.Semaphore
 }
 
 // NewResponderSessions creates a new ResponderSessions which allows any
@@ -883,6 +894,7 @@ func NewResponderSessions(
 		applyTTL:                 true,
 		obfuscationReplayHistory: newObfuscationReplayHistory(),
 		sessions:                 lrucache.NewWithLRU(sessionsTTL, 1*time.Minute, sessionsMaxSize),
+		concurrentNewSessions:    semaphore.New(maxResponderConcurrentNewSessions),
 	}, nil
 }
 
@@ -1210,16 +1222,35 @@ func (s *ResponderSessions) touchSession(sessionID ID, session *session) {
 // creates a new session, and places it in the cache, if not found.
 func (s *ResponderSessions) getSession(sessionID ID) (*session, error) {
 
-	s.mutex.Lock()
-	defer s.mutex.Unlock()
+	// Concurrency: profiling indicates that holding the mutex lock here when
+	// calling newSession leads to high contention and blocking. Instead,
+	// release the lock after checking for an existing session, and then
+	// recheck -- using lrucache.Add, which fails if an entry exists -- when
+	// inserting.
+	//
+	// A read-only lock is obtained on the initial check, allowing for
+	// concurrent checks; however, note that lrucache has its own RWMutex and
+	// obtains a write lock in Get when LRU ejection may need to be performed.
+	//
+	// A semaphore is used to enforce a sanity check maximum number of
+	// concurrent newSession calls.
+	//
+	// TODO: add a timeout or stop signal to Acquire?
 
 	strSessionID := string(sessionID[:])
 
+	s.mutex.RLock()
 	entry, ok := s.sessions.Get(strSessionID)
+	s.mutex.RUnlock()
+
 	if ok {
 		return entry.(*session), nil
 	}
 
+	err := s.concurrentNewSessions.Acquire(context.Background(), 1)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
 	session, err := newSession(
 		false, // !isInitiator
 		s.privateKey,
@@ -1230,12 +1261,20 @@ func (s *ResponderSessions) getSession(sessionID ID) (*session, error) {
 		nil,
 		&sessionID,
 		s.expectedInitiatorPublicKeys)
+	s.concurrentNewSessions.Release(1)
+
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
 
-	s.sessions.Set(
+	s.mutex.Lock()
+	err = s.sessions.Add(
 		strSessionID, session, lrucache.DefaultExpiration)
+	s.mutex.Unlock()
+
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
 
 	return session, nil
 }

+ 42 - 14
psiphon/common/inproxy/webrtc.go

@@ -315,6 +315,7 @@ func newWebRTCConn(
 
 	pionLoggerFactory := newPionLoggerFactory(
 		config.Logger,
+		func() bool { return ctx.Err() != nil },
 		config.EnableDebugLogging)
 
 	pionNetwork := newPionNetwork(
@@ -1535,7 +1536,7 @@ func (conn *webRTCConn) onConnectionStateChange(state webrtc.PeerConnectionState
 
 	conn.config.Logger.WithTraceFields(common.LogFields{
 		"state": state.String(),
-	}).Info("peer connection state changed")
+	}).Debug("peer connection state changed")
 }
 
 func (conn *webRTCConn) onICECandidate(candidate *webrtc.ICECandidate) {
@@ -1545,7 +1546,7 @@ func (conn *webRTCConn) onICECandidate(candidate *webrtc.ICECandidate) {
 
 	conn.config.Logger.WithTraceFields(common.LogFields{
 		"candidate": candidate.String(),
-	}).Info("new ICE candidate")
+	}).Debug("new ICE candidate")
 }
 
 func (conn *webRTCConn) onICEBindingRequest(m *stun.Message, local, remote ice.Candidate, pair *ice.CandidatePair) bool {
@@ -1569,7 +1570,7 @@ func (conn *webRTCConn) onICEBindingRequest(m *stun.Message, local, remote ice.C
 	conn.config.Logger.WithTraceFields(common.LogFields{
 		"local_candidate":  local.String(),
 		"remote_candidate": remote.String(),
-	}).Info("new ICE STUN binding request")
+	}).Debug("new ICE STUN binding request")
 
 	return false
 }
@@ -1578,14 +1579,14 @@ func (conn *webRTCConn) onICEConnectionStateChange(state webrtc.ICEConnectionSta
 
 	conn.config.Logger.WithTraceFields(common.LogFields{
 		"state": state.String(),
-	}).Info("ICE connection state changed")
+	}).Debug("ICE connection state changed")
 }
 
 func (conn *webRTCConn) onICEGatheringStateChange(state webrtc.ICEGathererState) {
 
 	conn.config.Logger.WithTraceFields(common.LogFields{
 		"state": state.String(),
-	}).Info("ICE gathering state changed")
+	}).Debug("ICE gathering state changed")
 }
 
 func (conn *webRTCConn) onNegotiationNeeded() {
@@ -1597,7 +1598,7 @@ func (conn *webRTCConn) onSignalingStateChange(state webrtc.SignalingState) {
 
 	conn.config.Logger.WithTraceFields(common.LogFields{
 		"state": state.String(),
-	}).Info("signaling state changed")
+	}).Debug("signaling state changed")
 }
 
 func (conn *webRTCConn) onDataChannel(dataChannel *webrtc.DataChannel) {
@@ -1610,7 +1611,7 @@ func (conn *webRTCConn) onDataChannel(dataChannel *webrtc.DataChannel) {
 	conn.config.Logger.WithTraceFields(common.LogFields{
 		"label": dataChannel.Label(),
 		"ID":    dataChannel.ID(),
-	}).Info("new data channel")
+	}).Debug("new data channel")
 }
 
 func (conn *webRTCConn) onDataChannelOpen() {
@@ -1978,18 +1979,22 @@ func processSDPAddresses(
 
 type pionLoggerFactory struct {
 	logger       common.Logger
+	stopLogging  func() bool
 	debugLogging bool
 }
 
-func newPionLoggerFactory(logger common.Logger, debugLogging bool) *pionLoggerFactory {
+func newPionLoggerFactory(
+	logger common.Logger, stopLogging func() bool, debugLogging bool) *pionLoggerFactory {
+
 	return &pionLoggerFactory{
 		logger:       logger,
+		stopLogging:  stopLogging,
 		debugLogging: debugLogging,
 	}
 }
 
 func (f *pionLoggerFactory) NewLogger(scope string) pion_logging.LeveledLogger {
-	return newPionLogger(scope, f.logger, f.debugLogging)
+	return newPionLogger(scope, f.logger, f.stopLogging, f.debugLogging)
 }
 
 // pionLogger wraps common.Logger and implements
@@ -1998,56 +2003,70 @@ func (f *pionLoggerFactory) NewLogger(scope string) pion_logging.LeveledLogger {
 type pionLogger struct {
 	scope        string
 	logger       common.Logger
+	stopLogging  func() bool
 	debugLogging bool
 	warnNoPairs  int32
 }
 
-func newPionLogger(scope string, logger common.Logger, debugLogging bool) *pionLogger {
+func newPionLogger(
+	scope string, logger common.Logger, stopLogging func() bool, debugLogging bool) *pionLogger {
+
 	return &pionLogger{
 		scope:        scope,
 		logger:       logger,
+		stopLogging:  stopLogging,
 		debugLogging: debugLogging,
 	}
 }
 
 func (l *pionLogger) Trace(msg string) {
-	if !l.debugLogging {
+	if l.stopLogging() || !l.debugLogging {
 		return
 	}
 	l.logger.WithTrace().Debug(fmt.Sprintf("webRTC: %s: %s", l.scope, msg))
 }
 
 func (l *pionLogger) Tracef(format string, args ...interface{}) {
-	if !l.debugLogging {
+	if l.stopLogging() || !l.debugLogging {
 		return
 	}
 	l.logger.WithTrace().Debug(fmt.Sprintf("webRTC: %s: %s", l.scope, fmt.Sprintf(format, args...)))
 }
 
 func (l *pionLogger) Debug(msg string) {
-	if !l.debugLogging {
+	if l.stopLogging() || !l.debugLogging {
 		return
 	}
 	l.logger.WithTrace().Debug(fmt.Sprintf("[webRTC: %s: %s", l.scope, msg))
 }
 
 func (l *pionLogger) Debugf(format string, args ...interface{}) {
-	if !l.debugLogging {
+	if l.stopLogging() || !l.debugLogging {
 		return
 	}
 	l.logger.WithTrace().Debug(fmt.Sprintf("webRTC: %s: %s", l.scope, fmt.Sprintf(format, args...)))
 }
 
 func (l *pionLogger) Info(msg string) {
+	if l.stopLogging() {
+		return
+	}
 	l.logger.WithTrace().Info(fmt.Sprintf("webRTC: %s: %s", l.scope, msg))
 }
 
 func (l *pionLogger) Infof(format string, args ...interface{}) {
+	if l.stopLogging() {
+		return
+	}
 	l.logger.WithTrace().Info(fmt.Sprintf("webRTC: %s: %s", l.scope, fmt.Sprintf(format, args...)))
 }
 
 func (l *pionLogger) Warn(msg string) {
 
+	if l.stopLogging() {
+		return
+	}
+
 	// To reduce diagnostic log noise, only log this message once per dial attempt.
 	if msg == "Failed to ping without candidate pairs. Connection is not possible yet." &&
 		!atomic.CompareAndSwapInt32(&l.warnNoPairs, 0, 1) {
@@ -2058,14 +2077,23 @@ func (l *pionLogger) Warn(msg string) {
 }
 
 func (l *pionLogger) Warnf(format string, args ...interface{}) {
+	if l.stopLogging() {
+		return
+	}
 	l.logger.WithTrace().Warning(fmt.Sprintf("webRTC: %s: %s", l.scope, fmt.Sprintf(format, args...)))
 }
 
 func (l *pionLogger) Error(msg string) {
+	if l.stopLogging() {
+		return
+	}
 	l.logger.WithTrace().Error(fmt.Sprintf("webRTC: %s: %s", l.scope, msg))
 }
 
 func (l *pionLogger) Errorf(format string, args ...interface{}) {
+	if l.stopLogging() {
+		return
+	}
 	l.logger.WithTrace().Error(fmt.Sprintf("webRTC: %s: %s", l.scope, fmt.Sprintf(format, args...)))
 }
 

+ 1 - 1
psiphon/common/networkid/networkid_windows.go

@@ -170,7 +170,7 @@ func getConnectionType(ifType winipcfg.IfType, description string) string {
 	var connectionType string
 
 	switch ifType {
-	case winipcfg.IfTypeEthernetCSMACD:
+	case winipcfg.IfTypeEthernetCSMACD, winipcfg.IfTypeEthernet3Mbit, winipcfg.IfTypeFastether, winipcfg.IfTypeFastetherFX, winipcfg.IfTypeGigabitethernet, winipcfg.IfTypeIEEE80212, winipcfg.IfTypeDigitalpowerline:
 		connectionType = "WIRED"
 	case winipcfg.IfTypeIEEE80211:
 		connectionType = "WIFI"

+ 4 - 0
psiphon/common/parameters/parameters.go

@@ -358,6 +358,8 @@ const (
 	DNSResolverIncludeEDNS0Probability                 = "DNSResolverIncludeEDNS0Probability"
 	DNSResolverCacheExtensionInitialTTL                = "DNSResolverCacheExtensionInitialTTL"
 	DNSResolverCacheExtensionVerifiedTTL               = "DNSResolverCacheExtensionVerifiedTTL"
+	DNSResolverQNameRandomizeCasingProbability         = "DNSResolverQNameRandomizeCasingProbability"
+	DNSResolverQNameMustMatchProbability               = "DNSResolverQNameMustMatchProbability"
 	AddFrontingProviderPsiphonFrontingHeader           = "AddFrontingProviderPsiphonFrontingHeader"
 	DirectHTTPProtocolTransformSpecs                   = "DirectHTTPProtocolTransformSpecs"
 	DirectHTTPProtocolTransformScopedSpecNames         = "DirectHTTPProtocolTransformScopedSpecNames"
@@ -880,6 +882,8 @@ var defaultParameters = map[string]struct {
 	DNSResolverIncludeEDNS0Probability:          {value: 0.0, minimum: 0.0},
 	DNSResolverCacheExtensionInitialTTL:         {value: time.Duration(0), minimum: time.Duration(0)},
 	DNSResolverCacheExtensionVerifiedTTL:        {value: time.Duration(0), minimum: time.Duration(0)},
+	DNSResolverQNameRandomizeCasingProbability:  {value: 0.0, minimum: 0.0},
+	DNSResolverQNameMustMatchProbability:        {value: 0.0, minimum: 0.0},
 
 	AddFrontingProviderPsiphonFrontingHeader: {value: protocol.LabeledTunnelProtocols{}},
 

+ 25 - 1
psiphon/common/protocol/packed.go

@@ -792,9 +792,13 @@ func init() {
 
 		{142, "statusData", rawJSONConverter},
 
+		// Specs: server.inproxyDialParams
+
 		{143, "inproxy_webrtc_local_ice_candidate_is_private_IP", intConverter},
 		{144, "inproxy_webrtc_remote_ice_candidate_is_private_IP", intConverter},
 
+		// Specs: server.baseDialParams
+
 		{145, "tls_sent_ticket", intConverter},
 		{146, "tls_did_resume", intConverter},
 		{147, "quic_sent_ticket", intConverter},
@@ -802,7 +806,27 @@ func init() {
 		{149, "quic_dial_early", intConverter},
 		{150, "quic_obfuscated_psk", intConverter},
 
-		// Next key value = 151
+		{151, "dns_qname_random_casing", intConverter},
+		{152, "dns_qname_must_match", intConverter},
+		{153, "dns_qname_mismatches", intConverter},
+
+		// Specs: server.inproxyDialParams
+
+		{154, "inproxy_broker_dns_qname_random_casing", intConverter},
+		{155, "inproxy_broker_dns_qname_must_match", intConverter},
+		{156, "inproxy_broker_dns_qname_mismatches", intConverter},
+		{157, "inproxy_webrtc_dns_qname_random_casing", intConverter},
+		{158, "inproxy_webrtc_dns_qname_must_match", intConverter},
+		{159, "inproxy_webrtc_dns_qname_mismatches", intConverter},
+
+		{160, "inproxy_dial_nat_discovery_duration", intConverter},
+		{161, "inproxy_dial_failed_attempts_duration", intConverter},
+		{162, "inproxy_dial_webrtc_ice_gathering_duration", intConverter},
+		{163, "inproxy_dial_broker_offer_duration", intConverter},
+		{164, "inproxy_dial_webrtc_connection_duration", intConverter},
+		{165, "inproxy_broker_is_reuse", intConverter},
+
+		// Next key value = 166
 	}
 
 	for _, spec := range packedAPIParameterSpecs {

+ 14 - 4
psiphon/common/refraction/refraction.go

@@ -416,8 +416,15 @@ func dial(
 			refractionDialer.Transport = transport.ID()
 			refractionDialer.TransportConfig = config
 			refractionDialer.DisableRegistrarOverrides = disableOverrides
-			refractionDialer.DialerWithLaddr = newWriteMergeDialer(
-				refractionDialer.DialerWithLaddr, false, 32)
+			if !conjureConfig.DoDecoyRegistration {
+				// Limitation: the writeMergeConn wrapping is skipped when
+				// using decoy registration, since the refraction package
+				// uses DialerWithLaddr for both the decoy registration step
+				// as well as the following phantom dial, and the
+				// writeMergeConn is only appropriate for the phantom dial.
+				refractionDialer.DialerWithLaddr = newWriteMergeDialer(
+					refractionDialer.DialerWithLaddr, false, 32)
+			}
 
 		case protocol.CONJURE_TRANSPORT_PREFIX_OSSH:
 
@@ -442,8 +449,11 @@ func dial(
 			refractionDialer.Transport = transport.ID()
 			refractionDialer.TransportConfig = config
 			refractionDialer.DisableRegistrarOverrides = disableOverrides
-			refractionDialer.DialerWithLaddr = newWriteMergeDialer(
-				refractionDialer.DialerWithLaddr, true, 64)
+			if !conjureConfig.DoDecoyRegistration {
+				// See limitation comment above.
+				refractionDialer.DialerWithLaddr = newWriteMergeDialer(
+					refractionDialer.DialerWithLaddr, true, 64)
+			}
 
 		case protocol.CONJURE_TRANSPORT_DTLS_OSSH:
 

+ 85 - 7
psiphon/common/resolver/resolver.go

@@ -41,6 +41,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
 	lrucache "github.com/cognusion/go-cache-lru"
 	"github.com/miekg/dns"
+	"golang.org/x/net/idna"
 )
 
 const (
@@ -208,6 +209,23 @@ type ResolveParameters struct {
 	// specify the same seed.
 	ProtocolTransformSeed *prng.Seed
 
+	// RandomQNameCasingSeed specifies the seed for randomizing the casing of
+	// the QName (hostname) in the DNS request. If not set, the QName casing
+	// will remain unchanged. To reproduce the same random casing, use the same
+	// seed.
+	RandomQNameCasingSeed *prng.Seed
+
+	// ResponseQNameMustMatch specifies whether the response's question section
+	// must contain exactly one entry, and that entry's QName (hostname) must
+	// exactly match the QName sent in the DNS request.
+	//
+	// RFC 1035 does not specify that the question section in the response must
+	// exactly match the question section in the request, but this behavior is
+	// expected [1].
+	//
+	// [1]: https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00#section-2.2.
+	ResponseQNameMustMatch bool
+
 	// IncludeEDNS0 indicates whether to include the EDNS(0) UDP maximum
 	// response size extension in DNS requests. The resolver can handle
 	// responses larger than 512 bytes (RFC 1035 maximum) regardless of
@@ -216,6 +234,7 @@ type ResolveParameters struct {
 	IncludeEDNS0 bool
 
 	firstAttemptWithAnswer int32
+	qnameMismatches        int32
 }
 
 // GetFirstAttemptWithAnswer returns the index of the first request attempt
@@ -235,6 +254,26 @@ func (r *ResolveParameters) setFirstAttemptWithAnswer(attempt int) {
 	atomic.StoreInt32(&r.firstAttemptWithAnswer, int32(attempt))
 }
 
+// GetQNameMismatches returns, for the most recent ResolveIP call using this
+// ResolveParameters, the number of DNS requests where the response's question
+// section either:
+//   - Did not contain exactly one entry; or
+//   - Contained one entry that had a QName (hostname) that did not match the
+//     QName sent in the DNS request.
+//
+// This information is used for logging metrics.
+//
+// The caller is responsible for synchronizing use of a ResolveParameters
+// instance (e.g, use a distinct ResolveParameters per ResolveIP to ensure
+// GetQNameMismatches refers to a specific ResolveIP).
+func (r *ResolveParameters) GetQNameMismatches() int {
+	return int(atomic.LoadInt32(&r.qnameMismatches))
+}
+
+func (r *ResolveParameters) setQNameMismatches(mismatches int) {
+	atomic.StoreInt32(&r.qnameMismatches, int32(mismatches))
+}
+
 // Implementation note: Go's standard net.Resolver supports specifying a
 // custom Dial function. This could be used to implement at least a large
 // subset of the Resolver functionality on top of Go's standard library
@@ -443,6 +482,16 @@ func (r *Resolver) MakeResolveParameters(
 		}
 	}
 
+	if p.WeightedCoinFlip(parameters.DNSResolverQNameRandomizeCasingProbability) {
+		var err error
+		params.RandomQNameCasingSeed, err = prng.NewSeed()
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+	}
+
+	params.ResponseQNameMustMatch = p.WeightedCoinFlip(parameters.DNSResolverQNameMustMatchProbability)
+
 	if p.WeightedCoinFlip(parameters.DNSResolverIncludeEDNS0Probability) {
 		params.IncludeEDNS0 = true
 	}
@@ -728,9 +777,11 @@ func (r *Resolver) ResolveIP(
 
 		server := servers[index]
 
-		// Only the first attempt pair tries transforms, as it's not certain
-		// the transforms will be compatible with DNS servers.
+		// Only the first attempt pair tries techniques that may not be
+		// compatible with all DNS servers.
 		useProtocolTransform := (i == 0 && params.ProtocolTransformSpec != nil)
+		useRandomQNameCasing := (i == 0 && params.RandomQNameCasingSeed != nil)
+		responseQNameMustMatch := (i == 0 && params.ResponseQNameMustMatch)
 
 		// Send A and AAAA requests concurrently.
 		questionTypes := []resolverQuestionType{resolverQuestionTypeA, resolverQuestionTypeAAAA}
@@ -752,7 +803,7 @@ func (r *Resolver) ResolveIP(
 			inFlight += 1
 			r.updateMetricPeakInFlight(inFlight)
 
-			go func(attempt int, questionType resolverQuestionType, useProtocolTransform bool) {
+			go func(attempt int, questionType resolverQuestionType, useProtocolTransform, useRandomQNameCasing, responseQNameMustMatch bool) {
 				defer waitGroup.Done()
 
 				// Always send a result back to the main loop, even if this
@@ -834,9 +885,11 @@ func (r *Resolver) ResolveIP(
 					r.networkConfig.logWarning,
 					params,
 					useProtocolTransform,
+					useRandomQNameCasing,
 					conn,
 					questionType,
-					hostname)
+					hostname,
+					responseQNameMustMatch)
 
 				// Update the min/max RTT metric when reported (>=0) even if
 				// the result is an error; i.e., the even if there was an
@@ -880,7 +933,7 @@ func (r *Resolver) ResolveIP(
 					}
 				}
 
-			}(i+1, questionType, useProtocolTransform)
+			}(i+1, questionType, useProtocolTransform, useRandomQNameCasing, responseQNameMustMatch)
 		}
 
 		resetTimer(requestTimeout)
@@ -1472,9 +1525,11 @@ func performDNSQuery(
 	logWarning func(error),
 	params *ResolveParameters,
 	useProtocolTransform bool,
+	useRandomQNameCasing bool,
 	conn net.Conn,
 	questionType resolverQuestionType,
-	hostname string) ([]net.IP, []time.Duration, time.Duration, error) {
+	hostname string,
+	responseQNameMustMatch bool) ([]net.IP, []time.Duration, time.Duration, error) {
 
 	if useProtocolTransform {
 		if params.ProtocolTransformSpec == nil ||
@@ -1494,6 +1549,16 @@ func performDNSQuery(
 		}
 	}
 
+	// Convert to punycode.
+	hostname, err := idna.ToASCII(hostname)
+	if err != nil {
+		return nil, nil, -1, errors.Trace(err)
+	}
+
+	if useRandomQNameCasing {
+		hostname = common.ToRandomASCIICasing(hostname, params.RandomQNameCasingSeed)
+	}
+
 	// UDPSize sets the receive buffer to > 512, even when we don't include
 	// EDNS(0), which will mitigate issues with RFC 1035 non-compliant
 	// servers. See Go issue 51127.
@@ -1523,7 +1588,7 @@ func performDNSQuery(
 	startTime := time.Now()
 
 	// Send the DNS request
-	err := dnsConn.WriteMsg(request)
+	err = dnsConn.WriteMsg(request)
 	if err != nil {
 		return nil, nil, -1, errors.Trace(err)
 	}
@@ -1531,6 +1596,10 @@ func performDNSQuery(
 	// Read and process the DNS response
 	var IPs []net.IP
 	var TTLs []time.Duration
+	var qnameMismatches int
+	defer func() {
+		params.setQNameMismatches(qnameMismatches)
+	}()
 	var lastErr error
 	RTT := time.Duration(-1)
 	for {
@@ -1571,6 +1640,15 @@ func performDNSQuery(
 			continue
 		}
 
+		if len(response.Question) != 1 || response.Question[0].Name != dns.Fqdn(hostname) {
+			qnameMismatches++
+			if responseQNameMustMatch {
+				lastErr = errors.Tracef("unexpected QName")
+				logWarning(lastErr)
+				continue
+			}
+		}
+
 		// Check the RCode.
 		//
 		// For IPv4, we expect RCodeSuccess as Psiphon will typically only

+ 102 - 21
psiphon/common/resolver/resolver_test.go

@@ -80,6 +80,8 @@ func runTestMakeResolveParameters() error {
 		"DNSResolverProtocolTransformProbability":     1.0,
 		"DNSResolverProtocolTransformSpecs":           transforms.Specs{transformName: exampleTransform},
 		"DNSResolverProtocolTransformScopedSpecNames": transforms.ScopedSpecNames{preferredAlternateDNSServer: []string{transformName}},
+		"DNSResolverQNameRandomizeCasingProbability":  1.0,
+		"DNSResolverQNameMustMatchProbability":        1.0,
 		"DNSResolverIncludeEDNS0Probability":          1.0,
 	}
 
@@ -132,7 +134,7 @@ func runTestMakeResolveParameters() error {
 		}
 	}
 
-	// Test: Preferred/Transform/EDNS(0)
+	// Test: Preferred/Transform/RandomQNameCasing/QNameMustMatch/EDNS(0)
 
 	paramValues["DNSResolverPreresolvedIPAddressProbability"] = 0.0
 
@@ -157,6 +159,8 @@ func runTestMakeResolveParameters() error {
 		resolverParams.PreferAlternateDNSServer != true ||
 		resolverParams.ProtocolTransformName != transformName ||
 		resolverParams.ProtocolTransformSpec == nil ||
+		resolverParams.RandomQNameCasingSeed == nil ||
+		resolverParams.ResponseQNameMustMatch != true ||
 		resolverParams.IncludeEDNS0 != true {
 		return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
 	}
@@ -165,6 +169,8 @@ func runTestMakeResolveParameters() error {
 
 	paramValues["DNSResolverPreferAlternateServerProbability"] = 0.0
 	paramValues["DNSResolverProtocolTransformProbability"] = 0.0
+	paramValues["DNSResolverQNameRandomizeCasingProbability"] = 0.0
+	paramValues["DNSResolverQNameMustMatchProbability"] = 0.0
 	paramValues["DNSResolverIncludeEDNS0Probability"] = 0.0
 
 	_, err = params.Set("", 0, paramValues)
@@ -188,6 +194,8 @@ func runTestMakeResolveParameters() error {
 		resolverParams.PreferAlternateDNSServer != false ||
 		resolverParams.ProtocolTransformName != "" ||
 		resolverParams.ProtocolTransformSpec != nil ||
+		resolverParams.RandomQNameCasingSeed != nil ||
+		resolverParams.ResponseQNameMustMatch != false ||
 		resolverParams.IncludeEDNS0 != false {
 		return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
 	}
@@ -198,14 +206,14 @@ func runTestMakeResolveParameters() error {
 func runTestResolver() error {
 
 	// noResponseServer will not respond to requests
-	noResponseServer, err := newTestDNSServer(false, false, false)
+	noResponseServer, err := newTestDNSServer(false, false, false, false)
 	if err != nil {
 		return errors.Trace(err)
 	}
 	defer noResponseServer.stop()
 
 	// invalidIPServer will respond with an invalid IP
-	invalidIPServer, err := newTestDNSServer(true, false, false)
+	invalidIPServer, err := newTestDNSServer(true, false, false, false)
 	if err != nil {
 		return errors.Trace(err)
 	}
@@ -213,7 +221,7 @@ func runTestResolver() error {
 
 	// okServer will respond to correct requests (expected domain) with the
 	// correct response (expected IPv4 or IPv6 address)
-	okServer, err := newTestDNSServer(true, true, false)
+	okServer, err := newTestDNSServer(true, true, false, false)
 	if err != nil {
 		return errors.Trace(err)
 	}
@@ -221,7 +229,7 @@ func runTestResolver() error {
 
 	// alternateOkServer behaves like okServer; getRequestCount is used to
 	// confirm that the alternate server was indeed used
-	alternateOkServer, err := newTestDNSServer(true, true, false)
+	alternateOkServer, err := newTestDNSServer(true, true, false, false)
 	if err != nil {
 		return errors.Trace(err)
 	}
@@ -230,12 +238,18 @@ func runTestResolver() error {
 	// transformOkServer behaves like okServer but only responds if the
 	// transform was applied; other servers do not respond if the transform
 	// is applied
-	transformOkServer, err := newTestDNSServer(true, true, true)
+	transformOkServer, err := newTestDNSServer(true, true, true, false)
 	if err != nil {
 		return errors.Trace(err)
 	}
 	defer transformOkServer.stop()
 
+	randomQNameCasingOkServer, err := newTestDNSServer(true, true, false, true)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer randomQNameCasingOkServer.stop()
+
 	servers := []string{noResponseServer.getAddr(), invalidIPServer.getAddr(), okServer.getAddr()}
 
 	networkConfig := &NetworkConfig{
@@ -487,7 +501,7 @@ func runTestResolver() error {
 		return errors.TraceNew("unexpected server count")
 	}
 
-	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleRealDomain)
 	if err != nil {
 		return errors.Trace(err)
 	}
@@ -529,6 +543,7 @@ func runTestResolver() error {
 
 	resolver.cache.Flush()
 
+	params.AttemptsPerServer = 0
 	params.AlternateDNSServer = transformOkServer.getAddr()
 	params.PreferAlternateDNSServer = true
 
@@ -555,12 +570,67 @@ func runTestResolver() error {
 		return errors.TraceNew("unexpected transform server request count")
 	}
 
+	params.AttemptsPerServer = 1
 	params.AlternateDNSServer = ""
 	params.PreferAlternateDNSServer = false
 	params.ProtocolTransformName = ""
 	params.ProtocolTransformSpec = nil
 	params.ProtocolTransformSeed = nil
 
+	// Test: random QName (hostname) casing
+	//
+	// Note: there's a (1/2)^N chance that the QName (hostname) with randomized
+	// casing has the same casing as the input QName, where N is the number of
+	// Unicode letters in the QName. In such an event these tests will either
+	// give a false positive or false negative depending on the subtest.
+
+	if randomQNameCasingOkServer.getRequestCount() != 0 {
+		return errors.TraceNew("unexpected random QName casing server request count")
+	}
+
+	resolver.cache.Flush()
+
+	params.AttemptsPerServer = 0
+	params.AttemptsPerPreferredServer = 1
+	params.AlternateDNSServer = randomQNameCasingOkServer.getAddr()
+	params.PreferAlternateDNSServer = true
+	params.RandomQNameCasingSeed = seed
+
+	_, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	resolver.cache.Flush()
+	params.ResponseQNameMustMatch = true
+
+	_, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err == nil {
+		return errors.TraceNew("expected QName mismatch")
+	}
+
+	resolver.cache.Flush()
+	params.AlternateDNSServer = okServer.getAddr()
+
+	_, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err == nil {
+		return errors.TraceNew("expected server to not support random QName casing")
+	}
+
+	err = checkResult(IPs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if randomQNameCasingOkServer.getRequestCount() < 1 {
+		return errors.TraceNew("unexpected random QName casing server request count")
+	}
+
+	params.AttemptsPerServer = 1
+	params.AlternateDNSServer = ""
+	params.PreferAlternateDNSServer = false
+	params.RandomQNameCasingSeed = nil
+
 	// Test: EDNS(0)
 
 	resolver.cache.Flush()
@@ -695,7 +765,7 @@ func runTestPublicDNSServers() ([]net.IP, string, error) {
 	}
 
 	IPs, err := resolver.ResolveIP(
-		context.Background(), networkID, params, exampleDomain)
+		context.Background(), networkID, params, exampleRealDomain)
 	if err != nil {
 		return nil, "", errors.Trace(err)
 	}
@@ -728,8 +798,10 @@ func getPublicDNSServers() []string {
 	return shuffledServers
 }
 
+var exampleDomain = fmt.Sprintf("%s.example.com", prng.Base64String(32))
+
 const (
-	exampleDomain     = "example.com"
+	exampleRealDomain = "example.com"
 	exampleIPv4       = "93.184.216.34"
 	exampleIPv4CIDR   = "93.184.216.0/24"
 	exampleIPv6       = "2606:2800:220:1:248:1893:25c8:1946"
@@ -741,15 +813,16 @@ const (
 var exampleTransform = transforms.Spec{[2]string{"^([a-f0-9]{4})0100", "\\$\\{1\\}0140"}}
 
 type testDNSServer struct {
-	respond         bool
-	validResponse   bool
-	expectTransform bool
-	addr            string
-	requestCount    int32
-	server          *dns.Server
+	respond                 bool
+	validResponse           bool
+	expectTransform         bool
+	expectRandomQNameCasing bool
+	addr                    string
+	requestCount            int32
+	server                  *dns.Server
 }
 
-func newTestDNSServer(respond, validResponse, expectTransform bool) (*testDNSServer, error) {
+func newTestDNSServer(respond, validResponse, expectTransform, expectRandomQNameCasing bool) (*testDNSServer, error) {
 
 	udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
 	if err != nil {
@@ -762,10 +835,11 @@ func newTestDNSServer(respond, validResponse, expectTransform bool) (*testDNSSer
 	}
 
 	s := &testDNSServer{
-		respond:         respond,
-		validResponse:   validResponse,
-		expectTransform: expectTransform,
-		addr:            udpConn.LocalAddr().String(),
+		respond:                 respond,
+		validResponse:           validResponse,
+		expectTransform:         expectTransform,
+		expectRandomQNameCasing: expectRandomQNameCasing,
+		addr:                    udpConn.LocalAddr().String(),
 	}
 
 	server := &dns.Server{
@@ -792,7 +866,9 @@ func (s *testDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 		return
 	}
 
-	if len(r.Question) != 1 || r.Question[0].Name != dns.Fqdn(exampleDomain) {
+	if len(r.Question) != 1 ||
+		(!s.expectRandomQNameCasing &&
+			r.Question[0].Name != dns.Fqdn(exampleDomain)) {
 		return
 	}
 
@@ -827,6 +903,11 @@ func (s *testDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 		}
 	}
 
+	if s.expectRandomQNameCasing {
+		// Simulate a server that does not preserve the casing of the QName.
+		m.Question[0].Name = dns.Fqdn(exampleDomain)
+	}
+
 	w.WriteMsg(m)
 }
 

+ 23 - 0
psiphon/common/utils.go

@@ -31,6 +31,7 @@ import (
 	"math"
 	"net/url"
 	"os"
+	"strings"
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -292,3 +293,25 @@ func MaxDuration(durations ...time.Duration) time.Duration {
 	}
 	return max
 }
+
+// ToRandomASCIICasing returns s with each ASCII letter randomly mapped to
+// either its upper or lower case.
+func ToRandomASCIICasing(s string, seed *prng.Seed) string {
+
+	PRNG := prng.NewPRNGWithSeed(seed)
+
+	var b strings.Builder
+	b.Grow(len(s))
+
+	for _, r := range s {
+		isLower := ('a' <= r && r <= 'z')
+		isUpper := ('A' <= r && r <= 'Z')
+		if (isLower || isUpper) && PRNG.FlipCoin() {
+			b.WriteRune(r ^ 0x20)
+		} else {
+			b.WriteRune(r)
+		}
+	}
+
+	return b.String()
+}

+ 29 - 0
psiphon/common/utils_test.go

@@ -28,6 +28,8 @@ import (
 	"strings"
 	"testing"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 )
 
 func TestGetStringSlice(t *testing.T) {
@@ -164,3 +166,30 @@ func TestSleepWithContext(t *testing.T) {
 		t.Errorf("unexpected duration: %v", duration)
 	}
 }
+
+func TestToRandomCasing(t *testing.T) {
+	s := "test.to.random.ascii.casing.aaaa.bbbb.c" // 32 Unicode letters
+
+	seed, err := prng.NewSeed()
+	if err != nil {
+		t.Errorf("NewPRNG failed: %s", err)
+	}
+
+	randomized := ToRandomASCIICasing(s, seed)
+
+	// Note: there's a (1/2)^32 chance that the randomized string has the same
+	// casing as the input string.
+	if strings.Compare(s, randomized) == 0 {
+		t.Errorf("expected random casing")
+	}
+
+	if strings.Compare(strings.ToLower(s), strings.ToLower(randomized)) != 0 {
+		t.Errorf("expected strings to be identical minus casing")
+	}
+
+	replaySameSeed := ToRandomASCIICasing(s, seed)
+
+	if strings.Compare(randomized, replaySameSeed) != 0 {
+		t.Errorf("expected randomized string with same seed to be identical")
+	}
+}

+ 35 - 0
psiphon/config.go

@@ -947,6 +947,8 @@ type Config struct {
 	DNSResolverProtocolTransformSpecs                transforms.Specs
 	DNSResolverProtocolTransformScopedSpecNames      transforms.ScopedSpecNames
 	DNSResolverProtocolTransformProbability          *float64
+	DNSResolverQNameRandomizeCasingProbability       *float64
+	DNSResolverQNameMustMatchProbability             *float64
 	DNSResolverIncludeEDNS0Probability               *float64
 	DNSResolverCacheExtensionInitialTTLMilliseconds  *int
 	DNSResolverCacheExtensionVerifiedTTLMilliseconds *int
@@ -1658,6 +1660,7 @@ func (config *Config) SetParameters(tag string, skipOnError bool, applyParameter
 	// posting notices.
 
 	config.paramsMutex.Lock()
+	tagUnchanged := tag != "" && tag == config.params.Get().Tag()
 	validationFlags := 0
 	if skipOnError {
 		validationFlags |= parameters.ValidationSkipOnError
@@ -1670,6 +1673,20 @@ func (config *Config) SetParameters(tag string, skipOnError bool, applyParameter
 	p := config.params.Get()
 	config.paramsMutex.Unlock()
 
+	// Skip emitting notices and invoking GetTacticsAppliedReceivers when the
+	// tactics tag is unchanged. The notices are redundant, and the receivers
+	// will unnecessarily reset components such as in-proxy broker clients.
+	//
+	// At this time, the GetTactics call in launchEstablishing can result in
+	// redundant SetParameters calls with an unchanged tag.
+	//
+	// As a fail safe, and since there should not be any unwanted side
+	// effects, the above params.Set is still executed even for unchanged tags.
+
+	if tagUnchanged {
+		return nil
+	}
+
 	NoticeInfo("applied %v parameters with tag '%s'", counts, tag)
 
 	// Emit certain individual parameter values for quick reference in diagnostics.
@@ -2381,6 +2398,14 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.DNSResolverProtocolTransformProbability] = *config.DNSResolverProtocolTransformProbability
 	}
 
+	if config.DNSResolverQNameRandomizeCasingProbability != nil {
+		applyParameters[parameters.DNSResolverQNameRandomizeCasingProbability] = *config.DNSResolverQNameRandomizeCasingProbability
+	}
+
+	if config.DNSResolverQNameMustMatchProbability != nil {
+		applyParameters[parameters.DNSResolverQNameMustMatchProbability] = *config.DNSResolverQNameMustMatchProbability
+	}
+
 	if config.DNSResolverIncludeEDNS0Probability != nil {
 		applyParameters[parameters.DNSResolverIncludeEDNS0Probability] = *config.DNSResolverIncludeEDNS0Probability
 	}
@@ -3289,6 +3314,16 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, *config.DNSResolverProtocolTransformProbability)
 	}
 
+	if config.DNSResolverQNameRandomizeCasingProbability != nil {
+		hash.Write([]byte("DNSResolverQNameRandomizeCasingProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.DNSResolverQNameRandomizeCasingProbability)
+	}
+
+	if config.DNSResolverQNameMustMatchProbability != nil {
+		hash.Write([]byte("DNSResolverQNameMustMatchProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.DNSResolverQNameMustMatchProbability)
+	}
+
 	if config.DNSResolverIncludeEDNS0Probability != nil {
 		hash.Write([]byte("DNSResolverIncludeEDNS0Probability"))
 		binary.Write(hash, binary.LittleEndian, *config.DNSResolverIncludeEDNS0Probability)

+ 11 - 0
psiphon/frontingDialParameters.go

@@ -518,6 +518,17 @@ func (meekDialParameters *FrontedMeekDialParameters) GetMetrics(overridePrefix s
 			logFields[prefix+"dns_transform"] = meekDialParameters.ResolveParameters.ProtocolTransformName
 		}
 
+		if meekDialParameters.ResolveParameters.RandomQNameCasingSeed != nil {
+			logFields[prefix+"dns_qname_random_casing"] = "1"
+		}
+
+		if meekDialParameters.ResolveParameters.ResponseQNameMustMatch {
+			logFields[prefix+"dns_qname_must_match"] = "1"
+		}
+
+		logFields[prefix+"dns_qname_mismatches"] = strconv.Itoa(
+			meekDialParameters.ResolveParameters.GetQNameMismatches())
+
 		logFields[prefix+"dns_attempt"] = strconv.Itoa(
 			meekDialParameters.ResolveParameters.GetFirstAttemptWithAnswer())
 	}

+ 55 - 6
psiphon/inproxy.go

@@ -148,11 +148,28 @@ func (b *InproxyBrokerClientManager) GetBrokerClient(
 		}
 	}
 
+	// Set isReuse, which will record a metric indicating if this broker
+	// client has already been used for a successful round trip, a case which
+	// should result in faster overall dials.
+	//
+	// Limitations with HasSuccess, and the resulting isReuse metric: in some
+	// cases, it's possible that the underlying TLS connection is still
+	// redialed by net/http; or it's possible that the Noise session is
+	// invalid/expired and must be reestablished; or it can be the case that
+	// a shared broker client is only partially established at this point in
+	// time.
+	//
+	// Return a shallow copy of the broker dial params in order to record the
+	// correct isReuse, which varies depending on previous use.
+
+	brokerDialParams := *b.brokerClientInstance.brokerDialParams
+	brokerDialParams.isReuse = b.brokerClientInstance.HasSuccess()
+
 	// The b.brokerClientInstance.brokerClient is wired up to refer back to
 	// b.brokerClientInstance.brokerDialParams/roundTripper, etc.
 
 	return b.brokerClientInstance.brokerClient,
-		b.brokerClientInstance.brokerDialParams,
+		&brokerDialParams,
 		nil
 }
 
@@ -286,7 +303,6 @@ type InproxyBrokerClientInstance struct {
 	brokerRootObfuscationSecret   inproxy.ObfuscationSecret
 	brokerDialParams              *InproxyBrokerDialParameters
 	replayEnabled                 bool
-	isReplay                      bool
 	roundTripper                  *InproxyBrokerRoundTripper
 	personalCompartmentIDs        []inproxy.ID
 	commonCompartmentIDs          []inproxy.ID
@@ -531,7 +547,6 @@ func NewInproxyBrokerClientInstance(
 		brokerRootObfuscationSecret: brokerRootObfuscationSecret,
 		brokerDialParams:            brokerDialParams,
 		replayEnabled:               replayEnabled,
-		isReplay:                    isReplay,
 		roundTripper:                roundTripper,
 		personalCompartmentIDs:      personalCompartmentIDs,
 		commonCompartmentIDs:        commonCompartmentIDs,
@@ -782,6 +797,15 @@ func prepareInproxyCompartmentIDs(
 	return commonCompartmentIDs, personalCompartmentIDs, nil
 }
 
+// HasSuccess indicates whether this broker client instance has completed at
+// least one successful round trip.
+func (b *InproxyBrokerClientInstance) HasSuccess() bool {
+	b.mutex.Lock()
+	defer b.mutex.Unlock()
+
+	return !b.lastSuccess.IsZero()
+}
+
 // Close closes the broker client round tripped, including closing all
 // underlying network connections, which will interrupt any in-flight round
 // trips.
@@ -856,9 +880,9 @@ func (b *InproxyBrokerClientInstance) BrokerClientRoundTripperSucceeded(roundTri
 	// Set replay or extend the broker dial parameters replay TTL after a
 	// success. With tunnel dial parameters, the replay TTL is extended after
 	// every successful tunnel connection. Since there are potentially more
-	// and more frequent broker round trips one tunnel dial, the TTL is only
-	// extended after some target duration has elapsed, to avoid excessive
-	// datastore writes.
+	// and more frequent broker round trips compared to tunnel dials, the TTL
+	// is only extended after some target duration has elapsed, to avoid
+	// excessive datastore writes.
 
 	if b.replayEnabled && now.Sub(b.lastStoreReplay) > b.replayUpdateFrequency {
 		b.brokerDialParams.LastUsedTimestamp = time.Now()
@@ -1079,6 +1103,7 @@ func (b *InproxyBrokerClientInstance) RelayedPacketRequestTimeout() time.Duratio
 type InproxyBrokerDialParameters struct {
 	brokerSpec *parameters.InproxyBrokerSpec `json:"-"`
 	isReplay   bool                          `json:"-"`
+	isReuse    bool                          `json:"-"`
 
 	LastUsedTimestamp      time.Time
 	LastUsedBrokerSpecHash []byte
@@ -1152,6 +1177,9 @@ func (brokerDialParams *InproxyBrokerDialParameters) prepareDialConfigs(
 
 	brokerDialParams.isReplay = isReplay
 
+	// brokerDialParams.isReuse is set only later, as this is a new broker
+	// client dial.
+
 	if isReplay {
 		// FrontedHTTPDialParameters
 		//
@@ -1206,6 +1234,12 @@ func (brokerDialParams *InproxyBrokerDialParameters) GetMetrics() common.LogFiel
 	}
 	logFields["inproxy_broker_is_replay"] = isReplay
 
+	isReuse := "0"
+	if brokerDialParams.isReuse {
+		isReuse = "1"
+	}
+	logFields["inproxy_broker_is_reuse"] = isReuse
+
 	return logFields
 }
 
@@ -2085,6 +2119,17 @@ func (dialParams *InproxySTUNDialParameters) GetMetrics() common.LogFields {
 			logFields["inproxy_webrtc_dns_transform"] = dialParams.ResolveParameters.ProtocolTransformName
 		}
 
+		if dialParams.ResolveParameters.RandomQNameCasingSeed != nil {
+			logFields["inproxy_webrtc_dns_qname_random_casing"] = "1"
+		}
+
+		if dialParams.ResolveParameters.ResponseQNameMustMatch {
+			logFields["inproxy_webrtc_dns_qname_must_match"] = "1"
+		}
+
+		logFields["inproxy_webrtc_dns_qname_mismatches"] = strconv.Itoa(
+			dialParams.ResolveParameters.GetQNameMismatches())
+
 		logFields["inproxy_webrtc_dns_attempt"] = strconv.Itoa(
 			dialParams.ResolveParameters.GetFirstAttemptWithAnswer())
 	}
@@ -2490,6 +2535,10 @@ func getInproxyNetworkType(networkType string) inproxy.NetworkType {
 		return inproxy.NetworkTypeWiFi
 	case "MOBILE":
 		return inproxy.NetworkTypeMobile
+	case "WIRED":
+		return inproxy.NetworkTypeWired
+	case "VPN":
+		return inproxy.NetworkTypeVPN
 	}
 
 	return inproxy.NetworkTypeUnknown

+ 3 - 0
psiphon/net.go

@@ -179,6 +179,9 @@ type HasIPv6RouteGetter interface {
 // - "WIRED" for a wired network
 // - "VPN" for a VPN network
 // - "UNKNOWN" for when the network type cannot be determined
+//
+// Note that the functions psiphon.GetNetworkType, psiphon.getInproxyNetworkType,
+// and inproxy.GetNetworkType must all be updated when new network types are added.
 type NetworkIDGetter interface {
 	GetNetworkID() string
 }

+ 2 - 0
psiphon/notice.go

@@ -616,6 +616,8 @@ func noticeWithDialParameters(noticeType string, dialParams *DialParameters, pos
 			}
 
 			if postDial {
+				args = append(args, "DNSQNameMismatches", dialParams.ResolveParameters.GetQNameMismatches())
+
 				args = append(args, "DNSAttempt", dialParams.ResolveParameters.GetFirstAttemptWithAnswer())
 			}
 		}

+ 15 - 0
psiphon/server/api.go

@@ -1120,6 +1120,9 @@ var baseDialParams = []requestParamSpec{
 	{"dns_preresolved", isAnyString, requestParamOptional},
 	{"dns_preferred", isAnyString, requestParamOptional},
 	{"dns_transform", isAnyString, requestParamOptional},
+	{"dns_qname_random_casing", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
+	{"dns_qname_must_match", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
+	{"dns_qname_mismatches", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"dns_attempt", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"http_transform", isAnyString, requestParamOptional},
 	{"seed_transform", isAnyString, requestParamOptional},
@@ -1162,10 +1165,16 @@ var inproxyDialParams = []requestParamSpec{
 	{"inproxy_broker_dns_preresolved", isAnyString, requestParamOptional},
 	{"inproxy_broker_dns_preferred", isAnyString, requestParamOptional},
 	{"inproxy_broker_dns_transform", isAnyString, requestParamOptional},
+	{"inproxy_broker_dns_qname_random_casing", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
+	{"inproxy_broker_dns_qname_must_match", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
+	{"inproxy_broker_dns_qname_mismatches", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"inproxy_broker_dns_attempt", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"inproxy_webrtc_dns_preresolved", isAnyString, requestParamOptional},
 	{"inproxy_webrtc_dns_preferred", isAnyString, requestParamOptional},
 	{"inproxy_webrtc_dns_transform", isAnyString, requestParamOptional},
+	{"inproxy_broker_dns_qname_random_casing", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
+	{"inproxy_webrtc_dns_qname_must_match", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
+	{"inproxy_webrtc_dns_qname_mismatches", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"inproxy_webrtc_dns_attempt", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"inproxy_webrtc_stun_server", isAnyString, requestParamOptional},
 	{"inproxy_webrtc_stun_server_resolved_ip_address", isAnyString, requestParamOptional},
@@ -1183,6 +1192,12 @@ var inproxyDialParams = []requestParamSpec{
 	{"inproxy_webrtc_remote_ice_candidate_type", isAnyString, requestParamOptional},
 	{"inproxy_webrtc_remote_ice_candidate_is_IPv6", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
 	{"inproxy_webrtc_remote_ice_candidate_port", isIntString, requestParamOptional | requestParamLogStringAsInt},
+	{"inproxy_dial_nat_discovery_duration", isIntString, requestParamOptional | requestParamLogStringAsInt},
+	{"inproxy_dial_failed_attempts_duration", isIntString, requestParamOptional | requestParamLogStringAsInt},
+	{"inproxy_dial_webrtc_ice_gathering_duration", isIntString, requestParamOptional | requestParamLogStringAsInt},
+	{"inproxy_dial_broker_offer_duration", isIntString, requestParamOptional | requestParamLogStringAsInt},
+	{"inproxy_dial_webrtc_connection_duration", isIntString, requestParamOptional | requestParamLogStringAsInt},
+	{"inproxy_broker_is_reuse", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
 }
 
 // baseAndDialParams adds baseDialParams and inproxyDialParams to baseParams.

+ 11 - 4
psiphon/server/server_test.go

@@ -720,12 +720,10 @@ var (
 	testSteeringIP = "1.1.1.1"
 )
 
-var serverRuns = 0
+var lastConnectedUpdateCount = 0
 
 func runServer(t *testing.T, runConfig *runServerConfig) {
 
-	serverRuns += 1
-
 	psiphonServerIPAddress := "127.0.0.1"
 	psiphonServerPort := 4000
 
@@ -1487,7 +1485,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 	// Test unique user counting cases.
 	var expectUniqueUser bool
-	switch serverRuns % 3 {
+	switch lastConnectedUpdateCount % 3 {
 	case 0:
 		// Mock no last_connected.
 		psiphon.SetKeyValue("lastConnected", "")
@@ -1665,6 +1663,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		}
 		waitOnNotification(t, tunnelsEstablished, timeoutSignal, "tunnel established timeout exceeded")
 		waitOnNotification(t, homepageReceived, timeoutSignal, "homepage received timeout exceeded")
+
+		// The tunnel connected, so the local last_connected has been updated.
+		lastConnectedUpdateCount += 1
 	}
 
 	if runConfig.doChangeBytesConfig {
@@ -2536,6 +2537,8 @@ func checkExpectedServerTunnelLogFields(
 
 			// Fields sent by the client
 
+			"inproxy_broker_is_replay",
+			"inproxy_broker_is_reuse",
 			"inproxy_broker_transport",
 			"inproxy_broker_fronting_provider_id",
 			"inproxy_broker_dial_address",
@@ -2545,6 +2548,10 @@ func checkExpectedServerTunnelLogFields(
 			"inproxy_webrtc_padded_messages_received",
 			"inproxy_webrtc_decoy_messages_sent",
 			"inproxy_webrtc_decoy_messages_received",
+
+			"inproxy_dial_webrtc_ice_gathering_duration",
+			"inproxy_dial_broker_offer_duration",
+			"inproxy_dial_webrtc_connection_duration",
 		} {
 			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
 				return fmt.Errorf("missing expected field '%s'", name)

+ 11 - 0
psiphon/serverApi.go

@@ -1278,6 +1278,17 @@ func getBaseAPIParameters(
 				params["dns_transform"] = dialParams.ResolveParameters.ProtocolTransformName
 			}
 
+			if dialParams.ResolveParameters.RandomQNameCasingSeed != nil {
+				params["dns_qname_random_casing"] = "1"
+			}
+
+			if dialParams.ResolveParameters.ResponseQNameMustMatch {
+				params["dns_qname_must_match"] = "1"
+			}
+
+			params["dns_qname_mismatches"] = strconv.Itoa(
+				dialParams.ResolveParameters.GetQNameMismatches())
+
 			params["dns_attempt"] = strconv.Itoa(
 				dialParams.ResolveParameters.GetFirstAttemptWithAnswer())
 		}

+ 6 - 3
psiphon/utils.go

@@ -288,14 +288,17 @@ func GetNetworkType(networkID string) string {
 	// check for and use the common network type prefixes currently used in
 	// NetworkIDGetter implementations.
 
-	if strings.HasPrefix(networkID, "VPN") {
-		return "VPN"
-	}
 	if strings.HasPrefix(networkID, "WIFI") {
 		return "WIFI"
 	}
 	if strings.HasPrefix(networkID, "MOBILE") {
 		return "MOBILE"
 	}
+	if strings.HasPrefix(networkID, "WIRED") {
+		return "WIRED"
+	}
+	if strings.HasPrefix(networkID, "VPN") {
+		return "VPN"
+	}
 	return "UNKNOWN"
 }