Browse Source

rewrite using UDP conn check

Adam Pritchard 1 year ago
parent
commit
31e3f08746

+ 2 - 2
go.mod

@@ -51,6 +51,7 @@ require (
 	github.com/florianl/go-nfqueue v1.1.1-0.20200829120558-a2f196e98ab0
 	github.com/florianl/go-nfqueue v1.1.1-0.20200829120558-a2f196e98ab0
 	github.com/flynn/noise v1.0.1-0.20220214164934-d803f5c4b0f4
 	github.com/flynn/noise v1.0.1-0.20220214164934-d803f5c4b0f4
 	github.com/fxamacker/cbor/v2 v2.5.0
 	github.com/fxamacker/cbor/v2 v2.5.0
+	github.com/go-ole/go-ole v1.3.0
 	github.com/gobwas/glob v0.2.4-0.20180402141543-f00a7392b439
 	github.com/gobwas/glob v0.2.4-0.20180402141543-f00a7392b439
 	github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da
 	github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da
 	github.com/google/gopacket v1.1.19
 	github.com/google/gopacket v1.1.19
@@ -85,6 +86,7 @@ require (
 	golang.org/x/term v0.19.0
 	golang.org/x/term v0.19.0
 	golang.org/x/time v0.5.0
 	golang.org/x/time v0.5.0
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
+	golang.zx2c4.com/wireguard/windows v0.5.3
 	tailscale.com v1.58.2
 	tailscale.com v1.58.2
 )
 )
 
 
@@ -102,7 +104,6 @@ require (
 	github.com/dchest/siphash v1.2.3 // indirect
 	github.com/dchest/siphash v1.2.3 // indirect
 	github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect
 	github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect
 	github.com/gaukas/godicttls v0.0.4 // indirect
 	github.com/gaukas/godicttls v0.0.4 // indirect
-	github.com/go-ole/go-ole v1.3.0 // indirect
 	github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
 	github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
 	github.com/golang/protobuf v1.5.3 // indirect
 	github.com/golang/protobuf v1.5.3 // indirect
 	github.com/google/go-cmp v0.6.0 // indirect
 	github.com/google/go-cmp v0.6.0 // indirect
@@ -152,7 +153,6 @@ require (
 	golang.org/x/mod v0.14.0 // indirect
 	golang.org/x/mod v0.14.0 // indirect
 	golang.org/x/text v0.14.0 // indirect
 	golang.org/x/text v0.14.0 // indirect
 	golang.org/x/tools v0.15.0 // indirect
 	golang.org/x/tools v0.15.0 // indirect
-	golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
 	google.golang.org/protobuf v1.31.0 // indirect
 	google.golang.org/protobuf v1.31.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )
 )

+ 142 - 263
psiphon/common/networkid/networkid_windows.go

@@ -19,266 +19,122 @@
 
 
 package networkid
 package networkid
 
 
-func Enabled() bool {
-	return true
-}
-
 import (
 import (
 	"fmt"
 	"fmt"
+	"net"
 	"net/netip"
 	"net/netip"
-	"slices"
+	"runtime"
 	"strings"
 	"strings"
+	"sync"
 	"syscall"
 	"syscall"
+	"time"
 	"unsafe"
 	"unsafe"
 
 
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/go-ole/go-ole"
 	"github.com/go-ole/go-ole"
 	"golang.org/x/sys/windows"
 	"golang.org/x/sys/windows"
 	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
 	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
 	"tailscale.com/wgengine/winnet"
 	"tailscale.com/wgengine/winnet"
 )
 )
 
 
