Kaynağa Gözat

Add inproxy package

Rod Hynes 2 yıl önce
ebeveyn
işleme
e6ed2a9c19
34 değiştirilmiş dosya ile 12219 ekleme ve 34 silme
  1. 2 0
      .github/workflows/tests.yml
  2. 34 0
      psiphon/common/inproxy/Dockerfile
  3. 875 0
      psiphon/common/inproxy/api.go
  4. 1191 0
      psiphon/common/inproxy/broker.go
  5. 234 0
      psiphon/common/inproxy/brokerClient.go
  6. 422 0
      psiphon/common/inproxy/client.go
  7. 242 0
      psiphon/common/inproxy/dialParameters.go
  8. 394 0
      psiphon/common/inproxy/dialParameters_test.go
  9. 422 0
      psiphon/common/inproxy/discovery.go
  10. 130 0
      psiphon/common/inproxy/discovery_test.go
  11. 131 0
      psiphon/common/inproxy/doc.go
  12. 101 0
      psiphon/common/inproxy/dtls.go
  13. 809 0
      psiphon/common/inproxy/inproxy_test.go
  14. 749 0
      psiphon/common/inproxy/matcher.go
  15. 458 0
      psiphon/common/inproxy/matcher_test.go
  16. 406 0
      psiphon/common/inproxy/nat.go
  17. 400 0
      psiphon/common/inproxy/obfuscation.go
  18. 416 0
      psiphon/common/inproxy/obfuscation_test.go
  19. 159 0
      psiphon/common/inproxy/packet-trace.diff
  20. 250 0
      psiphon/common/inproxy/portmapper.go
  21. 30 0
      psiphon/common/inproxy/portmapper_android.go
  22. 28 0
      psiphon/common/inproxy/portmapper_other.go
  23. 630 0
      psiphon/common/inproxy/proxy.go
  24. 173 0
      psiphon/common/inproxy/records.go
  25. 173 0
      psiphon/common/inproxy/server.go
  26. 1525 0
      psiphon/common/inproxy/session.go
  27. 518 0
      psiphon/common/inproxy/session_test.go
  28. 1081 0
      psiphon/common/inproxy/webrtc.go
  29. 4 0
      psiphon/common/logger.go
  30. 5 0
      psiphon/common/prng/prng.go
  31. 2 0
      psiphon/common/protocol/protocol.go
  32. 126 34
      psiphon/common/protocol/serverEntry.go
  33. 77 0
      psiphon/common/protocol/serverEntry_test.go
  34. 22 0
      psiphon/common/utils.go

+ 2 - 0
.github/workflows/tests.yml

@@ -80,6 +80,7 @@ jobs:
           go test -v -race ./psiphon/common/accesscontrol
           go test -v -race ./psiphon/common/crypto/ssh
           go test -v -race ./psiphon/common/fragmentor
+          go test -v -race ./psiphon/common/inproxy
           go test -v -race ./psiphon/common/monotime
           go test -v -race ./psiphon/common/obfuscator
           go test -v -race ./psiphon/common/osl
@@ -111,6 +112,7 @@ jobs:
           go test -v -covermode=count -coverprofile=accesscontrol.coverprofile ./psiphon/common/accesscontrol
           go test -v -covermode=count -coverprofile=ssh.coverprofile ./psiphon/common/crypto/ssh
           go test -v -covermode=count -coverprofile=fragmentor.coverprofile ./psiphon/common/fragmentor
+          go test -v -covermode=count -coverprofile=inproxy.coverprofile ./psiphon/common/inproxy
           go test -v -covermode=count -coverprofile=monotime.coverprofile ./psiphon/common/monotime
           go test -v -covermode=count -coverprofile=obfuscator.coverprofile ./psiphon/common/obfuscator
           go test -v -covermode=count -coverprofile=osl.coverprofile ./psiphon/common/osl

+ 34 - 0
psiphon/common/inproxy/Dockerfile

@@ -0,0 +1,34 @@
+# docker build --no-cache=true -t psiphon-inproxy-test .
+#
+# docker run \
+#   --platform=linux/amd64 \
+#   --user "$(id -u):$(id -g)" \ [maybe omit this line]
+#   --rm \
+#   -v $(go env GOCACHE):/.cache/go-build \
+#   -v $(go env GOMODCACHE):/go/pkg/mod \
+#   -v $PWD:/go/src/inproxy \
+#   psiphon-inproxy-test \
+#   /bin/bash -c 'PION_LOG_TRACE=all go test -v -timeout 30s -run TestInProxy'
+
+FROM --platform=linux/amd64 ubuntu:18.04
+
+# Install system-level dependencies.
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update -y && apt-get install -y --no-install-recommends \
+    build-essential \
+    ca-certificates \
+    curl \
+    git \
+    pkg-config \
+  && apt-get clean \
+  && rm -rf /var/lib/apt/lists/*
+
+# Install Go.
+ENV GOVERSION=go1.19.2 GOROOT=/usr/local/go GOPATH=/go PATH=$PATH:/usr/local/go/bin:/go/bin CGO_ENABLED=1
+
+RUN curl -L https://storage.googleapis.com/golang/$GOVERSION.linux-amd64.tar.gz -o /tmp/go.tar.gz \
+   && tar -C /usr/local -xzf /tmp/go.tar.gz \
+   && rm /tmp/go.tar.gz \
+   && echo $GOVERSION > $GOROOT/VERSION
+
+WORKDIR /go/src/inproxy

+ 875 - 0
psiphon/common/inproxy/api.go

@@ -0,0 +1,875 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"crypto/rand"
+	"crypto/subtle"
+	"encoding/hex"
+	"fmt"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/pion/webrtc/v3"
+)
+
+const (
+	ProxyProtocolVersion1 = 1
+	MaxCompartmentIDs     = 10
+)
+
+// ID is a unique identifier used to identify inproxy connections and actors.
+type ID [32]byte
+
+// MakeID generates a new ID using crypto/rand.
+func MakeID() (ID, error) {
+	var id ID
+	_, err := rand.Read(id[:])
+	if err != nil {
+		return id, errors.Trace(err)
+	}
+	return id, nil
+}
+
+// IDFromString returns an ID given its string encoding.
+func IDFromString(s string) (ID, error) {
+	var id ID
+	value, err := hex.DecodeString(s)
+	if err != nil {
+		return id, errors.Trace(err)
+	}
+	if len(value) != len(id) {
+		return id, errors.TraceNew("invalid length")
+	}
+	copy(id[:], value)
+	return id, nil
+}
+
+// MarshalText emits IDs as hex.
+func (id ID) MarshalText() ([]byte, error) {
+	return []byte(id.String()), nil
+}
+
+// String emits IDs as hex.
+func (id ID) String() string {
+	return fmt.Sprintf("%x", []byte(id[:]))
+}
+
+// Equal indicates whether two IDs are equal. It uses a constant time
+// comparison.
+func (id ID) Equal(x ID) bool {
+	return subtle.ConstantTimeCompare(id[:], x[:]) == 1
+}
+
+// HaveCommonIDs indicates whether two lists of IDs have a common entry.
+func HaveCommonIDs(a, b []ID) bool {
+	for _, x := range a {
+		for _, y := range b {
+			// Each comparison is constant time, but the number of comparisons
+			// varies and might leak the size of a list.
+			if x.Equal(y) {
+				return true
+			}
+		}
+	}
+	return false
+}
+
+// TransportSecret is a value used to validate a broker transport front, such
+// as a CDN.
+type TransportSecret [32]byte
+
+// Equal indicates whether two TransportSecrets are equal. It uses a constant
+// time comparison.
+func (s TransportSecret) Equal(t TransportSecret) bool {
+	return subtle.ConstantTimeCompare(s[:], t[:]) == 1
+}
+
+// NetworkType is the type of a network, such as WiFi or Mobile. This enum is
+// used for compact API message encoding.
+type NetworkType int32
+
+const (
+	NetworkTypeUnknown NetworkType = iota
+	NetworkTypeWiFi
+	NetworkTypeMobile
+)
+
+// NetworkProtocol is an Internet protocol, such as TCP or UDP. This enum is
+// used for compact API message encoding.
+type NetworkProtocol int32
+
+const (
+	NetworkProtocolTCP NetworkProtocol = iota
+	NetworkProtocolUDP
+)
+
+// NetworkProtocolFromString converts a "net" package network protocol string
+// value to a NetworkProtocol.
+func NetworkProtocolFromString(networkProtocol string) (NetworkProtocol, error) {
+	switch networkProtocol {
+	case "tcp":
+		return NetworkProtocolTCP, nil
+	case "udp":
+		return NetworkProtocolUDP, nil
+	}
+	var p NetworkProtocol
+	return p, errors.Tracef("unknown network protocol: %s", networkProtocol)
+}
+
+// String converts a NetworkProtocol to a "net" package network protocol string.
+func (p NetworkProtocol) String() string {
+	switch p {
+	case NetworkProtocolTCP:
+		return "tcp"
+	case NetworkProtocolUDP:
+		return "udp"
+	}
+	// This case will cause net dials to fail.
+	return ""
+}
+
+// ProxyMetrics are network topolology and resource metrics provided by a
+// proxy to a broker. The broker uses this information when matching proxies
+// and clients.
+type ProxyMetrics struct {
+	BaseMetrics                   BaseMetrics      `cbor:"1,keyasint,omitempty"`
+	ProxyProtocolVersion          int32            `cbor:"2,keyasint,omitempty"`
+	NATType                       NATType          `cbor:"3,keyasint,omitempty"`
+	PortMappingTypes              PortMappingTypes `cbor:"4,keyasint,omitempty"`
+	MaxClients                    int32            `cbor:"6,keyasint,omitempty"`
+	ConnectingClients             int32            `cbor:"7,keyasint,omitempty"`
+	ConnectedClients              int32            `cbor:"8,keyasint,omitempty"`
+	LimitUpstreamBytesPerSecond   int64            `cbor:"9,keyasint,omitempty"`
+	LimitDownstreamBytesPerSecond int64            `cbor:"10,keyasint,omitempty"`
+	PeakUpstreamBytesPerSecond    int64            `cbor:"11,keyasint,omitempty"`
+	PeakDownstreamBytesPerSecond  int64            `cbor:"12,keyasint,omitempty"`
+}
+
+// ClientMetrics are network topolology metrics provided by a client to a
+// broker. The broker uses this information when matching proxies and
+// clients.
+type ClientMetrics struct {
+	BaseMetrics          BaseMetrics      `cbor:"1,keyasint,omitempty"`
+	ProxyProtocolVersion int32            `cbor:"2,keyasint,omitempty"`
+	NATType              NATType          `cbor:"3,keyasint,omitempty"`
+	PortMappingTypes     PortMappingTypes `cbor:"4,keyasint,omitempty"`
+}
+
+// ProxyAnnounceRequest is an API request sent from a proxy to a broker,
+// announcing that it is available for a client connection. Proxies send one
+// ProxyAnnounceRequest for each available client connection. The broker will
+// match the proxy with a a client and return WebRTC connection information
+// in the response.
+//
+// PersonalCompartmentIDs limits the clients to those that supply one of the
+// specified compartment IDs; personal compartment IDs are distributed from
+// proxy operators to client users out-of-band and provide optional access
+// control.
+//
+// The proxy's session public key is an implicit and cryptographically
+// verified proxy ID.
+type ProxyAnnounceRequest struct {
+	PersonalCompartmentIDs []ID          `cbor:"1,keyasint,omitempty"`
+	Metrics                *ProxyMetrics `cbor:"2,keyasint,omitempty"`
+}
+
+// TODO: send ProxyAnnounceRequest/ClientOfferRequest.Metrics only with the
+// first request in a session and cache.
+
+// ProxyAnnounceResponse returns the connection information for a matched
+// client. To establish a WebRTC connection, the proxy uses the client's
+// offer SDP to create its own answer SDP and send that to the broker in a
+// subsequent ProxyAnswerRequest. The ConnectionID is a unique identifier for
+// this single connection and must be relayed back in the ProxyAnswerRequest.
+//
+// ClientRootObfuscationSecret is generated (or replayed) by the client and
+// sent to the proxy and used to drive obfuscation operations.
+//
+// DestinationAddress is the dial address for the Psiphon server the proxy is
+// to relay client traffic with. The broker validates that the dial address
+// corresponds to a valid Psiphon server.
+//
+// OperatorMessageJSON is an optional message bundle to be forwarded to the
+// user interface for display to the user; for example, to alert the proxy
+// operator of configuration issue; the JSON schema is not defined here.
+type ProxyAnnounceResponse struct {
+	OperatorMessageJSON         string                    `cbor:"1,keyasint,omitempty"`
+	ConnectionID                ID                        `cbor:"2,keyasint,omitempty"`
+	ClientProxyProtocolVersion  int32                     `cbor:"3,keyasint,omitempty"`
+	ClientOfferSDP              webrtc.SessionDescription `cbor:"4,keyasint,omitempty"`
+	ClientRootObfuscationSecret ObfuscationSecret         `cbor:"5,keyasint,omitempty"`
+	NetworkProtocol             NetworkProtocol           `cbor:"6,keyasint,omitempty"`
+	DestinationAddress          string                    `cbor:"7,keyasint,omitempty"`
+}
+
+// ClientOfferRequest is an API request sent from a client to a broker,
+// requesting a proxy connection. The client sends its WebRTC offer SDP with
+// this request.
+//
+// Clients specify known compartment IDs and are matched with proxies in those
+// compartments. CommonCompartmentIDs are comparment IDs managed by Psiphon
+// and revealed through tactics or bundled with server lists.
+// PersonalCompartmentIDs are compartment IDs shared privately between users,
+// out-of-band.
+//
+// ClientRootObfuscationSecret is generated (or replayed) by the client and
+// sent to the proxy and used to drive obfuscation operations.
+//
+// To specify the Psiphon server it wishes to proxy to, the client sends the
+// full, digitally signed Psiphon server entry to the broker and also the
+// specific dial address that it has selected for that server. The broker
+// validates the server entry signature, the server in-proxy capability, and
+// that the dial address corresponds to the network protocol, IP address or
+// domain, and destination port for a valid Psiphon tunnel protocol run by
+// the specified server entry.
+type ClientOfferRequest struct {
+	Metrics                     *ClientMetrics            `cbor:"1,keyasint,omitempty"`
+	CommonCompartmentIDs        []ID                      `cbor:"2,keyasint,omitempty"`
+	PersonalCompartmentIDs      []ID                      `cbor:"3,keyasint,omitempty"`
+	ClientOfferSDP              webrtc.SessionDescription `cbor:"4,keyasint,omitempty"`
+	ICECandidateTypes           ICECandidateTypes         `cbor:"5,keyasint,omitempty"`
+	ClientRootObfuscationSecret ObfuscationSecret         `cbor:"6,keyasint,omitempty"`
+	DestinationServerEntryJSON  []byte                    `cbor:"7,keyasint,omitempty"`
+	NetworkProtocol             NetworkProtocol           `cbor:"8,keyasint,omitempty"`
+	DestinationAddress          string                    `cbor:"9,keyasint,omitempty"`
+}
+
+// TODO: Encode SDPs using CBOR without field names, simliar to base metrics
+// transformation? Same with DestinationServerEntryJSON.
+
+// ClientOfferResponse returns the connecting information for a matched proxy.
+// The proxy's WebRTC SDP is an answer to the offer sent in
+// ClientOfferRequest and is used to begin dialing the WebRTC connection.
+//
+// Once the client completes its connection to the Psiphon server, it must
+// relay a BrokerServerRequest to the server on behalf of the broker. This
+// relay is conducted within a secure session. First, the client sends
+// RelayPacketToServer to the server. Then the client relays the response to
+// the broker using ClientRelayedPacketRequests and continues to relay using
+// ClientRelayedPacketRequests until complete. ConnectionID identifies this
+// connection and its relayed BrokerServerRequest.
+type ClientOfferResponse struct {
+	ConnectionID                 ID                        `cbor:"1,keyasint,omitempty"`
+	SelectedProxyProtocolVersion int32                     `cbor:"2,keyasint,omitempty"`
+	ProxyAnswerSDP               webrtc.SessionDescription `cbor:"3,keyasint,omitempty"`
+	RelayPacketToServer          []byte                    `cbor:"4,keyasint,omitempty"`
+}
+
+// ProxyAnswerRequest is an API request sent from a proxy to a broker,
+// following ProxyAnnounceResponse, with the WebRTC answer SDP corresponding
+// to the client offer SDP received in ProxyAnnounceResponse. ConnectionID
+// identifies the connection begun in ProxyAnnounceResponse.
+//
+// If the proxy was unable to establish an answer SDP or failed for some other
+// reason, it should still send ProxyAnswerRequest with AnswerError
+// populated; the broker will signal the client to abort this connection.
+type ProxyAnswerRequest struct {
+	ConnectionID                 ID                        `cbor:"1,keyasint,omitempty"`
+	SelectedProxyProtocolVersion int32                     `cbor:"2,keyasint,omitempty"`
+	ProxyAnswerSDP               webrtc.SessionDescription `cbor:"3,keyasint,omitempty"`
+	ICECandidateTypes            ICECandidateTypes         `cbor:"4,keyasint,omitempty"`
+	AnswerError                  string                    `cbor:"5,keyasint,omitempty"`
+}
+
+// ProxyAnswerResponse is the acknowledgement for a ProxyAnswerRequest.
+type ProxyAnswerResponse struct {
+}
+
+// ClientRelayedPacketRequest is an API request sent from a client to a
+// broker, relaying a secure session packet from the Psiphon server to the
+// broker. This relay is a continuation of the broker/server exchange begun
+// with ClientOfferResponse.RelayPacketToServer. PacketFromServer is the next
+// packet from the server. SessionInvalid indicates, to the broker, that the
+// session is invalid -- it may have expired -- and so the broker should
+// begin establishing a new session, and then send its BrokerServerRequest in
+// that new session.
+type ClientRelayedPacketRequest struct {
+	ConnectionID     ID     `cbor:"1,keyasint,omitempty"`
+	PacketFromServer []byte `cbor:"2,keyasint,omitempty"`
+	SessionInvalid   bool   `cbor:"3,keyasint,omitempty"`
+}
+
+// ClientRelayedPacketResponse returns the next packet from the broker to the
+// server. When PacketToServer is empty, the broker/server exchange is done
+// and the client stops relaying packets.
+type ClientRelayedPacketResponse struct {
+	PacketToServer []byte `cbor:"1,keyasint,omitempty"`
+}
+
+// BrokerServerRequest is an API request sent from a broker to a Psiphon
+// server. This delivers, to the server, information that neither the client
+// nor the proxy is trusted to report. ProxyID is the proxy ID to be logged
+// with server_tunnel to attribute traffic to a specific proxy. ClientIP is
+// the original client IP as seen by the broker; this is the IP value to be
+// used in GeoIP-related operations including traffic rules, tactics, and OSL
+// progress. ProxyIP is the proxy IP as seen by the broker; this value should
+// match the Psiphon's server observed client IP. Additional fields are
+// metrics to be logged with server_tunnel.
+type BrokerServerRequest struct {
+	ProxyID                     ID               `cbor:"1,keyasint,omitempty"`
+	ConnectionID                ID               `cbor:"2,keyasint,omitempty"`
+	MatchedCommonCompartments   bool             `cbor:"3,keyasint,omitempty"`
+	MatchedPersonalCompartments bool             `cbor:"4,keyasint,omitempty"`
+	ProxyNATType                NATType          `cbor:"5,keyasint,omitempty"`
+	ProxyPortMappingTypes       PortMappingTypes `cbor:"6,keyasint,omitempty"`
+	ClientNATType               NATType          `cbor:"7,keyasint,omitempty"`
+	ClientPortMappingTypes      PortMappingTypes `cbor:"8,keyasint,omitempty"`
+	ClientIP                    string           `cbor:"9,keyasint,omitempty"`
+	ProxyIP                     string           `cbor:"10,keyasint,omitempty"`
+}
+
+// BrokerServerResponse returns an acknowledgement of the BrokerServerRequest
+// to the broker from the Psiphon server. The ConnectionID must match the
+// value in the BrokerServerRequest.
+type BrokerServerResponse struct {
+	ConnectionID ID     `cbor:"1,keyasint,omitempty"`
+	ErrorMessage string `cbor:"2,keyasint,omitempty"`
+}
+
+// BaseMetrics is a compact encoding of Psiphon base API metrics, such as
+// sponsor_id, client_platform, and so on.
+type BaseMetrics map[int]interface{}
+
+// GetNetworkType extracts the network_type from base metrics and returns a
+// corresponding NetworkType. This is the one base metric that is used in the
+// broker logic, and not simply logged.
+func (metrics BaseMetrics) GetNetworkType() NetworkType {
+	key, ok := baseMetricsNameToInt["network_type"]
+	if !ok {
+		return NetworkTypeUnknown
+	}
+	value, ok := metrics[key]
+	if !ok {
+		return NetworkTypeUnknown
+	}
+	strValue, ok := value.(string)
+	if !ok {
+		return NetworkTypeUnknown
+	}
+	switch strValue {
+	case "WIFI":
+		return NetworkTypeWiFi
+	case "MOBILE":
+		return NetworkTypeMobile
+	}
+	return NetworkTypeUnknown
+}
+
+func EncodeBaseMetrics(params common.APIParameters) (BaseMetrics, error) {
+	metrics := BaseMetrics{}
+	for name, value := range params {
+		key, ok := baseMetricsNameToInt[name]
+		if !ok {
+			// The API metric to be sent is not in baseMetricsNameToInt. This
+			// will occur if baseMetricsNameToInt is not updated when new API
+			// metrics are added. Fail the operation and, ultimately, the
+			// dial rather than proceeding without the metric.
+			return nil, errors.Tracef("unknown name: %s", name)
+		}
+		metrics[key] = value
+
+	}
+	return metrics, nil
+}
+
+func DecodeBaseMetrics(metrics BaseMetrics) common.APIParameters {
+	params := common.APIParameters{}
+	for key, value := range metrics {
+		name, ok := baseMetricsIntToName[key]
+		if !ok {
+			// The API metric received is not in baseMetricsNameToInt. Skip
+			// logging it and proceed.
+			continue
+		}
+		params[name] = value
+
+	}
+	return params
+}
+
+// Sanity check lengths for array inputs.
+const (
+	maxICECandidateTypes = 10
+	maxPortMappingTypes  = 10
+)
+
+// ValidateAndGetLogFields validates the ProxyMetrics and returns
+// common.LogFields for logging.
+func (metrics *ProxyMetrics) ValidateAndGetLogFields(
+	baseMetricsValidator common.APIParameterValidator,
+	formatter common.APIParameterLogFieldFormatter,
+	geoIPData common.GeoIPData) (common.LogFields, error) {
+
+	if metrics.BaseMetrics == nil {
+		return nil, errors.TraceNew("missing base metrics")
+	}
+
+	baseMetrics := DecodeBaseMetrics(metrics.BaseMetrics)
+
+	err := baseMetricsValidator(baseMetrics)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	if metrics.ProxyProtocolVersion != ProxyProtocolVersion1 {
+		return nil, errors.Tracef("invalid proxy protocol version: %v", metrics.ProxyProtocolVersion)
+	}
+
+	if !metrics.NATType.IsValid() {
+		return nil, errors.Tracef("invalid NAT type: %v", metrics.NATType)
+	}
+
+	if len(metrics.PortMappingTypes) > maxPortMappingTypes {
+		return nil, errors.Tracef("invalid portmapping types length: %d", len(metrics.PortMappingTypes))
+	}
+
+	if !metrics.PortMappingTypes.IsValid() {
+		return nil, errors.Tracef("invalid portmapping types: %v", metrics.PortMappingTypes)
+	}
+
+	logFields := formatter(geoIPData, baseMetrics)
+
+	logFields["proxy_protocol_version"] = metrics.ProxyProtocolVersion
+	logFields["nat_type"] = metrics.NATType
+	logFields["port_mapping_types"] = metrics.PortMappingTypes
+	logFields["max_clients"] = metrics.MaxClients
+	logFields["connecting_clients"] = metrics.ConnectingClients
+	logFields["connected_clients"] = metrics.ConnectedClients
+	logFields["limit_upstream_bytes_per_second"] = metrics.LimitUpstreamBytesPerSecond
+	logFields["limit_downstream_bytes_per_second"] = metrics.LimitDownstreamBytesPerSecond
+	logFields["peak_upstream_bytes_per_second"] = metrics.PeakUpstreamBytesPerSecond
+	logFields["peak_downstream_bytes_per_second"] = metrics.PeakDownstreamBytesPerSecond
+
+	return logFields, nil
+}
+
+// ValidateAndGetLogFields validates the ClientMetrics and returns
+// common.LogFields for logging.
+func (metrics *ClientMetrics) ValidateAndGetLogFields(
+	baseMetricsValidator common.APIParameterValidator,
+	formatter common.APIParameterLogFieldFormatter,
+	geoIPData common.GeoIPData) (common.LogFields, error) {
+
+	if metrics.BaseMetrics == nil {
+		return nil, errors.TraceNew("missing base metrics")
+	}
+
+	baseMetrics := DecodeBaseMetrics(metrics.BaseMetrics)
+
+	err := baseMetricsValidator(baseMetrics)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	if metrics.ProxyProtocolVersion != ProxyProtocolVersion1 {
+		return nil, errors.Tracef("invalid proxy protocol version: %v", metrics.ProxyProtocolVersion)
+	}
+
+	if !metrics.NATType.IsValid() {
+		return nil, errors.Tracef("invalid NAT type: %v", metrics.NATType)
+	}
+
+	if len(metrics.PortMappingTypes) > maxPortMappingTypes {
+		return nil, errors.Tracef("invalid portmapping types length: %d", len(metrics.PortMappingTypes))
+	}
+
+	if !metrics.PortMappingTypes.IsValid() {
+		return nil, errors.Tracef("invalid portmapping types: %v", metrics.PortMappingTypes)
+	}
+
+	logFields := formatter(geoIPData, baseMetrics)
+
+	logFields["proxy_protocol_version"] = metrics.ProxyProtocolVersion
+	logFields["nat_type"] = metrics.NATType
+	logFields["port_mapping_types"] = metrics.PortMappingTypes
+
+	return logFields, nil
+}
+
+// ValidateAndGetLogFields validates the ProxyAnnounceRequest and returns
+// common.LogFields for logging.
+func (request *ProxyAnnounceRequest) ValidateAndGetLogFields(
+	baseMetricsValidator common.APIParameterValidator,
+	formatter common.APIParameterLogFieldFormatter,
+	geoIPData common.GeoIPData) (common.LogFields, error) {
+
+	if len(request.PersonalCompartmentIDs) > MaxCompartmentIDs {
+		return nil, errors.Tracef("invalid compartment IDs length: %d", len(request.PersonalCompartmentIDs))
+	}
+
+	if request.Metrics == nil {
+		return nil, errors.TraceNew("missing metrics")
+	}
+
+	logFields, err := request.Metrics.ValidateAndGetLogFields(
+		baseMetricsValidator, formatter, geoIPData)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// PersonalCompartmentIDs are user-generated and shared out-of-band;
+	// values are not logged since they may link users.
+
+	hasPersonalCompartmentIDs := len(request.PersonalCompartmentIDs) > 0
+
+	logFields["has_personal_compartment_ids"] = hasPersonalCompartmentIDs
+
+	return logFields, nil
+}
+
+// ClientOfferRequest validates the ProxyAnnounceRequest and returns
+// common.LogFields for logging.
+func (request *ClientOfferRequest) ValidateAndGetLogFields(
+	lookupGeoIP LookupGeoIP,
+	baseMetricsValidator common.APIParameterValidator,
+	formatter common.APIParameterLogFieldFormatter,
+	geoIPData common.GeoIPData) (common.LogFields, error) {
+
+	if len(request.CommonCompartmentIDs) > MaxCompartmentIDs {
+		return nil, errors.Tracef("invalid compartment IDs length: %d", len(request.CommonCompartmentIDs))
+	}
+
+	if len(request.PersonalCompartmentIDs) > MaxCompartmentIDs {
+		return nil, errors.Tracef("invalid compartment IDs length: %d", len(request.PersonalCompartmentIDs))
+	}
+
+	// Client offer SDP candidate addresses must match the country and ASN of
+	// the client. Don't facilitate connections to arbitrary destinations.
+	sdpMetrics, err := ValidateSDPAddresses([]byte(request.ClientOfferSDP.SDP), lookupGeoIP, geoIPData)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// The client's self-reported ICECandidateTypes are used instead of the
+	// candidate types that can be derived from the SDP, since port mapping
+	// types are edited into the SDP in a way that makes them
+	// indistinguishable from host candidate types.
+
+	if !request.ICECandidateTypes.IsValid() {
+		return nil, errors.Tracef("invalid ICE candidate types: %v", request.ICECandidateTypes)
+	}
+
+	if request.Metrics == nil {
+		return nil, errors.TraceNew("missing metrics")
+	}
+
+	logFields, err := request.Metrics.ValidateAndGetLogFields(
+		baseMetricsValidator, formatter, geoIPData)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// CommonCompartmentIDs are generated and managed and are a form of
+	// obfuscation secret, so are not logged. PersonalCompartmentIDs are
+	// user-generated and shared out-of-band; values are not logged since
+	// they may link users.
+
+	hasCommonCompartmentIDs := len(request.CommonCompartmentIDs) > 0
+	hasPersonalCompartmentIDs := len(request.PersonalCompartmentIDs) > 0
+
+	logFields["has_common_compartment_ids"] = hasCommonCompartmentIDs
+	logFields["has_personal_compartment_ids"] = hasPersonalCompartmentIDs
+	logFields["ice_candidate_types"] = request.ICECandidateTypes
+	logFields["has_IPv6"] = sdpMetrics.HasIPv6
+
+	return logFields, nil
+}
+
+// ProxyAnswerRequest validates the ProxyAnnounceRequest and returns
+// common.LogFields for logging.
+func (request *ProxyAnswerRequest) ValidateAndGetLogFields(
+	lookupGeoIP LookupGeoIP,
+	baseMetricsValidator common.APIParameterValidator,
+	formatter common.APIParameterLogFieldFormatter,
+	geoIPData common.GeoIPData) (common.LogFields, error) {
+
+	// Proxy answer SDP candidate addresses must match the country and ASN of
+	// the proxy. Don't facilitate connections to arbitrary destinations.
+	sdpMetrics, err := ValidateSDPAddresses([]byte(request.ProxyAnswerSDP.SDP), lookupGeoIP, geoIPData)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// The proxy's self-reported ICECandidateTypes are used instead of the
+	// candidate types that can be derived from the SDP, since port mapping
+	// types are edited into the SDP in a way that makes them
+	// indistinguishable from host candidate types.
+
+	if !request.ICECandidateTypes.IsValid() {
+		return nil, errors.Tracef("invalid ICE candidate types: %v", request.ICECandidateTypes)
+	}
+
+	if request.SelectedProxyProtocolVersion != ProxyProtocolVersion1 {
+		return nil, errors.Tracef("invalid select proxy protocol version: %v", request.SelectedProxyProtocolVersion)
+	}
+
+	logFields := formatter(geoIPData, common.APIParameters{})
+
+	logFields["connection_id"] = request.ConnectionID
+	logFields["ice_candidate_types"] = request.ICECandidateTypes
+	logFields["has_IPv6"] = sdpMetrics.HasIPv6
+	logFields["answer_error"] = request.AnswerError
+
+	return logFields, nil
+}
+
+// ClientRelayedPacketRequest validates the ProxyAnnounceRequest and returns
+// common.LogFields for logging.
+func (request *ClientRelayedPacketRequest) ValidateAndGetLogFields(
+	baseMetricsValidator common.APIParameterValidator,
+	formatter common.APIParameterLogFieldFormatter,
+	geoIPData common.GeoIPData) (common.LogFields, error) {
+
+	logFields := formatter(geoIPData, common.APIParameters{})
+
+	logFields["connection_id"] = request.ConnectionID
+	logFields["session_invalid"] = request.SessionInvalid
+
+	return logFields, nil
+}
+
+// BrokerServerRequest validates the ProxyAnnounceRequest and returns
+// common.LogFields for logging.
+func (request *BrokerServerRequest) ValidateAndGetLogFields() (common.LogFields, error) {
+
+	if !request.ProxyNATType.IsValid() {
+		return nil, errors.Tracef("invalid proxy NAT type: %v", request.ProxyNATType)
+	}
+
+	if !request.ProxyPortMappingTypes.IsValid() {
+		return nil, errors.Tracef("invalid proxy portmapping types: %v", request.ProxyPortMappingTypes)
+	}
+
+	if !request.ClientNATType.IsValid() {
+		return nil, errors.Tracef("invalid client NAT type: %v", request.ClientNATType)
+	}
+
+	if !request.ClientPortMappingTypes.IsValid() {
+		return nil, errors.Tracef("invalid client portmapping types: %v", request.ClientPortMappingTypes)
+	}
+
+	// Neither ClientIP nor ProxyIP is logged.
+
+	logFields := common.LogFields{}
+
+	logFields["proxy_id"] = request.ProxyID
+	logFields["connection_id"] = request.ConnectionID
+	logFields["matched_common_compartments"] = request.MatchedCommonCompartments
+	logFields["matched_personal_compartments"] = request.MatchedPersonalCompartments
+	logFields["proxy_nat_type"] = request.ProxyNATType
+	logFields["proxy_port_mapping_types"] = request.ProxyPortMappingTypes
+	logFields["client_nat_type"] = request.ClientNATType
+	logFields["client_port_mapping_types"] = request.ClientPortMappingTypes
+
+	return common.LogFields{}, nil
+}
+
+func MarshalProxyAnnounceRequest(request *ProxyAnnounceRequest) ([]byte, error) {
+	payload, err := marshalRecord(request, recordTypeAPIProxyAnnounceRequest)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalProxyAnnounceRequest(payload []byte) (*ProxyAnnounceRequest, error) {
+	var request *ProxyAnnounceRequest
+	err := unmarshalRecord(recordTypeAPIProxyAnnounceRequest, payload, &request)
+	return request, errors.Trace(err)
+}
+
+func MarshalProxyAnnounceResponse(response *ProxyAnnounceResponse) ([]byte, error) {
+	payload, err := marshalRecord(response, recordTypeAPIProxyAnnounceResponse)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalProxyAnnounceResponse(payload []byte) (*ProxyAnnounceResponse, error) {
+	var response *ProxyAnnounceResponse
+	err := unmarshalRecord(recordTypeAPIProxyAnnounceResponse, payload, &response)
+	return response, errors.Trace(err)
+}
+
+func MarshalProxyAnswerRequest(request *ProxyAnswerRequest) ([]byte, error) {
+	payload, err := marshalRecord(request, recordTypeAPIProxyAnswerRequest)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalProxyAnswerRequest(payload []byte) (*ProxyAnswerRequest, error) {
+	var request *ProxyAnswerRequest
+	err := unmarshalRecord(recordTypeAPIProxyAnswerRequest, payload, &request)
+	return request, errors.Trace(err)
+}
+
+func MarshalProxyAnswerResponse(response *ProxyAnswerResponse) ([]byte, error) {
+	payload, err := marshalRecord(response, recordTypeAPIProxyAnswerResponse)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalProxyAnswerResponse(payload []byte) (*ProxyAnswerResponse, error) {
+	var response *ProxyAnswerResponse
+	err := unmarshalRecord(recordTypeAPIProxyAnswerResponse, payload, &response)
+	return response, errors.Trace(err)
+}
+
+func MarshalClientOfferRequest(request *ClientOfferRequest) ([]byte, error) {
+	payload, err := marshalRecord(request, recordTypeAPIClientOfferRequest)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalClientOfferRequest(payload []byte) (*ClientOfferRequest, error) {
+	var request *ClientOfferRequest
+	err := unmarshalRecord(recordTypeAPIClientOfferRequest, payload, &request)
+	return request, errors.Trace(err)
+}
+
+func MarshalClientOfferResponse(response *ClientOfferResponse) ([]byte, error) {
+	payload, err := marshalRecord(response, recordTypeAPIClientOfferResponse)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalClientOfferResponse(payload []byte) (*ClientOfferResponse, error) {
+	var response *ClientOfferResponse
+	err := unmarshalRecord(recordTypeAPIClientOfferResponse, payload, &response)
+	return response, errors.Trace(err)
+}
+
+func MarshalClientRelayedPacketRequest(request *ClientRelayedPacketRequest) ([]byte, error) {
+	payload, err := marshalRecord(request, recordTypeAPIClientRelayedPacketRequest)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalClientRelayedPacketRequest(payload []byte) (*ClientRelayedPacketRequest, error) {
+	var request *ClientRelayedPacketRequest
+	err := unmarshalRecord(recordTypeAPIClientRelayedPacketRequest, payload, &request)
+	return request, errors.Trace(err)
+}
+
+func MarshalClientRelayedPacketResponse(response *ClientRelayedPacketResponse) ([]byte, error) {
+	payload, err := marshalRecord(response, recordTypeAPIClientRelayedPacketResponse)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalClientRelayedPacketResponse(payload []byte) (*ClientRelayedPacketResponse, error) {
+	var response *ClientRelayedPacketResponse
+	err := unmarshalRecord(recordTypeAPIClientRelayedPacketResponse, payload, &response)
+	return response, errors.Trace(err)
+}
+
+func MarshalBrokerServerRequest(request *BrokerServerRequest) ([]byte, error) {
+	payload, err := marshalRecord(request, recordTypeAPIBrokerServerRequest)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalBrokerServerRequest(payload []byte) (*BrokerServerRequest, error) {
+	var request *BrokerServerRequest
+	err := unmarshalRecord(recordTypeAPIBrokerServerRequest, payload, &request)
+	return request, errors.Trace(err)
+}
+
+func MarshalBrokerServerResponse(response *BrokerServerResponse) ([]byte, error) {
+	payload, err := marshalRecord(response, recordTypeAPIBrokerServerResponse)
+	return payload, errors.Trace(err)
+}
+
+func UnmarshalBrokerServerResponse(payload []byte) (*BrokerServerResponse, error) {
+	var response *BrokerServerResponse
+	err := unmarshalRecord(recordTypeAPIBrokerServerResponse, payload, &response)
+	return response, errors.Trace(err)
+}
+
+var (
+	baseMetricsNameToInt map[string]int
+	baseMetricsIntToName map[int]string
+)
+
+func init() {
+
+	// Initialize maps from base metrics JSON field names to CBOR field
+	// numbers. This list must be updated when new base metrics are added,
+	// and the new metrics must be appended so as to maintain the field
+	// number ordering.
+
+	names := []string{
+		"server_secret",
+		"client_session_id",
+		"propagation_channel_id",
+		"sponsor_id",
+		"client_version",
+		"client_platform",
+		"client_features",
+		"client_build_rev",
+		"device_region",
+		"session_id",
+		"relay_protocol",
+		"ssh_client_version",
+		"upstream_proxy_type",
+		"upstream_proxy_custom_header_names",
+		"fronting_provider_id",
+		"meek_dial_address",
+		"meek_resolved_ip_address",
+		"meek_sni_server_name",
+		"meek_host_header",
+		"meek_transformed_host_name",
+		"user_agent",
+		"tls_profile",
+		"tls_version",
+		"server_entry_region",
+		"server_entry_source",
+		"server_entry_timestamp",
+		"applied_tactics_tag",
+		"dial_port_number",
+		"quic_version",
+		"quic_dial_sni_address",
+		"quic_disable_client_path_mtu_discovery",
+		"upstream_bytes_fragmented",
+		"upstream_min_bytes_written",
+		"upstream_max_bytes_written",
+		"upstream_min_delayed",
+		"upstream_max_delayed",
+		"padding",
+		"pad_response",
+		"is_replay",
+		"egress_region",
+		"dial_duration",
+		"candidate_number",
+		"established_tunnels_count",
+		"upstream_ossh_padding",
+		"meek_cookie_size",
+		"meek_limit_request",
+		"meek_tls_padding",
+		"network_latency_multiplier",
+		"client_bpf",
+		"network_type",
+		"conjure_cached",
+		"conjure_delay",
+		"conjure_transport",
+		"split_tunnel",
+		"split_tunnel_regions",
+		"dns_preresolved",
+		"dns_preferred",
+		"dns_transform",
+		"dns_attempt",
+	}
+
+	baseMetricsNameToInt = make(map[string]int)
+	baseMetricsIntToName = make(map[int]string)
+	for i, name := range names {
+		baseMetricsNameToInt[name] = i
+		baseMetricsIntToName[i] = name
+	}
+}

+ 1191 - 0
psiphon/common/inproxy/broker.go

@@ -0,0 +1,1191 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"encoding/base64"
+	"encoding/json"
+	"net"
+	"strconv"
+	"sync/atomic"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+	"github.com/buraksezer/consistent"
+	"github.com/cespare/xxhash"
+	lrucache "github.com/cognusion/go-cache-lru"
+)
+
+const (
+
+	// BrokerReadTimeout is the read timeout, the duration before a request is
+	// fully read, that should be applied by the provided broker transport.
+	// For example, when the provided transport is net/http, set
+	// http.Server.ReadTimeout to at least BrokerReadTimeout.
+	BrokerReadTimeout = 5 * time.Second
+
+	// BrokerWriteTimeout is the write timeout, the duration before a response
+	// is fully written, that should be applied by the provided broker
+	// transport. This timeout accomodates the long polling performed by the
+	// proxy announce request. Both the immediate transport provider and any
+	// front (e.g., a CDN) must be configured to use this timeout.
+	BrokerWriteTimeout = (brokerProxyAnnounceTimeout + 5*time.Second)
+
+	// BrokerIdleTimeout is the idle timeout, the duration before an idle
+	// persistent connection is closed, that should be applied by the
+	// provided broker transport.
+	BrokerIdleTimeout = 2 * time.Minute
+
+	// BrokerMaxRequestBodySize is the maximum request size, that should be
+	// enforced by the provided broker transport.
+	BrokerMaxRequestBodySize = 65536
+
+	brokerProxyAnnounceTimeout         = 2 * time.Minute
+	brokerClientOfferTimeout           = 10 * time.Second
+	brokerPendingServerRequestsTTL     = 60 * time.Second
+	brokerPendingServerRequestsMaxSize = 100000
+	brokerMetricName                   = "in-proxy-broker"
+)
+
+// LookupGeoIP is a callback for providing GeoIP lookup service.
+type LookupGeoIP func(IP string) common.GeoIPData
+
+// Broker is the in-proxy broker component, which matches clients and proxies
+// and provides WebRTC signaling functionalty.
+//
+// Both clients and proxies send requests to the broker to obtain matches and
+// exchange WebRTC SDPs. Broker does not implement a transport or obfuscation
+// layer; instead that is provided by the HandleSessionPacket caller. A
+// typical implementation would provide a domain fronted web server which
+// runs a Broker and calls Broker.HandleSessionPacket to handle web requests
+// encapsulating secure session packets.
+type Broker struct {
+	config                *BrokerConfig
+	initiatorSessions     *InitiatorSessions
+	responderSessions     *ResponderSessions
+	matcher               *Matcher
+	pendingServerRequests *lrucache.Cache
+	commonCompartments    *consistent.Consistent
+}
+
+// BrokerConfig specifies the configuration for a Broker.
+type BrokerConfig struct {
+
+	// Logger is used to log events.
+	Logger common.Logger
+
+	// CommonCompartmentIDs is a list of common compartment IDs to apply to
+	// proxies that announce without personal compartment ID. Common
+	// compartment IDs are managed by Psiphon and distributed to clients via
+	// tactics or embedded in OSLs. Clients must supply a valid compartment
+	// ID to match with a proxy.
+	CommonCompartmentIDs []ID
+
+	// AllowProxy is a callback which can indicate whether a proxy with the
+	// given GeoIP data is allowed to match with common compartment ID
+	// clients. Proxies with personal compartment IDs are always allowed.
+	AllowProxy func(common.GeoIPData) bool
+
+	// AllowClient is a callback which can indicate whether a client with the
+	// given GeoIP data is allowed to match with common compartment ID
+	// proxies. Clients are always allowed to match based on personal
+	// compartment ID.
+	AllowClient func(common.GeoIPData) bool
+
+	// AllowDomainDestination is a callback which can indicate whether a
+	// client with the given GeoIP data is allowed to specify a proxied
+	// destination with a domain name. When false, only IP address
+	// destinations are allowed.
+	//
+	// While tactics may may be set to instruct clients to use only direct
+	// server tunnel protocols, with IP address destinations, this callback
+	// adds server-side enforcement.
+	AllowDomainDestination func(common.GeoIPData) bool
+
+	// LookupGeoIP provides GeoIP lookup service.
+	LookupGeoIP LookupGeoIP
+
+	// APIParameterValidator is a callback that validates base API metrics.
+	APIParameterValidator common.APIParameterValidator
+
+	// APIParameterValidator is a callback that formats base API metrics.
+	APIParameterLogFieldFormatter common.APIParameterLogFieldFormatter
+
+	// TransportSecret is a value that must be supplied by the provided
+	// transport. In the case of domain fronting, this is used to validate
+	// that the peer is a trusted CDN, and so it's relayed client IP
+	// (e.g, X-Forwarded-For header) is legitimate.
+	TransportSecret TransportSecret
+
+	// PrivateKey is the broker's secure session long term private key.
+	PrivateKey SessionPrivateKey
+
+	// ObfuscationRootSecret broker's secure session long term obfuscation key.
+	ObfuscationRootSecret ObfuscationSecret
+
+	// ServerEntrySignaturePublicKey is the key used to verify Psiphon server
+	// entry signatures.
+	ServerEntrySignaturePublicKey string
+
+	// IsValidServerEntryTag is a callback which checks if the specified
+	// server entry tag is on the list of valid and active Psiphon server
+	// entry tags.
+	IsValidServerEntryTag func(serverEntryTag string) bool
+
+	// These timeout parameters may be used to override defaults.
+	ProxyAnnounceTimeout     time.Duration
+	ClientOfferTimeout       time.Duration
+	PendingServerRequestsTTL time.Duration
+}
+
+// NewBroker initializes a new Broker.
+func NewBroker(config *BrokerConfig) (*Broker, error) {
+
+	// At least one common compatment ID is required. At a minimum, one ID
+	// will be used and distributed to clients via tactics, limiting matching
+	// to those clients targeted to receive that tactic parameters.
+
+	if len(config.CommonCompartmentIDs) == 0 {
+		return nil, errors.TraceNew("missing common compartment IDs")
+	}
+
+	// initiatorSessions are secure sessions initiated by the broker and used
+	// to send BrokerServerRequests to servers. The servers will be
+	// configured to establish sessions only with brokers with specified
+	// public keys.
+
+	initiatorSessions := NewInitiatorSessions(config.PrivateKey)
+
+	// responderSessions are secure sessions initiated by clients and proxies
+	// and used to send requests to the broker. Clients and proxies are
+	// configured to establish sessions only with specified broker public keys.
+
+	responderSessions, err := NewResponderSessions(config.PrivateKey, config.ObfuscationRootSecret)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	b := &Broker{
+		config:            config,
+		initiatorSessions: initiatorSessions,
+		responderSessions: responderSessions,
+		matcher: NewMatcher(&MatcherConfig{
+			Logger: config.Logger,
+		}),
+	}
+
+	b.pendingServerRequests = lrucache.NewWithLRU(
+		common.ValueOrDefault(config.PendingServerRequestsTTL, brokerPendingServerRequestsTTL),
+		1*time.Minute,
+		brokerPendingServerRequestsMaxSize)
+	b.pendingServerRequests.OnEvicted(b.evictedPendingServerRequest)
+
+	b.initializeCommonCompartmentIDHashing()
+
+	return b, nil
+}
+
+func (b *Broker) Start() error {
+	return errors.Trace(b.matcher.Start())
+}
+
+func (b *Broker) Stop() {
+	b.matcher.Stop()
+}
+
+// HandleSessionPacket handles a session packet from a client or proxy and
+// provides a response packet. The packet is part of a secure session and may
+// be a session handshake message, or a session-wrapped request payload.
+// Request payloads are routed to API request endpoints.
+//
+// The caller is expected to provide a transport obfuscation layer, such as
+// domain fronted HTTPs. The session has an obfuscation layer that ensures
+// that packets are fully random, randomly padded, and cannot be replayed.
+// This makes session packets suitable to embed as plaintext in some
+// transports.
+//
+// The caller is responsible for rate limiting and enforcing timeouts and
+// maximum payload size checks.
+//
+// Secure sessions support multiplexing concurrent requests, as long as the
+// provided transport, for example HTTP/2, supports this as well.
+//
+// The input ctx should be canceled if the client/proxy disconnects from the
+// transport while HandleSessionPacket is running, since long-polling proxy
+// announcement requests will otherwise remain blocked until eventual
+// timeout; net/http does this.
+//
+// When HandleSessionPacket returns an error, the transport provider should
+// apply anti-probing mechanisms, since the client/proxy may be a prober or
+// scanner. When a client/proxy tries to use an existing session that has
+// expired on the broker, this results in an error. This failure must be
+// relayed to the client/proxy, which will then start establishing a new
+// session. No specifics about the expiry error case need to be or should be
+// transmitted by the transport. For example, with an HTTP-type transport, a
+// generic 404 error should suffice both as an anti-probing response and as a
+// signal that a session is expired. Furthermore, HTTP-type transports may
+// keep underlying network connections open in both the anti-probing and
+// expired session cases, which facilitates a fast re-establishment by
+// legitimate clients/proxies.
+func (b *Broker) HandleSessionPacket(
+	ctx context.Context,
+	transportSecret TransportSecret,
+	brokerClientIP string,
+	geoIPData common.GeoIPData,
+	inPacket []byte) ([]byte, error) {
+
+	// Check that the transport peer has supplied the expected transport secret.
+	// In the case of CDN domain fronting, the trusted CDN is configured to
+	// add an HTTP header containing the secret. The original client IP and
+	// derived GeoIP information is only trusted when the correct transport
+	// secret is supplied. The security of the secret depends on the
+	// transport; for example, HTTPS between the CDN and the broker; the
+	// transport secret cannot be injected into a secure session.
+
+	if !b.config.TransportSecret.Equal(transportSecret) {
+		return nil, errors.TraceNew("invalid transport secret")
+	}
+
+	// handleUnwrappedRequest handles requests after session unwrapping.
+	// responderSessions.HandlePacket handles both session establishment and
+	// request unwrapping, and invokes handleUnwrappedRequest once a session
+	// is established and a valid request unwrapped.
+
+	handleUnwrappedRequest := func(initiatorID ID, unwrappedRequestPayload []byte) ([]byte, error) {
+
+		recordType, err := peekRecordPreambleType(unwrappedRequestPayload)
+
+		var responsePayload []byte
+
+		switch recordType {
+		case recordTypeAPIProxyAnnounceRequest:
+			responsePayload, err = b.handleProxyAnnounce(ctx, geoIPData, initiatorID, unwrappedRequestPayload)
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+		case recordTypeAPIProxyAnswerRequest:
+			responsePayload, err = b.handleProxyAnswer(ctx, brokerClientIP, geoIPData, initiatorID, unwrappedRequestPayload)
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+		case recordTypeAPIClientOfferRequest:
+			responsePayload, err = b.handleClientOffer(ctx, brokerClientIP, geoIPData, initiatorID, unwrappedRequestPayload)
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+		case recordTypeAPIClientRelayedPacketRequest:
+			responsePayload, err = b.handleClientRelayedPacket(ctx, geoIPData, initiatorID, unwrappedRequestPayload)
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+		default:
+			return nil, errors.Tracef("unexpected API record type %v", recordType)
+		}
+
+		return responsePayload, nil
+
+	}
+
+	outPacket, err := b.responderSessions.HandlePacket(
+		inPacket, handleUnwrappedRequest)
+	if err != nil {
+
+		// An error here could be due to invalid session traffic or an expired
+		// session, which is expected. For anti-probing purposes, the
+		// transport response should be the same in either case.
+
+		return nil, errors.Trace(err)
+	}
+
+	return outPacket, nil
+}
+
+// handleProxyAnnounce receives a proxy announcement, awaits a matching
+// client, and returns the client offer in the response. handleProxyAnnounce
+// has a long timeout so this request can idle until a matching client
+// arrives.
+func (b *Broker) handleProxyAnnounce(
+	ctx context.Context,
+	geoIPData common.GeoIPData,
+	initiatorID ID,
+	requestPayload []byte) (retResponse []byte, retErr error) {
+
+	startTime := time.Now()
+
+	var logFields common.LogFields
+	var clientOffer *MatchOffer
+
+	// As a future enhancement, a broker could initiate its own test
+	// connection to the proxy to verify its effectiveness, including
+	// simulating a symmetric NAT client.
+
+	// Each announcement represents availability for a single client matching.
+	// Proxies with multiple client availability will send multiple requests.
+	//
+	// The announcement request and response could be extended to allow the
+	// proxy to specify availability for multiple clients in the request, and
+	// multiple client offers returned in the response.
+	//
+	// If, as we expect, proxies run on home ISPs have limited upstream
+	// bandwidth, they will support only a couple of concurrent clients, and
+	// the simple single-client-announcment model may be sufficient. Also, if
+	// the transport is HTTP/2, multiple requests can be multiplexed over a
+	// single connection (and session) in any case.
+
+	// The proxy ID is an implicit parameter: it's the proxy's session public
+	// key. As part of the session handshake, the proxy has proven that it
+	// has the corresponding private key. Proxy IDs are logged to attribute
+	// traffic to a specific proxy.
+
+	proxyID := initiatorID
+
+	// Generate a connection ID. This ID is used to associate proxy
+	// announcments, client offers, and proxy answers, as well as associating
+	// Psiphon tunnels with in-proxy pairings.
+	connectionID, err := MakeID()
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// Always log the outcome.
+	defer func() {
+		if logFields == nil {
+			logFields = make(common.LogFields)
+		}
+		logFields["broker_event"] = "proxy-announce"
+		logFields["proxy_id"] = proxyID
+		logFields["elapsed_time"] = time.Since(startTime) / time.Millisecond
+		logFields["connection_id"] = connectionID
+		if clientOffer != nil {
+			// Log the target Psiphon server ID (diagnostic ID). The presence
+			// of this field indicates that a match was made.
+			logFields["destination_server_id"] = clientOffer.DestinationServerID
+		}
+		if retErr != nil {
+			logFields["error"] = retErr.Error()
+		}
+		b.config.Logger.LogMetric(brokerMetricName, logFields)
+	}()
+
+	announceRequest, err := UnmarshalProxyAnnounceRequest(requestPayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	logFields, err = announceRequest.ValidateAndGetLogFields(
+		b.config.APIParameterValidator,
+		b.config.APIParameterLogFieldFormatter,
+		geoIPData)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// AllowProxy may be used to disallow proxies from certain geolocations,
+	// such as censored locations, from announcing. Proxies with personal
+	// compartment IDs are always allowed, as they will be used only by
+	// clients specifically configured to use them.
+
+	if len(announceRequest.PersonalCompartmentIDs) == 0 &&
+		!b.config.AllowProxy(geoIPData) {
+
+		return nil, errors.TraceNew("proxy disallowed")
+	}
+
+	// Assign this proxy to a common compartment ID, unless it has specified a
+	// dedicated, personal compartment ID. Assignment uses consistent hashing
+	// keyed with the proxy ID, in an effort to keep proxies consistently
+	// assigned to the same compartment.
+
+	var commonCompartmentIDs []ID
+	if len(announceRequest.PersonalCompartmentIDs) == 0 {
+		compartmentID, err := b.selectCommonCompartmentID(proxyID)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		commonCompartmentIDs = []ID{compartmentID}
+	}
+
+	// Await client offer.
+
+	accounceCtx, cancelFunc := context.WithTimeout(
+		ctx, common.ValueOrDefault(b.config.ProxyAnnounceTimeout, brokerProxyAnnounceTimeout))
+	defer cancelFunc()
+
+	clientOffer, err = b.matcher.Announce(
+		accounceCtx,
+		&MatchAnnouncement{
+			Properties: MatchProperties{
+				CommonCompartmentIDs:   commonCompartmentIDs,
+				PersonalCompartmentIDs: announceRequest.PersonalCompartmentIDs,
+				GeoIPData:              geoIPData,
+				NetworkType:            announceRequest.Metrics.BaseMetrics.GetNetworkType(),
+				NATType:                announceRequest.Metrics.NATType,
+				PortMappingTypes:       announceRequest.Metrics.PortMappingTypes,
+			},
+			ProxyID:              initiatorID,
+			ConnectionID:         connectionID,
+			ProxyProtocolVersion: announceRequest.Metrics.ProxyProtocolVersion,
+		})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// Respond with the client offer. The proxy will follow up with an answer
+	// request, which is relayed to the client, and then the WebRTC dial begins.
+
+	// Limitation: as part of the client's tunnel establishment horse race, a
+	// client may abort an in-proxy dial at any point. If the overall dial is
+	// past the SDP exchange and aborted during the WebRTC connection
+	// establishment, the client may leave the proxy's Proxy.proxyOneClient
+	// dangling until timeout. Consider adding a signal from the client to
+	// the proxy, relayed by the broker, that a dial is aborted.
+
+	responsePayload, err := MarshalProxyAnnounceResponse(
+		&ProxyAnnounceResponse{
+			ConnectionID:                connectionID,
+			ClientProxyProtocolVersion:  clientOffer.ClientProxyProtocolVersion,
+			ClientOfferSDP:              clientOffer.ClientOfferSDP,
+			ClientRootObfuscationSecret: clientOffer.ClientRootObfuscationSecret,
+			NetworkProtocol:             clientOffer.NetworkProtocol,
+			DestinationAddress:          clientOffer.DestinationAddress,
+		})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return responsePayload, nil
+}
+
+// handleClientOffer receives a client offer, awaits a matching client, and
+// returns the proxy answer. handleClientOffer has a shorter timeout than
+// handleProxyAnnounce since the client has supplied an SDP with STUN hole
+// punches which will expire; and, in general, the client is trying to
+// connect immediately and is also trying other candidates.
+func (b *Broker) handleClientOffer(
+	ctx context.Context,
+	clientIP string,
+	geoIPData common.GeoIPData,
+	initiatorID ID,
+	requestPayload []byte) (retResponse []byte, retErr error) {
+
+	// As a future enhancement, consider having proxies send offer SDPs with
+	// announcements and clients long poll to await a match and then provide
+	// an answer. This order of operations would make sense if client demand
+	// is high and proxy supply is lower.
+	//
+	// Also see comment in Proxy.proxyOneClient for other alternative
+	// approaches.
+
+	// The client's session public key is ephemeral and is not logged.
+
+	startTime := time.Now()
+
+	var logFields common.LogFields
+	var serverParams *serverParams
+	var clientMatchOffer *MatchOffer
+	var proxyMatchAnnouncement *MatchAnnouncement
+	var proxyAnswer *MatchAnswer
+
+	// Always log the outcome.
+	defer func() {
+		if logFields == nil {
+			logFields = make(common.LogFields)
+		}
+		logFields["broker_event"] = "client-offer"
+		if serverParams != nil {
+			logFields["destination_server_id"] = serverParams.serverID
+		}
+		logFields["elapsed_time"] = time.Since(startTime) / time.Millisecond
+		if proxyAnswer != nil {
+
+			// The presence of these fields indicate that a match was made,
+			// the proxy delivered and answer, and the client was still
+			// waiting for it.
+
+			logFields["connection_id"] = proxyAnswer.ConnectionID
+			logFields["client_nat_type"] = clientMatchOffer.Properties.NATType
+			logFields["client_port_mapping_types"] = clientMatchOffer.Properties.PortMappingTypes
+			logFields["proxy_nat_type"] = proxyMatchAnnouncement.Properties.NATType
+			logFields["proxy_port_mapping_types"] = proxyMatchAnnouncement.Properties.PortMappingTypes
+			logFields["preferred_nat_match"] =
+				clientMatchOffer.Properties.IsPreferredNATMatch(&proxyMatchAnnouncement.Properties)
+
+			// TODO: also log proxy ice_candidate_types and has_IPv6; for the
+			// client, these values are added by ValidateAndGetLogFields.
+		}
+		if retErr != nil {
+			logFields["error"] = retErr.Error()
+		}
+
+		b.config.Logger.LogMetric(brokerMetricName, logFields)
+	}()
+
+	offerRequest, err := UnmarshalClientOfferRequest(requestPayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	logFields, err = offerRequest.ValidateAndGetLogFields(
+		b.config.LookupGeoIP,
+		b.config.APIParameterValidator,
+		b.config.APIParameterLogFieldFormatter,
+		geoIPData)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// AllowClient may be used to disallow clients from certain geolocations
+	// from offering. Clients are always allowed to match proxies with shared
+	// personal compartment IDs.
+
+	commonCompartmentIDs := offerRequest.CommonCompartmentIDs
+
+	if !b.config.AllowClient(geoIPData) {
+
+		if len(offerRequest.PersonalCompartmentIDs) == 0 {
+			return nil, errors.TraceNew("client disallowed")
+		}
+
+		// Only match personal compartment IDs.
+		commonCompartmentIDs = nil
+	}
+
+	// Validate that the proxy destination specified by the client is a valid
+	// dial address for a signed Psiphon server entry. This ensures a client
+	// can't misuse a proxy to connect to arbitrary destinations.
+
+	serverParams, err = b.validateDestination(
+		geoIPData,
+		offerRequest.DestinationServerEntryJSON,
+		offerRequest.NetworkProtocol,
+		offerRequest.DestinationAddress)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// Enqueue the client offer and await a proxy matching and subsequent
+	// proxy answer.
+
+	offerCtx, cancelFunc := context.WithTimeout(
+		ctx, common.ValueOrDefault(b.config.ClientOfferTimeout, brokerClientOfferTimeout))
+	defer cancelFunc()
+
+	clientMatchOffer = &MatchOffer{
+		Properties: MatchProperties{
+			CommonCompartmentIDs:   commonCompartmentIDs,
+			PersonalCompartmentIDs: offerRequest.PersonalCompartmentIDs,
+			GeoIPData:              geoIPData,
+			NetworkType:            offerRequest.Metrics.BaseMetrics.GetNetworkType(),
+			NATType:                offerRequest.Metrics.NATType,
+			PortMappingTypes:       offerRequest.Metrics.PortMappingTypes,
+		},
+		ClientProxyProtocolVersion:  offerRequest.Metrics.ProxyProtocolVersion,
+		ClientOfferSDP:              offerRequest.ClientOfferSDP,
+		ClientRootObfuscationSecret: offerRequest.ClientRootObfuscationSecret,
+		NetworkProtocol:             offerRequest.NetworkProtocol,
+		DestinationAddress:          offerRequest.DestinationAddress,
+		DestinationServerID:         serverParams.serverID,
+	}
+
+	proxyAnswer, proxyMatchAnnouncement, err = b.matcher.Offer(
+		offerCtx, clientMatchOffer)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// Log the type of compartment matching that occurred. As
+	// PersonalCompartmentIDs are user-generated and shared, actual matching
+	// values are not logged as they may link users.
+
+	// TODO: log matching common compartment IDs?
+
+	matchedCommonCompartments := HaveCommonIDs(
+		proxyMatchAnnouncement.Properties.CommonCompartmentIDs,
+		clientMatchOffer.Properties.CommonCompartmentIDs)
+
+	matchedPersonalCompartments := HaveCommonIDs(
+		proxyMatchAnnouncement.Properties.PersonalCompartmentIDs,
+		clientMatchOffer.Properties.PersonalCompartmentIDs)
+
+	// Initiate a BrokerServerRequest, which reports important information
+	// about the connection, including the original client IP, plus other
+	// values to be logged with server_tunne, to the server. The request is
+	// sent through a secure session established between the broker and the
+	// server.
+	//
+	// The broker may already have an established session with the server. In
+	// this case, only one relay round trip between the client and server
+	// will be necessary; the first round trip will be embedded in the
+	// Psiphon handshake.
+
+	relayPacket, err := b.initiateRelayedServerRequest(
+		serverParams,
+		proxyAnswer.ConnectionID,
+		&BrokerServerRequest{
+			ProxyID:                     proxyAnswer.ProxyID,
+			ConnectionID:                proxyAnswer.ConnectionID,
+			MatchedCommonCompartments:   matchedCommonCompartments,
+			MatchedPersonalCompartments: matchedPersonalCompartments,
+			ProxyNATType:                proxyMatchAnnouncement.Properties.NATType,
+			ProxyPortMappingTypes:       proxyMatchAnnouncement.Properties.PortMappingTypes,
+			ClientNATType:               clientMatchOffer.Properties.NATType,
+			ClientPortMappingTypes:      clientMatchOffer.Properties.PortMappingTypes,
+			ClientIP:                    clientIP,
+			ProxyIP:                     proxyAnswer.ProxyIP,
+		})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// Respond with the proxy answer and initial broker/server session packet.
+
+	responsePayload, err := MarshalClientOfferResponse(
+		&ClientOfferResponse{
+			ConnectionID:                 proxyAnswer.ConnectionID,
+			SelectedProxyProtocolVersion: proxyAnswer.SelectedProxyProtocolVersion,
+			ProxyAnswerSDP:               proxyAnswer.ProxyAnswerSDP,
+			RelayPacketToServer:          relayPacket,
+		})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return responsePayload, nil
+}
+
+// handleProxyAnswer receives a proxy answer and delivers it to the waiting
+// client.
+func (b *Broker) handleProxyAnswer(
+	ctx context.Context,
+	proxyIP string,
+	geoIPData common.GeoIPData,
+	initiatorID ID,
+	requestPayload []byte) (retResponse []byte, retErr error) {
+
+	startTime := time.Now()
+
+	var logFields common.LogFields
+	var proxyAnswer *MatchAnswer
+	var answerError string
+
+	// The proxy ID is an implicit parameter: it's the proxy's session public
+	// key.
+	proxyID := initiatorID
+
+	// Always log the outcome.
+	defer func() {
+		if logFields == nil {
+			logFields = make(common.LogFields)
+		}
+		logFields["broker_event"] = "proxy-answer"
+		logFields["proxy_id"] = proxyID
+		logFields["elapsed_time"] = time.Since(startTime) / time.Millisecond
+		if proxyAnswer != nil {
+			logFields["connection_id"] = proxyAnswer.ConnectionID
+		}
+		if answerError != "" {
+			// This is a proxy-reported error that occurred while creating the answer.
+			logFields["answer_error"] = answerError
+		}
+		if retErr != nil {
+			logFields["error"] = retErr.Error()
+		}
+		b.config.Logger.LogMetric(brokerMetricName, logFields)
+	}()
+
+	answerRequest, err := UnmarshalProxyAnswerRequest(requestPayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	logFields, err = answerRequest.ValidateAndGetLogFields(
+		b.config.LookupGeoIP,
+		b.config.APIParameterValidator,
+		b.config.APIParameterLogFieldFormatter,
+		geoIPData)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	if answerRequest.AnswerError != "" {
+
+		// The proxy failed to create an answer.
+
+		answerError = answerRequest.AnswerError
+
+		b.matcher.AnswerError(initiatorID, answerRequest.ConnectionID)
+
+	} else {
+
+		// Deliver the answer to the client.
+
+		proxyAnswer = &MatchAnswer{
+			ProxyIP:                      proxyIP,
+			ProxyID:                      initiatorID,
+			ConnectionID:                 answerRequest.ConnectionID,
+			SelectedProxyProtocolVersion: answerRequest.SelectedProxyProtocolVersion,
+			ProxyAnswerSDP:               answerRequest.ProxyAnswerSDP,
+		}
+
+		err = b.matcher.Answer(proxyAnswer)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+	}
+
+	// There is no data in this response, it's simply an acknowledgement that
+	// the answer was received. Upon receiving the response, the proxy should
+	// begin the WebRTC dial operation.
+
+	responsePayload, err := MarshalProxyAnswerResponse(
+		&ProxyAnswerResponse{})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return responsePayload, nil
+}
+
+// handleClientRelayedPacket facilitates broker/server sessions. The initial
+// packet from the broker is sent to the client in the ClientOfferResponse.
+// The client sends that to the server in the Psiphon handshake and receives
+// a server packet in the handshake response. That server packet is then
+// delivered to the broker in a ClientRelayedPacketRequest. If the session
+// was already established, the relay ends here. If the session needs to be
+// [re-]negotiated, there are additional ClientRelayedPacket round trips
+// until the session is established and the BrokerServerRequest is securely
+// exchanged between the broker and server.
+func (b *Broker) handleClientRelayedPacket(
+	ctx context.Context,
+	geoIPData common.GeoIPData,
+	initiatorID ID,
+	requestPayload []byte) (retResponse []byte, retErr error) {
+
+	startTime := time.Now()
+
+	var logFields common.LogFields
+	var relayedPacketRequest *ClientRelayedPacketRequest
+	var serverResponse *BrokerServerResponse
+	var serverID string
+
+	// Always log the outcome.
+	defer func() {
+		if logFields == nil {
+			logFields = make(common.LogFields)
+		}
+		logFields["broker_event"] = "client-relayed-packet"
+		logFields["elapsed_time"] = time.Since(startTime) / time.Millisecond
+		if relayedPacketRequest != nil {
+			logFields["connection_id"] = relayedPacketRequest.ConnectionID
+		}
+		if serverResponse != nil {
+			logFields["server_response"] = true
+		}
+		if serverID != "" {
+			logFields["server_id"] = serverID
+		}
+		if retErr != nil {
+			logFields["error"] = retErr.Error()
+		}
+		b.config.Logger.LogMetric(brokerMetricName, logFields)
+	}()
+
+	relayedPacketRequest, err := UnmarshalClientRelayedPacketRequest(requestPayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	logFields, err = relayedPacketRequest.ValidateAndGetLogFields(
+		b.config.APIParameterValidator,
+		b.config.APIParameterLogFieldFormatter,
+		geoIPData)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// The relay state is associated with the connection ID.
+
+	strConnectionID := string(relayedPacketRequest.ConnectionID[:])
+
+	entry, ok := b.pendingServerRequests.Get(strConnectionID)
+	if !ok {
+		// The relay state is not found; it may have been evicted from the
+		// cache. The client will receive a generic error in this case and
+		// should stop relaying. Assuming the server is configured to require
+		// a BrokerServerRequest, the tunnel will be terminated, so the
+		// client should also abandon the dial.
+		return nil, errors.TraceNew("no pending request")
+	}
+	pendingServerRequest := entry.(*pendingServerRequest)
+
+	serverID = pendingServerRequest.serverID
+
+	// When the broker tries to use an existing session that is expired on the
+	// server, the server will indicate that the session is invalid. The
+	// broker resets the session and starts to establish a new session.
+	// There's only one reset and re-establish attempt.
+	//
+	// The non-waiting session establishment mode is used for broker/server
+	// sessions: if multiple clients concurrently try to relay new sessions,
+	// all establishments will happen in parallel without forcing any clients
+	// to wait for one client to lead the establishment. The last established
+	// session will be retained for reuse.
+	//
+	// The client can forge the SessionInvalid flag, but has no incentive to
+	// do so.
+
+	if relayedPacketRequest.SessionInvalid &&
+		atomic.CompareAndSwapInt32(&pendingServerRequest.resetSession, 0, 1) {
+
+		pendingServerRequest.roundTrip.ResetSession()
+	}
+
+	// Next is given a nil ctx since we're not waiting for any other client to
+	// establish the session.
+	out, err := pendingServerRequest.roundTrip.Next(
+		nil, relayedPacketRequest.PacketFromServer)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// When out is nil, the exchange is over and the BrokerServer response
+	// from the server should be available.
+	if out == nil {
+
+		// Removed the cached state. Setting the deleted flag skips a cache
+		// eviction log.
+
+		atomic.StoreInt32(&pendingServerRequest.deleted, 1)
+		b.pendingServerRequests.Delete(strConnectionID)
+
+		// Get the response cached in the session round tripper.
+
+		serverResponsePayload, err := pendingServerRequest.roundTrip.Response()
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		serverResponse, err = UnmarshalBrokerServerResponse(serverResponsePayload)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		// If ErrorMessage is set, the server has rejected the connection.
+
+		if serverResponse.ErrorMessage != "" {
+			return nil, errors.Tracef("server error: %s", serverResponse.ErrorMessage)
+		}
+
+		// Check that the server has acknowledged the expected connection ID.
+
+		if relayedPacketRequest.ConnectionID != serverResponse.ConnectionID {
+			return nil, errors.Tracef(
+				"expected connection ID: %v, got: %v",
+				relayedPacketRequest.ConnectionID,
+				serverResponse.ConnectionID)
+		}
+	}
+
+	// Return the next broker packet for the client to relay to the server.
+	// When it receives a nil PacketToServer, the client will stop relaying.
+
+	responsePayload, err := MarshalClientRelayedPacketResponse(
+		&ClientRelayedPacketResponse{
+			PacketToServer: out,
+		})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return responsePayload, nil
+}
+
+type pendingServerRequest struct {
+	serverID      string
+	serverRequest *BrokerServerRequest
+	roundTrip     *InitiatorRoundTrip
+	resetSession  int32
+	deleted       int32
+}
+
+func (b *Broker) initiateRelayedServerRequest(
+	serverParams *serverParams,
+	connectionID ID,
+	serverRequest *BrokerServerRequest) ([]byte, error) {
+
+	requestPayload, err := MarshalBrokerServerRequest(serverRequest)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// Force a new, concurrent session establishment with the server even if
+	// another handshake is already in progess, relayed by some other client.
+	// This ensures clients don't block waiting for other client relays
+	// through other tunnels. The last established session will be retained
+	// for reuse.
+
+	waitToShareSession := false
+
+	roundTrip, err := b.initiatorSessions.NewRoundTrip(
+		serverParams.sessionPublicKey,
+		serverParams.sessionRootObfuscationSecret,
+		waitToShareSession,
+		requestPayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	relayPacket, err := roundTrip.Next(nil, nil)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	strConnectionID := string(connectionID[:])
+
+	b.pendingServerRequests.Set(
+		strConnectionID,
+		&pendingServerRequest{
+			serverID:      serverParams.serverID,
+			serverRequest: serverRequest,
+			roundTrip:     roundTrip,
+		},
+		lrucache.DefaultExpiration)
+
+	return relayPacket, nil
+}
+
+func (b *Broker) evictedPendingServerRequest(
+	connectionID string, entry interface{}) {
+
+	pendingServerRequest := entry.(*pendingServerRequest)
+
+	// Don't log when the entry was removed by handleClientRelayedPacket due
+	// to completion (this OnEvicted callback gets called in that case).
+	if atomic.LoadInt32(&pendingServerRequest.deleted) == 1 {
+		return
+	}
+
+	b.config.Logger.WithTraceFields(common.LogFields{
+		"server_id":     pendingServerRequest.serverID,
+		"connection_id": connectionID,
+	}).Info("pending server request timed out")
+
+	// TODO: consider adding a signal from the broker to the proxy to
+	// terminate this proxied connection when the BrokerServerResponse does
+	// not arrive in time.
+}
+
+type serverParams struct {
+	serverID                     string
+	sessionPublicKey             SessionPublicKey
+	sessionRootObfuscationSecret ObfuscationSecret
+}
+
+// validateDestination checks that the client's specified proxy dial
+// destination is valid destination address for a tunnel protocol in the
+// specified signed abd valid Psiphon server entry.
+func (b *Broker) validateDestination(
+	geoIPData common.GeoIPData,
+	destinationServerEntryJSON []byte,
+	networkProtocol NetworkProtocol,
+	destinationAddress string) (*serverParams, error) {
+
+	var serverEntryFields protocol.ServerEntryFields
+
+	err := json.Unmarshal(destinationServerEntryJSON, &serverEntryFields)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// Strip any unsigned fields, which could be forged by the client. In
+	// particular, this includes the server entry tag, which, in some cases,
+	// is locally populated by a client for its own reference.
+
+	serverEntryFields.RemoveUnsignedFields()
+
+	// Check that the server entry is signed by Psiphon. Otherwise a client
+	// could manufacture a server entry corresponding to an arbitrary dial
+	// destination.
+
+	err = serverEntryFields.VerifySignature(
+		b.config.ServerEntrySignaturePublicKey)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// The server entry tag must be set and signed by Psiphon, as local,
+	// client derived tags are unsigned and untrusted.
+
+	serverEntryTag := serverEntryFields.GetTag()
+
+	if serverEntryTag == "" {
+		return nil, errors.TraceNew("missing server entry tag")
+	}
+
+	// Check that the server entry tag is on a list of active and valid
+	// Psiphon server entry tags. This ensures that an obsolete entry for a
+	// pruned server cannot by misused by a client to proxy to what's no
+	// longer a Psiphon server.
+
+	if !b.config.IsValidServerEntryTag(serverEntryTag) {
+		return nil, errors.TraceNew("invalid server entry tag")
+	}
+
+	serverID := serverEntryFields.GetDiagnosticID()
+
+	serverEntry, err := serverEntryFields.GetServerEntry()
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// The server entry must include the in-proxy capability. This capability
+	// is set for only a subset of all Psiphon servers, to limited the number
+	// of servers a proxy can observe and enumerate. Well-behaved clients
+	// will not send any server entries lacking this capability, but here the
+	// broker enforces it.
+
+	if !serverEntry.SupportsInProxy() {
+		return nil, errors.TraceNew("missing inproxy capability")
+	}
+
+	// Validate the dial host (IP or domain) and port matches a tunnel
+	// protocol offered by the server entry.
+
+	destHost, destPort, err := net.SplitHostPort(destinationAddress)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	destPortNum, err := strconv.Atoi(destPort)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// For domain fronted cases, since we can't verify the Host header, access
+	// is strictly to limited to targeted clients. Clients should use tactics
+	// to avoid disallowed domain dial address cases, but here the broker
+	// enforces it.
+	//
+	// TODO: this issue could be further mitigated by signaling the proxy to
+	// terminate client connections that fail to deliver a timely
+	// BrokerServerResponse from the expected Psiphon server. See the comment
+	// in evictedPendingServerRequest.
+
+	isDomain := net.ParseIP(destHost) == nil
+	if isDomain && !b.config.AllowDomainDestination(geoIPData) {
+		return nil, errors.TraceNew("domain destination disallowed")
+	}
+
+	if !serverEntry.IsValidDialAddress(networkProtocol.String(), destHost, destPortNum) {
+		return nil, errors.TraceNew("invalid destination address")
+	}
+
+	// Extract and return the key material to be used for the secure session
+	// and BrokerServer exchange between the broker and the Psiphon server
+	// corresponding to this server entry.
+
+	params := &serverParams{
+		serverID: serverID,
+	}
+
+	sessionPublicKey, err := base64.StdEncoding.DecodeString(
+		serverEntry.InProxySessionPublicKey)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	if len(sessionPublicKey) != len(params.sessionPublicKey) {
+		return nil, errors.TraceNew("invalid session public key length")
+	}
+
+	sessionRootObfuscationSecret, err := base64.StdEncoding.DecodeString(
+		serverEntry.InProxySessionRootObfuscationSecret)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	if len(sessionRootObfuscationSecret) != len(params.sessionRootObfuscationSecret) {
+		return nil, errors.TraceNew("invalid session root obfuscation secret length")
+	}
+
+	copy(params.sessionPublicKey[:], sessionPublicKey)
+	copy(params.sessionRootObfuscationSecret[:], sessionRootObfuscationSecret)
+
+	return params, nil
+}
+
+func (b *Broker) initializeCommonCompartmentIDHashing() {
+
+	// Proxies without personal compartment IDs are randomly assigned to the
+	// set of common, Psiphon-specified, compartment IDs. These common
+	// compartment IDs are then distributed to targeted clients through
+	// tactics or embedded in OSLs, to limit access to proxies.
+	//
+	// Use consistent hashing in an effort to keep a consistent assignment of
+	// proxies (as specified by proxy ID, which covers all announcements for
+	// a single proxy). This is more of a concern for long-lived, permanent
+	// proxies that are not behind any NAT.
+	//
+	// Even with consistent hashing, a subset of proxies will still change
+	// assignment when CommonCompartmentIDs changes.
+
+	consistentMembers := make([]consistent.Member, len(b.config.CommonCompartmentIDs))
+	for i, compartmentID := range b.config.CommonCompartmentIDs {
+		consistentMembers[i] = consistentMember(compartmentID.String())
+	}
+
+	b.commonCompartments = consistent.New(
+		consistentMembers,
+		consistent.Config{
+			PartitionCount:    consistent.DefaultPartitionCount,
+			ReplicationFactor: consistent.DefaultReplicationFactor,
+			Load:              consistent.DefaultLoad,
+			Hasher:            xxhasher{},
+		})
+}
+
+// xxhasher wraps github.com/cespare/xxhash.Sum64 in the interface expected by
+// github.com/buraksezer/consistent. xxhash is a high quality hash function
+// used in github.com/buraksezer/consistent examples.
+type xxhasher struct{}
+
+func (h xxhasher) Sum64(data []byte) uint64 {
+	return xxhash.Sum64(data)
+}
+
+// consistentMember wraps the string type with the interface expected by
+// github.com/buraksezer/consistent.
+type consistentMember string
+
+func (m consistentMember) String() string {
+	return string(m)
+}
+
+func (b *Broker) selectCommonCompartmentID(proxyID ID) (ID, error) {
+
+	compartmentID, err := IDFromString(
+		b.commonCompartments.LocateKey(proxyID[:]).String())
+	if err != nil {
+		return compartmentID, errors.Trace(err)
+	}
+
+	return compartmentID, nil
+}

+ 234 - 0
psiphon/common/inproxy/brokerClient.go

@@ -0,0 +1,234 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+)
+
+// BrokerClient is used to make requests to a broker.
+//
+// Each BrokerClient maintains a secure broker session. A BrokerClient and its
+// session may be used for multiple concurrent requests. Session key material
+// is provided by DialParameters and must remain static for the lifetime of
+// the BrokerClient.
+//
+// Round trips between the BrokerClient and broker are provided by
+// BrokerClientRoundTripper from DialParameters. The RoundTripper must
+// maintain the association between a request payload and the corresponding
+// response payload. The canonical RoundTripper is an HTTP client, with
+// HTTP/2 or HTTP/3 used to multiplex concurrent requests.
+//
+// When the DialParameters BrokerClientRoundTripperSucceeded call back is
+// invoked, the RoundTripper provider may mark the RoundTripper dial
+// properties for replay.
+//
+// When the DialParameters BrokerClientRoundTripperFailed call back is
+// invoked, the RoundTripper provider should clear any replay state and also
+// create a new RoundTripper to be returned from BrokerClientRoundTripper.
+//
+// BrokerClient does not have a Close operation. The user should close the
+// provided RoundTripper as appropriate.
+//
+// The secure session layer includes obfuscation that provides random padding
+// and uniformly random payload content. The RoundTripper is expected to add
+// its own obfuscation layer; for example, domain fronting.
+type BrokerClient struct {
+	dialParams DialParameters
+	sessions   *InitiatorSessions
+}
+
+// NewBrokerClient initializes a new BrokerClient with the provided
+// DialParameters.
+func NewBrokerClient(dialParams DialParameters) (*BrokerClient, error) {
+
+	// A client is expected to use an ephemeral key, and can return a
+	// zero-value private key. Each proxy should use a peristent key, as the
+	// corresponding public key is the proxy ID, which is used to credit the
+	// proxy for its service.
+
+	privateKey := dialParams.BrokerClientPrivateKey()
+	if privateKey.IsZero() {
+		var err error
+		privateKey, err = GenerateSessionPrivateKey()
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+	}
+
+	return &BrokerClient{
+		dialParams: dialParams,
+		sessions:   NewInitiatorSessions(privateKey),
+	}, nil
+}
+
+// ProxyAnnounce sends a ProxyAnnounce request and returns the response.
+func (b *BrokerClient) ProxyAnnounce(
+	ctx context.Context,
+	request *ProxyAnnounceRequest) (*ProxyAnnounceResponse, error) {
+
+	requestPayload, err := MarshalProxyAnnounceRequest(request)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	responsePayload, err := b.roundTrip(ctx, requestPayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	response, err := UnmarshalProxyAnnounceResponse(responsePayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return response, nil
+}
+
+// ClientOffer sends a ClientOffer request and returns the response.
+func (b *BrokerClient) ClientOffer(
+	ctx context.Context,
+	request *ClientOfferRequest) (*ClientOfferResponse, error) {
+
+	requestPayload, err := MarshalClientOfferRequest(request)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	responsePayload, err := b.roundTrip(ctx, requestPayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	response, err := UnmarshalClientOfferResponse(responsePayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return response, nil
+}
+
+// ProxyAnswer sends a ProxyAnswer request and returns the response.
+func (b *BrokerClient) ProxyAnswer(
+	ctx context.Context,
+	request *ProxyAnswerRequest) (*ProxyAnswerResponse, error) {
+
+	requestPayload, err := MarshalProxyAnswerRequest(request)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	responsePayload, err := b.roundTrip(ctx, requestPayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	response, err := UnmarshalProxyAnswerResponse(responsePayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return response, nil
+}
+
+// ClientRelayedPacket sends a ClientRelayedPacket request and returns the
+// response.
+func (b *BrokerClient) ClientRelayedPacket(
+	ctx context.Context,
+	request *ClientRelayedPacketRequest) (*ClientRelayedPacketResponse, error) {
+
+	requestPayload, err := MarshalClientRelayedPacketRequest(request)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	responsePayload, err := b.roundTrip(ctx, requestPayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	response, err := UnmarshalClientRelayedPacketResponse(responsePayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return response, nil
+}
+
+func (b *BrokerClient) roundTrip(
+	ctx context.Context,
+	request []byte) ([]byte, error) {
+
+	// The round tripper may need to establish a transport-level connection;
+	// or this may already be established.
+
+	roundTripper, err := b.dialParams.BrokerClientRoundTripper()
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// InitiatorSessions.RoundTrip may make serveral round trips with
+	// roundTripper in order to complete a session establishment handshake.
+	//
+	// When there's an active session, only a single round trip is required,
+	// to exchange the application-level request and response.
+	//
+	// When a concurrent BrokerClient request is currently performing a
+	// session handshake, InitiatorSessions.RoundTrip will await completion
+	// of that handshake before sending the application-layer request.
+	//
+	// Retries are built in to InitiatorSessions.RoundTrip: if there's an
+	// existing session and it's expired, there will be additional round
+	// trips to establish a fresh session.
+	//
+	// While the round tripper is responsible for maintaining the
+	// request/response association, the application-level request and
+	// response are tagged with a RoundTripID which is checked to ensure the
+	// association is maintained.
+
+	waitToShareSession := true
+
+	response, err := b.sessions.RoundTrip(
+		ctx,
+		roundTripper,
+		b.dialParams.BrokerPublicKey(),
+		b.dialParams.BrokerRootObfuscationSecret(),
+		waitToShareSession,
+		request)
+	if err != nil {
+
+		// The DialParameters provider should close the existing
+		// BrokerClientRoundTripper and create a new RoundTripper to return
+		// in the next BrokerClientRoundTripper call.
+		//
+		// The session will be closed, if necessary, by InitiatorSessions.
+		// It's possible that the session remains valid and only the
+		// RoundTripper transport layer needs to be reset.
+		b.dialParams.BrokerClientRoundTripperFailed(roundTripper)
+
+		return nil, errors.Trace(err)
+	}
+
+	b.dialParams.BrokerClientRoundTripperSucceeded(roundTripper)
+
+	return response, nil
+}

+ 422 - 0
psiphon/common/inproxy/client.go

@@ -0,0 +1,422 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"net"
+	"sync"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+)
+
+// clientOfferRequestTimeout should be set to no more than brokerClientOfferTimeout
+
+const (
+	clientOfferRequestTimeout = 10 * time.Second
+	clientOfferRetryDelay     = 1 * time.Second
+	clientOfferRetryJitter    = 0.3
+)
+
+// ClientConn is a network connection to an in-proxy, which is relayed to a
+// Psiphon server destination. Psiphon clients use a ClientConn in place of a
+// physical TCP or UDP socket connection, passing the ClientConn into tunnel
+// protocol dials. ClientConn implements both net.Conn and net.PacketConn,
+// with net.PacketConn's ReadFrom/WriteTo behaving as if connected to the
+// initial dial address.
+type ClientConn struct {
+	config       *ClientConfig
+	brokerClient *BrokerClient
+	webRTCConn   *WebRTCConn
+	connectionID ID
+
+	relayMutex         sync.Mutex
+	initialRelayPacket []byte
+}
+
+// ClientConfig specifies the configuration for a ClientConn dial.
+type ClientConfig struct {
+
+	// Logger is used to log events.
+	Logger common.Logger
+
+	// BaseMetrics should be populated with Psiphon handshake metrics
+	// parameters. These will be sent to and logger by the broker.
+	BaseMetrics common.APIParameters
+
+	// DialParameters specifies specific WebRTC dial strategies and
+	// settings; DialParameters also facilities dial replay by receiving
+	// callbacks when individual dial steps succeed or fail.
+	DialParameters DialParameters
+
+	// BrokerClient is the BrokerClient to use for broker API calls. The
+	// BrokerClient may be shared with other client dials, allowing for
+	// connection and session reuse.
+	BrokerClient *BrokerClient
+
+	// ReliableTransport specifies whether to use reliable delivery with the
+	// underlying WebRTC DataChannel that relays the ClientConn traffic. When
+	// using a ClientConn to proxy traffic that expects reliable delivery, as
+	// if the physical network protocol were TCP, specify true. When using a
+	// ClientConn to proxy traffic that expects unreliable delivery, such as
+	// QUIC protocols expecting the physical network protocol UDP, specify
+	// false.
+	ReliableTransport bool
+
+	// DialNetworkProtocol specifies whether the in-proxy will relay TCP or UDP
+	// traffic.
+	DialNetworkProtocol NetworkProtocol
+
+	// DialAddress is the host:port destination network address the in-proxy
+	// will relay traffic to.
+	DialAddress string
+
+	// DestinationServerEntryJSON is a signed Psiphon server entry
+	// corresponding to the destination dial address. This signed server
+	// entry is sent to the broker, which will use it to validate that the
+	// server is a valid in-proxy destination.
+	// ServerEntryFields.RemoveUnsignedFields can be called to prune local
+	// fields before sending.
+	DestinationServerEntryJSON []byte
+}
+
+// DialClient establishes an in-proxy connection for relaying traffic to the
+// specified destination. DialClient first contacts the broker and initiates
+// an in-proxy pairing. config.BrokerClient may be shared by multiple dials,
+// and may have a preexisting connection and session with the broker.
+func DialClient(
+	ctx context.Context,
+	config *ClientConfig) (retConn *ClientConn, retErr error) {
+
+	// Reset and configure port mapper component, as required. See
+	// initPortMapper comment.
+	initPortMapper(config.DialParameters)
+
+	// Future improvements:
+	//
+	// - The broker connection and session, when not already established,
+	//   could be established concurrent with the WebRTC offer setup
+	//   (STUN/ICE gathering).
+	//
+	// - The STUN state used for NAT discovery could be reused for the WebRTC
+	//   dial.
+	//
+	// - A subsequent WebRTC offer setup could be run concurrent with the
+	//   client offer request, in case that request or WebRTC connections
+	//   fails, so that the offer is immediately ready for a retry.
+
+	if config.DialParameters.DiscoverNAT() {
+
+		// NAT discovery, using the RFC5780 algorithms is optional and
+		// conditional on the DiscoverNAT flag. Discovery is performed
+		// synchronously, so that NAT topology metrics can be reported to the
+		// broker in the ClientOffer request. For clients, NAT discovery is
+		// intended to be performed at a low sampling rate, since the RFC5780
+		// traffic may be unusual(differs from standard STUN requests for
+		// ICE) and since this step delays the dial. Clients should to cache
+		// their NAT discovery outcomes, associated with the current network
+		// by network ID, so metrics can be reported even without a discovery
+		// step; this is facilitated by DialParameters.
+		//
+		// NAT topology metrics are used by the broker to optimize client and
+		// in-proxy matching.
+		//
+		// For client NAT discovery, port mapping type discovery is skipped
+		// since port mappings are attempted when preparing the WebRTC offer,
+		// which also happens before the ClientOffer request.
+
+		NATDiscover(
+			ctx,
+			&NATDiscoverConfig{
+				Logger:          config.Logger,
+				DialParameters:  config.DialParameters,
+				SkipPortMapping: true,
+			})
+	}
+
+	var result *clientWebRTCDialResult
+	for {
+
+		// Repeatedly try to establish in-proxy/WebRTC connection until the
+		// dial context is canceled or times out.
+		//
+		// If a broker request fails, the
+		// DialParameters.BrokerClientRoundTripperFailed callback will be
+		// invoked, so the Psiphon client will have an opportunity to select
+		// new broker connection parameters before a retry. Similarly, when
+		// STUN servers fail, DialParameters.STUNServerAddressFailed will be
+		// invoked, giving the Psiphon client an opportunity to select new
+		// STUN server parameter -- although, in this failure case, the
+		// WebRTC connection attemp can succeed with other ICE candidates or
+		// no ICE candidates.
+
+		err := ctx.Err()
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		var retry bool
+		result, retry, err = dialClientWebRTCConn(ctx, config)
+		if err == nil {
+			break
+		}
+
+		if retry {
+			config.Logger.WithTraceFields(common.LogFields{"error": err}).Warning("dial failed")
+
+			// This delay is intended avoid overloading the broker with
+			// repeated requests. A jitter is applied to mitigate a traffic
+			// fingerprint.
+
+			common.SleepWithJitter(
+				ctx,
+				common.ValueOrDefault(config.DialParameters.OfferRetryDelay(), clientOfferRetryDelay),
+				common.ValueOrDefault(config.DialParameters.OfferRetryJitter(), clientOfferRetryJitter))
+
+			continue
+		}
+
+		return nil, errors.Trace(err)
+	}
+
+	return &ClientConn{
+		config:             config,
+		webRTCConn:         result.conn,
+		connectionID:       result.connectionID,
+		initialRelayPacket: result.relayPacket,
+	}, nil
+}
+
+// GetConnectionID returns the in-proxy connection ID, which the client should
+// include with its Psiphon handshake parameters.
+func (conn *ClientConn) GetConnectionID() ID {
+	return conn.connectionID
+}
+
+// InitialRelayPacket returns the initial packet in the broker->server
+// messaging session. The client must relay these packets to facilitate this
+// message exchange. Session security ensures clients cannot decrypt, modify,
+// or replay these session packets. The Psiphon client will sent the initial
+// packet as a parameter in the Psiphon server handshake request.
+func (conn *ClientConn) InitialRelayPacket() []byte {
+	conn.relayMutex.Lock()
+	defer conn.relayMutex.Unlock()
+
+	relayPacket := conn.initialRelayPacket
+	conn.initialRelayPacket = nil
+	return relayPacket
+}
+
+// RelayPacket takes any server->broker messaging session packets the client
+// receives and relays them back to the broker. RelayPacket returns the next
+// broker->server packet, if any, or nil when the message exchange is
+// complete. Psiphon clients receive a server->broker packet in the Psiphon
+// server handshake response and exchange additional packets in a
+// post-handshake Psiphon server request.
+//
+// If RelayPacket fails, the client should close the ClientConn and redial.
+func (conn *ClientConn) RelayPacket(
+	ctx context.Context, in []byte, sessionInvalid bool) ([]byte, error) {
+
+	// Future improvement: the client relaying these packets back to the
+	// broker is potentially an inter-flow fingerprint, alternating between
+	// the WebRTC flow and the client's broker connection. It may be possible
+	// to avoid this by having the client connect to the broker via the
+	// tunnel, resuming its broker session and relaying any further packets.
+
+	// Limitation: here, this mutex only ensures that this ClientConn doesn't
+	// make concurrent ClientRelayedPacket requests. The client must still
+	// ensure that the packets are delivered in the correct relay sequence.
+	conn.relayMutex.Lock()
+	defer conn.relayMutex.Unlock()
+
+	relayResponse, err := conn.config.BrokerClient.ClientRelayedPacket(
+		ctx,
+		&ClientRelayedPacketRequest{
+			ConnectionID:     conn.connectionID,
+			PacketFromServer: in,
+			SessionInvalid:   sessionInvalid,
+		})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return relayResponse.PacketToServer, nil
+}
+
+type clientWebRTCDialResult struct {
+	conn         *WebRTCConn
+	connectionID ID
+	relayPacket  []byte
+}
+
+func dialClientWebRTCConn(
+	ctx context.Context,
+	config *ClientConfig) (retResult *clientWebRTCDialResult, retRetry bool, retErr error) {
+
+	// Initialize the WebRTC offer
+
+	clientRootObfuscationSecret := config.DialParameters.ClientRootObfuscationSecret()
+
+	webRTCConn, SDP, SDPMetrics, err := NewWebRTCConnWithOffer(
+		ctx, &WebRTCConfig{
+			Logger:                      config.Logger,
+			DialParameters:              config.DialParameters,
+			ClientRootObfuscationSecret: clientRootObfuscationSecret,
+			ReliableTransport:           config.ReliableTransport,
+		})
+	if err != nil {
+		return nil, true, errors.Trace(err)
+	}
+	defer func() {
+		// Cleanup on early return
+		if retErr != nil {
+			webRTCConn.Close()
+		}
+	}()
+
+	// Send the ClientOffer request to the broker
+
+	offerRequestCtx, offerRequestCancelFunc := context.WithTimeout(
+		ctx, common.ValueOrDefault(config.DialParameters.OfferRequestTimeout(), clientOfferRequestTimeout))
+	defer offerRequestCancelFunc()
+
+	baseMetrics, err := EncodeBaseMetrics(config.BaseMetrics)
+	if err != nil {
+		return nil, false, errors.Trace(err)
+	}
+
+	// Here, DialParameters.NATType may be populated from discovery, or
+	// replayed from a previous run on the same network ID.
+	// DialParameters.PortMappingTypes may be populated via
+	// newWebRTCConnWithOffer.
+
+	offerResponse, err := config.BrokerClient.ClientOffer(
+		offerRequestCtx,
+		&ClientOfferRequest{
+			Metrics: &ClientMetrics{
+				BaseMetrics:          baseMetrics,
+				ProxyProtocolVersion: ProxyProtocolVersion1,
+				NATType:              config.DialParameters.NATType(),
+				PortMappingTypes:     config.DialParameters.PortMappingTypes(),
+			},
+			CommonCompartmentIDs:        config.DialParameters.CommonCompartmentIDs(),
+			PersonalCompartmentIDs:      config.DialParameters.PersonalCompartmentIDs(),
+			ClientOfferSDP:              SDP,
+			ICECandidateTypes:           SDPMetrics.ICECandidateTypes,
+			ClientRootObfuscationSecret: clientRootObfuscationSecret,
+			DestinationServerEntryJSON:  config.DestinationServerEntryJSON,
+			NetworkProtocol:             config.DialNetworkProtocol,
+			DestinationAddress:          config.DialAddress,
+		})
+	if err != nil {
+		return nil, false, errors.Trace(err)
+	}
+
+	if offerResponse.SelectedProxyProtocolVersion != ProxyProtocolVersion1 {
+		return nil, false, errors.Tracef(
+			"Unsupported proxy protocol version: %d",
+			offerResponse.SelectedProxyProtocolVersion)
+	}
+
+	// Establish the WebRTC DataChannel connection
+
+	err = webRTCConn.SetRemoteSDP(offerResponse.ProxyAnswerSDP)
+	if err != nil {
+		return nil, true, errors.Trace(err)
+	}
+
+	err = webRTCConn.AwaitInitialDataChannel(ctx)
+	if err != nil {
+		return nil, true, errors.Trace(err)
+	}
+
+	return &clientWebRTCDialResult{
+		conn:         webRTCConn,
+		connectionID: offerResponse.ConnectionID,
+		relayPacket:  offerResponse.RelayPacketToServer,
+	}, false, nil
+}
+
+// GetMetrics implements the common.MetricsSource interface.
+func (conn *ClientConn) GetMetrics() common.LogFields {
+
+	// TODO: determine which WebRTC ICE candidate was chosen, and log its
+	// type (host, server reflexive, etc.), and whether it's IPv6.
+
+	return common.LogFields{}
+}
+
+func (conn *ClientConn) Close() error {
+	return errors.Trace(conn.webRTCConn.Close())
+}
+
+func (conn *ClientConn) Read(p []byte) (int, error) {
+	n, err := conn.webRTCConn.Read(p)
+	return n, errors.Trace(err)
+}
+
+// Write relays p through the in-proxy connection. len(p) should be under
+// 32K.
+func (conn *ClientConn) Write(p []byte) (int, error) {
+	n, err := conn.webRTCConn.Write(p)
+	return n, errors.Trace(err)
+}
+
+func (conn *ClientConn) LocalAddr() net.Addr {
+	return conn.webRTCConn.LocalAddr()
+}
+
+func (conn *ClientConn) RemoteAddr() net.Addr {
+	return conn.webRTCConn.RemoteAddr()
+}
+
+func (conn *ClientConn) SetDeadline(t time.Time) error {
+	return conn.webRTCConn.SetDeadline(t)
+}
+
+func (conn *ClientConn) SetReadDeadline(t time.Time) error {
+	return conn.webRTCConn.SetReadDeadline(t)
+}
+
+func (conn *ClientConn) SetWriteDeadline(t time.Time) error {
+
+	// Limitation: this is a workaround; webRTCConn doesn't support
+	// SetWriteDeadline, but common/quic calls SetWriteDeadline on
+	// net.PacketConns to avoid hanging on EAGAIN when the conn is an actual
+	// UDP socket. See the comment in common/quic.writeTimeoutUDPConn. In
+	// this case, the conn is not a UDP socket and that particular
+	// SetWriteDeadline use case doesn't apply. Silently ignore the deadline
+	// and report no error.
+
+	return nil
+}
+
+func (conn *ClientConn) ReadFrom(b []byte) (int, net.Addr, error) {
+	n, err := conn.webRTCConn.Read(b)
+	return n, conn.webRTCConn.RemoteAddr(), err
+}
+
+func (conn *ClientConn) WriteTo(b []byte, _ net.Addr) (int, error) {
+	n, err := conn.webRTCConn.Write(b)
+	return n, err
+}

+ 242 - 0
psiphon/common/inproxy/dialParameters.go

@@ -0,0 +1,242 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"net"
+	"time"
+)
+
+// RoundTripper provides a request/response round trip network transport with
+// blocking circumvention capabilities. A typical implementation is domain
+// fronted HTTPS. RoundTripper is used by clients and proxies to make
+// requests to brokers.
+type RoundTripper interface {
+	RoundTrip(ctx context.Context, requestPayload []byte) (responsePayload []byte, err error)
+}
+
+// DialParameters provides in-proxy dial parameters and configuration, used by
+// both clients and proxies, and an interface for signaling when parameters
+// are successful or not, to facilitate replay of successful parameters.
+//
+// Each DialParameters should provide values selected in the context of a
+// single network, as identified by a network ID. A distinct DialParameters
+// should be created for each client in-proxy dial, with new or replayed
+// parameters selected as appropriate. One proxy run uses a single
+// DialParameters for all proxied connections. The proxy should be restarted
+// with a new DialParameters when the underlying network changes.
+type DialParameters interface {
+
+	// CommonCompartmentIDs is the list of common, Psiphon-managed, in-proxy
+	// compartment IDs known to a client. These IDs are delivered through
+	// tactics, or embedded in OSLs.
+	//
+	// At most MaxCompartmentIDs may be sent to a broker; if necessary, the
+	// provider may return a subset of known compartment IDs and replay when
+	// the overall dial is a success; and/or retain only the most recently
+	// discovered compartment IDs.
+	//
+	// CommonCompartmentIDs is not called for proxies.
+	CommonCompartmentIDs() []ID
+
+	// PersonalCompartmentIDs are compartment IDs distributed from proxy
+	// operators to client users out-of-band and provide optional access
+	// control. For example, a proxy operator may want to provide access only
+	// to certain users, and/or users want to use only a proxy run by a
+	// certain operator.
+	//
+	// At most MaxCompartmentIDs may be sent to a broker; for typical use
+	// cases, both clients and proxies will specify a single personal
+	// compartment ID.
+	PersonalCompartmentIDs() []ID
+
+	// Returns the network ID for the network this DialParameters is
+	// associated with. For a single DialParameters, the NetworkID value
+	// should not change. Replay-facilitating calls, Succeeded/Failed, all
+	// assume the network and network ID remain static. The network ID value
+	// is used by in-proxy dials to track internal state that depends on the
+	// current network; this includes the port mapping types supported by the
+	// network.
+	NetworkID() string
+
+	// Returns the network type for the current network, or NetworkTypeUnknown
+	// if unknown.
+	NetworkType() NetworkType
+
+	// BrokerClientPrivateKey is the client or proxy's private key to be used
+	// in the secure session established with a broker. Clients should
+	// generate ephemeral keys; this is done automatically when a zero-value
+	// SessionPrivateKey is returned. Proxies may generate, persist, and
+	// long-lived keys to enable traffic attribution to a proxy, identified
+	// by a proxy ID, the corresponding public key.
+	BrokerClientPrivateKey() SessionPrivateKey
+
+	// BrokerPublicKey is the public key for the broker selected by the
+	// provider and reachable via BrokerClientRoundTripper. The broker is
+	// authenticated in the secure session.
+	BrokerPublicKey() SessionPublicKey
+
+	// BrokerRootObfuscationSecret is the root obfuscation secret for the
+	// broker and used in the secure session.
+	BrokerRootObfuscationSecret() ObfuscationSecret
+
+	// BrokerClientRoundTripper returns a RoundTripper to use for broker
+	// requests. The provider handles selecting a broker and broker
+	// addressing, as well as providing a round trip network transport with
+	// blocking circumvention capabilities. A typical implementation is
+	// domain fronted HTTPS. The RoundTripper should offer persistent network
+	// connections and request multiplexing, for example with HTTP/2, so that
+	// a single connection can be used for many concurrent requests.
+	//
+	// Clients and proxies make round trips to establish a secure session with
+	// the broker, on top of the provided transport, and to exchange API
+	// requests with the broker.
+	BrokerClientRoundTripper() (RoundTripper, error)
+
+	// BrokerClientRoundTripperSucceeded is called after a successful round
+	// trip using the specified RoundTripper. This signal is used to set
+	// replay for the round tripper's successful dial parameters.
+	// BrokerClientRoundTripperSucceeded is called once per successful round
+	// trip; the provider can choose to set replay only once.
+	BrokerClientRoundTripperSucceeded(roundTripper RoundTripper)
+
+	// BrokerClientRoundTripperSucceeded is called after a failed round trip
+	// using the specified RoundTripper. This signal is used to clear replay
+	// for the round tripper's unsuccessful dial parameters. The provider
+	// will arrange for a new RoundTripper to be returned from the next
+	// BrokerClientRoundTripper call, discarding the current RoundTripper
+	// after closing its network resources.
+	BrokerClientRoundTripperFailed(roundTripper RoundTripper)
+
+	// ClientRootObfuscationSecret is the root obfuscation secret generated by
+	// or replayed by the client, which will be used to drive and replay
+	// obfuscation operations for the WebRTC dial, including any DTLS
+	// randomization. The proxy receives the same root obfuscation secret,
+	// relayed by the broker, and so the client's selection drives
+	// obfuscation/replay on both sides.
+	ClientRootObfuscationSecret() ObfuscationSecret
+
+	// DoDTLSRandomization indicates whether to perform DTLS ClientHello
+	// randomization. DoDTLSRandomization is specified by clients, which may
+	// use a weighted coin flip or a replay to determine the value.
+	DoDTLSRandomization() bool
+
+	// STUNServerAddress selects a STUN server to use for this dial. When
+	// RFC5780 is true, the STUN server must support RFC5780 NAT discovery;
+	// otherwise, only basic STUN bind operation support is required. Clients
+	// and proxies will receive a list of STUN server candidates via tactics,
+	// and select a candidate at random or replay for each dial. If
+	// STUNServerAddress returns "", STUN operations are skipped but the dial
+	// may still succeed if a port mapping can be established.
+	STUNServerAddress(RFC5780 bool) string
+
+	// STUNServerAddressSucceeded is called after a successful STUN operation
+	// with the STUN server specified by the address. This signal is used to
+	// set replay for successful STUN servers. STUNServerAddressSucceeded
+	// will be called when the STUN opertion succeeds, regardless of the
+	// outcome of the rest of the dial. RFC5780 is true when the STUN server
+	// was used for NAT discovery.
+	STUNServerAddressSucceeded(RFC5780 bool, address string)
+
+	// STUNServerAddressFailed is called after a failed STUN operation and is
+	// used to clear replay for the specified STUN server.
+	STUNServerAddressFailed(RFC5780 bool, address string)
+
+	// DiscoverNAT indicates whether a client dial should start with NAT
+	// discovery. Discovering and reporting the client NAT type will assist
+	// in broker matching. However, RFC5780 NAT discovery can slow down a
+	// dial and potentially looks like atypical network traffic. Client NAT
+	// discovery is controlled by tactics and may be disabled or set to run
+	// with a small probability. Discovered NAT types and portmapping types
+	// may be cached and used with future dials via SetNATType/NATType and
+	// SetPortMappingTypes/PortMappingTypes.
+	//
+	// Proxies always perform NAT discovery on start up, since that doesn't
+	// delay a client dial.
+	DiscoverNAT() bool
+
+	// DisableSTUN indicates whether to skip STUN operations.
+	DisableSTUN() bool
+
+	// DisableSTUN indicates whether to skip port mapping operations.
+	DisablePortMapping() bool
+
+	// DisableInboundForMobleNetworks indicates that all attempts to set up
+	// inbound operations -- including STUN and port mapping -- should be
+	// skipped when the network type is NetworkTypeMobile. This skips
+	// operations that can slow down dials and and unlikely to succeed on
+	// most mobile networks with CGNAT.
+	DisableInboundForMobleNetworks() bool
+
+	// NATType returns any persisted NAT type for the current network, as set
+	// by SetNATType. When NATTypeUnknown is returned, NAT discovery may be
+	// run.
+	NATType() NATType
+
+	// SetNATType is called when the NAT type for the current network has been
+	// discovered. The provider should persist this value, associated with
+	// the current network ID and with a reasonable TTL, so the value can be
+	// reused in subsequent dials without having to re-run NAT discovery.
+	SetNATType(t NATType)
+
+	// PortMappingTypes returns any persisted, supported port mapping types
+	// for the current network, as set by SetPortMappingTypes. When an empty
+	// list is returned port mapping discovery may be run. A list containing
+	// only PortMappingTypeNone indicates that no supported port mapping
+	// types were discovered.
+	PortMappingTypes() PortMappingTypes
+
+	// SetPortMappingTypes is called with the supported port mapping types
+	// discovered for the current network. The provider should persist this
+	// value, associated with the current network ID and with a reasonable
+	// TTL, so the value can be reused in subsequent dials without having to
+	// re-run port mapping discovery.
+	SetPortMappingTypes(t PortMappingTypes)
+
+	// ResolveAddress resolves a domain and returns its IP address. Clients
+	// and proxies may use this to hook into the Psiphon custom resolver. The
+	// provider adds the custom resolver tactics and network ID parameters
+	// required by psiphon/common.Resolver.
+	ResolveAddress(ctx context.Context, address string) (string, error)
+
+	// UDPListen dials a local UDP socket. The socket should be bound to a
+	// specific interface as required for VPN modes, and set a write timeout
+	// to mitigate the issue documented in psiphon/common.WriteTimeoutUDPConn.
+	UDPListen() (net.PacketConn, error)
+
+	// BindToDevice binds a socket, specified by the file descriptor, to an
+	// interface that isn't routed through a VPN when Psiphon is running in
+	// VPN mode. BindToDevice is used in cases where a custom dialer cannot
+	// be used, and UDPListen cannot be called.
+	BindToDevice(fileDescriptor int) error
+
+	DiscoverNATTimeout() time.Duration
+	OfferRequestTimeout() time.Duration
+	OfferRetryDelay() time.Duration
+	OfferRetryJitter() float64
+	AnnounceRequestTimeout() time.Duration
+	AnnounceRetryDelay() time.Duration
+	AnnounceRetryJitter() float64
+	WebRTCAnswerTimeout() time.Duration
+	AnswerRequestTimeout() time.Duration
+	ProxyClientConnectTimeout() time.Duration
+	ProxyDestinationDialTimeout() time.Duration
+}

+ 394 - 0
psiphon/common/inproxy/dialParameters_test.go

@@ -0,0 +1,394 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"net"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/stacktrace"
+)
+
+type testDialParameters struct {
+	mutex                             sync.Mutex
+	commonCompartmentIDs              []ID
+	personalCompartmentIDs            []ID
+	networkID                         string
+	networkType                       NetworkType
+	brokerClientPrivateKey            SessionPrivateKey
+	brokerPublicKey                   SessionPublicKey
+	brokerRootObfuscationSecret       ObfuscationSecret
+	brokerClientRoundTripper          RoundTripper
+	brokerClientRoundTripperSucceeded func(RoundTripper)
+	brokerClientRoundTripperFailed    func(RoundTripper)
+	clientRootObfuscationSecret       ObfuscationSecret
+	doDTLSRandomization               bool
+	stunServerAddress                 string
+	stunServerAddressRFC5780          string
+	stunServerAddressSucceeded        func(RFC5780 bool, address string)
+	stunServerAddressFailed           func(RFC5780 bool, address string)
+	discoverNAT                       bool
+	disableSTUN                       bool
+	disablePortMapping                bool
+	disableInboundForMobleNetworks    bool
+	natType                           NATType
+	setNATType                        func(NATType)
+	portMappingTypes                  PortMappingTypes
+	setPortMappingTypes               func(PortMappingTypes)
+	bindToDevice                      func(int) error
+	discoverNATTimeout                time.Duration
+	offerRequestTimeout               time.Duration
+	offerRetryDelay                   time.Duration
+	offerRetryJitter                  float64
+	announceRequestTimeout            time.Duration
+	announceRetryDelay                time.Duration
+	announceRetryJitter               float64
+	webRTCAnswerTimeout               time.Duration
+	answerRequestTimeout              time.Duration
+	proxyClientConnectTimeout         time.Duration
+	proxyDestinationDialTimeout       time.Duration
+}
+
+func (t *testDialParameters) CommonCompartmentIDs() []ID {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.commonCompartmentIDs
+}
+
+func (t *testDialParameters) PersonalCompartmentIDs() []ID {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.personalCompartmentIDs
+}
+
+func (t *testDialParameters) NetworkID() string {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.networkID
+}
+
+func (t *testDialParameters) NetworkType() NetworkType {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.networkType
+}
+
+func (t *testDialParameters) BrokerClientPrivateKey() SessionPrivateKey {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.brokerClientPrivateKey
+}
+
+func (t *testDialParameters) BrokerPublicKey() SessionPublicKey {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.brokerPublicKey
+}
+
+func (t *testDialParameters) BrokerRootObfuscationSecret() ObfuscationSecret {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.brokerRootObfuscationSecret
+}
+
+func (t *testDialParameters) BrokerClientRoundTripper() (RoundTripper, error) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.brokerClientRoundTripper, nil
+}
+
+func (t *testDialParameters) BrokerClientRoundTripperSucceeded(roundTripper RoundTripper) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	t.brokerClientRoundTripperSucceeded(roundTripper)
+}
+
+func (t *testDialParameters) BrokerClientRoundTripperFailed(roundTripper RoundTripper) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	t.brokerClientRoundTripperFailed(roundTripper)
+}
+
+func (t *testDialParameters) ClientRootObfuscationSecret() ObfuscationSecret {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.clientRootObfuscationSecret
+}
+
+func (t *testDialParameters) DoDTLSRandomization() bool {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.doDTLSRandomization
+}
+
+func (t *testDialParameters) STUNServerAddress(RFC5780 bool) string {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	if RFC5780 {
+		return t.stunServerAddressRFC5780
+	}
+	return t.stunServerAddress
+}
+
+func (t *testDialParameters) STUNServerAddressSucceeded(RFC5780 bool, address string) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	t.stunServerAddressSucceeded(RFC5780, address)
+}
+
+func (t *testDialParameters) STUNServerAddressFailed(RFC5780 bool, address string) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	t.stunServerAddressFailed(RFC5780, address)
+}
+
+func (t *testDialParameters) DiscoverNAT() bool {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.discoverNAT
+}
+
+func (t *testDialParameters) DisableSTUN() bool {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.disableSTUN
+}
+
+func (t *testDialParameters) DisablePortMapping() bool {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.disablePortMapping
+}
+
+func (t *testDialParameters) DisableInboundForMobleNetworks() bool {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.disableInboundForMobleNetworks
+}
+
+func (t *testDialParameters) NATType() NATType {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.natType
+}
+
+func (t *testDialParameters) SetNATType(natType NATType) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	t.natType = natType
+	t.setNATType(natType)
+}
+
+func (t *testDialParameters) PortMappingTypes() PortMappingTypes {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.portMappingTypes
+}
+
+func (t *testDialParameters) SetPortMappingTypes(portMappingTypes PortMappingTypes) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	t.portMappingTypes = append(PortMappingTypes{}, portMappingTypes...)
+	t.setPortMappingTypes(portMappingTypes)
+}
+
+func (t *testDialParameters) ResolveAddress(ctx context.Context, address string) (string, error) {
+	// No hostnames are resolved in the test.
+	return address, nil
+}
+
+func (t *testDialParameters) UDPListen() (net.PacketConn, error) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	conn, err := net.ListenUDP("udp", nil)
+	return conn, errors.Trace(err)
+}
+
+func (t *testDialParameters) BindToDevice(fileDescriptor int) error {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return errors.Trace(t.bindToDevice(fileDescriptor))
+}
+
+func (t *testDialParameters) DiscoverNATTimeout() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.discoverNATTimeout
+}
+
+func (t *testDialParameters) OfferRequestTimeout() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.offerRequestTimeout
+}
+
+func (t *testDialParameters) OfferRetryDelay() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.offerRetryDelay
+}
+
+func (t *testDialParameters) OfferRetryJitter() float64 {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.offerRetryJitter
+}
+
+func (t *testDialParameters) AnnounceRequestTimeout() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.announceRequestTimeout
+}
+
+func (t *testDialParameters) AnnounceRetryDelay() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.announceRetryDelay
+}
+
+func (t *testDialParameters) AnnounceRetryJitter() float64 {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.announceRetryJitter
+}
+
+func (t *testDialParameters) WebRTCAnswerTimeout() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.webRTCAnswerTimeout
+}
+
+func (t *testDialParameters) AnswerRequestTimeout() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.answerRequestTimeout
+}
+
+func (t *testDialParameters) ProxyClientConnectTimeout() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.proxyClientConnectTimeout
+}
+
+func (t *testDialParameters) ProxyDestinationDialTimeout() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.proxyDestinationDialTimeout
+}
+
+type testLogger struct {
+	logLevelDebug int32
+}
+
+func newTestLogger() *testLogger {
+	return &testLogger{logLevelDebug: 1}
+}
+
+func (logger *testLogger) WithTrace() common.LogTrace {
+	return &testLoggerTrace{
+		logger: logger,
+		trace:  stacktrace.GetParentFunctionName(),
+	}
+}
+
+func (logger *testLogger) WithTraceFields(fields common.LogFields) common.LogTrace {
+	return &testLoggerTrace{
+		logger: logger,
+		trace:  stacktrace.GetParentFunctionName(),
+		fields: fields,
+	}
+}
+
+func (logger *testLogger) LogMetric(metric string, fields common.LogFields) {
+	jsonFields, _ := json.Marshal(fields)
+	fmt.Printf(
+		"[%s] METRIC: %s: %s\n",
+		time.Now().UTC().Format(time.RFC3339),
+		metric,
+		string(jsonFields))
+}
+
+func (logger *testLogger) IsLogLevelDebug() bool {
+	return atomic.LoadInt32(&logger.logLevelDebug) == 1
+}
+
+func (logger *testLogger) SetLogLevelDebug(logLevelDebug bool) {
+	value := int32(0)
+	if logLevelDebug {
+		value = 1
+	}
+	atomic.StoreInt32(&logger.logLevelDebug, value)
+}
+
+type testLoggerTrace struct {
+	logger *testLogger
+	trace  string
+	fields common.LogFields
+}
+
+func (logger *testLoggerTrace) log(priority, message string) {
+	now := time.Now().UTC().Format(time.RFC3339)
+	if len(logger.fields) == 0 {
+		fmt.Printf(
+			"[%s] %s: %s: %s\n",
+			now, priority, logger.trace, message)
+	} else {
+		fields := common.LogFields{}
+		for k, v := range logger.fields {
+			switch v := v.(type) {
+			case error:
+				// Workaround for Go issue 5161: error types marshal to "{}"
+				fields[k] = v.Error()
+			default:
+				fields[k] = v
+			}
+		}
+		jsonFields, _ := json.Marshal(fields)
+		fmt.Printf(
+			"[%s] %s: %s: %s %s\n",
+			now, priority, logger.trace, message, string(jsonFields))
+	}
+}
+
+func (logger *testLoggerTrace) Debug(args ...interface{}) {
+	if !logger.logger.IsLogLevelDebug() {
+		return
+	}
+	logger.log("DEBUG", fmt.Sprint(args...))
+}
+
+func (logger *testLoggerTrace) Info(args ...interface{}) {
+	logger.log("INFO", fmt.Sprint(args...))
+}
+
+func (logger *testLoggerTrace) Warning(args ...interface{}) {
+	logger.log("WARNING", fmt.Sprint(args...))
+}
+
+func (logger *testLoggerTrace) Error(args ...interface{}) {
+	logger.log("ERROR", fmt.Sprint(args...))
+}

+ 422 - 0
psiphon/common/inproxy/discovery.go

@@ -0,0 +1,422 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"net"
+	"sync"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/pion/stun"
+)
+
+const (
+	discoverNATTimeout          = 10 * time.Second
+	discoverNATRoundTripTimeout = 2 * time.Second
+)
+
+// NATDiscoverConfig specifies the configuration for a NATDiscover run.
+type NATDiscoverConfig struct {
+
+	// Logger is used to log events.
+	Logger common.Logger
+
+	// DialParameters specifies specific STUN and discovery and
+	// settings, and receives discovery results.
+	DialParameters DialParameters
+
+	// SkipPortMapping indicates whether to skip port mapping type discovery,
+	// as clients do since they will gather the same stats during the WebRTC
+	// offer preparation.
+	SkipPortMapping bool
+}
+
+// NATDiscover runs NAT type and port mapping type discovery operations.
+//
+// Successfuly results are delivered to NATDiscoverConfig.DialParameters
+// callbacks, SetNATType and SetPortMappingTypes, which should cache results
+// associated with the current network, by network ID.
+//
+// NAT discovery will invoke DialParameter callbacks
+// STUNServerAddressSucceeded and STUNServerAddressFailed, which may be used
+// to mark or unmark STUN servers for replay.
+func NATDiscover(
+	ctx context.Context,
+	config *NATDiscoverConfig) {
+
+	// Run discovery until the specified timeout, or ctx is done. NAT and port
+	// mapping discovery are run concurrently.
+
+	discoverCtx, cancelFunc := context.WithTimeout(
+		ctx, common.ValueOrDefault(config.DialParameters.DiscoverNATTimeout(), discoverNATTimeout))
+	defer cancelFunc()
+
+	discoveryWaitGroup := new(sync.WaitGroup)
+
+	if config.DialParameters.NATType().NeedsDiscovery() &&
+		!config.DialParameters.DisableSTUN() {
+
+		discoveryWaitGroup.Add(1)
+		go func() {
+			defer discoveryWaitGroup.Done()
+
+			natType, err := discoverNATType(discoverCtx, config)
+
+			if err == nil {
+				// Deliver the result. The DialParameters provider may cache
+				// this result, associated wih the current networkID.
+				config.DialParameters.SetNATType(natType)
+			}
+
+			config.Logger.WithTraceFields(common.LogFields{
+				"nat_type": natType.String(),
+				"error":    err,
+			}).Info("NAT type discovery")
+
+		}()
+	}
+
+	if !config.SkipPortMapping &&
+		config.DialParameters.PortMappingTypes().NeedsDiscovery() &&
+		!config.DialParameters.DisablePortMapping() {
+
+		discoveryWaitGroup.Add(1)
+		go func() {
+			defer discoveryWaitGroup.Done()
+
+			portMappingTypes, err := discoverPortMappingTypes(
+				discoverCtx, config.Logger)
+
+			if err == nil {
+				// Deliver the result. The DialParameters provider may cache
+				// this result, associated wih the current networkID.
+				config.DialParameters.SetPortMappingTypes(portMappingTypes)
+			}
+
+			config.Logger.WithTraceFields(common.LogFields{
+				"port_mapping_types": portMappingTypes.String(),
+				"error":              err,
+			}).Info("Port mapping type discovery")
+
+		}()
+	}
+
+	discoveryWaitGroup.Wait()
+}
+
+func discoverNATType(
+	ctx context.Context,
+	config *NATDiscoverConfig) (NATType, error) {
+
+	RFC5780 := true
+	stunServerAddress := config.DialParameters.STUNServerAddress(RFC5780)
+
+	if stunServerAddress == "" {
+		return NATTypeUnknown, errors.TraceNew("no RFC5780 STUN server")
+	}
+
+	// The STUN server will observe proxy IP addresses. Enumeration is
+	// mitigated by using various public STUN servers, including Psiphon STUN
+	// servers for proxies in non-censored regions. Proxies are also more
+	// ephemeral than Psiphon servers.
+
+	// Limitation: RFC5780, "4.1. Source Port Selection" recommends using the
+	// same source port for NAT discovery _and_ subsequent NAT traveral
+	// applications, such as WebRTC ICE. It's stated that the discovered NAT
+	// type may only be valid for the particular tested port.
+	//
+	// We don't do this at this time, as we don't want to incur the full
+	// RFC5780 discovery overhead for every WebRTC dial, and expect that, in
+	// most typical cases, the network NAT type applies to all ports.
+	// Furthermore, the UDP conn that owns the tested port may need to be
+	// closed to interrupt discovery.
+
+	conn, err := config.DialParameters.UDPListen()
+	if err != nil {
+		return NATTypeUnknown, errors.Trace(err)
+	}
+	defer conn.Close()
+
+	type result struct {
+		NATType NATType
+		err     error
+	}
+	resultChannel := make(chan result, 1)
+
+	go func() {
+
+		serverAddress, err := config.DialParameters.ResolveAddress(
+			ctx, stunServerAddress)
+		if err != nil {
+			resultChannel <- result{err: errors.Trace(err)}
+			return
+		}
+
+		mapping, err := discoverNATMapping(ctx, conn, serverAddress)
+		if err != nil {
+			resultChannel <- result{err: errors.Trace(err)}
+			return
+		}
+
+		filtering, err := discoverNATFiltering(ctx, conn, serverAddress)
+		if err != nil {
+			resultChannel <- result{err: errors.Trace(err)}
+			return
+		}
+
+		resultChannel <- result{NATType: MakeNATType(mapping, filtering)}
+		return
+	}()
+
+	var r result
+	select {
+	case r = <-resultChannel:
+	case <-ctx.Done():
+		r.err = errors.Trace(ctx.Err())
+		// Interrupt the goroutine
+		conn.Close()
+		<-resultChannel
+	}
+
+	if r.err != nil {
+
+		if ctx.Err() == nil {
+			config.DialParameters.STUNServerAddressFailed(RFC5780, stunServerAddress)
+		}
+
+		return NATTypeUnknown, errors.Trace(r.err)
+	}
+
+	config.DialParameters.STUNServerAddressSucceeded(RFC5780, stunServerAddress)
+
+	return r.NATType, nil
+}
+
+// discoverNATMapping and discoverNATFiltering are modifications of:
+// https://github.com/pion/stun/blob/b321a45be43b07685c639943aaa28e6841517799/cmd/stun-nat-behaviour/main.go
+
+// https://github.com/pion/stun/blob/b321a45be43b07685c639943aaa28e6841517799/LICENSE.md:
+/*
+Copyright 2018 Pion LLC
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+*/
+
+// RFC5780: 4.3.  Determining NAT Mapping Behavior
+func discoverNATMapping(
+	ctx context.Context,
+	conn net.PacketConn,
+	serverAddress string) (NATMapping, error) {
+
+	// Test I: Regular binding request
+
+	request := stun.MustBuild(stun.TransactionID, stun.BindingRequest)
+
+	response, _, err := doSTUNRoundTrip(request, conn, serverAddress)
+	if err != nil {
+		return NATMappingUnknown, errors.Trace(err)
+	}
+	responseFields := parseSTUNMessage(response)
+	if responseFields.xorAddr == nil || responseFields.otherAddr == nil {
+		return NATMappingUnknown, errors.TraceNew("NAT discovery not supported")
+	}
+	if responseFields.xorAddr.String() == conn.LocalAddr().String() {
+		return NATMappingEndpointIndependent, nil
+	}
+
+	otherAddress := responseFields.otherAddr
+
+	// Test II: Send binding request to the other address but primary port
+
+	_, serverPort, err := net.SplitHostPort(serverAddress)
+	if err != nil {
+		return NATMappingUnknown, errors.Trace(err)
+	}
+
+	address := net.JoinHostPort(otherAddress.IP.String(), serverPort)
+	response2, _, err := doSTUNRoundTrip(request, conn, address)
+	if err != nil {
+		return NATMappingUnknown, errors.Trace(err)
+	}
+	response2Fields := parseSTUNMessage(response2)
+	if response2Fields.xorAddr.String() == responseFields.xorAddr.String() {
+		return NATMappingEndpointIndependent, nil
+	}
+
+	// Test III: Send binding request to the other address and port
+
+	response3, _, err := doSTUNRoundTrip(request, conn, otherAddress.String())
+	if err != nil {
+		return NATMappingUnknown, errors.Trace(err)
+	}
+	response3Fields := parseSTUNMessage(response3)
+	if response3Fields.xorAddr.String() == response2Fields.xorAddr.String() {
+		return NATMappingAddressDependent, nil
+	} else {
+		return NATMappingAddressPortDependent, nil
+	}
+
+	return NATMappingUnknown, nil
+}
+
+// RFC5780: 4.4.  Determining NAT Filtering Behavior
+func discoverNATFiltering(
+	ctx context.Context,
+	conn net.PacketConn,
+	serverAddress string) (NATFiltering, error) {
+
+	// Test I: Regular binding request
+
+	request := stun.MustBuild(stun.TransactionID, stun.BindingRequest)
+	response, _, err := doSTUNRoundTrip(request, conn, serverAddress)
+	if err != nil {
+		return NATFilteringUnknown, errors.Trace(err)
+	}
+	responseFields := parseSTUNMessage(response)
+	if responseFields.xorAddr == nil || responseFields.otherAddr == nil {
+		return NATFilteringUnknown, errors.TraceNew("NAT discovery not supported")
+	}
+
+	// Test II: Request to change both IP and port
+
+	request = stun.MustBuild(stun.TransactionID, stun.BindingRequest)
+	request.Add(stun.AttrChangeRequest, []byte{0x00, 0x00, 0x00, 0x06})
+
+	response, responseTimeout, err := doSTUNRoundTrip(request, conn, serverAddress)
+	if err == nil {
+		return NATFilteringEndpointIndependent, nil
+	} else if !responseTimeout {
+		return NATFilteringUnknown, errors.Trace(err)
+	}
+
+	// Test III: Request to change port only
+
+	request = stun.MustBuild(stun.TransactionID, stun.BindingRequest)
+	request.Add(stun.AttrChangeRequest, []byte{0x00, 0x00, 0x00, 0x02})
+
+	response, responseTimeout, err = doSTUNRoundTrip(request, conn, serverAddress)
+	if err == nil {
+		return NATFilteringAddressDependent, nil
+	} else if !responseTimeout {
+		return NATFilteringUnknown, errors.Trace(err)
+	}
+
+	return NATFilteringAddressPortDependent, nil
+}
+
+func parseSTUNMessage(message *stun.Message) (ret struct {
+	xorAddr    *stun.XORMappedAddress
+	otherAddr  *stun.OtherAddress
+	respOrigin *stun.ResponseOrigin
+	mappedAddr *stun.MappedAddress
+	software   *stun.Software
+},
+) {
+	ret.mappedAddr = &stun.MappedAddress{}
+	ret.xorAddr = &stun.XORMappedAddress{}
+	ret.respOrigin = &stun.ResponseOrigin{}
+	ret.otherAddr = &stun.OtherAddress{}
+	ret.software = &stun.Software{}
+	if ret.xorAddr.GetFrom(message) != nil {
+		ret.xorAddr = nil
+	}
+	if ret.otherAddr.GetFrom(message) != nil {
+		ret.otherAddr = nil
+	}
+	if ret.respOrigin.GetFrom(message) != nil {
+		ret.respOrigin = nil
+	}
+	if ret.mappedAddr.GetFrom(message) != nil {
+		ret.mappedAddr = nil
+	}
+	if ret.software.GetFrom(message) != nil {
+		ret.software = nil
+	}
+	return ret
+}
+
+// doSTUNRoundTrip returns nil, true, nil on timeout reading a response.
+func doSTUNRoundTrip(
+	request *stun.Message,
+	conn net.PacketConn,
+	remoteAddress string) (*stun.Message, bool, error) {
+
+	remoteAddr, err := net.ResolveUDPAddr("udp", remoteAddress)
+	if err != nil {
+		return nil, false, errors.Trace(err)
+	}
+
+	_ = request.NewTransactionID()
+	_, err = conn.WriteTo(request.Raw, remoteAddr)
+	if err != nil {
+		return nil, false, errors.Trace(err)
+	}
+
+	conn.SetReadDeadline(time.Now().Add(discoverNATRoundTripTimeout))
+
+	var buffer [1500]byte
+	n, _, err := conn.ReadFrom(buffer[:])
+	if err != nil {
+		if e, ok := err.(net.Error); ok && e.Timeout() {
+			return nil, true, errors.Trace(err)
+		}
+		return nil, false, errors.Trace(err)
+	}
+
+	response := new(stun.Message)
+	response.Raw = buffer[:n]
+	err = response.Decode()
+	if err != nil {
+		return nil, false, errors.Trace(err)
+	}
+
+	return response, false, nil
+}
+
+func discoverPortMappingTypes(
+	ctx context.Context,
+	logger common.Logger) (PortMappingTypes, error) {
+
+	portMappingTypes, err := probePortMapping(ctx, logger)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return portMappingTypes, nil
+}

+ 130 - 0
psiphon/common/inproxy/discovery_test.go

@@ -0,0 +1,130 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"sync/atomic"
+	"testing"
+)
+
+func TestNATDiscovery(t *testing.T) {
+
+	// TODO: run local STUN and port mapping servers to test against, along
+	// with iptables rules to simulate NAT conditions
+
+	stunServerAddress := "stun.nextcloud.com:443"
+
+	var setNATTypeCallCount,
+		setPortMappingTypesCallCount,
+		stunServerAddressSucceededCallCount,
+		stunServerAddressFailedCallCount int32
+
+	dialParams := &testDialParameters{
+		stunServerAddress:        stunServerAddress,
+		stunServerAddressRFC5780: stunServerAddress,
+
+		setNATType: func(NATType) {
+			atomic.AddInt32(&setNATTypeCallCount, 1)
+		},
+
+		setPortMappingTypes: func(PortMappingTypes) {
+			atomic.AddInt32(&setPortMappingTypesCallCount, 1)
+		},
+
+		stunServerAddressSucceeded: func(RFC5780 bool, address string) {
+			atomic.AddInt32(&stunServerAddressSucceededCallCount, 1)
+			if address != stunServerAddress {
+				t.Errorf("unexpected STUN server address")
+			}
+		},
+
+		stunServerAddressFailed: func(RFC5780 bool, address string) {
+			atomic.AddInt32(&stunServerAddressFailedCallCount, 1)
+			if address != stunServerAddress {
+				t.Errorf("unexpected STUN server address")
+			}
+		},
+	}
+
+	checkCallCounts := func(a, b, c, d int32) {
+		callCount := atomic.LoadInt32(&setNATTypeCallCount)
+		if callCount != a {
+			t.Errorf(
+				"unexpected setNATType call count: %d",
+				callCount)
+		}
+
+		callCount = atomic.LoadInt32(&setPortMappingTypesCallCount)
+		if callCount != b {
+			t.Errorf(
+				"unexpected setPortMappingTypes call count: %d",
+				callCount)
+		}
+
+		callCount = atomic.LoadInt32(&stunServerAddressSucceededCallCount)
+		if callCount != c {
+			t.Errorf(
+				"unexpected stunServerAddressSucceeded call count: %d",
+				callCount)
+		}
+
+		callCount = atomic.LoadInt32(&stunServerAddressFailedCallCount)
+		if callCount != d {
+			t.Errorf(
+				"unexpected stunServerAddressFailedCallCount call count: %d",
+				callCount)
+		}
+	}
+
+	config := &NATDiscoverConfig{
+		Logger:         newTestLogger(),
+		DialParameters: dialParams,
+	}
+
+	// Should do STUN only
+
+	dialParams.disablePortMapping = true
+
+	NATDiscover(context.Background(), config)
+
+	checkCallCounts(1, 0, 1, 0)
+
+	// Should do port mapping only
+
+	dialParams.disableSTUN = true
+	dialParams.disablePortMapping = false
+
+	NATDiscover(context.Background(), config)
+
+	checkCallCounts(1, 1, 1, 0)
+
+	// Should skip both and use values cached in DialParameters
+
+	dialParams.disableSTUN = false
+	dialParams.disablePortMapping = false
+
+	NATDiscover(context.Background(), config)
+
+	checkCallCounts(1, 1, 1, 0)
+
+	t.Logf("NAT Type: %s", dialParams.NATType())
+	t.Logf("Port Mapping Types: %s", dialParams.PortMappingTypes())
+}

+ 131 - 0
psiphon/common/inproxy/doc.go

@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+/*
+Package inproxy enables 3rd party, ephemeral proxies to help Psiphon clients
+connect to the Psiphon network.
+
+The in-proxy architecture is inspired by and similar to Tor's snowflake
+pluggable transport, https://snowflake.torproject.org/.
+
+With in-proxy, Psiphon clients are matched with proxies by brokers run by the
+Psiphon network.
+
+In addition to proxies in unblocked regions, proxies in blocked regions are
+supported, to facilitate the use cases such as a local region hop from a
+mobile ISP, where international traffic may be expensive and throttled, to a
+home ISP, which may be less restricted.
+
+The proxy/server hop uses the full range of Psiphon tunnel protocols,
+providing blocking circumvention on the 2nd hop.
+
+Proxies don't create Psiphon tunnels, they just relay either TCP or UDP flows
+from the client to the server, where those flows are Psiphon tunnel
+protocols. Proxies don't need to be upgraded in order to rely newer Psiphon
+tunnel protocols or protocol variants.
+
+Proxies cannot see the client traffic within the relayed Psiphon tunnel.
+Brokers verify that client destinations are valid Psiphon servers only, so
+proxies cannot be misused for non-Psiphon relaying.
+
+To limit the set of Psiphon servers that proxies can observe and enumerate,
+client destinations are limited to the set of servers specifically designated
+with in-proxy capabilities. This is enforced by the broker.
+
+Proxies are compartmentalized in two ways; (1) personal proxies will use a
+personal compartment ID to limit access to clients run by users with whom the
+proxy operator has shared, out-of-band, a personal compartment ID, or access
+token; (2) common proxies will be assigned a common compartment ID by the
+Psiphon network to limit access to clients that have obtained the common
+compartment ID, or access token, from Psiphon through channels such as
+targeted tactics or embedded in OSLs.
+
+Proxies are expected to be run for longer periods, on desktop computers. The
+in-proxy design does not currently support browser extension or website
+widget proxies.
+
+The client/proxy hop uses WebRTC, with the broker playing the role of a WebRTC
+signaling server in addition to matching clients and proxies. Clients and
+proxies gather ICE candidates, including any host candidates, IPv4 or IPv6,
+as well as STUN server reflexive candidates. In addition, any available port
+mapping protocols -- UPnP-IGD, NAT-PMP, PCP -- are used to gather port
+mapping candidates, which are injected into ICE SDPs as host candidates. TURN
+candidates are not used.
+
+NAT topology discovery is performed and metrics sent to broker to optimize
+utility and matching of proxies to clients. Mobile networks may be assumed to
+be CGNAT in case NAT discovery fails or is skipped. And, for mobile networks,
+there is an option to skip discovery and STUN for a faster dial.
+
+The client-proxy is a WebRTC data channel; on the wire, it is DTLS, preceded
+by an ICE STUN packet. By default, WebRTC DTLS is configured to look like
+common browsers. In addition, the DTLS ClientHello can be randomized. Proxy
+endpoints are ephemeral, but if they were to be scanned or probed, the
+response should look like common WebRTC stacks that receive packets from
+invalid peers.
+
+Clients and proxies connect to brokers via a domain fronting transport; the
+transport is abstracted and other channels may be provided. Within that
+transport, a Noise protocol framework session is established between
+clients/proxies and a broker, to ensure privacy, authentication, and replay
+defense between the end points; not even a domain fronting CDN can observe
+the transactions within a session. The session has an additional obfuscation
+layer that renders the messages as fully random, which may be suitable for
+encapsulating in plaintext transports;  adds random padding; and detects
+replay of any message.
+
+For clients and proxies, all broker and WebRTC dial parameters, including
+domain fronting, STUN server selection, NAT discovery behavior, timeouts, and
+so on are remotely configurable via Psiphon tactics. Callbacks facilitate
+replay of successful dial parameters for individual stages of a dial,
+including a successful broker connection, or a working STUN server.
+
+For each proxied client tunnel, brokers use secure sessions to send the
+destination Psiphon server a message indicating the proxy ID that's relaying
+the client's traffic, the original client IP, and additional metrics to be
+logged with the server_tunnel log for the tunnel. Neither a client nor a
+proxy is trusted to report the original client IP or the proxy ID.
+
+Instead of having the broker connect out to Psiphon servers, and trying to
+synchronize reliable arrival of these messages, the broker uses the client to
+relay secure session packets -- the message and response, preceded by a
+session handshake if required. These session packets piggyback on top of
+client/broker and client/server round trips that happen anyway, including the
+Psiphon API handshake.
+
+Psiphon servers with in-proxy capabilities should be configured, on in-proxy
+listeners, to require receipt of this broker message before finalizing
+traffic rules, issuing tactics, issuing OSL progress, or allowing traffic
+tunneling. The original client IP reported by the broker should be used for
+all client GeoIP policy decisions and logging.
+
+The proxy ID is the proxy's secure session public key; the proxy proves
+possession of the corresponding private key in the session handshake. Proxy
+IDs are not revealed to clients; only to brokers and Psiphon servers. A proxy
+may maintain a long-term key pair and corresponding proxy ID, and that may be
+used by Psiphon to assign reputation to well-performing proxies or to issue
+rewards for proxies.
+
+The proxy is designed to be bundled with the tunnel-core client, run
+optionally, and integrated with its tactics, data store, and logging. The
+broker is designed to be bundled with the Psiphon server, psiphond, and, like
+tactics requests, run under MeekServer; and use the tactics, psinet database,
+GeoIP services, and logging services provided by psiphond.
+*/
+package inproxy

+ 101 - 0
psiphon/common/inproxy/dtls.go

@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"net"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	lrucache "github.com/cognusion/go-cache-lru"
+)
+
+// dtlsSeedCacheTTL should be long enough for the seed to remain available in
+// the cache between when it's first set at the start of WebRTC operations,
+// and until all DTLS dials have completed.
+
+const (
+	dtlsSeedCacheTTL     = 60 * time.Second
+	dtlsSeedCacheMaxSize = 10000
+)
+
+// SetDTLSSeed establishes a cached common/prng seed to be used when
+// randomizing DTLS ClientHellos.
+//
+// The seed is keyed by the specified conn's local address. This allows a fork
+// of pion/dtls to fetch the seed and apply randomization without having to
+// fork many pion layers to pass in seeds. Concurrent dials must use distinct
+// conns with distinct local addresses (including port number).
+//
+// Both sides of a WebRTC connection may randomize their ClientHello. isOffer
+// allows the same seed to be used, but produce two distinct random streams.
+// The client generates or replays an obfuscation secret used to derive the
+// seed, and the obfuscation secret is relayed to the proxy by the Broker.
+//
+// The caller may specify TTL, which can be used to retain the cached key for
+// a dial timeout duration; when TTL is <= 0, a default TTL is used.
+func SetDTLSSeed(conn net.PacketConn, obfuscationSecret ObfuscationSecret, isOffer bool, TTL time.Duration) error {
+
+	if len(obfuscationSecret) != prng.SEED_LENGTH {
+		return errors.TraceNew("unexpected obfuscation secret length")
+	}
+
+	var baseSeed prng.Seed
+	copy(baseSeed[:], obfuscationSecret[:])
+
+	salt := "inproxy-client-DTLS-seed"
+	if !isOffer {
+		salt = "inproxy-proxy-DTLS-seed"
+	}
+
+	seed, err := prng.NewSaltedSeed(&baseSeed, salt)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if TTL <= 0 {
+		TTL = lrucache.DefaultExpiration
+	}
+
+	// In the case where a previously used local port number is reused in a
+	// new dial, this will replace the previous seed.
+
+	dtlsSeedCache.Set(conn.LocalAddr().String(), seed, TTL)
+
+	return nil
+}
+
+// GetDTLSSeed fetches a seed established by SetDTLSSeed, or returns an error
+// if no seed is found for the specified conn.
+func GetDTLSSeed(conn *net.UDPConn) (*prng.Seed, error) {
+	seed, ok := dtlsSeedCache.Get(conn.LocalAddr().String())
+	if !ok {
+		return nil, errors.TraceNew("missing seed")
+	}
+	return seed.(*prng.Seed), nil
+}
+
+var dtlsSeedCache *lrucache.Cache
+
+func init() {
+	dtlsSeedCache = lrucache.NewWithLRU(
+		dtlsSeedCacheTTL, 1*time.Minute, dtlsSeedCacheMaxSize)
+}

+ 809 - 0
psiphon/common/inproxy/inproxy_test.go

@@ -0,0 +1,809 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"bytes"
+	"context"
+	"crypto/tls"
+	"encoding/base64"
+	"encoding/json"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"net"
+	"net/http"
+	"strconv"
+	"sync"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
+	"golang.org/x/sync/errgroup"
+)
+
+func TestInProxy(t *testing.T) {
+	err := runTestInProxy()
+	if err != nil {
+		t.Errorf(errors.Trace(err).Error())
+	}
+}
+
+func runTestInProxy() error {
+
+	// Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
+
+	numProxies := 5
+	proxyMaxClients := 2
+	numClients := 10
+
+	bytesToSend := 1 << 20
+	messageSize := 1 << 10
+	targetElapsedSeconds := 2
+
+	baseMetrics := common.APIParameters{
+		"sponsor_id":      "test-sponsor-id",
+		"client_platform": "test-client-platform",
+	}
+
+	testTransportSecret, _ := MakeID()
+
+	testCompartmentID, _ := MakeID()
+	testCommonCompartmentIDs := []ID{testCompartmentID}
+
+	testNetworkID := "NETWORK-ID-1"
+	testNetworkType := NetworkTypeUnknown
+	testNATType := NATTypeUnknown
+	testSTUNServerAddress := "stun.nextcloud.com:443"
+
+	// TODO: test port mapping
+
+	stunServerAddressSucceededCount := int32(0)
+	stunServerAddressSucceeded := func(bool, string) { atomic.AddInt32(&stunServerAddressSucceededCount, 1) }
+	stunServerAddressFailedCount := int32(0)
+	stunServerAddressFailed := func(bool, string) { atomic.AddInt32(&stunServerAddressFailedCount, 1) }
+
+	roundTripperSucceededCount := int32(0)
+	roundTripperSucceded := func(RoundTripper) { atomic.AddInt32(&roundTripperSucceededCount, 1) }
+	roundTripperFailedCount := int32(0)
+	roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
+
+	testCtx, stopTest := context.WithCancel(context.Background())
+	defer stopTest()
+
+	testGroup := new(errgroup.Group)
+
+	// Enable test to run without requiring host firewall exceptions
+	setAllowLoopbackWebRTCConnections(true)
+
+	// Init logging
+
+	logger := newTestLogger()
+
+	// Start echo servers
+
+	tcpEchoListener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer tcpEchoListener.Close()
+	go runTCPEchoServer(tcpEchoListener)
+
+	// QUIC tests UDP proxying, and provides reliable delivery of echoed data
+	quicEchoServer, err := newQuicEchoServer()
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer quicEchoServer.Close()
+	go quicEchoServer.Run()
+
+	// Create signed server entry with capability
+
+	serverPrivateKey, err := GenerateSessionPrivateKey()
+	if err != nil {
+		return errors.Trace(err)
+	}
+	serverPublicKey, err := GetSessionPublicKey(serverPrivateKey)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	serverRootObfuscationSecret, err := GenerateRootObfuscationSecret()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	serverEntry := make(protocol.ServerEntryFields)
+	serverEntry["ipAddress"] = "127.0.0.1"
+	_, tcpPort, _ := net.SplitHostPort(tcpEchoListener.Addr().String())
+	_, udpPort, _ := net.SplitHostPort(quicEchoServer.Addr().String())
+	serverEntry["sshObfuscatedPort"], _ = strconv.Atoi(tcpPort)
+	serverEntry["sshObfuscatedQUICPort"], _ = strconv.Atoi(udpPort)
+	serverEntry["capabilities"] = []string{"OSSH", "QUIC", "inproxy"}
+	serverEntry["inProxySessionPublicKey"] = base64.StdEncoding.EncodeToString(serverPublicKey[:])
+	serverEntry["inProxySessionRootObfuscationSecret"] = base64.StdEncoding.EncodeToString(serverRootObfuscationSecret[:])
+	testServerEntryTag := prng.HexString(16)
+	serverEntry["tag"] = testServerEntryTag
+
+	serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey, err :=
+		protocol.NewServerEntrySignatureKeyPair()
+	if err != nil {
+		return errors.Trace(err)
+	}
+	err = serverEntry.AddSignature(serverEntrySignaturePublicKey, serverEntrySignaturePrivateKey)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	serverEntryJSON, err := json.Marshal(serverEntry)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Start broker
+
+	brokerPrivateKey, err := GenerateSessionPrivateKey()
+	if err != nil {
+		return errors.Trace(err)
+	}
+	brokerPublicKey, err := GetSessionPublicKey(brokerPrivateKey)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	brokerRootObfuscationSecret, err := GenerateRootObfuscationSecret()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	brokerListener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer brokerListener.Close()
+
+	brokerConfig := &BrokerConfig{
+
+		Logger: logger,
+
+		CommonCompartmentIDs: testCommonCompartmentIDs,
+
+		APIParameterValidator: func(params common.APIParameters) error {
+			if len(params) != len(baseMetrics) {
+				return errors.TraceNew("unexpected base metrics")
+			}
+			for name, value := range params {
+				if value.(string) != baseMetrics[name].(string) {
+					return errors.TraceNew("unexpected base metrics")
+				}
+			}
+			return nil
+		},
+
+		APIParameterLogFieldFormatter: func(
+			geoIPData common.GeoIPData, params common.APIParameters) common.LogFields {
+			return common.LogFields(params)
+		},
+
+		TransportSecret: TransportSecret(testTransportSecret),
+
+		PrivateKey: brokerPrivateKey,
+
+		ObfuscationRootSecret: brokerRootObfuscationSecret,
+
+		ServerEntrySignaturePublicKey: serverEntrySignaturePublicKey,
+
+		IsValidServerEntryTag: func(serverEntryTag string) bool { return serverEntryTag == testServerEntryTag },
+
+		AllowProxy:             func(common.GeoIPData) bool { return true },
+		AllowClient:            func(common.GeoIPData) bool { return true },
+		AllowDomainDestination: func(common.GeoIPData) bool { return true },
+	}
+
+	broker, err := NewBroker(brokerConfig)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = broker.Start()
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer broker.Stop()
+
+	testGroup.Go(func() error {
+		err := runHTTPServer(brokerListener, broker)
+		if testCtx.Err() != nil {
+			return nil
+		}
+		return errors.Trace(err)
+	})
+
+	// Stub server broker request handler (in Psiphon, this will be the
+	// destination Psiphon server; here, it's not necessary to build this
+	// handler into the destination echo server)
+
+	serverSessions, err := NewServerBrokerSessions(
+		serverPrivateKey, serverRootObfuscationSecret, []SessionPublicKey{brokerPublicKey})
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	var pendingBrokerServerRequestsMutex sync.Mutex
+	pendingBrokerServerRequests := make(map[ID]bool)
+
+	addPendingBrokerServerRequest := func(connectionID ID) {
+		pendingBrokerServerRequestsMutex.Lock()
+		defer pendingBrokerServerRequestsMutex.Unlock()
+		pendingBrokerServerRequests[connectionID] = true
+	}
+
+	hasPendingBrokerServerRequests := func() bool {
+		pendingBrokerServerRequestsMutex.Lock()
+		defer pendingBrokerServerRequestsMutex.Unlock()
+		return len(pendingBrokerServerRequests) > 0
+	}
+
+	handleBrokerServerRequests := func(in []byte, clientConnectionID ID) ([]byte, error) {
+
+		handler := func(brokerVerifiedOriginalClientIP string, logFields common.LogFields) {
+			pendingBrokerServerRequestsMutex.Lock()
+			defer pendingBrokerServerRequestsMutex.Unlock()
+
+			// Mark the request as no longer outstanding
+			delete(pendingBrokerServerRequests, clientConnectionID)
+		}
+
+		out, err := serverSessions.HandlePacket(logger, in, clientConnectionID, handler)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		return out, nil
+	}
+
+	// Start proxies
+
+	for i := 0; i < numProxies; i++ {
+
+		proxyPrivateKey, err := GenerateSessionPrivateKey()
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		dialParams := &testDialParameters{
+			networkID:                  testNetworkID,
+			networkType:                testNetworkType,
+			natType:                    testNATType,
+			stunServerAddress:          testSTUNServerAddress,
+			stunServerAddressRFC5780:   testSTUNServerAddress,
+			stunServerAddressSucceeded: stunServerAddressSucceeded,
+			stunServerAddressFailed:    stunServerAddressFailed,
+
+			brokerClientPrivateKey:      proxyPrivateKey,
+			brokerPublicKey:             brokerPublicKey,
+			brokerRootObfuscationSecret: brokerRootObfuscationSecret,
+			brokerClientRoundTripper: newHTTPRoundTripper(
+				brokerListener.Addr().String(), "proxy"),
+			brokerClientRoundTripperSucceeded: roundTripperSucceded,
+			brokerClientRoundTripperFailed:    roundTripperFailed,
+
+			setNATType:          func(NATType) {},
+			setPortMappingTypes: func(PortMappingTypes) {},
+			bindToDevice:        func(int) error { return nil },
+		}
+
+		proxy, err := NewProxy(&ProxyConfig{
+			Logger:                        logger,
+			BaseMetrics:                   baseMetrics,
+			DialParameters:                dialParams,
+			MaxClients:                    proxyMaxClients,
+			LimitUpstreamBytesPerSecond:   bytesToSend / targetElapsedSeconds,
+			LimitDownstreamBytesPerSecond: bytesToSend / targetElapsedSeconds,
+			ActivityUpdater: func(connectingClients int32, connectedClients int32,
+				bytesUp int64, bytesDown int64, bytesDuration time.Duration) {
+
+				fmt.Printf("[%s] ACTIVITY: %d connecting, %d connected, %d up, %d down\n",
+					time.Now().UTC().Format(time.RFC3339),
+					connectingClients, connectedClients, bytesUp, bytesDown)
+			},
+		})
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		testGroup.Go(func() error {
+			proxy.Run(testCtx)
+			return nil
+		})
+	}
+
+	// Run clients
+
+	clientsGroup := new(errgroup.Group)
+
+	makeClientFunc := func(
+		isTCP bool,
+		isMobile bool,
+		dialParams DialParameters,
+		brokerClient *BrokerClient) func() error {
+
+		var networkProtocol NetworkProtocol
+		var addr string
+		var wrapWithQUIC bool
+		if isTCP {
+			networkProtocol = NetworkProtocolTCP
+			addr = tcpEchoListener.Addr().String()
+		} else {
+			networkProtocol = NetworkProtocolUDP
+			addr = quicEchoServer.Addr().String()
+			wrapWithQUIC = true
+		}
+
+		return func() error {
+
+			dialCtx, cancelDial := context.WithTimeout(testCtx, 30*time.Second)
+			defer cancelDial()
+
+			conn, err := DialClient(
+				dialCtx,
+				&ClientConfig{
+					Logger:                     logger,
+					BaseMetrics:                baseMetrics,
+					DialParameters:             dialParams,
+					BrokerClient:               brokerClient,
+					ReliableTransport:          isTCP,
+					DialNetworkProtocol:        networkProtocol,
+					DialAddress:                addr,
+					DestinationServerEntryJSON: serverEntryJSON,
+				})
+			if err != nil {
+				return errors.Trace(err)
+			}
+
+			var relayConn net.Conn
+			relayConn = conn
+
+			if wrapWithQUIC {
+				quicConn, err := quic.Dial(
+					dialCtx,
+					conn,
+					&net.UDPAddr{Port: 1}, // This address is ignored, but the zero value is not allowed
+					"test", "QUICv1", nil, quicEchoServer.ObfuscationKey(), nil, nil, true)
+				if err != nil {
+					return errors.Trace(err)
+				}
+				relayConn = quicConn
+			}
+
+			addPendingBrokerServerRequest(conn.GetConnectionID())
+			signalRelayComplete := make(chan struct{})
+
+			clientsGroup.Go(func() error {
+				defer close(signalRelayComplete)
+
+				in := conn.InitialRelayPacket()
+				for in != nil {
+					out, err := handleBrokerServerRequests(in, conn.GetConnectionID())
+
+					// In general, trying to use an expired session results in an expected error...
+					sessionInvalid := err != nil
+
+					// ...but no error is expected in this test run.
+					if err != nil {
+						fmt.Printf("handleBrokerServerRequests failed: %v\n", err)
+					}
+
+					in, err = conn.RelayPacket(testCtx, out, sessionInvalid)
+					if err != nil {
+						return errors.Trace(err)
+					}
+				}
+
+				return nil
+			})
+
+			sendBytes := prng.Bytes(bytesToSend)
+
+			clientsGroup.Go(func() error {
+				for n := 0; n < bytesToSend; n += messageSize {
+					m := messageSize
+					if bytesToSend-n < m {
+						m = bytesToSend - n
+					}
+					_, err := relayConn.Write(sendBytes[n : n+m])
+					if err != nil {
+						return errors.Trace(err)
+					}
+				}
+				fmt.Printf("%d bytes sent\n", bytesToSend)
+				return nil
+			})
+
+			clientsGroup.Go(func() error {
+				buf := make([]byte, messageSize)
+				n := 0
+				for n < bytesToSend {
+					m, err := relayConn.Read(buf)
+					if err != nil {
+						return errors.Trace(err)
+					}
+					if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
+						return errors.Tracef(
+							"unexpected bytes: expected at index %d, received at index %d",
+							bytes.Index(sendBytes, buf[:m]), n)
+					}
+					n += m
+				}
+				fmt.Printf("%d bytes received\n", bytesToSend)
+
+				select {
+				case <-signalRelayComplete:
+				case <-testCtx.Done():
+				}
+
+				relayConn.Close()
+				conn.Close()
+
+				return nil
+			})
+
+			return nil
+		}
+	}
+
+	newClientParams := func(isMobile bool) (*testDialParameters, *BrokerClient, error) {
+
+		clientPrivateKey, err := GenerateSessionPrivateKey()
+		if err != nil {
+			return nil, nil, errors.Trace(err)
+		}
+
+		clientRootObfuscationSecret, err := GenerateRootObfuscationSecret()
+		if err != nil {
+			return nil, nil, errors.Trace(err)
+		}
+
+		dialParams := &testDialParameters{
+			commonCompartmentIDs: testCommonCompartmentIDs,
+
+			networkID:                  testNetworkID,
+			networkType:                testNetworkType,
+			natType:                    testNATType,
+			stunServerAddress:          testSTUNServerAddress,
+			stunServerAddressRFC5780:   testSTUNServerAddress,
+			stunServerAddressSucceeded: stunServerAddressSucceeded,
+			stunServerAddressFailed:    stunServerAddressFailed,
+
+			brokerClientPrivateKey:      clientPrivateKey,
+			brokerPublicKey:             brokerPublicKey,
+			brokerRootObfuscationSecret: brokerRootObfuscationSecret,
+			brokerClientRoundTripper: newHTTPRoundTripper(
+				brokerListener.Addr().String(), "client"),
+			brokerClientRoundTripperSucceeded: roundTripperSucceded,
+			brokerClientRoundTripperFailed:    roundTripperFailed,
+
+			clientRootObfuscationSecret: clientRootObfuscationSecret,
+			doDTLSRandomization:         true,
+
+			setNATType:          func(NATType) {},
+			setPortMappingTypes: func(PortMappingTypes) {},
+			bindToDevice:        func(int) error { return nil },
+		}
+
+		if isMobile {
+			dialParams.networkType = NetworkTypeMobile
+			dialParams.disableInboundForMobleNetworks = true
+		}
+
+		brokerClient, err := NewBrokerClient(dialParams)
+		if err != nil {
+			return nil, nil, errors.Trace(err)
+		}
+
+		return dialParams, brokerClient, nil
+	}
+
+	clientDialParams, clientBrokerClient, err := newClientParams(false)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	clientMobileDialParams, clientMobileBrokerClient, err := newClientParams(true)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	for i := 0; i < numClients; i++ {
+
+		// Test a mix of TCP and UDP proxying; also test the
+		// DisableInboundForMobleNetworks code path.
+
+		isTCP := i%2 == 0
+		isMobile := i%4 == 0
+
+		// Exercise BrokerClients shared by multiple clients, but also create
+		// several broker clients.
+		if i%8 == 0 {
+			clientDialParams, clientBrokerClient, err = newClientParams(false)
+			if err != nil {
+				return errors.Trace(err)
+			}
+
+			clientMobileDialParams, clientMobileBrokerClient, err = newClientParams(true)
+			if err != nil {
+				return errors.Trace(err)
+			}
+		}
+
+		dialParams := clientDialParams
+		brokerClient := clientBrokerClient
+		if isMobile {
+			dialParams = clientMobileDialParams
+			brokerClient = clientMobileBrokerClient
+		}
+
+		clientsGroup.Go(makeClientFunc(isTCP, isMobile, dialParams, brokerClient))
+	}
+
+	// Await client transfers complete
+
+	err = clientsGroup.Wait()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if hasPendingBrokerServerRequests() {
+		return errors.TraceNew("unexpected pending broker server requests")
+	}
+
+	// Await shutdowns
+
+	stopTest()
+	brokerListener.Close()
+
+	err = testGroup.Wait()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// TODO: check that elapsed time is consistent with rate limit (+/-)
+
+	// Check if STUN server replay callbacks were triggered
+	if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
+		return errors.TraceNew("unexpected STUN server succeeded count")
+	}
+	if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
+		return errors.TraceNew("unexpected STUN server failed count")
+	}
+
+	// Check if RoundTripper server replay callbacks were triggered
+	if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
+		return errors.TraceNew("unexpected round tripper succeeded count")
+	}
+	if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
+		return errors.TraceNew("unexpected round tripper failed count")
+	}
+
+	return nil
+}
+
+func runHTTPServer(listener net.Listener, broker *Broker) error {
+
+	httpServer := &http.Server{
+		ReadTimeout:  BrokerReadTimeout,
+		WriteTimeout: BrokerWriteTimeout,
+		IdleTimeout:  BrokerIdleTimeout,
+		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
+			// For this test, clients set the path to "/client" and proxies
+			// set the path to "/proxy" and we use that to create stub GeoIP
+			// data to pass the not-same-ASN condition.
+			var geoIPData common.GeoIPData
+			geoIPData.ASN = r.URL.Path
+
+			// Not an actual HTTP header in this test.
+			transportSecret := broker.config.TransportSecret
+
+			requestPayload, err := ioutil.ReadAll(
+				http.MaxBytesReader(w, r.Body, BrokerMaxRequestBodySize))
+			if err != nil {
+				fmt.Printf("runHTTPServer ioutil.ReadAll failed: %v\n", err)
+				http.Error(w, "", http.StatusNotFound)
+				return
+			}
+			clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
+			responsePayload, err := broker.HandleSessionPacket(
+				r.Context(),
+				transportSecret,
+				clientIP,
+				geoIPData,
+				requestPayload)
+			if err != nil {
+				fmt.Printf("runHTTPServer HandleSessionPacket failed: %v", err)
+				http.Error(w, "", http.StatusNotFound)
+				return
+			}
+			w.WriteHeader(http.StatusOK)
+			w.Write(responsePayload)
+		}),
+	}
+
+	certificate, privateKey, err := common.GenerateWebServerCertificate("www.example.com")
+	if err != nil {
+		return errors.Trace(err)
+	}
+	tlsCert, err := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	tlsConfig := &tls.Config{
+		Certificates: []tls.Certificate{tlsCert},
+	}
+
+	err = httpServer.Serve(tls.NewListener(listener, tlsConfig))
+	return errors.Trace(err)
+}
+
+type httpRoundTripper struct {
+	httpClient   *http.Client
+	endpointAddr string
+	path         string
+}
+
+func newHTTPRoundTripper(endpointAddr string, path string) *httpRoundTripper {
+	return &httpRoundTripper{
+		httpClient: &http.Client{
+			Transport: &http.Transport{
+				ForceAttemptHTTP2:   true,
+				MaxIdleConns:        2,
+				IdleConnTimeout:     BrokerIdleTimeout,
+				TLSHandshakeTimeout: 1 * time.Second,
+				TLSClientConfig: &tls.Config{
+					InsecureSkipVerify: true,
+				},
+			},
+		},
+		endpointAddr: endpointAddr,
+		path:         path,
+	}
+}
+
+func (r *httpRoundTripper) RoundTrip(
+	ctx context.Context, requestPayload []byte) ([]byte, error) {
+
+	url := fmt.Sprintf("https://%s/%s", r.endpointAddr, r.path)
+
+	request, err := http.NewRequestWithContext(
+		ctx, "POST", url, bytes.NewReader(requestPayload))
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	response, err := r.httpClient.Do(request)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	defer response.Body.Close()
+
+	if response.StatusCode != http.StatusOK {
+		return nil, errors.Tracef("unexpected response status code: %d", response.StatusCode)
+	}
+
+	responsePayload, err := io.ReadAll(response.Body)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return responsePayload, nil
+}
+
+func (r *httpRoundTripper) Close() error {
+	r.httpClient.CloseIdleConnections()
+	return nil
+}
+
+func runTCPEchoServer(listener net.Listener) {
+
+	for {
+		conn, err := listener.Accept()
+		if err != nil {
+			fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
+			return
+		}
+		go func(conn net.Conn) {
+			buf := make([]byte, 1024)
+			for {
+				n, err := conn.Read(buf)
+				if n > 0 {
+					_, err = conn.Write(buf[:n])
+				}
+				if err != nil {
+					fmt.Printf("runTCPEchoServer failed: %v\n", errors.Trace(err))
+					return
+				}
+			}
+		}(conn)
+	}
+}
+
+type quicEchoServer struct {
+	listener       net.Listener
+	obfuscationKey string
+}
+
+func newQuicEchoServer() (*quicEchoServer, error) {
+
+	obfuscationKey := prng.HexString(32)
+
+	listener, err := quic.Listen(
+		nil,
+		nil,
+		"127.0.0.1:0",
+		obfuscationKey,
+		false)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return &quicEchoServer{
+		listener:       listener,
+		obfuscationKey: obfuscationKey,
+	}, nil
+}
+
+func (q *quicEchoServer) ObfuscationKey() string {
+	return q.obfuscationKey
+}
+
+func (q *quicEchoServer) Close() error {
+	return q.listener.Close()
+}
+
+func (q *quicEchoServer) Addr() net.Addr {
+	return q.listener.Addr()
+}
+
+func (q *quicEchoServer) Run() {
+
+	for {
+		conn, err := q.listener.Accept()
+		if err != nil {
+			fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
+			return
+		}
+		go func(conn net.Conn) {
+			buf := make([]byte, 1024)
+			for {
+				n, err := conn.Read(buf)
+				if n > 0 {
+					_, err = conn.Write(buf[:n])
+				}
+				if err != nil {
+					fmt.Printf("quicEchoServer failed: %v\n", errors.Trace(err))
+					return
+				}
+			}
+		}(conn)
+	}
+}

+ 749 - 0
psiphon/common/inproxy/matcher.go

@@ -0,0 +1,749 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+package inproxy
+
+import (
+	"context"
+	"sync"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	lrucache "github.com/cognusion/go-cache-lru"
+	"github.com/gammazero/deque"
+	"github.com/pion/webrtc/v3"
+)
+
+// TTLs should be aligned with STUN hole punch lifetimes.
+
+const (
+	matcherAnnouncementQueueMaxSize = 100000
+	matcherOfferQueueMaxSize        = 100000
+	matcherPendingAnswersTTL        = 30 * time.Second
+	matcherPendingAnswersMaxSize    = 100000
+)
+
+// Matcher matches proxy announcements with client offers. Matcher also
+// coordinates pending proxy answers and routes answers to the awaiting
+// client offer handler.
+//
+// Matching prioritizes selecting the oldest announcments and client offers,
+// as they are closest to timing out.
+//
+// The client and proxy must supply matching personal or common compartment
+// IDs. Personal compartment matching is preferred. Common compartments are
+// managed by Psiphon and can be obtained via a tactics parameter or via an
+// OSL embedding.
+//
+// Matching prefers to pair proxies and clients in a way that maximizes total
+// possible matches. For a client or proxy with less-limited NAT traversal, a
+// pairing with more-limited NAT traversal is preferred; and vice versa.
+// Candidates with unknown NAT types and mobile network types are assumed to
+// have the most limited NAT traversal capability.
+//
+// Preferred matchings take priority over announcment age.
+//
+// The client and proxy will not match if they are in the same country and
+// ASN, as it's assumed that doesn't provide any blocking circumvention
+// benefit. Disallowing proxies in certain blocked countries is handled at a
+// higher level; any such proxies should not be enqueued for matching.
+type Matcher struct {
+	config *MatcherConfig
+
+	runMutex    sync.Mutex
+	runContext  context.Context
+	stopRunning context.CancelFunc
+	waitGroup   *sync.WaitGroup
+
+	// The announcement queue is implicitly sorted by announcement age. The
+	// count fields are used to skip searching deeper into the queue for
+	// preferred matches.
+
+	// TODO: replace queue and counts with an indexed, in-memory database?
+
+	announcementQueueMutex                      sync.Mutex
+	announcementQueue                           *deque.Deque[*announcementEntry]
+	announcementsPersonalCompartmentalizedCount int
+	announcementsUnlimitedNATCount              int
+	announcementsPartiallyLimitedNATCount       int
+	announcementsStrictlyLimitedNATCount        int
+
+	// The offer queue is also implicitly sorted by offer age. Both an offer
+	// and announcement queue are required since either announcements or
+	// offers can arrive while there are no available pairings.
+
+	offerQueueMutex sync.Mutex
+	offerQueue      *deque.Deque[*offerEntry]
+
+	matchSignal chan struct{}
+
+	pendingAnswers *lrucache.Cache
+}
+
+// MatchProperties specifies the compartment, GeoIP, and network topology
+// matching roperties of clients and proxies.
+type MatchProperties struct {
+	CommonCompartmentIDs   []ID
+	PersonalCompartmentIDs []ID
+	GeoIPData              common.GeoIPData
+	NetworkType            NetworkType
+	NATType                NATType
+	PortMappingTypes       PortMappingTypes
+}
+
+// EffectiveNATType combines the set of network properties into an effective
+// NAT type. When a port mapping is offered, a NAT type with unlimiter NAT
+// traversal is assumed. When NAT type is unknown and the network type is
+// mobile, CGNAT with limited NAT traversal is assumed.
+func (p *MatchProperties) EffectiveNATType() NATType {
+
+	if p.PortMappingTypes.Available() {
+		return NATTypePortMapping
+	}
+
+	// TODO: can a peer have limited NAT travseral for IPv4 and also have a
+	// publicly reachable IPv6 ICE host candidate? If so, change the
+	// effective NAT type? Depends on whether the matched peer can use IPv6.
+
+	if p.NATType == NATTypeUnknown && p.NetworkType == NetworkTypeMobile {
+		return NATTypeMobileNetwork
+	}
+
+	return p.NATType
+}
+
+// ExistsPreferredNATMatch indicates whether there exists a preferred NAT
+// matching given the types of pairing candidates available.
+func (p *MatchProperties) ExistsPreferredNATMatch(
+	unlimitedNAT, partiallyLimitedNAT, limitedNAT bool) bool {
+
+	return p.EffectiveNATType().ExistsPreferredMatch(
+		unlimitedNAT, partiallyLimitedNAT, limitedNAT)
+}
+
+// IsPreferredNATMatch indicates whether the peer candidate is a preferred
+// NAT matching.
+func (p *MatchProperties) IsPreferredNATMatch(
+	peerMatchProperties *MatchProperties) bool {
+
+	return p.EffectiveNATType().IsPreferredMatch(
+		peerMatchProperties.EffectiveNATType())
+}
+
+// IsPersonalCompartmentalized indicates whether the candidate has personal
+// compartment IDs.
+func (p *MatchProperties) IsPersonalCompartmentalized() bool {
+	return len(p.PersonalCompartmentIDs) > 0
+}
+
+// MatchAnnouncement is a proxy announcement to be queued for matching.
+type MatchAnnouncement struct {
+	Properties           MatchProperties
+	ProxyID              ID
+	ConnectionID         ID
+	ProxyProtocolVersion int32
+}
+
+// MatchOffer is a client offer to be queued for matching.
+type MatchOffer struct {
+	Properties                  MatchProperties
+	ClientProxyProtocolVersion  int32
+	ClientOfferSDP              webrtc.SessionDescription
+	ClientRootObfuscationSecret ObfuscationSecret
+	NetworkProtocol             NetworkProtocol
+	DestinationAddress          string
+	DestinationServerID         string
+}
+
+// MatchAnswer is a proxy answer, the proxy's follow up to a matched
+// announcement, to be routed to the awaiting client offer.
+type MatchAnswer struct {
+	ProxyIP                      string
+	ProxyID                      ID
+	ConnectionID                 ID
+	SelectedProxyProtocolVersion int32
+	ProxyAnswerSDP               webrtc.SessionDescription
+}
+
+// announcementEntry is an announcement queue entry, an announcement with its
+// associated lifetime context and signaling channel.
+type announcementEntry struct {
+	ctx          context.Context
+	announcement *MatchAnnouncement
+	offerChan    chan *MatchOffer
+}
+
+// offerEntry is an offer queue entry, an offer with its associated lifetime
+// context and signaling channel.
+type offerEntry struct {
+	ctx        context.Context
+	offer      *MatchOffer
+	answerChan chan *answerInfo
+}
+
+// answerInfo is an answer and its associated announcement.
+type answerInfo struct {
+	announcement *MatchAnnouncement
+	answer       *MatchAnswer
+}
+
+// pendingAnswer represents an answer that is expected to arrive from a
+// proxy.
+type pendingAnswer struct {
+	announcement *MatchAnnouncement
+	answerChan   chan *answerInfo
+}
+
+// MatcherConfig specifies the configuration for a matcher.
+type MatcherConfig struct {
+
+	// Logger is used to log events.
+	Logger common.Logger
+}
+
+// NewMatcher creates a new Matcher.
+func NewMatcher(config *MatcherConfig) *Matcher {
+
+	return &Matcher{
+		config: config,
+
+		waitGroup: new(sync.WaitGroup),
+
+		announcementQueue: deque.New[*announcementEntry](),
+		offerQueue:        deque.New[*offerEntry](),
+
+		matchSignal: make(chan struct{}, 1),
+
+		// matcherPendingAnswersTTL is not configurable; it supplies a default
+		// that is expected to be ignored when each entry's TTL is set to the
+		// Offer ctx timeout.
+
+		pendingAnswers: lrucache.NewWithLRU(
+			matcherPendingAnswersTTL,
+			1*time.Minute,
+			matcherPendingAnswersMaxSize),
+	}
+}
+
+// Start starts running the Matcher. The Matcher runs a goroutine which
+// matches announcements and offers.
+func (m *Matcher) Start() error {
+
+	m.runMutex.Lock()
+	defer m.runMutex.Unlock()
+
+	if m.runContext != nil {
+		return errors.TraceNew("already running")
+	}
+
+	m.runContext, m.stopRunning = context.WithCancel(context.Background())
+
+	m.waitGroup.Add(1)
+	go func() {
+		defer m.waitGroup.Done()
+		m.matchWorker(m.runContext)
+	}()
+
+	return nil
+}
+
+// Stop stops running the Matcher and its worker goroutine.
+//
+// Limitation: Stop is not synchronized with Announce/Offer/Answer, so items
+// can get enqueued during and after a Stop call. Stop is intended more for a
+// full broker shutdown, where this won't be a concern.
+func (m *Matcher) Stop() {
+
+	m.runMutex.Lock()
+	defer m.runMutex.Unlock()
+
+	m.stopRunning()
+	m.waitGroup.Wait()
+	m.runContext, m.stopRunning = nil, nil
+}
+
+// Announce enqueues the proxy announcement and blocks until it is matched
+// with a returned offer or ctx is done. The caller must not mutate the
+// announcement or its properties after calling Announce.
+//
+// The offer is sent to the proxy by the broker, and then the proxy sends its
+// answer back to the broker, which calls Answer with that value.
+func (m *Matcher) Announce(
+	ctx context.Context,
+	proxyAnnouncement *MatchAnnouncement) (*MatchOffer, error) {
+
+	announcementEntry := &announcementEntry{
+		ctx:          ctx,
+		announcement: proxyAnnouncement,
+		offerChan:    make(chan *MatchOffer, 1),
+	}
+
+	m.addAnnouncementEntry(announcementEntry)
+
+	// Await client offer.
+
+	var clientOffer *MatchOffer
+
+	select {
+	case <-ctx.Done():
+		m.removeAnnouncementEntry(announcementEntry)
+		return nil, errors.Trace(ctx.Err())
+
+	case clientOffer = <-announcementEntry.offerChan:
+	}
+
+	return clientOffer, nil
+}
+
+// Offer enqueues the client offer and blocks until it is matched with a
+// returned announcement or ctx is done. The caller must not mutate the offer
+// or its properties after calling Announce.
+//
+// The answer is returned to the client by the broker, and the WebRTC
+// connection is dialed. The original announcement is also returned, so its
+// match properties can be logged.
+func (m *Matcher) Offer(
+	ctx context.Context,
+	clientOffer *MatchOffer) (*MatchAnswer, *MatchAnnouncement, error) {
+
+	offerEntry := &offerEntry{
+		ctx:        ctx,
+		offer:      clientOffer,
+		answerChan: make(chan *answerInfo, 1),
+	}
+
+	m.addOfferEntry(offerEntry)
+
+	// Await proxy answer.
+
+	var proxyAnswerInfo *answerInfo
+
+	select {
+	case <-ctx.Done():
+		m.removeOfferEntry(offerEntry)
+
+		// TODO: also remove any pendingAnswers entry? The entry TTL is set to
+		// the Offer ctx, the client request, timeout, so it will eventually
+		// get removed. But a client may abort its request earlier than the
+		// timeout.
+
+		return nil, nil, errors.Trace(ctx.Err())
+	case proxyAnswerInfo = <-offerEntry.answerChan:
+	}
+
+	if proxyAnswerInfo == nil {
+
+		// nil will be delivered to the channel when either the proxy
+		// announcment request concurrently timed out, or the answer
+		// indicated a proxy error, or the answer did not arrive in time.
+		return nil, nil, errors.TraceNew("no answer")
+	}
+
+	// This is a sanity check and not expected to fail.
+	if !proxyAnswerInfo.answer.ConnectionID.Equal(
+		proxyAnswerInfo.announcement.ConnectionID) {
+		return nil, nil, errors.TraceNew("unexpected connection ID")
+	}
+
+	return proxyAnswerInfo.answer, proxyAnswerInfo.announcement, nil
+}
+
+// Answer delivers an answer from the proxy for a previously matched offer.
+// The ProxyID and ConnectionID must correspond to the original announcement.
+// The caller must not mutate the answer after calling Answer. Answer does
+// not block.
+//
+// The answer is returned to the awaiting Offer call and sent to the matched
+// client.
+func (m *Matcher) Answer(
+	proxyAnswer *MatchAnswer) error {
+
+	key := m.pendingAnswerKey(proxyAnswer.ProxyID, proxyAnswer.ConnectionID)
+	pendingAnswerValue, ok := m.pendingAnswers.Get(key)
+	if !ok {
+		// The client is no longer awaiting the response.
+		return errors.TraceNew("no client")
+	}
+
+	m.pendingAnswers.Delete(key)
+
+	pendingAnswer := pendingAnswerValue.(*pendingAnswer)
+
+	pendingAnswer.answerChan <- &answerInfo{
+		announcement: pendingAnswer.announcement,
+		answer:       proxyAnswer,
+	}
+
+	return nil
+}
+
+// AnswerError delivers a failed answer indication from the proxy to an
+// awaiting offer. The ProxyID and ConnectionID must correspond to the
+// original announcement.
+//
+// The failure indication is returned to the awaiting Offer call and sent to
+// the matched client.
+func (m *Matcher) AnswerError(proxyID ID, connectionID ID) {
+
+	key := m.pendingAnswerKey(proxyID, connectionID)
+	pendingAnswerValue, ok := m.pendingAnswers.Get(key)
+	if !ok {
+		// The client is no longer awaiting the response.
+		return
+	}
+
+	m.pendingAnswers.Delete(key)
+
+	// Closing the channel delivers nil, a failed indicator, to any receiver.
+	close(pendingAnswerValue.(*pendingAnswer).answerChan)
+}
+
+// matchWorker is the matching worker goroutine. It idles until signaled that
+// a queue item has been added, and then runs a full matching pass.
+func (m *Matcher) matchWorker(ctx context.Context) {
+	for {
+		select {
+		case <-m.matchSignal:
+		case <-ctx.Done():
+			return
+		}
+		m.matchAllOffers()
+	}
+}
+
+// matchAllOffers iterates over the queues, making all possible matches.
+func (m *Matcher) matchAllOffers() {
+
+	m.announcementQueueMutex.Lock()
+	defer m.announcementQueueMutex.Unlock()
+	m.offerQueueMutex.Lock()
+	defer m.offerQueueMutex.Unlock()
+
+	// Take each offer in turn, and select an announcement match. There is an
+	// implicit preference for older client offers, sooner to timeout, at the
+	// front of the queue.
+
+	// TODO: consider matching one offer, then releasing the locks to allow
+	// more announcements to be enqueued, then continuing to match.
+
+	i := 0
+	end := m.offerQueue.Len()
+
+	for i < end && m.announcementQueue.Len() > 0 {
+
+		offerEntry := m.offerQueue.At(i)
+
+		// Skip and remove this offer if its deadline has already passed.
+		// There is no signal to the awaiting Offer function, as it will exit
+		// based on the same ctx.
+
+		if offerEntry.ctx.Err() != nil {
+			m.offerQueue.Remove(i)
+			end -= 1
+			continue
+		}
+
+		j, ok := m.matchOffer(offerEntry)
+		if !ok {
+
+			// No match, so leave this offer in place in the queue and move to
+			// the next.
+
+			i++
+			continue
+		}
+
+		if m.config.Logger.IsLogLevelDebug() {
+			m.config.Logger.WithTraceFields(common.LogFields{
+				"match_index":             j,
+				"offer_queue_size":        m.offerQueue.Len(),
+				"announcement_queue_size": m.announcementQueue.Len(),
+			}).Debug("match metrics")
+		}
+
+		// Remove the matched announcement from the queue. Send the offer to
+		// the announcment entry's offerChan, which will deliver it to the
+		// blocked Announce call. Add a pending answers entry to await the
+		// proxy's follow up Answer call. The TTL for the pending answer
+		// entry is set to the matched Offer call's ctx, as the answer is
+		// only useful as long as the client is still waiting.
+
+		announcementEntry := m.announcementQueue.At(j)
+
+		expiry := lrucache.DefaultExpiration
+		deadline, ok := offerEntry.ctx.Deadline()
+		if ok {
+			expiry = time.Until(deadline)
+		}
+
+		key := m.pendingAnswerKey(
+			announcementEntry.announcement.ProxyID,
+			announcementEntry.announcement.ConnectionID)
+
+		m.pendingAnswers.Set(
+			key,
+			&pendingAnswer{
+				announcement: announcementEntry.announcement,
+				answerChan:   offerEntry.answerChan,
+			},
+			expiry)
+
+		announcementEntry.offerChan <- offerEntry.offer
+
+		m.announcementQueue.Remove(j)
+		m.adjustAnnouncementCounts(announcementEntry, -1)
+
+		// Remove the matched offer from the queue and match the next offer,
+		// now first in the queue.
+
+		m.offerQueue.Remove(i)
+		end -= 1
+	}
+}
+
+func (m *Matcher) matchOffer(offerEntry *offerEntry) (int, bool) {
+
+	// Assumes the caller has the queue mutexed locked.
+
+	// Check each announcement in turn, and select a match. There is an
+	// implicit preference for older proxy announcments, sooner to timeout, at the
+	// front of the queue.
+
+	// Future matching enhancements could include more sophisticated GeoIP
+	// rules, such as a configuration encoding knowledge of an ASN's NAT
+	// type, or preferred client/proxy country/ASN matches.
+
+	offerProperties := &offerEntry.offer.Properties
+
+	// Use the NAT traversal type counters to check if there's any preferred
+	// NAT match for this offer in the announcement queue. When there is, we
+	// will search beyond the first announcement.
+
+	existsPreferredNATMatch := offerProperties.ExistsPreferredNATMatch(
+		m.announcementsUnlimitedNATCount > 0,
+		m.announcementsPartiallyLimitedNATCount > 0,
+		m.announcementsStrictlyLimitedNATCount > 0)
+
+	bestMatch := -1
+	bestMatchNAT := false
+	bestMatchCompartment := false
+
+	end := m.announcementQueue.Len()
+
+	for i := 0; i < end; i++ {
+
+		announcementEntry := m.announcementQueue.At(i)
+
+		// Skip and remove this announcement if its deadline has already
+		// passed. There is no signal to the awaiting Announce function, as
+		// it will exit based on the same ctx.
+
+		if announcementEntry.ctx.Err() != nil {
+			m.announcementQueue.Remove(i)
+			end -= 1
+			continue
+		}
+
+		announcementProperties := &announcementEntry.announcement.Properties
+
+		// Disallow matching the same country and ASN
+
+		if offerProperties.GeoIPData.Country ==
+			announcementProperties.GeoIPData.Country &&
+			offerProperties.GeoIPData.ASN ==
+				announcementProperties.GeoIPData.ASN {
+			continue
+		}
+
+		// There must be a compartment match. If there is a personal
+		// compartment match, this match will be preferred.
+
+		matchCommonCompartment := HaveCommonIDs(
+			announcementProperties.CommonCompartmentIDs, offerProperties.CommonCompartmentIDs)
+		matchPersonalCompartment := HaveCommonIDs(
+			announcementProperties.PersonalCompartmentIDs, offerProperties.PersonalCompartmentIDs)
+		if !matchCommonCompartment && !matchPersonalCompartment {
+			continue
+		}
+
+		// Check if this is a preferred NAT match. Ultimately, a match may be
+		// made with potentially incompatible NATs, but the client/proxy
+		// reported NAT types may be incorrect or unknown; the client will
+		// oftern skip NAT discovery.
+
+		matchNAT := offerProperties.IsPreferredNATMatch(announcementProperties)
+
+		// At this point, the candidate is a match. Determine if this is a new
+		// best match.
+
+		if bestMatch == -1 {
+
+			// This is a match, and there was no previous match, so it becomes
+			// the provisional best match.
+
+			bestMatch = i
+			bestMatchNAT = matchNAT
+			bestMatchCompartment = matchPersonalCompartment
+
+		} else if !bestMatchNAT && matchNAT {
+
+			// If there was a previous best match which was not a preferred
+			// NAT match, this becomes the new best match. The preferred NAT
+			// match is prioritized over personal compartment matching.
+
+			bestMatch = i
+			bestMatchNAT = true
+			bestMatchCompartment = matchPersonalCompartment
+
+		} else if !bestMatchCompartment && matchPersonalCompartment && (!bestMatchNAT || matchNAT) {
+
+			// If there was a previous best match which was not a personal
+			// compartment match, and as long as this match doesn't undo a
+			// better NAT match, this becomes the new best match.
+
+			bestMatch = i
+			bestMatchNAT = matchNAT
+			bestMatchCompartment = true
+		}
+
+		// Stop as soon as we have the best possible match.
+
+		if (bestMatchNAT || !existsPreferredNATMatch) &&
+			(matchPersonalCompartment || m.announcementsPersonalCompartmentalizedCount == 0) {
+			break
+		}
+	}
+
+	return bestMatch, bestMatch != -1
+}
+
+func (m *Matcher) addAnnouncementEntry(announcementEntry *announcementEntry) bool {
+
+	m.announcementQueueMutex.Lock()
+	defer m.announcementQueueMutex.Unlock()
+
+	if m.announcementQueue.Len() >= matcherAnnouncementQueueMaxSize {
+		return false
+	}
+	m.announcementQueue.PushBack(announcementEntry)
+	m.adjustAnnouncementCounts(announcementEntry, 1)
+
+	select {
+	case m.matchSignal <- struct{}{}:
+	default:
+	}
+
+	return true
+}
+
+func (m *Matcher) removeAnnouncementEntry(announcementEntry *announcementEntry) {
+
+	m.announcementQueueMutex.Lock()
+	defer m.announcementQueueMutex.Unlock()
+
+	found := false
+	for i := 0; i < m.announcementQueue.Len(); i++ {
+		if m.announcementQueue.At(i) == announcementEntry {
+			m.announcementQueue.Remove(i)
+			m.adjustAnnouncementCounts(announcementEntry, -1)
+			found = true
+			break
+		}
+	}
+	if !found {
+
+		// The Announce call is aborting and taking its entry back out of the
+		// queue. If the entry is not found in the queue, then a concurrent
+		// Offer has matched the announcement. So check for the pending
+		// answer corresponding to the announcement and remove it and deliver
+		// a failure signal to the waiting Offer, so the client doesn't wait
+		// longer than necessary.
+
+		key := m.pendingAnswerKey(
+			announcementEntry.announcement.ProxyID,
+			announcementEntry.announcement.ConnectionID)
+
+		pendingAnswerValue, ok := m.pendingAnswers.Get(key)
+		if ok {
+			close(pendingAnswerValue.(*pendingAnswer).answerChan)
+			m.pendingAnswers.Delete(key)
+		}
+	}
+}
+
+func (m *Matcher) adjustAnnouncementCounts(
+	announcementEntry *announcementEntry, delta int) {
+
+	// Assumes s.announcementQueueMutex lock is held.
+
+	if announcementEntry.announcement.Properties.IsPersonalCompartmentalized() {
+		m.announcementsPersonalCompartmentalizedCount += delta
+	}
+
+	switch announcementEntry.announcement.Properties.EffectiveNATType().Traversal() {
+	case NATTraversalUnlimited:
+		m.announcementsUnlimitedNATCount += delta
+	case NATTraversalPartiallyLimited:
+		m.announcementsPartiallyLimitedNATCount += delta
+	case NATTraversalStrictlyLimited:
+		m.announcementsStrictlyLimitedNATCount += delta
+	}
+}
+
+func (m *Matcher) addOfferEntry(offerEntry *offerEntry) bool {
+
+	m.offerQueueMutex.Lock()
+	defer m.offerQueueMutex.Unlock()
+
+	if m.offerQueue.Len() >= matcherOfferQueueMaxSize {
+		return false
+	}
+	m.offerQueue.PushBack(offerEntry)
+
+	select {
+	case m.matchSignal <- struct{}{}:
+	default:
+	}
+
+	return true
+}
+
+func (m *Matcher) removeOfferEntry(offerEntry *offerEntry) {
+
+	m.offerQueueMutex.Lock()
+	defer m.offerQueueMutex.Unlock()
+
+	for i := 0; i < m.offerQueue.Len(); i++ {
+		if m.offerQueue.At(i) == offerEntry {
+			m.offerQueue.Remove(i)
+			break
+		}
+	}
+}
+
+func (m *Matcher) pendingAnswerKey(proxyID ID, connectionID ID) string {
+
+	// The pending answer lookup key is used to associate announcements and
+	// subsequent answers. While the client learns the ConnectionID, only the
+	// proxy knows the ProxyID component, so only the correct proxy can match
+	// an answer to an announcement. The ConnectionID component is necessary
+	// as a proxy may have multiple, concurrent pending answers.
+
+	return string(proxyID[:]) + string(connectionID[:])
+}

+ 458 - 0
psiphon/common/inproxy/matcher_test.go

@@ -0,0 +1,458 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+)
+
+func TestMatcher(t *testing.T) {
+	err := runTestMatcher()
+	if err != nil {
+		t.Errorf(errors.Trace(err).Error())
+	}
+
+}
+
+func runTestMatcher() error {
+
+	logger := newTestLogger()
+
+	m := NewMatcher(
+		&MatcherConfig{
+			Logger: logger,
+		})
+	err := m.Start()
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer m.Stop()
+
+	makeID := func() ID {
+		ID, err := MakeID()
+		if err != nil {
+			panic(err)
+		}
+		return ID
+	}
+
+	makeAnnouncement := func(properties *MatchProperties) *MatchAnnouncement {
+		return &MatchAnnouncement{
+			Properties:   *properties,
+			ProxyID:      makeID(),
+			ConnectionID: makeID(),
+		}
+	}
+
+	makeOffer := func(properties *MatchProperties) *MatchOffer {
+		return &MatchOffer{
+			Properties:                 *properties,
+			ClientProxyProtocolVersion: ProxyProtocolVersion1,
+		}
+	}
+
+	proxyFunc := func(
+		resultChan chan error,
+		matchProperties *MatchProperties,
+		timeout time.Duration,
+		waitBeforeAnswer chan struct{},
+		answerSuccess bool) {
+
+		ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
+		defer cancelFunc()
+
+		announcement := makeAnnouncement(matchProperties)
+		offer, err := m.Announce(ctx, announcement)
+		if err != nil {
+			resultChan <- errors.Trace(err)
+			return
+		}
+
+		if waitBeforeAnswer != nil {
+			<-waitBeforeAnswer
+		}
+
+		if answerSuccess {
+			err = m.Answer(
+				&MatchAnswer{
+					ProxyID:                      announcement.ProxyID,
+					ConnectionID:                 announcement.ConnectionID,
+					SelectedProxyProtocolVersion: offer.ClientProxyProtocolVersion,
+				})
+		} else {
+			m.AnswerError(announcement.ProxyID, announcement.ConnectionID)
+		}
+		resultChan <- errors.Trace(err)
+	}
+
+	clientFunc := func(
+		resultChan chan error,
+		matchProperties *MatchProperties,
+		timeout time.Duration) {
+
+		ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
+		defer cancelFunc()
+
+		offer := makeOffer(matchProperties)
+		answer, _, err := m.Offer(ctx, offer)
+		if err != nil {
+			resultChan <- errors.Trace(err)
+			return
+		}
+		if answer.SelectedProxyProtocolVersion != offer.ClientProxyProtocolVersion {
+			resultChan <- errors.TraceNew("unexpected selected proxy protocol version")
+			return
+		}
+		resultChan <- nil
+	}
+
+	// Test: announce timeout
+
+	proxyResultChan := make(chan error)
+
+	go proxyFunc(proxyResultChan, &MatchProperties{}, 1*time.Microsecond, nil, true)
+
+	err = <-proxyResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+	if m.announcementQueue.Len() != 0 {
+		return errors.TraceNew("unexpected queue size")
+	}
+
+	// Test: offer timeout
+
+	clientResultChan := make(chan error)
+
+	go clientFunc(clientResultChan, &MatchProperties{}, 1*time.Microsecond)
+
+	err = <-clientResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+	if m.offerQueue.Len() != 0 {
+		return errors.TraceNew("unexpected queue size")
+	}
+
+	// Test: basic match
+
+	basicCommonCompartmentIDs := []ID{makeID()}
+
+	geoIPData1 := &MatchProperties{
+		GeoIPData:            common.GeoIPData{Country: "C1", ASN: "A1"},
+		CommonCompartmentIDs: basicCommonCompartmentIDs,
+	}
+
+	geoIPData2 := &MatchProperties{
+		GeoIPData:            common.GeoIPData{Country: "C2", ASN: "A2"},
+		CommonCompartmentIDs: basicCommonCompartmentIDs,
+	}
+
+	go proxyFunc(proxyResultChan, geoIPData1, 10*time.Millisecond, nil, true)
+	go clientFunc(clientResultChan, geoIPData2, 10*time.Millisecond)
+
+	err = <-proxyResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = <-clientResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: answer error
+
+	go proxyFunc(proxyResultChan, geoIPData1, 10*time.Millisecond, nil, false)
+	go clientFunc(clientResultChan, geoIPData2, 10*time.Millisecond)
+
+	err = <-proxyResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = <-clientResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "no answer") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	// Test: client is gone
+
+	waitBeforeAnswer := make(chan struct{})
+
+	go proxyFunc(proxyResultChan, geoIPData1, 100*time.Millisecond, waitBeforeAnswer, true)
+	go clientFunc(clientResultChan, geoIPData2, 10*time.Millisecond)
+
+	err = <-clientResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	close(waitBeforeAnswer)
+
+	err = <-proxyResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "no client") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	// Test: no compartment match
+
+	compartment1 := &MatchProperties{
+		GeoIPData:              geoIPData1.GeoIPData,
+		CommonCompartmentIDs:   []ID{makeID()},
+		PersonalCompartmentIDs: []ID{makeID()},
+	}
+
+	compartment2 := &MatchProperties{
+		GeoIPData:              geoIPData2.GeoIPData,
+		CommonCompartmentIDs:   []ID{makeID()},
+		PersonalCompartmentIDs: []ID{makeID()},
+	}
+
+	go proxyFunc(proxyResultChan, compartment1, 10*time.Millisecond, nil, true)
+	go clientFunc(clientResultChan, compartment2, 10*time.Millisecond)
+
+	err = <-proxyResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	err = <-clientResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	// Test: common compartment match
+
+	compartment1And2 := &MatchProperties{
+		GeoIPData:            geoIPData2.GeoIPData,
+		CommonCompartmentIDs: []ID{compartment1.CommonCompartmentIDs[0], compartment2.CommonCompartmentIDs[0]},
+	}
+
+	go proxyFunc(proxyResultChan, compartment1, 10*time.Millisecond, nil, true)
+	go clientFunc(clientResultChan, compartment1And2, 10*time.Millisecond)
+
+	err = <-proxyResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = <-clientResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: personal compartment match
+
+	compartment1And2 = &MatchProperties{
+		GeoIPData:              geoIPData2.GeoIPData,
+		PersonalCompartmentIDs: []ID{compartment1.PersonalCompartmentIDs[0], compartment2.PersonalCompartmentIDs[0]},
+	}
+
+	go proxyFunc(proxyResultChan, compartment1, 10*time.Millisecond, nil, true)
+	go clientFunc(clientResultChan, compartment1And2, 10*time.Millisecond)
+
+	err = <-proxyResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = <-clientResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: personal compartment preferred match
+
+	compartment1Common := &MatchProperties{
+		GeoIPData:            geoIPData1.GeoIPData,
+		CommonCompartmentIDs: []ID{compartment1.CommonCompartmentIDs[0]},
+	}
+
+	compartment1Personal := &MatchProperties{
+		GeoIPData:              geoIPData1.GeoIPData,
+		PersonalCompartmentIDs: []ID{compartment1.PersonalCompartmentIDs[0]},
+	}
+
+	compartment1CommonAndPersonal := &MatchProperties{
+		GeoIPData:              geoIPData2.GeoIPData,
+		CommonCompartmentIDs:   []ID{compartment1.CommonCompartmentIDs[0]},
+		PersonalCompartmentIDs: []ID{compartment1.PersonalCompartmentIDs[0]},
+	}
+
+	client1ResultChan := make(chan error)
+	client2ResultChan := make(chan error)
+
+	proxy1ResultChan := make(chan error)
+	proxy2ResultChan := make(chan error)
+
+	go proxyFunc(proxy1ResultChan, compartment1Common, 10*time.Millisecond, nil, true)
+	go proxyFunc(proxy2ResultChan, compartment1Personal, 10*time.Millisecond, nil, true)
+	time.Sleep(5 * time.Millisecond) // Hack to ensure both proxies are enqueued
+	go clientFunc(client1ResultChan, compartment1CommonAndPersonal, 10*time.Millisecond)
+
+	err = <-proxy1ResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	// proxy2 should match since it has the preferred personal compartment ID
+	err = <-proxy2ResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = <-client1ResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: no same-ASN match
+
+	go proxyFunc(proxyResultChan, geoIPData1, 10*time.Millisecond, nil, true)
+	go clientFunc(clientResultChan, geoIPData1, 10*time.Millisecond)
+
+	err = <-proxyResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	err = <-clientResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	// Test: proxy preferred NAT match
+
+	client1Properties := &MatchProperties{
+		GeoIPData:            common.GeoIPData{Country: "C1", ASN: "A1"},
+		NATType:              NATTypeFullCone,
+		CommonCompartmentIDs: basicCommonCompartmentIDs,
+	}
+
+	client2Properties := &MatchProperties{
+		GeoIPData:            common.GeoIPData{Country: "C2", ASN: "A2"},
+		NATType:              NATTypeSymmetric,
+		CommonCompartmentIDs: basicCommonCompartmentIDs,
+	}
+
+	proxy1Properties := &MatchProperties{
+		GeoIPData:            common.GeoIPData{Country: "C3", ASN: "A3"},
+		NATType:              NATTypeNone,
+		CommonCompartmentIDs: basicCommonCompartmentIDs,
+	}
+
+	proxy2Properties := &MatchProperties{
+		GeoIPData:            common.GeoIPData{Country: "C4", ASN: "A4"},
+		NATType:              NATTypeSymmetric,
+		CommonCompartmentIDs: basicCommonCompartmentIDs,
+	}
+
+	go proxyFunc(proxy1ResultChan, proxy1Properties, 10*time.Millisecond, nil, true)
+	go proxyFunc(proxy2ResultChan, proxy2Properties, 10*time.Millisecond, nil, true)
+	time.Sleep(5 * time.Millisecond) // Hack to ensure both proxies are enqueued
+	go clientFunc(client1ResultChan, client1Properties, 10*time.Millisecond)
+
+	err = <-proxy1ResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	// proxy2 should match since it's the preferred NAT match
+	err = <-proxy2ResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = <-client1ResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: client preferred NAT match
+
+	go proxyFunc(client1ResultChan, client1Properties, 10*time.Millisecond, nil, true)
+	go proxyFunc(client2ResultChan, client2Properties, 10*time.Millisecond, nil, true)
+	time.Sleep(500 * time.Microsecond) // Hack to ensure both clients are enqueued
+	go clientFunc(proxy1ResultChan, proxy1Properties, 10*time.Millisecond)
+
+	err = <-proxy1ResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = <-client1ResultChan
+	if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	// client2 should match since it's the preferred NAT match
+	err = <-client2ResultChan
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: many matches
+
+	// Reduce test log noise for this phase of the test
+	logger.SetLogLevelDebug(false)
+
+	matchCount := 10000
+	proxyCount := matchCount
+	clientCount := matchCount
+
+	// Buffered so no goroutine will block reporting result
+	proxyResultChan = make(chan error, matchCount)
+	clientResultChan = make(chan error, matchCount)
+
+	for proxyCount > 0 || clientCount > 0 {
+
+		// Don't simply alternate enqueuing a proxy and a client
+		if proxyCount > 0 && (clientCount == 0 || prng.FlipCoin()) {
+			go proxyFunc(proxyResultChan, geoIPData1, 10*time.Second, nil, true)
+			proxyCount -= 1
+
+		} else if clientCount > 0 {
+			go clientFunc(clientResultChan, geoIPData2, 10*time.Second)
+			clientCount -= 1
+		}
+	}
+
+	for i := 0; i < matchCount; i++ {
+		err = <-proxyResultChan
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		err = <-clientResultChan
+		if err != nil {
+			return errors.Trace(err)
+		}
+	}
+
+	return nil
+}

+ 406 - 0
psiphon/common/inproxy/nat.go

@@ -0,0 +1,406 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"fmt"
+	"strings"
+)
+
+// NATMapping is a NAT mapping behavior defined in RFC 4787, section 4.1.
+type NATMapping int32
+
+const (
+	NATMappingUnknown NATMapping = iota
+	NATMappingEndpointIndependent
+	NATMappingAddressDependent
+	NATMappingAddressPortDependent
+)
+
+func (m NATMapping) String() string {
+	switch m {
+	case NATMappingUnknown:
+		return "MappingUnknown"
+	case NATMappingEndpointIndependent:
+		return "MappingEndpointIndependent"
+	case NATMappingAddressDependent:
+		return "MappingAddressDependent"
+	case NATMappingAddressPortDependent:
+		return "MappingAddressPortDependent"
+	}
+	return ""
+}
+
+// MarshalText ensures the string representation of the value is logged in
+// JSON.
+func (m NATMapping) MarshalText() ([]byte, error) {
+	return []byte(m.String()), nil
+}
+
+func (m NATMapping) IsValid() bool {
+	return m.String() != ""
+}
+
+// NATMapping is a NAT filtering behavior defined in RFC 4787, section 5.
+type NATFiltering int32
+
+const (
+	NATFilteringUnknown NATFiltering = iota
+	NATFilteringEndpointIndependent
+	NATFilteringAddressDependent
+	NATFilteringAddressPortDependent
+)
+
+func (f NATFiltering) String() string {
+	switch f {
+	case NATFilteringUnknown:
+		return "FilteringUnknown"
+	case NATFilteringEndpointIndependent:
+		return "FilteringEndpointIndependent"
+	case NATFilteringAddressDependent:
+		return "FilteringAddressDependent"
+	case NATFilteringAddressPortDependent:
+		return "FilteringAddressPortDependent"
+	}
+	return ""
+}
+
+// MarshalText ensures the string representation of the value is logged in
+// JSON.
+func (f NATFiltering) MarshalText() ([]byte, error) {
+	return []byte(f.String()), nil
+}
+
+func (f NATFiltering) IsValid() bool {
+	return f.String() != ""
+}
+
+// NATType specifies a network's NAT behavior and consists of a NATMapping and
+// a NATFiltering component.
+type NATType int32
+
+// MakeNATType creates a new NATType.
+func MakeNATType(mapping NATMapping, filtering NATFiltering) NATType {
+	return (NATType(mapping) << 2) | NATType(filtering)
+}
+
+var (
+	NATTypeUnknown = MakeNATType(NATMappingUnknown, NATFilteringUnknown)
+
+	// NATTypePortMapping is a pseudo NATType, used in matching, that
+	// represents the relevant NAT behavior of a port mapping (e.g., UPnP-IGD).
+	NATTypePortMapping = MakeNATType(NATMappingEndpointIndependent, NATFilteringEndpointIndependent)
+
+	// NATTypeMobileNetwork is a pseudo NATType, usied in matching, that
+	// represents the assumed and relevent NAT behavior of clients on mobile
+	// networks, presumed to be behind CGNAT when they report NATTypeUnknown.
+	NATTypeMobileNetwork = MakeNATType(NATMappingAddressPortDependent, NATFilteringAddressPortDependent)
+
+	// NATTypeNone and the following NATType constants are used in testing.
+	// They are not entirely precise (a symmetric NAT may have a different
+	// mix of mapping and filtering values). The matching logic does not use
+	// specific NAT type definitions and instead considers the reported
+	// mapping and filtering values.
+	NATTypeNone               = MakeNATType(NATMappingEndpointIndependent, NATFilteringEndpointIndependent)
+	NATTypeFullCone           = MakeNATType(NATMappingEndpointIndependent, NATFilteringEndpointIndependent)
+	NATTypeRestrictedCone     = MakeNATType(NATMappingEndpointIndependent, NATFilteringAddressDependent)
+	NATTypePortRestrictedCone = MakeNATType(NATMappingEndpointIndependent, NATFilteringAddressPortDependent)
+	NATTypeSymmetric          = MakeNATType(NATMappingAddressPortDependent, NATFilteringAddressPortDependent)
+)
+
+// NeedsDiscovery indicates that the NATType is unknown and should be
+// discovered.
+func (t NATType) NeedsDiscovery() bool {
+	return t == NATTypeUnknown
+}
+
+// Mapping extracts the NATMapping component of this NATType.
+func (t NATType) Mapping() NATMapping {
+	return NATMapping(t >> 2)
+}
+
+// Filtering extracts the NATFiltering component of this NATType.
+func (t NATType) Filtering() NATFiltering {
+	return NATFiltering(t & 0x3)
+}
+
+// Traversal returns the NATTraversal classification for this NATType.
+func (t NATType) Traversal() NATTraversal {
+	return MakeTraversal(t)
+}
+
+// Compatible indicates whether the NATType NATTraversals are compatible.
+func (t NATType) Compatible(t1 NATType) bool {
+	return t.Traversal().Compatible(t1.Traversal())
+}
+
+// IsPreferredMatch indicates whether the peer NATType's NATTraversal is
+// preferred.
+func (t NATType) IsPreferredMatch(t1 NATType) bool {
+	return t.Traversal().IsPreferredMatch(t1.Traversal())
+}
+
+// ExistsPreferredMatch indicates whhether there exists a preferred match for
+// the NATType's NATTraversal.
+func (t NATType) ExistsPreferredMatch(unlimited, partiallyLimited, limited bool) bool {
+	return t.Traversal().ExistsPreferredMatch(unlimited, partiallyLimited, limited)
+}
+
+func (t NATType) String() string {
+	return fmt.Sprintf(
+		"%s/%s", t.Mapping().String(), t.Filtering().String())
+}
+
+// MarshalText ensures the string representation of the value is logged in
+// JSON.
+func (t NATType) MarshalText() ([]byte, error) {
+	return []byte(t.String()), nil
+}
+
+func (t NATType) IsValid() bool {
+	return t.Mapping().IsValid() && t.Filtering().IsValid()
+}
+
+// NATTraversal classifies the NAT traversal potential for a NATType. NATTypes
+// are determined to be compatible -- that is, a connection between the
+// corresponding networks can be established via STUN hole punching  -- based
+// on their respective NATTraversal classifications.
+type NATTraversal int32
+
+const (
+	NATTraversalUnlimited NATTraversal = iota
+	NATTraversalPartiallyLimited
+	NATTraversalStrictlyLimited
+)
+
+// MakeTraversal returns the NATTraversal classification for the given
+// NATType.
+func MakeTraversal(t NATType) NATTraversal {
+	mapping := t.Mapping()
+	filtering := t.Filtering()
+	if mapping == NATMappingEndpointIndependent {
+		if filtering != NATFilteringAddressPortDependent {
+			// NAT type is, e.g., none, full cone, or restricted cone.
+			return NATTraversalUnlimited
+		}
+		// NAT type is, e.g., port restricted cone.
+		return NATTraversalPartiallyLimited
+	}
+
+	// NAT type is, e.g., symmetric; or unknown -- where we assume the worst
+	// case.
+	return NATTraversalStrictlyLimited
+}
+
+// Compatible indicates whether the NATTraversals are compatible.
+func (t NATTraversal) Compatible(t1 NATTraversal) bool {
+
+	// See the NAT compatibility matrix here:
+	// https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/wikis/NAT-matching#nat-compatibility
+
+	switch t {
+	case NATTraversalUnlimited:
+		// t1 can be any value when t is unlimited.
+		return true
+	case NATTraversalPartiallyLimited:
+		// t1 can be unlimited or partially limited when t is partially limited.
+		return t1 != NATTraversalStrictlyLimited
+	case NATTraversalStrictlyLimited:
+		// t1 must be unlimited when t is limited.
+		return t1 == NATTraversalUnlimited
+	}
+	return false
+}
+
+// IsPreferredMatch indicates whether the peer NATTraversal is a preferred
+// match for this NATTraversal. A match is preferred, and so prioritized,
+// when one of the two NATTraversals is more limited, but the pair is still
+// compatible. This preference attempt to reserve less limited match
+// candidates for those peers that need them.
+func (t NATTraversal) IsPreferredMatch(t1 NATTraversal) bool {
+	switch t {
+	case NATTraversalUnlimited:
+		// Prefer matching unlimited peers with strictly limited peers.
+		// TODO: prefer matching unlimited with partially limited?
+		return t1 == NATTraversalStrictlyLimited
+	case NATTraversalPartiallyLimited:
+		// Prefer matching partially limited peers with unlimited or other
+		// partially limited peers.
+		return t1 == NATTraversalUnlimited || t1 == NATTraversalPartiallyLimited
+	case NATTraversalStrictlyLimited:
+		// Prefer matching strictly limited peers with unlimited peers.
+		return t1 == NATTraversalUnlimited
+	}
+	return false
+}
+
+// ExistsPreferredMatch indicates whether a preferred match exists, for this
+// NATTraversal, when there are unlimited/partiallyLimited/strictlyLimited candidates
+// available.
+func (t NATTraversal) ExistsPreferredMatch(unlimited, partiallyLimited, strictlyLimited bool) bool {
+	switch t {
+	case NATTraversalUnlimited:
+		return strictlyLimited
+	case NATTraversalPartiallyLimited:
+		return unlimited || partiallyLimited
+	case NATTraversalStrictlyLimited:
+		return unlimited
+	}
+	return false
+}
+
+// PortMappingType is a port mapping protocol supported by a network. Values
+// include UPnP-IGD, NAT-PMP, and PCP.
+type PortMappingType int32
+
+const (
+	PortMappingTypeNone PortMappingType = iota
+	PortMappingTypeUPnP
+	PortMappingTypePMP
+	PortMappingTypePCP
+)
+
+func (t PortMappingType) String() string {
+	switch t {
+	case PortMappingTypeNone:
+		return "None"
+	case PortMappingTypeUPnP:
+		return "UPnP-IGD"
+	case PortMappingTypePMP:
+		return "PMP"
+	case PortMappingTypePCP:
+		return "PCP"
+	}
+	return ""
+}
+
+// MarshalText ensures the string representation of the value is logged in
+// JSON.
+func (t PortMappingType) MarshalText() ([]byte, error) {
+	return []byte(t.String()), nil
+}
+
+func (t PortMappingType) IsValid() bool {
+	return t.String() != ""
+}
+
+// PortMappingTypes is a list of port mapping protocol supported by a
+// network.
+type PortMappingTypes []PortMappingType
+
+// NeedsDiscovery indicates that the list of port mapping types is empty and
+// should be discovered. If a network has no supported port mapping types,
+// its list will include PortMappingTypeNone.
+func (t PortMappingTypes) NeedsDiscovery() bool {
+	return len(t) == 0
+}
+
+// Available indicates that at least one port mapping protocol is supported.
+func (t PortMappingTypes) Available() bool {
+	for _, portMappingType := range t {
+		if portMappingType > PortMappingTypeNone {
+			return true
+		}
+	}
+	return false
+}
+
+func (t PortMappingTypes) String() string {
+	s := make([]string, len(t))
+	for i, portMappingType := range t {
+		s[i] = portMappingType.String()
+	}
+	return strings.Join(s, ",")
+}
+
+// MarshalText ensures the string representation of the value is logged in
+// JSON.
+func (t PortMappingTypes) MarshalText() ([]byte, error) {
+	return []byte(t.String()), nil
+}
+
+func (t PortMappingTypes) IsValid() bool {
+	for _, portMappingType := range t {
+		if !portMappingType.IsValid() {
+			return false
+		}
+	}
+	return true
+}
+
+// ICECandidateType is an ICE candidate type: host for public addresses, port
+// mapping for when a port mapping protocol was used to establish a public
+// address, or server reflexive when STUN hole punching was used to create a
+// public address.
+type ICECandidateType int32
+
+const (
+	ICECandidateHost ICECandidateType = iota
+	ICECandidatePortMapping
+	ICECandidateServerReflexive
+)
+
+func (t ICECandidateType) String() string {
+	switch t {
+	case ICECandidateHost:
+		return "Host"
+	case ICECandidatePortMapping:
+		return "PortMapping"
+	case ICECandidateServerReflexive:
+		return "ServerReflexive"
+	}
+	return ""
+}
+
+// MarshalText ensures the string representation of the value is logged in
+// JSON.
+func (t ICECandidateType) MarshalText() ([]byte, error) {
+	return []byte(t.String()), nil
+}
+
+func (t ICECandidateType) IsValid() bool {
+	return t.String() != ""
+}
+
+// ICECandidateTypes is a list of ICE candidate types.
+type ICECandidateTypes []ICECandidateType
+
+func (t ICECandidateTypes) String() string {
+	s := make([]string, len(t))
+	for i, candidateType := range t {
+		s[i] = candidateType.String()
+	}
+	return strings.Join(s, ",")
+}
+
+// MarshalText ensures the string representation of the value is logged in
+// JSON.
+func (t ICECandidateTypes) MarshalText() ([]byte, error) {
+	return []byte(t.String()), nil
+}
+
+func (t ICECandidateTypes) IsValid() bool {
+	for _, candidateType := range t {
+		if !candidateType.IsValid() {
+			return false
+		}
+	}
+	return true
+}

+ 400 - 0
psiphon/common/inproxy/obfuscation.go

@@ -0,0 +1,400 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"crypto/aes"
+	"crypto/cipher"
+	"crypto/rand"
+	"crypto/sha256"
+	"encoding/binary"
+	"io"
+	"sync"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	"github.com/panmari/cuckoofilter"
+	"golang.org/x/crypto/hkdf"
+)
+
+const (
+	obfuscationSessionPacketNonceSize = 12
+	obfuscationAntiReplayTimePeriod   = 20 * time.Minute
+	obfuscationAntiReplayHistorySize  = 10000000
+)
+
+// ObfuscationSecret is shared, semisecret value used in obfuscation layers.
+type ObfuscationSecret [32]byte
+
+// GenerateRootObfuscationSecret creates a new ObfuscationSecret using
+// crypto/rand.
+func GenerateRootObfuscationSecret() (ObfuscationSecret, error) {
+
+	var secret ObfuscationSecret
+	_, err := rand.Read(secret[:])
+	if err != nil {
+		return secret, errors.Trace(err)
+	}
+
+	return secret, nil
+}
+
+// antiReplayTimeFactorPeriodSeconds is variable, to enable overriding the value in
+// tests. This value should not be overridden outside of test
+// cases.
+var antiReplayTimeFactorPeriodSeconds = int64(
+	obfuscationAntiReplayTimePeriod / time.Second)
+
+// deriveObfuscationSecret derives an obfuscation secret from the root secret,
+// a context, and an optional time factor.
+//
+// With a time factor, derived secrets remain valid only for a limited time
+// period. Both ends of an obfuscated communication will derive the same
+// secret based on a shared root secret, a common context, and local clocks.
+// The current time is rounded, allowing the one end's clock to be slightly
+// ahead of or behind of the other end's clock.
+//
+// The time factor can be used in concert with a replay history, bounding the
+// number of historical messages that need to be retained in the history.
+func deriveObfuscationSecret(
+	rootObfuscationSecret ObfuscationSecret,
+	useTimeFactor bool,
+	context string) (ObfuscationSecret, error) {
+
+	var salt []byte
+
+	if useTimeFactor {
+
+		roundedTimePeriod := (time.Now().Unix() +
+			(antiReplayTimeFactorPeriodSeconds / 2)) / antiReplayTimeFactorPeriodSeconds
+
+		var timeFactor [8]byte
+		binary.BigEndian.PutUint64(timeFactor[:], uint64(roundedTimePeriod))
+		salt = timeFactor[:]
+	}
+
+	var key ObfuscationSecret
+
+	_, err := io.ReadFull(
+		hkdf.New(sha256.New, rootObfuscationSecret[:], salt, []byte(context)), key[:])
+	if err != nil {
+		return key, errors.Trace(err)
+	}
+
+	return key, nil
+
+}
+
+// deriveSessionPacketObfuscationSecret derives a common session obfuscation
+// secret for either end of a session. Set isInitiator to true for packets
+// sent or received by the initator; and false for packets sent or received
+// by a responder. Set isObfuscating to true for sent packets, and false for
+// received packets.
+func deriveSessionPacketObfuscationSecret(
+	rootObfuscationSecret ObfuscationSecret,
+	isInitiator bool,
+	isObfuscating bool) (ObfuscationSecret, error) {
+
+	// Upstream is packets from the initiator to the responder; or,
+	// (isInitiator && isObfuscating) || (!isInitiator && !isObfuscating)
+	isUpstream := (isInitiator == isObfuscating)
+
+	// Derive distinct keys for each flow direction, to ensure that the two
+	// flows can't simply be xor'd.
+	context := "in-proxy-session-packet-intiator-to-responder"
+	if !isUpstream {
+		context = "in-proxy-session-packet-responder-to-initiator"
+	}
+
+	// The time factor is set for upstream; the responder uses an anti-replay
+	// history for packets received from initiators.
+	key, err := deriveObfuscationSecret(rootObfuscationSecret, isUpstream, context)
+	if err != nil {
+		return key, errors.Trace(err)
+	}
+
+	return key, nil
+}
+
+// obfuscateSessionPacket wraps a session packet with an obfuscation layer
+// which provides:
+//
+// - indistiguishability from fully random
+// - random padding
+// - anti-replay
+//
+// The full-random and padding properties make obfuscated packets appropriate
+// to embed in otherwise plaintext transports, such as HTTP, without being
+// trivially fingerprintable.
+//
+// While Noise protocol sessions messages have nonces and associated
+// anti-replay for nonces, this measure doen't cover the session handshake,
+// so an independent anti-replay mechanism is implemented here.
+func obfuscateSessionPacket(
+	rootObfuscationSecret ObfuscationSecret,
+	isInitiator bool,
+	packet []byte,
+	paddingMin int,
+	paddingMax int) ([]byte, error) {
+
+	// For simplicity, the secret is derived here for each packet. Derived
+	// keys could be cached, but we need to be updated when a time factor is
+	// active. Typical in-proxy sessions will exchange only a handful of
+	// packets per event: the session handshake, and an API request round
+	// trip or two. We don't attempt to avoid allocations here.
+	//
+	// Benchmark for secret derivation:
+	//
+	//   BenchmarkDeriveObfuscationSecret
+	//   BenchmarkDeriveObfuscationSecret-8   	 1303953	       902.7 ns/op
+
+	key, err := deriveSessionPacketObfuscationSecret(
+		rootObfuscationSecret, isInitiator, true)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	obfuscatedPacket := make([]byte, obfuscationSessionPacketNonceSize)
+
+	_, err = prng.Read(obfuscatedPacket[:])
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	var paddedPacket []byte
+	paddingSize := prng.Range(paddingMin, paddingMax)
+	paddedPacket = binary.AppendUvarint(paddedPacket, uint64(paddingSize))
+
+	paddedPacket = append(paddedPacket, make([]byte, paddingSize)...)
+	paddedPacket = append(paddedPacket, packet...)
+
+	block, err := aes.NewCipher(key[:])
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	aesgcm, err := cipher.NewGCM(block)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	obfuscatedPacket = aesgcm.Seal(
+		obfuscatedPacket,
+		obfuscatedPacket[:obfuscationSessionPacketNonceSize],
+		paddedPacket,
+		nil)
+
+	return obfuscatedPacket, nil
+}
+
+// deobfuscateSessionPacket deobfuscates a session packet obfuscated with
+// obfuscateSessionPacket and the same deobfuscateSessionPacket.
+//
+// Responders must supply an obfuscationReplayHistory, which checks for
+// replayed session packets (within the time factor). Responders should drop
+// into anti-probing response behavior when deobfuscateSessionPacket returns
+// an error: the obfuscated packet may have been created by a prober without
+// the correct secret; or replayed by a prober.
+func deobfuscateSessionPacket(
+	rootObfuscationSecret ObfuscationSecret,
+	isInitiator bool,
+	replayHistory *obfuscationReplayHistory,
+	obfuscatedPacket []byte) ([]byte, error) {
+
+	// A responder must provide a relay history, or it's misconfigured.
+	if isInitiator == (replayHistory != nil) {
+		return nil, errors.TraceNew("unexpected replay history")
+	}
+
+	// imitateDeobfuscateSessionPacketDuration is called in early failure
+	// cases to imitate the elapsed time of lookups and cryptographic
+	// operations that would otherwise be skipped. This is intended to
+	// mitigate timing attacks by probers.
+	//
+	// Limitation: this doesn't result in a constant time.
+
+	if len(obfuscatedPacket) < obfuscationSessionPacketNonceSize {
+		imitateDeobfuscateSessionPacketDuration(replayHistory)
+		return nil, errors.TraceNew("invalid nonce")
+	}
+
+	nonce := obfuscatedPacket[:obfuscationSessionPacketNonceSize]
+
+	if replayHistory != nil && replayHistory.Lookup(nonce) {
+		imitateDeobfuscateSessionPacketDuration(nil)
+		return nil, errors.TraceNew("replayed nonce")
+	}
+
+	key, err := deriveSessionPacketObfuscationSecret(
+		rootObfuscationSecret, isInitiator, false)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// As an AEAD, AES-GCM authenticates that the sender used the expected
+	// key, and so has the root obfuscation secret.
+
+	block, err := aes.NewCipher(key[:])
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	aesgcm, err := cipher.NewGCM(block)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	plaintext, err := aesgcm.Open(
+		nil,
+		nonce,
+		obfuscatedPacket[obfuscationSessionPacketNonceSize:],
+		nil)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	offset := 0
+	paddingSize, n := binary.Uvarint(plaintext[offset:])
+	if n < 1 {
+		return nil, errors.TraceNew("invalid padding size")
+	}
+	offset += n
+	if len(plaintext[offset:]) < int(paddingSize) {
+		return nil, errors.TraceNew("invalid padding")
+	}
+	offset += int(paddingSize)
+
+	if replayHistory != nil {
+
+		// Now that it's validated, add this packet to the replay history. The
+		// nonce is expected to be unique, so it's used as the history key.
+
+		err = replayHistory.Insert(nonce)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+	}
+
+	return plaintext[offset:], nil
+}
+
+func imitateDeobfuscateSessionPacketDuration(replayHistory *obfuscationReplayHistory) {
+
+	// Limitations: only one block is decrypted; crypto/aes or
+	// crypto/cipher.GCM may not be constant time, depending on hardware
+	// support; at best, this all-zeros invocation will make it as far as
+	// GCM.Open, and not check padding.
+
+	const (
+		blockSize = 16
+		tagSize   = 16
+	)
+	var secret ObfuscationSecret
+	var packet [obfuscationSessionPacketNonceSize + blockSize + tagSize]byte
+	if replayHistory != nil {
+		_ = replayHistory.Lookup(packet[:obfuscationSessionPacketNonceSize])
+	}
+	_, _ = deobfuscateSessionPacket(secret, true, nil, packet[:])
+}
+
+// obfuscationReplayHistory provides a lookup for recently observed obfuscated
+// session packet nonces. History is maintained for
+// 2*antiReplayTimeFactorPeriodSeconds; it's assumed that older packets, if
+// replayed, will fail to decrypt due to using an expired time factor.
+type obfuscationReplayHistory struct {
+	mutex         sync.Mutex
+	filters       [2]*cuckoo.Filter
+	currentFilter int
+	switchTime    time.Time
+}
+
+func newObfuscationReplayHistory() *obfuscationReplayHistory {
+
+	// Replay history is implemented using cuckoo filters, which use fixed
+	// space overhead, and less space overhead than storing nonces explictly
+	// under anticipated loads. With cuckoo filters, false positive lookups
+	// are possible, but false negative lookups are not. So there's a small
+	// chance that a non-replayed nonce will be flagged as in the history,
+	// but no chance that a replayed nonce will pass as not in the history.
+	//
+	// From github.com/panmari/cuckoofilter:
+	//   > With the 16 bit fingerprint size in this repository, you can expect r
+	//   > ~= 0.0001. Other implementations use 8 bit, which correspond to a
+	//   > false positive rate of r ~= 0.03. NewFilter returns a new
+	//   > cuckoofilter suitable for the given number of elements. When
+	//   > inserting more elements, insertion speed will drop significantly and
+	//   > insertions might fail altogether. A capacity of 1000000 is a normal
+	//   > default, which allocates about ~2MB on 64-bit machines.
+	//
+	// With obfuscationAntiReplayHistorySize set to 10M, the session_test test
+	// case with 10k clients making 100 requests each all within one time
+	// period consistently produces no false positives.
+	//
+	// To accomodate the rolling time factor window, there are two cuckoo
+	// filters, the "current" filter and the "next" filter. New nonces are
+	// inserted into both the current and next filter. Every
+	// antiReplayTimeFactorPeriodSeconds, the next filter replaces the
+	// current filter. The previous current filter is reset and becomes the
+	// new next filter.
+
+	return &obfuscationReplayHistory{
+		filters: [2]*cuckoo.Filter{
+			cuckoo.NewFilter(obfuscationAntiReplayHistorySize),
+			cuckoo.NewFilter(obfuscationAntiReplayHistorySize),
+		},
+		currentFilter: 0,
+		switchTime:    time.Now(),
+	}
+}
+
+func (h *obfuscationReplayHistory) Insert(value []byte) error {
+	h.mutex.Lock()
+	defer h.mutex.Unlock()
+
+	h.switchFilters()
+
+	if !h.filters[0].Insert(value) || !h.filters[1].Insert(value) {
+		return errors.TraceNew("replay history insert failed")
+	}
+
+	return nil
+}
+
+func (h *obfuscationReplayHistory) Lookup(value []byte) bool {
+	h.mutex.Lock()
+	defer h.mutex.Unlock()
+
+	h.switchFilters()
+
+	return h.filters[h.currentFilter].Lookup(value)
+}
+
+func (h *obfuscationReplayHistory) switchFilters() {
+
+	// Assumes caller holds h.mutex lock.
+
+	now := time.Now()
+	if h.switchTime.Before(now.Add(-time.Duration(antiReplayTimeFactorPeriodSeconds) * time.Second)) {
+		h.filters[h.currentFilter].Reset()
+		h.currentFilter = (h.currentFilter + 1) % 2
+		h.switchTime = now
+	}
+}

+ 416 - 0
psiphon/common/inproxy/obfuscation_test.go

@@ -0,0 +1,416 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"bytes"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+)
+
+func FuzzSessionPacketDeobfuscation(f *testing.F) {
+
+	packet := prng.Padding(100, 1000)
+	minPadding := 1
+	maxPadding := 1000
+
+	rootSecret, err := GenerateRootObfuscationSecret()
+	if err != nil {
+		f.Fatalf(errors.Trace(err).Error())
+	}
+
+	n := 10
+
+	originals := make([][]byte, n)
+
+	for i := 0; i < n; i++ {
+
+		obfuscatedPacket, err := obfuscateSessionPacket(
+			rootSecret, true, packet, minPadding, maxPadding)
+		if err != nil {
+			f.Fatalf(errors.Trace(err).Error())
+		}
+
+		originals[i] = obfuscatedPacket
+
+		f.Add(obfuscatedPacket)
+	}
+
+	f.Fuzz(func(t *testing.T, obfuscatedPacket []byte) {
+
+		// Make a new history each time to bypass the replay check and focus
+		// on fuzzing the parsing code.
+
+		_, err := deobfuscateSessionPacket(
+			rootSecret,
+			false,
+			newObfuscationReplayHistory(),
+			obfuscatedPacket)
+
+		// Only the original, valid messages should successfully deobfuscate.
+
+		inOriginals := false
+		for i := 0; i < n; i++ {
+			if bytes.Equal(originals[i], obfuscatedPacket) {
+				inOriginals = true
+				break
+			}
+		}
+
+		if (err == nil) != inOriginals {
+			f.Errorf("unexpected deobfuscation result")
+		}
+	})
+}
+
+func TestSessionPacketObfuscation(t *testing.T) {
+	err := runTestSessionPacketObfuscation()
+	if err != nil {
+		t.Errorf(errors.Trace(err).Error())
+	}
+}
+
+func runTestSessionPacketObfuscation() error {
+
+	// Use a replay time period factor more suitable for test runs.
+
+	originalAntiReplayTimeFactorPeriodSeconds := antiReplayTimeFactorPeriodSeconds
+	antiReplayTimeFactorPeriodSeconds = 1
+	defer func() {
+		antiReplayTimeFactorPeriodSeconds = originalAntiReplayTimeFactorPeriodSeconds
+	}()
+
+	rootSecret, err := GenerateRootObfuscationSecret()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	replayHistory := newObfuscationReplayHistory()
+
+	// Test: obfuscate/deobfuscate initiator -> responder
+
+	packet := prng.Bytes(1000)
+	minPadding := 1
+	maxPadding := 1000
+
+	obfuscatedPacket1, err := obfuscateSessionPacket(
+		rootSecret, true, packet, minPadding, maxPadding)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	packet1, err := deobfuscateSessionPacket(
+		rootSecret, false, replayHistory, obfuscatedPacket1)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if !bytes.Equal(packet1, packet) {
+		return errors.TraceNew("unexpected deobfuscated packet")
+	}
+
+	// Test: replay packet
+
+	_, err = deobfuscateSessionPacket(
+		rootSecret, false, replayHistory, obfuscatedPacket1)
+	if err == nil {
+		return errors.TraceNew("unexpected replay success")
+	}
+
+	// Test: replay packet after time factor period
+
+	time.Sleep(1 * time.Second)
+
+	_, err = deobfuscateSessionPacket(
+		rootSecret, false, replayHistory, obfuscatedPacket1)
+	if err == nil {
+		return errors.TraceNew("unexpected replay success")
+	}
+
+	// Test: different packet sizes (due to padding)
+
+	n := 10
+	for i := 0; i < n; i++ {
+		obfuscatedPacket2, err := obfuscateSessionPacket(
+			rootSecret, true, packet, minPadding, maxPadding)
+		if err != nil {
+			return errors.Trace(err)
+		}
+		if len(obfuscatedPacket1) != len(obfuscatedPacket2) {
+			break
+		}
+		if i == n-1 {
+			return errors.TraceNew("unexpected same size")
+		}
+	}
+
+	// Test: obfuscate/deobfuscate responder -> initiator
+
+	obfuscatedPacket2, err := obfuscateSessionPacket(
+		rootSecret, false, packet, minPadding, maxPadding)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	packet2, err := deobfuscateSessionPacket(
+		rootSecret, true, nil, obfuscatedPacket2)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if !bytes.Equal(packet2, packet) {
+		return errors.TraceNew("unexpected deobfuscated packet")
+	}
+
+	// Test: initiator -> initiator
+
+	obfuscatedPacket1, err = obfuscateSessionPacket(
+		rootSecret, true, packet, minPadding, maxPadding)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	_, err = deobfuscateSessionPacket(
+		rootSecret, true, nil, obfuscatedPacket1)
+	if err == nil {
+		return errors.TraceNew("unexpected initiator -> initiator success")
+	}
+
+	// Test: responder -> responder
+
+	obfuscatedPacket2, err = obfuscateSessionPacket(
+		rootSecret, false, packet, minPadding, maxPadding)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	_, err = deobfuscateSessionPacket(
+		rootSecret, false, newObfuscationReplayHistory(), obfuscatedPacket2)
+	if err == nil {
+		return errors.TraceNew("unexpected initiator -> initiator success")
+	}
+
+	// Test: distinct keys derived for each direction
+
+	isInitiator := true
+	secret1, err := deriveSessionPacketObfuscationSecret(
+		rootSecret, isInitiator, true)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	isInitiator = false
+	secret2, err := deriveSessionPacketObfuscationSecret(
+		rootSecret, isInitiator, true)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = testMostlyDifferent(secret1[:], secret2[:])
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: for identical packet with same padding and derived key, most
+	// bytes different (due to nonce)
+
+	padding := 100
+
+	obfuscatedPacket1, err = obfuscateSessionPacket(
+		rootSecret, true, packet, padding, padding)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	obfuscatedPacket2, err = obfuscateSessionPacket(
+		rootSecret, false, packet, padding, padding)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = testMostlyDifferent(obfuscatedPacket1, obfuscatedPacket2)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: uniformly random
+
+	for _, isInitiator := range []bool{true, false} {
+
+		err = testEntropy(func() ([]byte, error) {
+			obfuscatedPacket, err := obfuscateSessionPacket(
+				rootSecret, isInitiator, packet, padding, padding)
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+			return obfuscatedPacket, nil
+		})
+		if err != nil {
+			return errors.Trace(err)
+		}
+	}
+
+	// Test: wrong obfuscation secret
+
+	wrongRootSecret, err := GenerateRootObfuscationSecret()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	obfuscatedPacket1, err = obfuscateSessionPacket(
+		wrongRootSecret, true, packet, minPadding, maxPadding)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	_, err = deobfuscateSessionPacket(
+		rootSecret, false, newObfuscationReplayHistory(), obfuscatedPacket1)
+	if err == nil {
+		return errors.TraceNew("unexpected wrong secret success")
+	}
+
+	// Test: truncated obfuscated packet
+
+	obfuscatedPacket1, err = obfuscateSessionPacket(
+		rootSecret, true, packet, minPadding, maxPadding)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	obfuscatedPacket1 = obfuscatedPacket1[:len(obfuscatedPacket1)-1]
+
+	_, err = deobfuscateSessionPacket(
+		rootSecret, false, newObfuscationReplayHistory(), obfuscatedPacket1)
+	if err == nil {
+		return errors.TraceNew("unexpected truncated packet success")
+	}
+
+	// Test: flip byte
+
+	obfuscatedPacket1, err = obfuscateSessionPacket(
+		rootSecret, true, packet, minPadding, maxPadding)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	obfuscatedPacket1[len(obfuscatedPacket1)-1] ^= 1
+
+	_, err = deobfuscateSessionPacket(
+		rootSecret, false, newObfuscationReplayHistory(), obfuscatedPacket1)
+	if err == nil {
+		return errors.TraceNew("unexpected modified packet success")
+	}
+
+	return nil
+}
+
+func TestObfuscationReplayHistory(t *testing.T) {
+	err := runTestObfuscationReplayHistory()
+	if err != nil {
+		t.Errorf(errors.Trace(err).Error())
+	}
+}
+
+func runTestObfuscationReplayHistory() error {
+
+	replayHistory := newObfuscationReplayHistory()
+
+	size := obfuscationSessionPacketNonceSize
+
+	count := int(obfuscationAntiReplayHistorySize / 100)
+
+	// Test: values found as expected; no false positives
+
+	for i := 0; i < count; i++ {
+
+		value := prng.Bytes(size)
+
+		if replayHistory.Lookup(value) {
+			return errors.Tracef("value found on iteration %d", i)
+		}
+
+		err := replayHistory.Insert(value)
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		if !replayHistory.Lookup(value) {
+			return errors.Tracef("value not found on iteration %d", i)
+		}
+	}
+
+	return nil
+}
+
+func testMostlyDifferent(a, b []byte) error {
+
+	if len(a) != len(b) {
+		return errors.TraceNew("unexpected different size")
+	}
+
+	equalBytes := 0
+	for i := 0; i < len(a); i++ {
+		if a[i] == b[i] {
+			equalBytes += 1
+		}
+	}
+
+	// TODO: use a stricter threshold?
+	if equalBytes > len(a)/10 {
+		return errors.Tracef("unexpected similar bytes: %d/%d", equalBytes, len(a))
+	}
+
+	return nil
+}
+
+func testEntropy(f func() ([]byte, error)) error {
+
+	bitCount := make(map[int]int)
+
+	n := 10000
+
+	for i := 0; i < n; i++ {
+
+		value, err := f()
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		for j := 0; j < len(value); j++ {
+			for k := 0; k < 8; k++ {
+				bit := (uint8(value[j]) >> k) & 0x1
+				bitCount[(j*8)+k] += int(bit)
+			}
+		}
+
+	}
+
+	// TODO: use a stricter threshold?
+	for index, count := range bitCount {
+		if count < n/3 || count > 2*n/3 {
+			return errors.Tracef("unexpected entropy at %d: %v", index, bitCount)
+		}
+	}
+
+	return nil
+}

+ 159 - 0
psiphon/common/inproxy/packet-trace.diff

@@ -0,0 +1,159 @@
+diff --git a/Temp/inproxy-proto2/inproxy/inproxy_test.go b/Temp/inproxy-proto2/inproxy/inproxy_test.go
+index 334befb1..6da022d3 100644
+--- a/Temp/inproxy-proto2/inproxy/inproxy_test.go
++++ b/Temp/inproxy-proto2/inproxy/inproxy_test.go
+@@ -80,12 +80,15 @@ func runTestInProxy() error {
+ 	//numProxies := 10
+ 	//proxyMaxClients := 5
+ 	//numClients := 100
+-	numProxies := 5
+-	proxyMaxClients := 2
+-	numClients := 10
++	numProxies := 1
++	proxyMaxClients := 1
++	numClients := 5
+ 
++	// *TEMP*
++	//bytesToSend := 1 << 20
++	//messageSize := 1 << 10
+ 	bytesToSend := 1 << 20
+-	messageSize := 1 << 10
++	messageSize := 1024
+ 	targetElapsedSeconds := 2
+ 
+ 	baseMetrics := common.APIParameters{
+@@ -349,6 +352,9 @@ func runTestInProxy() error {
+ 
+ 	makeClientFunc := func(isTCP bool) func() error {
+ 
++		// *TEMP*
++		isTCP = false
++
+ 		// *DOC* use echo server address as proxy destination; alternate TCP/UDP
+ 		var network, addr string
+ 		if isTCP {
+@@ -429,6 +435,10 @@ func runTestInProxy() error {
+ 					if bytesToSend-n < m {
+ 						m = bytesToSend - n
+ 					}
++
++					// *TEMP*
++					fmt.Printf("  > SEND: %x\n", prefix(sendBytes[n:n+m]))
++
+ 					_, err := conn.Write(sendBytes[n : n+m])
+ 					if err != nil {
+ 						return errors.Trace(err)
+@@ -446,7 +456,15 @@ func runTestInProxy() error {
+ 					if err != nil {
+ 						return errors.Trace(err)
+ 					}
++
++					// *TEMP*
++					fmt.Printf("  > RECV: %x\n", prefix(buf[:m]))
++
+ 					if !bytes.Equal(sendBytes[n:n+m], buf[:m]) {
++
++						// *TEMP*
++						fmt.Printf("  > RECV: !EQUAL\n")
++
+ 						// *DOC* index logged to diagnose out-of-order or dropped message vs. entirely wrong bytes
+ 						return errors.Tracef(
+ 							"unexpected bytes: expected at index %d, received at index %d",
+@@ -636,6 +654,13 @@ func runTCPEchoServer(listener net.Listener) {
+ 	}
+ }
+ 
++func prefix(b []byte) []byte {
++	if len(b) <= 8 {
++		return b
++	}
++	return b[:8]
++}
++
+ func runUDPEchoServer(packetConn net.PacketConn) {
+ 	buf := make([]byte, 65536)
+ 	for {
+@@ -643,6 +668,10 @@ func runUDPEchoServer(packetConn net.PacketConn) {
+ 		if err != nil {
+ 			return
+ 		}
++
++		// *TEMP*
++		fmt.Printf("  > RELAY: %x\n", prefix(buf[:n]))
++
+ 		_, err = packetConn.WriteTo(buf[:n], addr)
+ 		if err != nil {
+ 			return
+diff --git a/Temp/inproxy-proto2/inproxy/proxy.go b/Temp/inproxy-proto2/inproxy/proxy.go
+index fd6e158f..2531e276 100644
+--- a/Temp/inproxy-proto2/inproxy/proxy.go
++++ b/Temp/inproxy-proto2/inproxy/proxy.go
+@@ -27,6 +27,8 @@ import (
+ 	"sync/atomic"
+ 	"time"
+ 
++	"fmt"
++
+ 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+ 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+ 	"github.com/pion/webrtc/v3"
+@@ -403,11 +405,45 @@ func (p *Proxy) proxyOneClient(ctx context.Context) error {
+ 	waitGroup := new(sync.WaitGroup)
+ 	relayErrors := make(chan error, 2)
+ 
++	// *TEMP*
++	myCopy := func(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) {
++		for {
++			nr, er := src.Read(buf)
++			if nr > 0 {
++				fmt.Printf("  > COPY: %x\n", prefix(buf[:nr]))
++				nw, ew := dst.Write(buf[0:nr])
++				if nw < 0 || nr < nw {
++					nw = 0
++					if ew == nil {
++						ew = errors.TraceNew("write invalid")
++					}
++				}
++				written += int64(nw)
++				if ew != nil {
++					err = ew
++					break
++				}
++				if nr != nw {
++					err = io.ErrShortWrite
++					break
++				}
++			}
++			if er != nil {
++				if er != io.EOF {
++					err = er
++				}
++				break
++			}
++		}
++		return written, err
++	}
++
+ 	waitGroup.Add(1)
+ 	go func() {
+ 		defer waitGroup.Done()
+-		// *TODO* doc: for packet conn, io.Copy buffer must be packet MTU; it's 32K
+-		_, err = io.Copy(webRTCConn, upstreamConn)
++		// *TODO* doc: for packet conn, io.Copy buffer must be packet MTU; it's 32K; need 64K?
++		var buf [65536]byte
++		_, err = io.CopyBuffer(webRTCConn, upstreamConn, buf[:])
+ 		if err != nil {
+ 			relayErrors <- errors.Trace(err)
+ 			return
+@@ -417,7 +453,10 @@ func (p *Proxy) proxyOneClient(ctx context.Context) error {
+ 	waitGroup.Add(1)
+ 	go func() {
+ 		defer waitGroup.Done()
+-		_, err := io.Copy(upstreamConn, webRTCConn)
++		var buf [65536]byte
++		// *TEMP*
++		//_, err := io.CopyBuffer(upstreamConn, webRTCConn, buf[:])
++		_, err := myCopy(upstreamConn, webRTCConn, buf[:])
+ 		if err != nil {
+ 			relayErrors <- errors.Trace(err)
+ 			return

+ 250 - 0
psiphon/common/inproxy/portmapper.go

@@ -0,0 +1,250 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"fmt"
+	"sync"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"tailscale.com/net/portmapper"
+	"tailscale.com/util/clientmetric"
+)
+
+// initPortMapper resets port mapping metrics state associated with the
+// current network when the network changes, as indicated by
+// DialParameters.NetworkID. initPortMapper also configures the port mapping
+// routines to use DialParameters.BindToDevice. Varying
+// DialParameters.BindToDevice between dials in a single process is not
+// supported.
+func initPortMapper(dialParams DialParameters) {
+
+	// It's safe for multiple, concurrent client dials to call
+	// resetRespondingPortMappingTypes: as long as the network ID does not
+	// change, calls won't clear any valid port mapping type metrics that
+	// were just recorded.
+	resetRespondingPortMappingTypes(dialParams.NetworkID())
+
+	// DialParameters.BindToDevice is set as a global variable in
+	// tailscale.com/net/portmapper. It's safe to repeatedly call
+	// setPortMapperBindToDevice here, under the assumption that
+	// DialParameters.BindToDevice is the same single, static function for
+	// all dials. This assumption is true for Psiphon.
+	setPortMapperBindToDevice(dialParams)
+}
+
+// portMapper represents a UDP port mapping from a local port to an external,
+// publicly addressable IP and port. Port mapping is implemented using
+// tailscale.com/net/portmapper, which probes the local network and gateway
+// for UPnP-IGD, NAT-PMP, and PCP port mapping capabilities.
+type portMapper struct {
+	havePortMappingOnce sync.Once
+	portMappingAddress  chan string
+	client              *portmapper.Client
+}
+
+// newPortMapper initializes a new port mapper, configured to map to the
+// specified localPort. newPortMapper does not initiate any network
+// operations (it's safe to call when DisablePortMapping is set).
+func newPortMapper(
+	logger common.Logger,
+	localPort int) *portMapper {
+
+	portMappingLogger := func(format string, args ...any) {
+		logger.WithTrace().Info("port mapping: " + fmt.Sprintf(format, args))
+	}
+
+	p := &portMapper{
+		portMappingAddress: make(chan string, 1),
+	}
+
+	// This code assumes assumes tailscale NewClient call does only
+	// initialization; this is the case as of tailscale.com/net/portmapper
+	// v1.36.2.
+	//
+	// This code further assumes that the onChanged callback passed to
+	// NewClient will not be invoked until after the
+	// GetCachedMappingOrStartCreatingOne call in portMapper.start; and so
+	// the p.client reference within callback will be valid.
+
+	client := portmapper.NewClient(portMappingLogger, nil, nil, func() {
+		p.havePortMappingOnce.Do(func() {
+			address, ok := p.client.GetCachedMappingOrStartCreatingOne()
+			if ok {
+				// With sync.Once and a buffer size of 1, this send won't block.
+				p.portMappingAddress <- address.String()
+			} else {
+
+				// This is not an expected case; there should be a port
+				// mapping when NewClient is invoked.
+				//
+				// TODO: deliver "" to the channel? Otherwise, receiving on
+				// portMapper.portMappingExternalAddress will hang, or block
+				// until a context is done.
+				portMappingLogger("unexpected missing port mapping")
+			}
+		})
+	})
+
+	p.client = client
+
+	p.client.SetLocalPort(uint16(localPort))
+
+	return p
+}
+
+// start initiates the port mapping attempt.
+func (p *portMapper) start() {
+	_, _ = p.client.GetCachedMappingOrStartCreatingOne()
+}
+
+// portMappingExternalAddress returns a channel which recieves a successful
+// port mapping external address, if any.
+func (p *portMapper) portMappingExternalAddress() <-chan string {
+	return p.portMappingAddress
+}
+
+// close releases the port mapping
+func (p *portMapper) close() error {
+	return errors.Trace(p.client.Close())
+}
+
+// probePortMapping discovers and reports which port mapping protocols are
+// supported on this network. probePortMapping does not establish a port mapping.
+//
+// It is intended that in-proxies amake a blocking call to probePortMapping on
+// start up (and after a network change) in order to report fresh port
+// mapping type metrics, for matching optimization in the ProxyAnnounce
+// request. Clients don't incur the delay of a probe call -- which produces
+// no port mapping -- and instead opportunistically grab port mapping type
+// metrics via getRespondingPortMappingTypes.
+func probePortMapping(
+	ctx context.Context,
+	logger common.Logger) (PortMappingTypes, error) {
+
+	portMappingLogger := func(format string, args ...any) {
+		logger.WithTrace().Info("port mapping probe: " + fmt.Sprintf(format, args))
+	}
+
+	client := portmapper.NewClient(portMappingLogger, nil, nil, nil)
+	defer client.Close()
+
+	result, err := client.Probe(ctx)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	portMappingTypes := PortMappingTypes{}
+	if result.UPnP {
+		portMappingTypes = append(portMappingTypes, PortMappingTypeUPnP)
+	}
+	if result.PMP {
+		portMappingTypes = append(portMappingTypes, PortMappingTypePMP)
+	}
+	if result.PCP {
+		portMappingTypes = append(portMappingTypes, PortMappingTypePCP)
+	}
+
+	// An empty lists means discovery is needed or the available port mappings
+	// are unknown; a list with None indicates that a probe returned no
+	// supported port mapping types.
+
+	if len(portMappingTypes) == 0 {
+		portMappingTypes = append(portMappingTypes, PortMappingTypeNone)
+	}
+
+	return portMappingTypes, nil
+}
+
+var respondingPortMappingTypesMutex sync.Mutex
+var respondingPortMappingTypesNetworkID string
+
+// resetRespondingPortMappingTypes clears tailscale.com/net/portmapper global
+// metrics fields which indicate which port mapping types are responding on
+// the current network. These metrics should be cleared whenever the current
+// network changes, as indicated by networkID.
+//
+// Limitations: there may be edge conditions where a
+// tailscale.com/net/portmapper client logs metrics concurrent to
+// resetRespondingPortMappingTypes being called with a new networkID. If
+// incorrect port mapping type metrics are supported, the Broker may log
+// incorrect statistics. However, Broker client/in-proxy matching is based on
+// actually established port mappings.
+func resetRespondingPortMappingTypes(networkID string) {
+
+	respondingPortMappingTypesMutex.Lock()
+	defer respondingPortMappingTypesMutex.Unlock()
+
+	if respondingPortMappingTypesNetworkID != networkID {
+		// Iterating over all metric fields appears to be the only API available.
+		for _, metric := range clientmetric.Metrics() {
+			switch metric.Name() {
+			case "portmap_upnp_ok", "portmap_pmp_ok", "portmap_pcp_ok":
+				metric.Set(0)
+			}
+		}
+		respondingPortMappingTypesNetworkID = networkID
+	}
+}
+
+// getRespondingPortMappingTypes returns the port mapping types that responded
+// during recent portMapper.start invocations as well as probePortMapping
+// invocations. The returned list is used for reporting metrics. See
+// resetRespondingPortMappingTypes for considerations due to accessing
+// tailscale.com/net/portmapper global metrics fields.
+//
+// To avoid delays, we do not run probePortMapping for regular client dials,
+// and so instead use this tailscale.com/net/portmapper metrics field
+// approach.
+//
+// Limitations: the return value represents all port mapping types that
+// responded in this session, since the last network change
+// (resetRespondingPortMappingTypes call); and do not indicate which of
+// several port mapping types may have been used for a particular dial.
+func getRespondingPortMappingTypes(networkID string) PortMappingTypes {
+
+	respondingPortMappingTypesMutex.Lock()
+	defer respondingPortMappingTypesMutex.Unlock()
+
+	portMappingTypes := PortMappingTypes{}
+
+	if respondingPortMappingTypesNetworkID != networkID {
+		// The network changed since the last resetRespondingPortMappingTypes
+		// call, and resetRespondingPortMappingTypes has not yet been called
+		// again. Ignore the current metrics.
+		return portMappingTypes
+	}
+
+	// Iterating over all metric fields appears to be the only API available.
+	for _, metric := range clientmetric.Metrics() {
+		if metric.Name() == "portmap_upnp_ok" && metric.Value() > 1 {
+			portMappingTypes = append(portMappingTypes, PortMappingTypeUPnP)
+		}
+		if metric.Name() == "portmap_pmp_ok" && metric.Value() > 1 {
+			portMappingTypes = append(portMappingTypes, PortMappingTypePMP)
+		}
+		if metric.Name() == "portmap_pcp_ok" && metric.Value() > 1 {
+			portMappingTypes = append(portMappingTypes, PortMappingTypePCP)
+		}
+	}
+	return portMappingTypes
+}

+ 30 - 0
psiphon/common/inproxy/portmapper_android.go

@@ -0,0 +1,30 @@
+//go:build android
+
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"tailscale.com/net/netns"
+)
+
+func setPortMapperBindToDevice(dialParams DialParameters) {
+	netns.SetAndroidProtectFunc(dialParams.BindToDevice)
+}

+ 28 - 0
psiphon/common/inproxy/portmapper_other.go

@@ -0,0 +1,28 @@
+//go:build !android
+
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+func setPortMapperBindToDevice(dialParams DialParameters) {
+	// BindToDevice is not applied on iOS as tailscale.com/net/netns does not
+	// have an equivilent to SetAndroidProtectFunc for iOS. At this time,
+	// BindToDevice operations on iOS are legacy code and not required.
+}

+ 630 - 0
psiphon/common/inproxy/proxy.go

@@ -0,0 +1,630 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"io"
+	"net"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/pion/webrtc/v3"
+)
+
+// Timeouts should be aligned with Broker timeouts.
+
+const (
+	proxyAnnounceRequestTimeout = 2 * time.Minute
+	proxyAnnounceRetryDelay     = 1 * time.Second
+	proxyAnnounceRetryJitter    = 0.3
+	proxyWebRTCAnswerTimeout    = 20 * time.Second
+	proxyAnswerRequestTimeout   = 10 * time.Second
+	proxyClientConnectTimeout   = 30 * time.Second
+	proxyDestinationDialTimeout = 30 * time.Second
+)
+
+// Proxy is the in-proxy proxying component, which relays traffic from a
+// client to a Psiphon server.
+type Proxy struct {
+	// Note: 64-bit ints used with atomic operations are placed
+	// at the start of struct to ensure 64-bit alignment.
+	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
+	bytesUp           int64
+	bytesDown         int64
+	peakBytesUp       int64
+	peakBytesDown     int64
+	connectingClients int32
+	connectedClients  int32
+
+	config                *ProxyConfig
+	brokerClient          *BrokerClient
+	activityUpdateWrapper *activityUpdateWrapper
+}
+
+// TODO: add PublicNetworkAddress/ListenNetworkAddress to facilitate manually
+// configured, permanent port mappings.
+
+// ProxyConfig specifies the configuration for a Proxy run.
+type ProxyConfig struct {
+
+	// Logger is used to log events.
+	Logger common.Logger
+
+	// BaseMetrics should be populated with Psiphon handshake metrics
+	// parameters. These will be sent to and logger by the Broker.
+	BaseMetrics common.APIParameters
+
+	// OperatorMessageHandler is a callback that is invoked with any user
+	// message JSON object that is sent to the Proxy from the Broker. This
+	// facility may be used to alert proxy operators when required. The JSON
+	// object schema is arbitrary and not defined here.
+	OperatorMessageHandler func(messageJSON string)
+
+	// DialParameters specifies specific broker and WebRTC dial configuration
+	// and strategies and settings; DialParameters also facilities dial
+	// replay by receiving callbacks when individual dial steps succeed or
+	// fail.
+	//
+	// As a DialParameters is associated with one network ID, it is expected
+	// that the proxy will be stopped and restarted when a network change is
+	// detected.
+	DialParameters DialParameters
+
+	// MaxClients is the maximum number of clients that are allowed to connect
+	// to the proxy.
+	MaxClients int
+
+	// LimitUpstreamBytesPerSecond limits the upstream data transfer rate for
+	// a single client. When 0, there is no limit.
+	LimitUpstreamBytesPerSecond int
+
+	// LimitDownstreamBytesPerSecond limits the downstream data transfer rate
+	// for a single client. When 0, there is no limit.
+	LimitDownstreamBytesPerSecond int
+
+	// ActivityUpdater specifies an ActivityUpdater for activity associated
+	// with this proxy.
+	ActivityUpdater ActivityUpdater
+}
+
+// ActivityUpdater is a callback that is invoked when clients connect and
+// disconnect and periodically with data transfer updates (unless idle). This
+// callback may be used to update an activity UI. This callback should post
+// this data to another thread or handler and return immediately and not
+// block on UI updates.
+type ActivityUpdater func(
+	connectingClients int32,
+	connectedClients int32,
+	bytesUp int64,
+	bytesDown int64,
+	bytesDuration time.Duration)
+
+// NewProxy initializes a new Proxy with the specified configuration.
+func NewProxy(config *ProxyConfig) (*Proxy, error) {
+
+	// Create one BrokerClient which will be shared for all requests. When the
+	// round tripper supports multiplexing -- for example HTTP/2 -- many
+	// concurrent requests can share the same TLS network connection and
+	// established session.
+
+	brokerClient, err := NewBrokerClient(config.DialParameters)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	p := &Proxy{
+		config:       config,
+		brokerClient: brokerClient,
+	}
+
+	p.activityUpdateWrapper = &activityUpdateWrapper{p: p}
+
+	return p, nil
+}
+
+// activityUpdateWrapper implements the psiphon/common.ActivityUpdater
+// interface and is used to receive bytes transferred updates from the
+// ActivityConns wrapping proxied traffic. A wrapper is used so that
+// UpdateProgress is not exported from Proxy.
+type activityUpdateWrapper struct {
+	p *Proxy
+}
+
+func (w *activityUpdateWrapper) UpdateProgress(bytesRead, bytesWritten int64, _ int64) {
+	atomic.AddInt64(&w.p.bytesUp, bytesWritten)
+	atomic.AddInt64(&w.p.bytesDown, bytesRead)
+}
+
+// Run runs the Proxy. The proxy sends requests to the Broker announcing its
+// availability; the Broker matches the proxy with clients, and facilitates
+// an exchange of WebRTC connection information; the proxy and each client
+// attempt to establish a connection; and the client's traffic is relayed to
+// Psiphon server.
+//
+// Run ends when ctx is Done. When a network change is detected, Run should be
+// stopped and a new Proxy configured and started. This minimizes dangling
+// client connections running over the previous network; provides an
+// opportunity to gather fresh NAT/port mapping metrics for the new network;
+// and allows for a new DialParameters, associated with the new network, to
+// be configured.
+func (p *Proxy) Run(ctx context.Context) {
+
+	// Reset and configure port mapper component, as required. See
+	// initPortMapper comment.
+	initPortMapper(p.config.DialParameters)
+
+	// Gather local network NAT/port mapping metrics before sending any
+	// announce requests. NAT topology metrics are used by the Broker to
+	// optimize client and in-proxy matching. Unlike the client, we always
+	// perform this synchronous step here, since waiting doesn't necessarily
+	// block a client tunnel dial.
+
+	initWaitGroup := new(sync.WaitGroup)
+	initWaitGroup.Add(1)
+	go func() {
+		defer initWaitGroup.Done()
+
+		// NATDiscover may use cached NAT type/port mapping values from
+		// DialParameters, based on the network ID. If discovery is not
+		// successful, the proxy still proceeds to announce.
+
+		NATDiscover(
+			ctx,
+			&NATDiscoverConfig{
+				Logger:         p.config.Logger,
+				DialParameters: p.config.DialParameters,
+			})
+
+	}()
+	initWaitGroup.Wait()
+
+	// Run MaxClient proxying workers. Each worker handles one client at a time.
+
+	proxyWaitGroup := new(sync.WaitGroup)
+
+	for i := 0; i < p.config.MaxClients; i++ {
+		proxyWaitGroup.Add(1)
+		go func() {
+			defer proxyWaitGroup.Done()
+			p.proxyClients(ctx)
+		}()
+	}
+
+	// Capture activity updates every second, which is the required frequency
+	// for PeakUp/DownstreamBytesPerSecond. This is also a reasonable
+	// frequency for invoking the ActivityUpdater and updating UI widgets.
+
+	activityUpdatePeriod := 1 * time.Second
+	ticker := time.NewTicker(activityUpdatePeriod)
+	defer ticker.Stop()
+
+loop:
+	for {
+		select {
+		case <-ticker.C:
+			p.activityUpdate(activityUpdatePeriod)
+		case <-ctx.Done():
+			break loop
+		}
+	}
+
+	proxyWaitGroup.Wait()
+}
+
+func (p *Proxy) activityUpdate(period time.Duration) {
+
+	connectingClients := atomic.LoadInt32(&p.connectingClients)
+	connectedClients := atomic.LoadInt32(&p.connectedClients)
+	bytesUp := atomic.SwapInt64(&p.bytesUp, 0)
+	bytesDown := atomic.SwapInt64(&p.bytesDown, 0)
+
+	greaterThanSwapInt64(&p.peakBytesUp, bytesUp)
+	greaterThanSwapInt64(&p.peakBytesDown, bytesDown)
+
+	if connectingClients == 0 &&
+		connectedClients == 0 &&
+		bytesUp == 0 &&
+		bytesDown == 0 {
+		// Skip the activity callback on idle.
+		return
+	}
+
+	p.config.ActivityUpdater(
+		connectingClients,
+		connectedClients,
+		bytesUp,
+		bytesDown,
+		period)
+}
+
+func greaterThanSwapInt64(addr *int64, new int64) bool {
+
+	// Limitation: if there are two concurrent calls, the greater value could
+	// get overwritten.
+
+	old := atomic.LoadInt64(addr)
+	if new > old {
+		return atomic.CompareAndSwapInt64(addr, old, new)
+	}
+	return false
+}
+
+func (p *Proxy) proxyClients(ctx context.Context) {
+
+	// Proxy one client, repeating until ctx is done.
+	//
+	// This worker starts with posting a long-polling announcement request.
+	// The broker response with a matched client, and the proxy and client
+	// attempt to establish a WebRTC connection for relaying traffic.
+	//
+	// Limitation: this design may not maximize the utility of the proxy,
+	// since some proxy/client connections will fail at the WebRTC stage due
+	// to NAT traversal failure, and at most MaxClient concurrent
+	// establishments are attempted. Another scenario comes from the Psiphon
+	// client horse race, which may start in-proxy dials but then abort them
+	// when some other tunnel protocol succeeds.
+	//
+	// As a future enhancement, consider using M announcement goroutines and N
+	// WebRTC dial goroutines. When an announcement gets a response,
+	// immediately announce again unless there are already MaxClient active
+	// connections established. This approach may require the proxy to
+	// backpedal and reject connections when establishment is too successful.
+	//
+	// Another enhancement could be a signal from the client, to the broker,
+	// relayed to the proxy, when a dial is aborted.
+
+	for ctx.Err() == nil {
+		err := p.proxyOneClient(ctx)
+		if err != nil && ctx.Err() == nil {
+			p.config.Logger.WithTraceFields(
+				common.LogFields{
+					"error": err.Error(),
+				}).Error("proxy client failed")
+
+			// Delay briefly, to avoid unintentionally overloading the broker
+			// in some recurring failure case. Use a jitter to avoid a
+			// regular traffic period.
+
+			common.SleepWithJitter(
+				ctx,
+				common.ValueOrDefault(p.config.DialParameters.AnnounceRetryDelay(), proxyAnnounceRetryDelay),
+				common.ValueOrDefault(p.config.DialParameters.AnnounceRetryJitter(), proxyAnnounceRetryJitter))
+		}
+	}
+}
+
+func (p *Proxy) proxyOneClient(ctx context.Context) error {
+
+	// Send the announce request
+
+	// At this point, no NAT traversal operations have been performed by the
+	// proxy, since its announcement may sit idle for the long-polling period
+	// and NAT hole punches or port mappings could expire before the
+	// long-polling period.
+	//
+	// As a future enhancement, the proxy could begin gathering WebRTC ICE
+	// candidates while awaiting a client match, reducing the turn around
+	// time after a match. This would make sense if there's high demand for
+	// proxies, and so hole punches unlikely to expire while awaiting a client match.
+	//
+	// Another possibility may be to prepare and send a full offer SDP in the
+	// announcment; and have the broker modify either the proxy or client
+	// offer SDP to produce an answer SDP. In this case, the entire
+	// ProxyAnswerRequest could be skipped as the WebRTC dial can begin after
+	// the ProxyAnnounceRequest response (and ClientOfferRequest response).
+	//
+	// Furthermore, if a port mapping can be established, instead of using
+	// WebRTC the proxy could run a Psiphon tunnel protocol listener at the
+	// mapped port and send the dial information -- including some secret to
+	// authenticate the client -- in its announcement. The client would then
+	// receive this direct dial information from the broker and connect. The
+	// proxy should be able to send keep alives to extend the port mapping
+	// lifetime.
+
+	announceRequestCtx, announceRequestCancelFunc := context.WithTimeout(
+		ctx, common.ValueOrDefault(p.config.DialParameters.AnnounceRequestTimeout(), proxyAnnounceRequestTimeout))
+	defer announceRequestCancelFunc()
+
+	metrics, err := p.getMetrics()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// A proxy ID is implicitly sent with requests; it's the proxy's session
+	// public key.
+
+	announceResponse, err := p.brokerClient.ProxyAnnounce(
+		announceRequestCtx,
+		&ProxyAnnounceRequest{
+			PersonalCompartmentIDs: p.config.DialParameters.PersonalCompartmentIDs(),
+			Metrics:                metrics,
+		})
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if announceResponse.ClientProxyProtocolVersion != ProxyProtocolVersion1 {
+		return errors.Tracef(
+			"Unsupported proxy protocol version: %d",
+			announceResponse.ClientProxyProtocolVersion)
+	}
+
+	if announceResponse.OperatorMessageJSON != "" {
+		p.config.OperatorMessageHandler(announceResponse.OperatorMessageJSON)
+	}
+
+	// For activity updates, indicate that a client connection is now underway.
+
+	atomic.AddInt32(&p.connectingClients, 1)
+	connected := false
+	defer func() {
+		if !connected {
+			atomic.AddInt32(&p.connectingClients, -1)
+		}
+	}()
+
+	// Initialize WebRTC using the client's offer SDP
+
+	webRTCAnswerCtx, webRTCAnswerCancelFunc := context.WithTimeout(
+		ctx, common.ValueOrDefault(p.config.DialParameters.WebRTCAnswerTimeout(), proxyWebRTCAnswerTimeout))
+	defer webRTCAnswerCancelFunc()
+
+	webRTCConn, SDP, SDPMetrics, webRTCErr := NewWebRTCConnWithAnswer(
+		webRTCAnswerCtx,
+		&WebRTCConfig{
+			Logger:                      p.config.Logger,
+			DialParameters:              p.config.DialParameters,
+			ClientRootObfuscationSecret: announceResponse.ClientRootObfuscationSecret,
+		},
+		announceResponse.ClientOfferSDP)
+	var webRTCRequestErr string
+	if webRTCErr != nil {
+		webRTCErr = errors.Trace(webRTCErr)
+		webRTCRequestErr = webRTCErr.Error()
+		SDP = webrtc.SessionDescription{}
+		// Continue to report the error to the broker. The broker will respond
+		// with failure to the client's offer request.
+	}
+	defer webRTCConn.Close()
+
+	// Send answer request with SDP or error.
+
+	answerRequestCtx, answerRequestCancelFunc := context.WithTimeout(
+		ctx, common.ValueOrDefault(p.config.DialParameters.AnswerRequestTimeout(), proxyAnswerRequestTimeout))
+	defer answerRequestCancelFunc()
+
+	_, err = p.brokerClient.ProxyAnswer(
+		answerRequestCtx,
+		&ProxyAnswerRequest{
+			ConnectionID:                 announceResponse.ConnectionID,
+			SelectedProxyProtocolVersion: announceResponse.ClientProxyProtocolVersion,
+			ProxyAnswerSDP:               SDP,
+			ICECandidateTypes:            SDPMetrics.ICECandidateTypes,
+			AnswerError:                  webRTCRequestErr,
+		})
+	if err != nil {
+		if webRTCErr != nil {
+			// Prioritize returning any WebRTC error for logging.
+			return webRTCErr
+		}
+		return errors.Trace(err)
+	}
+
+	// Now that an answer is sent, stop if WebRTC initialization failed.
+
+	if webRTCErr != nil {
+		return webRTCErr
+	}
+
+	// Await the WebRTC connection.
+
+	// We could concurrently dial the destination, to have that network
+	// connection available immediately once the WebRTC channel is
+	// established. This would work only for TCP, not UDP, network protocols
+	// and could only include the TCP connection, as client traffic is
+	// required for all higher layers such as TLS, SSH, etc. This could also
+	// create wasted load on destination Psiphon servers, particularly when
+	// WebRTC connections fail.
+
+	clientConnectCtx, clientConnectCancelFunc := context.WithTimeout(
+		ctx, common.ValueOrDefault(p.config.DialParameters.ProxyClientConnectTimeout(), proxyClientConnectTimeout))
+	defer clientConnectCancelFunc()
+
+	err = webRTCConn.AwaitInitialDataChannel(clientConnectCtx)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	p.config.Logger.WithTraceFields(common.LogFields{
+		"connectionID": announceResponse.ConnectionID,
+	}).Info("WebRTC data channel established")
+
+	// Dial the destination, a Psiphon server. The broker validates that the
+	// dial destination is a Psiphon server.
+
+	destinationDialContext, destinationDialCancelFunc := context.WithTimeout(
+		ctx, common.ValueOrDefault(p.config.DialParameters.ProxyDestinationDialTimeout(), proxyDestinationDialTimeout))
+	defer destinationDialCancelFunc()
+
+	// Use the custom resolver when resolving destination hostnames, such as
+	// those used in domain fronted protocols.
+	//
+	// - Resolving at the in-proxy should yield a more optimal CDN edge, vs.
+	//   resolving at the client.
+	//
+	// - Sending unresolved hostnames to in-proxies can expose some domain
+	//   fronting configuration. This can be mitigated by enabling domain
+	//   fronting on this 2nd hop only when the in-proxy is located in a
+	//   region that may be censored or blocked; this is to be enforced by
+	//   the broker.
+	//
+	// - Any DNSResolverPreresolved tactics applied will be relative to the
+	//   in-proxy location.
+
+	destinationAddress, err := p.config.DialParameters.ResolveAddress(ctx, announceResponse.DestinationAddress)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	var dialer net.Dialer
+	destinationConn, err := dialer.DialContext(
+		destinationDialContext,
+		announceResponse.NetworkProtocol.String(),
+		destinationAddress)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer destinationConn.Close()
+
+	// For activity updates, indicate that a client connection is established.
+
+	connected = true
+	atomic.AddInt32(&p.connectingClients, -1)
+	atomic.AddInt32(&p.connectedClients, 1)
+	defer func() {
+		atomic.AddInt32(&p.connectedClients, -1)
+	}()
+
+	// Throttle the relay connection.
+	//
+	// Here, each client gets LimitUp/DownstreamBytesPerSecond. Proxy
+	// operators may to want to limit their bandwidth usage with a single
+	// up/down value, an overall limit. The ProxyConfig can simply be
+	// generated by dividing the limit by MaxClients. This approach favors
+	// performance stability: each client gets the same throttling limits
+	// regardless of how many other clients are connected.
+
+	destinationConn = common.NewThrottledConn(
+		destinationConn,
+		common.RateLimits{
+			ReadBytesPerSecond:  int64(p.config.LimitUpstreamBytesPerSecond),
+			WriteBytesPerSecond: int64(p.config.LimitDownstreamBytesPerSecond),
+		})
+
+	// Hook up bytes transferred counting for activity updates.
+
+	// The ActivityMonitoredConn inactivity timeout is not configured, since
+	// the Psiphon server will close its connection to inactive clients on
+	// its own schedule.
+
+	destinationConn, err = common.NewActivityMonitoredConn(
+		destinationConn, 0, false, nil, p.activityUpdateWrapper)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Relay the client traffic to the destination. The client traffic is a
+	// standard Psiphon tunnel protocol destinated to a Psiphon server. Any
+	// blocking/censorship at the 2nd hop will be mitigated by the use of
+	// Psiphon circumvention protocols and techniques.
+
+	// Limitation: clients may apply fragmentation to traffic relayed over the
+	// data channel, and there's no guarantee that the fragmentation write
+	// sizes or delays will carry over to the egress side.
+
+	// The proxy operator's ISP may be able to observe that the operator's
+	// host has nearly matching ingress and egress traffic. The traffic
+	// content won't be the same: the ingress traffic is wrapped in a WebRTC
+	// data channel, and the egress traffic is a Psiphon tunnel protocol. But
+	// the traffic shape will be close to the same. As a future enhancement,
+	// consider adding data channel padding and decoy traffic, which is
+	// dropped on egress. For performance, traffic shaping could be ceased
+	// after some time. Even with this measure, over time the number of bytes
+	// in and out of the proxy may still indicate proxying.
+
+	waitGroup := new(sync.WaitGroup)
+	relayErrors := make(chan error, 2)
+
+	waitGroup.Add(1)
+	go func() {
+		defer waitGroup.Done()
+
+		// WebRTC data channels are based on SCTP, which is actually
+		// message-based, not a stream. The (default) max message size for
+		// pion/sctp is 65536:
+		// https://github.com/pion/sctp/blob/44ed465396c880e379aae9c1bf81809a9e06b580/association.go#L52.
+		//
+		// As io.Copy uses a buffer size of 32K, each relayed message will be
+		// less than the maximum. Calls to ClientConn.Write are also expected
+		// to use io.Copy, keeping messages at most 32K in size. Note that
+		// testing with io.CopyBuffer and a buffer of size 65536 actually
+		// yielded the pion error io.ErrShortBuffer, "short buffer", while a
+		// buffer of size 65535 worked.
+
+		_, err := io.Copy(webRTCConn, destinationConn)
+		if err != nil {
+			relayErrors <- errors.Trace(err)
+			return
+		}
+	}()
+
+	waitGroup.Add(1)
+	go func() {
+		defer waitGroup.Done()
+		_, err := io.Copy(destinationConn, webRTCConn)
+		if err != nil {
+			relayErrors <- errors.Trace(err)
+			return
+		}
+	}()
+
+	select {
+	case err = <-relayErrors:
+	case <-ctx.Done():
+	}
+
+	// Interrupt the relay goroutines by closing the connections.
+	webRTCConn.Close()
+	destinationConn.Close()
+
+	waitGroup.Wait()
+
+	p.config.Logger.WithTraceFields(common.LogFields{
+		"connectionID": announceResponse.ConnectionID,
+	}).Info("connection closed")
+
+	return err
+}
+
+func (p *Proxy) getMetrics() (*ProxyMetrics, error) {
+
+	baseMetrics, err := EncodeBaseMetrics(p.config.BaseMetrics)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return &ProxyMetrics{
+		BaseMetrics:                   baseMetrics,
+		ProxyProtocolVersion:          ProxyProtocolVersion1,
+		NATType:                       p.config.DialParameters.NATType(),
+		PortMappingTypes:              p.config.DialParameters.PortMappingTypes(),
+		MaxClients:                    int32(p.config.MaxClients),
+		ConnectingClients:             atomic.LoadInt32(&p.connectingClients),
+		ConnectedClients:              atomic.LoadInt32(&p.connectedClients),
+		LimitUpstreamBytesPerSecond:   int64(p.config.LimitUpstreamBytesPerSecond),
+		LimitDownstreamBytesPerSecond: int64(p.config.LimitDownstreamBytesPerSecond),
+		PeakUpstreamBytesPerSecond:    atomic.LoadInt64(&p.peakBytesUp),
+		PeakDownstreamBytesPerSecond:  atomic.LoadInt64(&p.peakBytesDown),
+	}, nil
+}

+ 173 - 0
psiphon/common/inproxy/records.go

@@ -0,0 +1,173 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"encoding/binary"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/fxamacker/cbor/v2"
+)
+
+// Records are CBOR-encoded data with a preamble, or prefix, indicating the
+// encoding schema version, data type, and data length. Records include
+// session messages, as well as API requests and responses which are session
+// message payloads.
+
+const (
+	recordVersion = 1
+
+	recordTypeFirst                          = 1
+	recordTypeSessionPacket                  = 1
+	recordTypeSessionRoundTrip               = 2
+	recordTypeAPIProxyAnnounceRequest        = 3
+	recordTypeAPIProxyAnnounceResponse       = 4
+	recordTypeAPIProxyAnswerRequest          = 5
+	recordTypeAPIProxyAnswerResponse         = 6
+	recordTypeAPIClientOfferRequest          = 7
+	recordTypeAPIClientOfferResponse         = 8
+	recordTypeAPIClientRelayedPacketRequest  = 9
+	recordTypeAPIClientRelayedPacketResponse = 10
+	recordTypeAPIBrokerServerRequest         = 11
+	recordTypeAPIBrokerServerResponse        = 12
+	recordTypeLast                           = 12
+)
+
+var cborEncoding, _ = cbor.CTAP2EncOptions().EncMode()
+
+func marshalRecord(record interface{}, recordType int) ([]byte, error) {
+	payload, err := cborEncoding.Marshal(record)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	payload, err = addRecordPreamble(recordType, payload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	return payload, nil
+}
+
+func unmarshalRecord(expectedRecordType int, payload []byte, record interface{}) error {
+	payload, err := readRecordPreamble(expectedRecordType, payload)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	err = cbor.Unmarshal(payload, record)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	return nil
+}
+
+// addRecordPreamble prepends a record preamble to the given record data
+// buffer. The input recordType specifies the type to encode; a version
+// number identifying the current encoding schema is supplied automatically.
+//
+// To avoid allocations, addRecordPreamble modifies the input record buffer;
+// use like record = append(record, ...).
+func addRecordPreamble(
+	recordType int, record []byte) ([]byte, error) {
+
+	if recordVersion < 0 || recordVersion > 0xff {
+		return nil, errors.TraceNew("invalid record preamble version")
+	}
+
+	if recordType < 0 || recordType > 0xff {
+		return nil, errors.TraceNew("invalid record preamble type")
+	}
+
+	if len(record) > 0xffff {
+		return nil, errors.TraceNew("invalid record length")
+	}
+
+	// The preamble:
+	// [ 1 byte version ][ 1 byte type ][ varint record data length ][ ...record data ... ]
+
+	var preamble [2 + binary.MaxVarintLen64]byte
+	preamble[0] = byte(recordVersion)
+	preamble[1] = byte(recordType)
+	preambleLen := 2 + binary.PutUvarint(preamble[2:], uint64(len(record)))
+
+	// Attempt to use the input buffer, which will avoid an allocation if it
+	// has sufficient capacity.
+	record = append(record, preamble[:preambleLen]...)
+	copy(record[preambleLen:], record[:len(record)-preambleLen])
+	copy(record[0:preambleLen], preamble[:preambleLen])
+
+	return record, nil
+}
+
+// peekRecordPreambleType returns the record type of the record data payload,
+// or an error if the preamble is invalid.
+func peekRecordPreambleType(payload []byte) (int, error) {
+
+	if len(payload) < 2 {
+		return -1, errors.TraceNew("invalid record preamble length")
+	}
+
+	if int(payload[0]) != recordVersion {
+		return -1, errors.TraceNew("invalid record preamble version")
+	}
+
+	recordType := int(payload[1])
+
+	if recordType < recordTypeFirst || recordType > recordTypeLast {
+		return -1, errors.Tracef("invalid record preamble type: %d %x", recordType, payload)
+	}
+
+	return recordType, nil
+}
+
+// readRecordPreamble consumes the record preamble from the given record data
+// payload and returns the remaining record. The record type must match
+// expectedRecordType and the version must match a known encoding schema
+// version.
+//
+// To avoid allocations, readRecordPreamble returns a slice of the
+// input record buffer; use like record = record[n:].
+func readRecordPreamble(expectedRecordType int, payload []byte) ([]byte, error) {
+
+	if len(payload) < 2 {
+		return nil, errors.TraceNew("invalid record preamble length")
+	}
+
+	if int(payload[0]) != recordVersion {
+		return nil, errors.TraceNew("invalid record preamble version")
+	}
+
+	if int(payload[1]) != expectedRecordType {
+		return nil, errors.Tracef("unexpected record preamble type")
+	}
+
+	recordDataLength, n := binary.Uvarint(payload[2:])
+	if n < 1 || 2+n > len(payload) {
+		return nil, errors.Tracef("invalid record preamble data length")
+	}
+
+	record := payload[2+n:]
+
+	// In the future, the data length field may be used to implement framing
+	// for a stream of records. For now, this check is simply a sanity check.
+	if len(record) != int(recordDataLength) {
+		return nil, errors.TraceNew("unexpected record preamble data length")
+	}
+
+	return record, nil
+}

+ 173 - 0
psiphon/common/inproxy/server.go

@@ -0,0 +1,173 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+)
+
+// ServerBrokerSessions manages the secure sessions that handle
+// BrokerServerRequests from brokers. Each in-proxy-capable Psiphon server
+// maintains a ServerBrokerSessions, with a set of established sessions for
+// each broker. Session messages are relayed between the broker and the
+// server by the client.
+type ServerBrokerSessions struct {
+	sessions *ResponderSessions
+}
+
+// NewServerBrokerSessions create a new ServerBrokerSessions, with the
+// specified key material. The expected brokers are authenticated with
+// brokerPublicKeys, an allow list.
+func NewServerBrokerSessions(
+	serverPrivateKey SessionPrivateKey,
+	serverRootObfuscationSecret ObfuscationSecret,
+	brokerPublicKeys []SessionPublicKey) (*ServerBrokerSessions, error) {
+
+	sessions, err := NewResponderSessionsForKnownInitiators(
+		serverPrivateKey, serverRootObfuscationSecret, brokerPublicKeys)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return &ServerBrokerSessions{
+		sessions: sessions,
+	}, nil
+}
+
+// ProxiedConnectionHandler is a callback, provided by the Psiphon server,
+// that receives information from a BrokerServerRequest for the client
+// associated with the callback.
+//
+// The server must use the brokerVerifiedOriginalClientIP for all GeoIP
+// operations associated with the client, including traffic rule selection
+// and client-side tactics selection.
+//
+// Since the BrokerServerRequest may be delivered later than the Psiphon
+// handshake -- in the case where the broker/server session needs to be
+// established there will be additional round trips -- the server should
+// delay traffic rule application, tactics responses, and allowing tunneled
+// traffic until after the ProxiedConnectionHandler callback is invoked for
+// the client. As a consequence, Psiphon Servers should be configured to
+// require Proxies to be used for designated protocols. It's expected that
+// server-side tactics such as packet manipulation will be applied based on
+// the proxy's IP address.
+//
+// The fields in logFields should be added to server_tunnel logs.
+type ProxiedConnectionHandler func(
+	brokerVerifiedOriginalClientIP string,
+	logFields common.LogFields)
+
+// HandlePacket handles a broker/server session packet, which are relayed by
+// clients. In Psiphon, the packets may be sent in the Psiphon handshake, or
+// in subsequent requests; while responses should be returned in the
+// handshake response or responses for later requests. When the broker/server
+// session is already established, it's expected that the BrokerServerRequest
+// arrives in the packet that accompanies the Psiphon handshake, and so no
+// additional round trip is required.
+//
+// Once the session is established and a verified BrokerServerRequest arrives,
+// the information from that request is sent to the ProxiedConnectionHandler
+// callback. The callback should be associated with the client that is
+// relaying the packets.
+//
+// clientConnectionID is the in-proxy connection ID specified by the client in
+// its Psiphon handshake.
+//
+// When the retOut return value is not nil, it should be relayed back to the
+// client in the handshake response or other tunneled response. When retOut
+// is nil, the relay is complete.
+//
+// When the retErr return value is not nil, it should be logged, and an error
+// flag (but not the retErr value) relayed back to the client. retErr may be
+// non-nil in expected conditions, such as the broker attempting to use a
+// session which has expired.
+func (s *ServerBrokerSessions) HandlePacket(
+	logger common.Logger,
+	in []byte,
+	clientConnectionID ID,
+	handler ProxiedConnectionHandler) (retOut []byte, retErr error) {
+
+	handleUnwrappedRequest := func(initiatorID ID, unwrappedRequestPayload []byte) ([]byte, error) {
+
+		brokerRequest, err := UnmarshalBrokerServerRequest(unwrappedRequestPayload)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		logFields, err := brokerRequest.ValidateAndGetLogFields()
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		// The initiatorID is the broker's public key.
+		logFields["broker_id"] = initiatorID
+
+		var errorMessage string
+
+		// The client must supply same connection ID to server that the broker
+		// sends to the server.
+		if brokerRequest.ConnectionID != clientConnectionID {
+
+			// Errors such as this are not error return values, as this is not
+			// an error in the session protocol. Instead, a response is sent
+			// to the broker containing the error message, which the broker may log.
+
+			errorMessage = "connection ID mismatch"
+
+			logger.WithTraceFields(common.LogFields{
+				"client_connection_id": clientConnectionID,
+				"broker_connection_id": brokerRequest.ConnectionID,
+			}).Error(errorMessage)
+		}
+
+		if errorMessage == "" {
+
+			handler(brokerRequest.ClientIP, logFields)
+		}
+
+		brokerResponse, err := MarshalBrokerServerResponse(
+			&BrokerServerResponse{
+				ConnectionID: brokerRequest.ConnectionID,
+				ErrorMessage: errorMessage,
+			})
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		return brokerResponse, nil
+	}
+
+	// An error here may be due to the broker using a session that has
+	// expired. In that case, the client should relay back that the session
+	// failed, and the broker will start reestablishing the session.
+	//
+	// TODO: distinguish between session expired, an expected error, and
+	// unexpected errors and then log only unexpected errors? However, expiry
+	// may be rare, and still useful to log.
+
+	out, err := s.sessions.HandlePacket(
+		in, handleUnwrappedRequest)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return out, nil
+}

+ 1525 - 0
psiphon/common/inproxy/session.go

@@ -0,0 +1,1525 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"bytes"
+	"context"
+	"crypto/rand"
+	"math"
+	"sync"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	lrucache "github.com/cognusion/go-cache-lru"
+	"github.com/flynn/noise"
+	"golang.org/x/crypto/curve25519"
+	"golang.zx2c4.com/wireguard/replay"
+)
+
+const (
+	sessionsTTL     = 5 * time.Minute
+	sessionsMaxSize = 100000
+
+	sessionObfuscationPaddingMinSize = 0
+	sessionObfuscationPaddingMaxSize = 256
+)
+
+const (
+	SessionProtocolName     = "psiphon-inproxy-session"
+	SessionProtocolVersion1 = 1
+)
+
+// SessionPrologue is a Noise protocol prologue, which binds the session ID to
+// the session.
+type SessionPrologue struct {
+	SessionProtocolName    string `cbor:"1,keyasint,omitempty"`
+	SessionProtocolVersion uint32 `cbor:"2,keyasint,omitempty"`
+	SessionID              ID     `cbor:"3,keyasint,omitempty"`
+}
+
+// SessionPacket is a Noise protocol message, which may be a session handshake
+// message, or secured application data, a SessionRoundTrip.
+type SessionPacket struct {
+	SessionID ID     `cbor:"1,keyasint,omitempty"`
+	Nonce     uint64 `cbor:"2,keyasint,omitempty"`
+	Payload   []byte `cbor:"3,keyasint,omitempty"`
+}
+
+// SessionRoundTrip is an application data request or response, which is
+// secured by the Noise protocol session. Each request is assigned a unique
+// RoundTripID, and each corresponding response has the same RoundTripID.
+type SessionRoundTrip struct {
+	RoundTripID ID     `cbor:"1,keyasint,omitempty"`
+	Payload     []byte `cbor:"2,keyasint,omitempty"`
+}
+
+// SessionPrivateKey is a Noise protocol private key.
+type SessionPrivateKey [32]byte
+
+// SessionPublicKey is a Noise protocol private key.
+type SessionPublicKey [32]byte
+
+// IsZero indicates if the private key is zero-value.
+func (k SessionPrivateKey) IsZero() bool {
+	var zero SessionPrivateKey
+	return bytes.Equal(k[:], zero[:])
+}
+
+// GenerateSessionPrivateKey creates a new Noise protocol session private key
+// using crypto/rand.
+func GenerateSessionPrivateKey() (SessionPrivateKey, error) {
+
+	var privateKey SessionPrivateKey
+
+	keyPair, err := noise.DH25519.GenerateKeypair(rand.Reader)
+	if err != nil {
+		return privateKey, errors.Trace(err)
+	}
+
+	if len(keyPair.Private) != len(privateKey) {
+		return privateKey, errors.TraceNew("unexpected private key length")
+	}
+	copy(privateKey[:], keyPair.Private)
+
+	return privateKey, nil
+}
+
+// GetSessionPublicKey returns the public key corresponding to the private
+// key.
+func GetSessionPublicKey(privateKey SessionPrivateKey) (SessionPublicKey, error) {
+
+	var sessionPublicKey SessionPublicKey
+
+	publicKey, err := curve25519.X25519(privateKey[:], curve25519.Basepoint)
+	if err != nil {
+		return sessionPublicKey, errors.Trace(err)
+	}
+
+	if len(publicKey) != len(sessionPublicKey) {
+		return sessionPublicKey, errors.TraceNew("unexpected public key length")
+	}
+	copy(sessionPublicKey[:], publicKey)
+
+	return sessionPublicKey, nil
+}
+
+// InitiatorSessions is a set of secure Noise protocol sessions for an
+// initiator. For in-proxy, clients and proxies will initiate sessions with
+// one more brokers and brokers will initiate sessions with multiple Psiphon
+// servers.
+//
+// Secure sessions provide encryption, authentication of the responder,
+// identity hiding for the initiator, forward secrecy, and anti-replay for
+// application data.
+//
+// Maintaining a set of established sessions minimizes round trips and
+// overhead, as established sessions can be shared and reused for many client
+// requests to one broker or many broker requests to one server.
+//
+// Currently, InitiatorSessions doesn't not cap the number of sessions or use
+// an LRU cache since the number of peers is bounded in the in-proxy
+// architecture; clients will typically use one or no more than a handful of
+// brokers and brokers will exchange requests with a subset of Psiphon
+// servers bounded by the in-proxy capability.
+//
+// InitiatorSessions are used via the RoundTrip function or InitiatorRoundTrip
+// type. RoundTrip is a synchronous function which performs any necessary
+// session establishment handshake along with the request/response exchange.
+// InitiatorRoundTrip offers an iterator interface, with stepwise invocations
+// for each step of the handshake and round trip.
+//
+// All round trips attempt to share and reuse any existing, established
+// session to a given peer. For a given peer, the waitToShareSession option
+// determines whether round trips will block and wait if a session handshake
+// is already in progress, or proceed with a concurrent handshake. For
+// in-proxy, clients and proxies use waitToShareSession; as broker/server
+// round trips are relayed through clients, brokers do not use
+// waitToShareSession so as to not rely on any single client.
+//
+// Round trips can be performed concurrently and requests can arrive out-of-
+// order. The higher level transport for sessions is responsible for
+// multiplexing round trips and maintaining the association between a request
+// and it's corresponding response.
+type InitiatorSessions struct {
+	privateKey SessionPrivateKey
+
+	mutex    sync.Mutex
+	sessions sessionLookup
+}
+
+// NewInitiatorSessions creates a new InitiatorSessions with the specified
+// initator private key.
+func NewInitiatorSessions(
+	initiatorPrivateKey SessionPrivateKey) *InitiatorSessions {
+
+	return &InitiatorSessions{
+		privateKey: initiatorPrivateKey,
+		sessions:   make(sessionLookup),
+	}
+}
+
+// RoundTrip sends the request to the specified responder and returns the
+// response.
+//
+// RoundTrip will establish a session when required, or reuse an existing
+// session when available.
+//
+// When waitToShareSession is true, RoundTrip will block until an existing,
+// non-established session is available to be shared.
+//
+// RoundTrip returns immediately when ctx becomes done.
+func (s *InitiatorSessions) RoundTrip(
+	ctx context.Context,
+	roundTripper RoundTripper,
+	responderPublicKey SessionPublicKey,
+	responderRootObfuscationSecret ObfuscationSecret,
+	waitToShareSession bool,
+	request []byte) ([]byte, error) {
+
+	rt, err := s.NewRoundTrip(
+		responderPublicKey,
+		responderRootObfuscationSecret,
+		waitToShareSession,
+		request)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	didResetSession := false
+
+	var in []byte
+	for {
+		out, err := rt.Next(ctx, in)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		if out == nil {
+			response, err := rt.Response()
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+			return response, nil
+		}
+		in, err = roundTripper.RoundTrip(ctx, out)
+		if err != nil {
+
+			// Perform at most one session reset, to accomodate the expected
+			// case where the initator reuses an established session that is
+			// expired for the responder.
+			//
+			// Higher levels implicitly provide additional retries to cover
+			// other cases; Psiphon client tunnel establishment will retry
+			// in-proxy dials; the proxy will retry its announce request if
+			// it fails -- after an appropriate delay.
+
+			if didResetSession == false {
+				// TODO: log reset
+				rt.ResetSession()
+				didResetSession = true
+			} else {
+				return nil, errors.Trace(err)
+			}
+		}
+	}
+}
+
+// NewRoundTrip creates a new InitiatorRoundTrip which will perform a
+// request/response round trip with the specified responder, sending the
+// input request. The InitiatorRoundTrip will establish a session when
+// required, or reuse an existing session when available.
+//
+// When waitToShareSession is true, InitiatorRoundTrip.Next will block until
+// an existing, non-established session is available to be shared.
+//
+// NewRoundTrip does not block or perform any session operations; the
+// operations begin on the first InitiatorRoundTrip.Next call. The content of
+// request should not be modified after calling NewRoundTrip.
+func (s *InitiatorSessions) NewRoundTrip(
+	responderPublicKey SessionPublicKey,
+	responderRootObfuscationSecret ObfuscationSecret,
+	waitToShareSession bool,
+	request []byte) (*InitiatorRoundTrip, error) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	// Generate a new round trip ID for the session round trip. The response
+	// is expected to echo back the same round trip ID. This check detects
+	// any potential misrouting of multiplexed round trip exchanges.
+
+	roundTripID, err := MakeID()
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	requestPayload, err := marshalRecord(
+		SessionRoundTrip{RoundTripID: roundTripID, Payload: request},
+		recordTypeSessionRoundTrip)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return &InitiatorRoundTrip{
+		initiatorSessions:              s,
+		responderPublicKey:             responderPublicKey,
+		responderRootObfuscationSecret: responderRootObfuscationSecret,
+		waitToShareSession:             waitToShareSession,
+		roundTripID:                    roundTripID,
+		requestPayload:                 requestPayload,
+	}, nil
+}
+
+// getSession looks for an existing session for the peer specified by public
+// key. When none is found, newSession is called to create a new session, and
+// this is stored, associated with the key. If an existing session is found,
+// indicate if it is ready to be shared or not.
+func (s *InitiatorSessions) getSession(
+	publicKey SessionPublicKey,
+	newSession func() (*session, error)) (
+	retSession *session, retisNew bool, retIsReady bool, retErr error) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	session, ok := s.sessions[publicKey]
+	if ok {
+		return session, false, session.isReadyToShare(nil), nil
+	}
+
+	session, err := newSession()
+	if err != nil {
+		return nil, false, false, errors.Trace(err)
+	}
+
+	s.sessions[publicKey] = session
+
+	return session, true, session.isReadyToShare(nil), nil
+}
+
+// setSession sets the session associated with the peer's public key.
+func (s *InitiatorSessions) setSession(publicKey SessionPublicKey, session *session) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	s.sessions[publicKey] = session
+}
+
+// removeIfSession removes the session associated with the peer's public key,
+// if it's the specified session.
+func (s *InitiatorSessions) removeIfSession(publicKey SessionPublicKey, session *session) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	currentSession, ok := s.sessions[publicKey]
+	if !ok || session != currentSession {
+		return
+	}
+
+	delete(s.sessions, publicKey)
+}
+
+// InitiatorRoundTrip represents the state of a session round trip, including
+// a session handshake if required. The session handshake and round trip is
+// advanced by calling InitiatorRoundTrip.Next.
+type InitiatorRoundTrip struct {
+	initiatorSessions              *InitiatorSessions
+	responderPublicKey             SessionPublicKey
+	responderRootObfuscationSecret ObfuscationSecret
+	waitToShareSession             bool
+	roundTripID                    ID
+	requestPayload                 []byte
+
+	mutex           sync.Mutex
+	sharingSession  bool
+	didResetSession bool
+	session         *session
+	response        []byte
+}
+
+// ResetSession clears the InitiatorRoundTrip session. Call ResetSession when
+// the responder indicates an error in response to session packet. Errors are
+// sent at the transport level. An error is expected when the initator reuses
+// an established session that is expired for the responder. After calling
+// ResetSession, the following Next call will being establishing a new
+// session. The expected session expiry scenario should occur at most once
+// per round trip.
+//
+// Limitation: since session errors/failures are handled at the transport
+// level, they may be forged, depending on the security provided by the
+// transport layer. For client and proxy sessions with a broker, if domain
+// fronting is used then security depends on the HTTPS layer and CDNs can
+// forge a session error. For broker sessions with Psiphon servers, the
+// relaying client could forge a server error -- but that would deny service
+// to the client when the BrokerServerRequest fails.
+//
+// ResetSession is ignored if response already received or if ResetSession
+// already called before.
+//
+// Higher levels implicitly provide additional round trip retries to cover
+// other cases; Psiphon client tunnel establishment will retry in-proxy
+// dials; the proxy will retry its announce request if it fails -- after an
+// appropriate delay.
+func (r *InitiatorRoundTrip) ResetSession() {
+
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	if r.didResetSession || r.response != nil {
+		return
+	}
+
+	if r.session != nil {
+
+		r.initiatorSessions.removeIfSession(r.responderPublicKey, r.session)
+		r.didResetSession = true
+		r.session = nil
+	}
+}
+
+// Next advances a round trip, as well as any session handshake that may be
+// first required. Next takes the next packet received from the responder and
+// returns the next packet to send to the responder. To begin, pass a nil
+// receivedPacket. The round trip is complete when Next returns nil for the
+// next packet to send; the response can be fetched from
+// InitiatorRoundTrip.Response.
+//
+// When waitToShareSession is set, Next will block until an existing,
+// non-established session is available to be shared.
+//
+// Multiple concurrent round trips are supported and requests from different
+// round trips can arrive at the responder out-of-order. The provided
+// transport is responsible for multiplexing round trips and maintaining an
+// association between sent and received packets for a given round trip.
+//
+// Next returns immediately when ctx becomes done.
+func (r *InitiatorRoundTrip) Next(
+	ctx context.Context,
+	receivedPacket []byte) (retSendPacket []byte, retErr error) {
+
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	if ctx != nil {
+		err := ctx.Err()
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+	}
+
+	if r.session == nil {
+
+		// If the session is nil, this is the first call to Next, and no
+		// packet from the peer is expected.
+
+		if receivedPacket != nil {
+			return nil, errors.TraceNew("unexpected received packet")
+		}
+
+		newSession := func() (*session, error) {
+			session, err := newSession(
+				true, // isInitiator
+				r.initiatorSessions.privateKey,
+				r.responderRootObfuscationSecret,
+				nil, // No obfuscation replay history
+				&r.responderPublicKey,
+				r.requestPayload,
+				nil,
+				nil)
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+			return session, nil
+		}
+
+		// Check for an existing session, or create a new one if there's no
+		// existing session.
+		//
+		// To ensure the concurrent waitToShareSession cases don't start
+		// multiple handshakes, getSession populates the initiatorSessions
+		// session map with a new, unestablished session.
+
+		session, isNew, isReady, err := r.initiatorSessions.getSession(
+			r.responderPublicKey, newSession)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		if isNew {
+
+			// When isNew is true, this InitiatorRoundTrip owns the session
+			// and will perform the handshake.
+
+			r.session = session
+			r.sharingSession = false
+
+		} else {
+
+			if isReady {
+
+				// When isReady is true, this shared session is fully
+				// established and ready for immediate use.
+
+				r.session = session
+				r.sharingSession = true
+
+			} else {
+
+				// The existing session is not yet ready for use.
+
+				if r.waitToShareSession {
+
+					// Wait for the owning InitiatorRoundTrip to complete the
+					// session handshake and then share the session.
+
+					signal := make(chan struct{})
+					if !session.isReadyToShare(signal) {
+						select {
+						case <-signal:
+						case <-ctx.Done():
+							return nil, errors.Trace(ctx.Err())
+						}
+					}
+					r.session = session
+					r.sharingSession = true
+
+				} else {
+
+					// Don't wait: create a new, unshared session.
+
+					r.session, err = newSession()
+					if err != nil {
+						return nil, errors.Trace(err)
+					}
+					r.sharingSession = false
+				}
+			}
+		}
+
+		if r.sharingSession {
+
+			// The shared session was either ready for immediate use, or we
+			// waited. Send the round trip request payload.
+
+			sendPacket, err := r.session.sendPacket(r.requestPayload)
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+			return sendPacket, nil
+		}
+
+		// Begin the handshake for a new session.
+
+		_, sendPacket, _, err := r.session.nextHandshakePacket(nil)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		return sendPacket, nil
+
+	}
+
+	// Not the first Next call, so a packet from the peer is expected.
+
+	if receivedPacket == nil {
+		return nil, errors.TraceNew("missing received packet")
+	}
+
+	if r.sharingSession || r.session.isEstablished() {
+
+		// When sharing an established and ready session, or once an owned
+		// session is eastablished, the next packet is post-handshake and
+		// should be the round trip request response.
+
+		responsePayload, err := r.session.receivePacket(receivedPacket)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		var sessionRoundTrip SessionRoundTrip
+		err = unmarshalRecord(recordTypeSessionRoundTrip, responsePayload, &sessionRoundTrip)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		// Check that the response RoundTripID matches the request RoundTripID.
+
+		if sessionRoundTrip.RoundTripID != r.roundTripID {
+			return nil, errors.TraceNew("unexpected round trip ID")
+		}
+
+		// Store the response so it can be retrieved later.
+
+		r.response = sessionRoundTrip.Payload
+		return nil, nil
+	}
+
+	// Continue the handshake. Since the first payload is sent to the
+	// responder along with the initiator's last handshake message, there's
+	// no sendPacket call in the owned session case. The last
+	// nextHandshakePacket will bundle it. Also, the payload output of
+	// nextHandshakePacket is ignored, as only a responder will receive a
+	// payload in a handshake message.
+
+	isEstablished, sendPacket, _, err := r.session.nextHandshakePacket(receivedPacket)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	if isEstablished {
+
+		// Retain the most recently established session as the cached session
+		// for reuse. This should be a no-op in the isNew case and only have
+		// an effect for !inNew and !waitToShareSession. Modifying the
+		// initiatorSessions map entry should not impact any concurrent
+		// handshakes, as each InitiatorRoundTrip maintains its own reference
+		// to its session.
+
+		r.initiatorSessions.setSession(r.responderPublicKey, r.session)
+	}
+
+	return sendPacket, nil
+}
+
+// Response returns the round trip response. Call Response after Next returns
+// nil for the next packet to send, indicating that the round trip is
+// complete.
+func (r *InitiatorRoundTrip) Response() ([]byte, error) {
+
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	if r.response == nil {
+		return nil, errors.TraceNew("no response")
+	}
+
+	return r.response, nil
+}
+
+// ResponderSessions is a set of secure Noise protocol sessions for an
+// responder. For in-proxy, brokers respond to clients and proxies and
+// servers respond to brokers.
+//
+// Secure sessions provide encryption, authentication of the responder,
+// identity hiding for the initiator, forward secrecy, and anti-replay for
+// application data.
+//
+// ResponderSessions maintains a cache of established sessions to minimizes
+// round trips and overhead as initiators are expected to make multiple round
+// trips. The cache has a TTL and maximum size with LRU to cap overall memory
+// usage. A broker may receive requests from millions of clients and proxies
+// and so only more recent sessions will be retained. Servers will receive
+// requests from only a handful of brokers, and so the TTL is not applied.
+//
+// Multiple, concurrent sessions for a single initiator public key are
+// supported.
+type ResponderSessions struct {
+	privateKey                  SessionPrivateKey
+	rootObfuscationSecret       ObfuscationSecret
+	applyTTL                    bool
+	obfuscationReplayHistory    *obfuscationReplayHistory
+	expectedInitiatorPublicKeys sessionPublicKeyLookup
+
+	mutex    sync.Mutex
+	sessions *lrucache.Cache
+}
+
+// NewResponderSessions creates a new ResponderSessions which allows any
+// initiators to establish a session. A TTL is applied to cached sessions.
+func NewResponderSessions(
+	responderPrivateKey SessionPrivateKey,
+	responderRootObfuscationSecret ObfuscationSecret) (*ResponderSessions, error) {
+
+	return &ResponderSessions{
+		privateKey:               responderPrivateKey,
+		rootObfuscationSecret:    responderRootObfuscationSecret,
+		applyTTL:                 true,
+		obfuscationReplayHistory: newObfuscationReplayHistory(),
+		sessions:                 lrucache.NewWithLRU(sessionsTTL, 1*time.Minute, sessionsMaxSize),
+	}, nil
+}
+
+// NewResponderSessionsForKnownInitiators creates a new ResponderSessions
+// which allows only allow-listed initiators to establish a session. No TTL
+// is applied to cached sessions.
+//
+// The NewResponderSessionsForKnownInitiators configuration is for Psiphon
+// servers responding to brokers. Only a handful of brokers are expected to
+// be deployed. A relatively small allow list of expected broker public keys
+// is easy to manage, deploy, and update. No TTL is applied to keep the
+// sessions established as much as possible and avoid extra client-relayed
+// round trips for BrokerServerRequests.
+func NewResponderSessionsForKnownInitiators(
+	responderPrivateKey SessionPrivateKey,
+	responderRootObfuscationKey ObfuscationSecret,
+	initiatorPublicKeys []SessionPublicKey) (*ResponderSessions, error) {
+
+	s, err := NewResponderSessions(responderPrivateKey, responderRootObfuscationKey)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	expectedPublicKeys := make(sessionPublicKeyLookup)
+	for _, publicKey := range initiatorPublicKeys {
+		expectedPublicKeys[publicKey] = struct{}{}
+	}
+
+	s.expectedInitiatorPublicKeys = expectedPublicKeys
+
+	return s, nil
+}
+
+// RequestHandler is an application-level handler that receives the decrypted
+// request payload and returns a response payload to be encrypted and sent to
+// the initiator. The initiatorID is the authenticated identifier of the
+// initiator: client, proxy, or broker.
+type RequestHandler func(initiatorID ID, request []byte) ([]byte, error)
+
+// HandlePacket takes a session packet, as received at the transport level,
+// and handles session handshake and request decryption. While a session
+// handshakes, HandlePacket returns the next handshake message to be relayed
+// back to the initiator over the transport.
+//
+// Once a session is fully established and a request is decrypted, the inner
+// request payload is passed to the RequestHandler for application-level
+// processing. The response received from the RequestHandler will be
+// encrypted with the session and returned from HandlePacket as the next
+// packet to send back over the transport.
+//
+// The session packet contains a session ID that is used to route packets from
+// many initiators to the correct session state.
+//
+// Above the Noise protocol security layer, session packets have an
+// obfuscation layer. If a packet doesn't authenticate with the expected
+// obfuscation secret, or if a packet is replayed, HandlePacket returns an
+// error. The obfuscation anti-replay layer covers replays of Noise handshake
+// messages which aren't covered by the Noise nonce anti-replay. When
+// HandlePacket returns an error, the caller should invoke anti-probing
+// behavior, such as returning a generic 404 error from an HTTP server for
+// HTTPS transports.
+//
+// There is one expected error case with legitimate initiators: when an
+// initiator reuses a session that is expired or no longer in the responder
+// cache. In this case the error response should be the same; the initiator
+// knows to attempt one session re-establishment in this case.
+//
+// The HandlePacket caller should implement initiator rate limiting in its
+// transport level.
+func (s *ResponderSessions) HandlePacket(
+	inPacket []byte,
+	requestHandler RequestHandler) (retOutPacket []byte, retErr error) {
+
+	// Concurrency: no locks are held for this function, only in specific
+	// helper functions.
+
+	// unwrapSessionPacket deobfuscates the session packet, and unmarshals a
+	// SessionPacket. The SessionPacket.SessionID is used to route the
+	// session packet to an existing session or to create a new one. The
+	// SessionPacket.Payload is a Noise handshake message or an encrypted
+	// request and that will be handled below.
+
+	sessionPacket, err := unwrapSessionPacket(
+		s.rootObfuscationSecret, false, s.obfuscationReplayHistory, inPacket)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	sessionID := sessionPacket.SessionID
+
+	// Check for an existing session with this session ID, or create a new one
+	// if not found. If the session _was_ in the cache but is now expired, a
+	// new session is created, but subsequent Noise operations will fail.
+
+	session, err := s.getSession(sessionID)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	defer func() {
+		if retErr != nil {
+
+			// If an error is returned, the session has failed, so don't
+			// retain it in the cache as it could be more recently used than
+			// an older but still valid session.
+			//
+			// TODO: should we retain the session if it has completed the
+			// handshake? As with initiator error signals, and depending on
+			// the transport security level, a SessionPacket with a
+			// legitimate session ID but corrupt Noise payload could be
+			// forged, terminating a legitimate session.
+
+			s.removeSession(sessionID)
+		}
+	}()
+
+	var requestPayload []byte
+
+	if session.isEstablished() {
+
+		// When the session is already established, decrypt the packet to get
+		// the request.
+
+		payload, err := session.receiveUnmarshaledPacket(sessionPacket)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		requestPayload = payload
+
+	} else {
+
+		// When the session is not established, the packet is the next
+		// handshake message. The initiator appends the request payload to
+		// the end of its last XK handshake message, and in that case payload
+		// will contain the request.
+
+		isEstablished, outPacket, payload, err :=
+			session.nextUnmarshaledHandshakePacket(sessionPacket)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		if outPacket != nil {
+
+			// The handshake is not complete until outPacket is nil; send the
+			// next handshake packet.
+
+			if payload != nil {
+
+				// A payload is not expected unless the handshake is complete.
+				return nil, errors.TraceNew("unexpected handshake payload")
+			}
+
+			// The session TTL is not extended here. Initiators, including
+			// clients and proxies, are given sessionsTTL to complete the
+			// entire handshake.
+
+			return outPacket, nil
+		}
+
+		if !isEstablished || payload == nil {
+
+			// When outPacket is nil, the handshake should be complete --
+			// isEstablished -- and, by convention, the first request payload
+			// should be available.
+
+			return nil, errors.TraceNew("unexpected established state")
+		}
+
+		requestPayload = payload
+	}
+
+	// Extend the session TTL.
+	s.touchSession(sessionID, session)
+
+	initiatorID, err := session.getPeerID()
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	var sessionRoundTrip SessionRoundTrip
+	err = unmarshalRecord(recordTypeSessionRoundTrip, requestPayload, &sessionRoundTrip)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	request := sessionRoundTrip.Payload
+
+	response, err := requestHandler(initiatorID, request)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// The response is assigned the same RoundTripID as the request.
+	sessionRoundTrip = SessionRoundTrip{
+		RoundTripID: sessionRoundTrip.RoundTripID,
+		Payload:     response,
+	}
+
+	responsePayload, err := marshalRecord(
+		sessionRoundTrip, recordTypeSessionRoundTrip)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	responsePacket, err := session.sendPacket(responsePayload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return responsePacket, nil
+}
+
+// touchSession sets a cached session for the specified session ID; if the
+// session is already in the cache, its TTL is extended. The LRU session
+// cache entry may be discarded once the cache is full.
+func (s *ResponderSessions) touchSession(sessionID ID, session *session) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	TTL := lrucache.DefaultExpiration
+	if !s.applyTTL {
+		TTL = lrucache.NoExpiration
+	}
+	s.sessions.Set(string(sessionID[:]), session, TTL)
+}
+
+// getSession returns an existing session for the specified session ID, or
+// creates a new session, and places it in the cache, if not found.
+func (s *ResponderSessions) getSession(sessionID ID) (*session, error) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	strSessionID := string(sessionID[:])
+
+	entry, ok := s.sessions.Get(strSessionID)
+	if ok {
+		return entry.(*session), nil
+	}
+
+	session, err := newSession(
+		false, // !isInitiator
+		s.privateKey,
+		s.rootObfuscationSecret,
+		s.obfuscationReplayHistory,
+		nil,
+		nil,
+		&sessionID,
+		s.expectedInitiatorPublicKeys)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	s.sessions.Set(
+		strSessionID, session, lrucache.DefaultExpiration)
+
+	return session, nil
+}
+
+// removeSession removes any existing session for the specified session ID.
+func (s *ResponderSessions) removeSession(sessionID ID) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	s.sessions.Delete(string(sessionID[:]))
+}
+
+type sessionState int
+
+const (
+
+	/*
+
+	   XK:
+	     <- s
+	     ...
+	     -> e, es
+	     <- e, ee
+	     -> s, se [+ first payload]
+
+	*/
+
+	sessionStateInitiator_XK_send_e_es = iota
+	sessionStateInitiator_XK_recv_e_ee_send_s_se_payload
+	sessionStateInitiator_XK_established
+
+	sessionStateResponder_XK_recv_e_es_send_e_ee
+	sessionStateResponder_XK_recv_s_se_payload
+	sessionStateResponder_XK_established
+)
+
+type sessionPublicKeyLookup map[SessionPublicKey]struct{}
+
+type sessionLookup map[SessionPublicKey]*session
+
+// session represents a Noise protocol session, including its initial
+// handshake state.
+//
+// The XK pattern is used:
+//   - Initiators may have short-lived static keys (clients), or long-lived
+//     static keys (proxies and brokers). The initiator key is securely
+//     transmitted to the responder while hiding its value.
+//   - The responder static key is always known (K) and exchanged out of
+//     band.
+//   - Provides forward secrecy.
+//   - The round trip request can be appended to the initiators final
+//     handshake message, eliminating an extra round trip.
+//
+// For in-proxy, any client or proxy can connect to a broker. Only allowed
+// brokers can connect to a server.
+//
+// To limit access to allowed brokers, expectedInitiatorPublicKeys is an allow
+// list of broker public keys. XK is still used for this case, instead of
+// KK:
+//   - With KK, the broker identity would have to be known before the Noise
+//     handshake begins
+//   - With XK, the broker proves possession of a private key corresponding to
+//     a broker public key on the allow list.
+//   - While KK will abort sooner than XK when an invalid broker key is used,
+//     completing the handshake and decrypting the first payload does not
+//     leak any information.
+//
+// The is no "close" operation for sessions. Responders will maintain a cache
+// of established sessions and discard the state for expired sessions or in
+// an LRU fashion. Initiators will reuse sessions until they are rejected by
+// a responder.
+//
+// There is no state for the obfuscation layer; each packet is obfuscated
+// independently since session packets may arrive at a peer out-of-order.
+type session struct {
+	isInitiator                 bool
+	sessionID                   ID
+	rootObfuscationSecret       ObfuscationSecret
+	replayHistory               *obfuscationReplayHistory
+	expectedInitiatorPublicKeys sessionPublicKeyLookup
+
+	mutex               sync.Mutex
+	state               sessionState
+	signalOnEstablished []chan struct{}
+	handshake           *noise.HandshakeState
+	firstPayload        []byte
+	peerPublicKey       []byte
+	send                *noise.CipherState
+	receive             *noise.CipherState
+	nonceReplay         replay.Filter
+}
+
+func newSession(
+	isInitiator bool,
+	privateKey SessionPrivateKey,
+	rootObfuscationSecret ObfuscationSecret,
+	replayHistory *obfuscationReplayHistory,
+
+	// Initiator
+	expectedResponderPublicKey *SessionPublicKey,
+	firstPayload []byte,
+
+	// Responder
+	peerSessionID *ID,
+	expectedInitiatorPublicKeys sessionPublicKeyLookup) (*session, error) {
+
+	if isInitiator {
+		if peerSessionID != nil ||
+			expectedResponderPublicKey == nil ||
+			expectedInitiatorPublicKeys != nil ||
+			firstPayload == nil {
+			return nil, errors.TraceNew("unexpected initiator parameters")
+		}
+	} else {
+		if peerSessionID == nil ||
+			expectedResponderPublicKey != nil ||
+			firstPayload != nil {
+			return nil, errors.TraceNew("unexpected responder parameters")
+		}
+	}
+
+	sessionID := peerSessionID
+	if sessionID == nil {
+		ID, err := MakeID()
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		sessionID = &ID
+	}
+
+	// The prologue binds the session ID and other meta data to the session.
+
+	prologue, err := cborEncoding.Marshal(SessionPrologue{
+		SessionProtocolName:    SessionProtocolName,
+		SessionProtocolVersion: SessionProtocolVersion1,
+		SessionID:              *sessionID,
+	})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	publicKey, err := GetSessionPublicKey(privateKey)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// SessionProtocolVersion1 implies this ciphersuite
+
+	config := noise.Config{
+		CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
+		Pattern:     noise.HandshakeXK,
+		Initiator:   isInitiator,
+		Prologue:    prologue,
+		StaticKeypair: noise.DHKey{
+			Public:  publicKey[:],
+			Private: privateKey[:]},
+	}
+
+	if expectedResponderPublicKey != nil {
+		config.PeerStatic = (*expectedResponderPublicKey)[:]
+	}
+
+	handshake, err := noise.NewHandshakeState(config)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	var state sessionState
+	if isInitiator {
+		state = sessionStateInitiator_XK_send_e_es
+	} else {
+		state = sessionStateResponder_XK_recv_e_es_send_e_ee
+	}
+
+	return &session{
+		isInitiator:                 isInitiator,
+		sessionID:                   *sessionID,
+		rootObfuscationSecret:       rootObfuscationSecret,
+		replayHistory:               replayHistory,
+		expectedInitiatorPublicKeys: expectedInitiatorPublicKeys,
+		state:                       state,
+		signalOnEstablished:         make([]chan struct{}, 0), // must be non-nil
+		handshake:                   handshake,
+		firstPayload:                firstPayload,
+	}, nil
+}
+
+// isEstablished indicates that the session handshake is complete.
+//
+// A session may not be ready to share when isEstablished is true.
+func (s *session) isEstablished() bool {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	return s.handshake == nil
+}
+
+// isReadyToShare indicates that the session handshake is complete _and_ that
+// the peer is known to have received and processed the final handshake
+// message.
+//
+// When isReadyToShare is true, multiple round trips can use a session
+// concurrently. Requests from different round trips can arrive at the peer
+// out-of-order.
+//
+// Session sharing is performed by initiators, and in the XK handshake the
+// last step is the initiator sends a final message to the responder. While
+// the initiator session becomes "established" after that last message is
+// output, we need to delay other round trips from sharing the session and
+// sending session-encrypted packets to the responder before the responder
+// actually receives that final handshake message.
+//
+// isReadyToShare becomes true once the round trip performing the handshake
+// receives its round trip response, which demonstrates that the responder
+// received the final message.
+func (s *session) isReadyToShare(signal chan struct{}) bool {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	if !s.isInitiator {
+		return false
+	}
+
+	if s.handshake == nil && s.signalOnEstablished == nil {
+		return true
+	}
+
+	if signal != nil {
+		s.signalOnEstablished = append(
+			s.signalOnEstablished, signal)
+	}
+
+	return false
+}
+
+// getPeerID returns the peer's public key, in the form of an ID. A given peer
+// identifier can only be provided by the peer with the corresponding private
+// key.
+func (s *session) getPeerID() (ID, error) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	var peerID ID
+
+	if s.handshake != nil {
+		return peerID, errors.TraceNew("not established")
+	}
+
+	if len(s.peerPublicKey) != len(peerID) {
+		return peerID, errors.TraceNew("invalid peer public key")
+	}
+
+	copy(peerID[:], s.peerPublicKey)
+
+	return peerID, nil
+}
+
+// sendPacket prepares a session packet to be sent to the peer, containing the
+// specified round trip payload. The packet is secured by the established
+// session.
+func (s *session) sendPacket(payload []byte) ([]byte, error) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	if s.handshake != nil {
+		return nil, errors.TraceNew("not established")
+	}
+
+	if s.send == nil {
+		return nil, errors.Trace(s.unexpectedStateError())
+	}
+
+	nonce := s.send.Nonce()
+
+	// Unlike tunnels, for example, sessions are not for bulk data transfer
+	// and we don't aim for zero allocation or extensive buffer reuse.
+
+	encryptedPayload, err := s.send.Encrypt(nil, nil, payload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	sessionPacket, err := s.wrapPacket(
+		&SessionPacket{
+			SessionID: s.sessionID,
+			Nonce:     nonce,
+			Payload:   encryptedPayload,
+		})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return sessionPacket, nil
+
+}
+
+// receivePacket opens a session packet received from the peer, using the
+// established session, and returns the round trip payload.
+//
+// As responders need to inspect the packet and use its session ID to route
+// packets to the correct session, responders will call
+// receiveUnmarshaledPacket instead.
+func (s *session) receivePacket(packet []byte) ([]byte, error) {
+
+	sessionPacket, err := s.unwrapPacket(packet)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	payload, err := s.receiveUnmarshaledPacket(sessionPacket)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return payload, nil
+}
+
+func (s *session) receiveUnmarshaledPacket(
+	sessionPacket *SessionPacket) ([]byte, error) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	if s.receive == nil {
+		return nil, errors.Trace(s.unexpectedStateError())
+	}
+
+	if sessionPacket.SessionID != s.sessionID {
+		return nil, errors.Tracef("unexpected sessionID")
+	}
+
+	s.receive.SetNonce(sessionPacket.Nonce)
+
+	payload, err := s.receive.Decrypt(nil, nil, sessionPacket.Payload)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	if !s.nonceReplay.ValidateCounter(sessionPacket.Nonce, math.MaxUint64) {
+		return nil, errors.TraceNew("replay detected")
+	}
+
+	// The session is ready to share once it's received a post-handshake
+	// response from the peer.
+
+	s.readyToShare()
+
+	return payload, nil
+}
+
+// nextHandshakePacket advances the session handshake. nextHandshakePacket
+// takes the next handshake packet received from the peer and returns the
+// next handshake packet to send to the peer. Start by passing nil for
+// inPacket. The handshake is complete when outPacket is nil.
+//
+// XK bundles the first initiator request payload along with a handshake
+// message, and nextHandshakePacket output that payload to the responder when
+// the handshake is complete.
+//
+// Once the handshake is complete, further round trips are exchanged using
+// sendPacket and receivePacket.
+//
+// As responders need to inspect the packet and use its session ID to route
+// packets to the correct session, responders will call
+// nextUnmarshaledHandshakePacket instead.
+func (s *session) nextHandshakePacket(inPacket []byte) (
+	isEstablished bool, outPacket []byte, payload []byte, err error) {
+
+	var sessionPacket *SessionPacket
+	if inPacket != nil {
+		sessionPacket, err = s.unwrapPacket(inPacket)
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+	}
+
+	isEstablished, outPacket, payload, err =
+		s.nextUnmarshaledHandshakePacket(sessionPacket)
+	if err != nil {
+		return false, nil, nil, errors.Trace(err)
+	}
+
+	return isEstablished, outPacket, payload, nil
+}
+
+func (s *session) nextUnmarshaledHandshakePacket(sessionPacket *SessionPacket) (
+	isEstablished bool, outPacket []byte, payload []byte, err error) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	var in []byte
+	if sessionPacket != nil {
+		if sessionPacket.SessionID != s.sessionID {
+			return false, nil, nil, errors.Tracef("unexpected sessionID")
+		}
+		if sessionPacket.Nonce != 0 {
+			return false, nil, nil, errors.TraceNew("unexpected nonce")
+		}
+		in = sessionPacket.Payload
+	}
+
+	// Handle handshake state transitions.
+
+	switch s.state {
+
+	// Initiator
+
+	case sessionStateInitiator_XK_send_e_es:
+		out, _, _, err := s.handshake.WriteMessage(nil, nil)
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+		outPacket, err := s.wrapPacket(
+			&SessionPacket{SessionID: s.sessionID, Payload: out})
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+		s.state = sessionStateInitiator_XK_recv_e_ee_send_s_se_payload
+		return false, outPacket, nil, nil
+
+	case sessionStateInitiator_XK_recv_e_ee_send_s_se_payload:
+		_, _, _, err := s.handshake.ReadMessage(nil, in)
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+		out, send, receive, err := s.handshake.WriteMessage(nil, s.firstPayload)
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+		outPacket, err := s.wrapPacket(
+			&SessionPacket{SessionID: s.sessionID, Payload: out})
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+		s.state = sessionStateInitiator_XK_established
+		s.established(send, receive)
+		return true, outPacket, nil, nil
+
+	// Responder
+
+	case sessionStateResponder_XK_recv_e_es_send_e_ee:
+		_, _, _, err := s.handshake.ReadMessage(nil, in)
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+		out, _, _, err := s.handshake.WriteMessage(nil, nil)
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+		outPacket, err := s.wrapPacket(
+			&SessionPacket{SessionID: s.sessionID, Payload: out})
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+		s.state = sessionStateResponder_XK_recv_s_se_payload
+		return false, outPacket, nil, nil
+
+	case sessionStateResponder_XK_recv_s_se_payload:
+		firstPayload, receive, send, err := s.handshake.ReadMessage(nil, in)
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+
+		// Check if the initiator's public key in on the allow list.
+		//
+		// Limitation: unlike with the KK pattern, the handshake completes and
+		// the initial payload is decrypted even when the initiator public
+		// key is not on the allow list.
+
+		err = s.checkExpectedInitiatorPublicKeys(s.handshake.PeerStatic())
+		if err != nil {
+			return false, nil, nil, errors.Trace(err)
+		}
+		s.state = sessionStateResponder_XK_established
+		s.established(send, receive)
+		return true, nil, firstPayload, nil
+	}
+
+	return false, nil, nil, errors.Trace(s.unexpectedStateError())
+}
+
+func (s *session) checkExpectedInitiatorPublicKeys(peerPublicKey []byte) error {
+
+	if s.expectedInitiatorPublicKeys == nil {
+		return nil
+	}
+
+	var publicKey SessionPublicKey
+	copy(publicKey[:], peerPublicKey)
+
+	_, ok := s.expectedInitiatorPublicKeys[publicKey]
+
+	if !ok {
+		return errors.TraceNew("unexpected initiator public key")
+	}
+
+	return nil
+}
+
+// Set the session as established.
+func (s *session) established(
+	send *noise.CipherState,
+	receive *noise.CipherState) {
+
+	// Assumes s.mutex lock is held.
+
+	s.peerPublicKey = s.handshake.PeerStatic()
+	s.handshake = nil
+	s.firstPayload = nil
+	s.send = send
+	s.receive = receive
+}
+
+// Set the session as ready to share.
+func (s *session) readyToShare() {
+
+	// Assumes s.mutex lock is held.
+
+	if s.signalOnEstablished == nil {
+		return
+	}
+
+	for _, signal := range s.signalOnEstablished {
+		close(signal)
+	}
+	s.signalOnEstablished = nil
+}
+
+// Marshal and obfuscate a SessionPacket.
+func (s *session) wrapPacket(sessionPacket *SessionPacket) ([]byte, error) {
+
+	// No lock. References only static session fields.
+
+	marshaledPacket, err := marshalRecord(
+		sessionPacket, recordTypeSessionPacket)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	obfuscatedPacket, err := obfuscateSessionPacket(
+		s.rootObfuscationSecret,
+		s.isInitiator,
+		marshaledPacket,
+		sessionObfuscationPaddingMinSize,
+		sessionObfuscationPaddingMaxSize)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return obfuscatedPacket, nil
+
+}
+
+// Deobfuscate and unmarshal a SessionPacket.
+func (s *session) unwrapPacket(obfuscatedPacket []byte) (*SessionPacket, error) {
+
+	// No lock. References only static session fields.
+
+	sessionPacket, err := unwrapSessionPacket(
+		s.rootObfuscationSecret,
+		s.isInitiator,
+		s.replayHistory,
+		obfuscatedPacket)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return sessionPacket, nil
+
+}
+
+// Deobfuscate and unmarshal SessionPacket. unwrapSessionPacket is used by
+// responders, which must peak at the SessionPacket and get the session ID to
+// route packets to the correct session.
+func unwrapSessionPacket(
+	rootObfuscationSecret ObfuscationSecret,
+	isInitiator bool,
+	replayHistory *obfuscationReplayHistory,
+	obfuscatedPacket []byte) (*SessionPacket, error) {
+
+	packet, err := deobfuscateSessionPacket(
+		rootObfuscationSecret,
+		isInitiator,
+		replayHistory,
+		obfuscatedPacket)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	var sessionPacket *SessionPacket
+	err = unmarshalRecord(recordTypeSessionPacket, packet, &sessionPacket)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return sessionPacket, nil
+}
+
+// Create an error that includes the current handshake state.
+func (s *session) unexpectedStateError() error {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	return errors.Tracef("unexpected state: %v", s.state)
+}

+ 518 - 0
psiphon/common/inproxy/session_test.go

@@ -0,0 +1,518 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"bytes"
+	"context"
+	"crypto/rand"
+	"fmt"
+	"math"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	"github.com/flynn/noise"
+	"golang.zx2c4.com/wireguard/replay"
+)
+
+func TestSessions(t *testing.T) {
+	err := runTestSessions()
+	if err != nil {
+		t.Errorf(errors.Trace(err).Error())
+	}
+}
+
+func runTestSessions() error {
+
+	// Test: basic round trip succeeds
+
+	responderPrivateKey, err := GenerateSessionPrivateKey()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	responderPublicKey, err := GetSessionPublicKey(responderPrivateKey)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	responderRootObfuscationSecret, err := GenerateRootObfuscationSecret()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	responderSessions, err := NewResponderSessions(
+		responderPrivateKey, responderRootObfuscationSecret)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	initiatorPrivateKey, err := GenerateSessionPrivateKey()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	initiatorPublicKey, err := GetSessionPublicKey(initiatorPrivateKey)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
+
+	roundTripper := newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
+
+	waitToShareSession := true
+
+	request := roundTripper.MakeRequest()
+
+	response, err := initiatorSessions.RoundTrip(
+		context.Background(),
+		roundTripper,
+		responderPublicKey,
+		responderRootObfuscationSecret,
+		waitToShareSession,
+		request)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
+		return errors.TraceNew("unexpected response")
+	}
+
+	// Test: session expires; new one negotiated
+
+	responderSessions.sessions.Flush()
+
+	request = roundTripper.MakeRequest()
+
+	response, err = initiatorSessions.RoundTrip(
+		context.Background(),
+		roundTripper,
+		responderPublicKey,
+		responderRootObfuscationSecret,
+		waitToShareSession,
+		request)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
+		return errors.TraceNew("unexpected response")
+	}
+
+	// Test: expected known initiator public key
+
+	initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
+
+	responderSessions, err = NewResponderSessionsForKnownInitiators(
+		responderPrivateKey,
+		responderRootObfuscationSecret,
+		[]SessionPublicKey{initiatorPublicKey})
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	roundTripper = newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
+
+	request = roundTripper.MakeRequest()
+
+	response, err = initiatorSessions.RoundTrip(
+		context.Background(),
+		roundTripper,
+		responderPublicKey,
+		responderRootObfuscationSecret,
+		waitToShareSession,
+		request)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
+		return errors.TraceNew("unexpected response")
+	}
+
+	// Test: wrong known initiator public key
+
+	unknownInitiatorPrivateKey, err := GenerateSessionPrivateKey()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	unknownInitiatorSessions := NewInitiatorSessions(unknownInitiatorPrivateKey)
+
+	ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Millisecond)
+	defer cancelFunc()
+
+	request = roundTripper.MakeRequest()
+
+	response, err = unknownInitiatorSessions.RoundTrip(
+		ctx,
+		roundTripper,
+		responderPublicKey,
+		responderRootObfuscationSecret,
+		waitToShareSession,
+		request)
+	if err == nil || !strings.HasSuffix(err.Error(), "unexpected initiator public key") {
+		return errors.Tracef("unexpected result: %v", err)
+	}
+
+	// Test: many concurrent sessions
+
+	responderSessions, err = NewResponderSessions(
+		responderPrivateKey, responderRootObfuscationSecret)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	roundTripper = newTestSessionRoundTripper(responderSessions, nil)
+
+	clientCount := 10000
+	requestCount := 100
+	concurrentRequestCount := 5
+
+	resultChan := make(chan error, clientCount)
+
+	for i := 0; i < clientCount; i++ {
+
+		// Run clients concurrently
+
+		go func() {
+
+			initiatorPrivateKey, err := GenerateSessionPrivateKey()
+			if err != nil {
+				resultChan <- errors.Trace(err)
+				return
+			}
+
+			initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
+
+			for i := 0; i < requestCount; i += concurrentRequestCount {
+
+				requestResultChan := make(chan error, concurrentRequestCount)
+
+				for j := 0; j < concurrentRequestCount; j++ {
+
+					// Run some of each client's requests concurrently, to
+					// exercise waitToShareSession
+
+					go func(waitToShareSession bool) {
+
+						request := roundTripper.MakeRequest()
+
+						response, err := initiatorSessions.RoundTrip(
+							context.Background(),
+							roundTripper,
+							responderPublicKey,
+							responderRootObfuscationSecret,
+							waitToShareSession,
+							request)
+						if err != nil {
+							requestResultChan <- errors.Trace(err)
+							return
+						}
+
+						if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
+							requestResultChan <- errors.TraceNew("unexpected response")
+							return
+						}
+
+						requestResultChan <- nil
+					}(i%2 == 0)
+				}
+
+				for i := 0; i < concurrentRequestCount; i++ {
+					err = <-requestResultChan
+					if err != nil {
+						resultChan <- errors.Trace(err)
+						return
+					}
+				}
+			}
+
+			resultChan <- nil
+		}()
+	}
+
+	for i := 0; i < clientCount; i++ {
+		err = <-resultChan
+		if err != nil {
+			return errors.Trace(err)
+		}
+	}
+
+	return nil
+}
+
+type testSessionRoundTripper struct {
+	sessions              *ResponderSessions
+	expectedPeerPublicKey *SessionPublicKey
+}
+
+func newTestSessionRoundTripper(
+	sessions *ResponderSessions,
+	expectedPeerPublicKey *SessionPublicKey) *testSessionRoundTripper {
+
+	return &testSessionRoundTripper{
+		sessions:              sessions,
+		expectedPeerPublicKey: expectedPeerPublicKey,
+	}
+}
+
+func (t *testSessionRoundTripper) MakeRequest() []byte {
+	return prng.Bytes(prng.Range(100, 1000))
+}
+
+func (t *testSessionRoundTripper) ExpectedResponse(requestPayload []byte) []byte {
+	l := len(requestPayload)
+	responsePayload := make([]byte, l)
+	for i, b := range requestPayload {
+		responsePayload[l-i-1] = b
+	}
+	return responsePayload
+}
+
+func (t *testSessionRoundTripper) RoundTrip(ctx context.Context, requestPayload []byte) ([]byte, error) {
+
+	err := ctx.Err()
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	unwrappedRequestHandler := func(initiatorID ID, unwrappedRequest []byte) ([]byte, error) {
+
+		if t.expectedPeerPublicKey != nil {
+			if !bytes.Equal(initiatorID[:], (*t.expectedPeerPublicKey)[:]) {
+				return nil, errors.TraceNew("unexpected initiator ID")
+			}
+		}
+
+		return t.ExpectedResponse(unwrappedRequest), nil
+	}
+
+	responsePayload, err := t.sessions.HandlePacket(requestPayload, unwrappedRequestHandler)
+	if err != nil {
+		// Errors here are expected; e.g., in the session expired case.
+		fmt.Printf("HandlePacket failed: %v\n", err)
+		return nil, errors.Trace(err)
+	}
+
+	return responsePayload, nil
+}
+
+func (t *testSessionRoundTripper) Close() error {
+	t.sessions = nil
+	return nil
+}
+
+func TestNoise(t *testing.T) {
+	err := runTestNoise()
+	if err != nil {
+		t.Errorf(errors.Trace(err).Error())
+	}
+}
+
+func runTestNoise() error {
+
+	prologue := []byte("psiphon-inproxy-session")
+
+	initiatorKeys, err := noise.DH25519.GenerateKeypair(rand.Reader)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	responderKeys, err := noise.DH25519.GenerateKeypair(rand.Reader)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	initiatorHandshake, err := noise.NewHandshakeState(
+		noise.Config{
+			CipherSuite:   noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
+			Pattern:       noise.HandshakeXK,
+			Initiator:     true,
+			Prologue:      prologue,
+			StaticKeypair: initiatorKeys,
+			PeerStatic:    responderKeys.Public,
+		})
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	responderHandshake, err := noise.NewHandshakeState(
+		noise.Config{
+			CipherSuite:   noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
+			Pattern:       noise.HandshakeXK,
+			Initiator:     false,
+			Prologue:      prologue,
+			StaticKeypair: responderKeys,
+		})
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Noise XK: -> e, es
+
+	var initiatorMsg []byte
+	initiatorMsg, _, _, err = initiatorHandshake.WriteMessage(initiatorMsg, nil)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	var receivedPayload []byte
+	receivedPayload, _, _, err = responderHandshake.ReadMessage(nil, initiatorMsg)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	if len(receivedPayload) > 0 {
+		return errors.TraceNew("unexpected payload")
+	}
+
+	// Noise XK: <- e, ee
+
+	var responderMsg []byte
+	responderMsg, _, _, err = responderHandshake.WriteMessage(responderMsg, nil)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	receivedPayload = nil
+	receivedPayload, _, _, err = initiatorHandshake.ReadMessage(nil, responderMsg)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	if len(receivedPayload) > 0 {
+		return errors.TraceNew("unexpected payload")
+	}
+
+	// Noise XK: -> s, se + payload
+
+	sendPayload := prng.Bytes(1000)
+
+	var initiatorSend, initiatorReceive *noise.CipherState
+	var initiatorReplay replay.Filter
+
+	initiatorMsg = nil
+	initiatorMsg, initiatorSend, initiatorReceive, err = initiatorHandshake.WriteMessage(initiatorMsg, sendPayload)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	if initiatorSend == nil || initiatorReceive == nil {
+		return errors.Tracef("unexpected incomplete handshake")
+	}
+
+	var responderSend, responderReceive *noise.CipherState
+	var responderReplay replay.Filter
+
+	receivedPayload = nil
+	receivedPayload, responderReceive, responderSend, err = responderHandshake.ReadMessage(receivedPayload, initiatorMsg)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	if responderReceive == nil || responderSend == nil {
+		return errors.TraceNew("unexpected incomplete handshake")
+	}
+	if receivedPayload == nil {
+		return errors.TraceNew("missing payload")
+	}
+	if bytes.Compare(sendPayload, receivedPayload) != 0 {
+		return errors.TraceNew("incorrect payload")
+	}
+
+	if bytes.Compare(responderHandshake.PeerStatic(), initiatorKeys.Public) != 0 {
+		return errors.TraceNew("unexpected initiator static public key")
+	}
+
+	// post-handshake initiator <- responder
+
+	nonce := responderSend.Nonce()
+	responderMsg = nil
+	responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	initiatorReceive.SetNonce(nonce)
+	receivedPayload = nil
+	receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
+		return errors.TraceNew("replay detected")
+	}
+	if bytes.Compare(sendPayload, receivedPayload) != 0 {
+		return errors.TraceNew("incorrect payload")
+	}
+
+	for i := 0; i < 100; i++ {
+
+		// post-handshake initiator -> responder
+
+		sendPayload = prng.Bytes(1000)
+
+		nonce = initiatorSend.Nonce()
+		initiatorMsg = nil
+		initiatorMsg, err = initiatorSend.Encrypt(initiatorMsg, nil, sendPayload)
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		responderReceive.SetNonce(nonce)
+		receivedPayload = nil
+		receivedPayload, err = responderReceive.Decrypt(receivedPayload, nil, initiatorMsg)
+		if err != nil {
+			return errors.Trace(err)
+		}
+		if !responderReplay.ValidateCounter(nonce, math.MaxUint64) {
+			return errors.TraceNew("replay detected")
+		}
+		if bytes.Compare(sendPayload, receivedPayload) != 0 {
+			return errors.TraceNew("incorrect payload")
+		}
+
+		// post-handshake initiator <- responder
+
+		nonce = responderSend.Nonce()
+		responderMsg = nil
+		responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		responderReceive.SetNonce(nonce)
+		receivedPayload = nil
+		receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
+		if err != nil {
+			return errors.Trace(err)
+		}
+		if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
+			return errors.TraceNew("replay detected")
+		}
+		if bytes.Compare(sendPayload, receivedPayload) != 0 {
+			return errors.TraceNew("incorrect payload")
+		}
+	}
+
+	return nil
+}

+ 1081 - 0
psiphon/common/inproxy/webrtc.go

@@ -0,0 +1,1081 @@
+/*
+ * Copyright (c) 2023, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package inproxy
+
+import (
+	"context"
+	"fmt"
+	"math"
+	"net"
+	"strconv"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/pion/datachannel"
+	"github.com/pion/ice/v2"
+	"github.com/pion/sdp/v3"
+	"github.com/pion/webrtc/v3"
+	"github.com/wader/filtertransport"
+)
+
+const (
+	dataChannelBufferedAmountLowThreshold uint64 = 512 * 1024
+	dataChannelMaxBufferedAmount          uint64 = 1024 * 1024
+)
+
+// WebRTCConn is a WebRTC connection between two peers, with a data channel
+// used to relay streams or packets between them. WebRTCConn implements the
+// net.Conn interface.
+type WebRTCConn struct {
+	config                       *WebRTCConfig
+	mutex                        sync.Mutex
+	udpConn                      net.PacketConn
+	portMapper                   *portMapper
+	isClosed                     bool
+	closedSignal                 chan struct{}
+	peerConnection               *webrtc.PeerConnection
+	dataChannel                  *webrtc.DataChannel
+	dataChannelConn              datachannel.ReadWriteCloser
+	dataChannelOpenedSignal      chan struct{}
+	dataChannelOpenedOnce        sync.Once
+	dataChannelWriteBufferSignal chan struct{}
+	messageMutex                 sync.Mutex
+	messageBuffer                []byte
+	messageOffset                int
+	messageLength                int
+	messageError                 error
+}
+
+// WebRTCConfig specifies the configuration for a WebRTC dial.
+type WebRTCConfig struct {
+
+	// Logger is used to log events.
+	Logger common.Logger
+
+	// DialParameters specifies specific WebRTC dial strategies and
+	// settings; DialParameters also facilities dial replay by receiving
+	// callbacks when individual dial steps succeed or fail.
+	DialParameters DialParameters
+
+	// ClientRootObfuscationSecret is generated (or replayed) by the client
+	// and sent to the proxy and used to drive obfuscation operations.
+	ClientRootObfuscationSecret ObfuscationSecret
+
+	// DoDTLSRandomization indicates whether to perform DTLS randomization.
+	DoDTLSRandomization bool
+
+	// ReliableTransport indicates whether to configure the WebRTC data
+	// channel to use reliable transport. Set ReliableTransport when proxying
+	// a TCP stream, and unset it when proxying a UDP packets flow with its
+	// own reliability later, such as QUIC.
+	ReliableTransport bool
+}
+
+// NewWebRTCConnWithOffer initiates a new WebRTC connection. An offer SDP is
+// returned, to be sent to the peer. After the offer SDP is forwarded and an
+// answer SDP received in response, call SetRemoteSDP with the answer SDP and
+// then call AwaitInitialDataChannel to await the eventual WebRTC connection
+// establishment.
+func NewWebRTCConnWithOffer(
+	ctx context.Context,
+	config *WebRTCConfig) (
+	*WebRTCConn, webrtc.SessionDescription, *SDPMetrics, error) {
+
+	conn, SDP, metrics, err := newWebRTCConn(ctx, config, nil)
+	if err != nil {
+		return nil, webrtc.SessionDescription{}, nil, errors.Trace(err)
+	}
+	return conn, *SDP, metrics, nil
+}
+
+// NewWebRTCConnWithAnswer creates a new WebRTC connection initiated by a peer
+// that provided an offer SDP. An answer SDP is returned to be sent to the
+// peer. After the answer SDP is forwarded, call AwaitInitialDataChannel to
+// await the eventual WebRTC connection establishment.
+func NewWebRTCConnWithAnswer(
+	ctx context.Context,
+	config *WebRTCConfig,
+	peerSDP webrtc.SessionDescription) (
+	*WebRTCConn, webrtc.SessionDescription, *SDPMetrics, error) {
+
+	conn, SDP, metrics, err := newWebRTCConn(ctx, config, &peerSDP)
+	if err != nil {
+		return nil, webrtc.SessionDescription{}, nil, errors.Trace(err)
+	}
+	return conn, *SDP, metrics, nil
+}
+
+func newWebRTCConn(
+	ctx context.Context,
+	config *WebRTCConfig,
+	peerSDP *webrtc.SessionDescription) (
+	retConn *WebRTCConn,
+	retSDP *webrtc.SessionDescription,
+	retMetrics *SDPMetrics,
+	retErr error) {
+
+	isOffer := peerSDP == nil
+
+	udpConn, err := config.DialParameters.UDPListen()
+	if err != nil {
+		return nil, nil, nil, errors.Trace(err)
+	}
+
+	// Facilitate DTLS Client/ServerHello randomization. The client decides
+	// whether to do DTLS randomization and generates and the proxy receives
+	// ClientRootObfuscationSecret, so the client can orchestrate replay on
+	// both ends of the connection by reusing an obfuscation secret. Derive a
+	// secret specific to DTLS. SetDTLSSeed will futher derive a secure PRNG
+	// seed specific to either the client or proxy end of the connection
+	// (so each peer's randomization will be distinct).
+	//
+	// To avoid forking many pion repos in order to pass the seed through to
+	// the DTLS implementation, SetDTLSSeed populates a cache that's keyed by
+	// the UDP conn.
+	//
+	// TODO: pion/dtls is not forked yet, so this is a no-op at this time.
+
+	if config.DoDTLSRandomization {
+
+		dtlsObfuscationSecret, err := deriveObfuscationSecret(
+			config.ClientRootObfuscationSecret, false, "in-proxy-DTLS-seed")
+		if err != nil {
+			return nil, nil, nil, errors.Trace(err)
+		}
+
+		deadline, _ := ctx.Deadline()
+		err = SetDTLSSeed(udpConn, dtlsObfuscationSecret, isOffer, time.Until(deadline))
+		if err != nil {
+			return nil, nil, nil, errors.Trace(err)
+		}
+	}
+
+	// Initialize WebRTC
+
+	// There is no explicit anti-probing measures for the proxy side of the
+	// WebRTC connection, since each proxy "listener" is ephemeral, and since
+	// the WebRTC data channel protocol authenticates peers with
+	// certificates, so even if a probe were to find an ephemeral proxy
+	// listener, the listener can respond the same as a normal WebRTC end
+	// point would respond to a peer that doesn't have the correct credentials.
+	//
+	// pion's Mux API is used, as it enables providing a pre-created UDP
+	// socket which is configured with necessary BindToDevice settings. We do
+	// not actually multiplex multiple client connections on a single proxy
+	// connection. As a proxy creates a new UDP socket and Mux for each
+	// client, this currently open issue should not impact our
+	// implementation: "Listener doesn't process parallel handshakes",
+	// https://github.com/pion/dtls/issues/279.
+	//
+	// We detach data channels in order to use the standard Read/Write APIs.
+	// As detaching avoids using the pion DataChannel read loop, this
+	// currently open issue should not impact our
+	// implementation: "DataChannel.readLoop goroutine leak",
+	// https://github.com/pion/webrtc/issues/2098.
+
+	settingEngine := webrtc.SettingEngine{}
+	settingEngine.DetachDataChannels()
+	settingEngine.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled)
+	settingEngine.SetICEUDPMux(webrtc.NewICEUDPMux(&webrtcLogger{logger: config.Logger}, udpConn))
+
+	// Set this behavior to like common web browser WebRTC stacks.
+	settingEngine.SetDTLSInsecureSkipHelloVerify(true)
+
+	webRTCAPI := webrtc.NewAPI(webrtc.WithSettingEngine(settingEngine))
+
+	dataChannelLabel := "in-proxy-data-channel"
+
+	// NAT traversal setup
+
+	// When DisableInboundForMobleNetworks is set, skip both STUN and port
+	// mapping for mobile networks. Most mobile networks use CGNAT and
+	// neither STUN nor port mapping will be effective. It's faster to not
+	// wait for something that ultimately won't work.
+
+	disableInbound := config.DialParameters.DisableInboundForMobleNetworks() &&
+		config.DialParameters.NetworkType() == NetworkTypeMobile
+
+	// Try to establish a port mapping (UPnP-IGD, PCP, or NAT-PMP). The port
+	// mapper will attempt to identify the local gateway and query various
+	// port mapping protocols. portMapper.start launches this process and
+	// does not block. Port mappings are not part of the WebRTC standard, or
+	// supported by pion/webrtc. Instead, if a port mapping is established,
+	// it's edited into the SDP as a new host-type ICE candidate.
+
+	localPort := udpConn.LocalAddr().(*net.UDPAddr).Port
+	portMapper := newPortMapper(config.Logger, localPort)
+
+	doPortMapping := !disableInbound && !config.DialParameters.DisablePortMapping()
+
+	if doPortMapping {
+		portMapper.start()
+	}
+
+	// Select a STUN server for ICE hole punching. The STUN server to be used
+	// needs only support bind and not full RFC5780 NAT discovery.
+	//
+	// Each dial trys only one STUN server; in Psiphon tunnel establishment,
+	// other, concurrent in-proxy dials may select alternative STUN servers
+	// via DialParameters. When the STUN server operation is successful,
+	// DialParameters will be signaled so that it may configure the STUN
+	// server selection for replay.
+	//
+	// The STUN server will observe proxy IP addresses. Enumeration is
+	// mitigated by using various public STUN servers, including Psiphon STUN
+	// servers for proxies in non-censored regions. Proxies are also more
+	// ephemeral than Psiphon servers.
+
+	RFC5780 := false
+	stunServerAddress := config.DialParameters.STUNServerAddress(RFC5780)
+
+	// Proceed even when stunServerAddress is "" and !DisableSTUN, as ICE may
+	// find other host candidates.
+
+	doSTUN := stunServerAddress != "" && !disableInbound && !config.DialParameters.DisableSTUN()
+
+	var ICEServers []webrtc.ICEServer
+
+	if doSTUN {
+
+		// Use the Psiphon custom resolver to resolve any STUN server domains.
+		serverAddress, err := config.DialParameters.ResolveAddress(
+			ctx, stunServerAddress)
+		if err != nil {
+			return nil, nil, nil, errors.Trace(err)
+		}
+
+		ICEServers = []webrtc.ICEServer{
+			webrtc.ICEServer{
+				URLs: []string{"stun:" + serverAddress},
+			},
+		}
+	}
+
+	peerConnection, err := webRTCAPI.NewPeerConnection(
+		webrtc.Configuration{
+			ICEServers: ICEServers,
+		})
+	if err != nil {
+		return nil, nil, nil, errors.Trace(err)
+	}
+
+	conn := &WebRTCConn{
+		config:                       config,
+		udpConn:                      udpConn,
+		portMapper:                   portMapper,
+		closedSignal:                 make(chan struct{}),
+		peerConnection:               peerConnection,
+		dataChannelOpenedSignal:      make(chan struct{}),
+		dataChannelWriteBufferSignal: make(chan struct{}, 1),
+
+		// A data channel uses SCTP and is message oriented. The maximum
+		// message size supported by pion/webrtc is 65536:
+		// https://github.com/pion/webrtc/blob/dce970438344727af9c9965f88d958c55d32e64d/datachannel.go#L19.
+		// This read buffer must be as large as the maximum message size or
+		// else a read may fail with io.ErrShortBuffer.
+		messageBuffer: make([]byte, math.MaxUint16),
+	}
+	defer func() {
+		if retErr != nil {
+			// Cleanup on early return
+			conn.Close()
+
+			// Notify the DialParameters that the operation failed so that it
+			// can clear replay for that STUN server selection.
+			//
+			// Limitation: the error here may be due to failures unrelated to
+			// the STUN server.
+
+			if ctx.Err() == nil && doSTUN {
+				config.DialParameters.STUNServerAddressFailed(RFC5780, stunServerAddress)
+			}
+		}
+	}()
+
+	conn.peerConnection.OnConnectionStateChange(conn.onConnectionStateChange)
+	conn.peerConnection.OnICECandidate(conn.onICECandidate)
+	conn.peerConnection.OnICEConnectionStateChange(conn.onICEConnectionStateChange)
+	conn.peerConnection.OnICEGatheringStateChange(conn.onICEGatheringStateChange)
+	conn.peerConnection.OnNegotiationNeeded(conn.onNegotiationNeeded)
+	conn.peerConnection.OnSignalingStateChange(conn.onSignalingStateChange)
+	conn.peerConnection.OnDataChannel(conn.onDataChannel)
+
+	// As a future enhancement, consider using media channels instead of data
+	// channels, as media channels may be more common. Proxied QUIC would
+	// work over an unreliable media channel. Note that a media channel is
+	// still prefixed with STUN and DTLS exchanges before SRTP begins, so the
+	// first few packets are the same as a data channel.
+
+	// The offer sets the data channel configuration.
+	if isOffer {
+
+		dataChannelInit := &webrtc.DataChannelInit{}
+		if !config.ReliableTransport {
+			ordered := false
+			dataChannelInit.Ordered = &ordered
+			maxRetransmits := uint16(0)
+			dataChannelInit.MaxRetransmits = &maxRetransmits
+		}
+
+		dataChannel, err := peerConnection.CreateDataChannel(
+			dataChannelLabel, dataChannelInit)
+		if err != nil {
+			return nil, nil, nil, errors.Trace(err)
+		}
+
+		conn.setDataChannel(dataChannel)
+	}
+
+	// Prepare to await full ICE completion, including STUN candidates.
+	// Trickle ICE is not used, simplifying the broker API. It's expected
+	// that most clients and proxies will be behind a NAT, and not have
+	// publicly addressable host candidates. TURN is not used. So most
+	// candidates will be STUN, or server-reflexive, candidates.
+	//
+	// Later, the first to complete out of ICE or port mapping is used.
+	//
+	// TODO: stop waiting if an IPv6 host candidate is found?
+
+	iceComplete := webrtc.GatheringCompletePromise(conn.peerConnection)
+
+	// Create an offer, or input a peer's offer to create an answer.
+
+	if isOffer {
+
+		offer, err := conn.peerConnection.CreateOffer(nil)
+		if err != nil {
+			return nil, nil, nil, errors.Trace(err)
+		}
+
+		err = conn.peerConnection.SetLocalDescription(offer)
+		if err != nil {
+			return nil, nil, nil, errors.Trace(err)
+		}
+
+	} else {
+
+		err = conn.peerConnection.SetRemoteDescription(*peerSDP)
+		if err != nil {
+			return nil, nil, nil, errors.Trace(err)
+		}
+
+		answer, err := conn.peerConnection.CreateAnswer(nil)
+		if err != nil {
+			return nil, nil, nil, errors.Trace(err)
+		}
+
+		err = conn.peerConnection.SetLocalDescription(answer)
+		if err != nil {
+			return nil, nil, nil, errors.Trace(err)
+		}
+
+	}
+
+	// Await either ICE or port mapping completion.
+
+	// As a future enhancement, track which of ICE or port mapping succeeds
+	// and is then followed by a failed WebRTC dial; stop trying the method
+	// that often fails.
+
+	iceCompleted := false
+	portMappingExternalAddr := ""
+
+	select {
+	case <-iceComplete:
+		iceCompleted = true
+
+	case portMappingExternalAddr = <-portMapper.portMappingExternalAddress():
+
+		// Set responding port mapping types for metrics.
+		//
+		// Limitation: if there are multiple responding protocol types, it's
+		// not known here which was used for this dial.
+		config.DialParameters.SetPortMappingTypes(
+			getRespondingPortMappingTypes(config.DialParameters.NetworkID()))
+
+	case <-ctx.Done():
+		return nil, nil, nil, errors.Trace(ctx.Err())
+	}
+
+	// Release any port mapping resources when not using it.
+	if portMapper != nil && portMappingExternalAddr == "" {
+		portMapper.close()
+		conn.portMapper = nil
+	}
+
+	// Get the offer or answer, now populated with any ICE candidates.
+
+	localDescription := conn.peerConnection.LocalDescription()
+
+	// Adjust the SDP, removing local network addresses and adding any
+	// port mapping candidate.
+
+	adjustedSDP, metrics, err := PrepareSDPAddresses([]byte(
+		localDescription.SDP), portMappingExternalAddr)
+	if err != nil {
+		return nil, nil, nil, errors.Trace(err)
+	}
+
+	// When STUN was attempted, ICE completed, and a STUN server-reflexive
+	// candidate is present, notify the DialParameters so that it can set
+	// replay for that STUN server selection.
+
+	if iceCompleted && doSTUN {
+		hasServerReflexive := false
+		for _, candidateType := range metrics.ICECandidateTypes {
+			if candidateType == ICECandidateServerReflexive {
+				hasServerReflexive = true
+			}
+		}
+		if hasServerReflexive {
+			config.DialParameters.STUNServerAddressSucceeded(RFC5780, stunServerAddress)
+		} else {
+			config.DialParameters.STUNServerAddressFailed(RFC5780, stunServerAddress)
+		}
+	}
+
+	// The WebRTCConn is prepared, but the data channel is not yet connected.
+	// On the offer end, the peer's following answer must be input to
+	// SetRemoteSDP. And both ends must call AwaitInitialDataChannel to await
+	// the data channel establishment.
+
+	return conn,
+		&webrtc.SessionDescription{
+			Type: localDescription.Type,
+			SDP:  string(adjustedSDP),
+		},
+		metrics,
+		nil
+}
+
+func (conn *WebRTCConn) setDataChannel(dataChannel *webrtc.DataChannel) {
+
+	// Assumes the caller holds conn.mutex, or is newWebRTCConn, creating the
+	// conn.
+
+	conn.dataChannel = dataChannel
+	conn.dataChannel.OnOpen(conn.onDataChannelOpen)
+	conn.dataChannel.OnClose(conn.onDataChannelClose)
+
+	conn.dataChannel.OnOpen(conn.onDataChannelOpen)
+	conn.dataChannel.OnClose(conn.onDataChannelClose)
+
+	// Set up flow control (see comment in conn.Write)
+	conn.dataChannel.SetBufferedAmountLowThreshold(dataChannelBufferedAmountLowThreshold)
+	conn.dataChannel.OnBufferedAmountLow(func() {
+		select {
+		case conn.dataChannelWriteBufferSignal <- struct{}{}:
+		default:
+		}
+	})
+}
+
+// SetRemoteSDP takes the answer SDP that is received in response to an offer
+// SDP. SetRemoteSDP initiates the WebRTC connection establishment on the
+// offer end.
+func (conn *WebRTCConn) SetRemoteSDP(peerSDP webrtc.SessionDescription) error {
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	err := conn.peerConnection.SetRemoteDescription(peerSDP)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	return nil
+}
+
+// AwaitInitialDataChannel returns when the data channel is established, or
+// when an error has occured.
+func (conn *WebRTCConn) AwaitInitialDataChannel(ctx context.Context) error {
+
+	// Don't lock the mutex, or else necessary operations will deadlock.
+
+	select {
+	case <-conn.dataChannelOpenedSignal:
+
+		// The data channel is connected.
+		//
+		// TODO: for metrics, determine which end was the network connection
+		// initiator; and determine which type of ICE candidate was
+		// successful (note that peer-reflexive candidates aren't in either
+		// SDP and emerge only during ICE negotiation).
+
+	case <-ctx.Done():
+		return errors.Trace(ctx.Err())
+	case <-conn.closedSignal:
+		return errors.TraceNew("connection has closed")
+	}
+	return nil
+}
+
+func (conn *WebRTCConn) Close() error {
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	if conn.isClosed {
+		return nil
+	}
+
+	// Close the udpConn to interrupt any blocking DTLS handshake:
+	// https://github.com/pion/webrtc/blob/c1467e4871c78ee3f463b50d858d13dc6f2874a4/dtlstransport.go#L334-L340
+
+	if conn.udpConn != nil {
+		conn.udpConn.Close()
+	}
+
+	if conn.portMapper != nil {
+		conn.portMapper.close()
+	}
+
+	if conn.dataChannelConn != nil {
+		conn.dataChannelConn.Close()
+	}
+	if conn.dataChannel != nil {
+		conn.dataChannel.Close()
+	}
+	if conn.peerConnection != nil {
+		conn.peerConnection.Close()
+	}
+
+	close(conn.closedSignal)
+
+	conn.isClosed = true
+
+	return nil
+}
+
+func (conn *WebRTCConn) Read(p []byte) (int, error) {
+
+	// Don't hold this lock, or else concurrent Writes will be blocked.
+	conn.mutex.Lock()
+	dataChannelConn := conn.dataChannelConn
+	conn.mutex.Unlock()
+
+	if dataChannelConn == nil {
+		return 0, errors.TraceNew("not connected")
+	}
+
+	// The input read buffer, p, may not be the same length as the message
+	// read from the data channel. Buffer the read message if another Read
+	// call is necessary to consume it. As per https://pkg.go.dev/io#Reader,
+	// dataChannelConn bytes read are processed even when
+	// dataChannelConn.Read returns an error; the error value is stored and
+	// returned with the Read call that consumes the end of the message buffer.
+
+	conn.messageMutex.Lock()
+	defer conn.messageMutex.Unlock()
+
+	if conn.messageOffset == conn.messageLength {
+		n, err := dataChannelConn.Read(conn.messageBuffer)
+		conn.messageOffset = 0
+		conn.messageLength = n
+		conn.messageError = err
+	}
+
+	n := copy(p, conn.messageBuffer[conn.messageOffset:conn.messageLength])
+	conn.messageOffset += n
+
+	var err error
+	if conn.messageOffset == conn.messageLength {
+		err = conn.messageError
+	}
+
+	return n, errors.Trace(err)
+}
+
+func (conn *WebRTCConn) Write(p []byte) (int, error) {
+
+	// Don't hold this lock, or else concurrent Reads will be blocked.
+	conn.mutex.Lock()
+	isClosed := conn.isClosed
+	bufferedAmount := conn.dataChannel.BufferedAmount()
+	dataChannelConn := conn.dataChannelConn
+	conn.mutex.Unlock()
+
+	if dataChannelConn == nil {
+		return 0, errors.TraceNew("not connected")
+	}
+
+	// Flow control is required to ensure that Write calls don't result in
+	// unbounded buffering in pion/webrtc. Use similar logic and the same
+	// buffer size thresholds as the pion sample code.
+	//
+	// https://github.com/pion/webrtc/tree/master/examples/data-channels-flow-control#when-do-we-need-it:
+	// > Send or SendText methods are called on DataChannel to send data to
+	// > the connected peer. The methods return immediately, but it does not
+	// > mean the data was actually sent onto the wire. Instead, it is
+	// > queued in a buffer until it actually gets sent out to the wire.
+	// >
+	// > When you have a large amount of data to send, it is an application's
+	// > responsibility to control the buffered amount in order not to
+	// > indefinitely grow the buffer size to eventually exhaust the memory.
+
+	// If the pion write buffer is too full, wait for a signal that sufficient
+	// write data has been consumed before writing more.
+	if !isClosed && bufferedAmount+uint64(len(p)) > dataChannelMaxBufferedAmount {
+		select {
+		case <-conn.dataChannelWriteBufferSignal:
+		case <-conn.closedSignal:
+			return 0, errors.TraceNew("connection has closed")
+		}
+	}
+
+	// Limitation: if len(p) > 65536, the dataChannelConn.Write wil fail. In
+	// practise, this is not expected to happen with typical use cases such
+	// as io.Copy, which uses a 32K buffer.
+
+	n, err := dataChannelConn.Write(p)
+	return n, errors.Trace(err)
+}
+
+func (conn *WebRTCConn) LocalAddr() net.Addr {
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	// This is the local UDP socket address, not the external, public address.
+	return conn.udpConn.LocalAddr()
+}
+
+func (conn *WebRTCConn) RemoteAddr() net.Addr {
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	// Not supported.
+	return nil
+}
+
+func (conn *WebRTCConn) SetDeadline(t time.Time) error {
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	return errors.TraceNew("not supported")
+}
+
+func (conn *WebRTCConn) SetReadDeadline(t time.Time) error {
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	return errors.TraceNew("not supported")
+}
+
+func (conn *WebRTCConn) SetWriteDeadline(t time.Time) error {
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	return errors.TraceNew("not supported")
+}
+
+func (conn *WebRTCConn) onConnectionStateChange(state webrtc.PeerConnectionState) {
+
+	if state == webrtc.PeerConnectionStateFailed {
+		conn.Close()
+	}
+
+	conn.config.Logger.WithTraceFields(common.LogFields{
+		"state": state.String(),
+	}).Info("peer connection state changed")
+}
+
+func (conn *WebRTCConn) onICECandidate(candidate *webrtc.ICECandidate) {
+
+	conn.config.Logger.WithTraceFields(common.LogFields{
+		"candidate": candidate,
+	}).Info("new ICE candidate")
+}
+
+func (conn *WebRTCConn) onICEConnectionStateChange(state webrtc.ICEConnectionState) {
+
+	conn.config.Logger.WithTraceFields(common.LogFields{
+		"state": state.String(),
+	}).Info("ICE connection state changed")
+}
+
+func (conn *WebRTCConn) onICEGatheringStateChange(state webrtc.ICEGathererState) {
+
+	conn.config.Logger.WithTraceFields(common.LogFields{
+		"state": state.String(),
+	}).Info("ICE gathering state changed")
+}
+
+func (conn *WebRTCConn) onNegotiationNeeded() {
+
+	conn.config.Logger.WithTrace().Info("negotiation needed")
+}
+
+func (conn *WebRTCConn) onSignalingStateChange(state webrtc.SignalingState) {
+
+	conn.config.Logger.WithTraceFields(common.LogFields{
+		"state": state.String(),
+	}).Info("signaling state changed")
+}
+
+func (conn *WebRTCConn) onDataChannel(dataChannel *webrtc.DataChannel) {
+
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	conn.setDataChannel(dataChannel)
+
+	conn.config.Logger.WithTraceFields(common.LogFields{
+		"label": dataChannel.Label(),
+		"ID":    dataChannel.ID(),
+	}).Info("new data channel")
+}
+
+func (conn *WebRTCConn) onDataChannelOpen() {
+
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	dataChannelConn, err := conn.dataChannel.Detach()
+	if err == nil {
+		conn.dataChannelConn = dataChannelConn
+
+		// TODO: can a data channel be connected, disconnected, and then
+		// reestablished in one session?
+
+		conn.dataChannelOpenedOnce.Do(func() { close(conn.dataChannelOpenedSignal) })
+	}
+
+	conn.config.Logger.WithTraceFields(common.LogFields{
+		"detachError": err,
+	}).Info("data channel open")
+}
+
+func (conn *WebRTCConn) onDataChannelClose() {
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+
+	conn.config.Logger.WithTrace().Info("data channel closed")
+}
+
+// PrepareSDPAddresses adjusts the SDP, pruning local network addresses and
+// adding any port mapping as a host candidate.
+func PrepareSDPAddresses(
+	encodedSDP []byte,
+	portMappingExternalAddr string) ([]byte, *SDPMetrics, error) {
+
+	modifiedSDP, metrics, err := processSDPAddresses(
+		encodedSDP, portMappingExternalAddr, false, nil, common.GeoIPData{})
+	return modifiedSDP, metrics, errors.Trace(err)
+}
+
+// ValidateSDPAddresses checks that the SDP does not contain an empty list of
+// candidates, bogon candidates, or candidates outside of the country and ASN
+// for the specified expectedGeoIPData.
+func ValidateSDPAddresses(
+	encodedSDP []byte,
+	lookupGeoIP LookupGeoIP,
+	expectedGeoIPData common.GeoIPData) (*SDPMetrics, error) {
+
+	_, metrics, err := processSDPAddresses(encodedSDP, "", true, lookupGeoIP, expectedGeoIPData)
+	return metrics, errors.Trace(err)
+}
+
+// SDPMetrics are network capability metrics values for an SDP.
+type SDPMetrics struct {
+	ICECandidateTypes []ICECandidateType
+	HasIPv6           bool
+}
+
+// processSDPAddresses is based on snowflake/common/util.StripLocalAddresses
+// https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/blob/v2.5.1/common/util/util.go#L70-99
+/*
+              This file contains the license for "Snowflake"
+     a free software project which provides a WebRTC pluggable transport.
+
+================================================================================
+Copyright (c) 2016, Serene Han, Arlo Breault
+Copyright (c) 2019-2020, The Tor Project, Inc
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+  * Redistributions of source code must retain the above copyright notice, this
+list of conditions and the following disclaimer.
+
+  * Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation and/or
+other materials provided with the distribution.
+
+  * Neither the names of the copyright owners nor the names of its
+contributors may be used to endorse or promote products derived from this
+software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+================================================================================
+
+*/
+
+func processSDPAddresses(
+	encodedSDP []byte,
+	portMappingExternalAddr string,
+	errorOnBogon bool,
+	lookupGeoIP LookupGeoIP,
+	expectedGeoIPData common.GeoIPData) ([]byte, *SDPMetrics, error) {
+
+	var sessionDescription sdp.SessionDescription
+	err := sessionDescription.Unmarshal(encodedSDP)
+	if err != nil {
+		return nil, nil, errors.Trace(err)
+	}
+
+	candidateTypes := map[ICECandidateType]bool{}
+	hasIPv6 := false
+
+	var portMappingICECandidates []sdp.Attribute
+	if portMappingExternalAddr != "" {
+
+		// Prepare ICE candidate attibute pair for the port mapping, modeled after the definition of host candidates.
+
+		host, portStr, err := net.SplitHostPort(portMappingExternalAddr)
+		if err != nil {
+			return nil, nil, errors.Trace(err)
+		}
+		port, _ := strconv.Atoi(portStr)
+
+		// Only IPv4 port mapping addresses are supported due to the
+		// NewCandidateHost limitation noted below. It is expected that port
+		// mappings will be IPv4, as NAT and IPv6 is not a typical combination.
+
+		hostIP := net.ParseIP(host)
+		if hostIP != nil && hostIP.To4() != nil {
+
+			for _, component := range []webrtc.ICEComponent{webrtc.ICEComponentRTP, webrtc.ICEComponentRTCP} {
+
+				// The candidate ID is generated and the priorty and foundation
+				// use the default for hosts.
+				//
+				// Limitation: NewCandidateHost initializes the networkType to
+				// NetworkTypeUDP4, and this field is not-exported.
+				// https://github.com/pion/ice/blob/6d301287654b05a36248842c278d58d501454bff/candidate_host.go#L27-L64
+
+				iceCandidate, err := ice.NewCandidateHost(&ice.CandidateHostConfig{
+					Network:   "udp",
+					Address:   host,
+					Port:      port,
+					Component: uint16(component),
+				})
+				if err != nil {
+					return nil, nil, errors.Trace(err)
+				}
+
+				portMappingICECandidates = append(
+					portMappingICECandidates,
+					sdp.Attribute{Key: "candidate", Value: iceCandidate.Marshal()})
+			}
+
+			candidateTypes[ICECandidatePortMapping] = true
+		}
+	}
+
+	candidateCount := len(portMappingICECandidates)
+
+	for _, mediaDescription := range sessionDescription.MediaDescriptions {
+
+		addPortMappingCandidates := len(portMappingICECandidates) > 0
+		var attributes []sdp.Attribute
+		for _, attribute := range mediaDescription.Attributes {
+
+			// Insert the port mapping candidate either before the
+			// first "a=candidate", or before "a=end-of-candidates"(there may
+			// be no "a=candidate" attributes).
+
+			if addPortMappingCandidates &&
+				(attribute.IsICECandidate() || attribute.Key == sdp.AttrKeyEndOfCandidates) {
+
+				attributes = append(attributes, portMappingICECandidates...)
+				addPortMappingCandidates = false
+			}
+
+			if attribute.IsICECandidate() {
+
+				candidate, err := ice.UnmarshalCandidate(attribute.Value)
+				if err != nil {
+					return nil, nil, errors.Trace(err)
+				}
+
+				candidateIP := net.ParseIP(candidate.Address())
+
+				if candidateIP == nil {
+					return nil, nil, errors.TraceNew("unexpected non-IP")
+				}
+
+				if candidateIP.To4() == nil {
+					hasIPv6 = true
+				}
+
+				// Strip non-routable bogons, including LAN addresses.
+				// Same-LAN client/proxy hops are not expected to be useful,
+				// and this also avoids unnecessary local network traffic.
+				//
+				// Well-behaved clients and proxies will strip these values;
+				// the broker enforces this and uses errorOnBogon.
+
+				if !getAllowLoopbackWebRTCConnections() &&
+					isBogon(candidateIP) {
+
+					if errorOnBogon {
+						return nil, nil, errors.TraceNew("unexpected bogon")
+					}
+					continue
+				}
+
+				// The broker will check that clients and proxies specify only
+				// candidates that map to the same GeoIP country and ASN as
+				// the client/proxy connection to the broker. This limits
+				// misuse of candidate to connect to other locations.
+				// Legitimate candidates will not all have the exact same IP
+				// address, as there could be a mix of IPv4 and IPv6, as well
+				// as potentially different NAT paths.
+
+				if lookupGeoIP != nil {
+					candidateGeoIPData := lookupGeoIP(candidate.Address())
+					if candidateGeoIPData.Country != expectedGeoIPData.Country {
+						return nil, nil, errors.TraceNew("unexpected GeoIP country")
+					}
+					if candidateGeoIPData.ASN != expectedGeoIPData.ASN {
+						return nil, nil, errors.TraceNew("unexpected GeoIP ASN")
+					}
+				}
+
+				// These types are not reported:
+				// - CandidateTypeRelay: TURN servers are not used.
+				// - CandidateTypePeerReflexive: this candidate type only
+				//   emerges later in the connection process.
+
+				switch candidate.Type() {
+				case ice.CandidateTypeHost:
+					candidateTypes[ICECandidateHost] = true
+				case ice.CandidateTypeServerReflexive:
+					candidateTypes[ICECandidateServerReflexive] = true
+				}
+
+				candidateCount += 1
+			}
+
+			attributes = append(attributes, attribute)
+		}
+
+		mediaDescription.Attributes = attributes
+	}
+
+	if candidateCount == 0 {
+		return nil, nil, errors.TraceNew("no candidates")
+	}
+
+	encodedSDP, err = sessionDescription.Marshal()
+	if err != nil {
+		return nil, nil, errors.Trace(err)
+	}
+
+	metrics := &SDPMetrics{
+		HasIPv6: hasIPv6,
+	}
+	for candidateType := range candidateTypes {
+		metrics.ICECandidateTypes = append(metrics.ICECandidateTypes, candidateType)
+	}
+
+	return encodedSDP, metrics, nil
+}
+
+var allowLoopbackWebRTCConnections int32
+
+func getAllowLoopbackWebRTCConnections() bool {
+	return atomic.LoadInt32(&allowLoopbackWebRTCConnections) == 1
+}
+
+// setAllowLoopbackWebRTCConnections is for testing only, to allow the
+// end-to-end inproxy_test to run with a restrictive OS firewall in place. Do
+// not export.
+func setAllowLoopbackWebRTCConnections(allow bool) {
+	value := int32(0)
+	if allow {
+		value = 1
+	}
+	atomic.StoreInt32(&allowLoopbackWebRTCConnections, value)
+}
+
+func isBogon(IP net.IP) bool {
+	if IP == nil {
+		return false
+	}
+	return filtertransport.FindIPNet(
+		filtertransport.DefaultFilteredNetworks, IP)
+}
+
+// webrtcLogger wraps common.Logger and implements
+// https://pkg.go.dev/github.com/pion/logging#LeveledLogger for passing into
+// pion.
+type webrtcLogger struct {
+	logger common.Logger
+}
+
+func (l *webrtcLogger) Trace(msg string) {
+	// Ignored.
+}
+
+func (l *webrtcLogger) Tracef(format string, args ...interface{}) {
+	// Ignored.
+}
+
+func (l *webrtcLogger) Debug(msg string) {
+	l.logger.WithTrace().Debug("webRTC: " + msg)
+}
+
+func (l *webrtcLogger) Debugf(format string, args ...interface{}) {
+	l.logger.WithTrace().Debug("webRTC: " + fmt.Sprintf(format, args...))
+}
+
+func (l *webrtcLogger) Info(msg string) {
+	l.logger.WithTrace().Info("webRTC: " + msg)
+}
+
+func (l *webrtcLogger) Infof(format string, args ...interface{}) {
+	l.logger.WithTrace().Info("webRTC: " + fmt.Sprintf(format, args...))
+}
+
+func (l *webrtcLogger) Warn(msg string) {
+	l.logger.WithTrace().Warning("webRTC: " + msg)
+}
+
+func (l *webrtcLogger) Warnf(format string, args ...interface{}) {
+	l.logger.WithTrace().Warning("webRTC: " + fmt.Sprintf(format, args...))
+}
+
+func (l *webrtcLogger) Error(msg string) {
+	l.logger.WithTrace().Error("webRTC: " + msg)
+}
+
+func (l *webrtcLogger) Errorf(format string, args ...interface{}) {
+	l.logger.WithTrace().Error("webRTC: " + fmt.Sprintf(format, args...))
+}

+ 4 - 0
psiphon/common/logger.go

@@ -28,6 +28,10 @@ type Logger interface {
 	WithTrace() LogTrace
 	WithTraceFields(fields LogFields) LogTrace
 	LogMetric(metric string, fields LogFields)
+
+	// IsLogLevelDebug is used to skip formatting debug-level log messages in
+	// cases where performance would be impacted.
+	IsLogLevelDebug() bool
 }
 
 // LogTrace is interface-compatible with the return values from

+ 5 - 0
psiphon/common/prng/prng.go

@@ -269,6 +269,11 @@ func (p *PRNG) Perm(n int) []int {
 	return p.rand.Perm(n)
 }
 
+// Shuffle is equivilent to math/rand.Shuffle.
+func (p *PRNG) Shuffle(n int, swap func(i, j int)) {
+	p.rand.Shuffle(n, swap)
+}
+
 // Range selects a random integer in [min, max].
 // If min < 0, min is set to 0. If max < min, min is returned.
 func (p *PRNG) Range(min, max int) int {

+ 2 - 0
psiphon/common/protocol/protocol.go

@@ -87,6 +87,8 @@ const (
 
 	CONJURE_TRANSPORT_MIN_OSSH   = "Min-OSSH"
 	CONJURE_TRANSPORT_OBFS4_OSSH = "Obfs4-OSSH"
+
+	CAPABILITY_INPROXY = "inproxy"
 )
 
 var SupportedServerEntrySources = []string{

+ 126 - 34
psiphon/common/protocol/serverEntry.go

@@ -32,6 +32,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"regexp"
 	"strings"
 	"time"
 
@@ -44,40 +45,42 @@ import (
 // several protocols. Server entries are JSON records downloaded from
 // various sources.
 type ServerEntry struct {
-	Tag                             string   `json:"tag"`
-	IpAddress                       string   `json:"ipAddress"`
-	WebServerPort                   string   `json:"webServerPort"` // not an int
-	WebServerSecret                 string   `json:"webServerSecret"`
-	WebServerCertificate            string   `json:"webServerCertificate"`
-	SshPort                         int      `json:"sshPort"`
-	SshUsername                     string   `json:"sshUsername"`
-	SshPassword                     string   `json:"sshPassword"`
-	SshHostKey                      string   `json:"sshHostKey"`
-	SshObfuscatedPort               int      `json:"sshObfuscatedPort"`
-	SshObfuscatedQUICPort           int      `json:"sshObfuscatedQUICPort"`
-	LimitQUICVersions               []string `json:"limitQUICVersions"`
-	SshObfuscatedTapDancePort       int      `json:"sshObfuscatedTapdancePort"`
-	SshObfuscatedConjurePort        int      `json:"sshObfuscatedConjurePort"`
-	SshObfuscatedKey                string   `json:"sshObfuscatedKey"`
-	Capabilities                    []string `json:"capabilities"`
-	Region                          string   `json:"region"`
-	FrontingProviderID              string   `json:"frontingProviderID"`
-	MeekServerPort                  int      `json:"meekServerPort"`
-	MeekCookieEncryptionPublicKey   string   `json:"meekCookieEncryptionPublicKey"`
-	MeekObfuscatedKey               string   `json:"meekObfuscatedKey"`
-	MeekFrontingHost                string   `json:"meekFrontingHost"`
-	MeekFrontingHosts               []string `json:"meekFrontingHosts"`
-	MeekFrontingDomain              string   `json:"meekFrontingDomain"`
-	MeekFrontingAddresses           []string `json:"meekFrontingAddresses"`
-	MeekFrontingAddressesRegex      string   `json:"meekFrontingAddressesRegex"`
-	MeekFrontingDisableSNI          bool     `json:"meekFrontingDisableSNI"`
-	TacticsRequestPublicKey         string   `json:"tacticsRequestPublicKey"`
-	TacticsRequestObfuscatedKey     string   `json:"tacticsRequestObfuscatedKey"`
-	ConfigurationVersion            int      `json:"configurationVersion"`
-	Signature                       string   `json:"signature"`
-	DisableHTTPTransforms           bool     `json:"disableHTTPTransforms"`
-	DisableObfuscatedQUICTransforms bool     `json:"disableObfuscatedQUICTransforms"`
-	DisableOSSHTransforms           bool     `json:"disableOSSHTransforms"`
+	Tag                                 string   `json:"tag"`
+	IpAddress                           string   `json:"ipAddress"`
+	WebServerPort                       string   `json:"webServerPort"` // not an int
+	WebServerSecret                     string   `json:"webServerSecret"`
+	WebServerCertificate                string   `json:"webServerCertificate"`
+	SshPort                             int      `json:"sshPort"`
+	SshUsername                         string   `json:"sshUsername"`
+	SshPassword                         string   `json:"sshPassword"`
+	SshHostKey                          string   `json:"sshHostKey"`
+	SshObfuscatedPort                   int      `json:"sshObfuscatedPort"`
+	SshObfuscatedQUICPort               int      `json:"sshObfuscatedQUICPort"`
+	LimitQUICVersions                   []string `json:"limitQUICVersions"`
+	SshObfuscatedTapDancePort           int      `json:"sshObfuscatedTapdancePort"`
+	SshObfuscatedConjurePort            int      `json:"sshObfuscatedConjurePort"`
+	SshObfuscatedKey                    string   `json:"sshObfuscatedKey"`
+	Capabilities                        []string `json:"capabilities"`
+	Region                              string   `json:"region"`
+	FrontingProviderID                  string   `json:"frontingProviderID"`
+	MeekServerPort                      int      `json:"meekServerPort"`
+	MeekCookieEncryptionPublicKey       string   `json:"meekCookieEncryptionPublicKey"`
+	MeekObfuscatedKey                   string   `json:"meekObfuscatedKey"`
+	MeekFrontingHost                    string   `json:"meekFrontingHost"`
+	MeekFrontingHosts                   []string `json:"meekFrontingHosts"`
+	MeekFrontingDomain                  string   `json:"meekFrontingDomain"`
+	MeekFrontingAddresses               []string `json:"meekFrontingAddresses"`
+	MeekFrontingAddressesRegex          string   `json:"meekFrontingAddressesRegex"`
+	MeekFrontingDisableSNI              bool     `json:"meekFrontingDisableSNI"`
+	TacticsRequestPublicKey             string   `json:"tacticsRequestPublicKey"`
+	TacticsRequestObfuscatedKey         string   `json:"tacticsRequestObfuscatedKey"`
+	ConfigurationVersion                int      `json:"configurationVersion"`
+	Signature                           string   `json:"signature"`
+	DisableHTTPTransforms               bool     `json:"disableHTTPTransforms"`
+	DisableObfuscatedQUICTransforms     bool     `json:"disableObfuscatedQUICTransforms"`
+	DisableOSSHTransforms               bool     `json:"disableOSSHTransforms"`
+	InProxySessionPublicKey             string   `json:"inProxySessionPublicKey"`
+	InProxySessionRootObfuscationSecret string   `json:"inProxySessionRootObfuscationSecret"`
 
 	// These local fields are not expected to be present in downloaded server
 	// entries. They are added by the client to record and report stats about
@@ -647,6 +650,89 @@ func (serverEntry *ServerEntry) GetDialPortNumber(tunnelProtocol string) (int, e
 	return 0, errors.TraceNew("unknown protocol")
 }
 
+// IsValidDialAddress indicates whether the dial destination network/host/port
+// matches the dial parameters for any of the tunnel protocols supported by
+// the server entry.
+//
+// Limitations:
+// - TAPDANCE-OSSH and CONJURE-OSSH are not supported.
+// - The host header is not considered in the case of fronted protocols.
+func (serverEntry *ServerEntry) IsValidDialAddress(
+	networkProtocol string, dialHost string, dialPortNumber int) bool {
+
+	for _, tunnelProtocol := range SupportedTunnelProtocols {
+
+		if !serverEntry.SupportsProtocol(tunnelProtocol) {
+			continue
+		}
+
+		if TunnelProtocolUsesRefractionNetworking(tunnelProtocol) {
+			// The TapDance and Conjure destination addresses are not included
+			// in the server entry, so TAPDANCE-OSSH and CONJURE-OSSH dial
+			// destinations cannot be validated here.
+			continue
+		}
+
+		usesTCP := TunnelProtocolUsesTCP(tunnelProtocol)
+		if (usesTCP && networkProtocol != "tcp") || (!usesTCP && networkProtocol != "udp") {
+			continue
+		}
+
+		tunnelPortNumber, err := serverEntry.GetDialPortNumber(tunnelProtocol)
+		if err != nil || tunnelPortNumber != dialPortNumber {
+			// Silently fail on error as the server entry should be well-formed.
+			continue
+		}
+
+		if !TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
+
+			// For all direct protocols, the destination host must be the
+			// server IP address.
+
+			if serverEntry.IpAddress != dialHost {
+				continue
+			}
+
+		} else {
+
+			// For fronted protocols, the destination host may be domain and
+			// must match either MeekFrontingAddressesRegex or
+			// MeekFrontingAddresses. As in psiphon.selectFrontingParameters,
+			// MeekFrontingAddressesRegex takes precedence when not empty.
+			//
+			// As the host header value is not checked here, additional
+			// measures must be taken to ensure the destination is a Psiphon server.
+
+			if len(serverEntry.MeekFrontingAddressesRegex) > 0 {
+
+				re, err := regexp.Compile(serverEntry.MeekFrontingAddressesRegex)
+				if err != nil {
+					continue
+				}
+
+				// The entire dialHost string must match the regex.
+				re.Longest()
+				match := re.FindString(dialHost)
+				if match == "" || match != dialHost {
+					continue
+				}
+
+			} else {
+
+				if !common.Contains(serverEntry.MeekFrontingAddresses, dialHost) {
+					continue
+				}
+			}
+		}
+
+		// When all of the checks pass for this protocol, the input is a valid
+		// dial destination.
+		return true
+	}
+
+	return false
+}
+
 // GetSupportedTacticsProtocols returns a list of tunnel protocols,
 // supported by the ServerEntry's capabilities, that may be used
 // for tactics requests.
@@ -699,6 +785,12 @@ func (serverEntry *ServerEntry) GetDiagnosticID() string {
 	return TagToDiagnosticID(serverEntry.Tag)
 }
 
+// SupportsInproxy returns true when the server is designated to receive
+// connections via in-proxies.
+func (serverEntry *ServerEntry) SupportsInProxy() bool {
+	return serverEntry.hasCapability(CAPABILITY_INPROXY)
+}
+
 // GenerateServerEntryTag creates a server entry tag value that is
 // cryptographically derived from the IP address and web server secret in a
 // way that is difficult to reverse the IP address value from the tag or

+ 77 - 0
psiphon/common/protocol/serverEntry_test.go

@@ -295,3 +295,80 @@ func testServerEntryListSignatures(t *testing.T, setExplicitTag bool) {
 		t.Fatalf("AddSignature unexpectedly succeeded")
 	}
 }
+
+func TestIsValidDialAddress(t *testing.T) {
+
+	serverEntry := &ServerEntry{
+		IpAddress:                  "192.168.0.1",
+		SshPort:                    1,
+		SshObfuscatedPort:          2,
+		SshObfuscatedQUICPort:      3,
+		Capabilities:               []string{"handshake", "SSH", "OSSH", "QUIC", "FRONTED-MEEK"},
+		MeekFrontingAddressesRegex: "[ab]+",
+		MeekServerPort:             443,
+	}
+
+	testCases := []struct {
+		description     string
+		networkProtocol string
+		dialHost        string
+		dialPortNumber  int
+		isValid         bool
+	}{
+		{
+			"valid IP dial",
+			"tcp", "192.168.0.1", 1,
+			true,
+		},
+		{
+			"valid domain dial",
+			"tcp", "aaabbbaaabbb", 443,
+			true,
+		},
+		{
+			"valid UDP dial",
+			"tcp", "192.168.0.1", 1,
+			true,
+		},
+		{
+			"invalid network dial",
+			"udp", "192.168.0.1", 1,
+			false,
+		},
+		{
+			"invalid IP dial",
+			"tcp", "192.168.0.2", 1,
+			false,
+		},
+		{
+			"invalid domain dial",
+			"tcp", "aaabbbcccbbb", 443,
+			false,
+		},
+		{
+			"invalid port dial",
+			"tcp", "192.168.0.1", 4,
+			false,
+		},
+		{
+			"invalid domain port dial",
+			"tcp", "aaabbbaaabbb", 80,
+			false,
+		},
+		{
+			"invalid domain newline dial",
+			"tcp", "aaabbbaaabbb\nccc", 443,
+			false,
+		},
+	}
+
+	for _, testCase := range testCases {
+		t.Run(testCase.description, func(t *testing.T) {
+			if testCase.isValid != serverEntry.IsValidDialAddress(
+				testCase.networkProtocol, testCase.dialHost, testCase.dialPortNumber) {
+
+				t.Errorf("unexpected IsValidDialAddress result")
+			}
+		})
+	}
+}

+ 22 - 0
psiphon/common/utils.go

@@ -34,6 +34,7 @@ import (
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/wildcard"
 )
 
@@ -238,3 +239,24 @@ func SleepWithContext(ctx context.Context, duration time.Duration) {
 	case <-ctx.Done():
 	}
 }
+
+// SleepWithJitter returns after the specified duration, with random jitter
+// applied, or once the input ctx is done, whichever is first.
+func SleepWithJitter(ctx context.Context, duration time.Duration, jitter float64) {
+	timer := time.NewTimer(prng.JitterDuration(duration, jitter))
+	defer timer.Stop()
+	select {
+	case <-ctx.Done():
+	case <-timer.C:
+	}
+}
+
+// ValueOrDefault returns the input value, or, when value is the zero value of
+// its type, defaultValue.
+func ValueOrDefault[T comparable](value, defaultValue T) T {
+	var zero T
+	if value == zero {
+		return defaultValue
+	}
+	return value
+}