Browse Source

Add resolver and transforms packages

Rod Hynes 3 years ago
parent
commit
b86f6f0f9a

+ 70 - 0
psiphon/common/parameters/parameters.go

@@ -55,6 +55,7 @@ package parameters
 
 import (
 	"encoding/json"
+	"net"
 	"net/http"
 	"reflect"
 	"sync/atomic"
@@ -65,6 +66,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
 	"golang.org/x/net/bpf"
 )
 
@@ -301,6 +303,17 @@ const (
 	RestrictFrontingProviderIDsClientProbability     = "RestrictFrontingProviderIDsClientProbability"
 	UpstreamProxyAllowAllServerEntrySources          = "UpstreamProxyAllowAllServerEntrySources"
 	DestinationBytesMetricsASN                       = "DestinationBytesMetricsASN"
+	DNSResolverAttemptsPerServer                     = "DNSResolverAttemptsPerServer"
+	DNSResolverRequestTimeout                        = "DNSResolverRequestTimeout"
+	DNSResolverAwaitTimeout                          = "DNSResolverAwaitTimeout"
+	DNSResolverPreresolvedIPAddressProbability       = "DNSResolverPreresolvedIPAddressProbability"
+	DNSResolverPreresolvedIPAddressCIDRs             = "DNSResolverPreresolvedIPAddressCIDRs"
+	DNSResolverAlternateServers                      = "DNSResolverAlternateServers"
+	DNSResolverPreferAlternateServerProbability      = "DNSResolverPreferAlternateServerProbability"
+	DNSResolverProtocolTransformProbability          = "DNSResolverProtocolTransformProbability"
+	DNSResolverProtocolTransformSpecs                = "DNSResolverProtocolTransformSpecs"
+	DNSResolverProtocolTransformScopedSpecNames      = "DNSResolverProtocolTransformScopedSpecNames"
+	DNSResolverIncludeEDNS0Probability               = "DNSResolverIncludeEDNS0Probability"
 )
 
 const (
@@ -637,6 +650,18 @@ var defaultParameters = map[string]struct {
 	UpstreamProxyAllowAllServerEntrySources: {value: false},
 
 	DestinationBytesMetricsASN: {value: "", flags: serverSideOnly},
+
+	DNSResolverAttemptsPerServer:                {value: 2, minimum: 1},
+	DNSResolverRequestTimeout:                   {value: 5 * time.Second, minimum: 100 * time.Millisecond, flags: useNetworkLatencyMultiplier},
+	DNSResolverAwaitTimeout:                     {value: 100 * time.Millisecond, minimum: 1 * time.Millisecond, flags: useNetworkLatencyMultiplier},
+	DNSResolverPreresolvedIPAddressProbability:  {value: 0.0, minimum: 0.0},
+	DNSResolverPreresolvedIPAddressCIDRs:        {value: LabeledCIDRs{}},
+	DNSResolverAlternateServers:                 {value: []string{}},
+	DNSResolverPreferAlternateServerProbability: {value: 0.0, minimum: 0.0},
+	DNSResolverProtocolTransformProbability:     {value: 0.0, minimum: 0.0},
+	DNSResolverProtocolTransformSpecs:           {value: transforms.Specs{}},
+	DNSResolverProtocolTransformScopedSpecNames: {value: transforms.ScopedSpecNames{}},
+	DNSResolverIncludeEDNS0Probability:          {value: 0.0, minimum: 0.0},
 }
 
 // IsServerSideOnly indicates if the parameter specified by name is used
@@ -961,6 +986,14 @@ func (p *Parameters) Set(
 					}
 					return nil, errors.Trace(err)
 				}
+			case LabeledCIDRs:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
 			}
 
 			// Enforce any minimums. Assumes defaultParameters[name]
@@ -1447,3 +1480,40 @@ func (p ParametersAccessor) TunnelProtocolPortLists(name string) TunnelProtocolP
 	p.snapshot.getValue(name, &value)
 	return value
 }
+
+// *TODO* move to other file?
+// *DOC*
+type LabeledCIDRs map[string][]string
+
+func (c LabeledCIDRs) Validate() error {
+	for _, CIDRs := range c {
+		for _, CIDR := range CIDRs {
+			_, _, err := net.ParseCIDR(CIDR)
+			if err != nil {
+				return errors.Trace(err)
+			}
+		}
+	}
+	return nil
+}
+
+// *DOC*
+func (p ParametersAccessor) LabeledCIDRs(name, label string) []string {
+	value := LabeledCIDRs{}
+	p.snapshot.getValue(name, &value)
+	return value[label]
+}
+
+// *DOC*
+func (p ParametersAccessor) ProtocolTransformSpecs(name string) transforms.Specs {
+	value := transforms.Specs{}
+	p.snapshot.getValue(name, &value)
+	return value
+}
+
+// *DOC*
+func (p ParametersAccessor) ProtocolTransformScopedSpecNames(name string) transforms.ScopedSpecNames {
+	value := transforms.ScopedSpecNames{}
+	p.snapshot.getValue(name, &value)
+	return value
+}

+ 1385 - 0
psiphon/common/resolver/resolver.go

