Просмотр исходного кода

Inproxy port mapping fixes

- Probe data must be present in order to establish an actual port mapping, at
  least in the case of UPnP. Instead of invoking Probe once per answer event,
  copy the necessary data from the upfront NAT discovery probe.

- Don't attempt port mappings for clients unless the optional NAT discovery is
  run.

- Await an expected port mapping even after ICE candidate gathering completes
  in the case where STUN is skipped.
Rod Hynes 1 год назад
Родитель
Сommit
75ef9756db

+ 9 - 13
psiphon/common/inproxy/client.go

@@ -178,25 +178,21 @@ func DialClient(
 		// synchronously, so that NAT topology metrics can be reported to the
 		// synchronously, so that NAT topology metrics can be reported to the
 		// broker in the ClientOffer request. For clients, NAT discovery is
 		// broker in the ClientOffer request. For clients, NAT discovery is
 		// intended to be performed at a low sampling rate, since the RFC5780
 		// intended to be performed at a low sampling rate, since the RFC5780
-		// traffic may be unusual(differs from standard STUN requests for
-		// ICE) and since this step delays the dial. Clients should to cache
-		// their NAT discovery outcomes, associated with the current network
-		// by network ID, so metrics can be reported even without a discovery
-		// step; this is facilitated by WebRTCDialCoordinator.
+		// traffic may be unusual (differs from standard STUN requests for
+		// ICE), the port mapping probe traffic may be unusual, and since
+		// this step delays the dial. Clients should to cache their NAT
+		// discovery outcomes, associated with the current network by network
+		// ID, so metrics can be reported even without a discovery step; this
+		// is facilitated by WebRTCDialCoordinator.
 		//
 		//
 		// NAT topology metrics are used by the broker to optimize client and
 		// NAT topology metrics are used by the broker to optimize client and
 		// in-proxy matching.
 		// in-proxy matching.
-		//
-		// For client NAT discovery, port mapping type discovery is skipped
-		// since port mappings are attempted when preparing the WebRTC offer,
-		// which also happens before the ClientOffer request.
 
 
 		NATDiscover(
 		NATDiscover(
 			ctx,
 			ctx,
 			&NATDiscoverConfig{
 			&NATDiscoverConfig{
 				Logger:                config.Logger,
 				Logger:                config.Logger,
 				WebRTCDialCoordinator: config.WebRTCDialCoordinator,
 				WebRTCDialCoordinator: config.WebRTCDialCoordinator,
-				SkipPortMapping:       true,
 			})
 			})
 	}
 	}
 
 
