Просмотр исходного кода

Merge pull request #742 from geebee/feature/programmatic-network-config

replace RunNetworkConfigCommand calls with netlink/writing to proc
Rod Hynes 8 месяцев назад
Родитель
Сommit
5ef8b1b706
2 измененных файлов с 117 добавлено и 86 удалено
  1. 1 1
      go.mod
  2. 116 85
      psiphon/common/tun/tun_linux.go

+ 1 - 1
go.mod

@@ -83,6 +83,7 @@ require (
 	github.com/sirupsen/logrus v1.9.3
 	github.com/sirupsen/logrus v1.9.3
 	github.com/stretchr/testify v1.9.0
 	github.com/stretchr/testify v1.9.0
 	github.com/syndtr/gocapability v0.0.0-20170704070218-db04d3cc01c8
 	github.com/syndtr/gocapability v0.0.0-20170704070218-db04d3cc01c8
+	github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85
 	github.com/wader/filtertransport v0.0.0-20200316221534-bdd9e61eee78
 	github.com/wader/filtertransport v0.0.0-20200316221534-bdd9e61eee78
 	github.com/wlynxg/anet v0.0.5
 	github.com/wlynxg/anet v0.0.5
 	golang.org/x/crypto v0.35.0
 	golang.org/x/crypto v0.35.0
@@ -143,7 +144,6 @@ require (
 	github.com/shadowsocks/go-shadowsocks2 v0.1.5 // indirect
 	github.com/shadowsocks/go-shadowsocks2 v0.1.5 // indirect
 	github.com/shoenig/go-m1cpu v0.1.6 // indirect
 	github.com/shoenig/go-m1cpu v0.1.6 // indirect
 	github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 // indirect
 	github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 // indirect
-	github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 // indirect
 	github.com/tklauser/go-sysconf v0.3.12 // indirect
 	github.com/tklauser/go-sysconf v0.3.12 // indirect
 	github.com/tklauser/numcpus v0.6.1 // indirect
 	github.com/tklauser/numcpus v0.6.1 // indirect
 	github.com/vishvananda/netlink v1.2.1-beta.2 // indirect
 	github.com/vishvananda/netlink v1.2.1-beta.2 // indirect

+ 116 - 85
psiphon/common/tun/tun_linux.go

@@ -23,13 +23,14 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"os"
 	"os"
-	"strconv"
+	"path/filepath"
 	"strings"
 	"strings"
 	"syscall"
 	"syscall"
 	"unsafe"
 	"unsafe"
 
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/tailscale/netlink"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
@@ -160,66 +161,94 @@ func (device *Device) writeTunPacket(packet []byte) error {
 func resetNATTables(
 func resetNATTables(
 	config *ServerConfig,
 	config *ServerConfig,
 	IPAddress net.IP) error {
 	IPAddress net.IP) error {
-
-	// Uses the "conntrack" command, which is often not installed by default.
-
 	// conntrack --delete -src-nat --orig-src <address> will clear NAT tables of existing
 	// conntrack --delete -src-nat --orig-src <address> will clear NAT tables of existing
 	// connections, making it less likely that traffic for a previous client using the
 	// connections, making it less likely that traffic for a previous client using the
 	// specified address will be forwarded to a new client using this address. This is in
 	// specified address will be forwarded to a new client using this address. This is in
 	// the already unlikely event that there's still in-flight traffic when the address is
 	// the already unlikely event that there's still in-flight traffic when the address is
 	// recycled.
 	// recycled.
 
 
-	err := common.RunNetworkConfigCommand(
-		config.Logger,
-		config.SudoNetworkConfigCommands,
-		"conntrack",
-		"--delete",
-		"--src-nat",
-		"--orig-src",
-		IPAddress.String())
-	if err != nil {
+	// The netlink library does not expose the facilities for conclusively determining if
+	// src-nat has been applied to an individual flow, so replacing the previous call to
+	// the conntrack binary (see the comment above) with the code below is not a 1-to-1
+	// replacement. Since no other non-SNAT flows for these IPs that might exist need to
+	// be retained at the time resetNATTables is called, we're now skipping that check.
+
+	var family netlink.InetFamily
+	if IPAddress.To4() != nil {
+		family = unix.AF_INET
+	} else if IPAddress.To16() != nil {
+		family = unix.AF_INET6
+	} else {
+		return errors.TraceNew("invalid IP address family")
+	}
 
 
-		// conntrack exits with this error message when there are no flows
-		// to delete, which is not a failure condition.
-		if strings.Contains(err.Error(), "0 flow entries have been deleted") {
-			return nil
-		}
+	filter := &netlink.ConntrackFilter{}
+	_ = filter.AddIP(netlink.ConntrackOrigSrcIP, IPAddress)
 
 
+	_, err := netlink.ConntrackDeleteFilter(netlink.ConntrackTable, family, filter)
+	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
 	return nil
 	return nil
 }
 }
 
 
+func setSysctl(key, value string) error {
+	err := os.WriteFile(
+		filepath.Join("/proc/sys", strings.ReplaceAll(key, ".", "/")),
+		[]byte(value),
+		0o644,
+	)
+	if err != nil {
+		return errors.Tracef("failed to write sysctl %s=%s: %w", key, value, err)
+	}
+
+	return nil
+}
+
 func configureServerInterface(
 func configureServerInterface(
 	config *ServerConfig,
 	config *ServerConfig,
 	tunDeviceName string) error {
 	tunDeviceName string) error {
 
 
 	// Set tun device network addresses and MTU
 	// Set tun device network addresses and MTU
 
 
-	IPv4Address, IPv4Netmask, err := splitIPMask(serverIPv4AddressCIDR)
+	link, err := netlink.LinkByName(tunDeviceName)
 	if err != nil {
 	if err != nil {
-		return errors.Trace(err)
+		return errors.Tracef("failed to get interface %s: %w", tunDeviceName, err)
 	}
 	}
 
 
-	err = common.RunNetworkConfigCommand(
-		config.Logger,
-		config.SudoNetworkConfigCommands,
-		"ifconfig",
-		tunDeviceName,
-		IPv4Address, "netmask", IPv4Netmask,
-		"mtu", strconv.Itoa(getMTU(config.MTU)),
-		"up")
+	_, ipv4Net, err := net.ParseCIDR(serverIPv4AddressCIDR)
 	if err != nil {
 	if err != nil {
-		return errors.Trace(err)
+		return errors.Tracef("failed to parse server IPv4 address: %s: %w", serverIPv4AddressCIDR, err)
+	}
+
+	ipv4Addr := &netlink.Addr{IPNet: ipv4Net}
+	err = netlink.AddrAdd(link, ipv4Addr)
+	if err != nil {
+		return errors.Tracef("failed to add IPv4 address to interface: %s: %w", ipv4Net.String(), err)
+	}
+
+	err = netlink.LinkSetMTU(link, getMTU(config.MTU))
+	if err != nil {
+		return errors.Tracef("failed to set interface MTU: %d: %w", config.MTU, err)
+	}
+
+	err = netlink.LinkSetUp(link)
+	if err != nil {
+		return errors.Tracef("failed to set interface up: %w", err)
+	}
+
+	_, ipv6Net, err := net.ParseCIDR(serverIPv6AddressCIDR)
+	if err != nil {
+		err = errors.Tracef("failed to parse server IPv6 address: %s: %w", serverIPv4AddressCIDR, err)
+	} else {
+		ipv6Addr := &netlink.Addr{IPNet: ipv6Net}
+		err = netlink.AddrAdd(link, ipv6Addr)
+		if err != nil {
+			err = errors.Tracef("failed to add IPv6 address to interface: %s: %w", ipv6Net.String(), err)
+		}
 	}
 	}
 
 
-	err = common.RunNetworkConfigCommand(
-		config.Logger,
-		config.SudoNetworkConfigCommands,
-		"ifconfig",
-		tunDeviceName,
-		"add", serverIPv6AddressCIDR)
 	if err != nil {
 	if err != nil {
 		if config.AllowNoIPv6NetworkConfiguration {
 		if config.AllowNoIPv6NetworkConfiguration {
 			config.Logger.WithTraceFields(
 			config.Logger.WithTraceFields(
@@ -240,20 +269,12 @@ func configureServerInterface(
 
 
 	// TODO: need only set forwarding for specific interfaces?
 	// TODO: need only set forwarding for specific interfaces?
 
 
-	err = common.RunNetworkConfigCommand(
-		config.Logger,
-		config.SudoNetworkConfigCommands,
-		"sysctl",
-		"net.ipv4.conf.all.forwarding=1")
+	err = setSysctl("net.ipv4.conf.all.forwarding", "1")
 	if err != nil {
 	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
-	err = common.RunNetworkConfigCommand(
-		config.Logger,
-		config.SudoNetworkConfigCommands,
-		"sysctl",
-		"net.ipv6.conf.all.forwarding=1")
+	err = setSysctl("net.ipv6.conf.all.forwarding", "1")
 	if err != nil {
 	if err != nil {
 		if config.AllowNoIPv6NetworkConfiguration {
 		if config.AllowNoIPv6NetworkConfiguration {
 			config.Logger.WithTraceFields(
 			config.Logger.WithTraceFields(
@@ -311,31 +332,40 @@ func configureClientInterface(
 	tunDeviceName string) error {
 	tunDeviceName string) error {
 
 
 	// Set tun device network addresses and MTU
 	// Set tun device network addresses and MTU
+	link, err := netlink.LinkByName(tunDeviceName)
+	if err != nil {
+		return errors.Trace(fmt.Errorf("failed to get interface %s: %w", tunDeviceName, err))
+	}
 
 
-	IPv4Address, IPv4Netmask, err := splitIPMask(config.IPv4AddressCIDR)
+	_, ipv4Net, err := net.ParseCIDR(config.IPv4AddressCIDR)
 	if err != nil {
 	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
-	err = common.RunNetworkConfigCommand(
-		config.Logger,
-		config.SudoNetworkConfigCommands,
-		"ifconfig",
-		tunDeviceName,
-		IPv4Address,
-		"netmask", IPv4Netmask,
-		"mtu", strconv.Itoa(getMTU(config.MTU)),
-		"up")
-	if err != nil {
+	ipv4Addr := &netlink.Addr{IPNet: ipv4Net}
+	if err := netlink.AddrAdd(link, ipv4Addr); err != nil {
+		return errors.Trace(err)
+	}
+
+	if err := netlink.LinkSetMTU(link, getMTU(config.MTU)); err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
-	err = common.RunNetworkConfigCommand(
-		config.Logger,
-		config.SudoNetworkConfigCommands,
-		"ifconfig",
-		tunDeviceName,
-		"add", config.IPv6AddressCIDR)
+	if err := netlink.LinkSetUp(link); err != nil {
+		return errors.Trace(err)
+	}
+
+	_, ipv6Net, err := net.ParseCIDR(config.IPv6AddressCIDR)
+	if err != nil {
+		err = errors.Trace(err)
+	} else {
+		ipv6Addr := &netlink.Addr{IPNet: ipv6Net}
+		err = netlink.AddrAdd(link, ipv6Addr)
+		if err != nil {
+			err = errors.Trace(err)
+		}
+	}
+
 	if err != nil {
 	if err != nil {
 		if config.AllowNoIPv6NetworkConfiguration {
 		if config.AllowNoIPv6NetworkConfiguration {
 			config.Logger.WithTraceFields(
 			config.Logger.WithTraceFields(
@@ -371,14 +401,27 @@ func configureClientInterface(
 		// Note: use "replace" instead of "add" as route from
 		// Note: use "replace" instead of "add" as route from
 		// previous run (e.g., tun_test case) may not yet be cleared.
 		// previous run (e.g., tun_test case) may not yet be cleared.
 
 
-		err = common.RunNetworkConfigCommand(
-			config.Logger,
-			config.SudoNetworkConfigCommands,
-			"ip",
-			"-6",
-			"route", "replace",
-			destination,
-			"dev", tunDeviceName)
+		link, err := netlink.LinkByName(tunDeviceName)
+		if err != nil {
+			err = errors.Trace(err)
+		} else {
+			_, destNet, parseErr := net.ParseCIDR(destination)
+			if parseErr != nil {
+				err = errors.Trace(err)
+			} else {
+				route := &netlink.Route{
+					LinkIndex: link.Attrs().Index,
+					Dst:       destNet,
+					Family:    netlink.FAMILY_V6,
+				}
+
+				err = netlink.RouteReplace(route)
+				if err != nil {
+					err = errors.Trace(err)
+				}
+			}
+		}
+
 		if err != nil {
 		if err != nil {
 			if config.AllowNoIPv6NetworkConfiguration {
 			if config.AllowNoIPv6NetworkConfiguration {
 				config.Logger.WithTraceFields(
 				config.Logger.WithTraceFields(
@@ -413,29 +456,17 @@ func fixBindToDevice(logger common.Logger, useSudo bool, tunDeviceName string) e
 	// > https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt and
 	// > https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt and
 	// > RFC3704)
 	// > RFC3704)
 
 
-	err := common.RunNetworkConfigCommand(
-		logger,
-		useSudo,
-		"sysctl",
-		"net.ipv4.conf.all.accept_local=1")
+	err := setSysctl("net.ipv4.conf.all.accept_local", "1")
 	if err != nil {
 	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
-	err = common.RunNetworkConfigCommand(
-		logger,
-		useSudo,
-		"sysctl",
-		"net.ipv4.conf.all.rp_filter=0")
+	err = setSysctl("net.ipv4.conf.all.rp_filter", "0")
 	if err != nil {
 	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
-	err = common.RunNetworkConfigCommand(
-		logger,
-		useSudo,
-		"sysctl",
-		fmt.Sprintf("net.ipv4.conf.%s.rp_filter=0", tunDeviceName))
+	err = setSysctl(fmt.Sprintf("net.ipv4.conf.%s.rp_filter", tunDeviceName), "0")
 	if err != nil {
 	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}