Przeglądaj źródła

fixes resulting from feedback

Adam Pritchard 1 rok temu
rodzic
commit
560d1da978

+ 2 - 0
.github/workflows/tests.yml

@@ -100,6 +100,7 @@ jobs:
           go test -v -timeout 30m -race ./psiphon
           go test -v -race ./ClientLibrary/clientlib
           go test -v -race ./Server/logging/analysis
+          go test -v -race ./psiphon/common/networkid
 
       # TODO: fix and re-enable test
       # sudo -E env "PATH=$PATH" go test -v -covermode=count -coverprofile=tun.coverprofile ./psiphon/common/tun
@@ -133,6 +134,7 @@ jobs:
           go test -v -timeout 30m -covermode=count -coverprofile=psiphon.coverprofile ./psiphon
           go test -v -covermode=count -coverprofile=clientlib.coverprofile ./ClientLibrary/clientlib
           go test -v -covermode=count -coverprofile=analysis.coverprofile ./Server/logging/analysis
+          go test -v -covermode=count -coverprofile=networkid.coverprofile ./psiphon/common/networkid
           $GOPATH/bin/gover
           $GOPATH/bin/goveralls -coverprofile=gover.coverprofile -service=github -repotoken "$COVERALLS_TOKEN"
 

+ 10 - 4
psiphon/common/networkid/networkid_windows.go

@@ -20,7 +20,6 @@
 package networkid
 
 import (
-	"fmt"
 	"net"
 	"net/netip"
 	"runtime"
@@ -48,6 +47,8 @@ func getDefaultLocalAddr() (net.IP, error) {
 	// This approach is described in psiphon/common/inproxy/pionNetwork.Interfaces()
 	// The basic idea is that we initialize a UDP connection and see what local
 	// address the system decides to use.
+	// Note that no actual network request is made by these calls. They can be performed
+	// with no network connectivity at all.
 	// TODO: Use common test IP addresses in that function and this.
 
 	// We'll prefer IPv4 and check it first (both might be available)
@@ -153,14 +154,14 @@ func getInterfaceInfo(index int) (networkID, description string, ifType winipcfg
 			uintptr(unsafe.Pointer(n)),
 			uintptr(unsafe.Pointer(&guid)))
 		if hr != 0 {
-			return "", "", 0, fmt.Errorf("GetNetworkId failed: %08x", hr)
+			return "", "", 0, errors.Tracef("GetNetworkId failed: %08x", hr)
 		}
 
 		networkID = guid.String()
 		return networkID, description, ifType, nil
 	}
 
-	return "", "", 0, fmt.Errorf("network connection not found for interface %d", index)
+	return "", "", 0, errors.Tracef("network connection not found for interface %d", index)
 }
 
 // Get the connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the network with the given