@@ -337,7 +333,7 @@ func dialClientWebRTCConn(
 	trafficShapingParameters := config.WebRTCDialCoordinator.DataChannelTrafficShapingParameters()
 	trafficShapingParameters := config.WebRTCDialCoordinator.DataChannelTrafficShapingParameters()
 	clientRootObfuscationSecret := config.WebRTCDialCoordinator.ClientRootObfuscationSecret()
 	clientRootObfuscationSecret := config.WebRTCDialCoordinator.ClientRootObfuscationSecret()
 
 
-	webRTCConn, SDP, SDPMetrics, err := newWebRTCConnWithOffer(
+	webRTCConn, SDP, SDPMetrics, err := newWebRTCConnForOffer(
 		ctx, &webRTCConfig{
 		ctx, &webRTCConfig{
 			Logger:                      config.Logger,
 			Logger:                      config.Logger,
 			EnableDebugLogging:          config.EnableWebRTCDebugLogging,
 			EnableDebugLogging:          config.EnableWebRTCDebugLogging,
@@ -367,8 +363,8 @@ func dialClientWebRTCConn(
 
 
 	// Here, WebRTCDialCoordinator.NATType may be populated from discovery, or
 	// Here, WebRTCDialCoordinator.NATType may be populated from discovery, or
 	// replayed from a previous run on the same network ID.
 	// replayed from a previous run on the same network ID.
-	// WebRTCDialCoordinator.PortMappingTypes may be populated via
-	// newWebRTCConnWithOffer.
+	// WebRTCDialCoordinator.PortMappingTypes/PortMappingProbe may be
+	// populated via the optional NATDiscover run above or in a previous dial.
 
 
 	// ClientOffer applies BrokerDialCoordinator.OfferRequestTimeout or
 	// ClientOffer applies BrokerDialCoordinator.OfferRequestTimeout or
 	// OfferRequestPersonalTimeout as the request timeout.
 	// OfferRequestPersonalTimeout as the request timeout.

+ 10 - 0
psiphon/common/inproxy/coordinator.go

@@ -305,6 +305,15 @@ type WebRTCDialCoordinator interface {
 	// re-run port mapping discovery.
 	// re-run port mapping discovery.
 	SetPortMappingTypes(t PortMappingTypes)
 	SetPortMappingTypes(t PortMappingTypes)
 
 
+	// PortMappingProbe returns any persisted PortMappingProbe for the current
+	// network, which is used to establish port mappings.
+	PortMappingProbe() *PortMappingProbe
+
+	// SetPortMappingProbe receives a PortMappingProbe instance, which caches
+	// complete port mapping service details and is a required input for
+	// subsequent port mapping establishment on the current network.
+	SetPortMappingProbe(p *PortMappingProbe)
+
 	// ResolveAddress resolves a domain and returns its IP address. Clients
 	// ResolveAddress resolves a domain and returns its IP address. Clients
 	// and proxies may use this to hook into the Psiphon custom resolver. The
 	// and proxies may use this to hook into the Psiphon custom resolver. The
 	// provider adds the custom resolver tactics and network ID parameters
 	// provider adds the custom resolver tactics and network ID parameters
@@ -343,6 +352,7 @@ type WebRTCDialCoordinator interface {
 
 
 	DiscoverNATTimeout() time.Duration
 	DiscoverNATTimeout() time.Duration
 	WebRTCAnswerTimeout() time.Duration
 	WebRTCAnswerTimeout() time.Duration
+	WebRTCAwaitPortMappingTimeout() time.Duration
 	WebRTCAwaitDataChannelTimeout() time.Duration
 	WebRTCAwaitDataChannelTimeout() time.Duration
 	ProxyDestinationDialTimeout() time.Duration
 	ProxyDestinationDialTimeout() time.Duration
 	ProxyRelayInactivityTimeout() time.Duration
 	ProxyRelayInactivityTimeout() time.Duration

+ 20 - 0
psiphon/common/inproxy/coordinator_test.go

@@ -203,10 +203,12 @@ type testWebRTCDialCoordinator struct {
 	natType                         NATType
 	natType                         NATType
 	setNATType                      func(NATType)
 	setNATType                      func(NATType)
 	portMappingTypes                PortMappingTypes
 	portMappingTypes                PortMappingTypes
+	portMappingProbe                *PortMappingProbe
 	setPortMappingTypes             func(PortMappingTypes)
 	setPortMappingTypes             func(PortMappingTypes)
 	bindToDevice                    func(int) error
 	bindToDevice                    func(int) error
 	discoverNATTimeout              time.Duration
 	discoverNATTimeout              time.Duration
 	webRTCAnswerTimeout             time.Duration
 	webRTCAnswerTimeout             time.Duration
+	webRTCAwaitPortMappingTimeout   time.Duration
 	webRTCAwaitDataChannelTimeout   time.Duration
 	webRTCAwaitDataChannelTimeout   time.Duration
 	proxyDestinationDialTimeout     time.Duration
 	proxyDestinationDialTimeout     time.Duration
 	proxyRelayInactivityTimeout     time.Duration
 	proxyRelayInactivityTimeout     time.Duration
@@ -319,6 +321,18 @@ func (t *testWebRTCDialCoordinator) SetPortMappingTypes(portMappingTypes PortMap
 	t.setPortMappingTypes(portMappingTypes)
 	t.setPortMappingTypes(portMappingTypes)
 }
 }
 
 
+func (t *testWebRTCDialCoordinator) PortMappingProbe() *PortMappingProbe {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.portMappingProbe
+}
+
+func (t *testWebRTCDialCoordinator) SetPortMappingProbe(portMappingProbe *PortMappingProbe) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	t.portMappingProbe = portMappingProbe
+}
+
 func (t *testWebRTCDialCoordinator) ResolveAddress(ctx context.Context, network, address string) (string, error) {
 func (t *testWebRTCDialCoordinator) ResolveAddress(ctx context.Context, network, address string) (string, error) {
 
 
 	// Note: can't use common/resolver due to import cycle
 	// Note: can't use common/resolver due to import cycle
@@ -389,6 +403,12 @@ func (t *testWebRTCDialCoordinator) WebRTCAnswerTimeout() time.Duration {
 	return t.webRTCAnswerTimeout
 	return t.webRTCAnswerTimeout
 }
 }
 
 
+func (t *testWebRTCDialCoordinator) WebRTCAwaitPortMappingTimeout() time.Duration {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.webRTCAwaitPortMappingTimeout
+}
+
 func (t *testWebRTCDialCoordinator) WebRTCAwaitDataChannelTimeout() time.Duration {
 func (t *testWebRTCDialCoordinator) WebRTCAwaitDataChannelTimeout() time.Duration {
 	t.mutex.Lock()
 	t.mutex.Lock()
 	defer t.mutex.Unlock()
 	defer t.mutex.Unlock()

+ 10 - 8
psiphon/common/inproxy/discovery.go

@@ -66,7 +66,8 @@ func NATDiscover(
 	// mapping discovery are run concurrently.
 	// mapping discovery are run concurrently.
 
 
 	discoverCtx, cancelFunc := context.WithTimeout(
 	discoverCtx, cancelFunc := context.WithTimeout(
-		ctx, common.ValueOrDefault(config.WebRTCDialCoordinator.DiscoverNATTimeout(), discoverNATTimeout))
+		ctx, common.ValueOrDefault(
+			config.WebRTCDialCoordinator.DiscoverNATTimeout(), discoverNATTimeout))
 	defer cancelFunc()
 	defer cancelFunc()
 
 
 	discoveryWaitGroup := new(sync.WaitGroup)
 	discoveryWaitGroup := new(sync.WaitGroup)
@@ -102,13 +103,14 @@ func NATDiscover(
 		go func() {
 		go func() {
 			defer discoveryWaitGroup.Done()
 			defer discoveryWaitGroup.Done()
 
 
-			portMappingTypes, err := discoverPortMappingTypes(
+			portMappingTypes, portMapperProbe, err := discoverPortMappingTypes(
 				discoverCtx, config.Logger)
 				discoverCtx, config.Logger)
 
 
 			if err == nil {
 			if err == nil {
-				// Deliver the result. The WebRTCDialCoordinator provider may cache
-				// this result, associated wih the current networkID.
+				// Deliver the results. The WebRTCDialCoordinator provider
+				// should cache this data, associated wih the current networkID.
 				config.WebRTCDialCoordinator.SetPortMappingTypes(portMappingTypes)
 				config.WebRTCDialCoordinator.SetPortMappingTypes(portMappingTypes)
+				config.WebRTCDialCoordinator.SetPortMappingProbe(portMapperProbe)
 			}
 			}
 
 
 			config.Logger.WithTraceFields(common.LogFields{
 			config.Logger.WithTraceFields(common.LogFields{
@@ -255,12 +257,12 @@ func discoverNATType(
 
 
 func discoverPortMappingTypes(
 func discoverPortMappingTypes(
 	ctx context.Context,
 	ctx context.Context,
-	logger common.Logger) (PortMappingTypes, error) {
+	logger common.Logger) (PortMappingTypes, *PortMappingProbe, error) {
 
 
-	portMappingTypes, err := probePortMapping(ctx, logger)
+	portMappingTypes, portMapperProbe, err := probePortMapping(ctx, logger)
 	if err != nil {
 	if err != nil {
-		return nil, errors.Trace(err)
+		return nil, nil, errors.Trace(err)
 	}
 	}
 
 
-	return portMappingTypes, nil
+	return portMappingTypes, portMapperProbe, nil
 }
 }

+ 7 - 4
psiphon/common/inproxy/inproxy_disabled.go

@@ -128,7 +128,7 @@ type webRTCSDPMetrics struct {
 	filteredICECandidates []string
 	filteredICECandidates []string
 }
 }
 
 
-func newWebRTCConnWithOffer(
+func newWebRTCConnForOffer(
 	ctx context.Context,
 	ctx context.Context,
 	config *webRTCConfig,
 	config *webRTCConfig,
 	hasPersonalCompartmentIDs bool) (
 	hasPersonalCompartmentIDs bool) (
@@ -136,7 +136,7 @@ func newWebRTCConnWithOffer(
 	return nil, WebRTCSessionDescription{}, nil, errors.Trace(errNotEnabled)
 	return nil, WebRTCSessionDescription{}, nil, errors.Trace(errNotEnabled)
 }
 }
 
 
-func newWebRTCConnWithAnswer(
+func newWebRTCConnForAnswer(
 	ctx context.Context,
 	ctx context.Context,
 	config *webRTCConfig,
 	config *webRTCConfig,
 	peerSDP WebRTCSessionDescription,
 	peerSDP WebRTCSessionDescription,
@@ -159,11 +159,14 @@ func filterSDPAddresses(
 func initPortMapper(coordinator WebRTCDialCoordinator) {
 func initPortMapper(coordinator WebRTCDialCoordinator) {
 }
 }
 
 
+type PortMappingProbe struct {
+}
+
 func probePortMapping(
 func probePortMapping(
 	ctx context.Context,
 	ctx context.Context,
-	logger common.Logger) (PortMappingTypes, error) {
+	logger common.Logger) (PortMappingTypes, *PortMappingProbe, error) {
 
 
-	return nil, errors.Trace(errNotEnabled)
+	return nil, nil, errors.Trace(errNotEnabled)
 }
 }
 
 
 func discoverNATMapping(
 func discoverNATMapping(

+ 163 - 17
psiphon/common/inproxy/portmapper.go

@@ -24,8 +24,11 @@ package inproxy
 import (
 import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
+	"reflect"
+	"runtime/debug"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
+	"unsafe"
 
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -68,10 +71,23 @@ type portMapper struct {
 
 
 // newPortMapper initializes a new port mapper, configured to map to the
 // newPortMapper initializes a new port mapper, configured to map to the
 // specified localPort. newPortMapper does not initiate any network
 // specified localPort. newPortMapper does not initiate any network
-// operations (it's safe to call when DisablePortMapping is set).
+// operations.
+//
+// newPortMapper requires a PortMappingProbe initialized by probePortMapping,
+// as the underlying portmapper.Client.GetCachedMappingOrStartCreatingOne
+// requires data populated by Client.Probe, such as UPnP service
+// information.
+//
+// Rather that run a full Client.Probe per port mapping, the service data from
+// one probe run is reused.
 func newPortMapper(
 func newPortMapper(
 	logger common.Logger,
 	logger common.Logger,
-	localPort int) *portMapper {
+	probe *PortMappingProbe,
+	localPort int) (*portMapper, error) {
+
+	if probe == nil {
+		return nil, errors.TraceNew("missing probe")
+	}
 
 
 	portMappingLogger := func(format string, args ...any) {
 	portMappingLogger := func(format string, args ...any) {
 		logger.WithTrace().Info(
 		logger.WithTrace().Info(
@@ -93,6 +109,9 @@ func newPortMapper(
 	// the p.client reference within callback will be valid.
 	// the p.client reference within callback will be valid.
 
 
 	client := portmapper.NewClient(portMappingLogger, nil, nil, nil, func() {
 	client := portmapper.NewClient(portMappingLogger, nil, nil, nil, func() {
+		if !p.client.HaveMapping() {
+			return
+		}
 		p.havePortMappingOnce.Do(func() {
 		p.havePortMappingOnce.Do(func() {
 			address, ok := p.client.GetCachedMappingOrStartCreatingOne()
 			address, ok := p.client.GetCachedMappingOrStartCreatingOne()
 			if ok {
 			if ok {
@@ -116,12 +135,114 @@ func newPortMapper(
 
 
 	p.client.SetLocalPort(uint16(localPort))
 	p.client.SetLocalPort(uint16(localPort))
 
 
-	return p
+	// Copy the port mapping service data from the input probe.
+	err := p.cloneProbe(probe)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return p, nil
+}
+
+var portmapperDependencyVersionCheck bool
+
+func init() {
+	buildInfo, ok := debug.ReadBuildInfo()
+	if !ok {
+		return
+	}
+	for _, dep := range buildInfo.Deps {
+		if dep.Path == "tailscale.com" && dep.Version == "v1.58.2" {
+			portmapperDependencyVersionCheck = true
+			return
+		}
+	}
+}
+
+// cloneProbe copies the port mapping service data gather by Client.Probe from
+// the input probe client.
+func (p *portMapper) cloneProbe(probe *PortMappingProbe) error {
+
+	// The required portmapper.Client fields are not exported by
+	// tailscale/net/portmapper, so unsafe reflection is used to copy the
+	// values. A simple portmapper.Client struct copy can't be performed as
+	// the struct contain a sync.Mutex field.
+	//
+	// The following is assumed, based on the pinned dependency version:
+	//
+	// - portmapper.Client.Probe is synchronous, so once probe.client.Probe is
+	//   complete, it's safe to read its fields
+	//
+	// - portmapping.Probe does not create a cached mapping.
+	//
+	// - Only Probe populates the copied fields and
+	//   portmapper.Client.GetCachedMappingOrStartCreatingOne merely reads
+	//   them (or clears them, in invalidateMappingsLocked)
+	//
+	// We further assume that the caller synchronizes access to the input
+	// probe, so the probe is idle when cloned
+	// (see Proxy.networkDiscoveryMutex).
+	//
+	// An explicit dependency version pin check is made since potential logic
+	// changes in future versions of the dependency may break the above
+	// assumptions while the reflect operation might still succeed.
+	//
+	// TODO: fork the dependency to add internal support for shared probe
+	// state, trim additional tailscale dependencies, use Psiphon's custom
+	// dialer, and remove globals (see clientmetric.Metrics below).
+
+	if !portmapperDependencyVersionCheck {
+		return errors.TraceNew("dependency version check failed")
+	}
+
+	src := reflect.ValueOf(probe.client).Elem()
+	dst := reflect.ValueOf(p.client).Elem()
+
+	shallowCloneField := func(name string) error {
+		srcField := src.FieldByName(name)
+		dstField := dst.FieldByName(name)
+		// Bypass "reflect: reflect.Value.Set using value obtained using
+		// unexported field" restriction.
+		srcField = reflect.NewAt(
+			srcField.Type(), unsafe.Pointer(srcField.UnsafeAddr())).Elem()
+		dstField = reflect.NewAt(
+			dstField.Type(), unsafe.Pointer(dstField.UnsafeAddr())).Elem()
+		if !srcField.CanSet() || !dstField.CanSet() {
+			return errors.Tracef("%s: cannot set field", name)
+		}
+		dstField.Set(srcField)
+		return nil
+	}
+
+	// As of the pinned dependency version,
+	// portmapper.invalidateMappingsLocked sets uPnPMetas to nil, but doesn't
+	// write to the original slice elements, so a shallow copy is sufficient.
+
+	for _, fieldName := range []string{
+		"lastMyIP",
+		"lastGW",
+		"lastProbe",
+		"pmpPubIP",
+		"pmpPubIPTime",
+		"pmpLastEpoch",
+		"pcpSawTime",
+		"pcpLastEpoch",
+		"uPnPSawTime",
+		"uPnPMetas",
+	} {
+		err := shallowCloneField(fieldName)
+		if err != nil {
+			return errors.Trace(err)
+		}
+	}
+
+	return nil
 }
 }
 
 
 // start initiates the port mapping attempt.
 // start initiates the port mapping attempt.
 func (p *portMapper) start() {
 func (p *portMapper) start() {
 	p.portMappingLogger("started")
 	p.portMappingLogger("started")
+	// There is no cached mapping at this point.
 	_, _ = p.client.GetCachedMappingOrStartCreatingOne()
 	_, _ = p.client.GetCachedMappingOrStartCreatingOne()
 }
 }
 
 
@@ -133,6 +254,11 @@ func (p *portMapper) portMappingExternalAddress() <-chan string {
 
 
 // close releases the port mapping
 // close releases the port mapping
 func (p *portMapper) close() error {
 func (p *portMapper) close() error {
+
+	// TODO: it's not clear whether a concurrent portmapper.Client.createOrGetMapping,
+	// in progress at the time of the portmapper.Client call, will dispose of
+	// any created mapping if it completes after Close.
+
 	err := p.client.Close()
 	err := p.client.Close()
 	p.portMappingLogger("closed")
 	p.portMappingLogger("closed")
 	return errors.Trace(err)
 	return errors.Trace(err)
@@ -147,18 +273,24 @@ func formatPortMappingLog(format string, args ...any) string {
 	return fmt.Sprintf(format, args...)
 	return fmt.Sprintf(format, args...)
 }
 }
 
 
+// PortMappingProbe records information about the port mapping services found
+// in a port mapping service probe.
+type PortMappingProbe struct {
+	client *portmapper.Client
+}
+
 // probePortMapping discovers and reports which port mapping protocols are
 // probePortMapping discovers and reports which port mapping protocols are
-// supported on this network. probePortMapping does not establish a port mapping.
+// supported on this network. probePortMapping does not establish a port
+// mapping. probePortMapping caches a PortMappingProbe for use in subsequent
+// port mapping establishment.
 //
 //
-// It is intended that in-proxies amake a blocking call to probePortMapping on
-// start up (and after a network change) in order to report fresh port
-// mapping type metrics, for matching optimization in the ProxyAnnounce
-// request. Clients don't incur the delay of a probe call -- which produces
-// no port mapping -- and instead opportunistically grab port mapping type
-// metrics via getRespondingPortMappingTypes.
+// It is intended that in-proxy proxies make a blocking call to
+// probePortMapping on start up (and after a network change) in order to
+// report fresh port mapping type metrics, for matching optimization in the
+// ProxyAnnounce request.
 func probePortMapping(
 func probePortMapping(
 	ctx context.Context,
 	ctx context.Context,
-	logger common.Logger) (PortMappingTypes, error) {
+	logger common.Logger) (PortMappingTypes, *PortMappingProbe, error) {
 
 
 	portMappingLogger := func(format string, args ...any) {
 	portMappingLogger := func(format string, args ...any) {
 		logger.WithTrace().Info(
 		logger.WithTrace().Info(
@@ -166,11 +298,10 @@ func probePortMapping(
 	}
 	}
 
 
 	client := portmapper.NewClient(portMappingLogger, nil, nil, nil, nil)
 	client := portmapper.NewClient(portMappingLogger, nil, nil, nil, nil)
-	defer client.Close()
 
 
 	result, err := client.Probe(ctx)
 	result, err := client.Probe(ctx)
 	if err != nil {
 	if err != nil {
-		return nil, errors.Trace(err)
+		return nil, nil, errors.Trace(err)
 	}
 	}
 
 
 	portMappingTypes := PortMappingTypes{}
 	portMappingTypes := PortMappingTypes{}
@@ -184,15 +315,30 @@ func probePortMapping(
 		portMappingTypes = append(portMappingTypes, PortMappingTypePCP)
 		portMappingTypes = append(portMappingTypes, PortMappingTypePCP)
 	}
 	}
 
 
-	// An empty lists means discovery is needed or the available port mappings
-	// are unknown; a list with None indicates that a probe returned no
-	// supported port mapping types.
+	var probe *PortMappingProbe
 
 
 	if len(portMappingTypes) == 0 {
 	if len(portMappingTypes) == 0 {
+
+		// An empty lists means discovery is needed or the available port mappings
+		// are unknown; a list with None indicates that a probe returned no
+		// supported port mapping types.
+
 		portMappingTypes = append(portMappingTypes, PortMappingTypeNone)
 		portMappingTypes = append(portMappingTypes, PortMappingTypeNone)
+
+	} else {
+
+		// Return a probe for use in subsequent port mappings only when
+		// services were found.
+		//
+		// It is not necessary to call PortMappingProbe.client.Close, as it is
+		// not holding open any actual mappings.
+
+		probe = &PortMappingProbe{
+			client: client,
+		}
 	}
 	}
 
 
-	return portMappingTypes, nil
+	return portMappingTypes, probe, nil
 }
 }
 
 
 var respondingPortMappingTypesMutex sync.Mutex
 var respondingPortMappingTypesMutex sync.Mutex

+ 8 - 7
psiphon/common/inproxy/proxy.go

@@ -451,6 +451,8 @@ func (p *Proxy) doNetworkDiscovery(
 	if p.networkDiscoveryRunOnce &&
 	if p.networkDiscoveryRunOnce &&
 		p.networkDiscoveryNetworkID == networkID {
 		p.networkDiscoveryNetworkID == networkID {
 		// Already ran discovery for this network.
 		// Already ran discovery for this network.
+		//
+		// TODO: periodically re-probe for port mapping services?
 		return
 		return
 	}
 	}
 
 
@@ -458,11 +460,11 @@ func (p *Proxy) doNetworkDiscovery(
 	// initPortMapper comment.
 	// initPortMapper comment.
 	initPortMapper(webRTCCoordinator)
 	initPortMapper(webRTCCoordinator)
 
 
-	// Gather local network NAT/port mapping metrics before sending any
-	// announce requests. NAT topology metrics are used by the Broker to
-	// optimize client and in-proxy matching. Unlike the client, we always
-	// perform this synchronous step here, since waiting doesn't necessarily
-	// block a client tunnel dial.
+	// Gather local network NAT/port mapping metrics and configuration before
+	// sending any announce requests. NAT topology metrics are used by the
+	// Broker to optimize client and in-proxy matching. Unlike the client, we
+	// always perform this synchronous step here, since waiting doesn't
+	// necessarily block a client tunnel dial.
 
 
 	waitGroup := new(sync.WaitGroup)
 	waitGroup := new(sync.WaitGroup)
 	waitGroup.Add(1)
 	waitGroup.Add(1)
@@ -472,7 +474,6 @@ func (p *Proxy) doNetworkDiscovery(
 		// NATDiscover may use cached NAT type/port mapping values from
 		// NATDiscover may use cached NAT type/port mapping values from
 		// DialParameters, based on the network ID. If discovery is not
 		// DialParameters, based on the network ID. If discovery is not
 		// successful, the proxy still proceeds to announce.
 		// successful, the proxy still proceeds to announce.
-
 		NATDiscover(
 		NATDiscover(
 			ctx,
 			ctx,
 			&NATDiscoverConfig{
 			&NATDiscoverConfig{
@@ -717,7 +718,7 @@ func (p *Proxy) proxyOneClient(
 	// included in SDPs.
 	// included in SDPs.
 	hasPersonalCompartmentIDs := len(personalCompartmentIDs) > 0
 	hasPersonalCompartmentIDs := len(personalCompartmentIDs) > 0
 
 
-	webRTCConn, SDP, sdpMetrics, webRTCErr := newWebRTCConnWithAnswer(
+	webRTCConn, SDP, sdpMetrics, webRTCErr := newWebRTCConnForAnswer(
 		webRTCAnswerCtx,
 		webRTCAnswerCtx,
 		&webRTCConfig{
 		&webRTCConfig{
 			Logger:                      p.config.Logger,
 			Logger:                      p.config.Logger,

+ 2 - 2
psiphon/common/inproxy/sdp_test.go

@@ -60,7 +60,7 @@ func runTestProcessSDP() error {
 	SetAllowBogonWebRTCConnections(true)
 	SetAllowBogonWebRTCConnections(true)
 	defer SetAllowBogonWebRTCConnections(false)
 	defer SetAllowBogonWebRTCConnections(false)
 
 
-	conn, webRTCSDP, metrics, err := newWebRTCConnWithOffer(
+	conn, webRTCSDP, metrics, err := newWebRTCConnForOffer(
 		context.Background(), config, hasPersonalCompartmentIDs)
 		context.Background(), config, hasPersonalCompartmentIDs)
 	if err != nil {
 	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
@@ -194,7 +194,7 @@ func runTestProcessSDP() error {
 	allowPrivateIPAddressCandidates = true
 	allowPrivateIPAddressCandidates = true
 	filterPrivateIPAddressCandidates = true
 	filterPrivateIPAddressCandidates = true
 
 
-	conn, webRTCSDP, metrics, err = newWebRTCConnWithOffer(
+	conn, webRTCSDP, metrics, err = newWebRTCConnForOffer(
 		context.Background(), config, hasPersonalCompartmentIDs)
 		context.Background(), config, hasPersonalCompartmentIDs)
 	if err != nil {
 	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)

+ 91 - 28
psiphon/common/inproxy/webrtc.go

@@ -51,6 +51,8 @@ import (
 )
 )
 
 
 const (
 const (
+	portMappingAwaitTimeout = 2 * time.Second
+
 	dataChannelAwaitTimeout                      = 20 * time.Second
 	dataChannelAwaitTimeout                      = 20 * time.Second
 	dataChannelBufferedAmountLowThreshold uint64 = 512 * 1024
 	dataChannelBufferedAmountLowThreshold uint64 = 512 * 1024
 	dataChannelMaxBufferedAmount          uint64 = 1024 * 1024
 	dataChannelMaxBufferedAmount          uint64 = 1024 * 1024
@@ -151,7 +153,7 @@ type webRTCConfig struct {
 // answer SDP received in response, call SetRemoteSDP with the answer SDP and
 // answer SDP received in response, call SetRemoteSDP with the answer SDP and
 // then call AwaitInitialDataChannel to await the eventual WebRTC connection
 // then call AwaitInitialDataChannel to await the eventual WebRTC connection
 // establishment.
 // establishment.
-func newWebRTCConnWithOffer(
+func newWebRTCConnForOffer(
 	ctx context.Context,
 	ctx context.Context,
 	config *webRTCConfig,
 	config *webRTCConfig,
 	hasPersonalCompartmentIDs bool) (
 	hasPersonalCompartmentIDs bool) (
@@ -169,7 +171,7 @@ func newWebRTCConnWithOffer(
 // that provided an offer SDP. An answer SDP is returned to be sent to the
 // that provided an offer SDP. An answer SDP is returned to be sent to the
 // peer. After the answer SDP is forwarded, call AwaitInitialDataChannel to
 // peer. After the answer SDP is forwarded, call AwaitInitialDataChannel to
 // await the eventual WebRTC connection establishment.
 // await the eventual WebRTC connection establishment.
-func newWebRTCConnWithAnswer(
+func newWebRTCConnForAnswer(
 	ctx context.Context,
 	ctx context.Context,
 	config *webRTCConfig,
 	config *webRTCConfig,
 	peerSDP WebRTCSessionDescription,
 	peerSDP WebRTCSessionDescription,
@@ -461,20 +463,34 @@ func newWebRTCConn(
 	disableInbound := config.WebRTCDialCoordinator.DisableInboundForMobileNetworks() &&
 	disableInbound := config.WebRTCDialCoordinator.DisableInboundForMobileNetworks() &&
 		config.WebRTCDialCoordinator.NetworkType() == NetworkTypeMobile
 		config.WebRTCDialCoordinator.NetworkType() == NetworkTypeMobile
 
 
-	// Try to establish a port mapping (UPnP-IGD, PCP, or NAT-PMP). The port
-	// mapper will attempt to identify the local gateway and query various
-	// port mapping protocols. portMapper.start launches this process and
-	// does not block. Port mappings are not part of the WebRTC standard, or
-	// supported by pion/webrtc. Instead, if a port mapping is established,
-	// it's edited into the SDP as a new host-type ICE candidate.
+	// Try to establish a port mapping (UPnP-IGD, PCP, or NAT-PMP), using port
+	// mapping services previously found and recorded in PortMappingProbe.
+	// Note that portMapper may perform additional probes. portMapper.start
+	// launches the process of creating a new port mapping and does not
+	// block. Port mappings are not part of the WebRTC standard, or supported
+	// by pion/webrtc. Instead, if a port mapping is established, it's edited
+	// into the SDP as a new host-type ICE candidate.
 
 
-	localPort := udpConn.LocalAddr().(*net.UDPAddr).Port
-	portMapper := newPortMapper(config.Logger, localPort)
+	portMappingProbe := config.WebRTCDialCoordinator.PortMappingProbe()
 
 
-	doPortMapping := !disableInbound && !config.WebRTCDialCoordinator.DisablePortMapping()
+	doPortMapping := !disableInbound &&
+		!config.WebRTCDialCoordinator.DisablePortMapping() &&
+		portMappingProbe != nil
 
 
+	var portMapper *portMapper
 	if doPortMapping {
 	if doPortMapping {
-		portMapper.start()
+		localPort := udpConn.LocalAddr().(*net.UDPAddr).Port
+		portMapper, err = newPortMapper(config.Logger, portMappingProbe, localPort)
+		if err != nil {
+			config.Logger.WithTraceFields(common.LogFields{
+				"error": err,
+			}).Warning("newPortMapper failed")
+			// Continue without port mapper
+		} else {
+			portMapper.start()
+			// On early return, portMapper will be closed by the following
+			// deferred conn.Close.
+		}
 	}
 	}
 
 
 	// Select a STUN server for ICE hole punching. The STUN server to be used
 	// Select a STUN server for ICE hole punching. The STUN server to be used
@@ -688,27 +704,74 @@ func newWebRTCConn(
 	iceCompleted := false
 	iceCompleted := false
 	portMappingExternalAddr := ""
 	portMappingExternalAddr := ""
 
 
-	select {
-	case <-iceComplete:
-		iceCompleted = true
+	if portMapper == nil {
 
 
-	case portMappingExternalAddr = <-portMapper.portMappingExternalAddress():
+		select {
+		case <-iceComplete:
+			iceCompleted = true
+		case <-ctx.Done():
+			return nil, nil, nil, errors.Trace(ctx.Err())
+		}
+
+	} else {
+
+		select {
+		case <-iceComplete:
+			iceCompleted = true
+		case portMappingExternalAddr = <-portMapper.portMappingExternalAddress():
+		case <-ctx.Done():
+			return nil, nil, nil, errors.Trace(ctx.Err())
+		}
 
 
-		// Set responding port mapping types for metrics.
+		// When STUN is skipped and a port mapping is expected to be
+		// available, await a port mapping for a short period. In this
+		// scenario, pion ICE gathering may complete first, since it's only
+		// gathering local host candidates.
+		//
+		// It remains possible that these local candidates are sufficient, if
+		// they are public IPs or private IPs on the same LAN as the peer in
+		// the case of personal pairing. For that reason, the await timeout
+		// should be no more than a couple of seconds.
 		//
 		//
-		// Limitation: if there are multiple responding protocol types, it's
-		// not known here which was used for this dial.
-		config.WebRTCDialCoordinator.SetPortMappingTypes(
-			getRespondingPortMappingTypes(config.WebRTCDialCoordinator.NetworkID()))
+		// TODO: also await port mappings when doSTUN, in case there are no
+		// STUN candidates; see hasServerReflexive check below; as it stands,
+		// in this case, it's more likely that port mapping won the previous
+		// select race.
+
+		if iceCompleted && portMappingExternalAddr == "" && !doSTUN && doPortMapping {
+
+			timer := time.NewTimer(
+				common.ValueOrDefault(
+					config.WebRTCDialCoordinator.WebRTCAwaitPortMappingTimeout(),
+					portMappingAwaitTimeout))
+			defer timer.Stop()
+
+			select {
+			case portMappingExternalAddr = <-portMapper.portMappingExternalAddress():
+			case <-timer.C:
+				// Continue without port mapping
+			case <-ctx.Done():
+				return nil, nil, nil, errors.Trace(ctx.Err())
+			}
+			timer.Stop()
+		}
 
 
-	case <-ctx.Done():
-		return nil, nil, nil, errors.Trace(ctx.Err())
-	}
+		if portMapper != nil && portMappingExternalAddr == "" {
 
 
-	// Release any port mapping resources when not using it.
-	if portMapper != nil && portMappingExternalAddr == "" {
-		portMapper.close()
-		conn.portMapper = nil
+			// Release any port mapping resources when not using it.
+			portMapper.close()
+			conn.portMapper = nil
+
+		} else if portMappingExternalAddr != "" {
+
+			// Update responding port mapping types for metrics.
+			//
+			// Limitation: if there are multiple responding protocol types, it's
+			// not known here which was used for this dial.
+			config.WebRTCDialCoordinator.SetPortMappingTypes(
+				getRespondingPortMappingTypes(config.WebRTCDialCoordinator.NetworkID()))
+
+		}
 	}
 	}
 
 
 	config.Logger.WithTraceFields(common.LogFields{
 	config.Logger.WithTraceFields(common.LogFields{

+ 2 - 0
psiphon/common/parameters/parameters.go

@@ -437,6 +437,7 @@ const (
 	InproxyProxyDiscoverNATTimeout                     = "InproxyProxyDiscoverNATTimeout"
 	InproxyProxyDiscoverNATTimeout                     = "InproxyProxyDiscoverNATTimeout"
 	InproxyClientDiscoverNATTimeout                    = "InproxyClientDiscoverNATTimeout"
 	InproxyClientDiscoverNATTimeout                    = "InproxyClientDiscoverNATTimeout"
 	InproxyWebRTCAnswerTimeout                         = "InproxyWebRTCAnswerTimeout"
 	InproxyWebRTCAnswerTimeout                         = "InproxyWebRTCAnswerTimeout"
+	InproxyWebRTCAwaitPortMappingTimeout               = "InproxyWebRTCAwaitPortMappingTimeout"
 	InproxyProxyWebRTCAwaitDataChannelTimeout          = "InproxyProxyWebRTCAwaitDataChannelTimeout"
 	InproxyProxyWebRTCAwaitDataChannelTimeout          = "InproxyProxyWebRTCAwaitDataChannelTimeout"
 	InproxyClientWebRTCAwaitDataChannelTimeout         = "InproxyClientWebRTCAwaitDataChannelTimeout"
 	InproxyClientWebRTCAwaitDataChannelTimeout         = "InproxyClientWebRTCAwaitDataChannelTimeout"
 	InproxyProxyDestinationDialTimeout                 = "InproxyProxyDestinationDialTimeout"
 	InproxyProxyDestinationDialTimeout                 = "InproxyProxyDestinationDialTimeout"
@@ -945,6 +946,7 @@ var defaultParameters = map[string]struct {
 	InproxyProxyDiscoverNATTimeout:                     {value: 10 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyProxyDiscoverNATTimeout:                     {value: 10 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyClientDiscoverNATTimeout:                    {value: 10 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyClientDiscoverNATTimeout:                    {value: 10 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyWebRTCAnswerTimeout:                         {value: 20 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyWebRTCAnswerTimeout:                         {value: 20 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
+	InproxyWebRTCAwaitPortMappingTimeout:               {value: 2 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyProxyWebRTCAwaitDataChannelTimeout:          {value: 30 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyProxyWebRTCAwaitDataChannelTimeout:          {value: 30 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyClientWebRTCAwaitDataChannelTimeout:         {value: 20 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyClientWebRTCAwaitDataChannelTimeout:         {value: 20 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyProxyDestinationDialTimeout:                 {value: 20 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyProxyDestinationDialTimeout:                 {value: 20 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},

+ 1 - 1
psiphon/controller.go

@@ -2799,7 +2799,7 @@ func (controller *Controller) runInproxyProxy() {
 	allowProxy := p.Bool(parameters.InproxyAllowProxy)
 	allowProxy := p.Bool(parameters.InproxyAllowProxy)
 	p.Close()
 	p.Close()
 
 
-	// Running an unstream proxy is also an incompatible case.
+	// Running an upstream proxy is also an incompatible case.
 
 
 	useUpstreamProxy := controller.config.UseUpstreamProxy()
 	useUpstreamProxy := controller.config.UseUpstreamProxy()
 
 

+ 50 - 2
psiphon/inproxy.go

@@ -1667,6 +1667,7 @@ type InproxyWebRTCDialInstance struct {
 	disableIPv6ICECandidates        bool
 	disableIPv6ICECandidates        bool
 	discoverNATTimeout              time.Duration
 	discoverNATTimeout              time.Duration
 	webRTCAnswerTimeout             time.Duration
 	webRTCAnswerTimeout             time.Duration
+	webRTCAwaitPortMappingTimeout   time.Duration
 	awaitDataChannelTimeout         time.Duration
 	awaitDataChannelTimeout         time.Duration
 	proxyDestinationDialTimeout     time.Duration
 	proxyDestinationDialTimeout     time.Duration
 	proxyRelayInactivityTimeout     time.Duration
 	proxyRelayInactivityTimeout     time.Duration
@@ -1770,6 +1771,7 @@ func NewInproxyWebRTCDialInstance(
 		disableIPv6ICECandidates:        disableIPv6ICECandidates,
 		disableIPv6ICECandidates:        disableIPv6ICECandidates,
 		discoverNATTimeout:              discoverNATTimeout,
 		discoverNATTimeout:              discoverNATTimeout,
 		webRTCAnswerTimeout:             p.Duration(parameters.InproxyWebRTCAnswerTimeout),
 		webRTCAnswerTimeout:             p.Duration(parameters.InproxyWebRTCAnswerTimeout),
+		webRTCAwaitPortMappingTimeout:   p.Duration(parameters.InproxyWebRTCAwaitPortMappingTimeout),
 		awaitDataChannelTimeout:         awaitDataChannelTimeout,
 		awaitDataChannelTimeout:         awaitDataChannelTimeout,
 		proxyDestinationDialTimeout:     p.Duration(parameters.InproxyProxyDestinationDialTimeout),
 		proxyDestinationDialTimeout:     p.Duration(parameters.InproxyProxyDestinationDialTimeout),
 		proxyRelayInactivityTimeout:     p.Duration(parameters.InproxyProxyRelayInactivityTimeout),
 		proxyRelayInactivityTimeout:     p.Duration(parameters.InproxyProxyRelayInactivityTimeout),
@@ -1894,10 +1896,22 @@ func (w *InproxyWebRTCDialInstance) PortMappingTypes() inproxy.PortMappingTypes
 }
 }
 
 
 // Implements the inproxy.WebRTCDialCoordinator interface.
 // Implements the inproxy.WebRTCDialCoordinator interface.
-func (w *InproxyWebRTCDialInstance) SetPortMappingTypes(portMappingTypes inproxy.PortMappingTypes) {
+func (w *InproxyWebRTCDialInstance) SetPortMappingTypes(
+	portMappingTypes inproxy.PortMappingTypes) {
 	w.natStateManager.setPortMappingTypes(w.networkID, portMappingTypes)
 	w.natStateManager.setPortMappingTypes(w.networkID, portMappingTypes)
 }
 }
 
 
+// Implements the inproxy.WebRTCDialCoordinator interface.
+func (w *InproxyWebRTCDialInstance) PortMappingProbe() *inproxy.PortMappingProbe {
+	return w.natStateManager.getPortMappingProbe(w.networkID)
+}
+
+// Implements the inproxy.WebRTCDialCoordinator interface.
+func (w *InproxyWebRTCDialInstance) SetPortMappingProbe(
+	portMappingProbe *inproxy.PortMappingProbe) {
+	w.natStateManager.setPortMappingProbe(w.networkID, portMappingProbe)
+}
+
 // Implements the inproxy.WebRTCDialCoordinator interface.
 // Implements the inproxy.WebRTCDialCoordinator interface.
 func (w *InproxyWebRTCDialInstance) ResolveAddress(ctx context.Context, network, address string) (string, error) {
 func (w *InproxyWebRTCDialInstance) ResolveAddress(ctx context.Context, network, address string) (string, error) {
 
 
@@ -2061,6 +2075,11 @@ func (w *InproxyWebRTCDialInstance) WebRTCAnswerTimeout() time.Duration {
 	return w.webRTCAnswerTimeout
 	return w.webRTCAnswerTimeout
 }
 }
 
 
+// Implements the inproxy.WebRTCDialCoordinator interface.
+func (w *InproxyWebRTCDialInstance) WebRTCAwaitPortMappingTimeout() time.Duration {
+	return w.webRTCAwaitPortMappingTimeout
+}
+
 // Implements the inproxy.WebRTCDialCoordinator interface.
 // Implements the inproxy.WebRTCDialCoordinator interface.
 func (w *InproxyWebRTCDialInstance) WebRTCAwaitDataChannelTimeout() time.Duration {
 func (w *InproxyWebRTCDialInstance) WebRTCAwaitDataChannelTimeout() time.Duration {
 	return w.awaitDataChannelTimeout
 	return w.awaitDataChannelTimeout
@@ -2298,6 +2317,7 @@ type InproxyNATStateManager struct {
 	networkID        string
 	networkID        string
 	natType          inproxy.NATType
 	natType          inproxy.NATType
 	portMappingTypes inproxy.PortMappingTypes
 	portMappingTypes inproxy.PortMappingTypes
+	portMappingProbe *inproxy.PortMappingProbe
 }
 }
 
 
 // NewInproxyNATStateManager creates a new InproxyNATStateManager.
 // NewInproxyNATStateManager creates a new InproxyNATStateManager.
@@ -2374,7 +2394,8 @@ func (s *InproxyNATStateManager) getPortMappingTypes(
 }
 }
 
 
 func (s *InproxyNATStateManager) setPortMappingTypes(
 func (s *InproxyNATStateManager) setPortMappingTypes(
-	networkID string, portMappingTypes inproxy.PortMappingTypes) {
+	networkID string,
+	portMappingTypes inproxy.PortMappingTypes) {
 
 
 	s.mutex.Lock()
 	s.mutex.Lock()
 	defer s.mutex.Unlock()
 	defer s.mutex.Unlock()
@@ -2386,6 +2407,33 @@ func (s *InproxyNATStateManager) setPortMappingTypes(
 	s.portMappingTypes = portMappingTypes
 	s.portMappingTypes = portMappingTypes
 }
 }
 
 
+func (s *InproxyNATStateManager) getPortMappingProbe(
+	networkID string) *inproxy.PortMappingProbe {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	if s.networkID != networkID {
+		return nil
+	}
+
+	return s.portMappingProbe
+}
+
+func (s *InproxyNATStateManager) setPortMappingProbe(
+	networkID string,
+	portMappingProbe *inproxy.PortMappingProbe) {
+
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+
+	if s.networkID != networkID {
+		return
+	}
+
+	s.portMappingProbe = portMappingProbe
+}
+
 // inproxyUDPConn is based on NewUDPConn and includes the write timeout
 // inproxyUDPConn is based on NewUDPConn and includes the write timeout
 // workaround from common.WriteTimeoutUDPConn.
 // workaround from common.WriteTimeoutUDPConn.
 //
 //