-/*
-Here are the values we want, to construct our "network type" value:
-- ID that uniquely identifies the currently connected network, aka "network ID". We want this to be
-  stable for the same network over time.
-- Internet connection type: wi-fi, wired, or mobile. (TODO: Bluetooth? USB?)
-- Whether the internet connection is being tunneled through a VPN.
-
-We will define "currently connected network" as those with the default routes.
-
-1. Get the interfaces associated with the default routes. There might be more than one interface;
-   this can happen if there's also a VPN. Prefer IPv4 interfaces.
-2. Get the IP addresses associated with each interface. We'll need these to map the interface to
-   their adapters. (Recall that interfaces are logical and adapters are physical or virtual (VPNs).)
-3. For each interface, get the "interface type". This contributes to determining connection type
-   and VPN status.
-4. When determining which interface to get what data from, consider the metric (which determines
-   routing priority).
-5. Use interface IP addresses to find the associated adapter. From the adapter, get the **network ID**
-   and adapter description (which we'll use to check for VPN connection).
-6. Using interface type and adapter description, determine **connection type** and **VPN status**.
-*/
-
-type defaultRouteInfo struct {
-	interfaceLUID winipcfg.LUID
-	metric        uint32
-	family        winipcfg.AddressFamily
-	ifType        winipcfg.IfType
-}
-
-type interfaceInfo struct {
-	luid           winipcfg.LUID
-	description    string
-	ifType         winipcfg.IfType
-	metric         uint32
-	addresses      []netip.Addr
-	networkID      string
-	connectionType string
-	isVPN          bool
+func Enabled() bool {
+	return true
 }
 }
 
 
-// Gets information about the default routes (i.e., 0.0.0.0/0 for IPv4, ::/0 for IPv6),
-// for the given address family in metric order.
-func getDefaultRoutes(family winipcfg.AddressFamily) ([]defaultRouteInfo, error) {
-	adaptersAddrs, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagIncludeAllInterfaces)
-	if err != nil {
-		return nil, fmt.Errorf("GetAdaptersAddresses: %w", err)
-	}
-	if len(adaptersAddrs) == 0 {
-		return nil, fmt.Errorf("no adapters found")
+// Get address associated with the default interface.
+func getDefaultLocalAddr() (net.IP, error) {
+	// Note that this function has no Windows-specific code and could be used elsewhere.
+
+	// This approach is described in psiphon/common/inproxy/pionNetwork.Interfaces()
+	// The basic idea is that we initialize a UDP connection and see what local
+	// address the system decides to use.
+	// TODO: Use common test IP addresses in that function and this.
+
+	// We'll prefer IPv4 and check it first (both might be available)
+	ipv4UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("93.184.216.34:3478"))
+	ipv4UDPConn, ipv4Err := net.DialUDP("udp4", nil, ipv4UDPAddr)
+	if ipv4Err == nil {
+		ip := ipv4UDPConn.LocalAddr().(*net.UDPAddr).IP
+		ipv4UDPConn.Close()
+		return ip, nil
 	}
 	}
 
 
-	ipForwardTable, err := winipcfg.GetIPForwardTable2(family)
-	if err != nil {
-		return nil, fmt.Errorf("GetIPForwardTable2: %w", err)
+	ipv6UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("[2606:2800:220:1:248:1893:25c8:1946]:3478"))
+	ipv6UDPConn, ipv6Err := net.DialUDP("udp6", nil, ipv6UDPAddr)
+	if ipv6Err == nil {
+		ip := ipv6UDPConn.LocalAddr().(*net.UDPAddr).IP
+		ipv6UDPConn.Close()
+		return ip, nil
 	}
 	}
 
 