@@ -0,0 +1,1385 @@
+/*
+ * Copyright (c) 2022, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+// Package resolver implements a DNS stub resolver, or DNS client, which
+// resolves domain names.
+//
+// The resolver is Psiphon-specific and oriented towards blocking resistance.
+// See ResolveIP for more details.
+package resolver
+
+import (
+	"context"
+	"encoding/hex"
+	"fmt"
+	"net"
+	"sync"
+	"sync/atomic"
+	"syscall"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
+	lrucache "github.com/cognusion/go-cache-lru"
+	"github.com/miekg/dns"
+)
+
+const (
+	resolverCacheDefaultTTL          = 1 * time.Minute
+	resolverCacheReapFrequency       = 1 * time.Minute
+	resolverCacheMaxEntries          = 10000
+	resolverServersUpdateTTL         = 5 * time.Second
+	resolverDefaultAttemptsPerServer = 2
+	resolverDefaultRequestTimeout    = 5 * time.Second
+	resolverDefaultAwaitTimeout      = 100 * time.Millisecond
+	resolverDefaultAnswerTTL         = 1 * time.Minute
+	resolverDNSPort                  = "53"
+	udpPacketBufferSize              = 1232
+)
+
+// NetworkConfig specifies network-level configuration for a Resolver.
+type NetworkConfig struct {
+
+	// GetDNSServers returns a list of system DNS server addresses(IP:port, or
+	// IP only with port 53 assumed), as determined via OS APIs, in priority
+	// order. GetDNSServers may be nil.
+	GetDNSServers func() []string
+
+	// BindToDevice should ensure the input file descriptor, a UDP socket, is
+	// excluded from VPN routing. BindToDevice may be nil.
+	BindToDevice func(fd int) (string, error)
+
+	// IPv6Synthesize should apply NAT64 synthesis to the input IPv4 address,
+	// returning a synthesized IPv6 address that will route to the same
+	// endpoint. IPv6Synthesize may be nil.
+	IPv6Synthesize func(IPv4 string) string
+
+	// LogWarning is an optional callback which is used to log warnings and
+	// transient errors which would otherwise not be recorded or returned.
+	LogWarning func(error)
+
+	// LogHostnames indicates whether to log hostname in errors or not.
+	LogHostnames bool
+}
+
+func (c *NetworkConfig) logWarning(err error) {
+	if c.LogWarning != nil {
+		c.LogWarning(err)
+	}
+}
+
+// ResolveParameters specifies the configuration and behavior of a single
+// ResolveIP call, a single domain name resolution.
+//
+// New ResolveParameters may be generated by calling MakeResolveParameters,
+// which takes tactics parameters as an input.
+//
+// ResolveParameters may be persisted for replay.
+type ResolveParameters struct {
+
+	// AttemptsPerServer specifies how many requests to send to each DNS
+	// server before trying the next server. IPv4 and IPv6 requests are set
+	// concurrently and count as one attempt.
+	AttemptsPerServer int
+
+	// RequestTimeout specifies how long to wait for a valid response before
+	// moving on to the next attempt.
+	RequestTimeout time.Duration
+
+	// AwaitTimeout specifies how long to await an additional response after
+	// the first response is received. This additional wait time applies only
+	// when there is no IPv4 or IPv6 response.
+	AwaitTimeout time.Duration
+
+	// PreresolvedIPAddress specifies an IP address result to be used in place
+	// of making a request.
+	PreresolvedIPAddress string
+
+	// AlternateDNSServer specifies an alterate DNS server to be used when
+	// either no system DNS servers are available or when
+	// PreferAlternateDNSServer is set.
+	AlternateDNSServer string
+
+	// PreferAlternateDNSServer indicates whether to prioritize using the
+	// AlternateDNSServer. When set, the AlternateDNSServer is attempted
+	// before any system DNS servers.
+	PreferAlternateDNSServer bool
+
+	// ProtocolTransformName specifies the name associated with
+	// ProtocolTransformSpec and is used for metrics.
+	ProtocolTransformName string
+
+	// ProtocolTransformSpec specifies a transform to apply to the DNS request packet.
+	// See: "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms".
+	//
+	// As transforms operate on strings and DNS requests are binary,
+	// transforms should be expressed using hex characters.
+	//
+	// DNS transforms include strategies discovered by the Geneva team,
+	// https://geneva.cs.umd.edu.
+	ProtocolTransformSpec transforms.Spec
+
+	// ProtocolTransformSeed specifies the seed to use for generating random
+	// data in the ProtocolTransformSpec transform. To replay a transform,
+	// specify the same seed.
+	ProtocolTransformSeed *prng.Seed
+
+	// IncludeEDNS0 indicates whether to include the EDNS(0) UDP maximum
+	// response size extension in DNS requests. The resolver can handle
+	// responses larger than 512 bytes (RFC 1035 maximum) regardless of
+	// whether the extension is included; the extension may be included as
+	// part of appearing similar to other DNS traffic.
+	IncludeEDNS0 bool
+}
+
+// Implementation note: Go's standard net.Resolver supports specifying a
+// custom Dial function. This could be used to implement at least a large
+// subset of the Resolver functionality on top of Go's standard library
+// resolver. However, net.Resolver is limited to using the CGO resolver on
+// Android, https://github.com/golang/go/issues/8877, in which case the
+// custom Dial function is not used. Furthermore, the the pure Go resolver in
+// net/dnsclient_unix.go appears to not be used on Windows at this time.
+//
+// Go also provides golang.org/x/net/dns/dnsmessage, a DNS message marshaller,
+// which could potentially be used in place of github.com/miekg/dns.
+
+// Resolver is a DNS stub resolver, or DNS client, which resolves domain
+// names. A Resolver instance maintains a cache, a network state snapshot,
+// and metrics. All ResolveIP calls will share the same cache and state.
+// Multiple concurrent ResolveIP calls are supported.
+type Resolver struct {
+	networkConfig *NetworkConfig
+
+	mutex             sync.Mutex
+	networkID         string
+	hasIPv6Route      bool
+	systemServers     []string
+	lastServersUpdate time.Time
+	cache             *lrucache.Cache
+	metrics           resolverMetrics
+}
+
+type resolverMetrics struct {
+	resolves      int
+	cacheHits     int
+	requestsIPv4  int
+	requestsIPv6  int
+	responsesIPv4 int
+	responsesIPv6 int
+	peakInFlight  int64
+	minRTT        time.Duration
+	maxRTT        time.Duration
+}
+
+// NewResolver creates a new Resolver, invoking Update for the specified
+// networkID.
+func NewResolver(networkConfig *NetworkConfig, networkID string) (*Resolver, error) {
+
+	r := &Resolver{
+		networkConfig: networkConfig,
+	}
+
+	// updateNetworkState will initialize the cache and network state,
+	// including system DNS servers.
+	err := r.updateNetworkState(networkID)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return r, nil
+}
+
+// Stop clears the Resolver cache and resets metrics. Stop must be called only
+// after ceasing all in-flight ResolveIP goroutines, or else the cache or
+// metrics may repopulate. A Resolver may be resumed after calling Stop, but
+// Update must be called first.
+func (r *Resolver) Stop() {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	// r.networkConfig is not set to nil to avoid possible nil pointer
+	// dereferences by concurrent ResolveIP calls.
+
+	r.networkID = ""
+	r.hasIPv6Route = false
+	r.systemServers = nil
+	r.cache.Flush()
+	r.metrics = resolverMetrics{}
+}
+
+// MakeResolveParameters generates ResolveParameters using the input tactics
+// parameters and optional frontingProviderID context.
+func (r *Resolver) MakeResolveParameters(
+	p parameters.ParametersAccessor,
+	frontingProviderID string) (*ResolveParameters, error) {
+
+	params := &ResolveParameters{
+		AttemptsPerServer: p.Int(parameters.DNSResolverAttemptsPerServer),
+		RequestTimeout:    p.Duration(parameters.DNSResolverRequestTimeout),
+		AwaitTimeout:      p.Duration(parameters.DNSResolverAwaitTimeout),
+	}
+
+	// When a frontingProviderID is specified, generate a pre-resolved IP
+	// address, based on tactics configuration.
+	if frontingProviderID != "" {
+		if p.WeightedCoinFlip(parameters.DNSResolverPreresolvedIPAddressProbability) {
+			CIDRs := p.LabeledCIDRs(parameters.DNSResolverPreresolvedIPAddressCIDRs, frontingProviderID)
+			if len(CIDRs) > 0 {
+				CIDR := CIDRs[prng.Intn(len(CIDRs))]
+				IP, err := generateIPAddressFromCIDR(CIDR)
+				if err != nil {
+					return nil, errors.Trace(err)
+				}
+				params.PreresolvedIPAddress = IP.String()
+			}
+		}
+	}
+
+	// When PreresolvedIPAddress is set, there's no DNS request and the
+	// following params can be skipped.
+	if params.PreresolvedIPAddress != "" {
+		return params, nil
+	}
+
+	// Select an alternate DNS server, typically a public DNS server. Ensure
+	// tactics is configured with an empty DNSResolverAlternateServers list
+	// in cases where attempts to public DNS server are unwanted.
+	alternateServers := p.Strings(parameters.DNSResolverAlternateServers)
+	if len(alternateServers) > 0 {
+		params.AlternateDNSServer = alternateServers[prng.Intn(len(alternateServers))]
+	}
+	params.PreferAlternateDNSServer = p.WeightedCoinFlip(
+		parameters.DNSResolverPreferAlternateServerProbability)
+
+	// Select a DNS transform. DNS request transforms are "scoped" by
+	// alternate DNS server; that is, when an alternate DNS server is certain
+	// to be attempted first, a transform associated with and known to work
+	// with that DNS server will be selected. Otherwise, a transform from the
+	// default scope (transforms.SCOPE_ANY == "") is selected.
+	//
+	// In any case, ResolveIP will only apply a transform on the first request
+	// attempt.
+	if p.WeightedCoinFlip(parameters.DNSResolverProtocolTransformProbability) {
+
+		specs := p.ProtocolTransformSpecs(
+			parameters.DNSResolverProtocolTransformSpecs)
+		scopedSpecNames := p.ProtocolTransformScopedSpecNames(
+			parameters.DNSResolverProtocolTransformScopedSpecNames)
+
+		// The alternate DNS server will be the first attempt if
+		// PreferAlternateDNSServer or the list of system DNS servers is empty.
+		//
+		// Limitation: the system DNS server list may change, due to a later
+		// Resolver.update call when ResolveIP is called with these
+		// ResolveParameters.
+		_, systemServers := r.getNetworkState()
+		scope := transforms.SCOPE_ANY
+		if params.PreferAlternateDNSServer || len(systemServers) == 0 {
+			scope = params.AlternateDNSServer
+		}
+
+		name, spec := specs.Select(scope, scopedSpecNames)
+
+		if spec != nil {
+			params.ProtocolTransformName = name
+			params.ProtocolTransformSpec = spec
+			var err error
+			params.ProtocolTransformSeed, err = prng.NewSeed()
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+		}
+	}
+
+	if p.WeightedCoinFlip(parameters.DNSResolverIncludeEDNS0Probability) {
+		params.IncludeEDNS0 = true
+	}
+
+	return params, nil
+}
+
+// ResolveAddress splits the input host:port address, calls ResolveIP to
+// resolve the IP address of the host, selects an IP if there are multiple,
+// and returns a rejoined IP:port.
+func (r *Resolver) ResolveAddress(
+	ctx context.Context,
+	networkID string,
+	params *ResolveParameters,
+	address string) (string, error) {
+
+	hostname, port, err := net.SplitHostPort(address)
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+
+	IPs, err := r.ResolveIP(ctx, networkID, params, hostname)
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+
+	return net.JoinHostPort(IPs[prng.Intn(len(IPs))].String(), port), nil
+}
+
+// ResolveIP resolves a domain name.
+//
+// The input params may be nil, in which case default timeouts are used.
+//
+// ResolveIP performs concurrent A and AAAA lookups, returns any valid
+// response IPs, and caches results.
+//
+// ResolveIP is not a general purpose resolver and is Psiphon-specific. For
+// example, resolved domains are expected to exist; ResolveIP does not
+// fallback to TCP; does not consult any "hosts" file; does not perform RFC
+// 3484 sorting logic (see Go issue 18518); only implements a subset of
+// Go/glibc/resolv.conf(5) resolver parameters (attempts and timeouts, but
+// not rotate, single-request etc.) ResolveIP does not implement singleflight
+// logic, as the Go resolver does, and allows multiple concurrent request for
+// the same domain -- Psiphon won't often resolve the exact same domain
+// multiple times concurrently, and, when it does, there's a circumvention
+// benefit to attempting different DNS servers and protocol transforms.
+//
+// ResolveIP does not currently support DoT, DoH, or TCP; those protocols are
+// often blocked or less common. Instead, ResolveIP makes a best effort to
+// evade plaintext UDP DNS interference by ignoring invalid responses and by
+// optionally applying protocol transforms that may evade blocking.
+func (r *Resolver) ResolveIP(
+	ctx context.Context,
+	networkID string,
+	params *ResolveParameters,
+	hostname string) ([]net.IP, error) {
+
+	// ResolveIP does _not_ lock r.mutex for the lifetime of the function, to
+	// ensure many ResolveIP calls can run concurrently.
+
+	// Call updateNetworkState immediately before resolving, as a best effort
+	// to ensure that system DNS servers and IPv6 routing network state
+	// reflects the current network. updateNetworkState locks the Resolver
+	// mutex for its duration, and so concurrent ResolveIP calls may block at
+	// this point. However, all updateNetworkState operations are local to
+	// the host or device; and, if the networkID is unchanged since the last
+	// call, updateNetworkState may not perform any operations; and after the
+	// updateNetworkState call, ResolveIP proceeds without holding the mutex
+	// lock. As a result, this step should not prevent ResolveIP concurrency.
+	err := r.updateNetworkState(networkID)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	if params == nil {
+		// Supply default ResolveParameters
+		params = &ResolveParameters{
+			AttemptsPerServer: resolverDefaultAttemptsPerServer,
+			RequestTimeout:    resolverDefaultRequestTimeout,
+			AwaitTimeout:      resolverDefaultAwaitTimeout,
+		}
+	}
+
+	// If the hostname is already an IP address, just return that. For
+	// metrics, this does not count as a resolve, as the caller may invoke
+	// ResolveIP for all dials.
+	IP := net.ParseIP(hostname)
+	if IP != nil {
+		return []net.IP{IP}, nil
+	}
+
+	// Count all resolves of an actual domain, including cached and
+	// pre-resolved cases.
+	r.updateMetricResolves()
+
+	// When PreresolvedIPAddress is set, tactics parameters determined the IP address
+	// in this case.
+	if params.PreresolvedIPAddress != "" {
+		IP := net.ParseIP(params.PreresolvedIPAddress)
+		if IP == nil {
+			// Unexpected case, as MakeResolveParameters selects the IP address.
+			return nil, errors.TraceNew("invalid IP address")
+		}
+		return []net.IP{IP}, nil
+	}
+
+	// Use a snapshot of the current network state, including IPv6 routing and
+	// system DNS servers.
+	//
+	// Limitation: these values are used even if the network changes in the
+	// middle of a ResolveIP call; ResolveIP is not interrupted if the
+	// network changes.
+	hasIPv6Route, systemServers := r.getNetworkState()
+
+	// Use the standard library resolver when there's no GetDNSServers, or the
+	// system server list is otherwise empty, and no alternate DNS server is
+	// configured.
+	//
+	// Note that in the case where there are no system DNS servers and there
+	// is an AlternateDNSServer, if the AlternateDNSServer attempt fails,
+	// control does not flow back to defaultResolverLookupIP. On platforms
+	// without GetDNSServers, the caller must arrange for distinct attempts
+	// that try a AlternateDNSServer, or just use the standard library
+	// resolver.
+	//
+	// ResolveIP should always be called, even when defaultResolverLookupIP
+	// will be used, to ensure correct metrics counts and ensure a consistent
+	// error message log stack for all DNS-related failures.
+	if len(systemServers) == 0 && params.AlternateDNSServer == "" {
+		IPs, err := defaultResolverLookupIP(ctx, hostname, r.networkConfig.LogHostnames)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		return IPs, err
+	}
+
+	// Consult the cache before making queries. This comes after the standard
+	// library case, to allow the standard library to provide its own caching
+	// logic.
+	IPs := r.getCache(hostname)
+	if IPs != nil {
+		return IPs, nil
+	}
+
+	// Set the list of DNS servers to attempt. AlternateDNSServer is used
+	// first when PreferAlternateDNSServer is set; otherwise
+	// AlternateDNSServer is used only when there is no system DNS server.
+	var servers []string
+	if params.AlternateDNSServer != "" &&
+		(len(systemServers) == 0 || params.PreferAlternateDNSServer) {
+		servers = []string{params.AlternateDNSServer}
+	}
+	servers = append(servers, systemServers...)
+	if len(servers) == 0 {
+		return nil, errors.TraceNew("no DNS servers")
+	}
+
+	// Set the request timeout and set up a reusable timer for handling
+	// request and await timeouts.
+	//
+	// We expect to always have a request timeout. Handle the unexpected no
+	// timeout, 0, case by setting the longest timeout possible, ~290 years;
+	// always having a non-zero timeout makes the following code marginally
+	// simpler.
+	requestTimeout := params.RequestTimeout
+	if requestTimeout == 0 {
+		requestTimeout = 1<<63 - 1
+	}
+	var timer *time.Timer
+	timerDrained := true
+	resetTimer := func(timeout time.Duration) {
+		if timer == nil {
+			timer = time.NewTimer(timeout)
+		} else {
+			if !timerDrained && !timer.Stop() {
+				<-timer.C
+			}
+			timer.Reset(timeout)
+		}
+		timerDrained = false
+	}
+
+	// Orchestrate the DNS requests
+
+	resolveCtx, cancelFunc := context.WithCancel(ctx)
+	defer cancelFunc()
+	waitGroup := new(sync.WaitGroup)
+	conns := common.NewConns()
+	type answer struct {
+		IPs  []net.IP
+		TTLs []time.Duration
+	}
+	maxAttempts := len(servers) * params.AttemptsPerServer
+	answerChan := make(chan *answer, maxAttempts*2)
+	inFlight := int64(0)
+	awaitA := int32(1)
+	awaitAAAA := int32(1)
+	if !hasIPv6Route {
+		awaitAAAA = 0
+	}
+	var result *answer
+	var lastErr atomic.Value
+
+	stop := false
+	for i := 0; !stop && i < maxAttempts; i++ {
+
+		// Limitation: AttemptsPerServer applies for all servers, including
+		// the AlternateDNSSever. So in the PreferAlternateDNSServer case,
+		// that many attempts are made before falling back to system DNS servers.
+
+		server := servers[i/params.AttemptsPerServer]
+
+		// Only the first attempt pair tries transforms, as it's not certain
+		// the transforms will be compatible with DNS servers.
+		useProtocolTransform := (i == 0 && params.ProtocolTransformSpec != nil)
+
+		// Send A and AAAA requests concurrently.
+		questionTypes := []resolverQuestionType{resolverQuestionTypeA, resolverQuestionTypeAAAA}
+		if !hasIPv6Route {
+			questionTypes = questionTypes[0:1]
+		}
+
+		for _, questionType := range questionTypes {
+
+			waitGroup.Add(1)
+
+			// For metrics, track peak concurrent in-flight requests for
+			// a _single_ ResolveIP. inFlight for this ResolveIP is also used
+			// to determine whether to await additional responses once the
+			// first, valid response is received. For that logic to be
+			// correct, we must increment inFlight in this outer goroutine to
+			// ensure the await logic sees either inFlight > 0 or an answer
+			// in the channel.
+			r.updateMetricPeakInFlight(atomic.AddInt64(&inFlight, 1))
+
+			go func(questionType resolverQuestionType, useProtocolTransform bool) {
+				defer waitGroup.Done()
+
+				// We must decrement inFlight only after sending an answer and
+				// setting awaitA or awaitAAAA to ensure that the await logic
+				// in the outer goroutine will see inFlight 0 only once those
+				// operations are complete.
+				//
+				// We cannot wait and decement inFlight when the outer
+				// goroutine receives answers, as no answer is sent in some
+				// cases, such as when the resolve fails due to NXDOMAIN.
+				defer atomic.AddInt64(&inFlight, -1)
+
+				// The request count metric counts the _intention_ to send
+				// requests, as there's a possibility that newResolverConn or
+				// performDNSQuery fail locally before sending a request packet.
+				switch questionType {
+				case resolverQuestionTypeA:
+					r.updateMetricRequestsIPv4()
+				case resolverQuestionTypeAAAA:
+					r.updateMetricRequestsIPv6()
+				}
+
+				// While it's possible, and potentially more more optimal, to
+				// use the same UDP socket for both the A and AAAA request,
+				// we use a distinct socket per request, as common DNS clients do.
+				conn, err := r.newResolverConn(server)
+				if err != nil {
+					if resolveCtx.Err() == nil {
+						lastErr.Store(errors.Trace(err))
+					}
+					return
+				}
+				defer conn.Close()
+
+				// There's no context.Context support in the underlying API
+				// used by performDNSQuery, so instead collect all the
+				// request conns so that they can be closed, and any blocking
+				// network I/O interrupted, below, if resolveCtx is done.
+				if !conns.Add(conn) {
+					// Add fails when conns is already closed.
+					return
+				}
+
+				// performDNSQuery will send the request and read a response.
+				// performDNSQuery will continue reading responses until it
+				// receives a valid response, which can mitigate a subset of
+				// DNS injection attacks (to the limited extent possible for
+				// plaintext DNS).
+				//
+				// For IPv4, NXDOMAIN or a response with no IPs is not
+				// expected for domains resolved by Psiphon, so
+				// performDNSQuery treats such a response as invalid. For
+				// IPv6, a response with no IPs, may be valid(even though the
+				// response could be forged); the resolver will continue its
+				// attempts loop if it has no other IPs.
+				//
+				// Each performDNSQuery has no timeout and runs
+				// until it has read a valid response or the requestCtx is
+				// done. This allows for slow arriving, valid responses to
+				// eventually succeed, even if the read time exceeds
+				// requestTimeout, as long as the read time is less than the
+				// requestCtx timeout.
+				//
+				// With this approach, the overall ResolveIP call may have
+				// more than 2 performDNSQuery requests in-flight at a time,
+				// as requestTimeout is used to schedule sending the next
+				// attempt but not cancel the current attempt. For
+				// connectionless UDP, the resulting network traffic should
+				// be similar to common DNS clients which do cancel request
+				// before beginning the next attempt.
+				IPs, TTLs, RTT, err := performDNSQuery(
+					resolveCtx,
+					r.networkConfig.logWarning,
+					params,
+					useProtocolTransform,
+					conn,
+					questionType,
+					hostname)
+
+				// Update the min/max RTT metric as long as the reported RTT >
+				// 0, even if the result is an error -- i.e., the even if
+				// there was an invalid response --  unless the resolveCtx is
+				// done (we don't want to consider the RTT in the case of
+				// cancellation). This assumes no actual RTT will be 0 nanoseconds
+				//
+				// Limitation: since individual requests aren't cancelled
+				// after requestTimeout, RTT metrics won't reflect
+				// no-response cases, although request and response count
+				// disparities will still show up in the metrics.
+				if RTT > 0 && resolveCtx.Err() == nil {
+					r.updateMetricRTT(RTT)
+				}
+
+				if err != nil {
+					if resolveCtx.Err() == nil {
+						lastErr.Store(errors.Trace(err))
+					}
+					return
+				}
+
+				if len(IPs) > 0 {
+					select {
+					case answerChan <- &answer{IPs: IPs, TTLs: TTLs}:
+					default:
+					}
+				}
+
+				// Mark no longer awaiting A or AAAA as long as there is a
+				// valid response, even if there are no IPs in the IPv6 case.
+				switch questionType {
+				case resolverQuestionTypeA:
+					r.updateMetricResponsesIPv4()
+					atomic.StoreInt32(&awaitA, 0)
+				case resolverQuestionTypeAAAA:
+					r.updateMetricResponsesIPv6()
+					atomic.StoreInt32(&awaitAAAA, 0)
+				default:
+				}
+
+			}(questionType, useProtocolTransform)
+		}
+
+		resetTimer(requestTimeout)
+
+		select {
+		case result = <-answerChan:
+			// When the first answer, a response with valid IPs, arrives, exit
+			// the attempts loop. The following await branch may collect
+			// additional answers.
+			stop = true
+		case <-timer.C:
+			// When requestTimeout arrives, loop around and launch the next
+			// attempt; leave the existing requests running in case they
+			// eventually respond.
+			lastErr.Store(errors.TraceNew("timeout"))
+			timerDrained = true
+		case <-resolveCtx.Done():
+			// When resolveCtx is done, exit the attempts loop.
+			//
+			// TODO: retain previous lastErr, for failed_tunnel, if it
+			// provides more infomation about DNS interference? As part of
+			// this, have performDNSQuery also return relevent errors instead
+			// of simply calling LogWarning for invalid DNS responses.
+			lastErr.Store(errors.Trace(ctx.Err()))
+			stop = true
+		}
+	}
+
+	// Receive any additional answers, now present in the channel, which
+	// arrived concurrent with the first answer. This receive avoids a race
+	// condition where inFlight may now be 0, with additional answers
+	// enqueued, in which case the following await branch is not taken.
+	select {
+	case nextAnswer := <-answerChan:
+		result.IPs = append(result.IPs, nextAnswer.IPs...)
+		result.TTLs = append(result.TTLs, nextAnswer.TTLs...)
+	default:
+	}
+
+	// When we have an answer, await -- for a short time,
+	// params.AwaitTimeout -- extra answers from any remaining in-flight
+	// requests. Only await if the request isn't cancelled and we don't
+	// already have at least one IPv4 and one IPv6 response; only await AAAA
+	// if it was sent; note that a valid AAAA response may include no IPs
+	// lastErr is not set in timeout/cancelled cases here, since we already
+	// have an answer.
+	if result != nil &&
+		resolveCtx.Err() == nil &&
+		atomic.LoadInt64(&inFlight) > 0 &&
+		(atomic.LoadInt32(&awaitA) != 0 || atomic.LoadInt32(&awaitAAAA) != 0) ||
+		params.AwaitTimeout > 0 {
+
+		resetTimer(params.AwaitTimeout)
+
+		for {
+
+			stop := false
+			select {
+			case nextAnswer := <-answerChan:
+				result.IPs = append(result.IPs, nextAnswer.IPs...)
+				result.TTLs = append(result.TTLs, nextAnswer.TTLs...)
+			case <-timer.C:
+				timerDrained = true
+				stop = true
+			case <-resolveCtx.Done():
+				stop = true
+			}
+
+			if stop ||
+				atomic.LoadInt64(&inFlight) == 0 ||
+				(atomic.LoadInt32(&awaitA) == 0 && atomic.LoadInt32(&awaitAAAA) == 0) {
+				break
+			}
+		}
+	}
+
+	timer.Stop()
+
+	// Interrupt all workers.
+	cancelFunc()
+	conns.CloseAll()
+	waitGroup.Wait()
+
+	// When there's no answer, return the last error.
+	if result == nil {
+		err := lastErr.Load()
+		if err == nil {
+			err = errors.TraceNew("missing error")
+		}
+		if r.networkConfig.LogHostnames {
+			err = fmt.Errorf("resolve %s : %w", hostname, err.(error))
+		}
+		return nil, errors.Trace(err.(error))
+	}
+
+	// Update the cache now, after all results are gathered.
+	r.setCache(hostname, result.IPs, result.TTLs)
+
+	return result.IPs, nil
+}
+
+// GetMetrics returns a summary of DNS metrics.
+func (r *Resolver) GetMetrics() string {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	return fmt.Sprintf("resolves %d | hit %d | req v4/v6 %d/%d | resp %d/%d | peak %d | rtt %d - %d ms.",
+		r.metrics.resolves,
+		r.metrics.cacheHits,
+		r.metrics.requestsIPv4,
+		r.metrics.requestsIPv6,
+		r.metrics.responsesIPv4,
+		r.metrics.responsesIPv6,
+		r.metrics.peakInFlight,
+		r.metrics.minRTT/time.Millisecond,
+		r.metrics.maxRTT/time.Millisecond)
+}
+
+// updateNetworkState updates the system DNS server list, IPv6 state, and the
+// cache.
+func (r *Resolver) updateNetworkState(networkID string) error {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	// Only perform blocking/expensive update operations when necessary.
+	updateAll := false
+	updateIPv6Route := false
+	updateServers := false
+	flushCache := false
+
+	// If r.cache is nil, this is the first update call in NewResolver. Create
+	// the cache and perform all updates.
+	if r.cache == nil {
+		r.cache = lrucache.NewWithLRU(
+			resolverCacheDefaultTTL,
+			resolverCacheReapFrequency,
+			resolverCacheMaxEntries)
+		updateAll = true
+	}
+
+	// Perform all updates when the networkID has changed, which indicates a
+	// different network.
+	if r.networkID != networkID {
+		updateAll = true
+	}
+
+	if updateAll {
+		updateIPv6Route = true
+		updateServers = true
+		flushCache = true
+	}
+
+	// Even when the networkID has not changed, update DNS servers
+	// periodically. This is similar to how other DNS clients
+	// poll /etc/resolv.conf, including the period of 5s.
+	if time.Since(r.lastServersUpdate) > resolverServersUpdateTTL {
+		updateServers = true
+	}
+
+	// Update hasIPv6Route, which indicates whether the current network has an
+	// IPv6 route and so if DNS requests for AAAA records will be sent.
+	// There's no use for AAAA records on IPv4-only networks; and other
+	// common DNS clients omit AAAA requests on IPv4-only records, so these
+	// requests would otherwise be unusual.
+	//
+	// There's no hasIPv4Route as we always need to resolve A records,
+	// particularly for IPv4-only endpoints; for IPv6-only networks,
+	// NetworkConfig.IPv6Synthesize should be used to accomodate IPv4 DNS
+	// server addresses, and dials performed outside the Resolver will
+	// similarly use NAT 64 (on iOS; on Android, 464XLAT will handle this
+	// transparently).
+	if updateIPv6Route {
+		hasIPv6Route, err := hasRoutableIPv6Interface()
+		if err != nil {
+			// Log warning and proceed without IPv6.
+			r.networkConfig.logWarning(
+				errors.Tracef("unable to determine IPv6 route: %v", err))
+			hasIPv6Route = false
+		}
+		r.hasIPv6Route = hasIPv6Route
+	}
+
+	// Update the list of system DNS servers. It's not an error condition here
+	// if the list is empty: a subsequent ResolveIP may use
+	// ResolveParameters which specifies an AlternateDNSServer.
+	if updateServers && r.networkConfig.GetDNSServers != nil {
+
+		servers := r.networkConfig.GetDNSServers()
+
+		// Make a copy; don't rely on the slice from GetDNSServers being stable.
+		systemServers := append([]string(nil), servers...)
+
+		for i, systemServer := range systemServers {
+			host, _, err := net.SplitHostPort(systemServer)
+			if err != nil {
+				// Assume the SplitHostPort error is due to systemServer being
+				// an IP only, and append the default port, 53. If
+				// systemServer _isn't_ an IP, the following ParseIP will fail.
+				systemServers[i] = net.JoinHostPort(systemServer, resolverDNSPort)
+				host = systemServer
+			}
+			if net.ParseIP(host) == nil {
+				// Log warning and proceed without this DNS server.
+				r.networkConfig.logWarning(
+					errors.TraceNew("invalid DNS server IP address"))
+			}
+		}
+
+		// Check if the list of servers has changed, including order. If
+		// changed, flush the cache even if the networkID has not changed.
+		// Cached results are only considered valid as long as the system DNS
+		// configuration remains the same.
+		equal := len(r.systemServers) == len(systemServers)
+		if equal {
+			for i := 0; i < len(r.systemServers); i++ {
+				if r.systemServers[i] != systemServers[i] {
+					equal = false
+					break
+				}
+			}
+		}
+		flushCache = flushCache || !equal
+
+		// Concurrency note: once the r.systemServers slice is set, the
+		// contents of the backing array must not be modified due to
+		// concurrent ResolveIP calls.
+		r.systemServers = systemServers
+
+		r.lastServersUpdate = time.Now()
+	}
+
+	if flushCache {
+		r.cache.Flush()
+	}
+
+	// Set r.networkID only after all operations complete without errors; if
+	// r.networkID were set earlier, a subsequent
+	// ResolveIP/updateNetworkState call might proceed as if the network
+	// state were updated for the specified network ID.
+	r.networkID = networkID
+
+	return nil
+}
+
+func (r *Resolver) getNetworkState() (bool, []string) {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	return r.hasIPv6Route, r.systemServers
+}
+
+func (r *Resolver) setCache(hostname string, IPs []net.IP, TTLs []time.Duration) {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	// In case TTL is zero, use a default. The shortest TTL is used.
+	TTL := resolverDefaultAnswerTTL
+	for _, answerTTL := range TTLs {
+		if answerTTL < TTL {
+			TTL = answerTTL
+		}
+	}
+
+	// Limitation: with concurrent ResolveIPs for the same domain, the last
+	// setCache call determines the cache value. The results are not merged.
+
+	r.cache.Set(hostname, IPs, TTL)
+}
+
+func (r *Resolver) getCache(hostname string) []net.IP {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	entry, ok := r.cache.Get(hostname)
+	if !ok {
+		return nil
+	}
+	r.metrics.cacheHits += 1
+	return entry.([]net.IP)
+}
+
+// newResolverConn creates a UDP socket that will send packets to serverAddr.
+// serverAddr is an IP:port, which allows specifying the port for testing or
+// in rare cases where the port isn't 53.
+func (r *Resolver) newResolverConn(serverAddr string) (net.Conn, error) {
+
+	// When configured, attempt to synthesize an IPv6 address from
+	// an IPv4 address for compatibility on DNS64/NAT64 networks.
+	// If synthesize fails, try the original address.
+	if r.networkConfig.IPv6Synthesize != nil {
+		serverIPStr, port, err := net.SplitHostPort(serverAddr)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		serverIP := net.ParseIP(serverIPStr)
+		if serverIP != nil && serverIP.To4() != nil {
+			synthesized := r.networkConfig.IPv6Synthesize(serverIPStr)
+			if synthesized != "" && net.ParseIP(synthesized) != nil {
+				serverAddr = net.JoinHostPort(synthesized, port)
+			}
+		}
+	}
+
+	dialer := &net.Dialer{}
+	if r.networkConfig.BindToDevice != nil {
+		dialer.Control = func(_, _ string, c syscall.RawConn) error {
+			var controlErr error
+			err := c.Control(func(fd uintptr) {
+				_, err := r.networkConfig.BindToDevice(int(fd))
+				if err != nil {
+					controlErr = errors.Tracef("BindToDevice failed: %s", err)
+					return
+				}
+			})
+			if controlErr != nil {
+				return errors.Trace(controlErr)
+			}
+			return errors.Trace(err)
+		}
+	}
+
+	// context.Background is ok in this case as the UDP dial is just a local
+	// syscall to create the socket.
+	conn, err := dialer.DialContext(context.Background(), "udp", serverAddr)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return conn, nil
+}
+
+func (r *Resolver) updateMetricResolves() {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	r.metrics.resolves += 1
+}
+
+func (r *Resolver) updateMetricRequestsIPv4() {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	r.metrics.requestsIPv4 += 1
+}
+
+func (r *Resolver) updateMetricRequestsIPv6() {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	r.metrics.requestsIPv6 += 1
+}
+
+func (r *Resolver) updateMetricResponsesIPv4() {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	r.metrics.responsesIPv4 += 1
+}
+
+func (r *Resolver) updateMetricResponsesIPv6() {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	r.metrics.responsesIPv6 += 1
+}
+
+func (r *Resolver) updateMetricPeakInFlight(inFlight int64) {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	if inFlight > r.metrics.peakInFlight {
+		r.metrics.peakInFlight = inFlight
+	}
+}
+
+func (r *Resolver) updateMetricRTT(rtt time.Duration) {
+	r.mutex.Lock()
+	defer r.mutex.Unlock()
+
+	// This assumes no actual RTT will be 0 nanoseconds
+	if r.metrics.minRTT == 0 || rtt < r.metrics.minRTT {
+		r.metrics.minRTT = rtt
+	}
+
+	if rtt > r.metrics.maxRTT {
+		r.metrics.maxRTT = rtt
+	}
+}
+
+func hasRoutableIPv6Interface() (bool, error) {
+
+	interfaces, err := net.Interfaces()
+	if err != nil {
+		return false, errors.Trace(err)
+	}
+
+	for _, in := range interfaces {
+
+		if (in.Flags&net.FlagUp == 0) ||
+			(in.Flags&(net.FlagLoopback|net.FlagPointToPoint)) != 0 {
+			continue
+		}
+
+		addrs, err := in.Addrs()
+		if err != nil {
+			return false, errors.Trace(err)
+		}
+
+		for _, addr := range addrs {
+			if IPNet, ok := addr.(*net.IPNet); ok &&
+				IPNet.IP.To4() == nil &&
+				!IPNet.IP.IsLinkLocalUnicast() {
+
+				return true, nil
+			}
+		}
+	}
+
+	return false, nil
+}
+
+func generateIPAddressFromCIDR(CIDR string) (net.IP, error) {
+	_, IPNet, err := net.ParseCIDR(CIDR)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	// A retry is required, since a CIDR may include broadcast IPs (a.b.c.0) or
+	// other invalid values. The number of retries is limited to ensure we
+	// don't hanging case of a misconfiguration
+	for i := 0; i < 10; i++ {
+		randBytes := prng.Bytes(len(IPNet.IP))
+		IP := make(net.IP, len(IPNet.IP))
+		// The 1 bits in the mask must apply to the IP in the CIDR and the 0
+		// bits in the mask are available to randomize.
+		for i := 0; i < len(IP); i++ {
+			IP[i] = (IPNet.IP[i] & IPNet.Mask[i]) | (randBytes[i] & ^IPNet.Mask[i])
+		}
+		if IP.IsGlobalUnicast() && !common.IsBogon(IP) {
+			return IP, nil
+		}
+	}
+	return nil, errors.TraceNew("failed to generate random IP")
+}
+
+type resolverQuestionType int
+
+const (
+	resolverQuestionTypeA    = 0
+	resolverQuestionTypeAAAA = 1
+)
+
+func performDNSQuery(
+	resolveCtx context.Context,
+	logWarning func(error),
+	params *ResolveParameters,
+	useProtocolTransform bool,
+	conn net.Conn,
+	questionType resolverQuestionType,
+	hostname string) ([]net.IP, []time.Duration, time.Duration, error) {
+
+	if useProtocolTransform {
+		if params.ProtocolTransformSpec == nil ||
+			params.ProtocolTransformSeed == nil {
+			return nil, nil, 0, errors.TraceNew("invalid protocol transform configuration")
+		}
+		// miekg/dns expects conn to be a net.PacketConn or else it writes the
+		// TCP length prefix
+		udpConn, ok := conn.(*net.UDPConn)
+		if !ok {
+			return nil, nil, 0, errors.TraceNew("conn is not a *net.UDPConn")
+		}
+		conn = &transformDNSPacketConn{
+			UDPConn:   udpConn,
+			transform: params.ProtocolTransformSpec,
+			seed:      params.ProtocolTransformSeed,
+		}
+	}
+
+	// UDPSize sets the receive buffer to > 512, even when we don't include
+	// EDNS(0), which will mitigate issues with RFC 1035 non-compliant
+	// servers. See Go issue 51127.
+	dnsConn := &dns.Conn{
+		Conn:    conn,
+		UDPSize: udpPacketBufferSize,
+	}
+	defer dnsConn.Close()
+
+	// SetQuestion initializes request.MsgHdr.Id to a random value
+	request := &dns.Msg{MsgHdr: dns.MsgHdr{RecursionDesired: true}}
+	switch questionType {
+	case resolverQuestionTypeA:
+		request.SetQuestion(dns.Fqdn(hostname), dns.TypeA)
+	case resolverQuestionTypeAAAA:
+		request.SetQuestion(dns.Fqdn(hostname), dns.TypeAAAA)
+	default:
+		return nil, nil, 0, errors.TraceNew("unknown DNS request question type")
+	}
+	if params.IncludeEDNS0 {
+		// miekg/dns: "RFC 6891, Section 6.1.1 allows the OPT record to appear
+		// anywhere in the additional record section, but it's usually at the
+		// end..."
+		request.SetEdns0(udpPacketBufferSize, false)
+	}
+
+	startTime := time.Now()
+
+	// Send the DNS request
+	dnsConn.WriteMsg(request)
+
+	// Read and process the DNS response
+	var IPs []net.IP
+	var TTLs []time.Duration
+	for {
+
+		// Stop when resolveCtx is done; the caller, ResolveIP, will also
+		// close conn, which will interrupt a blocking dnsConn.ReadMsg.
+		if resolveCtx.Err() != nil {
+			return nil, nil, time.Since(startTime), errors.Trace(resolveCtx.Err())
+		}
+
+		// Read a response
+		response, err := dnsConn.ReadMsg()
+		if err == nil && response.MsgHdr.Id != request.MsgHdr.Id {
+			err = dns.ErrId
+		}
+		if err != nil {
+			// Try reading again, in case the first response packet failed to
+			// unmarshal or had an invalid ID. The Go resolver also does this;
+			// see Go issue 13281.
+			if resolveCtx.Err() == nil {
+				// Only log if resolveCtx is not done; otherwise the error could
+				// be due to conn being closed by ResolveIP.
+				logWarning(errors.Tracef("invalid response: %v", err))
+			}
+			continue
+		}
+
+		// Check the RCode.
+		//
+		// For IPv4, we expect RCodeSuccess as Psiphon will typically only
+		// resolve domains that exist and have a valid IP (when this isn't
+		// the case, and we retry, the overall ResolveIP and its parent dial
+		// will still abort after resolveCtx is done, or RequestTimeout
+		// expires for maxAttempts).
+		//
+		// For IPv6, we should also expect RCodeSucess even if there is no
+		// AAAA record, as long as the domain exists and has an A record.
+		// However, per RFC 6147 section 5.1.2, we may receive
+		// NXDOMAIN: "...some servers respond with RCODE=3 to a AAAA query
+		// even if there is an A record available for that owner name. Those
+		// servers are in clear violation of the meaning of RCODE 3...". In
+		// this case, we coalese NXDOMAIN into success to treat the response
+		// the same as success with no AAAA record.
+		//
+		// All other RCodes, which are unexpected, lead to a read retry.
+		if response.MsgHdr.Rcode != dns.RcodeSuccess &&
+			!(questionType == resolverQuestionTypeAAAA && response.MsgHdr.Rcode == dns.RcodeNameError) {
+
+			errMsg, ok := dns.RcodeToString[response.MsgHdr.Rcode]
+			if !ok {
+				errMsg = fmt.Sprintf("Rcode: %d", response.MsgHdr.Rcode)
+			}
+			logWarning(errors.Tracef("unexpected RCode: %v", errMsg))
+			continue
+		}
+
+		// Extract all IP answers, along with corresponding TTLs for caching.
+		// Perform additional validation, which may lead to another read
+		// retry. However, if _any_ valid IP is found, stop reading and
+		// return that result. Again, the validation is only best effort.
+
+		checkFailed := false
+		for _, answer := range response.Answer {
+			var IP net.IP
+			var TTLSec uint32
+			switch questionType {
+			case resolverQuestionTypeA:
+				if a, ok := answer.(*dns.A); ok {
+					IP = a.A
+					TTLSec = a.Hdr.Ttl
+				}
+			case resolverQuestionTypeAAAA:
+				if aaaa, ok := answer.(*dns.AAAA); ok {
+					IP = aaaa.AAAA
+					TTLSec = aaaa.Hdr.Ttl
+				}
+			}
+			err := checkDNSAnswerIP(IP)
+			if err != nil {
+				checkFailed = true
+				logWarning(errors.Tracef("invalid IP: %v", err))
+				// Check the next answer
+				continue
+			}
+			IPs = append(IPs, IP)
+			TTLs = append(TTLs, time.Duration(TTLSec)*time.Second)
+		}
+
+		// For IPv4, an IP is expected, as noted in the comment above.
+		//
+		// In potential cases where we resolve a domain that has only an IPv6
+		// address, the concurrent AAAA request will deliver its result to
+		// ResolveIP, and that answer will be selected, so only the "await"
+		// logic will delay the parent dial in that case.
+		if questionType == resolverQuestionTypeA && len(IPs) == 0 && !checkFailed {
+			checkFailed = true
+			logWarning(errors.TraceNew("unexpected empty A response"))
+		}
+
+		// Retry if there are no valid IPs and any error; if no error, this
+		// may be a valid AAAA response with no IPs, in which case return the
+		// result.
+		if len(IPs) == 0 && checkFailed {
+			continue
+		}
+
+		return IPs, TTLs, time.Since(startTime), nil
+	}
+}
+
+func checkDNSAnswerIP(IP net.IP) error {
+
+	if IP == nil {
+		return errors.TraceNew("IP is nil")
+	}
+
+	// Limitation: this could still be a phony/injected response, it's not
+	// possible to verify with plaintext DNS, but a "bogon" IP is clearly
+	// invalid.
+	if common.IsBogon(IP) {
+		return errors.TraceNew("IP is bogon")
+	}
+
+	// Create a temporary socket bound to the destination IP. This checks
+	// thats the local host has a route to this IP. If not, we'll reject the
+	// IP. This prevents selecting an IP which is guaranteed to fail to dial.
+	// Use UDP as this results in no network traffic; the destination port is
+	// arbitrary. The Go resolver performs a similar operation.
+	//
+	// Limitations:
+	// - We may cache the IP and reuse it without checking routability again;
+	//   the cache should be flushed when network state changes.
+	// - Given that the AAAA is requested only when the host has an IPv6
+	//   route, we don't expect this to often fail with a _valid_ response.
+	//   However, this remains a possibility and in this case,
+	//   performDNSQuery will keep awaiting a response which can trigger
+	//   the "await" logic.
+	conn, err := net.DialUDP("udp", nil, &net.UDPAddr{IP: IP, Port: 443})
+	if err != nil {
+		return errors.Trace(err)
+	}
+	conn.Close()
+
+	return nil
+}
+
+func defaultResolverLookupIP(
+	ctx context.Context, hostname string, logHostnames bool) ([]net.IP, error) {
+
+	addrs, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
+
+	if err != nil && !logHostnames {
+		// Remove domain names from "net" error messages.
+		err = common.RedactNetError(err)
+	}
+
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	ips := make([]net.IP, len(addrs))
+	for i, addr := range addrs {
+		ips[i] = addr.IP
+	}
+
+	return ips, nil
+}
+
+// transformDNSPacketConn wraps a *net.UDPConn, intercepting Write calls and
+// applying the specified protocol transform.
+//
+// As transforms operate on strings and DNS requests are binary, the transform
+// should be expressed using hex characters. The DNS packet to be written
+// (input the Write) is converted to hex, transformed, and converted back to
+// binary and then actually written to the UDP socket.
+type transformDNSPacketConn struct {
+	*net.UDPConn
+	transform transforms.Spec
+	seed      *prng.Seed
+}
+
+func (conn *transformDNSPacketConn) Write(b []byte) (int, error) {
+
+	// Limitation: there is no check that a transformed packet remains within
+	// the network packet MTU.
+
+	input := hex.EncodeToString(b)
+	output, err := conn.transform.Apply(conn.seed, input)
+	if err != nil {
+		return 0, errors.Trace(err)
+	}
+	packet, err := hex.DecodeString(output)
+	if err != nil {
+		return 0, errors.Trace(err)
+	}
+
+	_, err = conn.UDPConn.Write(packet)
+	if err != nil {
+		// In the error case, don't report bytes written as the number could
+		// exceed the pre-transform length.
+		return 0, errors.Trace(err)
+	}
+
+	// Report the pre-transform length as bytes written, as the caller may check
+	// that the requested len(b) bytes were written.
+	return len(b), nil
+}

+ 730 - 0
psiphon/common/resolver/resolver_test.go

@@ -0,0 +1,730 @@
+/*
+ * Copyright (c) 2022, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package resolver
+
+import (
+	"context"
+	"fmt"
+	"net"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
+	"github.com/miekg/dns"
+)
+
+func TestMakeResolveParameters(t *testing.T) {
+	err := runTestMakeResolveParameters()
+	if err != nil {
+		t.Fatalf(errors.Trace(err).Error())
+	}
+}
+
+func TestResolver(t *testing.T) {
+	err := runTestResolver()
+	if err != nil {
+		t.Fatalf(errors.Trace(err).Error())
+	}
+}
+
+func TestPublicDNSServers(t *testing.T) {
+	IPs, metrics, err := runTestPublicDNSServers()
+	if err != nil {
+		t.Fatalf(errors.Trace(err).Error())
+	}
+	t.Logf("IPs: %v", IPs)
+	t.Logf("Metrics: %v", metrics)
+}
+
+func runTestMakeResolveParameters() error {
+
+	frontingProviderID := "frontingProvider"
+	alternateDNSServer := "172.16.0.1"
+	transformName := "exampleTransform"
+
+	paramValues := map[string]interface{}{
+		"DNSResolverPreresolvedIPAddressProbability":  1.0,
+		"DNSResolverPreresolvedIPAddressCIDRs":        parameters.LabeledCIDRs{frontingProviderID: []string{exampleIPv4CIDR}},
+		"DNSResolverAlternateServers":                 []string{alternateDNSServer},
+		"DNSResolverPreferAlternateServerProbability": 1.0,
+		"DNSResolverProtocolTransformProbability":     1.0,
+		"DNSResolverProtocolTransformSpecs":           transforms.Specs{transformName: exampleTransform},
+		"DNSResolverProtocolTransformScopedSpecNames": transforms.ScopedSpecNames{alternateDNSServer: []string{transformName}},
+		"DNSResolverIncludeEDNS0Probability":          1.0,
+	}
+
+	params, err := parameters.NewParameters(nil)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	_, err = params.Set("", false, paramValues)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	resolver, err := NewResolver(&NetworkConfig{}, "")
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer resolver.Stop()
+
+	resolverParams, err := resolver.MakeResolveParameters(
+		params.Get(), frontingProviderID)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: PreresolvedIPAddress
+
+	CIDRContainsIP := func(CIDR, IP string) bool {
+		_, IPNet, _ := net.ParseCIDR(CIDR)
+		return IPNet.Contains(net.ParseIP(IP))
+	}
+
+	if resolverParams.AttemptsPerServer != 2 ||
+		resolverParams.RequestTimeout != 5*time.Second ||
+		resolverParams.AwaitTimeout != 100*time.Millisecond ||
+		!CIDRContainsIP(exampleIPv4CIDR, resolverParams.PreresolvedIPAddress) ||
+		resolverParams.AlternateDNSServer != "" ||
+		resolverParams.PreferAlternateDNSServer != false ||
+		resolverParams.ProtocolTransformName != "" ||
+		resolverParams.ProtocolTransformSpec != nil ||
+		resolverParams.IncludeEDNS0 != false {
+		return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
+	}
+
+	// Test: additional generateIPAddressFromCIDR cases
+
+	for i := 0; i < 10000; i++ {
+		for _, CIDR := range []string{exampleIPv4CIDR, exampleIPv6CIDR} {
+			IP, err := generateIPAddressFromCIDR(CIDR)
+			if err != nil {
+				return errors.Trace(err)
+			}
+			if !CIDRContainsIP(CIDR, IP.String()) || common.IsBogon(IP) {
+				return errors.Tracef(
+					"invalid generated IP address %v for CIDR %v", IP, CIDR)
+			}
+		}
+	}
+
+	// Test: Alternate/Transform/EDNS(0)
+
+	paramValues["DNSResolverPreresolvedIPAddressProbability"] = 0.0
+
+	_, err = params.Set("", false, paramValues)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	resolverParams, err = resolver.MakeResolveParameters(
+		params.Get(), frontingProviderID)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if resolverParams.AttemptsPerServer != 2 ||
+		resolverParams.RequestTimeout != 5*time.Second ||
+		resolverParams.AwaitTimeout != 100*time.Millisecond ||
+		resolverParams.PreresolvedIPAddress != "" ||
+		resolverParams.AlternateDNSServer != alternateDNSServer ||
+		resolverParams.PreferAlternateDNSServer != true ||
+		resolverParams.ProtocolTransformName != transformName ||
+		resolverParams.ProtocolTransformSpec == nil ||
+		resolverParams.IncludeEDNS0 != true {
+		return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
+	}
+
+	// Test: No Alternate/Transform/EDNS(0)
+
+	paramValues["DNSResolverPreferAlternateServerProbability"] = 0.0
+	paramValues["DNSResolverProtocolTransformProbability"] = 0.0
+	paramValues["DNSResolverIncludeEDNS0Probability"] = 0.0
+
+	_, err = params.Set("", false, paramValues)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	resolverParams, err = resolver.MakeResolveParameters(
+		params.Get(), frontingProviderID)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if resolverParams.AttemptsPerServer != 2 ||
+		resolverParams.RequestTimeout != 5*time.Second ||
+		resolverParams.AwaitTimeout != 100*time.Millisecond ||
+		resolverParams.PreresolvedIPAddress != "" ||
+		resolverParams.AlternateDNSServer != alternateDNSServer ||
+		resolverParams.PreferAlternateDNSServer != false ||
+		resolverParams.ProtocolTransformName != "" ||
+		resolverParams.ProtocolTransformSpec != nil ||
+		resolverParams.IncludeEDNS0 != false {
+		return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
+	}
+
+	return nil
+}
+
+func runTestResolver() error {
+
+	// noResponseServer will not respond to requests
+	noResponseServer, err := newTestDNSServer(false, false, false)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer noResponseServer.stop()
+
+	// invalidIPServer will respond with an invalid IP
+	invalidIPServer, err := newTestDNSServer(true, false, false)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer invalidIPServer.stop()
+
+	// okServer will respond to correct requests (expected domain) with the
+	// correct response (expected IPv4 or IPv6 address)
+	okServer, err := newTestDNSServer(true, true, false)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer okServer.stop()
+
+	// alternateOkServer behaves like okServer; getRequestCount is used to
+	// confirm that the alternate server was indeed used
+	alternateOkServer, err := newTestDNSServer(true, true, false)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer alternateOkServer.stop()
+
+	// transformOkServer behaves like okServer but only responds if the
+	// transform was applied; other servers do not respond if the transform
+	// is applied
+	transformOkServer, err := newTestDNSServer(true, true, true)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer transformOkServer.stop()
+
+	servers := []string{noResponseServer.getAddr(), invalidIPServer.getAddr(), okServer.getAddr()}
+
+	networkConfig := &NetworkConfig{
+		GetDNSServers: func() []string { return servers },
+		LogWarning:    func(err error) { fmt.Printf("LogWarning: %v\n", err) },
+	}
+
+	networkID := "networkID-1"
+
+	resolver, err := NewResolver(networkConfig, networkID)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	defer resolver.Stop()
+
+	params := &ResolveParameters{
+		AttemptsPerServer: 1,
+		RequestTimeout:    250 * time.Millisecond,
+		AwaitTimeout:      250 * time.Millisecond,
+		IncludeEDNS0:      true,
+	}
+
+	checkResult := func(IPs []net.IP) error {
+		var IPv4, IPv6 net.IP
+		for _, IP := range IPs {
+			if IP.To4() != nil {
+				IPv4 = IP
+			} else {
+				IPv6 = IP
+			}
+		}
+		if IPv4 == nil {
+			return errors.TraceNew("missing IPv4 response")
+		}
+		if IPv4.String() != exampleIPv4 {
+			return errors.TraceNew("unexpected IPv4 response")
+		}
+		if resolver.hasIPv6Route {
+			if IPv6 == nil {
+				return errors.TraceNew("missing IPv6 response")
+			}
+			if IPv6.String() != exampleIPv6 {
+				return errors.TraceNew("unexpected IPv6 response")
+			}
+		}
+		return nil
+	}
+
+	ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancelFunc()
+
+	// Test: should retry until okServer responds
+
+	IPs, err := resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = checkResult(IPs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if resolver.metrics.resolves != 1 ||
+		resolver.metrics.cacheHits != 0 ||
+		resolver.metrics.requestsIPv4 != 3 || resolver.metrics.responsesIPv4 != 1 ||
+		(resolver.hasIPv6Route && (resolver.metrics.requestsIPv6 != 3 || resolver.metrics.responsesIPv6 != 1)) {
+		return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
+	}
+
+	// Test: cached response
+
+	beforeMetrics := resolver.metrics
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = checkResult(IPs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
+		resolver.metrics.cacheHits != beforeMetrics.cacheHits+1 ||
+		resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
+		resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
+		return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
+	}
+
+	// Test: PreresolvedIPAddress
+
+	beforeMetrics = resolver.metrics
+
+	params.PreresolvedIPAddress = exampleIPv4
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
+		return errors.TraceNew("unexpected preresolved response")
+	}
+
+	if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
+		resolver.metrics.cacheHits != beforeMetrics.cacheHits ||
+		resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
+		resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
+		return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
+	}
+
+	params.PreresolvedIPAddress = ""
+
+	// Test: change network ID, which must clear cache
+
+	beforeMetrics = resolver.metrics
+
+	networkID = "networkID-2"
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = checkResult(IPs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
+		resolver.metrics.cacheHits != beforeMetrics.cacheHits {
+		return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
+	}
+
+	// Test: PreferAlternateDNSServer
+
+	if alternateOkServer.getRequestCount() != 0 {
+		return errors.TraceNew("unexpected alternate server request count")
+	}
+
+	resolver.cache.Flush()
+
+	params.AlternateDNSServer = alternateOkServer.getAddr()
+	params.PreferAlternateDNSServer = true
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = checkResult(IPs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if alternateOkServer.getRequestCount() < 1 {
+		return errors.TraceNew("unexpected alternate server request count")
+	}
+
+	params.AlternateDNSServer = ""
+	params.PreferAlternateDNSServer = false
+
+	// Test: fall over to AlternateDNSServer when no system servers
+
+	beforeCount := alternateOkServer.getRequestCount()
+
+	previousGetDNSServers := networkConfig.GetDNSServers
+
+	networkConfig.GetDNSServers = func() []string { return nil }
+
+	// Force system servers update
+	networkID = "networkID-3"
+
+	resolver.cache.Flush()
+
+	params.AlternateDNSServer = alternateOkServer.getAddr()
+	params.PreferAlternateDNSServer = false
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = checkResult(IPs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if alternateOkServer.getRequestCount() <= beforeCount {
+		return errors.TraceNew("unexpected alterate server request count")
+	}
+
+	// Test: use default, standard resolver when no servers
+
+	resolver.cache.Flush()
+
+	params.AlternateDNSServer = ""
+	params.PreferAlternateDNSServer = false
+
+	if len(resolver.systemServers) != 0 {
+		return errors.TraceNew("unexpected server count")
+	}
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if len(IPs) == 0 {
+		return errors.TraceNew("unexpected response")
+	}
+
+	// Test: ResolveAddress
+
+	networkConfig.GetDNSServers = previousGetDNSServers
+
+	// Force system servers update
+	networkID = "networkID-4"
+
+	domainAddress := net.JoinHostPort(exampleDomain, "443")
+
+	address, err := resolver.ResolveAddress(ctx, networkID, params, domainAddress)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	host, port, err := net.SplitHostPort(address)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	IP := net.ParseIP(host)
+
+	if IP == nil || (host != exampleIPv4 && host != exampleIPv6) || port != "443" {
+		return errors.TraceNew("unexpected response")
+	}
+
+	// Test: protocol transform
+
+	if transformOkServer.getRequestCount() != 0 {
+		return errors.TraceNew("unexpected transform server request count")
+	}
+
+	resolver.cache.Flush()
+
+	params.AlternateDNSServer = transformOkServer.getAddr()
+	params.PreferAlternateDNSServer = true
+
+	seed, err := prng.NewSeed()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	params.ProtocolTransformName = "exampleTransform"
+	params.ProtocolTransformSpec = exampleTransform
+	params.ProtocolTransformSeed = seed
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = checkResult(IPs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if transformOkServer.getRequestCount() < 1 {
+		return errors.TraceNew("unexpected transform server request count")
+	}
+
+	params.AlternateDNSServer = ""
+	params.PreferAlternateDNSServer = false
+	params.ProtocolTransformName = ""
+	params.ProtocolTransformSpec = nil
+	params.ProtocolTransformSeed = nil
+
+	// Test: EDNS(0)
+
+	resolver.cache.Flush()
+
+	params.IncludeEDNS0 = true
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = checkResult(IPs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	params.IncludeEDNS0 = false
+
+	// Test: input IP address
+
+	beforeMetrics = resolver.metrics
+
+	resolver.cache.Flush()
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleIPv4)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
+		return errors.TraceNew("unexpected IPv4 response")
+	}
+
+	if resolver.metrics.resolves != beforeMetrics.resolves {
+		return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
+	}
+
+	// Test: cancel context
+
+	resolver.cache.Flush()
+
+	cancelFunc()
+
+	IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
+	if err == nil {
+		return errors.TraceNew("unexpected success")
+	}
+
+	return nil
+}
+
+func runTestPublicDNSServers() ([]net.IP, string, error) {
+
+	networkConfig := &NetworkConfig{
+		GetDNSServers: getPublicDNSServers,
+	}
+
+	networkID := "networkID-1"
+
+	resolver, err := NewResolver(networkConfig, networkID)
+	if err != nil {
+		return nil, "", errors.Trace(err)
+	}
+	defer resolver.Stop()
+
+	params := &ResolveParameters{
+		AttemptsPerServer: 1,
+		RequestTimeout:    5 * time.Second,
+		AwaitTimeout:      1 * time.Second,
+		IncludeEDNS0:      true,
+	}
+
+	IPs, err := resolver.ResolveIP(
+		context.Background(), networkID, params, exampleDomain)
+	if err != nil {
+		return nil, "", errors.Trace(err)
+	}
+
+	gotIPv4 := false
+	gotIPv6 := false
+	for _, IP := range IPs {
+		if IP.To4() != nil {
+			gotIPv4 = true
+		} else {
+			gotIPv6 = true
+		}
+	}
+	if !gotIPv4 {
+		return nil, "", errors.TraceNew("missing IPv4 response")
+	}
+	if !gotIPv6 && resolver.hasIPv6Route {
+		return nil, "", errors.TraceNew("missing IPv6 response")
+	}
+
+	return IPs, resolver.GetMetrics(), nil
+}
+
+func getPublicDNSServers() []string {
+	servers := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9"}
+	shuffledServers := make([]string, len(servers))
+	for i, j := range prng.Perm(len(servers)) {
+		shuffledServers[i] = servers[j]
+	}
+	return shuffledServers
+}
+
+const (
+	exampleDomain   = "example.com"
+	exampleIPv4     = "93.184.216.34"
+	exampleIPv4CIDR = "93.184.216.0/24"
+	exampleIPv6     = "2606:2800:220:1:248:1893:25c8:1946"
+	exampleIPv6CIDR = "2606:2800:220::/48"
+)
+
+// Set the reserved Z flag
+var exampleTransform = transforms.Spec{[2]string{"^([a-f0-9]{4})0100", "\\$\\{1\\}0140"}}
+
+type testDNSServer struct {
+	respond         bool
+	validResponse   bool
+	expectTransform bool
+	addr            string
+	requestCount    int32
+	server          *dns.Server
+}
+
+func newTestDNSServer(respond, validResponse, expectTransform bool) (*testDNSServer, error) {
+
+	udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	udpConn, err := net.ListenUDP("udp", udpAddr)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	s := &testDNSServer{
+		respond:         respond,
+		validResponse:   validResponse,
+		expectTransform: expectTransform,
+		addr:            udpConn.LocalAddr().String(),
+	}
+
+	server := &dns.Server{
+		PacketConn: udpConn,
+		Handler:    s,
+	}
+
+	s.server = server
+
+	go server.ActivateAndServe()
+
+	return s, nil
+}
+
+func (s *testDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+	atomic.AddInt32(&s.requestCount, 1)
+
+	if !s.respond {
+		return
+	}
+
+	// Check the reserved Z flag
+	if s.expectTransform != r.MsgHdr.Zero {
+		return
+	}
+
+	if len(r.Question) != 1 || r.Question[0].Name != dns.Fqdn(exampleDomain) {
+		return
+	}
+
+	m := new(dns.Msg)
+	m.SetReply(r)
+	m.Answer = make([]dns.RR, 1)
+	if r.Question[0].Qtype == dns.TypeA {
+		IP := net.ParseIP(exampleIPv4)
+		if !s.validResponse {
+			IP = net.ParseIP("127.0.0.1")
+		}
+		m.Answer[0] = &dns.A{
+			Hdr: dns.RR_Header{
+				Name:   r.Question[0].Name,
+				Rrtype: dns.TypeA,
+				Class:  dns.ClassINET,
+				Ttl:    60},
+			A: IP,
+		}
+	} else {
+		IP := net.ParseIP(exampleIPv6)
+		if !s.validResponse {
+			IP = net.ParseIP("::1")
+		}
+		m.Answer[0] = &dns.AAAA{
+			Hdr: dns.RR_Header{
+				Name:   r.Question[0].Name,
+				Rrtype: dns.TypeAAAA,
+				Class:  dns.ClassINET,
+				Ttl:    60},
+			AAAA: IP,
+		}
+	}
+
+	w.WriteMsg(m)
+}
+
+func (s *testDNSServer) getAddr() string {
+	return s.addr
+}
+
+func (s *testDNSServer) getRequestCount() int {
+	return int(atomic.LoadInt32(&s.requestCount))
+}
+
+func (s *testDNSServer) stop() {
+	s.server.PacketConn.Close()
+	s.server.Shutdown()
+}

+ 174 - 0
psiphon/common/transforms/transforms.go

@@ -0,0 +1,174 @@
+/*
+ * Copyright (c) 2022, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+// Package transforms provides a mechanism to define and apply string data
+// transformations, with the transformations defined by regular expressions
+// to match data to be transformed, and regular expression generators to
+// specify additional or replacement data.
+package transforms
+
+import (
+	"regexp"
+	"regexp/syntax"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	regen "github.com/zach-klippenstein/goregen"
+)
+
+const (
+	SCOPE_ANY = ""
+)
+
+// Spec is a transform spec. A spec is a list of individual transforms to be
+// applied in order. Each transform is defined by two elements: a regular
+// expression to by matched against the input; and a regular expression
+// generator which generates new data. Subgroups from the regular expression
+// may be specified in the regular expression generator, and are populated
+// with the subgroup match, and in this way parts of the original matching
+// data may be retained in the transformed data.
+//
+// For example, with the transform [2]string{"([a-b])", "\\$\\
+// {1\\}"c}, substrings consisting of the characters 'a' and 'b' will be
+// transformed into the same substring with a single character 'c' appended.
+type Spec [][2]string
+
+// Specs is a set of named Specs.
+type Specs map[string]Spec
+
+// Validate checks that all entries in a set of Specs is well-formed, with
+// valid regular expressions.
+func (specs Specs) Validate() error {
+	seed, err := prng.NewSeed()
+	if err != nil {
+		return errors.Trace(err)
+	}
+	for _, spec := range specs {
+		// Call Apply to compile/validate the regular expressions and generators.
+		_, err := spec.Apply(seed, "")
+		if err != nil {
+			return errors.Trace(err)
+		}
+	}
+
+	return nil
+}
+
+// ScopedSpecNames defines groups a list of Specs, referenced by their Spec
+// name, with the group defined by a scope. The meaning of scope depends on
+// the context in which the transforms are to be used.
+//
+// For example, in the context of DNS request transforms, the scope is the DNS
+// server for which a specific group of transforms is known to be effective.
+//
+// The scope name "" is SCOPE_ANY, and matches any input scope name when there
+// is no specific entry for that scope name in ScopedSpecNames.
+type ScopedSpecNames map[string][]string
+
+// Validate checks that the ScopedSpecNames is well-formed and referenced Spec
+// names are defined in the corresponding input specs.
+func (scopedSpecs ScopedSpecNames) Validate(specs Specs) error {
+
+	for _, scoped := range scopedSpecs {
+		for _, specName := range scoped {
+			_, ok := specs[specName]
+			if !ok {
+				return errors.Tracef("undefined spec name: %s", specName)
+			}
+		}
+	}
+
+	return nil
+}
+
+// Select picks a Spec from Specs based on the input scope and scoping rules.
+// If the input scope name is defined in scopedSpecs, that match takes
+// precedence. Otherwise SCOPE_ANY is selected, when present.
+//
+// After the scope is resolved, Select randomly selects from the matching Spec
+// list.
+//
+// Select will return "", nil when no selection can be made.
+func (specs Specs) Select(scope string, scopedSpecs ScopedSpecNames) (string, Spec) {
+
+	if scope != SCOPE_ANY {
+		scoped, ok := scopedSpecs[scope]
+		if ok {
+			// If the specific scope is defined but empty, this means select
+			// nothing -- don't fall through to SCOPE_ANY.
+			if len(scoped) == 0 {
+				return "", nil
+			}
+
+			specName := scoped[prng.Intn(len(scoped))]
+			spec, ok := specs[specName]
+			if !ok {
+				// specName is not found in specs, which should not happen if
+				// Validate passes; select nothing in this case.
+				return "", nil
+			}
+			return specName, spec
+		}
+		// Fall through to SCOPE_ANY.
+	}
+
+	anyScope, ok := scopedSpecs[SCOPE_ANY]
+	if !ok || len(anyScope) == 0 {
+		// No SCOPE_ANY, or SCOPE_ANY is an empty list.
+		return "", nil
+	}
+
+	specName := anyScope[prng.Intn(len(anyScope))]
+	spec, ok := specs[specName]
+	if !ok {
+		return "", nil
+	}
+	return specName, spec
+}
+
+// Apply applies the Spec to the input string, producting the output string.
+//
+// The input seed is used for all random generation. The same seed can be
+// supplied to produce the same output, for replay.
+func (spec Spec) Apply(seed *prng.Seed, input string) (string, error) {
+
+	// TODO: complied regexp and regen could be cached, but the seed is an
+	// issue with the regen.
+
+	value := input
+	for _, transform := range spec {
+
+		args := &regen.GeneratorArgs{
+			RngSource: prng.NewPRNGWithSeed(seed),
+			Flags:     syntax.OneLine | syntax.NonGreedy,
+		}
+		rg, err := regen.NewGenerator(transform[1], args)
+		if err != nil {
+			panic(err.Error())
+		}
+		replacement := rg.Generate()
+		if err != nil {
+			panic(err.Error())
+		}
+
+		re := regexp.MustCompile(transform[0])
+		value = re.ReplaceAllString(value, replacement)
+	}
+	return value, nil
+}

+ 140 - 0
psiphon/common/transforms/transforms_test.go

@@ -0,0 +1,140 @@
+/*
+ * Copyright (c) 2022, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package transforms
+
+import (
+	"reflect"
+	"strings"
+	"testing"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+)
+
+func TestTransforms(t *testing.T) {
+	err := runTestTransforms()
+	if err != nil {
+		t.Fatalf(errors.Trace(err).Error())
+	}
+}
+
+func runTestTransforms() error {
+
+	transformNameAny := "exampleTransform1"
+	transformNameScoped := "exampleTransform2"
+	scopeName := "exampleScope"
+
+	specs := Specs{
+		transformNameAny: Spec{[2]string{"x", "y"}},
+		transformNameScoped: Spec{
+			[2]string{"aa", "cc"},
+			[2]string{"bb", "(dd|ee)"},
+			[2]string{"^([c0]{6})", "\\$\\{1\\}ff0"},
+		},
+	}
+
+	scopedSpecs := ScopedSpecNames{
+		SCOPE_ANY: []string{transformNameAny},
+		scopeName: []string{transformNameScoped},
+	}
+
+	// Test: validation
+
+	err := specs.Validate()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = scopedSpecs.Validate(specs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	// Test: select based on scope
+
+	name, spec := specs.Select(SCOPE_ANY, scopedSpecs)
+	if name != transformNameAny || !reflect.DeepEqual(spec, specs[transformNameAny]) {
+		return errors.TraceNew("unexpected select result")
+	}
+
+	name, spec = specs.Select(scopeName, scopedSpecs)
+	if name != transformNameScoped || !reflect.DeepEqual(spec, specs[transformNameScoped]) {
+		return errors.TraceNew("unexpected select result")
+	}
+
+	// Test: correct transform (assumes spec is transformNameScoped)
+
+	seed, err := prng.NewSeed()
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	input := "aa0aa0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa"
+	output, err := spec.Apply(seed, input)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if !strings.HasPrefix(output, "cc0cc0ff0") ||
+		strings.IndexAny(output, "ab") != -1 ||
+		strings.IndexAny(output, "de") == -1 {
+		return errors.Tracef("unexpected apply result: %s", output)
+	}
+
+	// Test: same result with same seed
+
+	previousOutput := output
+
+	output, err = spec.Apply(seed, input)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if output != previousOutput {
+		return errors.Tracef("unexpected different apply result")
+	}
+
+	// Test: different result with different seed (with high probability)
+
+	different := false
+	for i := 0; i < 1000; i++ {
+
+		seed, err = prng.NewSeed()
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		output, err = spec.Apply(seed, input)
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		if output != previousOutput {
+			different = true
+			break
+		}
+	}
+
+	if !different {
+		return errors.Tracef("unexpected identical apply result")
+	}
+
+	return nil
+}