Browse Source

Merge pull request #713 from rod-hynes/master

Integrate common/networkid
Rod Hynes 1 year ago
parent
commit
016ba0bb43

+ 6 - 21
psiphon/common/networkid/networkid_windows.go

@@ -26,7 +26,6 @@ import (
 	"strings"
 	"sync"
 	"syscall"
-	"time"
 	"unsafe"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -236,9 +235,6 @@ var workThread struct {
 	init sync.Once
 	reqs chan (chan<- result)
 	err  error
-
-	cachedResult string
-	cacheExpiry  time.Time
 }
 
 // Get returns the compound network ID; see [psiphon.NetworkIDGetter] for details.
@@ -247,19 +243,14 @@ var workThread struct {
 // a transitory Network ID may be returned that will change on the next call. The caller
 // may wish to delay responding to a new Network ID until the value is confirmed.
 func Get() (string, error) {
-	// It is not clear if the COM NetworkListManager calls are threadsafe. We're using them
-	// read-only and they're probably fine, but we're not sure. Additionally, our networkID
-	// retrieval code is somewhat slow: 3.5ms. This function gets called by each connection
-	// attempt (in the horse race, etc.), so this extra time might add ~10% to a such an
-	// attempt. The value is very unlikely to change in a short amount of time, so it seems
-	// like a good optimization to cache the result. We'll restrict our work to single
-	// thread to achieve both goals.
+
+	// It is not clear if the COM NetworkListManager calls are threadsafe.
+	// We're using them read-only and they're probably fine, but we're not
+	// sure. We'll restrict our work to single thread.
 	workThread.init.Do(func() {
 		workThread.reqs = make(chan (chan<- result))
 
 		go func() {
-			const resultCacheDuration = 500 * time.Millisecond
-
 			// Go can switch the execution of a goroutine from one OS thread to another
 			// at (almost) any time. This may or may not be risky to do for our win32
 			// (and especially COM) calls, so we're going to explicitly lock this goroutine
@@ -276,14 +267,8 @@ func Get() (string, error) {
 			defer windows.CoUninitialize()
 
 			for resCh := range workThread.reqs {
-				if workThread.cachedResult != "" && workThread.cacheExpiry.After(time.Now()) {
-					resCh <- result{workThread.cachedResult, nil}
-				} else {
-					networkID, err := getNetworkID()
-					resCh <- result{networkID, err}
-					workThread.cachedResult = networkID
-					workThread.cacheExpiry = time.Now().Add(resultCacheDuration)
-				}
+				networkID, err := getNetworkID()
+				resCh <- result{networkID, err}
 			}
 		}()
 	})

+ 3 - 0
psiphon/common/parameters/parameters.go

@@ -477,6 +477,7 @@ const (
 	InproxyProxyOnBrokerClientFailedRetryPeriod        = "InproxyProxyOnBrokerClientFailedRetryPeriod"
 	InproxyProxyIncompatibleNetworkTypes               = "InproxyProxyIncompatibleNetworkTypes"
 	InproxyClientIncompatibleNetworkTypes              = "InproxyClientIncompatibleNetworkTypes"
+	NetworkIDCacheTTL                                  = "NetworkIDCacheTTL"
 
 	// Retired parameters
 
@@ -1017,6 +1018,8 @@ var defaultParameters = map[string]struct {
 	InproxyProxyOnBrokerClientFailedRetryPeriod:        {value: 30 * time.Second, minimum: time.Duration(0)},
 	InproxyProxyIncompatibleNetworkTypes:               {value: []string{}},
 	InproxyClientIncompatibleNetworkTypes:              {value: []string{}},
+
+	NetworkIDCacheTTL: {value: 500 * time.Millisecond, minimum: time.Duration(0)},
 }
 
 // IsServerSideOnly indicates if the parameter specified by name is used

+ 104 - 15
psiphon/config.go

@@ -35,11 +35,13 @@ import (
 	"strings"
 	"sync"
 	"sync/atomic"
+	"time"
 	"unicode"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/inproxy"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/networkid"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/resolver"
@@ -318,7 +320,8 @@ type Config struct {
 
 	// NetworkID, when not blank, is used as the identifier for the host's
 	// current active network.
-	// NetworkID is ignored when NetworkIDGetter is set.
+	// NetworkID is ignored when NetworkIDGetter is set, or when
+	// common/networkid is enabled.
 	NetworkID string
 
 	// DisableTactics disables tactics operations including requests, payload
@@ -1060,9 +1063,10 @@ type Config struct {
 	InproxyProxyOnBrokerClientFailedRetryPeriodMilliseconds *int
 	InproxyProxyIncompatibleNetworkTypes                    []string
 	InproxyClientIncompatibleNetworkTypes                   []string
+	InproxySkipAwaitFullyConnected                          bool
+	InproxyEnableWebRTCDebugLogging                         bool
 
-	InproxySkipAwaitFullyConnected  bool
-	InproxyEnableWebRTCDebugLogging bool
+	NetworkIDCacheTTLMilliseconds *int
 
 	// params is the active parameters.Parameters with defaults, config values,
 	// and, optionally, tactics applied.
@@ -1079,7 +1083,7 @@ type Config struct {
 	authorizations     []string
 
 	deviceBinder    DeviceBinder
-	networkIDGetter NetworkIDGetter
+	networkIDGetter *cachingNetworkIDGetter
 
 	clientFeatures []string
 
@@ -1507,6 +1511,9 @@ func (config *Config) Commit(migrateFromLegacyFields bool) error {
 	// wrap config.DeviceBinder and config.NetworkIDGetter/NetworkID with
 	// loggers.
 	//
+	// The network ID getter is further wrapped with a cache (see
+	// cachingNetworkIDGetter doc).
+	//
 	// New variables are set to avoid mutating input config fields.
 	// Internally, code must use config.deviceBinder and
 	// config.networkIDGetter and not the input/exported fields.
@@ -1518,17 +1525,23 @@ func (config *Config) Commit(migrateFromLegacyFields bool) error {
 	networkIDGetter := config.NetworkIDGetter
 
 	if networkIDGetter == nil {
-		// Limitation: unlike NetworkIDGetter, which calls back to platform APIs
-		// this method of network identification is not dynamic and will not reflect
-		// network changes that occur while running.
-		if config.NetworkID != "" {
-			networkIDGetter = newStaticNetworkGetter(config.NetworkID)
+		if networkid.Enabled() {
+			networkIDGetter = newCommonNetworkIDGetter()
 		} else {
-			networkIDGetter = newStaticNetworkGetter("UNKNOWN")
+			// Limitation: unlike NetworkIDGetter, which calls back to platform APIs
+			// this method of network identification is not dynamic and will not reflect
+			// network changes that occur while running.
+			if config.NetworkID != "" {
+				networkIDGetter = newStaticNetworkIDGetter(config.NetworkID)
+			} else {
+				networkIDGetter = newStaticNetworkIDGetter(unknownNetworkID)
+			}
 		}
 	}
 
-	config.networkIDGetter = newLoggingNetworkIDGetter(networkIDGetter)
+	config.networkIDGetter = newCachingNetworkIDGetter(
+		config,
+		newLoggingNetworkIDGetter(networkIDGetter))
 
 	// Initialize config.clientFeatures, which adds feature names on top of
 	// those specified by the host application in config.ClientFeatures.
@@ -2760,6 +2773,10 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.InproxyClientIncompatibleNetworkTypes] = config.InproxyClientIncompatibleNetworkTypes
 	}
 
+	if config.NetworkIDCacheTTLMilliseconds != nil {
+		applyParameters[parameters.NetworkIDCacheTTL] = fmt.Sprintf("%dms", *config.NetworkIDCacheTTLMilliseconds)
+	}
+
 	// When adding new config dial parameters that may override tactics, also
 	// update setDialParametersHash.
 
@@ -3655,18 +3672,36 @@ func (d *loggingDeviceBinder) BindToDevice(fileDescriptor int) (string, error) {
 	return deviceInfo, err
 }
 
-type staticNetworkGetter struct {
+const unknownNetworkID = "UNKNOWN"
+
+type staticNetworkIDGetter struct {
 	networkID string
 }
 
-func newStaticNetworkGetter(networkID string) *staticNetworkGetter {
-	return &staticNetworkGetter{networkID: networkID}
+func newStaticNetworkIDGetter(networkID string) *staticNetworkIDGetter {
+	return &staticNetworkIDGetter{networkID: networkID}
 }
 
-func (n *staticNetworkGetter) GetNetworkID() string {
+func (n *staticNetworkIDGetter) GetNetworkID() string {
 	return n.networkID
 }
 
+type commonNetworkIDGetter struct {
+}
+
+func newCommonNetworkIDGetter() *commonNetworkIDGetter {
+	return &commonNetworkIDGetter{}
+}
+
+func (n *commonNetworkIDGetter) GetNetworkID() string {
+	networkID, err := networkid.Get()
+	if err != nil {
+		NoticeError("networkid.Get failed: %v", errors.Trace(err))
+		return unknownNetworkID
+	}
+	return networkID
+}
+
 type loggingNetworkIDGetter struct {
 	n NetworkIDGetter
 }
@@ -3694,6 +3729,60 @@ func (n *loggingNetworkIDGetter) GetNetworkID() string {
 	return networkID
 }
 
+// cachingNetworkIDGetter caches the GetNetworkID result from the underlying
+// network ID getter. The current GetNetworkID implementations take in the
+// range of 1-7ms (Android); 2-3ms (iOS); ~3.5ms (Windows) to execute, on
+// modern devices. To minimize delaying dials and other operations that start
+// with fetching the current network ID, the return values are cached for a
+// short time. On platforms that invoke NetworkChanged, the cache is flushed
+// immediately upon a network change.
+type cachingNetworkIDGetter struct {
+	config *Config
+	n      NetworkIDGetter
+
+	mutex           sync.Mutex
+	cachedNetworkID string
+	cacheExpiry     time.Time
+}
+
+func newCachingNetworkIDGetter(
+	config *Config, n NetworkIDGetter) *cachingNetworkIDGetter {
+
+	return &cachingNetworkIDGetter{
+		config: config,
+		n:      n,
+	}
+}
+
+func (n *cachingNetworkIDGetter) GetNetworkID() string {
+	n.mutex.Lock()
+	defer n.mutex.Unlock()
+
+	if n.cachedNetworkID != "" && n.cacheExpiry.After(time.Now()) {
+		return n.cachedNetworkID
+	}
+
+	networkID := n.n.GetNetworkID()
+
+	p := n.config.GetParameters().Get()
+	ttl := p.Duration(parameters.NetworkIDCacheTTL)
+
+	if ttl > 0 {
+		n.cachedNetworkID = networkID
+		n.cacheExpiry = time.Now().Add(ttl)
+	}
+
+	return networkID
+}
+
+func (n *cachingNetworkIDGetter) FlushCache() {
+	n.mutex.Lock()
+	defer n.mutex.Unlock()
+
+	n.cachedNetworkID = ""
+	n.cacheExpiry = time.Time{}
+}
+
 // migrationsFromLegacyNoticeFilePaths returns the file migrations which must be
 // performed to move notice files from legacy file paths, which were configured
 // with the legacy config fields HomepageNoticesFilename and

+ 8 - 3
psiphon/controller.go

@@ -462,20 +462,25 @@ func (controller *Controller) SetDynamicConfig(sponsorID string, authorizations
 func (controller *Controller) NetworkChanged() {
 
 	// Explicitly reset components that don't use the current network context.
+
 	controller.TerminateNextActiveTunnel()
+
 	if controller.inproxyProxyBrokerClientManager != nil {
 		controller.inproxyProxyBrokerClientManager.NetworkChanged()
 	}
 	controller.inproxyClientBrokerClientManager.NetworkChanged()
 
+	controller.config.networkIDGetter.FlushCache()
+
+	// Cancel the previous current network context, which will interrupt any
+	// operations using this context. Then create a new context for the new
+	// current network.
+
 	controller.currentNetworkMutex.Lock()
 	defer controller.currentNetworkMutex.Unlock()
 
-	// Cancel the previous current network context, which will interrupt any
-	// operations using this context.
 	controller.currentNetworkCancelFunc()
 
-	// Create a new context for the new current network.
 	controller.currentNetworkCtx, controller.currentNetworkCancelFunc =
 		context.WithCancel(context.Background())
 }

+ 1 - 0
psiphon/dialParameters_test.go

@@ -307,6 +307,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	dialParams.Succeeded()
 
 	testNetworkID = prng.HexString(8)
+	clientConfig.networkIDGetter.FlushCache()
 
 	dialParams, err = MakeDialParameters(
 		clientConfig, steeringIPCache, nil, nil, nil, canReplay, selectProtocol, serverEntries[0], nil, nil, false, 0, 0)

+ 3 - 3
psiphon/inproxy_test.go

@@ -192,7 +192,7 @@ func runInproxyBrokerDialParametersTest(t *testing.T) error {
 	previousBrokerClient := brokerClient
 	previousNetworkID := networkID
 	networkID = "NETWORK2"
-	config.networkIDGetter = newStaticNetworkGetter(networkID)
+	config.networkIDGetter = newCachingNetworkIDGetter(config, newStaticNetworkIDGetter(networkID))
 	config.SetResolver(resolver.NewResolver(&resolver.NetworkConfig{}, networkID))
 
 	brokerClient, brokerDialParams, err = manager.GetBrokerClient(networkID)
@@ -217,7 +217,7 @@ func runInproxyBrokerDialParametersTest(t *testing.T) error {
 	// Test: another replay after switch back to previous network ID
 
 	networkID = previousNetworkID
-	config.networkIDGetter = newStaticNetworkGetter(networkID)
+	config.networkIDGetter = newCachingNetworkIDGetter(config, newStaticNetworkIDGetter(networkID))
 
 	brokerClient, brokerDialParams, err = manager.GetBrokerClient(networkID)
 	if err != nil {
@@ -422,7 +422,7 @@ func runInproxyNATStateTest() error {
 	// Test: reset
 
 	networkID = "NETWORK2"
-	config.networkIDGetter = newStaticNetworkGetter(networkID)
+	config.networkIDGetter = newCachingNetworkIDGetter(config, newStaticNetworkIDGetter(networkID))
 
 	manager.reset()