-	var defaultRoutes []defaultRouteInfo
-	for _, route := range ipForwardTable {
-		if route.DestinationPrefix.PrefixLength != 0 ||
-			(route.DestinationPrefix.RawPrefix.Family != windows.AF_INET &&
-				route.DestinationPrefix.RawPrefix.Family != windows.AF_INET6) {
-			// Not a default route.
-			continue
-		}
-
-		var adapterAddrs *winipcfg.IPAdapterAddresses
-		for _, iface := range adaptersAddrs {
-			if iface.LUID == route.InterfaceLUID {
-				// Found the interface for this route.
-				adapterAddrs = iface
-				break
-			}
-		}
-		if adapterAddrs == nil {
-			// No adapter found for this route.
-			continue
-		}
-
-		// Don't add duplicates
-		dup := slices.ContainsFunc(defaultRoutes, func(dr defaultRouteInfo) bool {
-			return dr.interfaceLUID == route.InterfaceLUID
-		})
-		if dup {
-			continue
-		}
-
-		// Microsoft docs say:
-		//
-		// "The actual route metric used to compute the route preferences for IPv4 is the
-		// summation of the route metric offset specified in the Metric member of the
-		// MIB_IPFORWARD_ROW2 structure and the interface metric specified in this member
-		// for IPv4"
-		metric := route.Metric
-		switch family {
-		case windows.AF_INET:
-			metric += adapterAddrs.Ipv4Metric
-		case windows.AF_INET6:
-			metric += adapterAddrs.Ipv6Metric
-		}
-
-		defaultRoutes = append(defaultRoutes, defaultRouteInfo{
-			family:        route.DestinationPrefix.RawPrefix.Family,
-			interfaceLUID: route.InterfaceLUID,
-			metric:        metric,
-			ifType:        adapterAddrs.IfType,
-		})
-	}
-
-	slices.SortFunc(defaultRoutes, func(a, b defaultRouteInfo) int {
-		if a.metric < b.metric {
-			return -1
-		} else if a.metric > b.metric {
-			return 1
-		}
-		return 0
-	})
-
-	return defaultRoutes, nil
+	return nil, errors.Trace(ipv4Err)
 }
 }
 
 
-// The default routes (0.0.0.0/0 for IPv4, ::/0 for IPv6) have one or more interfaces associated
-// with them; more than one generally means there's an active VPN connection affecting all traffic.
-// This function returns the interfaces associated with the default routes. If there IPv4 routes
-// and interfaces, only those will be returned; otherwise, IPv6 interfaces will be returned.
-func getDefaultRouteInterfaces() ([]interfaceInfo, error) {
-	var family winipcfg.AddressFamily = windows.AF_INET
-	routes, err := getDefaultRoutes(family)
+// Given the IP of a local interface, get that interface info.
+func getInterfaceForLocalIP(ip net.IP) (*net.Interface, error) {
+	// Note that this function has no Windows-specific code and could be used elsewhere.
+
+	ifaces, err := net.Interfaces()
 	if err != nil {
 	if err != nil {
-		return nil, fmt.Errorf("getDefaultRoutes(ipv4): %w", err)
+		return nil, errors.Trace(err)
 	}
 	}
 
 
-	if len(routes) == 0 {
-		// No IPv4 default routes, try IPv6.
-		family = windows.AF_INET6
-		routes, err = getDefaultRoutes(family)
+	for _, iface := range ifaces {
+		addrs, err := iface.Addrs()
 		if err != nil {
 		if err != nil {
-			return nil, fmt.Errorf("getDefaultRoutes(ipv6): %w", err)
+			return nil, errors.Trace(err)
 		}
 		}
-	}
 
 
-	if len(routes) == 0 {
-		// TODO: return an error or an empty slice?
-		return nil, fmt.Errorf("no default routes found")
-	}
-
-	// Now we have the default routes, in metric order
-
-	unicastIPAddresses, err := winipcfg.GetUnicastIPAddressTable(family)
-	if err != nil {
-		return nil, fmt.Errorf("GetAdaptersAddresses: %w", err)
-	}
-
-	interfaces := make([]interfaceInfo, 0, len(routes))
-	for _, route := range routes {
-		ifaceInfo := interfaceInfo{
-			luid:   route.interfaceLUID,
-			metric: route.metric,
-			ifType: route.ifType,
-		}
+		for _, addr := range addrs {
+			addrIP, _, err := net.ParseCIDR(addr.String())
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
 
 
-		for _, addr := range unicastIPAddresses {
-			if addr.InterfaceLUID == route.interfaceLUID {
-				ifaceInfo.addresses = append(ifaceInfo.addresses, addr.Address.Addr())
+			if addrIP.Equal(ip) {
+				return &iface, nil
 			}
 			}
 		}
 		}
-
-		interfaces = append(interfaces, ifaceInfo)
 	}
 	}
 
 
-	return interfaces, nil
+	return nil, errors.TraceNew("not found")
 }
 }
 
 
