Browse Source

Merge pull request #710 from adam-p/network-id

add networkid package and windows implementation
Rod Hynes 1 year ago
parent
commit
3061945719

+ 2 - 0
.github/workflows/tests.yml

@@ -100,6 +100,7 @@ jobs:
           go test -v -timeout 30m -race ./psiphon
           go test -v -race ./ClientLibrary/clientlib
           go test -v -race ./Server/logging/analysis
+          go test -v -race ./psiphon/common/networkid
 
       # TODO: fix and re-enable test
       # sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=tun.coverprofile ./psiphon/common/tun
@@ -133,6 +134,7 @@ jobs:
           go test -v -timeout 30m -covermode=count -coverprofile=psiphon.coverprofile ./psiphon
           go test -v -covermode=count -coverprofile=clientlib.coverprofile ./ClientLibrary/clientlib
           go test -v -covermode=count -coverprofile=analysis.coverprofile ./Server/logging/analysis
+          go test -v -covermode=count -coverprofile=networkid.coverprofile ./psiphon/common/networkid
           $GOPATH/bin/gover
           $GOPATH/bin/goveralls -coverprofile=gover.coverprofile -service=github -repotoken "$COVERALLS_TOKEN"
 

+ 2 - 2
go.mod

@@ -51,6 +51,7 @@ require (
 	github.com/florianl/go-nfqueue v1.1.1-0.20200829120558-a2f196e98ab0
 	github.com/flynn/noise v1.0.1-0.20220214164934-d803f5c4b0f4
 	github.com/fxamacker/cbor/v2 v2.5.0
+	github.com/go-ole/go-ole v1.3.0
 	github.com/gobwas/glob v0.2.4-0.20180402141543-f00a7392b439
 	github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da
 	github.com/google/gopacket v1.1.19
@@ -85,6 +86,7 @@ require (
 	golang.org/x/term v0.19.0
 	golang.org/x/time v0.5.0
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
+	golang.zx2c4.com/wireguard/windows v0.5.3
 	tailscale.com v1.58.2
 )
 
@@ -102,7 +104,6 @@ require (
 	github.com/dchest/siphash v1.2.3 // indirect
 	github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect
 	github.com/gaukas/godicttls v0.0.4 // indirect
-	github.com/go-ole/go-ole v1.3.0 // indirect
 	github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
 	github.com/golang/protobuf v1.5.3 // indirect
 	github.com/google/go-cmp v0.6.0 // indirect
@@ -152,7 +153,6 @@ require (
 	golang.org/x/mod v0.14.0 // indirect
 	golang.org/x/text v0.14.0 // indirect
 	golang.org/x/tools v0.15.0 // indirect
-	golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
 	google.golang.org/protobuf v1.31.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 32 - 0
psiphon/common/networkid/networkid_disabled.go

@@ -0,0 +1,32 @@
+//go:build !windows
+
+/*
+ * Copyright (c) 2024, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package networkid
+
+import "fmt"
+
+func Enabled() bool {
+	return false
+}
+
+func Get() (string, error) {
+	return "", fmt.Errorf("operation is not enabled")
+}

+ 300 - 0
psiphon/common/networkid/networkid_windows.go

@@ -0,0 +1,300 @@
+/*
+ * Copyright (c) 2024, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package networkid
+
+import (
+	"net"
+	"net/netip"
+	"runtime"
+	"strings"
+	"sync"
+	"syscall"
+	"time"
+	"unsafe"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/go-ole/go-ole"
+	"golang.org/x/sys/windows"
+	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+	"tailscale.com/wgengine/winnet"
+)
+
+func Enabled() bool {
+	return true
+}
+
+// Get address associated with the default interface.
+func getDefaultLocalAddr() (net.IP, error) {
+	// Note that this function has no Windows-specific code and could be used elsewhere.
+
+	// This approach is described in psiphon/common/inproxy/pionNetwork.Interfaces()
+	// The basic idea is that we initialize a UDP connection and see what local
+	// address the system decides to use.
+	// Note that no actual network request is made by these calls. They can be performed
+	// with no network connectivity at all.
+	// TODO: Use common test IP addresses in that function and this.
+
+	// We'll prefer IPv4 and check it first (both might be available)
+	ipv4UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("93.184.216.34:3478"))
+	ipv4UDPConn, ipv4Err := net.DialUDP("udp4", nil, ipv4UDPAddr)
+	if ipv4Err == nil {
+		ip := ipv4UDPConn.LocalAddr().(*net.UDPAddr).IP
+		ipv4UDPConn.Close()
+		return ip, nil
+	}
+
+	ipv6UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("[2606:2800:220:1:248:1893:25c8:1946]:3478"))
+	ipv6UDPConn, ipv6Err := net.DialUDP("udp6", nil, ipv6UDPAddr)
+	if ipv6Err == nil {
+		ip := ipv6UDPConn.LocalAddr().(*net.UDPAddr).IP
+		ipv6UDPConn.Close()
+		return ip, nil
+	}
+
+	return nil, errors.Trace(ipv4Err)
+}
+
+// Given the IP of a local interface, get that interface info.
+func getInterfaceForLocalIP(ip net.IP) (*net.Interface, error) {
+	// Note that this function has no Windows-specific code and could be used elsewhere.
+
+	ifaces, err := net.Interfaces()
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	for _, iface := range ifaces {
+		addrs, err := iface.Addrs()
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		for _, addr := range addrs {
+			addrIP, _, err := net.ParseCIDR(addr.String())
+			if err != nil {
+				return nil, errors.Trace(err)
+			}
+
+			if addrIP.Equal(ip) {
+				return &iface, nil
+			}
+		}
+	}
+
+	return nil, errors.TraceNew("not found")
+}
+
+// Given the interface index, get info about the interface and its network.
+func getInterfaceInfo(index int) (networkID, description string, ifType winipcfg.IfType, err error) {
+	luid, err := winipcfg.LUIDFromIndex(uint32(index))
+	if err != nil {
+		return "", "", 0, errors.Trace(err)
+	}
+
+	ifrow, err := luid.Interface()
+	if err != nil {
+		return "", "", 0, errors.Trace(err)
+	}
+
+	description = ifrow.Description() + " " + ifrow.Alias()
+
+	ifType = ifrow.Type
+
+	var c ole.Connection
+	nlm, err := winnet.NewNetworkListManager(&c)
+	if err != nil {
+		return "", "", 0, errors.Trace(err)
+	}
+	defer nlm.Release()
+
+	netConns, err := nlm.GetNetworkConnections()
+	if err != nil {
+		return "", "", 0, errors.Trace(err)
+	}
+	defer netConns.Release()
+
+	for _, nc := range netConns {
+		ncAdapterID, err := nc.GetAdapterId()
+		if err != nil {
+			return "", "", 0, errors.Trace(err)
+		}
+		if ncAdapterID != ifrow.InterfaceGUID.String() {
+			continue
+		}
+
+		// Found the INetworkConnection for the target adapter.
+		// Get its network and network ID.
+
+		n, err := nc.GetNetwork()
+		if err != nil {
+			return "", "", 0, errors.Trace(err)
+		}
+		defer n.Release()
+
+		guid := ole.GUID{}
+		hr, _, _ := syscall.SyscallN(
+			n.VTable().GetNetworkId,
+			uintptr(unsafe.Pointer(n)),
+			uintptr(unsafe.Pointer(&guid)))
+		if hr != 0 {
+			return "", "", 0, errors.Tracef("GetNetworkId failed: %08x", hr)
+		}
+
+		networkID = guid.String()
+		return networkID, description, ifType, nil
+	}
+
+	return "", "", 0, errors.Tracef("network connection not found for interface %d", index)
+}
+
+// Get the connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the network with the given
+// interface type and description.
+// If the correct connection type can not be determined, "UNKNOWN" will be returned.
+func getConnectionType(ifType winipcfg.IfType, description string) string {
+	var connectionType string
+
+	switch ifType {
+	case winipcfg.IfTypeEthernetCSMACD:
+		connectionType = "WIRED"
+	case winipcfg.IfTypeIEEE80211:
+		connectionType = "WIFI"
+	case winipcfg.IfTypeWwanpp, winipcfg.IfTypeWwanpp2:
+		connectionType = "MOBILE"
+	case winipcfg.IfTypePPP, winipcfg.IfTypePropVirtual, winipcfg.IfTypeTunnel:
+		connectionType = "VPN"
+	default:
+		connectionType = "UNKNOWN"
+	}
+
+	if connectionType != "VPN" {
+		// The ifType doesn't indicate a VPN, but that's not well-defined, so we'll fall
+		// back to checking for certain words in the description. This feels like a hack,
+		// but research suggests that it's the best we can do.
+
+		description = strings.ToLower(description)
+		if strings.Contains(description, "vpn") ||
+			strings.Contains(description, "tunnel") ||
+			strings.Contains(description, "virtual") ||
+			strings.Contains(description, "tap") ||
+			strings.Contains(description, "l2tp") ||
+			strings.Contains(description, "sstp") ||
+			strings.Contains(description, "pptp") ||
+			strings.Contains(description, "openvpn") {
+			connectionType = "VPN"
+		}
+	}
+
+	return connectionType
+}
+
+func getNetworkID() (string, error) {
+	localAddr, err := getDefaultLocalAddr()
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+
+	iface, err := getInterfaceForLocalIP(localAddr)
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+
+	networkID, description, ifType, err := getInterfaceInfo(iface.Index)
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+
+	connectionType := getConnectionType(ifType, description)
+
+	compoundID := connectionType + "-" + strings.Trim(networkID, "{}")
+
+	return compoundID, nil
+}
+
+type result struct {
+	networkID string
+	err       error
+}
+
+var workThread struct {
+	init sync.Once
+	reqs chan (chan<- result)
+	err  error
+
+	cachedResult string
+	cacheExpiry  time.Time
+}
+
+// Get returns the compound network ID; see [psiphon.NetworkIDGetter] for details.
+// This function is safe to call concurrently from multiple goroutines.
+// Note that if this function is called immediately after a network change (within ~2000ms)
+// a transitory Network ID may be returned that will change on the next call. The caller
+// may wish to delay responding to a new Network ID until the value is confirmed.
+func Get() (string, error) {
+	// It is not clear if the COM NetworkListManager calls are threadsafe. We're using them
+	// read-only and they're probably fine, but we're not sure. Additionally, our networkID
+	// retrieval code is somewhat slow: 3.5ms. This function gets called by each connection
+	// attempt (in the horse race, etc.), so this extra time might add ~10% to a such an
+	// attempt. The value is very unlikely to change in a short amount of time, so it seems
+	// like a good optimization to cache the result. We'll restrict our work to single
+	// thread to achieve both goals.
+	workThread.init.Do(func() {
+		workThread.reqs = make(chan (chan<- result))
+
+		go func() {
+			const resultCacheDuration = 500 * time.Millisecond
+
+			// Go can switch the execution of a goroutine from one OS thread to another
+			// at (almost) any time. This may or may not be risky to do for our win32
+			// (and especially COM) calls, so we're going to explicitly lock this goroutine
+			// to a single OS thread. This shouldn't have any real impact on performance
+			// and will help protect against difficult-to-reproduce errors.
+			runtime.LockOSThread()
+			defer runtime.UnlockOSThread()
+
+			if err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED); err != nil {
+				workThread.err = errors.Trace(err)
+				close(workThread.reqs)
+				return
+			}
+			defer windows.CoUninitialize()
+
+			for resCh := range workThread.reqs {
+				if workThread.cachedResult != "" && workThread.cacheExpiry.After(time.Now()) {
+					resCh <- result{workThread.cachedResult, nil}
+				} else {
+					networkID, err := getNetworkID()
+					resCh <- result{networkID, err}
+					workThread.cachedResult = networkID
+					workThread.cacheExpiry = time.Now().Add(resultCacheDuration)
+				}
+			}
+		}()
+	})
+
+	resCh := make(chan result)
+	workThread.reqs <- resCh
+	res := <-resCh
+
+	if res.err != nil {
+		return "", errors.Trace(res.err)
+	}
+
+	return res.networkID, nil
+}

+ 72 - 0
psiphon/common/networkid/networkid_windows_test.go

@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2024, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package networkid
+
+import (
+	"testing"
+	"time"
+)
+
+// prevent compiler optimization
+var networkID string
+var err error
+
+// This test doesn't show anything very useful, as it will mostly be getting cached results
+func BenchmarkGet(b *testing.B) {
+	for i := 0; i < b.N; i++ {
+		networkID, err = Get()
+		if err != nil {
+			b.Fatalf("error: %v", err)
+		}
+	}
+}
+
+func TestGet(t *testing.T) {
+	gotNetworkID, err := Get()
+	if err != nil {
+		t.Errorf("error: %v", err)
+		return
+	}
+	if gotNetworkID == "" {
+		t.Error("got empty network ID")
+	}
+
+	// Call again immediately to get a cached result
+	gotNetworkID, err = Get()
+	if err != nil {
+		t.Errorf("error: %v", err)
+		return
+	}
+	if gotNetworkID == "" {
+		t.Error("got empty network ID")
+	}
+
+	// Wait until the cached result expires, so we get another fresh value
+	time.Sleep(2 * time.Second)
+
+	gotNetworkID, err = Get()
+	if err != nil {
+		t.Errorf("error: %v", err)
+		return
+	}
+	if gotNetworkID == "" {
+		t.Error("got empty network ID")
+	}
+}

+ 1 - 0
vendor/modules.txt

@@ -688,6 +688,7 @@ tailscale.com/util/vizerror
 tailscale.com/util/winutil
 tailscale.com/version
 tailscale.com/version/distro
+tailscale.com/wgengine/winnet
 # gitlab.com/yawning/obfs4.git => github.com/jmwample/obfs4 v0.0.0-20230725223418-2d2e5b4a16ba
 # github.com/pion/dtls/v2 => ./replace/dtls
 # github.com/pion/ice/v2 => ./replace/ice

+ 191 - 0
vendor/tailscale.com/wgengine/winnet/winnet.go

@@ -0,0 +1,191 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build windows
+
+package winnet
+
+import (
+	"fmt"
+	"syscall"
+	"unsafe"
+
+	"github.com/go-ole/go-ole"
+	"github.com/go-ole/go-ole/oleutil"
+)
+
+const CLSID_NetworkListManager = "{DCB00C01-570F-4A9B-8D69-199FDBA5723B}"
+
+var IID_INetwork = ole.NewGUID("{8A40A45D-055C-4B62-ABD7-6D613E2CEAEC}")
+var IID_INetworkConnection = ole.NewGUID("{DCB00005-570F-4A9B-8D69-199FDBA5723B}")
+
+type NetworkListManager struct {
+	d *ole.Dispatch
+}
+
+type INetworkConnection struct {
+	ole.IDispatch
+}
+
+type ConnectionList []*INetworkConnection
+
+type INetworkConnectionVtbl struct {
+	ole.IDispatchVtbl
+	GetNetwork                uintptr
+	Get_IsConnectedToInternet uintptr
+	Get_IsConnected           uintptr
+	GetConnectivity           uintptr
+	GetConnectionId           uintptr
+	GetAdapterId              uintptr
+	GetDomainType             uintptr
+}
+
+type INetwork struct {
+	ole.IDispatch
+}
+
+type INetworkVtbl struct {
+	ole.IDispatchVtbl
+	GetName                    uintptr
+	SetName                    uintptr
+	GetDescription             uintptr
+	SetDescription             uintptr
+	GetNetworkId               uintptr
+	GetDomainType              uintptr
+	GetNetworkConnections      uintptr
+	GetTimeCreatedAndConnected uintptr
+	Get_IsConnectedToInternet  uintptr
+	Get_IsConnected            uintptr
+	GetConnectivity            uintptr
+	GetCategory                uintptr
+	SetCategory                uintptr
+}
+
+func NewNetworkListManager(c *ole.Connection) (*NetworkListManager, error) {
+	err := c.Create(CLSID_NetworkListManager)
+	if err != nil {
+		return nil, err
+	}
+	defer c.Release()
+
+	d, err := c.Dispatch()
+	if err != nil {
+		return nil, err
+	}
+
+	return &NetworkListManager{
+		d: d,
+	}, nil
+}
+
+func (m *NetworkListManager) Release() {
+	m.d.Release()
+}
+
+func (cl ConnectionList) Release() {
+	for _, v := range cl {
+		v.Release()
+	}
+}
+
+func asIID(u ole.UnknownLike, iid *ole.GUID) (*ole.IDispatch, error) {
+	if u == nil {
+		return nil, fmt.Errorf("asIID: nil UnknownLike")
+	}
+
+	d, err := u.QueryInterface(iid)
+	u.Release()
+	if err != nil {
+		return nil, err
+	}
+	return d, nil
+}
+
+func (m *NetworkListManager) GetNetworkConnections() (ConnectionList, error) {
+	ncraw, err := m.d.Call("GetNetworkConnections")
+	if err != nil {
+		return nil, err
+	}
+
+	nli := ncraw.ToIDispatch()
+	if nli == nil {
+		return nil, fmt.Errorf("GetNetworkConnections: not IDispatch")
+	}
+
+	cl := ConnectionList{}
+
+	err = oleutil.ForEach(nli, func(v *ole.VARIANT) error {
+		nc, err := asIID(v.ToIUnknown(), IID_INetworkConnection)
+		if err != nil {
+			return err
+		}
+		nco := (*INetworkConnection)(unsafe.Pointer(nc))
+		cl = append(cl, nco)
+		return nil
+	})
+
+	if err != nil {
+		cl.Release()
+		return nil, err
+	}
+	return cl, nil
+}
+
+func (n *INetwork) GetName() (string, error) {
+	v, err := n.CallMethod("GetName")
+	if err != nil {
+		return "", err
+	}
+	return v.ToString(), err
+}
+
+func (n *INetwork) GetCategory() (int32, error) {
+	var result int32
+
+	r, _, _ := syscall.SyscallN(
+		n.VTable().GetCategory,
+		uintptr(unsafe.Pointer(n)),
+		uintptr(unsafe.Pointer(&result)),
+	)
+	if int32(r) < 0 {
+		return 0, ole.NewError(r)
+	}
+
+	return result, nil
+}
+
+func (n *INetwork) SetCategory(v int32) error {
+	r, _, _ := syscall.SyscallN(
+		n.VTable().SetCategory,
+		uintptr(unsafe.Pointer(n)),
+		uintptr(v),
+	)
+	if int32(r) < 0 {
+		return ole.NewError(r)
+	}
+
+	return nil
+}
+
+func (n *INetwork) VTable() *INetworkVtbl {
+	return (*INetworkVtbl)(unsafe.Pointer(n.RawVTable))
+}
+
+func (v *INetworkConnection) VTable() *INetworkConnectionVtbl {
+	return (*INetworkConnectionVtbl)(unsafe.Pointer(v.RawVTable))
+}
+
+func (v *INetworkConnection) GetNetwork() (*INetwork, error) {
+	var result *INetwork
+
+	r, _, _ := syscall.SyscallN(
+		v.VTable().GetNetwork,
+		uintptr(unsafe.Pointer(v)),
+		uintptr(unsafe.Pointer(&result)),
+	)
+	if int32(r) < 0 {
+		return nil, ole.NewError(r)
+	}
+
+	return result, nil
+}

+ 26 - 0
vendor/tailscale.com/wgengine/winnet/winnet_windows.go

@@ -0,0 +1,26 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package winnet
+
+import (
+	"fmt"
+	"syscall"
+	"unsafe"
+
+	"github.com/go-ole/go-ole"
+)
+
+func (v *INetworkConnection) GetAdapterId() (string, error) {
+	buf := ole.GUID{}
+	hr, _, _ := syscall.Syscall(
+		v.VTable().GetAdapterId,
+		2,
+		uintptr(unsafe.Pointer(v)),
+		uintptr(unsafe.Pointer(&buf)),
+		0)
+	if hr != 0 {
+		return "", fmt.Errorf("GetAdapterId failed: %08x", hr)
+	}
+	return buf.String(), nil
+}