networkid_windows.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. /*
  2. * Copyright (c) 2024, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package networkid
  20. import (
  21. "net"
  22. "net/netip"
  23. "runtime"
  24. "strings"
  25. "sync"
  26. "syscall"
  27. "unsafe"
  28. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  29. "github.com/go-ole/go-ole"
  30. "golang.org/x/sys/windows"
  31. "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
  32. "tailscale.com/wgengine/winnet"
  33. )
  34. func Enabled() bool {
  35. return true
  36. }
  37. // Get address associated with the default interface.
  38. func getDefaultLocalAddr() (net.IP, error) {
  39. // Note that this function has no Windows-specific code and could be used elsewhere.
  40. // This approach is described in psiphon/common/inproxy/pionNetwork.Interfaces()
  41. // The basic idea is that we initialize a UDP connection and see what local
  42. // address the system decides to use.
  43. // Note that no actual network request is made by these calls. They can be performed
  44. // with no network connectivity at all.
  45. // TODO: Use common test IP addresses in that function and this.
  46. // We'll prefer IPv4 and check it first (both might be available)
  47. ipv4UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("93.184.216.34:3478"))
  48. ipv4UDPConn, ipv4Err := net.DialUDP("udp4", nil, ipv4UDPAddr)
  49. if ipv4Err == nil {
  50. ip := ipv4UDPConn.LocalAddr().(*net.UDPAddr).IP
  51. ipv4UDPConn.Close()
  52. return ip, nil
  53. }
  54. ipv6UDPAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("[2606:2800:220:1:248:1893:25c8:1946]:3478"))
  55. ipv6UDPConn, ipv6Err := net.DialUDP("udp6", nil, ipv6UDPAddr)
  56. if ipv6Err == nil {
  57. ip := ipv6UDPConn.LocalAddr().(*net.UDPAddr).IP
  58. ipv6UDPConn.Close()
  59. return ip, nil
  60. }
  61. return nil, errors.Trace(ipv4Err)
  62. }
  63. // Given the IP of a local interface, get that interface info.
  64. func getInterfaceForLocalIP(ip net.IP) (*net.Interface, error) {
  65. // Note that this function has no Windows-specific code and could be used elsewhere.
  66. ifaces, err := net.Interfaces()
  67. if err != nil {
  68. return nil, errors.Trace(err)
  69. }
  70. for _, iface := range ifaces {
  71. addrs, err := iface.Addrs()
  72. if err != nil {
  73. return nil, errors.Trace(err)
  74. }
  75. for _, addr := range addrs {
  76. addrIP, _, err := net.ParseCIDR(addr.String())
  77. if err != nil {
  78. return nil, errors.Trace(err)
  79. }
  80. if addrIP.Equal(ip) {
  81. return &iface, nil
  82. }
  83. }
  84. }
  85. return nil, errors.TraceNew("not found")
  86. }
  87. // Given the interface index, get info about the interface and its network.
  88. func getInterfaceInfo(index int) (networkID, description string, ifType winipcfg.IfType, err error) {
  89. luid, err := winipcfg.LUIDFromIndex(uint32(index))
  90. if err != nil {
  91. return "", "", 0, errors.Trace(err)
  92. }
  93. ifrow, err := luid.Interface()
  94. if err != nil {
  95. return "", "", 0, errors.Trace(err)
  96. }
  97. description = ifrow.Description() + " " + ifrow.Alias()
  98. ifType = ifrow.Type
  99. var c ole.Connection
  100. nlm, err := winnet.NewNetworkListManager(&c)
  101. if err != nil {
  102. return "", "", 0, errors.Trace(err)
  103. }
  104. defer nlm.Release()
  105. netConns, err := nlm.GetNetworkConnections()
  106. if err != nil {
  107. return "", "", 0, errors.Trace(err)
  108. }
  109. defer netConns.Release()
  110. for _, nc := range netConns {
  111. ncAdapterID, err := nc.GetAdapterId()
  112. if err != nil {
  113. return "", "", 0, errors.Trace(err)
  114. }
  115. if ncAdapterID != ifrow.InterfaceGUID.String() {
  116. continue
  117. }
  118. // Found the INetworkConnection for the target adapter.
  119. // Get its network and network ID.
  120. n, err := nc.GetNetwork()
  121. if err != nil {
  122. return "", "", 0, errors.Trace(err)
  123. }
  124. defer n.Release()
  125. guid := ole.GUID{}
  126. hr, _, _ := syscall.SyscallN(
  127. n.VTable().GetNetworkId,
  128. uintptr(unsafe.Pointer(n)),
  129. uintptr(unsafe.Pointer(&guid)))
  130. if hr != 0 {
  131. return "", "", 0, errors.Tracef("GetNetworkId failed: %08x", hr)
  132. }
  133. networkID = guid.String()
  134. return networkID, description, ifType, nil
  135. }
  136. return "", "", 0, errors.Tracef("network connection not found for interface %d", index)
  137. }
  138. // Get the connection type ("WIRED", "WIFI", "MOBILE", "VPN") of the network with the given
  139. // interface type and description.
  140. // If the correct connection type can not be determined, "UNKNOWN" will be returned.
  141. func getConnectionType(ifType winipcfg.IfType, description string) string {
  142. var connectionType string
  143. switch ifType {
  144. case winipcfg.IfTypeEthernetCSMACD, winipcfg.IfTypeEthernet3Mbit, winipcfg.IfTypeFastether, winipcfg.IfTypeFastetherFX, winipcfg.IfTypeGigabitethernet, winipcfg.IfTypeIEEE80212, winipcfg.IfTypeDigitalpowerline:
  145. connectionType = "WIRED"
  146. case winipcfg.IfTypeIEEE80211:
  147. connectionType = "WIFI"
  148. case winipcfg.IfTypeWwanpp, winipcfg.IfTypeWwanpp2:
  149. connectionType = "MOBILE"
  150. case winipcfg.IfTypePPP, winipcfg.IfTypePropVirtual, winipcfg.IfTypeTunnel:
  151. connectionType = "VPN"
  152. default:
  153. connectionType = "UNKNOWN"
  154. }
  155. if connectionType != "VPN" {
  156. // The ifType doesn't indicate a VPN, but that's not well-defined, so we'll fall
  157. // back to checking for certain words in the description. This feels like a hack,
  158. // but research suggests that it's the best we can do.
  159. description = strings.ToLower(description)
  160. if strings.Contains(description, "vpn") ||
  161. strings.Contains(description, "tunnel") ||
  162. strings.Contains(description, "virtual") ||
  163. strings.Contains(description, "tap") ||
  164. strings.Contains(description, "l2tp") ||
  165. strings.Contains(description, "sstp") ||
  166. strings.Contains(description, "pptp") ||
  167. strings.Contains(description, "openvpn") {
  168. connectionType = "VPN"
  169. }
  170. }
  171. return connectionType
  172. }
  173. func getNetworkID() (string, error) {
  174. localAddr, err := getDefaultLocalAddr()
  175. if err != nil {
  176. return "", errors.Trace(err)
  177. }
  178. iface, err := getInterfaceForLocalIP(localAddr)
  179. if err != nil {
  180. return "", errors.Trace(err)
  181. }
  182. networkID, description, ifType, err := getInterfaceInfo(iface.Index)
  183. if err != nil {
  184. return "", errors.Trace(err)
  185. }
  186. connectionType := getConnectionType(ifType, description)
  187. compoundID := connectionType + "-" + strings.Trim(networkID, "{}")
  188. return compoundID, nil
  189. }
  190. type result struct {
  191. networkID string
  192. err error
  193. }
  194. var workThread struct {
  195. init sync.Once
  196. reqs chan (chan<- result)
  197. err error
  198. }
  199. // Get returns the compound network ID; see [psiphon.NetworkIDGetter] for details.
  200. // This function is safe to call concurrently from multiple goroutines.
  201. // Note that if this function is called immediately after a network change (within ~2000ms)
  202. // a transitory Network ID may be returned that will change on the next call. The caller
  203. // may wish to delay responding to a new Network ID until the value is confirmed.
  204. func Get() (string, error) {
  205. // It is not clear if the COM NetworkListManager calls are threadsafe.
  206. // We're using them read-only and they're probably fine, but we're not
  207. // sure. We'll restrict our work to single thread.
  208. workThread.init.Do(func() {
  209. workThread.reqs = make(chan (chan<- result))
  210. go func() {
  211. // Go can switch the execution of a goroutine from one OS thread to another
  212. // at (almost) any time. This may or may not be risky to do for our win32
  213. // (and especially COM) calls, so we're going to explicitly lock this goroutine
  214. // to a single OS thread. This shouldn't have any real impact on performance
  215. // and will help protect against difficult-to-reproduce errors.
  216. runtime.LockOSThread()
  217. defer runtime.UnlockOSThread()
  218. if err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED); err != nil {
  219. workThread.err = errors.Trace(err)
  220. close(workThread.reqs)
  221. return
  222. }
  223. defer windows.CoUninitialize()
  224. for resCh := range workThread.reqs {
  225. networkID, err := getNetworkID()
  226. resCh <- result{networkID, err}
  227. }
  228. }()
  229. })
  230. resCh := make(chan result)
  231. workThread.reqs <- resCh
  232. res := <-resCh
  233. if res.err != nil {
  234. return "", errors.Trace(res.err)
  235. }
  236. return res.networkID, nil
  237. }