Просмотр исходного кода

Merge pull request #597 from rod-hynes/master

Performance and protocol selection enhancements
Rod Hynes 4 лет назад
Родитель
Сommit
60eeb9c78c

+ 5 - 7
psiphon/common/api.go

@@ -30,13 +30,11 @@ type APIParameterValidator func(APIParameters) error
 
 // GeoIPData is type-compatible with psiphon/server.GeoIPData.
 type GeoIPData struct {
-	Country           string
-	City              string
-	ISP               string
-	ASN               string
-	ASO               string
-	HasDiscoveryValue bool
-	DiscoveryValue    int
+	Country string
+	City    string
+	ISP     string
+	ASN     string
+	ASO     string
 }
 
 // APIParameterLogFieldFormatter is a function that returns formatted

+ 20 - 1
psiphon/common/parameters/parameters.go

@@ -231,6 +231,7 @@ const (
 	ReplayAPIRequestPadding                          = "ReplayAPIRequestPadding"
 	ReplayLaterRoundMoveToFrontProbability           = "ReplayLaterRoundMoveToFrontProbability"
 	ReplayRetainFailedProbability                    = "ReplayRetainFailedProbability"
+	ReplayHoldOffTunnel                              = "ReplayHoldOffTunnel"
 	APIRequestUpstreamPaddingMinBytes                = "APIRequestUpstreamPaddingMinBytes"
 	APIRequestUpstreamPaddingMaxBytes                = "APIRequestUpstreamPaddingMaxBytes"
 	APIRequestDownstreamPaddingMinBytes              = "APIRequestDownstreamPaddingMinBytes"
@@ -286,6 +287,14 @@ const (
 	CustomHostNameRegexes                            = "CustomHostNameRegexes"
 	CustomHostNameProbability                        = "CustomHostNameProbability"
 	CustomHostNameLimitProtocols                     = "CustomHostNameLimitProtocols"
+	HoldOffTunnelMinDuration                         = "HoldOffTunnelMinDuration"
+	HoldOffTunnelMaxDuration                         = "HoldOffTunnelMaxDuration"
+	HoldOffTunnelProtocols                           = "HoldOffTunnelProtocols"
+	HoldOffTunnelFrontingProviderIDs                 = "HoldOffTunnelFrontingProviderIDs"
+	HoldOffTunnelProbability                         = "HoldOffTunnelProbability"
+	RestrictFrontingProviderIDs                      = "RestrictFrontingProviderIDs"
+	RestrictFrontingProviderIDsServerProbability     = "RestrictFrontingProviderIDsServerProbability"
+	RestrictFrontingProviderIDsClientProbability     = "RestrictFrontingProviderIDsClientProbability"
 )
 
 const (
@@ -533,6 +542,7 @@ var defaultParameters = map[string]struct {
 	ReplayAPIRequestPadding:                {value: true},
 	ReplayLaterRoundMoveToFrontProbability: {value: 0.0, minimum: 0.0},
 	ReplayRetainFailedProbability:          {value: 0.5, minimum: 0.0},
+	ReplayHoldOffTunnel:                    {value: true},
 
 	APIRequestUpstreamPaddingMinBytes:   {value: 0, minimum: 0},
 	APIRequestUpstreamPaddingMaxBytes:   {value: 1024, minimum: 0},
@@ -600,6 +610,16 @@ var defaultParameters = map[string]struct {
 	CustomHostNameRegexes:        {value: RegexStrings{}},
 	CustomHostNameProbability:    {value: 0.0, minimum: 0.0},
 	CustomHostNameLimitProtocols: {value: protocol.TunnelProtocols{}},
+
+	HoldOffTunnelMinDuration:         {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffTunnelMaxDuration:         {value: time.Duration(0), minimum: time.Duration(0)},
+	HoldOffTunnelProtocols:           {value: protocol.TunnelProtocols{}},
+	HoldOffTunnelFrontingProviderIDs: {value: []string{}},
+	HoldOffTunnelProbability:         {value: 0.0, minimum: 0.0},
+
+	RestrictFrontingProviderIDs:                  {value: []string{}},
+	RestrictFrontingProviderIDsServerProbability: {value: 0.0, minimum: 0.0, flags: serverSideOnly},
+	RestrictFrontingProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
 }
 
 // IsServerSideOnly indicates if the parameter specified by name is used
@@ -826,7 +846,6 @@ func (p *Parameters) Set(
 					}
 				}
 			case protocol.LabeledTLSProfiles:
-
 				if skipOnError {
 					newValue = v.PruneInvalid(customTLSProfileNames)
 				} else {

+ 18 - 4
psiphon/common/protocol/protocol.go

@@ -253,14 +253,27 @@ func TunnelProtocolMayUseServerPacketManipulation(protocol string) bool {
 		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET
 }
 
-func UseClientTunnelProtocol(
+func IsValidClientTunnelProtocol(
 	clientProtocol string,
+	listenerProtocol string,
 	serverProtocols TunnelProtocols) bool {
 
+	if !common.Contains(serverProtocols, clientProtocol) {
+		return false
+	}
+
+	// If the client reports the same tunnel protocol as the listener, the value
+	// is valid.
+
+	if clientProtocol == listenerProtocol {
+		return true
+	}
+
 	// When the server is running multiple fronted protocols, and the client
-	// reports a fronted protocol, use the client's reported tunnel protocol
-	// since some CDNs forward several protocols to the same server port; in this
-	// case the server port is not sufficient to distinguish these protocols.
+	// reports a fronted protocol, the client's reported tunnel protocol is
+	// presumed to be valid since some CDNs forward several protocols to the same
+	// server port; in this case the listener port is not sufficient to
+	// distinguish these protocols.
 
 	if !TunnelProtocolUsesFrontedMeek(clientProtocol) {
 		return false
@@ -445,6 +458,7 @@ type HandshakeResponse struct {
 	HttpsRequestRegexes      []map[string]string `json:"https_request_regexes"`
 	EncodedServerList        []string            `json:"encoded_server_list"`
 	ClientRegion             string              `json:"client_region"`
+	ClientAddress            string              `json:"client_address"`
 	ServerTimestamp          string              `json:"server_timestamp"`
 	ActiveAuthorizationIDs   []string            `json:"active_authorization_ids"`
 	TacticsPayload           json.RawMessage     `json:"tactics_payload"`

+ 246 - 86
psiphon/common/tun/tun.go

@@ -210,6 +210,12 @@ type ServerConfig struct {
 	// IPv6 DNS traffic. It functions like GetDNSResolverIPv4Addresses.
 	GetDNSResolverIPv6Addresses func() []net.IP
 
+	// EnableDNSFlowTracking specifies whether to apply flow tracking to DNS
+	// flows, as required for DNS quality metrics. Typically there are many
+	// short-lived DNS flows to track and each tracked flow adds some overhead,
+	// so this defaults to off.
+	EnableDNSFlowTracking bool
+
 	// DownstreamPacketQueueSize specifies the size of the downstream
 	// packet queue. The packet tunnel server multiplexes all client
 	// packets through a single tun device, so when a packet is read,
@@ -357,6 +363,12 @@ type MetricsUpdater func(
 	TCPApplicationBytesDown, TCPApplicationBytesUp,
 	UDPApplicationBytesDown, UDPApplicationBytesUp int64)
 
+// DNSQualityReporter is a function which receives a DNS quality report:
+// whether a DNS request received a reponse, the elapsed time, and the
+// resolver used.
+type DNSQualityReporter func(
+	receivedResponse bool, requestDuration time.Duration, resolverIP net.IP)
+
 // ClientConnected handles new client connections, creating or resuming
 // a session and returns with client packet handlers running.
 //
@@ -394,7 +406,8 @@ func (server *Server) ClientConnected(
 	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker,
 	checkAllowedDomainFunc AllowedDomainChecker,
 	flowActivityUpdaterMaker FlowActivityUpdaterMaker,
-	metricsUpdater MetricsUpdater) error {
+	metricsUpdater MetricsUpdater,
+	dnsQualityReporter DNSQualityReporter) error {
 
 	// It's unusual to call both sync.WaitGroup.Add() _and_ Done() in the same
 	// goroutine. There's no other place to call Add() since ClientConnected is
@@ -452,11 +465,21 @@ func (server *Server) ClientConnected(
 			lastActivity:             int64(monotime.Now()),
 			sessionID:                sessionID,
 			metrics:                  new(packetMetrics),
+			enableDNSFlowTracking:    server.config.EnableDNSFlowTracking,
 			DNSResolverIPv4Addresses: append([]net.IP(nil), DNSResolverIPv4Addresses...),
 			DNSResolverIPv6Addresses: append([]net.IP(nil), server.config.GetDNSResolverIPv6Addresses()...),
 			workers:                  new(sync.WaitGroup),
 		}
 
+		// One-time, for this session, random resolver selection for TCP transparent
+		// DNS forwarding. See comment in processPacket.
+		if len(clientSession.DNSResolverIPv4Addresses) > 0 {
+			clientSession.TCPDNSResolverIPv4Index = prng.Intn(len(clientSession.DNSResolverIPv4Addresses))
+		}
+		if len(clientSession.DNSResolverIPv6Addresses) > 0 {
+			clientSession.TCPDNSResolverIPv6Index = prng.Intn(len(clientSession.DNSResolverIPv6Addresses))
+		}
+
 		// allocateIndex initializes session.index, session.assignedIPv4Address,
 		// and session.assignedIPv6Address; and updates server.indexToSession and
 		// server.sessionIDToIndex.
@@ -479,7 +502,8 @@ func (server *Server) ClientConnected(
 		checkAllowedUDPPortFunc,
 		checkAllowedDomainFunc,
 		flowActivityUpdaterMaker,
-		metricsUpdater)
+		metricsUpdater,
+		dnsQualityReporter)
 
 	return nil
 }
@@ -518,7 +542,8 @@ func (server *Server) resumeSession(
 	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker,
 	checkAllowedDomainFunc AllowedDomainChecker,
 	flowActivityUpdaterMaker FlowActivityUpdaterMaker,
-	metricsUpdater MetricsUpdater) {
+	metricsUpdater MetricsUpdater,
+	dnsQualityReporter DNSQualityReporter) {
 
 	session.mutex.Lock()
 	defer session.mutex.Unlock()
@@ -560,6 +585,8 @@ func (server *Server) resumeSession(
 
 	session.setMetricsUpdater(&metricsUpdater)
 
+	session.setDNSQualityReporter(&dnsQualityReporter)
+
 	session.channel = channel
 
 	// Parent context is not server.runContext so that session workers
@@ -687,6 +714,9 @@ func (server *Server) removeSession(session *session) {
 	server.sessionIDToIndex.Delete(session.sessionID)
 	server.indexToSession.Delete(session.index)
 	server.interruptSession(session)
+
+	// Delete flows to ensure any pending flow metrics are reported.
+	session.deleteFlows()
 }
 
 func (server *Server) runOrphanMetricsCheckpointer() {
@@ -1071,22 +1101,26 @@ type session struct {
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
 	lastActivity             int64
 	lastFlowReapIndex        int64
+	downstreamPackets        unsafe.Pointer
 	checkAllowedTCPPortFunc  unsafe.Pointer
 	checkAllowedUDPPortFunc  unsafe.Pointer
 	checkAllowedDomainFunc   unsafe.Pointer
 	flowActivityUpdaterMaker unsafe.Pointer
 	metricsUpdater           unsafe.Pointer
-	downstreamPackets        unsafe.Pointer
+	dnsQualityReporter       unsafe.Pointer
 
 	allowBogons              bool
 	metrics                  *packetMetrics
 	sessionID                string
 	index                    int32
+	enableDNSFlowTracking    bool
 	DNSResolverIPv4Addresses []net.IP
+	TCPDNSResolverIPv4Index  int
 	assignedIPv4Address      net.IP
 	setOriginalIPv4Address   int32
 	originalIPv4Address      net.IP
 	DNSResolverIPv6Addresses []net.IP
+	TCPDNSResolverIPv6Index  int
 	assignedIPv6Address      net.IP
 	setOriginalIPv6Address   int32
 	originalIPv6Address      net.IP
@@ -1138,6 +1172,14 @@ func (session *session) getOriginalIPv6Address() net.IP {
 	return session.originalIPv6Address
 }
 
+func (session *session) setDownstreamPackets(p *PacketQueue) {
+	atomic.StorePointer(&session.downstreamPackets, unsafe.Pointer(p))
+}
+
+func (session *session) getDownstreamPackets() *PacketQueue {
+	return (*PacketQueue)(atomic.LoadPointer(&session.downstreamPackets))
+}
+
 func (session *session) setCheckAllowedTCPPortFunc(p *AllowedPortChecker) {
 	atomic.StorePointer(&session.checkAllowedTCPPortFunc, unsafe.Pointer(p))
 }
@@ -1198,12 +1240,16 @@ func (session *session) getMetricsUpdater() MetricsUpdater {
 	return *p
 }
 
-func (session *session) setDownstreamPackets(p *PacketQueue) {
-	atomic.StorePointer(&session.downstreamPackets, unsafe.Pointer(p))
+func (session *session) setDNSQualityReporter(p *DNSQualityReporter) {
+	atomic.StorePointer(&session.dnsQualityReporter, unsafe.Pointer(p))
 }
 
-func (session *session) getDownstreamPackets() *PacketQueue {
-	return (*PacketQueue)(atomic.LoadPointer(&session.downstreamPackets))
+func (session *session) getDNSQualityReporter() DNSQualityReporter {
+	p := (*DNSQualityReporter)(atomic.LoadPointer(&session.dnsQualityReporter))
+	if p == nil {
+		return nil
+	}
+	return *p
 }
 
 // flowID identifies an IP traffic flow using the conventional
@@ -1249,14 +1295,23 @@ type flowState struct {
 	// Note: 64-bit ints used with atomic operations are placed
 	// at the start of struct to ensure 64-bit alignment.
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
-	lastUpstreamPacketTime   int64
-	lastDownstreamPacketTime int64
-	activityUpdaters         []FlowActivityUpdater
+	firstUpstreamPacketTime   int64
+	lastUpstreamPacketTime    int64
+	firstDownstreamPacketTime int64
+	lastDownstreamPacketTime  int64
+	isDNS                     bool
+	dnsQualityReporter        DNSQualityReporter
+	activityUpdaters          []FlowActivityUpdater
 }
 
 func (flowState *flowState) expired(idleExpiry time.Duration) bool {
 	now := monotime.Now()
-	return (now.Sub(monotime.Time(atomic.LoadInt64(&flowState.lastUpstreamPacketTime))) > idleExpiry) ||
+
+	// Traffic in either direction keeps the flow alive. Initially, only one of
+	// lastUpstreamPacketTime or lastDownstreamPacketTime will be set by
+	// startTrackingFlow, and the other value will be 0 and evaluate as expired.
+
+	return (now.Sub(monotime.Time(atomic.LoadInt64(&flowState.lastUpstreamPacketTime))) > idleExpiry) &&
 		(now.Sub(monotime.Time(atomic.LoadInt64(&flowState.lastDownstreamPacketTime))) > idleExpiry)
 }
 
@@ -1271,7 +1326,7 @@ func (session *session) isTrackingFlow(ID flowID) bool {
 
 	// Check if flow is expired but not yet reaped.
 	if flowState.expired(FLOW_IDLE_EXPIRY) {
-		session.flows.Delete(ID)
+		session.deleteFlow(ID, flowState)
 		return false
 	}
 
@@ -1285,6 +1340,7 @@ func (session *session) isTrackingFlow(ID flowID) bool {
 // - one-time permissions checks for a flow
 // - OSLs
 // - domain bytes transferred [TODO]
+// - DNS quality metrics
 //
 // The applicationData from the first packet in the flow is
 // inspected to determine any associated hostname, using HTTP or
@@ -1305,7 +1361,10 @@ func (session *session) isTrackingFlow(ID flowID) bool {
 // startTrackingFlow may be called from concurrent goroutines; if
 // the flow is already tracked, it is simply updated.
 func (session *session) startTrackingFlow(
-	ID flowID, direction packetDirection, applicationData []byte) {
+	ID flowID,
+	direction packetDirection,
+	applicationData []byte,
+	isDNS bool) {
 
 	now := int64(monotime.Now())
 
@@ -1334,12 +1393,16 @@ func (session *session) startTrackingFlow(
 	}
 
 	flowState := &flowState{
-		activityUpdaters: activityUpdaters,
+		isDNS:              isDNS,
+		activityUpdaters:   activityUpdaters,
+		dnsQualityReporter: session.getDNSQualityReporter(),
 	}
 
 	if direction == packetDirectionServerUpstream {
+		flowState.firstUpstreamPacketTime = now
 		flowState.lastUpstreamPacketTime = now
 	} else {
+		flowState.firstDownstreamPacketTime = now
 		flowState.lastDownstreamPacketTime = now
 	}
 
@@ -1350,7 +1413,9 @@ func (session *session) startTrackingFlow(
 }
 
 func (session *session) updateFlow(
-	ID flowID, direction packetDirection, applicationData []byte) {
+	ID flowID,
+	direction packetDirection,
+	applicationData []byte) {
 
 	f, ok := session.flows.Load(ID)
 	if !ok {
@@ -1366,10 +1431,16 @@ func (session *session) updateFlow(
 
 	if direction == packetDirectionServerUpstream {
 		upstreamBytes = int64(len(applicationData))
+
+		atomic.CompareAndSwapInt64(&flowState.firstUpstreamPacketTime, 0, now)
+
 		atomic.StoreInt64(&flowState.lastUpstreamPacketTime, now)
+
 	} else {
 		downstreamBytes = int64(len(applicationData))
 
+		atomic.CompareAndSwapInt64(&flowState.firstDownstreamPacketTime, 0, now)
+
 		// Follows common.ActivityMonitoredConn semantics, where
 		// duration is updated only for downstream activity. This
 		// is intened to produce equivalent behaviour for port
@@ -1384,17 +1455,67 @@ func (session *session) updateFlow(
 	}
 }
 
+// deleteFlow stops tracking a flow and logs any outstanding metrics.
+// flowState is passed in to avoid duplicating the lookup that all callers
+// have already performed.
+func (session *session) deleteFlow(ID flowID, flowState *flowState) {
+
+	if flowState.isDNS {
+
+		dnsStartTime := monotime.Time(
+			atomic.LoadInt64(&flowState.firstUpstreamPacketTime))
+
+		if dnsStartTime > 0 {
+
+			// Record DNS quality metrics using a heuristic: if a packet was sent and
+			// then a packet was received, assume the DNS request successfully received
+			// a valid response; failure occurs when the resolver fails to provide a
+			// response; a "no such host" response is still a success. Limitations: we
+			// assume a resolver will not respond when, e.g., rate limiting; we ignore
+			// subsequent requests made via the same UDP/TCP flow; deleteFlow may be
+			// called only after the flow has expired, which adds some delay to the
+			// recording of the DNS metric.
+
+			dnsEndTime := monotime.Time(
+				atomic.LoadInt64(&flowState.firstDownstreamPacketTime))
+
+			dnsSuccess := true
+			if dnsEndTime == 0 {
+				dnsSuccess = false
+				dnsEndTime = monotime.Now()
+			}
+
+			resolveElapsedTime := dnsEndTime.Sub(dnsStartTime)
+
+			flowState.dnsQualityReporter(
+				dnsSuccess,
+				resolveElapsedTime,
+				net.IP(ID.upstreamIPAddress[:]))
+		}
+	}
+
+	session.flows.Delete(ID)
+}
+
 // reapFlows removes expired idle flows.
 func (session *session) reapFlows() {
 	session.flows.Range(func(key, value interface{}) bool {
 		flowState := value.(*flowState)
 		if flowState.expired(FLOW_IDLE_EXPIRY) {
-			session.flows.Delete(key)
+			session.deleteFlow(key.(flowID), flowState)
 		}
 		return true
 	})
 }
 
+// deleteFlows deletes all flows.
+func (session *session) deleteFlows() {
+	session.flows.Range(func(key, value interface{}) bool {
+		session.deleteFlow(key.(flowID), value.(*flowState))
+		return true
+	})
+}
+
 type packetMetrics struct {
 	upstreamRejectReasons   [packetRejectReasonCount]int64
 	downstreamRejectReasons [packetRejectReasonCount]int64
@@ -2287,8 +2408,8 @@ func processPacket(
 	// Check if the packet qualifies for transparent DNS rewriting
 	//
 	// - Both TCP and UDP DNS packets may qualify
-	// - Transparent DNS flows are not tracked, as most DNS
-	//   resolutions are very-short lived exchanges
+	// - Unless configured, transparent DNS flows are not tracked,
+	//   as most DNS resolutions are very-short lived exchanges
 	// - The traffic rules checks are bypassed, since transparent
 	//   DNS is essential
 
@@ -2301,7 +2422,9 @@ func processPacket(
 			// will be rewritten to go to one of the server's resolvers.
 
 			if destinationPort == portNumberDNS {
-				if version == 4 && destinationIPAddress.Equal(transparentDNSResolverIPv4Address) {
+				if version == 4 &&
+					destinationIPAddress.Equal(transparentDNSResolverIPv4Address) {
+
 					numResolvers := len(session.DNSResolverIPv4Addresses)
 					if numResolvers > 0 {
 						doTransparentDNS = true
@@ -2310,7 +2433,9 @@ func processPacket(
 						return false
 					}
 
-				} else if version == 6 && destinationIPAddress.Equal(transparentDNSResolverIPv6Address) {
+				} else if version == 6 &&
+					destinationIPAddress.Equal(transparentDNSResolverIPv6Address) {
+
 					numResolvers := len(session.DNSResolverIPv6Addresses)
 					if numResolvers > 0 {
 						doTransparentDNS = true
@@ -2372,9 +2497,82 @@ func processPacket(
 		}
 	}
 
+	// Apply rewrites before determining flow ID to ensure that corresponding up-
+	// and downstream flows yield the same flow ID.
+
+	var rewriteSourceIPAddress, rewriteDestinationIPAddress net.IP
+
+	if direction == packetDirectionServerUpstream {
+
+		// Store original source IP address to be replaced in
+		// downstream rewriting.
+
+		if version == 4 {
+			session.setOriginalIPv4AddressIfNotSet(sourceIPAddress)
+			rewriteSourceIPAddress = session.assignedIPv4Address
+		} else { // version == 6
+			session.setOriginalIPv6AddressIfNotSet(sourceIPAddress)
+			rewriteSourceIPAddress = session.assignedIPv6Address
+		}
+
+		// Rewrite DNS packets destinated for the transparent DNS target addresses
+		// to go to one of the server's resolvers. This random selection uses
+		// math/rand to minimize overhead.
+		//
+		// Limitation: TCP packets are always assigned to the same resolver, as
+		// currently there is no method for tracking the assigned resolver per TCP
+		// flow.
+
+		if doTransparentDNS {
+			if version == 4 {
+
+				index := session.TCPDNSResolverIPv4Index
+				if protocol == internetProtocolUDP {
+					index = rand.Intn(len(session.DNSResolverIPv4Addresses))
+				}
+				rewriteDestinationIPAddress = session.DNSResolverIPv4Addresses[index]
+
+			} else { // version == 6
+
+				index := session.TCPDNSResolverIPv6Index
+				if protocol == internetProtocolUDP {
+					index = rand.Intn(len(session.DNSResolverIPv6Addresses))
+				}
+				rewriteDestinationIPAddress = session.DNSResolverIPv6Addresses[index]
+			}
+		}
+
+	} else if direction == packetDirectionServerDownstream {
+
+		// Destination address will be original source address.
+
+		if version == 4 {
+			rewriteDestinationIPAddress = session.getOriginalIPv4Address()
+		} else { // version == 6
+			rewriteDestinationIPAddress = session.getOriginalIPv6Address()
+		}
+
+		if rewriteDestinationIPAddress == nil {
+			metrics.rejectedPacket(direction, packetRejectNoOriginalAddress)
+			return false
+		}
+
+		// Rewrite source address of packets from servers' resolvers
+		// to transparent DNS target address.
+
+		if doTransparentDNS {
+
+			if version == 4 {
+				rewriteSourceIPAddress = transparentDNSResolverIPv4Address
+			} else { // version == 6
+				rewriteSourceIPAddress = transparentDNSResolverIPv6Address
+			}
+		}
+	}
+
 	// Check if flow is tracked before checking traffic permission
 
-	doFlowTracking := !doTransparentDNS && isServer
+	doFlowTracking := isServer && (!doTransparentDNS || session.enableDNSFlowTracking)
 
 	// TODO: verify this struct is stack allocated
 	var ID flowID
@@ -2384,12 +2582,32 @@ func processPacket(
 	if doFlowTracking {
 
 		if direction == packetDirectionServerUpstream {
-			ID.set(
-				sourceIPAddress, sourcePort, destinationIPAddress, destinationPort, protocol)
+
+			// Reflect rewrites in the upstream case and don't reflect rewrites in the
+			// following downstream case: all flow IDs are in the upstream space, with
+			// the assigned private IP for the client and, in the case of DNS, the
+			// actual resolver IP.
+
+			srcIP := sourceIPAddress
+			if rewriteSourceIPAddress != nil {
+				srcIP = rewriteSourceIPAddress
+			}
+
+			destIP := destinationIPAddress
+			if rewriteDestinationIPAddress != nil {
+				destIP = rewriteDestinationIPAddress
+			}
+
+			ID.set(srcIP, sourcePort, destIP, destinationPort, protocol)
 
 		} else if direction == packetDirectionServerDownstream {
+
 			ID.set(
-				destinationIPAddress, destinationPort, sourceIPAddress, sourcePort, protocol)
+				destinationIPAddress,
+				destinationPort,
+				sourceIPAddress,
+				sourcePort,
+				protocol)
 		}
 
 		isTrackingFlow = session.isTrackingFlow(ID)
@@ -2477,68 +2695,10 @@ func processPacket(
 		}
 	}
 
-	// Configure rewriting.
+	// Apply packet rewrites. IP (v4 only) and TCP/UDP all have packet
+	// checksums which are updated to relect the rewritten headers.
 
 	var checksumAccumulator int32
-	var rewriteSourceIPAddress, rewriteDestinationIPAddress net.IP
-
-	if direction == packetDirectionServerUpstream {
-
-		// Store original source IP address to be replaced in
-		// downstream rewriting.
-
-		if version == 4 {
-			session.setOriginalIPv4AddressIfNotSet(sourceIPAddress)
-			rewriteSourceIPAddress = session.assignedIPv4Address
-		} else { // version == 6
-			session.setOriginalIPv6AddressIfNotSet(sourceIPAddress)
-			rewriteSourceIPAddress = session.assignedIPv6Address
-		}
-
-		// Rewrite DNS packets destinated for the transparent DNS target
-		// addresses to go to one of the server's resolvers.
-
-		if doTransparentDNS {
-
-			if version == 4 {
-				rewriteDestinationIPAddress = session.DNSResolverIPv4Addresses[rand.Intn(
-					len(session.DNSResolverIPv4Addresses))]
-			} else { // version == 6
-				rewriteDestinationIPAddress = session.DNSResolverIPv6Addresses[rand.Intn(
-					len(session.DNSResolverIPv6Addresses))]
-			}
-		}
-
-	} else if direction == packetDirectionServerDownstream {
-
-		// Destination address will be original source address.
-
-		if version == 4 {
-			rewriteDestinationIPAddress = session.getOriginalIPv4Address()
-		} else { // version == 6
-			rewriteDestinationIPAddress = session.getOriginalIPv6Address()
-		}
-
-		if rewriteDestinationIPAddress == nil {
-			metrics.rejectedPacket(direction, packetRejectNoOriginalAddress)
-			return false
-		}
-
-		// Rewrite source address  of packets from servers' resolvers
-		// to transparent DNS target address.
-
-		if doTransparentDNS {
-
-			if version == 4 {
-				rewriteSourceIPAddress = transparentDNSResolverIPv4Address
-			} else { // version == 6
-				rewriteSourceIPAddress = transparentDNSResolverIPv6Address
-			}
-		}
-	}
-
-	// Apply rewrites. IP (v4 only) and TCP/UDP all have packet
-	// checksums which are updated to relect the rewritten headers.
 
 	if rewriteSourceIPAddress != nil {
 		checksumAccumulate(sourceIPAddress, false, &checksumAccumulator)
@@ -2570,7 +2730,7 @@ func processPacket(
 
 	if doFlowTracking {
 		if !isTrackingFlow {
-			session.startTrackingFlow(ID, direction, applicationData)
+			session.startTrackingFlow(ID, direction, applicationData, doTransparentDNS)
 		} else {
 			session.updateFlow(ID, direction, applicationData)
 		}

+ 4 - 1
psiphon/common/tun/tun_test.go

@@ -409,6 +409,8 @@ func (server *testServer) run() {
 			checkAllowedPortFunc := func(net.IP, int) bool { return true }
 			checkAllowedDomainFunc := func(string) bool { return true }
 
+			dnsQualityReporter := func(_ bool, _ time.Duration, _ net.IP) {}
+
 			server.tunServer.ClientConnected(
 				sessionID,
 				signalConn,
@@ -416,7 +418,8 @@ func (server *testServer) run() {
 				checkAllowedPortFunc,
 				checkAllowedDomainFunc,
 				server.updaterMaker,
-				server.metricsUpdater)
+				server.metricsUpdater,
+				dnsQualityReporter)
 
 			signalConn.Wait()
 

+ 12 - 0
psiphon/common/utils.go

@@ -22,6 +22,7 @@ package common
 import (
 	"bytes"
 	"compress/zlib"
+	"context"
 	"crypto/rand"
 	std_errors "errors"
 	"fmt"
@@ -226,3 +227,14 @@ func SafeParseRequestURI(rawurl string) (*url.URL, error) {
 	}
 	return parsedURL, err
 }
+
+// SleepWithContext returns after the specified duration or once the input ctx
+// is done, whichever is first.
+func SleepWithContext(ctx context.Context, duration time.Duration) {
+	timer := time.NewTimer(duration)
+	defer timer.Stop()
+	select {
+	case <-timer.C:
+	case <-ctx.Done():
+	}
+}

+ 21 - 0
psiphon/common/utils_test.go

@@ -21,11 +21,13 @@ package common
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"net/url"
 	"reflect"
 	"strings"
 	"testing"
+	"time"
 )
 
 func TestGetStringSlice(t *testing.T) {
@@ -142,3 +144,22 @@ func TestSafeParseRequestURI(t *testing.T) {
 		t.Error("URL in error string")
 	}
 }
+
+func TestSleepWithContext(t *testing.T) {
+
+	start := time.Now()
+	SleepWithContext(context.Background(), 2*time.Millisecond)
+	duration := time.Since(start)
+	if duration/time.Millisecond != 2 {
+		t.Errorf("unexpected duration: %v", duration)
+	}
+
+	start = time.Now()
+	ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Millisecond)
+	defer cancelFunc()
+	SleepWithContext(ctx, 2*time.Millisecond)
+	duration = time.Since(start)
+	if duration/time.Millisecond != 1 {
+		t.Errorf("unexpected duration: %v", duration)
+	}
+}

+ 86 - 0
psiphon/config.go

@@ -437,6 +437,10 @@ type Config struct {
 	// EmitServerAlerts indicates whether to emit notices for server alerts.
 	EmitServerAlerts bool
 
+	// EmitClientAddress indicates whether to emit the client's public network
+	// address, IP and port, as seen by the server.
+	EmitClientAddress bool
+
 	// RateLimits specify throttling configuration for the tunnel.
 	RateLimits common.RateLimits
 
@@ -731,6 +735,19 @@ type Config struct {
 	ConjureDecoyRegistrarMinDelayMilliseconds *int
 	ConjureDecoyRegistrarMaxDelayMilliseconds *int
 
+	// HoldOffTunnelMinDurationMilliseconds and other HoldOffTunnel fields are
+	// for testing purposes.
+	HoldOffTunnelMinDurationMilliseconds *int
+	HoldOffTunnelMaxDurationMilliseconds *int
+	HoldOffTunnelProtocols               []string
+	HoldOffTunnelFrontingProviderIDs     []string
+	HoldOffTunnelProbability             *float64
+
+	// RestrictFrontingProviderIDs and other RestrictFrontingProviderIDs fields
+	// are for testing purposes.
+	RestrictFrontingProviderIDs                  []string
+	RestrictFrontingProviderIDsClientProbability *float64
+
 	// params is the active parameters.Parameters with defaults, config values,
 	// and, optionally, tactics applied.
 	//
@@ -1668,6 +1685,34 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.ConjureDecoyRegistrarMaxDelay] = fmt.Sprintf("%dms", *config.ConjureDecoyRegistrarMaxDelayMilliseconds)
 	}
 
+	if config.HoldOffTunnelMinDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffTunnelMinDuration] = fmt.Sprintf("%dms", *config.HoldOffTunnelMinDurationMilliseconds)
+	}
+
+	if config.HoldOffTunnelMaxDurationMilliseconds != nil {
+		applyParameters[parameters.HoldOffTunnelMaxDuration] = fmt.Sprintf("%dms", *config.HoldOffTunnelMaxDurationMilliseconds)
+	}
+
+	if len(config.HoldOffTunnelProtocols) > 0 {
+		applyParameters[parameters.HoldOffTunnelProtocols] = protocol.TunnelProtocols(config.HoldOffTunnelProtocols)
+	}
+
+	if len(config.HoldOffTunnelFrontingProviderIDs) > 0 {
+		applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = config.HoldOffTunnelFrontingProviderIDs
+	}
+
+	if config.HoldOffTunnelProbability != nil {
+		applyParameters[parameters.HoldOffTunnelProbability] = *config.HoldOffTunnelProbability
+	}
+
+	if len(config.RestrictFrontingProviderIDs) > 0 {
+		applyParameters[parameters.RestrictFrontingProviderIDs] = config.RestrictFrontingProviderIDs
+	}
+
+	if config.RestrictFrontingProviderIDsClientProbability != nil {
+		applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = *config.RestrictFrontingProviderIDsClientProbability
+	}
+
 	// When adding new config dial parameters that may override tactics, also
 	// update setDialParametersHash.
 
@@ -1946,6 +1991,47 @@ func (config *Config) setDialParametersHash() {
 		binary.Write(hash, binary.LittleEndian, int64(*config.ConjureDecoyRegistrarMaxDelayMilliseconds))
 	}
 
+	if config.HoldOffTunnelMinDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffTunnelMinDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffTunnelMinDurationMilliseconds))
+	}
+
+	if config.HoldOffTunnelMaxDurationMilliseconds != nil {
+		hash.Write([]byte("HoldOffTunnelMaxDurationMilliseconds"))
+		binary.Write(hash, binary.LittleEndian, int64(*config.HoldOffTunnelMaxDurationMilliseconds))
+	}
+
+	if len(config.HoldOffTunnelProtocols) > 0 {
+		hash.Write([]byte("HoldOffTunnelProtocols"))
+		for _, protocol := range config.HoldOffTunnelProtocols {
+			hash.Write([]byte(protocol))
+		}
+	}
+
+	if len(config.HoldOffTunnelFrontingProviderIDs) > 0 {
+		hash.Write([]byte("HoldOffTunnelFrontingProviderIDs"))
+		for _, providerID := range config.HoldOffTunnelFrontingProviderIDs {
+			hash.Write([]byte(providerID))
+		}
+	}
+
+	if config.HoldOffTunnelProbability != nil {
+		hash.Write([]byte("HoldOffTunnelProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.HoldOffTunnelProbability)
+	}
+
+	if len(config.RestrictFrontingProviderIDs) > 0 {
+		hash.Write([]byte("RestrictFrontingProviderIDs"))
+		for _, providerID := range config.RestrictFrontingProviderIDs {
+			hash.Write([]byte(providerID))
+		}
+	}
+
+	if config.RestrictFrontingProviderIDsClientProbability != nil {
+		hash.Write([]byte("RestrictFrontingProviderIDsClientProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.RestrictFrontingProviderIDsClientProbability)
+	}
+
 	config.dialParametersHash = hash.Sum(nil)
 }
 

+ 37 - 0
psiphon/dialParameters.go

@@ -128,6 +128,8 @@ type DialParameters struct {
 
 	APIRequestPaddingSeed *prng.Seed
 
+	HoldOffTunnelDuration time.Duration
+
 	DialConnMetrics          common.MetricsSource `json:"-"`
 	ObfuscatedSSHConnMetrics common.MetricsSource `json:"-"`
 
@@ -184,6 +186,7 @@ func MakeDialParameters(
 	replayLivenessTest := p.Bool(parameters.ReplayLivenessTest)
 	replayUserAgent := p.Bool(parameters.ReplayUserAgent)
 	replayAPIRequestPadding := p.Bool(parameters.ReplayAPIRequestPadding)
+	replayHoldOffTunnel := p.Bool(parameters.ReplayHoldOffTunnel)
 
 	// Check for existing dial parameters for this server/network ID.
 
@@ -345,6 +348,20 @@ func MakeDialParameters(
 		dialParams.TunnelProtocol = selectedProtocol
 	}
 
+	// Skip this candidate when the clients tactics restrict usage of the
+	// fronting provider ID. See the corresponding server-side enforcement
+	// comments in server.TacticsListener.accept.
+	if protocol.TunnelProtocolUsesFrontedMeek(dialParams.TunnelProtocol) &&
+		common.Contains(
+			p.Strings(parameters.RestrictFrontingProviderIDs),
+			dialParams.ServerEntry.FrontingProviderID) {
+		if p.WeightedCoinFlip(
+			parameters.RestrictFrontingProviderIDsClientProbability) {
+			return nil, errors.Tracef(
+				"restricted fronting provider ID: %s", dialParams.ServerEntry.FrontingProviderID)
+		}
+	}
+
 	if config.UseUpstreamProxy() &&
 		!protocol.TunnelProtocolSupportsUpstreamProxy(dialParams.TunnelProtocol) {
 
@@ -628,6 +645,26 @@ func MakeDialParameters(
 		}
 	}
 
+	if !isReplay || !replayHoldOffTunnel {
+
+		if common.Contains(
+			p.TunnelProtocols(parameters.HoldOffTunnelProtocols), dialParams.TunnelProtocol) ||
+
+			(protocol.TunnelProtocolUsesFrontedMeek(dialParams.TunnelProtocol) &&
+				common.Contains(
+					p.Strings(parameters.HoldOffTunnelFrontingProviderIDs),
+					dialParams.FrontingProviderID)) {
+
+			if p.WeightedCoinFlip(parameters.HoldOffTunnelProbability) {
+
+				dialParams.HoldOffTunnelDuration = prng.Period(
+					p.Duration(parameters.HoldOffTunnelMinDuration),
+					p.Duration(parameters.HoldOffTunnelMaxDuration))
+			}
+		}
+
+	}
+
 	// Set dial address fields. This portion of configuration is
 	// deterministic, given the parameters established or replayed so far.
 

+ 53 - 2
psiphon/dialParameters_test.go

@@ -76,9 +76,17 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("error committing configuration file: %s", err)
 	}
 
+	holdOffTunnelProtocols := protocol.TunnelProtocols{protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH}
+	frontingProviderID := prng.HexString(8)
+
 	applyParameters := make(map[string]interface{})
 	applyParameters[parameters.TransformHostNameProbability] = 1.0
 	applyParameters[parameters.PickUserAgentProbability] = 1.0
+	applyParameters[parameters.HoldOffTunnelMinDuration] = "1ms"
+	applyParameters[parameters.HoldOffTunnelMaxDuration] = "10ms"
+	applyParameters[parameters.HoldOffTunnelProtocols] = holdOffTunnelProtocols
+	applyParameters[parameters.HoldOffTunnelFrontingProviderIDs] = []string{frontingProviderID}
+	applyParameters[parameters.HoldOffTunnelProbability] = 1.0
 	err = clientConfig.SetParameters("tag1", true, applyParameters)
 	if err != nil {
 		t.Fatalf("SetParameters failed: %s", err)
@@ -90,7 +98,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 	defer CloseDataStore()
 
-	serverEntries := makeMockServerEntries(tunnelProtocol, 100)
+	serverEntries := makeMockServerEntries(tunnelProtocol, frontingProviderID, 100)
 
 	canReplay := func(serverEntry *protocol.ServerEntry, replayProtocol string) bool {
 		return replayProtocol == tunnelProtocol
@@ -204,6 +212,18 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("missing API request fields")
 	}
 
+	if common.Contains(holdOffTunnelProtocols, tunnelProtocol) ||
+		protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
+		if dialParams.HoldOffTunnelDuration < 1*time.Millisecond ||
+			dialParams.HoldOffTunnelDuration > 10*time.Millisecond {
+			t.Fatalf("unexpected hold-off duration: %v", dialParams.HoldOffTunnelDuration)
+		}
+	} else {
+		if dialParams.HoldOffTunnelDuration != 0 {
+			t.Fatalf("unexpected hold-off duration: %v", dialParams.HoldOffTunnelDuration)
+		}
+	}
+
 	dialConfig := dialParams.GetDialConfig()
 	if dialConfig.UpstreamProxyErrorCallback == nil {
 		t.Fatalf("missing upstreamProxyErrorCallback")
@@ -418,6 +438,33 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("unexpected replayed fields")
 	}
 
+	// Test: client-side restrict fronting provider ID
+
+	applyParameters[parameters.RestrictFrontingProviderIDs] = []string{frontingProviderID}
+	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 1.0
+	err = clientConfig.SetParameters("tag4", true, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
+	dialParams, err = MakeDialParameters(clientConfig, nil, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
+
+	if protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
+		if err == nil {
+			t.Fatalf("unexpected MakeDialParameters success")
+		}
+	} else {
+		if err != nil {
+			t.Fatalf("MakeDialParameters failed: %s", err)
+		}
+	}
+
+	applyParameters[parameters.RestrictFrontingProviderIDsClientProbability] = 0.0
+	err = clientConfig.SetParameters("tag5", true, applyParameters)
+	if err != nil {
+		t.Fatalf("SetParameters failed: %s", err)
+	}
+
 	// Test: iterator shuffles
 
 	for i, serverEntry := range serverEntries {
@@ -509,7 +556,10 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 }
 
-func makeMockServerEntries(tunnelProtocol string, count int) []*protocol.ServerEntry {
+func makeMockServerEntries(
+	tunnelProtocol string,
+	frontingProviderID string,
+	count int) []*protocol.ServerEntry {
 
 	serverEntries := make([]*protocol.ServerEntry, count)
 
@@ -524,6 +574,7 @@ func makeMockServerEntries(tunnelProtocol string, count int) []*protocol.ServerE
 			MeekServerPort:             6,
 			MeekFrontingHosts:          []string{"www1.example.org", "www2.example.org", "www3.example.org"},
 			MeekFrontingAddressesRegex: "[a-z0-9]{1,64}.example.org",
+			FrontingProviderID:         frontingProviderID,
 			LocalSource:                protocol.SERVER_ENTRY_SOURCE_EMBEDDED,
 			LocalTimestamp:             common.TruncateTimestampToHour(common.GetCurrentTimestamp()),
 		}

+ 24 - 3
psiphon/notice.go

@@ -667,6 +667,15 @@ func NoticeClientRegion(region string) {
 		"region", region)
 }
 
+// NoticeClientAddress is the client's public network address, the IP address
+// and port, as seen by the server and reported to the client in the
+// handshake.
+func NoticeClientAddress(address string) {
+	singletonNoticeLogger.outputNotice(
+		"ClientAddress", 0,
+		"address", address)
+}
+
 // NoticeTunnels is how many active tunnels are available. The client should use this to
 // determine connecting/unexpected disconnect state transitions. When count is 0, the core is
 // disconnected; when count > 1, the core is connected.
@@ -925,9 +934,21 @@ func NoticeServerAlert(alert protocol.AlertRequest) {
 
 // NoticeBursts reports tunnel data transfer burst metrics.
 func NoticeBursts(diagnosticID string, burstMetrics common.LogFields) {
-	singletonNoticeLogger.outputNotice(
-		"Bursts", noticeIsDiagnostic,
-		append([]interface{}{"diagnosticID", diagnosticID}, listCommonFields(burstMetrics)...)...)
+	if GetEmitNetworkParameters() {
+		singletonNoticeLogger.outputNotice(
+			"Bursts", noticeIsDiagnostic,
+			append([]interface{}{"diagnosticID", diagnosticID}, listCommonFields(burstMetrics)...)...)
+	}
+}
+
+// NoticeHoldOffTunnel reports tunnel hold-offs.
+func NoticeHoldOffTunnel(diagnosticID string, duration time.Duration) {
+	if GetEmitNetworkParameters() {
+		singletonNoticeLogger.outputNotice(
+			"HoldOffTunnel", noticeIsDiagnostic,
+			"diagnosticID", diagnosticID,
+			"duration", duration.String())
+	}
 }
 
 type repetitiveNoticeState struct {

+ 50 - 7
psiphon/server/api.go

@@ -20,6 +20,8 @@
 package server
 
 import (
+	"crypto/hmac"
+	"crypto/sha256"
 	"crypto/subtle"
 	"encoding/base64"
 	"encoding/json"
@@ -60,6 +62,7 @@ const (
 //
 func sshAPIRequestHandler(
 	support *SupportServices,
+	clientAddr string,
 	geoIPData GeoIPData,
 	authorizedAccessTypes []string,
 	name string,
@@ -87,6 +90,7 @@ func sshAPIRequestHandler(
 	return dispatchAPIRequestHandler(
 		support,
 		protocol.PSIPHON_SSH_API_PROTOCOL,
+		clientAddr,
 		geoIPData,
 		authorizedAccessTypes,
 		name,
@@ -98,6 +102,7 @@ func sshAPIRequestHandler(
 func dispatchAPIRequestHandler(
 	support *SupportServices,
 	apiProtocol string,
+	clientAddr string,
 	geoIPData GeoIPData,
 	authorizedAccessTypes []string,
 	name string,
@@ -146,14 +151,22 @@ func dispatchAPIRequestHandler(
 	}
 
 	switch name {
+
 	case protocol.PSIPHON_API_HANDSHAKE_REQUEST_NAME:
-		return handshakeAPIRequestHandler(support, apiProtocol, geoIPData, params)
+		return handshakeAPIRequestHandler(
+			support, apiProtocol, clientAddr, geoIPData, params)
+
 	case protocol.PSIPHON_API_CONNECTED_REQUEST_NAME:
-		return connectedAPIRequestHandler(support, geoIPData, authorizedAccessTypes, params)
+		return connectedAPIRequestHandler(
+			support, clientAddr, geoIPData, authorizedAccessTypes, params)
+
 	case protocol.PSIPHON_API_STATUS_REQUEST_NAME:
-		return statusAPIRequestHandler(support, geoIPData, authorizedAccessTypes, params)
+		return statusAPIRequestHandler(
+			support, clientAddr, geoIPData, authorizedAccessTypes, params)
+
 	case protocol.PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME:
-		return clientVerificationAPIRequestHandler(support, geoIPData, authorizedAccessTypes, params)
+		return clientVerificationAPIRequestHandler(
+			support, clientAddr, geoIPData, authorizedAccessTypes, params)
 	}
 
 	return nil, errors.Tracef("invalid request name: %s", name)
@@ -177,6 +190,7 @@ var handshakeRequestParams = append(
 func handshakeAPIRequestHandler(
 	support *SupportServices,
 	apiProtocol string,
+	clientAddr string,
 	geoIPData GeoIPData,
 	params common.APIParameters) ([]byte, error) {
 
@@ -289,10 +303,20 @@ func handshakeAPIRequestHandler(
 
 	pad_response, _ := getPaddingSizeRequestParam(params, "pad_response")
 
-	if !geoIPData.HasDiscoveryValue {
-		return nil, errors.TraceNew("unexpected missing discovery value")
+	// Discover new servers
+
+	host, _, err := net.SplitHostPort(clientAddr)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	clientIP := net.ParseIP(host)
+	if clientIP == nil {
+		return nil, errors.TraceNew("missing client IP")
 	}
-	encodedServerList := db.DiscoverServers(geoIPData.DiscoveryValue)
+
+	encodedServerList := db.DiscoverServers(
+		calculateDiscoveryValue(support.Config.DiscoveryValueHMACKey, clientIP))
 
 	// When the client indicates that it used an unsigned server entry for this
 	// connection, return a signed copy of the server entry for the client to
@@ -324,6 +348,7 @@ func handshakeAPIRequestHandler(
 		HttpsRequestRegexes:      httpsRequestRegexes,
 		EncodedServerList:        encodedServerList,
 		ClientRegion:             geoIPData.Country,
+		ClientAddress:            clientAddr,
 		ServerTimestamp:          common.GetCurrentTimestamp(),
 		ActiveAuthorizationIDs:   handshakeStateInfo.activeAuthorizationIDs,
 		TacticsPayload:           marshaledTacticsPayload,
@@ -340,6 +365,21 @@ func handshakeAPIRequestHandler(
 	return responsePayload, nil
 }
 
+// calculateDiscoveryValue derives a value from the client IP address to be
+// used as input in the server discovery algorithm.
+// See https://github.com/Psiphon-Inc/psiphon-automation/tree/master/Automation/psi_ops_discovery.py
+// for full details.
+func calculateDiscoveryValue(discoveryValueHMACKey string, ipAddress net.IP) int {
+	// From: psi_ops_discovery.calculate_ip_address_strategy_value:
+	//     # Mix bits from all octets of the client IP address to determine the
+	//     # bucket. An HMAC is used to prevent pre-calculation of buckets for IPs.
+	//     return ord(hmac.new(HMAC_KEY, ip_address, hashlib.sha256).digest()[0])
+	// TODO: use 3-octet algorithm?
+	hash := hmac.New(sha256.New, []byte(discoveryValueHMACKey))
+	hash.Write([]byte(ipAddress.String()))
+	return int(hash.Sum(nil)[0])
+}
+
 // uniqueUserParams are the connected request parameters which are logged for
 // unique_user events.
 var uniqueUserParams = append(
@@ -370,6 +410,7 @@ var updateOnConnectedParamNames = append(
 // connected_timestamp is truncated as a privacy measure.
 func connectedAPIRequestHandler(
 	support *SupportServices,
+	clientAddr string,
 	geoIPData GeoIPData,
 	authorizedAccessTypes []string,
 	params common.APIParameters) ([]byte, error) {
@@ -497,6 +538,7 @@ var failedTunnelStatParams = append(
 // string). Stats processor must handle this input with care.
 func statusAPIRequestHandler(
 	support *SupportServices,
+	clientAddr string,
 	geoIPData GeoIPData,
 	authorizedAccessTypes []string,
 	params common.APIParameters) ([]byte, error) {
@@ -720,6 +762,7 @@ func statusAPIRequestHandler(
 // for older Android clients that still send verification requests
 func clientVerificationAPIRequestHandler(
 	support *SupportServices,
+	clientAddr string,
 	geoIPData GeoIPData,
 	authorizedAccessTypes []string,
 	params common.APIParameters) ([]byte, error) {

+ 48 - 3
psiphon/server/config.go

@@ -61,7 +61,7 @@ const (
 	SSH_OBFUSCATED_KEY_BYTE_LENGTH                      = 32
 	PERIODIC_GARBAGE_COLLECTION                         = 120 * time.Second
 	STOP_ESTABLISH_TUNNELS_ESTABLISHED_CLIENT_THRESHOLD = 20
-	DEFAULT_LOG_FILE_REOPEN_RETRIES                     = 10
+	DEFAULT_LOG_FILE_REOPEN_RETRIES                     = 25
 )
 
 // Config specifies the configuration and behavior of a Psiphon
@@ -154,8 +154,8 @@ type Config struct {
 	// protocols include:
 	// "SSH", "OSSH", "UNFRONTED-MEEK-OSSH", "UNFRONTED-MEEK-HTTPS-OSSH",
 	// "UNFRONTED-MEEK-SESSION-TICKET-OSSH", "FRONTED-MEEK-OSSH",
-	// ""FRONTED-MEEK-QUIC-OSSH" FRONTED-MEEK-HTTP-OSSH", "QUIC-OSSH",
-	// ""MARIONETTE-OSSH", and TAPDANCE-OSSH".
+	// "FRONTED-MEEK-QUIC-OSSH", "FRONTED-MEEK-HTTP-OSSH", "QUIC-OSSH",
+	// "MARIONETTE-OSSH", "TAPDANCE-OSSH", abd "CONJURE-OSSH".
 	//
 	// In the case of "MARIONETTE-OSSH" the port value is ignored and must be
 	// set to 0. The port value specified in the Marionette format is used.
@@ -335,6 +335,10 @@ type Config struct {
 	// PacketTunnelEgressInterface specifies tun.ServerConfig.EgressInterface.
 	PacketTunnelEgressInterface string
 
+	// PacketTunnelEnableDNSFlowTracking sets
+	// tun.ServerConfig.EnableDNSFlowTracking.
+	PacketTunnelEnableDNSFlowTracking bool
+
 	// PacketTunnelDownstreamPacketQueueSize specifies
 	// tun.ServerConfig.DownStreamPacketQueueSize.
 	PacketTunnelDownstreamPacketQueueSize int
@@ -420,6 +424,8 @@ type Config struct {
 	periodicGarbageCollection                      time.Duration
 	stopEstablishTunnelsEstablishedClientThreshold int
 	dumpProfilesOnStopEstablishTunnelsDone         int32
+	frontingProviderID                             string
+	runningProtocols                               []string
 }
 
 // GetLogFileReopenConfig gets the reopen retries, and create/mode inputs for
@@ -486,6 +492,18 @@ func (config *Config) GetOwnEncodedServerEntry(serverEntryTag string) (string, b
 	return serverEntry, ok
 }
 
+// GetFrontingProviderID returns the fronting provider ID associated with the
+// server's fronted protocol(s).
+func (config *Config) GetFrontingProviderID() string {
+	return config.frontingProviderID
+}
+
+// GetRunningProtocols returns the list of protcols this server is running.
+// The caller must not mutate the return value.
+func (config *Config) GetRunningProtocols() []string {
+	return config.runningProtocols
+}
+
 // LoadConfig loads and validates a JSON encoded server config.
 func LoadConfig(configJSON []byte) (*Config, error) {
 
@@ -625,6 +643,33 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 			"AccessControlVerificationKeyRing is invalid: %s", err)
 	}
 
+	// Limitation: the following is a shortcut which extracts the server's
+	// fronting provider ID from the server's OwnEncodedServerEntries. This logic
+	// assumes a server has only one fronting provider. In principle, it's
+	// possible for server with multiple server entries to run multiple fronted
+	// protocols, each with a different fronting provider ID.
+	//
+	// TODO: add an explicit parameter mapping tunnel protocol ports to fronting
+	// provider IDs.
+
+	for _, encodedServerEntry := range config.OwnEncodedServerEntries {
+		serverEntry, err := protocol.DecodeServerEntry(encodedServerEntry, "", "")
+		if err != nil {
+			return nil, errors.Tracef(
+				"protocol.DecodeServerEntry failed: %s", err)
+		}
+		if config.frontingProviderID == "" {
+			config.frontingProviderID = serverEntry.FrontingProviderID
+		} else if config.frontingProviderID != serverEntry.FrontingProviderID {
+			return nil, errors.Tracef("unsupported multiple FrontingProviderID values")
+		}
+	}
+
+	config.runningProtocols = []string{}
+	for tunnelProtocol := range config.TunnelProtocolPorts {
+		config.runningProtocols = append(config.runningProtocols, tunnelProtocol)
+	}
+
 	return &config, nil
 }
 

+ 18 - 5
psiphon/server/dns.go

@@ -172,21 +172,28 @@ func (dns *DNSResolver) reloadWhenStale() {
 	}
 }
 
+// GetAll returns a list of all DNS resolver addresses. Cached values are
+// updated if they're stale. If reloading fails, the previous values are
+// used.
+func (dns *DNSResolver) GetAll() []net.IP {
+	return dns.getAll(true, true)
+}
+
 // GetAllIPv4 returns a list of all IPv4 DNS resolver addresses.
 // Cached values are updated if they're stale. If reloading fails,
 // the previous values are used.
 func (dns *DNSResolver) GetAllIPv4() []net.IP {
-	return dns.getAll(false)
+	return dns.getAll(true, false)
 }
 
 // GetAllIPv6 returns a list of all IPv6 DNS resolver addresses.
 // Cached values are updated if they're stale. If reloading fails,
 // the previous values are used.
 func (dns *DNSResolver) GetAllIPv6() []net.IP {
-	return dns.getAll(true)
+	return dns.getAll(false, true)
 }
 
-func (dns *DNSResolver) getAll(wantIPv6 bool) []net.IP {
+func (dns *DNSResolver) getAll(wantIPv4, wantIPv6 bool) []net.IP {
 
 	dns.reloadWhenStale()
 
@@ -195,8 +202,14 @@ func (dns *DNSResolver) getAll(wantIPv6 bool) []net.IP {
 
 	resolvers := make([]net.IP, 0)
 	for _, resolver := range dns.resolvers {
-		if (resolver.To4() == nil) == wantIPv6 {
-			resolvers = append(resolvers, resolver)
+		if resolver.To4() != nil {
+			if wantIPv4 {
+				resolvers = append(resolvers, resolver)
+			}
+		} else {
+			if wantIPv6 {
+				resolvers = append(resolvers, resolver)
+			}
 		}
 	}
 	return resolvers

+ 14 - 58
psiphon/server/geoip.go

@@ -20,8 +20,6 @@
 package server
 
 import (
-	"crypto/hmac"
-	"crypto/sha256"
 	"fmt"
 	"io"
 	"net"
@@ -45,17 +43,13 @@ const (
 // GeoIPData is GeoIP data for a client session. Individual client
 // IP addresses are neither logged nor explicitly referenced during a session.
 // The GeoIP country, city, and ISP corresponding to a client IP address are
-// resolved and then logged along with usage stats. The DiscoveryValue is
-// a special value derived from the client IP that's used to compartmentalize
-// discoverable servers (see calculateDiscoveryValue for details).
+// resolved and then logged along with usage stats.
 type GeoIPData struct {
-	Country           string
-	City              string
-	ISP               string
-	ASN               string
-	ASO               string
-	HasDiscoveryValue bool
-	DiscoveryValue    int
+	Country string
+	City    string
+	ISP     string
+	ASN     string
+	ASO     string
 }
 
 // NewGeoIPData returns a GeoIPData initialized with the expected
@@ -93,9 +87,8 @@ func (g GeoIPData) SetLogFieldsWithPrefix(prefix string, logFields LogFields) {
 // supports hot reloading of MaxMind data while the server is
 // running.
 type GeoIPService struct {
-	databases             []*geoIPDatabase
-	sessionCache          *cache.Cache
-	discoveryValueHMACKey string
+	databases    []*geoIPDatabase
+	sessionCache *cache.Cache
 }
 
 type geoIPDatabase struct {
@@ -107,14 +100,11 @@ type geoIPDatabase struct {
 }
 
 // NewGeoIPService initializes a new GeoIPService.
-func NewGeoIPService(
-	databaseFilenames []string,
-	discoveryValueHMACKey string) (*GeoIPService, error) {
+func NewGeoIPService(databaseFilenames []string) (*GeoIPService, error) {
 
 	geoIP := &GeoIPService{
-		databases:             make([]*geoIPDatabase, len(databaseFilenames)),
-		sessionCache:          cache.New(GEOIP_SESSION_CACHE_TTL, 1*time.Minute),
-		discoveryValueHMACKey: discoveryValueHMACKey,
+		databases:    make([]*geoIPDatabase, len(databaseFilenames)),
+		sessionCache: cache.New(GEOIP_SESSION_CACHE_TTL, 1*time.Minute),
 	}
 
 	for i, filename := range databaseFilenames {
@@ -203,21 +193,12 @@ func (geoIP *GeoIPService) Reloaders() []common.Reloader {
 }
 
 // Lookup determines a GeoIPData for a given string client IP address.
-//
-// When addDiscoveryValue is true, GeoIPData.DiscoveryValue is calculated and
-// GeoIPData.HasDiscoveryValue is true.
-func (geoIP *GeoIPService) Lookup(
-	strIP string, addDiscoveryValue bool) GeoIPData {
-
-	return geoIP.LookupIP(net.ParseIP(strIP), addDiscoveryValue)
+func (geoIP *GeoIPService) Lookup(strIP string) GeoIPData {
+	return geoIP.LookupIP(net.ParseIP(strIP))
 }
 
 // LookupIP determines a GeoIPData for a given client IP address.
-//
-// When addDiscoveryValue is true, GeoIPData.DiscoveryValue is calculated and
-// GeoIPData.HasDiscoveryValue is true.
-func (geoIP *GeoIPService) LookupIP(
-	IP net.IP, addDiscoveryValue bool) GeoIPData {
+func (geoIP *GeoIPService) LookupIP(IP net.IP) GeoIPData {
 
 	result := NewGeoIPData()
 
@@ -274,14 +255,6 @@ func (geoIP *GeoIPService) LookupIP(
 		result.ASO = geoIPFields.ASO
 	}
 
-	// Populate DiscoveryValue fields (even when there's no GeoIP database).
-
-	if addDiscoveryValue {
-		result.HasDiscoveryValue = true
-		result.DiscoveryValue = calculateDiscoveryValue(
-			geoIP.discoveryValueHMACKey, IP)
-	}
-
 	return result
 }
 
@@ -325,20 +298,3 @@ func (geoIP *GeoIPService) InSessionCache(sessionID string) bool {
 	_, found := geoIP.sessionCache.Get(sessionID)
 	return found
 }
-
-// calculateDiscoveryValue derives a value from the client IP address to be
-// used as input in the server discovery algorithm. Since we do not explicitly
-// store the client IP address, we must derive the value here and store it for
-// later use by the discovery algorithm.
-// See https://github.com/Psiphon-Inc/psiphon-automation/tree/master/Automation/psi_ops_discovery.py
-// for full details.
-func calculateDiscoveryValue(discoveryValueHMACKey string, ipAddress net.IP) int {
-	// From: psi_ops_discovery.calculate_ip_address_strategy_value:
-	//     # Mix bits from all octets of the client IP address to determine the
-	//     # bucket. An HMAC is used to prevent pre-calculation of buckets for IPs.
-	//     return ord(hmac.new(HMAC_KEY, ip_address, hashlib.sha256).digest()[0])
-	// TODO: use 3-octet algorithm?
-	hash := hmac.New(sha256.New, []byte(discoveryValueHMACKey))
-	hash.Write([]byte(ipAddress.String()))
-	return int(hash.Sum(nil)[0])
-}

+ 41 - 0
psiphon/server/listener.go

@@ -25,6 +25,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 )
@@ -57,6 +58,18 @@ func NewTacticsListener(
 // Accept calls the underlying listener's Accept, and then applies server-side
 // tactics to new connections.
 func (listener *TacticsListener) Accept() (net.Conn, error) {
+	for {
+		// accept may discard a successfully accepted conn. In that case, accept
+		// returns nil, nil; call accept until either the conn or err is not nil.
+		conn, err := listener.accept()
+		if conn != nil || err != nil {
+			// Don't modify error from net.Listener
+			return conn, err
+		}
+	}
+}
+
+func (listener *TacticsListener) accept() (net.Conn, error) {
 
 	conn, err := listener.Listener.Accept()
 	if err != nil {
@@ -77,6 +90,34 @@ func (listener *TacticsListener) Accept() (net.Conn, error) {
 		return conn, nil
 	}
 
+	// Disconnect immediately if the clients tactics restricts usage of the
+	// fronting provider ID. The probability may be used to influence usage of a
+	// given fronting provider; but when only that provider works for a given
+	// client, and the probability is less than 1.0, the client can retry until
+	// it gets a successful coin flip.
+	//
+	// Clients will also skip candidates with restricted fronting provider IDs.
+	// The client-side probability, RestrictFrontingProviderIDsClientProbability,
+	// is applied independently of the server-side coin flip here.
+	//
+	//
+	// At this stage, GeoIP tactics filters are active, but handshake API
+	// parameters are not.
+	//
+	// See the comment in server.LoadConfig regarding fronting provider ID
+	// limitations.
+
+	if protocol.TunnelProtocolUsesFrontedMeek(listener.tunnelProtocol) &&
+		common.Contains(
+			p.Strings(parameters.RestrictFrontingProviderIDs),
+			listener.support.Config.GetFrontingProviderID()) {
+		if p.WeightedCoinFlip(
+			parameters.RestrictFrontingProviderIDsServerProbability) {
+			conn.Close()
+			return nil, nil
+		}
+	}
+
 	// Server-side fragmentation may be synchronized with client-side in two ways.
 	//
 	// In the OSSH case, replay is always activated and it is seeded using the

+ 50 - 7
psiphon/server/listener_test.go

@@ -28,13 +28,16 @@ import (
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 )
 
 func TestListener(t *testing.T) {
 
-	tunnelProtocol := protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH
+	tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK
+
+	frontingProviderID := prng.HexString(8)
 
 	tacticsConfigJSONFormat := `
     {
@@ -54,14 +57,25 @@ func TestListener(t *testing.T) {
           },
           "Tactics" : {
             "Parameters" : {
-              "LimitTunnelProtocols" : ["%s"],
               "FragmentorDownstreamLimitProtocols" : ["%s"],
               "FragmentorDownstreamProbability" : 1.0,
               "FragmentorDownstreamMinTotalBytes" : 1,
               "FragmentorDownstreamMaxTotalBytes" : 1,
               "FragmentorDownstreamMinWriteBytes" : 1,
-              "FragmentorDownstreamMaxWriteBytes" : 1,
-              "FragmentorDownstreamLimitProtocols" : ["OSSH"]
+              "FragmentorDownstreamMaxWriteBytes" : 1
+            }
+          }
+        },
+        {
+          "Filter" : {
+            "Regions": ["R3"],
+            "ISPs": ["I3"],
+            "Cities": ["C3"]
+          },
+          "Tactics" : {
+            "Parameters" : {
+              "RestrictFrontingProviderIDs" : ["%s"],
+              "RestrictFrontingProviderIDsServerProbability" : 1.0
             }
           }
         }
@@ -78,7 +92,7 @@ func TestListener(t *testing.T) {
 	tacticsConfigJSON := fmt.Sprintf(
 		tacticsConfigJSONFormat,
 		tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
-		tunnelProtocol, tunnelProtocol)
+		tunnelProtocol, frontingProviderID)
 
 	tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json")
 
@@ -108,31 +122,54 @@ func TestListener(t *testing.T) {
 	listenerUnfragmentedGeoIPWrongCity := func(string) GeoIPData {
 		return GeoIPData{Country: "R1", ISP: "I1", City: "C2"}
 	}
+	listenerRestrictedFrontingProviderIDGeoIP := func(string) GeoIPData {
+		return GeoIPData{Country: "R3", ISP: "I3", City: "C3"}
+	}
+	listenerUnrestrictedFrontingProviderIDWrongRegion := func(string) GeoIPData {
+		return GeoIPData{Country: "R2", ISP: "I3", City: "C3"}
+	}
 
 	listenerTestCases := []struct {
 		description      string
 		geoIPLookup      func(string) GeoIPData
 		expectFragmentor bool
+		expectConnection bool
 	}{
 		{
 			"fragmented",
 			listenerFragmentedGeoIP,
 			true,
+			true,
 		},
 		{
 			"unfragmented-region",
 			listenerUnfragmentedGeoIPWrongRegion,
 			false,
+			true,
 		},
 		{
 			"unfragmented-ISP",
 			listenerUnfragmentedGeoIPWrongISP,
 			false,
+			true,
 		},
 		{
 			"unfragmented-city",
 			listenerUnfragmentedGeoIPWrongCity,
 			false,
+			true,
+		},
+		{
+			"restricted",
+			listenerRestrictedFrontingProviderIDGeoIP,
+			false,
+			false,
+		},
+		{
+			"unrestricted-region",
+			listenerUnrestrictedFrontingProviderIDWrongRegion,
+			false,
+			true,
 		},
 	}
 
@@ -145,6 +182,7 @@ func TestListener(t *testing.T) {
 			}
 
 			support := &SupportServices{
+				Config:        &Config{frontingProviderID: frontingProviderID},
 				TacticsServer: tacticsServer,
 			}
 			support.ReplayCache = NewReplayCache(support)
@@ -172,11 +210,14 @@ func TestListener(t *testing.T) {
 				}
 			}()
 
-			timer := time.NewTimer(3 * time.Second)
+			timer := time.NewTimer(1 * time.Second)
 			defer timer.Stop()
 
 			select {
 			case serverConn := <-result:
+				if !testCase.expectConnection {
+					t.Fatalf("unexpected accepted connection")
+				}
 				_, isFragmentor := serverConn.(*fragmentor.Conn)
 				if testCase.expectFragmentor && !isFragmentor {
 					t.Fatalf("unexpected non-fragmentor: %T", serverConn)
@@ -185,7 +226,9 @@ func TestListener(t *testing.T) {
 				}
 				serverConn.Close()
 			case <-timer.C:
-				t.Fatalf("timeout before expected accepted connection")
+				if testCase.expectConnection {
+					t.Fatalf("timeout before expected accepted connection")
+				}
 			}
 
 			clientConn.Close()

+ 45 - 11
psiphon/server/meek.go

@@ -359,7 +359,7 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		// Endpoint mode. Currently, this means it's handled by the tactics
 		// request handler.
 
-		geoIPData := server.support.GeoIPService.Lookup(clientIP, false)
+		geoIPData := server.support.GeoIPService.Lookup(clientIP)
 		handled := server.support.TacticsServer.HandleEndPoint(
 			endPoint, common.GeoIPData(geoIPData), responseWriter, request)
 		if !handled {
@@ -621,7 +621,7 @@ func (server *MeekServer) getSessionOrEndpoint(
 				proxyClientIP := strings.Split(value, ",")[0]
 				if net.ParseIP(proxyClientIP) != nil &&
 					server.support.GeoIPService.Lookup(
-						proxyClientIP, false).Country != GEOIP_UNKNOWN_VALUE {
+						proxyClientIP).Country != GEOIP_UNKNOWN_VALUE {
 
 					clientIP = proxyClientIP
 					break
@@ -630,10 +630,6 @@ func (server *MeekServer) getSessionOrEndpoint(
 		}
 	}
 
-	if server.rateLimit(clientIP) {
-		return "", nil, nil, "", "", errors.TraceNew("rate limit exceeded")
-	}
-
 	// The session is new (or expired). Treat the cookie value as a new meek
 	// cookie, extract the payload, and create a new session.
 
@@ -651,6 +647,32 @@ func (server *MeekServer) getSessionOrEndpoint(
 		return "", nil, nil, "", "", errors.Trace(err)
 	}
 
+	tunnelProtocol := server.listenerTunnelProtocol
+
+	if clientSessionData.ClientTunnelProtocol != "" {
+
+		if !protocol.IsValidClientTunnelProtocol(
+			clientSessionData.ClientTunnelProtocol,
+			server.listenerTunnelProtocol,
+			server.support.Config.GetRunningProtocols()) {
+
+			return "", nil, nil, "", "", errors.Tracef(
+				"invalid client tunnel protocol: %s", clientSessionData.ClientTunnelProtocol)
+		}
+
+		tunnelProtocol = clientSessionData.ClientTunnelProtocol
+	}
+
+	// Any rate limit is enforced after the meek cookie is validated, so a prober
+	// without the obfuscation secret will be unable to fingerprint the server
+	// based on response time combined with the rate limit configuration. The
+	// rate limit is primarily intended to limit memory resource consumption and
+	// not the overhead incurred by cookie validation.
+
+	if server.rateLimit(clientIP, tunnelProtocol) {
+		return "", nil, nil, "", "", errors.TraceNew("rate limit exceeded")
+	}
+
 	// Handle endpoints before enforcing CheckEstablishTunnels.
 	// Currently, endpoints are tactics requests, and we allow these to be
 	// handled by servers which would otherwise reject new tunnels.
@@ -729,19 +751,31 @@ func (server *MeekServer) getSessionOrEndpoint(
 	return sessionID, session, underlyingConn, "", "", nil
 }
 
-func (server *MeekServer) rateLimit(clientIP string) bool {
+func (server *MeekServer) rateLimit(clientIP string, tunnelProtocol string) bool {
 
-	historySize, thresholdSeconds, regions, ISPs, cities, GCTriggerCount, _ :=
+	historySize,
+		thresholdSeconds,
+		tunnelProtocols,
+		regions,
+		ISPs,
+		cities,
+		GCTriggerCount, _ :=
 		server.support.TrafficRulesSet.GetMeekRateLimiterConfig()
 
 	if historySize == 0 {
 		return false
 	}
 
+	if len(tunnelProtocols) > 0 {
+		if !common.Contains(tunnelProtocols, tunnelProtocol) {
+			return false
+		}
+	}
+
 	if len(regions) > 0 || len(ISPs) > 0 || len(cities) > 0 {
 
 		// TODO: avoid redundant GeoIP lookups?
-		geoIPData := server.support.GeoIPService.Lookup(clientIP, false)
+		geoIPData := server.support.GeoIPService.Lookup(clientIP)
 
 		if len(regions) > 0 {
 			if !common.Contains(regions, geoIPData.Country) {
@@ -811,7 +845,7 @@ func (server *MeekServer) rateLimit(clientIP string) bool {
 
 func (server *MeekServer) rateLimitWorker() {
 
-	_, _, _, _, _, _, reapFrequencySeconds :=
+	_, _, _, _, _, _, _, reapFrequencySeconds :=
 		server.support.TrafficRulesSet.GetMeekRateLimiterConfig()
 
 	timer := time.NewTimer(time.Duration(reapFrequencySeconds) * time.Second)
@@ -821,7 +855,7 @@ func (server *MeekServer) rateLimitWorker() {
 		select {
 		case <-timer.C:
 
-			_, thresholdSeconds, _, _, _, _, reapFrequencySeconds :=
+			_, thresholdSeconds, _, _, _, _, _, reapFrequencySeconds :=
 				server.support.TrafficRulesSet.GetMeekRateLimiterConfig()
 
 			server.rateLimitLock.Lock()

+ 29 - 12
psiphon/server/meek_test.go

@@ -37,6 +37,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"golang.org/x/crypto/nacl/box"
 )
 
@@ -400,9 +401,18 @@ func (interruptor *fileDescriptorInterruptor) BindToDevice(fileDescriptor int) (
 }
 
 func TestMeekRateLimiter(t *testing.T) {
+	runTestMeekRateLimiter(t, true)
+	runTestMeekRateLimiter(t, false)
+}
+
+func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
+
+	attempts := 10
 
 	allowedConnections := 5
-	testDurationSeconds := 10
+	if !rateLimit {
+		allowedConnections = 10
+	}
 
 	// Run meek server
 
@@ -414,14 +424,23 @@ func TestMeekRateLimiter(t *testing.T) {
 	meekCookieEncryptionPrivateKey := base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPrivateKey[:])
 	meekObfuscatedKey := prng.HexString(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
 
+	tunnelProtocol := protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK
+
+	meekRateLimiterTunnelProtocols := []string{tunnelProtocol}
+	if !rateLimit {
+		meekRateLimiterTunnelProtocols = []string{protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS}
+	}
+
 	mockSupport := &SupportServices{
 		Config: &Config{
 			MeekObfuscatedKey:              meekObfuscatedKey,
 			MeekCookieEncryptionPrivateKey: meekCookieEncryptionPrivateKey,
+			TunnelProtocolPorts:            map[string]int{tunnelProtocol: 0},
 		},
 		TrafficRulesSet: &TrafficRulesSet{
 			MeekRateLimiterHistorySize:                   allowedConnections,
-			MeekRateLimiterThresholdSeconds:              testDurationSeconds,
+			MeekRateLimiterThresholdSeconds:              attempts,
+			MeekRateLimiterTunnelProtocols:               meekRateLimiterTunnelProtocols,
 			MeekRateLimiterGarbageCollectionTriggerCount: 1,
 			MeekRateLimiterReapHistoryFrequencySeconds:   1,
 		},
@@ -444,7 +463,7 @@ func TestMeekRateLimiter(t *testing.T) {
 	server, err := NewMeekServer(
 		mockSupport,
 		listener,
-		"",
+		tunnelProtocol,
 		0,
 		useTLS,
 		isFronted,
@@ -486,15 +505,13 @@ func TestMeekRateLimiter(t *testing.T) {
 	}()
 
 	// Run meek clients:
-	// For 10 seconds, connect once per second vs. rate limit of 5-per-10 seconds,
+	// For 10 attempts, connect once per second vs. rate limit of 5-per-10 seconds,
 	// so about half of the connections should be rejected by the rate limiter.
 
-	stopTime := time.Now().Add(time.Duration(testDurationSeconds) * time.Second)
-
 	totalConnections := 0
 	totalFailures := 0
 
-	for {
+	for i := 0; i < attempts; i++ {
 
 		dialConfig := &psiphon.DialConfig{}
 
@@ -541,14 +558,14 @@ func TestMeekRateLimiter(t *testing.T) {
 			totalConnections += 1
 		}
 
-		if !time.Now().Before(stopTime) {
-			break
+		if i < attempts-1 {
+			time.Sleep(1 * time.Second)
 		}
-
-		time.Sleep(1 * time.Second)
 	}
 
-	if totalConnections != allowedConnections || totalFailures == 0 {
+	if totalConnections != allowedConnections ||
+		totalFailures != attempts-totalConnections {
+
 		t.Fatalf(
 			"Unexpected results: %d connections, %d failures",
 			totalConnections, totalFailures)

+ 1 - 1
psiphon/server/packetman.go

@@ -141,7 +141,7 @@ func selectPacketManipulationSpec(
 			"packet manipulation protocol port not found: %d", protocolPort)
 	}
 
-	geoIPData := support.GeoIPService.LookupIP(clientIP, false)
+	geoIPData := support.GeoIPService.LookupIP(clientIP)
 
 	specName, doReplay := support.ReplayCache.GetReplayPacketManipulation(
 		targetTunnelProtocol, geoIPData)

+ 8 - 4
psiphon/server/services.go

@@ -98,6 +98,7 @@ func RunServices(configJSON []byte) (retErr error) {
 			SudoNetworkConfigCommands:   config.PacketTunnelSudoNetworkConfigCommands,
 			GetDNSResolverIPv4Addresses: support.DNSResolver.GetAllIPv4,
 			GetDNSResolverIPv6Addresses: support.DNSResolver.GetAllIPv6,
+			EnableDNSFlowTracking:       config.PacketTunnelEnableDNSFlowTracking,
 			EgressInterface:             config.PacketTunnelEgressInterface,
 			DownstreamPacketQueueSize:   config.PacketTunnelDownstreamPacketQueueSize,
 			SessionIdleExpirySeconds:    config.PacketTunnelSessionIdleExpirySeconds,
@@ -376,9 +377,13 @@ func logServerLoad(support *SupportServices) {
 
 	serverLoad.Add(support.ServerTacticsParametersCache.GetMetrics())
 
-	protocolStats, regionStats :=
+	upstreamStats, protocolStats, regionStats :=
 		support.TunnelServer.GetLoadStats()
 
+	for name, value := range upstreamStats {
+		serverLoad[name] = value
+	}
+
 	for protocol, stats := range protocolStats {
 		serverLoad[protocol] = stats
 	}
@@ -415,7 +420,7 @@ func logIrregularTunnel(
 	logFields["event_name"] = "irregular_tunnel"
 	logFields["listener_protocol"] = listenerTunnelProtocol
 	logFields["listener_port_number"] = listenerPort
-	support.GeoIPService.Lookup(clientIP, false).SetLogFields(logFields)
+	support.GeoIPService.Lookup(clientIP).SetLogFields(logFields)
 	logFields["tunnel_error"] = tunnelError.Error()
 	log.LogRawFieldsWithTimestamp(logFields)
 }
@@ -459,8 +464,7 @@ func NewSupportServices(config *Config) (*SupportServices, error) {
 		return nil, errors.Trace(err)
 	}
 
-	geoIPService, err := NewGeoIPService(
-		config.GeoIPDatabaseFilenames, config.DiscoveryValueHMACKey)
+	geoIPService, err := NewGeoIPService(config.GeoIPDatabaseFilenames)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}

+ 50 - 5
psiphon/server/trafficRules.go

@@ -70,8 +70,16 @@ type TrafficRulesSet struct {
 	// any client endpoint request or any request to create a new session, but
 	// not any meek request for an existing session, if the
 	// MeekRateLimiterHistorySize requests occur in
-	// MeekRateLimiterThresholdSeconds. The scope of rate limiting may be
-	// limited using LimitMeekRateLimiterRegions/ISPs/Cities.
+	// MeekRateLimiterThresholdSeconds.
+	//
+	// A use case for the the meek rate limiter is to mitigate dangling resource
+	// usage that results from meek connections that are partially established
+	// and then interrupted (e.g, drop packets after allowing up to the initial
+	// HTTP request and header lines). In the case of CDN fronted meek, the CDN
+	// itself may hold open the interrupted connection.
+	//
+	// The scope of rate limiting may be
+	// limited using LimitMeekRateLimiterTunnelProtocols/Regions/ISPs/Cities.
 	//
 	// Hot reloading a new history size will result in existing history being
 	// truncated.
@@ -81,6 +89,11 @@ type TrafficRulesSet struct {
 	// specification and must be set when MeekRateLimiterHistorySize is set.
 	MeekRateLimiterThresholdSeconds int
 
+	// MeekRateLimiterTunnelProtocols, if set, limits application of the meek
+	// late-stage rate limiter to the specified meek protocols. When omitted or
+	// empty, meek rate limiting is applied to all meek protocols.
+	MeekRateLimiterTunnelProtocols []string
+
 	// MeekRateLimiterRegions, if set, limits application of the meek
 	// late-stage rate limiter to clients in the specified list of GeoIP
 	// countries. When omitted or empty, meek rate limiting, if configured,
@@ -266,6 +279,12 @@ type RateLimits struct {
 	WriteBytesPerSecond   *int64
 	CloseAfterExhausted   *bool
 
+	// EstablishmentRead/WriteBytesPerSecond are used in place of
+	// Read/WriteBytesPerSecond for tunnels in the establishment phase, from the
+	// initial network connection up to the completion of the API handshake.
+	EstablishmentReadBytesPerSecond  *int64
+	EstablishmentWriteBytesPerSecond *int64
+
 	// UnthrottleFirstTunnelOnly specifies whether any
 	// ReadUnthrottledBytes/WriteUnthrottledBytes apply
 	// only to the first tunnel in a session.
@@ -273,14 +292,19 @@ type RateLimits struct {
 }
 
 // CommonRateLimits converts a RateLimits to a common.RateLimits.
-func (rateLimits *RateLimits) CommonRateLimits() common.RateLimits {
-	return common.RateLimits{
+func (rateLimits *RateLimits) CommonRateLimits(handshaked bool) common.RateLimits {
+	r := common.RateLimits{
 		ReadUnthrottledBytes:  *rateLimits.ReadUnthrottledBytes,
 		ReadBytesPerSecond:    *rateLimits.ReadBytesPerSecond,
 		WriteUnthrottledBytes: *rateLimits.WriteUnthrottledBytes,
 		WriteBytesPerSecond:   *rateLimits.WriteBytesPerSecond,
 		CloseAfterExhausted:   *rateLimits.CloseAfterExhausted,
 	}
+	if !handshaked {
+		r.ReadBytesPerSecond = *rateLimits.EstablishmentReadBytesPerSecond
+		r.WriteBytesPerSecond = *rateLimits.EstablishmentWriteBytesPerSecond
+	}
+	return r
 }
 
 // NewTrafficRulesSet initializes a TrafficRulesSet with
@@ -306,6 +330,7 @@ func NewTrafficRulesSet(filename string) (*TrafficRulesSet, error) {
 			// Modify actual traffic rules only after validation
 			set.MeekRateLimiterHistorySize = newSet.MeekRateLimiterHistorySize
 			set.MeekRateLimiterThresholdSeconds = newSet.MeekRateLimiterThresholdSeconds
+			set.MeekRateLimiterTunnelProtocols = newSet.MeekRateLimiterTunnelProtocols
 			set.MeekRateLimiterRegions = newSet.MeekRateLimiterRegions
 			set.MeekRateLimiterISPs = newSet.MeekRateLimiterISPs
 			set.MeekRateLimiterCities = newSet.MeekRateLimiterCities
@@ -349,6 +374,8 @@ func (set *TrafficRulesSet) Validate() error {
 			(rules.RateLimits.ReadBytesPerSecond != nil && *rules.RateLimits.ReadBytesPerSecond < 0) ||
 			(rules.RateLimits.WriteUnthrottledBytes != nil && *rules.RateLimits.WriteUnthrottledBytes < 0) ||
 			(rules.RateLimits.WriteBytesPerSecond != nil && *rules.RateLimits.WriteBytesPerSecond < 0) ||
+			(rules.RateLimits.EstablishmentReadBytesPerSecond != nil && *rules.RateLimits.EstablishmentReadBytesPerSecond < 0) ||
+			(rules.RateLimits.EstablishmentWriteBytesPerSecond != nil && *rules.RateLimits.EstablishmentWriteBytesPerSecond < 0) ||
 			(rules.DialTCPPortForwardTimeoutMilliseconds != nil && *rules.DialTCPPortForwardTimeoutMilliseconds < 0) ||
 			(rules.IdleTCPPortForwardTimeoutMilliseconds != nil && *rules.IdleTCPPortForwardTimeoutMilliseconds < 0) ||
 			(rules.IdleUDPPortForwardTimeoutMilliseconds != nil && *rules.IdleUDPPortForwardTimeoutMilliseconds < 0) ||
@@ -527,6 +554,14 @@ func (set *TrafficRulesSet) GetTrafficRules(
 		trafficRules.RateLimits.CloseAfterExhausted = new(bool)
 	}
 
+	if trafficRules.RateLimits.EstablishmentReadBytesPerSecond == nil {
+		trafficRules.RateLimits.EstablishmentReadBytesPerSecond = new(int64)
+	}
+
+	if trafficRules.RateLimits.EstablishmentWriteBytesPerSecond == nil {
+		trafficRules.RateLimits.EstablishmentWriteBytesPerSecond = new(int64)
+	}
+
 	if trafficRules.RateLimits.UnthrottleFirstTunnelOnly == nil {
 		trafficRules.RateLimits.UnthrottleFirstTunnelOnly = new(bool)
 	}
@@ -727,6 +762,14 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			trafficRules.RateLimits.CloseAfterExhausted = filteredRules.Rules.RateLimits.CloseAfterExhausted
 		}
 
+		if filteredRules.Rules.RateLimits.EstablishmentReadBytesPerSecond != nil {
+			trafficRules.RateLimits.EstablishmentReadBytesPerSecond = filteredRules.Rules.RateLimits.EstablishmentReadBytesPerSecond
+		}
+
+		if filteredRules.Rules.RateLimits.EstablishmentWriteBytesPerSecond != nil {
+			trafficRules.RateLimits.EstablishmentWriteBytesPerSecond = filteredRules.Rules.RateLimits.EstablishmentWriteBytesPerSecond
+		}
+
 		if filteredRules.Rules.RateLimits.UnthrottleFirstTunnelOnly != nil {
 			trafficRules.RateLimits.UnthrottleFirstTunnelOnly = filteredRules.Rules.RateLimits.UnthrottleFirstTunnelOnly
 		}
@@ -877,7 +920,8 @@ func (rules *TrafficRules) allowSubnet(remoteIP net.IP) bool {
 
 // GetMeekRateLimiterConfig gets a snapshot of the meek rate limiter
 // configuration values.
-func (set *TrafficRulesSet) GetMeekRateLimiterConfig() (int, int, []string, []string, []string, int, int) {
+func (set *TrafficRulesSet) GetMeekRateLimiterConfig() (
+	int, int, []string, []string, []string, []string, int, int) {
 
 	set.ReloadableFile.RLock()
 	defer set.ReloadableFile.RUnlock()
@@ -895,6 +939,7 @@ func (set *TrafficRulesSet) GetMeekRateLimiterConfig() (int, int, []string, []st
 
 	return set.MeekRateLimiterHistorySize,
 		set.MeekRateLimiterThresholdSeconds,
+		set.MeekRateLimiterTunnelProtocols,
 		set.MeekRateLimiterRegions,
 		set.MeekRateLimiterISPs,
 		set.MeekRateLimiterCities,

+ 289 - 136
psiphon/server/tunnelServer.go

@@ -199,7 +199,7 @@ func (server *TunnelServer) Run() error {
 			support,
 			listener,
 			tunnelProtocol,
-			func(IP string) GeoIPData { return support.GeoIPService.Lookup(IP, false) })
+			func(IP string) GeoIPData { return support.GeoIPService.Lookup(IP) })
 
 		log.WithTraceFields(
 			LogFields{
@@ -264,7 +264,9 @@ func (server *TunnelServer) Run() error {
 // broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
 // include current connected client count, total number of current port
 // forwards.
-func (server *TunnelServer) GetLoadStats() (ProtocolStats, RegionStats) {
+func (server *TunnelServer) GetLoadStats() (
+	UpstreamStats, ProtocolStats, RegionStats) {
+
 	return server.sshServer.getLoadStats()
 }
 
@@ -460,11 +462,6 @@ func (sshServer *sshServer) getEstablishTunnelsMetrics() (bool, int64) {
 // occurs, it will send the error to the listenerError channel.
 func (sshServer *sshServer) runListener(sshListener *sshListener, listenerError chan<- error) {
 
-	runningProtocols := make([]string, 0)
-	for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
-		runningProtocols = append(runningProtocols, tunnelProtocol)
-	}
-
 	handleClient := func(clientTunnelProtocol string, clientConn net.Conn) {
 
 		// Note: establish tunnel limiter cannot simply stop TCP
@@ -477,28 +474,15 @@ func (sshServer *sshServer) runListener(sshListener *sshListener, listenerError
 			return
 		}
 
-		// The tunnelProtocol passed to handleClient is used for stats,
-		// throttling, etc. When the tunnel protocol can be determined
-		// unambiguously from the listening port, use that protocol and
-		// don't use any client-declared value. Only use the client's
-		// value, if present, in special cases where the listening port
-		// cannot distinguish the protocol.
+		// tunnelProtocol is used for stats and traffic rules. In many cases, its
+		// value is unambiguously determined by the listener port. In certain cases,
+		// such as multiple fronted protocols with a single backend listener, the
+		// client's reported tunnel protocol value is used. The caller must validate
+		// clientTunnelProtocol with protocol.IsValidClientTunnelProtocol.
+
 		tunnelProtocol := sshListener.tunnelProtocol
 		if clientTunnelProtocol != "" {
-
-			if !common.Contains(runningProtocols, clientTunnelProtocol) {
-				log.WithTraceFields(
-					LogFields{
-						"clientTunnelProtocol": clientTunnelProtocol}).
-					Warning("invalid client tunnel protocol")
-				clientConn.Close()
-				return
-			}
-
-			if protocol.UseClientTunnelProtocol(
-				clientTunnelProtocol, runningProtocols) {
-				tunnelProtocol = clientTunnelProtocol
-			}
+			tunnelProtocol = clientTunnelProtocol
 		}
 
 		// sshListener.tunnelProtocol indictes the tunnel protocol run by the
@@ -717,51 +701,84 @@ func (sshServer *sshServer) unregisterEstablishedClient(client *sshClient) {
 	client.stop()
 }
 
-type ProtocolStats map[string]map[string]int64
-type RegionStats map[string]map[string]map[string]int64
+type UpstreamStats map[string]interface{}
+type ProtocolStats map[string]map[string]interface{}
+type RegionStats map[string]map[string]map[string]interface{}
 
-func (sshServer *sshServer) getLoadStats() (ProtocolStats, RegionStats) {
+func (sshServer *sshServer) getLoadStats() (
+	UpstreamStats, ProtocolStats, RegionStats) {
 
 	sshServer.clientsMutex.Lock()
 	defer sshServer.clientsMutex.Unlock()
 
-	// Explicitly populate with zeros to ensure 0 counts in log messages
-	zeroStats := func() map[string]int64 {
-		stats := make(map[string]int64)
-		stats["accepted_clients"] = 0
-		stats["established_clients"] = 0
-		stats["dialing_tcp_port_forwards"] = 0
-		stats["tcp_port_forwards"] = 0
-		stats["total_tcp_port_forwards"] = 0
-		stats["udp_port_forwards"] = 0
-		stats["total_udp_port_forwards"] = 0
-		stats["tcp_port_forward_dialed_count"] = 0
-		stats["tcp_port_forward_dialed_duration"] = 0
-		stats["tcp_port_forward_failed_count"] = 0
-		stats["tcp_port_forward_failed_duration"] = 0
-		stats["tcp_port_forward_rejected_dialing_limit_count"] = 0
-		stats["tcp_port_forward_rejected_disallowed_count"] = 0
-		stats["udp_port_forward_rejected_disallowed_count"] = 0
-		stats["tcp_ipv4_port_forward_dialed_count"] = 0
-		stats["tcp_ipv4_port_forward_dialed_duration"] = 0
-		stats["tcp_ipv4_port_forward_failed_count"] = 0
-		stats["tcp_ipv4_port_forward_failed_duration"] = 0
-		stats["tcp_ipv6_port_forward_dialed_count"] = 0
-		stats["tcp_ipv6_port_forward_dialed_duration"] = 0
-		stats["tcp_ipv6_port_forward_failed_count"] = 0
-		stats["tcp_ipv6_port_forward_failed_duration"] = 0
+	// Explicitly populate with zeros to ensure 0 counts in log messages.
+
+	zeroClientStats := func() map[string]interface{} {
+		stats := make(map[string]interface{})
+		stats["accepted_clients"] = int64(0)
+		stats["established_clients"] = int64(0)
 		return stats
 	}
 
-	zeroProtocolStats := func() map[string]map[string]int64 {
-		stats := make(map[string]map[string]int64)
-		stats["ALL"] = zeroStats()
+	// Due to hot reload and changes to the underlying system configuration, the
+	// set of resolver IPs may change between getLoadStats calls, so this
+	// enumeration for zeroing is a best effort.
+	resolverIPs := sshServer.support.DNSResolver.GetAll()
+
+	// Fields which are primarily concerned with upstream/egress performance.
+	zeroUpstreamStats := func() map[string]interface{} {
+		stats := make(map[string]interface{})
+		stats["dialing_tcp_port_forwards"] = int64(0)
+		stats["tcp_port_forwards"] = int64(0)
+		stats["total_tcp_port_forwards"] = int64(0)
+		stats["udp_port_forwards"] = int64(0)
+		stats["total_udp_port_forwards"] = int64(0)
+		stats["tcp_port_forward_dialed_count"] = int64(0)
+		stats["tcp_port_forward_dialed_duration"] = int64(0)
+		stats["tcp_port_forward_failed_count"] = int64(0)
+		stats["tcp_port_forward_failed_duration"] = int64(0)
+		stats["tcp_port_forward_rejected_dialing_limit_count"] = int64(0)
+		stats["tcp_port_forward_rejected_disallowed_count"] = int64(0)
+		stats["udp_port_forward_rejected_disallowed_count"] = int64(0)
+		stats["tcp_ipv4_port_forward_dialed_count"] = int64(0)
+		stats["tcp_ipv4_port_forward_dialed_duration"] = int64(0)
+		stats["tcp_ipv4_port_forward_failed_count"] = int64(0)
+		stats["tcp_ipv4_port_forward_failed_duration"] = int64(0)
+		stats["tcp_ipv6_port_forward_dialed_count"] = int64(0)
+		stats["tcp_ipv6_port_forward_dialed_duration"] = int64(0)
+		stats["tcp_ipv6_port_forward_failed_count"] = int64(0)
+		stats["tcp_ipv6_port_forward_failed_duration"] = int64(0)
+
+		zeroDNSStats := func() map[string]int64 {
+			m := map[string]int64{"ALL": 0}
+			for _, resolverIP := range resolverIPs {
+				m[resolverIP.String()] = 0
+			}
+			return m
+		}
+
+		stats["dns_count"] = zeroDNSStats()
+		stats["dns_duration"] = zeroDNSStats()
+		stats["dns_failed_count"] = zeroDNSStats()
+		stats["dns_failed_duration"] = zeroDNSStats()
+		return stats
+	}
+
+	zeroProtocolStats := func() map[string]map[string]interface{} {
+		stats := make(map[string]map[string]interface{})
+		stats["ALL"] = zeroClientStats()
 		for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
-			stats[tunnelProtocol] = zeroStats()
+			stats[tunnelProtocol] = zeroClientStats()
 		}
 		return stats
 	}
 
+	addInt64 := func(stats map[string]interface{}, name string, value int64) {
+		stats[name] = stats[name].(int64) + value
+	}
+
+	upstreamStats := zeroUpstreamStats()
+
 	// [<protocol or ALL>][<stat name>] -> count
 	protocolStats := zeroProtocolStats()
 
@@ -778,11 +795,11 @@ func (sshServer *sshServer) getLoadStats() (ProtocolStats, RegionStats) {
 					regionStats[region] = zeroProtocolStats()
 				}
 
-				protocolStats["ALL"]["accepted_clients"] += acceptedClientCount
-				protocolStats[tunnelProtocol]["accepted_clients"] += acceptedClientCount
+				addInt64(protocolStats["ALL"], "accepted_clients", acceptedClientCount)
+				addInt64(protocolStats[tunnelProtocol], "accepted_clients", acceptedClientCount)
 
-				regionStats[region]["ALL"]["accepted_clients"] += acceptedClientCount
-				regionStats[region][tunnelProtocol]["accepted_clients"] += acceptedClientCount
+				addInt64(regionStats[region]["ALL"], "accepted_clients", acceptedClientCount)
+				addInt64(regionStats[region][tunnelProtocol], "accepted_clients", acceptedClientCount)
 			}
 		}
 	}
@@ -798,75 +815,108 @@ func (sshServer *sshServer) getLoadStats() (ProtocolStats, RegionStats) {
 			regionStats[region] = zeroProtocolStats()
 		}
 
-		stats := []map[string]int64{
+		for _, stats := range []map[string]interface{}{
 			protocolStats["ALL"],
 			protocolStats[tunnelProtocol],
 			regionStats[region]["ALL"],
-			regionStats[region][tunnelProtocol]}
-
-		for _, stat := range stats {
-
-			stat["established_clients"] += 1
-
-			// Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
-
-			stat["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
-			stat["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
-			stat["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
-			// client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
-			stat["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
-			stat["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
-
-			stat["tcp_port_forward_dialed_count"] += client.qualityMetrics.TCPPortForwardDialedCount
-			stat["tcp_port_forward_dialed_duration"] +=
-				int64(client.qualityMetrics.TCPPortForwardDialedDuration / time.Millisecond)
-			stat["tcp_port_forward_failed_count"] += client.qualityMetrics.TCPPortForwardFailedCount
-			stat["tcp_port_forward_failed_duration"] +=
-				int64(client.qualityMetrics.TCPPortForwardFailedDuration / time.Millisecond)
-			stat["tcp_port_forward_rejected_dialing_limit_count"] +=
-				client.qualityMetrics.TCPPortForwardRejectedDialingLimitCount
-			stat["tcp_port_forward_rejected_disallowed_count"] +=
-				client.qualityMetrics.TCPPortForwardRejectedDisallowedCount
-			stat["udp_port_forward_rejected_disallowed_count"] +=
-				client.qualityMetrics.UDPPortForwardRejectedDisallowedCount
-
-			stat["tcp_ipv4_port_forward_dialed_count"] += client.qualityMetrics.TCPIPv4PortForwardDialedCount
-			stat["tcp_ipv4_port_forward_dialed_duration"] +=
-				int64(client.qualityMetrics.TCPIPv4PortForwardDialedDuration / time.Millisecond)
-			stat["tcp_ipv4_port_forward_failed_count"] += client.qualityMetrics.TCPIPv4PortForwardFailedCount
-			stat["tcp_ipv4_port_forward_failed_duration"] +=
-				int64(client.qualityMetrics.TCPIPv4PortForwardFailedDuration / time.Millisecond)
-
-			stat["tcp_ipv6_port_forward_dialed_count"] += client.qualityMetrics.TCPIPv6PortForwardDialedCount
-			stat["tcp_ipv6_port_forward_dialed_duration"] +=
-				int64(client.qualityMetrics.TCPIPv6PortForwardDialedDuration / time.Millisecond)
-			stat["tcp_ipv6_port_forward_failed_count"] += client.qualityMetrics.TCPIPv6PortForwardFailedCount
-			stat["tcp_ipv6_port_forward_failed_duration"] +=
-				int64(client.qualityMetrics.TCPIPv6PortForwardFailedDuration / time.Millisecond)
-		}
-
-		client.qualityMetrics.TCPPortForwardDialedCount = 0
-		client.qualityMetrics.TCPPortForwardDialedDuration = 0
-		client.qualityMetrics.TCPPortForwardFailedCount = 0
-		client.qualityMetrics.TCPPortForwardFailedDuration = 0
-		client.qualityMetrics.TCPPortForwardRejectedDialingLimitCount = 0
-		client.qualityMetrics.TCPPortForwardRejectedDisallowedCount = 0
-		client.qualityMetrics.UDPPortForwardRejectedDisallowedCount = 0
-
-		client.qualityMetrics.TCPIPv4PortForwardDialedCount = 0
-		client.qualityMetrics.TCPIPv4PortForwardDialedDuration = 0
-		client.qualityMetrics.TCPIPv4PortForwardFailedCount = 0
-		client.qualityMetrics.TCPIPv4PortForwardFailedDuration = 0
-
-		client.qualityMetrics.TCPIPv6PortForwardDialedCount = 0
-		client.qualityMetrics.TCPIPv6PortForwardDialedDuration = 0
-		client.qualityMetrics.TCPIPv6PortForwardFailedCount = 0
-		client.qualityMetrics.TCPIPv6PortForwardFailedDuration = 0
+			regionStats[region][tunnelProtocol]} {
+
+			addInt64(stats, "established_clients", 1)
+		}
+
+		// Note:
+		// - can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
+		// - client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
+
+		addInt64(upstreamStats, "dialing_tcp_port_forwards",
+			client.tcpTrafficState.concurrentDialingPortForwardCount)
+
+		addInt64(upstreamStats, "tcp_port_forwards",
+			client.tcpTrafficState.concurrentPortForwardCount)
+
+		addInt64(upstreamStats, "total_tcp_port_forwards",
+			client.tcpTrafficState.totalPortForwardCount)
+
+		addInt64(upstreamStats, "udp_port_forwards",
+			client.udpTrafficState.concurrentPortForwardCount)
+
+		addInt64(upstreamStats, "total_udp_port_forwards",
+			client.udpTrafficState.totalPortForwardCount)
+
+		addInt64(upstreamStats, "tcp_port_forward_dialed_count",
+			client.qualityMetrics.TCPPortForwardDialedCount)
+
+		addInt64(upstreamStats, "tcp_port_forward_dialed_duration",
+			int64(client.qualityMetrics.TCPPortForwardDialedDuration/time.Millisecond))
+
+		addInt64(upstreamStats, "tcp_port_forward_failed_count",
+			client.qualityMetrics.TCPPortForwardFailedCount)
+
+		addInt64(upstreamStats, "tcp_port_forward_failed_duration",
+			int64(client.qualityMetrics.TCPPortForwardFailedDuration/time.Millisecond))
+
+		addInt64(upstreamStats, "tcp_port_forward_rejected_dialing_limit_count",
+			client.qualityMetrics.TCPPortForwardRejectedDialingLimitCount)
+
+		addInt64(upstreamStats, "tcp_port_forward_rejected_disallowed_count",
+			client.qualityMetrics.TCPPortForwardRejectedDisallowedCount)
+
+		addInt64(upstreamStats, "udp_port_forward_rejected_disallowed_count",
+			client.qualityMetrics.UDPPortForwardRejectedDisallowedCount)
+
+		addInt64(upstreamStats, "tcp_ipv4_port_forward_dialed_count",
+			client.qualityMetrics.TCPIPv4PortForwardDialedCount)
+
+		addInt64(upstreamStats, "tcp_ipv4_port_forward_dialed_duration",
+			int64(client.qualityMetrics.TCPIPv4PortForwardDialedDuration/time.Millisecond))
+
+		addInt64(upstreamStats, "tcp_ipv4_port_forward_failed_count",
+			client.qualityMetrics.TCPIPv4PortForwardFailedCount)
+
+		addInt64(upstreamStats, "tcp_ipv4_port_forward_failed_duration",
+			int64(client.qualityMetrics.TCPIPv4PortForwardFailedDuration/time.Millisecond))
+
+		addInt64(upstreamStats, "tcp_ipv6_port_forward_dialed_count",
+			client.qualityMetrics.TCPIPv6PortForwardDialedCount)
+
+		addInt64(upstreamStats, "tcp_ipv6_port_forward_dialed_duration",
+			int64(client.qualityMetrics.TCPIPv6PortForwardDialedDuration/time.Millisecond))
+
+		addInt64(upstreamStats, "tcp_ipv6_port_forward_failed_count",
+			client.qualityMetrics.TCPIPv6PortForwardFailedCount)
+
+		addInt64(upstreamStats, "tcp_ipv6_port_forward_failed_duration",
+			int64(client.qualityMetrics.TCPIPv6PortForwardFailedDuration/time.Millisecond))
+
+		// DNS metrics limitations:
+		// - port forwards (sshClient.handleTCPChannel) don't know or log the resolver IP.
+		// - udpgw and packet tunnel transparent DNS use a heuristic to classify success/failure,
+		//   and there may be some delay before these code paths report DNS metrics.
+
+		// Every client.qualityMetrics DNS map has an "ALL" entry.
+
+		for key, value := range client.qualityMetrics.DNSCount {
+			upstreamStats["dns_count"].(map[string]int64)[key] += value
+		}
+
+		for key, value := range client.qualityMetrics.DNSDuration {
+			upstreamStats["dns_duration"].(map[string]int64)[key] += int64(value / time.Millisecond)
+		}
+
+		for key, value := range client.qualityMetrics.DNSFailedCount {
+			upstreamStats["dns_failed_count"].(map[string]int64)[key] += value
+		}
+
+		for key, value := range client.qualityMetrics.DNSFailedDuration {
+			upstreamStats["dns_failed_duration"].(map[string]int64)[key] += int64(value / time.Millisecond)
+		}
+
+		client.qualityMetrics.reset()
 
 		client.Unlock()
 	}
 
-	return protocolStats, regionStats
+	return upstreamStats, protocolStats, regionStats
 }
 
 func (sshServer *sshServer) getEstablishedClientCount() int {
@@ -1102,7 +1152,7 @@ func (sshServer *sshServer) handleClient(
 	}
 
 	geoIPData := sshServer.support.GeoIPService.Lookup(
-		common.IPAddressFromAddr(clientAddr), true)
+		common.IPAddressFromAddr(clientAddr))
 
 	sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
 	defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
@@ -1158,6 +1208,7 @@ func (sshServer *sshServer) handleClient(
 		tunnelProtocol,
 		serverPacketManipulation,
 		replayedServerPacketManipulation,
+		clientAddr,
 		geoIPData)
 
 	// sshClient.run _must_ call onSSHHandshakeFinished to release the semaphore:
@@ -1203,6 +1254,7 @@ type sshClient struct {
 	throttledConn                        *common.ThrottledConn
 	serverPacketManipulation             string
 	replayedServerPacketManipulation     bool
+	clientAddr                           net.Addr
 	geoIPData                            GeoIPData
 	sessionID                            string
 	isFirstTunnelInSession               bool
@@ -1213,7 +1265,7 @@ type sshClient struct {
 	trafficRules                         TrafficRules
 	tcpTrafficState                      trafficState
 	udpTrafficState                      trafficState
-	qualityMetrics                       qualityMetrics
+	qualityMetrics                       *qualityMetrics
 	tcpPortForwardLRU                    *common.LRUConns
 	oslClientSeedState                   *osl.ClientSeedState
 	signalIssueSLOKs                     chan struct{}
@@ -1269,6 +1321,57 @@ type qualityMetrics struct {
 	TCPIPv6PortForwardDialedDuration        time.Duration
 	TCPIPv6PortForwardFailedCount           int64
 	TCPIPv6PortForwardFailedDuration        time.Duration
+	DNSCount                                map[string]int64
+	DNSDuration                             map[string]time.Duration
+	DNSFailedCount                          map[string]int64
+	DNSFailedDuration                       map[string]time.Duration
+}
+
+func newQualityMetrics() *qualityMetrics {
+	return &qualityMetrics{
+		DNSCount:          make(map[string]int64),
+		DNSDuration:       make(map[string]time.Duration),
+		DNSFailedCount:    make(map[string]int64),
+		DNSFailedDuration: make(map[string]time.Duration),
+	}
+}
+
+func (q *qualityMetrics) reset() {
+
+	q.TCPPortForwardDialedCount = 0
+	q.TCPPortForwardDialedDuration = 0
+	q.TCPPortForwardFailedCount = 0
+	q.TCPPortForwardFailedDuration = 0
+	q.TCPPortForwardRejectedDialingLimitCount = 0
+	q.TCPPortForwardRejectedDisallowedCount = 0
+
+	q.UDPPortForwardRejectedDisallowedCount = 0
+
+	q.TCPIPv4PortForwardDialedCount = 0
+	q.TCPIPv4PortForwardDialedDuration = 0
+	q.TCPIPv4PortForwardFailedCount = 0
+	q.TCPIPv4PortForwardFailedDuration = 0
+
+	q.TCPIPv6PortForwardDialedCount = 0
+	q.TCPIPv6PortForwardDialedDuration = 0
+	q.TCPIPv6PortForwardFailedCount = 0
+	q.TCPIPv6PortForwardFailedDuration = 0
+
+	// Retain existing maps to avoid memory churn. The Go compiler optimizes map
+	// clearing operations of the following form.
+
+	for k := range q.DNSCount {
+		delete(q.DNSCount, k)
+	}
+	for k := range q.DNSDuration {
+		delete(q.DNSDuration, k)
+	}
+	for k := range q.DNSFailedCount {
+		delete(q.DNSFailedCount, k)
+	}
+	for k := range q.DNSFailedDuration {
+		delete(q.DNSFailedDuration, k)
+	}
 }
 
 type handshakeState struct {
@@ -1296,6 +1399,7 @@ func newSshClient(
 	tunnelProtocol string,
 	serverPacketManipulation string,
 	replayedServerPacketManipulation bool,
+	clientAddr net.Addr,
 	geoIPData GeoIPData) *sshClient {
 
 	runCtx, stopRunning := context.WithCancel(context.Background())
@@ -1310,8 +1414,10 @@ func newSshClient(
 		tunnelProtocol:                   tunnelProtocol,
 		serverPacketManipulation:         serverPacketManipulation,
 		replayedServerPacketManipulation: replayedServerPacketManipulation,
+		clientAddr:                       clientAddr,
 		geoIPData:                        geoIPData,
 		isFirstTunnelInSession:           true,
+		qualityMetrics:                   newQualityMetrics(),
 		tcpPortForwardLRU:                common.NewLRUConns(),
 		signalIssueSLOKs:                 make(chan struct{}, 1),
 		runCtx:                           runCtx,
@@ -1988,8 +2094,14 @@ func (sshClient *sshClient) handleSSHRequests(requests <-chan *ssh.Request) {
 			// Note: unlock before use is only safe as long as referenced sshClient data,
 			// such as slices in handshakeState, is read-only after initially set.
 
+			clientAddr := ""
+			if sshClient.clientAddr != nil {
+				clientAddr = sshClient.clientAddr.String()
+			}
+
 			responsePayload, err = sshAPIRequestHandler(
 				sshClient.sshServer.support,
+				clientAddr,
 				sshClient.geoIPData,
 				authorizedAccessTypes,
 				request.Type,
@@ -2161,9 +2273,9 @@ func (sshClient *sshClient) handleNewRandomStreamChannel(
 	// is available pre-handshake, albeit with additional restrictions.
 	//
 	// The random stream is subject to throttling in traffic rules; for
-	// unthrottled liveness tests, set initial Read/WriteUnthrottledBytes as
-	// required. The random stream maximum count and response size cap
-	// mitigate clients abusing the facility to waste server resources.
+	// unthrottled liveness tests, set EstablishmentRead/WriteBytesPerSecond as
+	// required. The random stream maximum count and response size cap mitigate
+	// clients abusing the facility to waste server resources.
 	//
 	// Like all other channels, this channel type is handled asynchronously,
 	// so it's possible to run at any point in the tunnel lifecycle.
@@ -2337,6 +2449,8 @@ func (sshClient *sshClient) handleNewPacketTunnelChannel(
 		sshClient.Unlock()
 	}
 
+	dnsQualityReporter := sshClient.updateQualityMetricsWithDNSResult
+
 	err = sshClient.sshServer.support.PacketTunnelServer.ClientConnected(
 		sshClient.sessionID,
 		packetTunnelChannel,
@@ -2344,7 +2458,8 @@ func (sshClient *sshClient) handleNewPacketTunnelChannel(
 		checkAllowedUDPPortFunc,
 		checkAllowedDomainFunc,
 		flowActivityUpdaterMaker,
-		metricUpdater)
+		metricUpdater,
+		dnsQualityReporter)
 	if err != nil {
 		log.WithTraceFields(LogFields{"error": err}).Warning("start packet tunnel client failed")
 		sshClient.setPacketTunnelChannel(nil)
@@ -3008,7 +3123,8 @@ func (sshClient *sshClient) setTrafficRules() (int64, int64) {
 	if sshClient.throttledConn != nil {
 		// Any existing throttling state is reset.
 		sshClient.throttledConn.SetLimits(
-			sshClient.trafficRules.RateLimits.CommonRateLimits())
+			sshClient.trafficRules.RateLimits.CommonRateLimits(
+				sshClient.handshakeState.completed))
 	}
 
 	return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
@@ -3102,7 +3218,8 @@ func (sshClient *sshClient) rateLimits() common.RateLimits {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
-	return sshClient.trafficRules.RateLimits.CommonRateLimits()
+	return sshClient.trafficRules.RateLimits.CommonRateLimits(
+		sshClient.handshakeState.completed)
 }
 
 func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
@@ -3496,6 +3613,33 @@ func (sshClient *sshClient) updateQualityMetricsWithUDPRejectedDisallowed() {
 	sshClient.qualityMetrics.UDPPortForwardRejectedDisallowedCount += 1
 }
 
+func (sshClient *sshClient) updateQualityMetricsWithDNSResult(
+	success bool, duration time.Duration, resolverIP net.IP) {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	resolver := ""
+	if resolverIP != nil {
+		resolver = resolverIP.String()
+	}
+	if success {
+		sshClient.qualityMetrics.DNSCount["ALL"] += 1
+		sshClient.qualityMetrics.DNSDuration["ALL"] += duration
+		if resolver != "" {
+			sshClient.qualityMetrics.DNSCount[resolver] += 1
+			sshClient.qualityMetrics.DNSDuration[resolver] += duration
+		}
+	} else {
+		sshClient.qualityMetrics.DNSFailedCount["ALL"] += 1
+		sshClient.qualityMetrics.DNSFailedDuration["ALL"] += duration
+		if resolver != "" {
+			sshClient.qualityMetrics.DNSFailedCount[resolver] += 1
+			sshClient.qualityMetrics.DNSFailedDuration[resolver] += duration
+		}
+	}
+}
+
 func (sshClient *sshClient) handleTCPChannel(
 	remainingDialTimeout time.Duration,
 	hostToConnect string,
@@ -3573,6 +3717,17 @@ func (sshClient *sshClient) handleTCPChannel(
 	IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
 	cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
 
+	resolveElapsedTime := time.Since(dialStartTime)
+
+	// Record DNS metrics. If LookupIPAddr returns net.DNSError.IsNotFound, this
+	// is "no such host" and not a DNS failure. Limitation: the resolver IP is
+	// not known.
+
+	dnsErr, ok := err.(*net.DNSError)
+	dnsNotFound := ok && dnsErr.IsNotFound
+	dnsSuccess := err == nil || dnsNotFound
+	sshClient.updateQualityMetricsWithDNSResult(dnsSuccess, resolveElapsedTime, nil)
+
 	// IPv4 is preferred in case the host has limited IPv6 routing. IPv6 is
 	// selected and attempted only when there's no IPv4 option.
 	// TODO: shuffle list to try other IPs?
@@ -3593,8 +3748,6 @@ func (sshClient *sshClient) handleTCPChannel(
 		err = std_errors.New("no IP address")
 	}
 
-	resolveElapsedTime := time.Since(dialStartTime)
-
 	if err != nil {
 
 		// Record a port forward failure
@@ -3639,7 +3792,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 	if doSplitTunnel {
 
-		destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP, false)
+		destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP)
 
 		if destinationGeoIPData.Country == sshClient.geoIPData.Country &&
 			sshClient.geoIPData.Country != GEOIP_UNKNOWN_VALUE {

+ 50 - 9
psiphon/server/udp.go

@@ -29,6 +29,7 @@ import (
 	"sync"
 	"sync/atomic"
 
+	"github.com/Psiphon-Labs/goarista/monotime"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -241,12 +242,18 @@ func (mux *udpPortForwardMultiplexer) run() {
 				preambleSize: message.preambleSize,
 				remoteIP:     message.remoteIP,
 				remotePort:   message.remotePort,
+				dialIP:       dialIP,
 				conn:         conn,
 				lruEntry:     lruEntry,
 				bytesUp:      0,
 				bytesDown:    0,
 				mux:          mux,
 			}
+
+			if message.forwardDNS {
+				portForward.dnsFirstWriteTime = int64(monotime.Now())
+			}
+
 			mux.portForwardsMutex.Lock()
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwardsMutex.Unlock()
@@ -291,15 +298,18 @@ type udpPortForward struct {
 	// Note: 64-bit ints used with atomic operations are placed
 	// at the start of struct to ensure 64-bit alignment.
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
-	bytesUp      int64
-	bytesDown    int64
-	connID       uint16
-	preambleSize int
-	remoteIP     []byte
-	remotePort   uint16
-	conn         net.Conn
-	lruEntry     *common.LRUConnsEntry
-	mux          *udpPortForwardMultiplexer
+	dnsFirstWriteTime int64
+	dnsFirstReadTime  int64
+	bytesUp           int64
+	bytesDown         int64
+	connID            uint16
+	preambleSize      int
+	remoteIP          []byte
+	remotePort        uint16
+	dialIP            net.IP
+	conn              net.Conn
+	lruEntry          *common.LRUConnsEntry
+	mux               *udpPortForwardMultiplexer
 }
 
 func (portForward *udpPortForward) relayDownstream() {
@@ -330,6 +340,11 @@ func (portForward *udpPortForward) relayDownstream() {
 			break
 		}
 
+		if atomic.LoadInt64(&portForward.dnsFirstWriteTime) > 0 &&
+			atomic.LoadInt64(&portForward.dnsFirstReadTime) == 0 { // Check if already set before invoking Now.
+			atomic.CompareAndSwapInt64(&portForward.dnsFirstReadTime, 0, int64(monotime.Now()))
+		}
+
 		err = writeUdpgwPreamble(
 			portForward.preambleSize,
 			0,
@@ -369,6 +384,32 @@ func (portForward *udpPortForward) relayDownstream() {
 	bytesDown := atomic.LoadInt64(&portForward.bytesDown)
 	portForward.mux.sshClient.closedPortForward(portForwardTypeUDP, bytesUp, bytesDown)
 
+	dnsStartTime := monotime.Time(atomic.LoadInt64(&portForward.dnsFirstWriteTime))
+	if dnsStartTime > 0 {
+
+		// Record DNS metrics using a heuristic: if a UDP packet was written and
+		// then a packet was read, assume the DNS request successfully received a
+		// valid response; failure occurs when the resolver fails to provide a
+		// response; a "no such host" response is still a success. Limitations: we
+		// assume a resolver will not respond when, e.g., rate limiting; we ignore
+		// subsequent requests made via the same UDP port forward.
+
+		dnsEndTime := monotime.Time(atomic.LoadInt64(&portForward.dnsFirstReadTime))
+
+		dnsSuccess := true
+		if dnsEndTime == 0 {
+			dnsSuccess = false
+			dnsEndTime = monotime.Now()
+		}
+
+		resolveElapsedTime := dnsEndTime.Sub(dnsStartTime)
+
+		portForward.mux.sshClient.updateQualityMetricsWithDNSResult(
+			dnsSuccess,
+			resolveElapsedTime,
+			net.IP(portForward.dialIP))
+	}
+
 	log.WithTraceFields(
 		LogFields{
 			"remoteAddr": fmt.Sprintf("%s:%d",

+ 4 - 0
psiphon/server/webServer.go

@@ -240,6 +240,7 @@ func (webServer *webServer) handshakeHandler(w http.ResponseWriter, r *http.Requ
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
 			protocol.PSIPHON_WEB_API_PROTOCOL,
+			r.RemoteAddr,
 			webServer.lookupGeoIPData(params),
 			nil,
 			protocol.PSIPHON_API_HANDSHAKE_REQUEST_NAME,
@@ -271,6 +272,7 @@ func (webServer *webServer) connectedHandler(w http.ResponseWriter, r *http.Requ
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
 			protocol.PSIPHON_WEB_API_PROTOCOL,
+			r.RemoteAddr,
 			webServer.lookupGeoIPData(params),
 			nil, // authorizedAccessTypes not logged in web API transport
 			protocol.PSIPHON_API_CONNECTED_REQUEST_NAME,
@@ -296,6 +298,7 @@ func (webServer *webServer) statusHandler(w http.ResponseWriter, r *http.Request
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
 			protocol.PSIPHON_WEB_API_PROTOCOL,
+			r.RemoteAddr,
 			webServer.lookupGeoIPData(params),
 			nil, // authorizedAccessTypes not logged in web API transport
 			protocol.PSIPHON_API_STATUS_REQUEST_NAME,
@@ -322,6 +325,7 @@ func (webServer *webServer) clientVerificationHandler(w http.ResponseWriter, r *
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
 			protocol.PSIPHON_WEB_API_PROTOCOL,
+			r.RemoteAddr,
 			webServer.lookupGeoIPData(params),
 			nil, // authorizedAccessTypes not logged in web API transport
 			protocol.PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME,

+ 5 - 3
psiphon/serverApi.go

@@ -55,7 +55,6 @@ type ServerContext struct {
 	tunnel                   *Tunnel
 	psiphonHttpsClient       *http.Client
 	statsRegexps             *transferstats.Regexps
-	clientRegion             string
 	clientUpgradeVersion     string
 	serverHandshakeTimestamp string
 	paddingPRNG              *prng.PRNG
@@ -225,8 +224,11 @@ func (serverContext *ServerContext) doHandshakeRequest(
 		return errors.Trace(err)
 	}
 
-	serverContext.clientRegion = handshakeResponse.ClientRegion
-	NoticeClientRegion(serverContext.clientRegion)
+	if serverContext.tunnel.config.EmitClientAddress {
+		NoticeClientAddress(handshakeResponse.ClientAddress)
+	}
+
+	NoticeClientRegion(handshakeResponse.ClientRegion)
 
 	var serverEntries []protocol.ServerEntryFields
 

+ 11 - 2
psiphon/tunnel.go

@@ -695,8 +695,6 @@ func dialTunnel(
 	// parameters are cleared, no longer to be retried, if the tunnel fails to
 	// connect.
 	//
-	//
-	//
 	// Limitation: dials that fail to connect due to the server being in a
 	// load-limiting state are not distinguished and excepted from this
 	// logic.
@@ -1158,6 +1156,17 @@ func dialTunnel(
 
 	cleanupConn = nil
 
+	// When configured to do so, hold-off on activating this tunnel. This allows
+	// some extra time for slower but less resource intensive protocols to
+	// establish tunnels. By holding off post-connect, the client has this
+	// established tunnel ready to activate in case other protocols fail to
+	// establish. This hold-off phase continues to consume one connection worker.
+
+	if dialParams.HoldOffTunnelDuration > 0 {
+		NoticeHoldOffTunnel(dialParams.ServerEntry.GetDiagnosticID(), dialParams.HoldOffTunnelDuration)
+		common.SleepWithContext(ctx, dialParams.HoldOffTunnelDuration)
+	}
+
 	// Note: dialConn may be used to close the underlying network connection
 	// but should not be used to perform I/O as that would interfere with SSH
 	// (and also bypasses throttling).