-// For the given set of interface IP addresses, get the network ID and description.
-func getNetworkInfo(interfaceIPAddrs []netip.Addr) (networkID, description string, err error) {
-	if len(interfaceIPAddrs) == 0 {
-		return "", "", fmt.Errorf("no addresses")
-	}
-
-	// The set of IP addresses seems to be the only reliable way to map an interface to an
-	// adapter. (Remember that adapters and interfaces are not identical and have a
-	// many-to-many relationship.)
-	// We need the adapter to get the network ID.
-
-	// Assume all provided addresses are the same family
-	var family winipcfg.AddressFamily = windows.AF_INET
-	if interfaceIPAddrs[0].Is6() {
-		family = windows.AF_INET6
+// Given the interface index, get info about the interface and its network.
+func getInterfaceInfo(index int) (networkID, description string, ifType winipcfg.IfType, err error) {
+	luid, err := winipcfg.LUIDFromIndex(uint32(index))
+	if err != nil {
+		return "", "", 0, errors.Trace(err)
 	}
 	}
 
 
-	adapterAddrs, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagIncludeAllInterfaces)
+	ifrow, err := luid.Interface()
 	if err != nil {
 	if err != nil {
-		return "", "", fmt.Errorf("GetAdaptersAddresses: %w", err)
+		return "", "", 0, errors.Trace(err)
 	}
 	}
 
 
-	var adapterGUID string
-adapterAddrsLoop:
-	for _, adapterAddr := range adapterAddrs {
-		for unicast := adapterAddr.FirstUnicastAddress; unicast != nil; unicast = unicast.Next {
-			unicastNetIP, _ := netip.AddrFromSlice(unicast.Address.IP())
-			for _, addr := range interfaceIPAddrs {
-				if addr == unicastNetIP {
-					// IP matches; found the adapter.
-					// We are making the assumption that a single IP is enough to uniquely
-					// identify an adapter, rather than the whole set of them. This seems
-					// reasonable.
-					guid, err := adapterAddr.LUID.GUID()
-					if err != nil {
-						return "", "", fmt.Errorf("cannot convert adapter LUID to GUID")
-					}
-					if guid == nil {
-						return "", "", fmt.Errorf("adapter LUID has no GUID")
-					}
-
-					adapterGUID = guid.String()
-					description = adapterAddr.Description() + " " + adapterAddr.FriendlyName()
-					break adapterAddrsLoop
-				}
-			}
-		}
-	}
+	description = ifrow.Description() + " " + ifrow.Alias()
 
 
-	// We have the description and the adapter GUID, which we can use to get the network ID.
+	ifType = ifrow.Type
 
 
 	var c ole.Connection
 	var c ole.Connection
 	nlm, err := winnet.NewNetworkListManager(&c)
 	nlm, err := winnet.NewNetworkListManager(&c)
 	if err != nil {
 	if err != nil {
-		return "", "", fmt.Errorf("NewNetworkListManager: %w", err)
+		return "", "", 0, errors.Trace(err)
 	}
 	}
 	defer nlm.Release()
 	defer nlm.Release()
 
 
 	netConns, err := nlm.GetNetworkConnections()
 	netConns, err := nlm.GetNetworkConnections()
 	if err != nil {
 	if err != nil {
-		return "", "", fmt.Errorf("GetNetworkConnections: %w", err)
+		return "", "", 0, errors.Trace(err)
 	}
 	}
 	defer netConns.Release()
 	defer netConns.Release()
 
 
 	for _, nc := range netConns {
 	for _, nc := range netConns {
-		adapterID, err := nc.GetAdapterId()
+		ncAdapterID, err := nc.GetAdapterId()
 		if err != nil {
 		if err != nil {
-			return "", "", fmt.Errorf("GetAdapterId: %w", err)
+			return "", "", 0, errors.Trace(err)
 		}
 		}
-		if adapterID != adapterGUID {
+		if ncAdapterID != ifrow.InterfaceGUID.String() {
 			continue
 			continue
 		}
 		}
 
 
@@ -287,7 +143,7 @@ adapterAddrsLoop:
 
 
 		n, err := nc.GetNetwork()
 		n, err := nc.GetNetwork()
 		if err != nil {
 		if err != nil {
-			return "", "", fmt.Errorf("GetNetwork: %w", err)
+			return "", "", 0, errors.Trace(err)
 		}
 		}
 		defer n.Release()
 		defer n.Release()
 
 
@@ -297,20 +153,22 @@ adapterAddrsLoop:
 			uintptr(unsafe.Pointer(n)),
 			uintptr(unsafe.Pointer(n)),
 			uintptr(unsafe.Pointer(&guid)))
 			uintptr(unsafe.Pointer(&guid)))
 		if hr != 0 {
 		if hr != 0 {
-			return "", "", fmt.Errorf("GetNetworkId failed: %08x", hr)
+			return "", "", 0, fmt.Errorf("GetNetworkId failed: %08x", hr)
 		}
 		}
 
 
 		networkID = guid.String()
 		networkID = guid.String()
-		break
+		return networkID, description, ifType, nil
 	}
 	}
 
 