@@ -254,8 +255,13 @@ func Get() (string, error) {
 		workThread.reqs = make(chan (chan<- result))
 
 		go func() {
-			const resultCacheDuration = time.Second
+			const resultCacheDuration = 500 * time.Millisecond
 
+			// Go can switch the execution of a goroutine from one OS thread to another
+			// at (almost) any time. This may or may not be risky to do for our win32
+			// (and especially COM) calls, so we're going to explicitly lock this goroutine
+			// to a single OS thread. This shouldn't have any real impact on performance
+			// and will help protect against difficult-to-reproduce errors.
 			runtime.LockOSThread()
 			defer runtime.UnlockOSThread()
 

+ 191 - 0
vendor/tailscale.com/wgengine/winnet/winnet.go

@@ -0,0 +1,191 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build windows
+
+package winnet
+
+import (
+	"fmt"
+	"syscall"
+	"unsafe"
+
+	"github.com/go-ole/go-ole"
+	"github.com/go-ole/go-ole/oleutil"
+)
+
+const CLSID_NetworkListManager = "{DCB00C01-570F-4A9B-8D69-199FDBA5723B}"
+
+var IID_INetwork = ole.NewGUID("{8A40A45D-055C-4B62-ABD7-6D613E2CEAEC}")
+var IID_INetworkConnection = ole.NewGUID("{DCB00005-570F-4A9B-8D69-199FDBA5723B}")
+
+type NetworkListManager struct {
+	d *ole.Dispatch
+}
+
+type INetworkConnection struct {
+	ole.IDispatch
+}
+
+type ConnectionList []*INetworkConnection
+
+type INetworkConnectionVtbl struct {
+	ole.IDispatchVtbl
+	GetNetwork                uintptr
+	Get_IsConnectedToInternet uintptr
+	Get_IsConnected           uintptr
+	GetConnectivity           uintptr
+	GetConnectionId           uintptr
+	GetAdapterId              uintptr
+	GetDomainType             uintptr
+}
+
+type INetwork struct {
+	ole.IDispatch
+}
+
+type INetworkVtbl struct {
+	ole.IDispatchVtbl
+	GetName                    uintptr
+	SetName                    uintptr
+	GetDescription             uintptr
+	SetDescription             uintptr
+	GetNetworkId               uintptr
+	GetDomainType              uintptr
+	GetNetworkConnections      uintptr
+	GetTimeCreatedAndConnected uintptr
+	Get_IsConnectedToInternet  uintptr
+	Get_IsConnected            uintptr
+	GetConnectivity            uintptr
+	GetCategory                uintptr
+	SetCategory                uintptr
+}
+
+func NewNetworkListManager(c *ole.Connection) (*NetworkListManager, error) {
+	err := c.Create(CLSID_NetworkListManager)
+	if err != nil {
+		return nil, err
+	}
+	defer c.Release()
+
+	d, err := c.Dispatch()
+	if err != nil {
+		return nil, err
+	}
+
+	return &NetworkListManager{
+		d: d,
+	}, nil
+}
+
+func (m *NetworkListManager) Release() {
+	m.d.Release()
+}
+
+func (cl ConnectionList) Release() {
+	for _, v := range cl {
+		v.Release()
+	}
+}
+
+func asIID(u ole.UnknownLike, iid *ole.GUID) (*ole.IDispatch, error) {
+	if u == nil {
+		return nil, fmt.Errorf("asIID: nil UnknownLike")
+	}
+
+	d, err := u.QueryInterface(iid)
+	u.Release()
+	if err != nil {
+		return nil, err
+	}
+	return d, nil
+}
+
+func (m *NetworkListManager) GetNetworkConnections() (ConnectionList, error) {
+	ncraw, err := m.d.Call("GetNetworkConnections")
+	if err != nil {
+		return nil, err
+	}
+
+	nli := ncraw.ToIDispatch()
+	if nli == nil {
+		return nil, fmt.Errorf("GetNetworkConnections: not IDispatch")
+	}
+
+	cl := ConnectionList{}
+
+	err = oleutil.ForEach(nli, func(v *ole.VARIANT) error {
+		nc, err := asIID(v.ToIUnknown(), IID_INetworkConnection)
+		if err != nil {
+			return err
+		}
+		nco := (*INetworkConnection)(unsafe.Pointer(nc))
+		cl = append(cl, nco)
+		return nil
+	})
+
+	if err != nil {
+		cl.Release()
+		return nil, err
+	}
+	return cl, nil
+}
+
+func (n *INetwork) GetName() (string, error) {
+	v, err := n.CallMethod("GetName")
+	if err != nil {
+		return "", err
+	}
+	return v.ToString(), err
+}
+
+func (n *INetwork) GetCategory() (int32, error) {
+	var result int32
+
+	r, _, _ := syscall.SyscallN(
+		n.VTable().GetCategory,
+		uintptr(unsafe.Pointer(n)),
+		uintptr(unsafe.Pointer(&result)),
+	)
+	if int32(r) < 0 {
+		return 0, ole.NewError(r)
+	}
+
+	return result, nil
+}
+
+func (n *INetwork) SetCategory(v int32) error {
+	r, _, _ := syscall.SyscallN(
+		n.VTable().SetCategory,
+		uintptr(unsafe.Pointer(n)),
+		uintptr(v),
+	)
+	if int32(r) < 0 {
+		return ole.NewError(r)
+	}
+
+	return nil
+}
+
+func (n *INetwork) VTable() *INetworkVtbl {
+	return (*INetworkVtbl)(unsafe.Pointer(n.RawVTable))
+}
+
+func (v *INetworkConnection) VTable() *INetworkConnectionVtbl {
+	return (*INetworkConnectionVtbl)(unsafe.Pointer(v.RawVTable))
+}
+
+func (v *INetworkConnection) GetNetwork() (*INetwork, error) {
+	var result *INetwork
+
+	r, _, _ := syscall.SyscallN(
+		v.VTable().GetNetwork,
+		uintptr(unsafe.Pointer(v)),
+		uintptr(unsafe.Pointer(&result)),
+	)
+	if int32(r) < 0 {
+		return nil, ole.NewError(r)
+	}
+
+	return result, nil
+}

+ 26 - 0
vendor/tailscale.com/wgengine/winnet/winnet_windows.go

@@ -0,0 +1,26 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package winnet
+
+import (
+	"fmt"
+	"syscall"
+	"unsafe"
+
+	"github.com/go-ole/go-ole"
+)
+
+func (v *INetworkConnection) GetAdapterId() (string, error) {
+	buf := ole.GUID{}
+	hr, _, _ := syscall.Syscall(
+		v.VTable().GetAdapterId,
+		2,
+		uintptr(unsafe.Pointer(v)),
+		uintptr(unsafe.Pointer(&buf)),
+		0)
+	if hr != 0 {
+		return "", fmt.Errorf("GetAdapterId failed: %08x", hr)
+	}
+	return buf.String(), nil
+}