|
@@ -19,266 +19,122 @@
|
|
|
|
|
|
|
|
package networkid
|
|
package networkid
|
|
|
|
|
|
|
|
-func Enabled() bool {
|
|
|
|
|
- return true
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
import (
|
|
import (
|
|
|
"fmt"
|
|
"fmt"
|
|
|
|
|
+ "net"
|
|
|
"net/netip"
|
|
"net/netip"
|
|
|
- "slices"
|
|
|
|
|
|
|
+ "runtime"
|
|
|
"strings"
|
|
"strings"
|
|
|
|
|
+ "sync"
|
|
|
"syscall"
|
|
"syscall"
|
|
|
|
|
+ "time"
|
|
|
"unsafe"
|
|
"unsafe"
|
|
|
|
|
|
|
|
|
|
+ "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
|
|
|
"github.com/go-ole/go-ole"
|
|
"github.com/go-ole/go-ole"
|
|
|
"golang.org/x/sys/windows"
|
|
"golang.org/x/sys/windows"
|
|
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
|
|
"tailscale.com/wgengine/winnet"
|
|
"tailscale.com/wgengine/winnet"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-/*
|
|
|
|
|
-Here are the values we want, to construct our "network type" value:
|
|
|
|
|
-- ID that uniquely identifies the currently connected network, aka "network ID". We want this to be
|
|
|
|
|
- stable for the same network over time.
|
|
|
|
|
-- Internet connection type: wi-fi, wired, or mobile. (TODO: Bluetooth? USB?)
|
|
|
|
|
-- Whether the internet connection is being tunneled through a VPN.
|
|
|
|
|
-
|
|
|
|
|
-We will define "currently connected network" as those with the default routes.
|
|
|
|
|
-
|
|
|
|
|
-1. Get the interfaces associated with the default routes. There might be more than one interface;
|
|
|
|
|
- this can happen if there's also a VPN. Prefer IPv4 interfaces.
|
|
|
|
|
-2. Get the IP addresses associated with each interface. We'll need these to map the interface to
|
|
|
|
|
- their adapters. (Recall that interfaces are logical and adapters are physical or virtual (VPNs).)
|
|
|
|
|
-3. For each interface, get the "interface type". This contributes to determining connection type
|
|
|
|
|
- and VPN status.
|
|
|
|
|
-4. When determining which interface to get what data from, consider the metric (which determines
|
|
|
|
|
- routing priority).
|
|
|
|
|
-5. Use interface IP addresses to find the associated adapter. From the adapter, get the **network ID**
|
|
|
|
|
- and adapter description (which we'll use to check for VPN connection).
|
|
|
|
|
-6. Using interface type and adapter description, determine **connection type** and **VPN status**.
|
|
|
|
|
-*/
|
|
|
|
|
-
|
|
|
|
|
-type defaultRouteInfo struct {
|
|
|
|
|
- interfaceLUID winipcfg.LUID
|
|
|
|
|
- metric uint32
|
|
|
|
|
- family winipcfg.AddressFamily
|
|
|
|
|
- ifType winipcfg.IfType
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-type interfaceInfo struct {
|
|
|
|
|
- luid winipcfg.LUID
|
|
|
|
|
- description string
|
|
|
|
|
- ifType winipcfg.IfType
|
|
|
|
|
- metric uint32
|
|
|
|
|
- addresses []netip.Addr
|
|
|
|
|
- networkID string
|
|
|
|
|
- connectionType string
|
|
|
|
|
- isVPN bool
|
|
|
|
|
|
|
+func Enabled() bool {
|
|
|
|
|
+ return true
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Gets information about the default routes (i.e., 0.0.0.0/0 for IPv4, ::/0 for IPv6),
|
|
|
|
|
-// for the given address family in metric order.
|
|
|
|
|
-func getDefaultRoutes(family winipcfg.AddressFamily) ([]defaultRouteInfo, error) {
|
|
|
|
|
- adaptersAddrs, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagIncludeAllInterfaces)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, fmt.Errorf("GetAdaptersAddresses: %w", err)
|
|
|
|
|
- }
|
|
|
|
|
- if len(adaptersAddrs) == 0 {
|
|
|
|
|
- return nil, fmt.Errorf("no adapters found")
|
|
|
|
|
|
|
+// Get address associated with the default interface.
|
|
|
|
|
+func getDefaultLocalAddr() (net.IP, error) {
|
|
|
|
|
+ // Note that this function has no Windows-specific code and could be used elsewhere.
|
|
|
|
|
+
|
|
|
|
|
+ // This approach is described in psiphon/common/inproxy/pionNetwork.Interfaces()
|
|
|
|
|
+ // The basic idea is that we initialize a UDP connection and see what local
|
|
|
|
|
+ // address the system decides to use.
|
|
|
|
|
+ // TODO: Use common test IP addresses in that function and this.
|
|
|
|
|
+
|
|
|
|
|
+ // We'll prefer IPv4 and check it first (both might be available)
|
|
|
|
|
+ ipv4UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("93.184.216.34:3478"))
|
|
|
|
|
+ ipv4UDPConn, ipv4Err := net.DialUDP("udp4", nil, ipv4UDPAddr)
|
|
|
|
|
+ if ipv4Err == nil {
|
|
|
|
|
+ ip := ipv4UDPConn.LocalAddr().(*net.UDPAddr).IP
|
|
|
|
|
+ ipv4UDPConn.Close()
|
|
|
|
|
+ return ip, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- ipForwardTable, err := winipcfg.GetIPForwardTable2(family)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, fmt.Errorf("GetIPForwardTable2: %w", err)
|
|
|
|
|
|
|
+ ipv6UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("[2606:2800:220:1:248:1893:25c8:1946]:3478"))
|
|
|
|
|
+ ipv6UDPConn, ipv6Err := net.DialUDP("udp6", nil, ipv6UDPAddr)
|
|
|
|
|
+ if ipv6Err == nil {
|
|
|
|
|
+ ip := ipv6UDPConn.LocalAddr().(*net.UDPAddr).IP
|
|
|
|
|
+ ipv6UDPConn.Close()
|
|
|
|
|
+ return ip, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- var defaultRoutes []defaultRouteInfo
|
|
|
|
|
- for _, route := range ipForwardTable {
|
|
|
|
|
- if route.DestinationPrefix.PrefixLength != 0 ||
|
|
|
|
|
- (route.DestinationPrefix.RawPrefix.Family != windows.AF_INET &&
|
|
|
|
|
- route.DestinationPrefix.RawPrefix.Family != windows.AF_INET6) {
|
|
|
|
|
- // Not a default route.
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- var adapterAddrs *winipcfg.IPAdapterAddresses
|
|
|
|
|
- for _, iface := range adaptersAddrs {
|
|
|
|
|
- if iface.LUID == route.InterfaceLUID {
|
|
|
|
|
- // Found the interface for this route.
|
|
|
|
|
- adapterAddrs = iface
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- if adapterAddrs == nil {
|
|
|
|
|
- // No adapter found for this route.
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Don't add duplicates
|
|
|
|
|
- dup := slices.ContainsFunc(defaultRoutes, func(dr defaultRouteInfo) bool {
|
|
|
|
|
- return dr.interfaceLUID == route.InterfaceLUID
|
|
|
|
|
- })
|
|
|
|
|
- if dup {
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Microsoft docs say:
|
|
|
|
|
- //
|
|
|
|
|
- // "The actual route metric used to compute the route preferences for IPv4 is the
|
|
|
|
|
- // summation of the route metric offset specified in the Metric member of the
|
|
|
|
|
- // MIB_IPFORWARD_ROW2 structure and the interface metric specified in this member
|
|
|
|
|
- // for IPv4"
|
|
|
|
|
- metric := route.Metric
|
|
|
|
|
- switch family {
|
|
|
|
|
- case windows.AF_INET:
|
|
|
|
|
- metric += adapterAddrs.Ipv4Metric
|
|
|
|
|
- case windows.AF_INET6:
|
|
|
|
|
- metric += adapterAddrs.Ipv6Metric
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- defaultRoutes = append(defaultRoutes, defaultRouteInfo{
|
|
|
|
|
- family: route.DestinationPrefix.RawPrefix.Family,
|
|
|
|
|
- interfaceLUID: route.InterfaceLUID,
|
|
|
|
|
- metric: metric,
|
|
|
|
|
- ifType: adapterAddrs.IfType,
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- slices.SortFunc(defaultRoutes, func(a, b defaultRouteInfo) int {
|
|
|
|
|
- if a.metric < b.metric {
|
|
|
|
|
- return -1
|
|
|
|
|
- } else if a.metric > b.metric {
|
|
|
|
|
- return 1
|
|
|
|
|
- }
|
|
|
|
|
- return 0
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- return defaultRoutes, nil
|
|
|
|
|
|
|
+ return nil, errors.Trace(ipv4Err)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// The default routes (0.0.0.0/0 for IPv4, ::/0 for IPv6) have one or more interfaces associated
|
|
|
|
|
-// with them; more than one generally means there's an active VPN connection affecting all traffic.
|
|
|
|
|
-// This function returns the interfaces associated with the default routes. If there IPv4 routes
|
|
|
|
|
-// and interfaces, only those will be returned; otherwise, IPv6 interfaces will be returned.
|
|
|
|
|
-func getDefaultRouteInterfaces() ([]interfaceInfo, error) {
|
|
|
|
|
- var family winipcfg.AddressFamily = windows.AF_INET
|
|
|
|
|
- routes, err := getDefaultRoutes(family)
|
|
|
|
|
|
|
+// Given the IP of a local interface, get that interface info.
|
|
|
|
|
+func getInterfaceForLocalIP(ip net.IP) (*net.Interface, error) {
|
|
|
|
|
+ // Note that this function has no Windows-specific code and could be used elsewhere.
|
|
|
|
|
+
|
|
|
|
|
+ ifaces, err := net.Interfaces()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("getDefaultRoutes(ipv4): %w", err)
|
|
|
|
|
|
|
+ return nil, errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if len(routes) == 0 {
|
|
|
|
|
- // No IPv4 default routes, try IPv6.
|
|
|
|
|
- family = windows.AF_INET6
|
|
|
|
|
- routes, err = getDefaultRoutes(family)
|
|
|
|
|
|
|
+ for _, iface := range ifaces {
|
|
|
|
|
+ addrs, err := iface.Addrs()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return nil, fmt.Errorf("getDefaultRoutes(ipv6): %w", err)
|
|
|
|
|
|
|
+ return nil, errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
- if len(routes) == 0 {
|
|
|
|
|
- // TODO: return an error or an empty slice?
|
|
|
|
|
- return nil, fmt.Errorf("no default routes found")
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // Now we have the default routes, in metric order
|
|
|
|
|
-
|
|
|
|
|
- unicastIPAddresses, err := winipcfg.GetUnicastIPAddressTable(family)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, fmt.Errorf("GetAdaptersAddresses: %w", err)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- interfaces := make([]interfaceInfo, 0, len(routes))
|
|
|
|
|
- for _, route := range routes {
|
|
|
|
|
- ifaceInfo := interfaceInfo{
|
|
|
|
|
- luid: route.interfaceLUID,
|
|
|
|
|
- metric: route.metric,
|
|
|
|
|
- ifType: route.ifType,
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ for _, addr := range addrs {
|
|
|
|
|
+ addrIP, _, err := net.ParseCIDR(addr.String())
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, errors.Trace(err)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- for _, addr := range unicastIPAddresses {
|
|
|
|
|
- if addr.InterfaceLUID == route.interfaceLUID {
|
|
|
|
|
- ifaceInfo.addresses = append(ifaceInfo.addresses, addr.Address.Addr())
|
|
|
|
|
|
|
+ if addrIP.Equal(ip) {
|
|
|
|
|
+ return &iface, nil
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- interfaces = append(interfaces, ifaceInfo)
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- return interfaces, nil
|
|
|
|
|
|
|
+ return nil, errors.TraceNew("not found")
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// For the given set of interface IP addresses, get the network ID and description.
|
|
|
|
|
-func getNetworkInfo(interfaceIPAddrs []netip.Addr) (networkID, description string, err error) {
|
|
|
|
|
- if len(interfaceIPAddrs) == 0 {
|
|
|
|
|
- return "", "", fmt.Errorf("no addresses")
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // The set of IP addresses seems to be the only reliable way to map an interface to an
|
|
|
|
|
- // adapter. (Remember that adapters and interfaces are not identical and have a
|
|
|
|
|
- // many-to-many relationship.)
|
|
|
|
|
- // We need the adapter to get the network ID.
|
|
|
|
|
-
|
|
|
|
|
- // Assume all provided addresses are the same family
|
|
|
|
|
- var family winipcfg.AddressFamily = windows.AF_INET
|
|
|
|
|
- if interfaceIPAddrs[0].Is6() {
|
|
|
|
|
- family = windows.AF_INET6
|
|
|
|
|
|
|
+// Given the interface index, get info about the interface and its network.
|
|
|
|
|
+func getInterfaceInfo(index int) (networkID, description string, ifType winipcfg.IfType, err error) {
|
|
|
|
|
+ luid, err := winipcfg.LUIDFromIndex(uint32(index))
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return "", "", 0, errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- adapterAddrs, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagIncludeAllInterfaces)
|
|
|
|
|
|
|
+ ifrow, err := luid.Interface()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return "", "", fmt.Errorf("GetAdaptersAddresses: %w", err)
|
|
|
|
|
|
|
+ return "", "", 0, errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- var adapterGUID string
|
|
|
|
|
-adapterAddrsLoop:
|
|
|
|
|
- for _, adapterAddr := range adapterAddrs {
|
|
|
|
|
- for unicast := adapterAddr.FirstUnicastAddress; unicast != nil; unicast = unicast.Next {
|
|
|
|
|
- unicastNetIP, _ := netip.AddrFromSlice(unicast.Address.IP())
|
|
|
|
|
- for _, addr := range interfaceIPAddrs {
|
|
|
|
|
- if addr == unicastNetIP {
|
|
|
|
|
- // IP matches; found the adapter.
|
|
|
|
|
- // We are making the assumption that a single IP is enough to uniquely
|
|
|
|
|
- // identify an adapter, rather than the whole set of them. This seems
|
|
|
|
|
- // reasonable.
|
|
|
|
|
- guid, err := adapterAddr.LUID.GUID()
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return "", "", fmt.Errorf("cannot convert adapter LUID to GUID")
|
|
|
|
|
- }
|
|
|
|
|
- if guid == nil {
|
|
|
|
|
- return "", "", fmt.Errorf("adapter LUID has no GUID")
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- adapterGUID = guid.String()
|
|
|
|
|
- description = adapterAddr.Description() + " " + adapterAddr.FriendlyName()
|
|
|
|
|
- break adapterAddrsLoop
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ description = ifrow.Description() + " " + ifrow.Alias()
|
|
|
|
|
|
|
|
- // We have the description and the adapter GUID, which we can use to get the network ID.
|
|
|
|
|
|
|
+ ifType = ifrow.Type
|
|
|
|
|
|
|
|
var c ole.Connection
|
|
var c ole.Connection
|
|
|
nlm, err := winnet.NewNetworkListManager(&c)
|
|
nlm, err := winnet.NewNetworkListManager(&c)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return "", "", fmt.Errorf("NewNetworkListManager: %w", err)
|
|
|
|
|
|
|
+ return "", "", 0, errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
defer nlm.Release()
|
|
defer nlm.Release()
|
|
|
|
|
|
|
|
netConns, err := nlm.GetNetworkConnections()
|
|
netConns, err := nlm.GetNetworkConnections()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return "", "", fmt.Errorf("GetNetworkConnections: %w", err)
|
|
|
|
|
|
|
+ return "", "", 0, errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
defer netConns.Release()
|
|
defer netConns.Release()
|
|
|
|
|
|
|
|
for _, nc := range netConns {
|
|
for _, nc := range netConns {
|
|
|
- adapterID, err := nc.GetAdapterId()
|
|
|
|
|
|
|
+ ncAdapterID, err := nc.GetAdapterId()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return "", "", fmt.Errorf("GetAdapterId: %w", err)
|
|
|
|
|
|
|
+ return "", "", 0, errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
- if adapterID != adapterGUID {
|
|
|
|
|
|
|
+ if ncAdapterID != ifrow.InterfaceGUID.String() {
|
|
|
continue
|
|
continue
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -287,7 +143,7 @@ adapterAddrsLoop:
|
|
|
|
|
|
|
|
n, err := nc.GetNetwork()
|
|
n, err := nc.GetNetwork()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return "", "", fmt.Errorf("GetNetwork: %w", err)
|
|
|
|
|
|
|
+ return "", "", 0, errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
defer n.Release()
|
|
defer n.Release()
|
|
|
|
|
|
|
@@ -297,20 +153,22 @@ adapterAddrsLoop:
|
|
|
uintptr(unsafe.Pointer(n)),
|
|
uintptr(unsafe.Pointer(n)),
|
|
|
uintptr(unsafe.Pointer(&guid)))
|
|
uintptr(unsafe.Pointer(&guid)))
|
|
|
if hr != 0 {
|
|
if hr != 0 {
|
|
|
- return "", "", fmt.Errorf("GetNetworkId failed: %08x", hr)
|
|
|
|
|
|
|
+ return "", "", 0, fmt.Errorf("GetNetworkId failed: %08x", hr)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
networkID = guid.String()
|
|
networkID = guid.String()
|
|
|
- break
|
|
|
|
|
|
|
+ return networkID, description, ifType, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- return networkID, description, nil
|
|
|
|
|
|
|
+ return "", "", 0, fmt.Errorf("network connection not found for interface %d", index)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Get the connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the network with the given
|
|
// Get the connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the network with the given
|
|
|
-// interface type and description, and determine if it is a VPN.
|
|
|
|
|
-// If the correct connection type can not be determined, connectionType will be set to "UNKNOWN".
|
|
|
|
|
-func GetConnectionType(ifType winipcfg.IfType, description string) (connectionType string, isVPN bool) {
|
|
|
|
|
|
|
+// interface type and description.
|
|
|
|
|
+// If the correct connection type can not be determined, "UNKNOWN" will be returned.
|
|
|
|
|
+func getConnectionType(ifType winipcfg.IfType, description string) string {
|
|
|
|
|
+ var connectionType string
|
|
|
|
|
+
|
|
|
switch ifType {
|
|
switch ifType {
|
|
|
case winipcfg.IfTypeEthernetCSMACD:
|
|
case winipcfg.IfTypeEthernetCSMACD:
|
|
|
connectionType = "WIRED"
|
|
connectionType = "WIRED"
|
|
@@ -320,12 +178,11 @@ func GetConnectionType(ifType winipcfg.IfType, description string) (connectionTy
|
|
|
connectionType = "MOBILE"
|
|
connectionType = "MOBILE"
|
|
|
case winipcfg.IfTypePPP, winipcfg.IfTypePropVirtual, winipcfg.IfTypeTunnel:
|
|
case winipcfg.IfTypePPP, winipcfg.IfTypePropVirtual, winipcfg.IfTypeTunnel:
|
|
|
connectionType = "VPN"
|
|
connectionType = "VPN"
|
|
|
- isVPN = true
|
|
|
|
|
default:
|
|
default:
|
|
|
connectionType = "UNKNOWN"
|
|
connectionType = "UNKNOWN"
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if !isVPN {
|
|
|
|
|
|
|
+ if connectionType != "VPN" {
|
|
|
// The ifType doesn't indicate a VPN, but that's not well-defined, so we'll fall
|
|
// The ifType doesn't indicate a VPN, but that's not well-defined, so we'll fall
|
|
|
// back to checking for certain words in the description. This feels like a hack,
|
|
// back to checking for certain words in the description. This feels like a hack,
|
|
|
// but research suggests that it's the best we can do.
|
|
// but research suggests that it's the best we can do.
|
|
@@ -339,74 +196,96 @@ func GetConnectionType(ifType winipcfg.IfType, description string) (connectionTy
|
|
|
strings.Contains(description, "sstp") ||
|
|
strings.Contains(description, "sstp") ||
|
|
|
strings.Contains(description, "pptp") ||
|
|
strings.Contains(description, "pptp") ||
|
|
|
strings.Contains(description, "openvpn") {
|
|
strings.Contains(description, "openvpn") {
|
|
|
- isVPN = true
|
|
|
|
|
|
|
+ connectionType = "VPN"
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- return connectionType, isVPN
|
|
|
|
|
|
|
+ return connectionType
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Get information about the current active network connection(s).
|
|
|
|
|
-// networkID: Unique ID for the highest-priority (lowest metric) network connection.
|
|
|
|
|
-// connectionType: The connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the
|
|
|
|
|
-// active network connection(s). (Guaranteed to be non-empty on success, but may be "UNKNOWN".)
|
|
|
|
|
-// isVPN: True if the active network connection is a VPN.
|
|
|
|
|
-func getNetworkType() (networkID, connectionType string, isVPN bool, err error) {
|
|
|
|
|
- // Initialize COM library
|
|
|
|
|
- err = windows.CoInitializeEx(0, windows.COINIT_APARTMENTTHREADED)
|
|
|
|
|
|
|
+func getNetworkID() (string, error) {
|
|
|
|
|
+ localAddr, err := getDefaultLocalAddr()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return "", "", false, fmt.Errorf("CoInitializeEx: %w", err)
|
|
|
|
|
|
|
+ return "", errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
- defer windows.CoUninitialize()
|
|
|
|
|
|
|
|
|
|
- interfaces, err := getDefaultRouteInterfaces()
|
|
|
|
|
|
|
+ iface, err := getInterfaceForLocalIP(localAddr)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return "", "", false, fmt.Errorf("getDefaultRouteInterfaces: %w", err)
|
|
|
|
|
|
|
+ return "", errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if len(interfaces) == 0 {
|
|
|
|
|
- return "", "", false, fmt.Errorf("no default route interfaces found")
|
|
|
|
|
|
|
+ networkID, description, ifType, err := getInterfaceInfo(iface.Index)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return "", errors.Trace(err)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // If we have a VPN but it is not the lowest metric, we will consider it to be not in use.
|
|
|
|
|
- // The lowest-metric non-VPN interface with a valid connection type (WIRED, WIFI, MOBILE) will be the one we use for that.
|
|
|
|
|
|
|
+ connectionType := getConnectionType(ifType, description)
|
|
|
|
|
|
|
|
- lowestMetric := true
|
|
|
|
|
- for _, iface := range interfaces {
|
|
|
|
|
- iface.networkID, iface.description, err = getNetworkInfo(iface.addresses)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return "", "", false, fmt.Errorf("getNetworkInfo: %w", err)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ compoundID := connectionType + "-" + strings.Trim(networkID, "{}")
|
|
|
|
|
|
|
|
- iface.connectionType, iface.isVPN = GetConnectionType(iface.ifType, iface.description)
|
|
|
|
|
|
|
+ return compoundID, nil
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
- if lowestMetric {
|
|
|
|
|
- networkID = iface.networkID
|
|
|
|
|
- connectionType = iface.connectionType
|
|
|
|
|
- isVPN = iface.isVPN
|
|
|
|
|
- } else if connectionType == "" || connectionType == "UNKNOWN" || (connectionType == "VPN" && iface.connectionType != "UNKNOWN") {
|
|
|
|
|
- // We got a better value for connection type
|
|
|
|
|
- connectionType = iface.connectionType
|
|
|
|
|
- }
|
|
|
|
|
- // else this is a higher-metric interface, and the lower ones already told us what we need to know.
|
|
|
|
|
|
|
+type result struct {
|
|
|
|
|
+ networkID string
|
|
|
|
|
+ err error
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
- lowestMetric = false
|
|
|
|
|
- }
|
|
|
|
|
|
|
+var workThread struct {
|
|
|
|
|
+ init sync.Once
|
|
|
|
|
+ reqs chan (chan<- result)
|
|
|
|
|
+ err error
|
|
|
|
|
|
|
|
- return networkID, connectionType, isVPN, nil
|
|
|
|
|
|
|
+ cachedResult string
|
|
|
|
|
+ cacheExpiry time.Time
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Get returns the compound network ID; see [psiphon.NetworkIDGetter] for details.
|
|
// Get returns the compound network ID; see [psiphon.NetworkIDGetter] for details.
|
|
|
-// In that string, "VPN" takes precendence over "WIRED", "WIFI", and "MOBILE"; in that
|
|
|
|
|
-// case connectionType can be used to determine the underlying network type. (It might be
|
|
|
|
|
-// desirable to put that value into feedback, say.)
|
|
|
|
|
-func Get() (compoundID, connectionType string, isVPN bool, err error) {
|
|
|
|
|
- networkID, connectionType, isVPN, err := getNetworkType()
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return "", "", false, fmt.Errorf("getNetworkType: %w", err)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- compoundID = connectionType + "-" + strings.Trim(networkID, "{}")
|
|
|
|
|
-
|
|
|
|
|
- return compoundID, connectionType, isVPN, nil
|
|
|
|
|
|
|
+// This function is safe to call concurrently from multiple goroutines.
|
|
|
|
|
+func Get() (string, error) {
|
|
|
|
|
+ // It is not clear if the COM NetworkListManager calls are threadsafe. We're using them
|
|
|
|
|
+ // read-only and they're probably fine, but we're not sure. Additionally, our networkID
|
|
|
|
|
+ // retrieval code is somewhat slow: 3.5ms. This function gets called by each connection
|
|
|
|
|
+ // attempt (in the horse race, etc.), so this extra time might add ~10% to a such an
|
|
|
|
|
+ // attempt. The value is very unlikely to change in a short amount of time, so it seems
|
|
|
|
|
+ // like a good optimization to cache the result. We'll restrict our work to single
|
|
|
|
|
+ // thread to achieve both goals.
|
|
|
|
|
+ workThread.init.Do(func() {
|
|
|
|
|
+ workThread.reqs = make(chan (chan<- result))
|
|
|
|
|
+
|
|
|
|
|
+ go func() {
|
|
|
|
|
+ const resultCacheDuration = time.Second
|
|
|
|
|
+
|
|
|
|
|
+ runtime.LockOSThread()
|
|
|
|
|
+ defer runtime.UnlockOSThread()
|
|
|
|
|
+
|
|
|
|
|
+ if err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED); err != nil {
|
|
|
|
|
+ workThread.err = errors.Trace(err)
|
|
|
|
|
+ close(workThread.reqs)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ defer windows.CoUninitialize()
|
|
|
|
|
+
|
|
|
|
|
+ for resCh := range workThread.reqs {
|
|
|
|
|
+ if workThread.cachedResult != "" && workThread.cacheExpiry.After(time.Now()) {
|
|
|
|
|
+ resCh <- result{workThread.cachedResult, nil}
|
|
|
|
|
+ } else {
|
|
|
|
|
+ networkID, err := getNetworkID()
|
|
|
|
|
+ resCh <- result{networkID, err}
|
|
|
|
|
+ workThread.cachedResult = networkID
|
|
|
|
|
+ workThread.cacheExpiry = time.Now().Add(resultCacheDuration)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }()
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ resCh := make(chan result)
|
|
|
|
|
+ workThread.reqs <- resCh
|
|
|
|
|
+ res := <-resCh
|
|
|
|
|
+
|
|
|
|
|
+ if res.err != nil {
|
|
|
|
|
+ return "", errors.Trace(res.err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return res.networkID, nil
|
|
|
}
|
|
}
|