-	return networkID, description, nil
+	return "", "", 0, fmt.Errorf("network connection not found for interface %d", index)
 }
 }
 
 
 // Get the connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the network with the given
 // Get the connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the network with the given
-// interface type and description, and determine if it is a VPN.
-// If the correct connection type can not be determined, connectionType will be set to "UNKNOWN".
-func GetConnectionType(ifType winipcfg.IfType, description string) (connectionType string, isVPN bool) {
+// interface type and description.
+// If the correct connection type can not be determined, "UNKNOWN" will be returned.
+func getConnectionType(ifType winipcfg.IfType, description string) string {
+	var connectionType string
+
 	switch ifType {
 	switch ifType {
 	case winipcfg.IfTypeEthernetCSMACD:
 	case winipcfg.IfTypeEthernetCSMACD:
 		connectionType = "WIRED"
 		connectionType = "WIRED"
@@ -320,12 +178,11 @@ func GetConnectionType(ifType winipcfg.IfType, description string) (connectionTy
 		connectionType = "MOBILE"
 		connectionType = "MOBILE"
 	case winipcfg.IfTypePPP, winipcfg.IfTypePropVirtual, winipcfg.IfTypeTunnel:
 	case winipcfg.IfTypePPP, winipcfg.IfTypePropVirtual, winipcfg.IfTypeTunnel:
 		connectionType = "VPN"
 		connectionType = "VPN"
-		isVPN = true
 	default:
 	default:
 		connectionType = "UNKNOWN"
 		connectionType = "UNKNOWN"
 	}
 	}
 
 
-	if !isVPN {
+	if connectionType != "VPN" {
 		// The ifType doesn't indicate a VPN, but that's not well-defined, so we'll fall
 		// The ifType doesn't indicate a VPN, but that's not well-defined, so we'll fall
 		// back to checking for certain words in the description. This feels like a hack,
 		// back to checking for certain words in the description. This feels like a hack,
 		// but research suggests that it's the best we can do.
 		// but research suggests that it's the best we can do.
@@ -339,74 +196,96 @@ func GetConnectionType(ifType winipcfg.IfType, description string) (connectionTy
 			strings.Contains(description, "sstp") ||
 			strings.Contains(description, "sstp") ||
 			strings.Contains(description, "pptp") ||
 			strings.Contains(description, "pptp") ||
 			strings.Contains(description, "openvpn") {
 			strings.Contains(description, "openvpn") {
-			isVPN = true
+			connectionType = "VPN"
 		}
 		}
 	}
 	}
 
 
-	return connectionType, isVPN
+	return connectionType
 }
 }
 
 
-// Get information about the current active network connection(s).
-// networkID: Unique ID for the highest-priority (lowest metric) network connection.
-// connectionType: The connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the
-// active network connection(s). (Guaranteed to be non-empty on success, but may be "UNKNOWN".)
-// isVPN: True if the active network connection is a VPN.
-func getNetworkType() (networkID, connectionType string, isVPN bool, err error) {
-	// Initialize COM library
-	err = windows.CoInitializeEx(0, windows.COINIT_APARTMENTTHREADED)
+func getNetworkID() (string, error) {
+	localAddr, err := getDefaultLocalAddr()
 	if err != nil {
 	if err != nil {
-		return "", "", false, fmt.Errorf("CoInitializeEx: %w", err)
+		return "", errors.Trace(err)
 	}
 	}
-	defer windows.CoUninitialize()
 
 
-	interfaces, err := getDefaultRouteInterfaces()
+	iface, err := getInterfaceForLocalIP(localAddr)
 	if err != nil {
 	if err != nil {
-		return "", "", false, fmt.Errorf("getDefaultRouteInterfaces: %w", err)
+		return "", errors.Trace(err)
 	}
 	}
 
 
-	if len(interfaces) == 0 {
-		return "", "", false, fmt.Errorf("no default route interfaces found")
+	networkID, description, ifType, err := getInterfaceInfo(iface.Index)
+	if err != nil {
+		return "", errors.Trace(err)
 	}
 	}
 
 
-	// If we have a VPN but it is not the lowest metric, we will consider it to be not in use.
-	// The lowest-metric non-VPN interface with a valid connection type (WIRED, WIFI, MOBILE) will be the one we use for that.
+	connectionType := getConnectionType(ifType, description)
 
 
-	lowestMetric := true
-	for _, iface := range interfaces {
-		iface.networkID, iface.description, err = getNetworkInfo(iface.addresses)
-		if err != nil {
-			return "", "", false, fmt.Errorf("getNetworkInfo: %w", err)
-		}
+	compoundID := connectionType + "-" + strings.Trim(networkID, "{}")
 
 
-		iface.connectionType, iface.isVPN = GetConnectionType(iface.ifType, iface.description)
+	return compoundID, nil
+}
 
 
-		if lowestMetric {
-			networkID = iface.networkID
-			connectionType = iface.connectionType
-			isVPN = iface.isVPN
-		} else if connectionType == "" || connectionType == "UNKNOWN" || (connectionType == "VPN" && iface.connectionType != "UNKNOWN") {
-			// We got a better value for connection type
-			connectionType = iface.connectionType
-		}
-		// else this is a higher-metric interface, and the lower ones already told us what we need to know.
+type result struct {
+	networkID string
+	err       error
+}
 
 
-		lowestMetric = false
-	}
+var workThread struct {
+	init sync.Once
+	reqs chan (chan<- result)
+	err  error
 
 
-	return networkID, connectionType, isVPN, nil
+	cachedResult string
+	cacheExpiry  time.Time
 }
 }
 
 
 // Get returns the compound network ID; see [psiphon.NetworkIDGetter] for details.
 // Get returns the compound network ID; see [psiphon.NetworkIDGetter] for details.
-// In that string, "VPN" takes precendence over "WIRED", "WIFI", and "MOBILE"; in that
-// case connectionType can be used to determine the underlying network type. (It might be
-// desirable to put that value into feedback, say.)
-func Get() (compoundID, connectionType string, isVPN bool, err error) {
-	 networkID, connectionType, isVPN, err := getNetworkType()
-	 if err != nil {
-		 return "", "", false, fmt.Errorf("getNetworkType: %w", err)
-	 }
-
-	 compoundID = connectionType + "-" + strings.Trim(networkID, "{}")
-
-	 return compoundID, connectionType, isVPN, nil
+// This function is safe to call concurrently from multiple goroutines.
+func Get() (string, error) {
+	// It is not clear if the COM NetworkListManager calls are threadsafe. We're using them
+	// read-only and they're probably fine, but we're not sure. Additionally, our networkID
+	// retrieval code is somewhat slow: 3.5ms. This function gets called by each connection
+	// attempt (in the horse race, etc.), so this extra time might add ~10% to a such an
+	// attempt. The value is very unlikely to change in a short amount of time, so it seems
+	// like a good optimization to cache the result. We'll restrict our work to single
+	// thread to achieve both goals.
+	workThread.init.Do(func() {
+		workThread.reqs = make(chan (chan<- result))
+
+		go func() {
+			const resultCacheDuration = time.Second
+
+			runtime.LockOSThread()
+			defer runtime.UnlockOSThread()
+
+			if err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED); err != nil {
+				workThread.err = errors.Trace(err)
+				close(workThread.reqs)
+				return
+			}
+			defer windows.CoUninitialize()
+
+			for resCh := range workThread.reqs {
+				if workThread.cachedResult != "" && workThread.cacheExpiry.After(time.Now()) {
+					resCh <- result{workThread.cachedResult, nil}
+				} else {
+					networkID, err := getNetworkID()
+					resCh <- result{networkID, err}
+					workThread.cachedResult = networkID
+					workThread.cacheExpiry = time.Now().Add(resultCacheDuration)
+				}
+			}
+		}()
+	})
+
+	resCh := make(chan result)
+	workThread.reqs <- resCh
+	res := <-resCh
+
+	if res.err != nil {
+		return "", errors.Trace(res.err)
+	}
+
+	return res.networkID, nil
 }
 }

+ 72 - 0
psiphon/common/networkid/networkid_windows_test.go

@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2024, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package networkid
+
+import (
+	"testing"
+	"time"
+)
+
+// prevent compiler optimization
+var networkID string
+var err error
+
+// This test doesn't show anything very useful, as it will mostly be getting cached results
+func BenchmarkGet(b *testing.B) {
+	for i := 0; i < b.N; i++ {
+		networkID, err = Get()
+		if err != nil {
+			b.Fatalf("error: %v", err)
+		}
+	}
+}
+
+func TestGet(t *testing.T) {
+	gotNetworkID, err := Get()
+	if err != nil {
+		t.Errorf("error: %v", err)
+		return
+	}
+	if gotNetworkID == "" {
+		t.Error("got empty network ID")
+	}
+
+	// Call again immediately to get a cached result
+	gotNetworkID, err = Get()
+	if err != nil {
+		t.Errorf("error: %v", err)
+		return
+	}
+	if gotNetworkID == "" {
+		t.Error("got empty network ID")
+	}
+
+	// Wait until the cached result expires, so we get another fresh value
+	time.Sleep(2 * time.Second)
+
+	gotNetworkID, err = Get()
+	if err != nil {
+		t.Errorf("error: %v", err)
+		return
+	}
+	if gotNetworkID == "" {
+		t.Error("got empty network ID")
+	}
+}

+ 1 - 0
vendor/modules.txt

@@ -688,6 +688,7 @@ tailscale.com/util/vizerror
 tailscale.com/util/winutil
 tailscale.com/util/winutil
 tailscale.com/version
 tailscale.com/version
 tailscale.com/version/distro
 tailscale.com/version/distro
+tailscale.com/wgengine/winnet
 # gitlab.com/yawning/obfs4.git => github.com/jmwample/obfs4 v0.0.0-20230725223418-2d2e5b4a16ba
 # gitlab.com/yawning/obfs4.git => github.com/jmwample/obfs4 v0.0.0-20230725223418-2d2e5b4a16ba
 # github.com/pion/dtls/v2 => ./replace/dtls
 # github.com/pion/dtls/v2 => ./replace/dtls
 # github.com/pion/ice/v2 => ./replace/ice
 # github.com/pion/ice/v2 => ./replace/ice