Răsfoiți Sursa

Merge pull request #339 from rod-hynes/master

new tcpDial logic; automated test fixes; client_build_rev
Rod Hynes 9 ani în urmă
părinte
comite
04d47a77cc

+ 4 - 4
psiphon/TCPConn.go

@@ -131,9 +131,9 @@ func interruptibleTCPDial(addr string, config *DialConfig) (*TCPConn, error) {
 		var netConn net.Conn
 		var netConn net.Conn
 		var err error
 		var err error
 		if config.UpstreamProxyUrl != "" {
 		if config.UpstreamProxyUrl != "" {
-			netConn, err = proxiedTcpDial(addr, config, conn.dialResult)
+			netConn, err = proxiedTcpDial(addr, config)
 		} else {
 		} else {
-			netConn, err = tcpDial(addr, config, conn.dialResult)
+			netConn, err = tcpDial(addr, config)
 		}
 		}
 
 
 		// Mutex is necessary for referencing conn.isClosed and conn.Conn as
 		// Mutex is necessary for referencing conn.isClosed and conn.Conn as
@@ -172,9 +172,9 @@ func interruptibleTCPDial(addr string, config *DialConfig) (*TCPConn, error) {
 
 
 // proxiedTcpDial wraps a tcpDial call in an upstreamproxy dial.
 // proxiedTcpDial wraps a tcpDial call in an upstreamproxy dial.
 func proxiedTcpDial(
 func proxiedTcpDial(
-	addr string, config *DialConfig, dialResult chan error) (net.Conn, error) {
+	addr string, config *DialConfig) (net.Conn, error) {
 	dialer := func(network, addr string) (net.Conn, error) {
 	dialer := func(network, addr string) (net.Conn, error) {
-		return tcpDial(addr, config, dialResult)
+		return tcpDial(addr, config)
 	}
 	}
 
 
 	upstreamDialer := upstreamproxy.NewProxyDialFunc(
 	upstreamDialer := upstreamproxy.NewProxyDialFunc(

+ 129 - 69
psiphon/TCPConn_bind.go

@@ -24,12 +24,14 @@ package psiphon
 import (
 import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"math/rand"
 	"net"
 	"net"
 	"os"
 	"os"
 	"strconv"
 	"strconv"
 	"syscall"
 	"syscall"
-	"time"
 
 
+	"github.com/Psiphon-Inc/goarista/monotime"
+	"github.com/Psiphon-Inc/goselect"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 )
 
 
@@ -37,19 +39,9 @@ import (
 //
 //
 // To implement socket device binding, the lower-level syscall APIs are used.
 // To implement socket device binding, the lower-level syscall APIs are used.
 // The sequence of syscalls in this implementation are taken from:
 // The sequence of syscalls in this implementation are taken from:
-// https://code.google.com/p/go/issues/detail?id=6966
-func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn, error) {
-
-	// Like interruption, this timeout doesn't stop this connection goroutine,
-	// it just unblocks the calling interruptibleTCPDial.
-	if config.ConnectTimeout != 0 {
-		time.AfterFunc(config.ConnectTimeout, func() {
-			select {
-			case dialResult <- errors.New("connect timeout"):
-			default:
-			}
-		})
-	}
+// https://github.com/golang/go/issues/6966
+// (originally: https://code.google.com/p/go/issues/detail?id=6966)
+func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
 
 
 	// Get the remote IP and port, resolving a domain name if necessary
 	// Get the remote IP and port, resolving a domain name if necessary
 	host, strPort, err := net.SplitHostPort(addr)
 	host, strPort, err := net.SplitHostPort(addr)
@@ -68,70 +60,138 @@ func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn,
 		return nil, common.ContextError(errors.New("no IP address"))
 		return nil, common.ContextError(errors.New("no IP address"))
 	}
 	}
 
 
-	// Select an IP at random from the list, so we're not always
-	// trying the same IP (when > 1) which may be blocked.
-	// TODO: retry all IPs until one connects? For now, this retry
-	// will happen on subsequent TCPDial calls, when a different IP
-	// is selected.
-	index, err := common.MakeSecureRandomInt(len(ipAddrs))
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
+	// Iterate over a pseudorandom permutation of the destination
+	// IPs and attempt connections.
+	//
+	// Only continue retrying as long as the initial ConnectTimeout
+	// has not expired. Unlike net.Dial, we do not fractionalize the
+	// timeout, as the ConnectTimeout is generally intended to apply
+	// to a single attempt. So these serial retries are most useful
+	// in cases of immediate failure, such as "no route to host"
+	// errors when a host resolves to both IPv4 and IPv6 but IPv6
+	// addresses are unreachable.
+	// Retries at higher levels cover other cases: e.g.,
+	// Controller.remoteServerListFetcher will retry its entire
+	// operation and tcpDial will try a new permutation; or similarly,
+	// Controller.establishCandidateGenerator will retry a candidate
+	// tunnel server dials.
 
 
-	var ipv4 [4]byte
-	var ipv6 [16]byte
-	var domain int
-	ipAddr := ipAddrs[index]
-
-	// Get address type (IPv4 or IPv6)
-	if ipAddr != nil && ipAddr.To4() != nil {
-		copy(ipv4[:], ipAddr.To4())
-		domain = syscall.AF_INET
-	} else if ipAddr != nil && ipAddr.To16() != nil {
-		copy(ipv6[:], ipAddr.To16())
-		domain = syscall.AF_INET6
-	} else {
-		return nil, common.ContextError(fmt.Errorf("Got invalid IP address: %s", ipAddr.String()))
-	}
+	permutedIndexes := rand.Perm(len(ipAddrs))
 
 
-	// Create a socket and bind to device, when configured to do so
-	socketFd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
-	if err != nil {
-		return nil, common.ContextError(err)
+	lastErr := errors.New("unknown error")
+
+	var deadline monotime.Time
+	if config.ConnectTimeout != 0 {
+		deadline = monotime.Now().Add(config.ConnectTimeout)
 	}
 	}
 
 
-	if config.DeviceBinder != nil {
-		// WARNING: this potentially violates the direction to not call into
-		// external components after the Controller may have been stopped.
-		// TODO: rework DeviceBinder as an internal 'service' which can trap
-		// external calls when they should not be made?
-		err = config.DeviceBinder.BindToDevice(socketFd)
+	for iteration, index := range permutedIndexes {
+
+		if iteration > 0 && deadline != 0 && monotime.Now().After(deadline) {
+			// lastErr should be set by the previous iteration
+			break
+		}
+
+		// Get address type (IPv4 or IPv6)
+
+		var ipv4 [4]byte
+		var ipv6 [16]byte
+		var domain int
+		var sockAddr syscall.Sockaddr
+
+		ipAddr := ipAddrs[index]
+		if ipAddr != nil && ipAddr.To4() != nil {
+			copy(ipv4[:], ipAddr.To4())
+			domain = syscall.AF_INET
+		} else if ipAddr != nil && ipAddr.To16() != nil {
+			copy(ipv6[:], ipAddr.To16())
+			domain = syscall.AF_INET6
+		} else {
+			lastErr = common.ContextError(fmt.Errorf("Got invalid IP address: %s", ipAddr.String()))
+			continue
+		}
+		if domain == syscall.AF_INET {
+			sockAddr = &syscall.SockaddrInet4{Addr: ipv4, Port: port}
+		} else if domain == syscall.AF_INET6 {
+			sockAddr = &syscall.SockaddrInet6{Addr: ipv6, Port: port}
+		}
+
+		// Create a socket and bind to device, when configured to do so
+
+		socketFd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
+		if err != nil {
+			lastErr = common.ContextError(err)
+			continue
+		}
+
+		if config.DeviceBinder != nil {
+			// WARNING: this potentially violates the direction to not call into
+			// external components after the Controller may have been stopped.
+			// TODO: rework DeviceBinder as an internal 'service' which can trap
+			// external calls when they should not be made?
+			err = config.DeviceBinder.BindToDevice(socketFd)
+			if err != nil {
+				syscall.Close(socketFd)
+				lastErr = common.ContextError(fmt.Errorf("BindToDevice failed: %s", err))
+				continue
+			}
+		}
+
+		// Connect socket to the server's IP address
+
+		err = syscall.SetNonblock(socketFd, true)
 		if err != nil {
 		if err != nil {
 			syscall.Close(socketFd)
 			syscall.Close(socketFd)
-			return nil, common.ContextError(fmt.Errorf("BindToDevice failed: %s", err))
+			lastErr = common.ContextError(err)
+			continue
 		}
 		}
-	}
 
 
-	// Connect socket to the server's IP address
-	if domain == syscall.AF_INET {
-		sockAddr := syscall.SockaddrInet4{Addr: ipv4, Port: port}
-		err = syscall.Connect(socketFd, &sockAddr)
-	} else if domain == syscall.AF_INET6 {
-		sockAddr := syscall.SockaddrInet6{Addr: ipv6, Port: port}
-		err = syscall.Connect(socketFd, &sockAddr)
-	}
-	if err != nil {
-		syscall.Close(socketFd)
-		return nil, common.ContextError(err)
-	}
+		err = syscall.Connect(socketFd, sockAddr)
+		if err != nil {
+			if errno, ok := err.(syscall.Errno); !ok || errno != syscall.EINPROGRESS {
+				syscall.Close(socketFd)
+				lastErr = common.ContextError(err)
+				continue
+			}
+		}
 
 
-	// Convert the socket fd to a net.Conn
-	file := os.NewFile(uintptr(socketFd), "")
-	netConn, err := net.FileConn(file) // net.FileConn() dups socketFd
-	file.Close()                       // file.Close() closes socketFd
-	if err != nil {
-		return nil, common.ContextError(err)
+		fdset := &goselect.FDSet{}
+		fdset.Set(uintptr(socketFd))
+
+		timeout := config.ConnectTimeout
+		if config.ConnectTimeout == 0 {
+			timeout = -1
+		}
+
+		err = goselect.Select(socketFd+1, nil, fdset, nil, timeout)
+		if err != nil {
+			lastErr = common.ContextError(err)
+			continue
+		}
+		if !fdset.IsSet(uintptr(socketFd)) {
+			lastErr = common.ContextError(errors.New("file descriptor not set"))
+			continue
+		}
+
+		err = syscall.SetNonblock(socketFd, false)
+		if err != nil {
+			syscall.Close(socketFd)
+			lastErr = common.ContextError(err)
+			continue
+		}
+
+		// Convert the socket fd to a net.Conn
+
+		file := os.NewFile(uintptr(socketFd), "")
+		netConn, err := net.FileConn(file) // net.FileConn() dups socketFd
+		file.Close()                       // file.Close() closes socketFd
+		if err != nil {
+			lastErr = common.ContextError(err)
+			continue
+		}
+
+		return netConn, nil
 	}
 	}
 
 
-	return netConn, nil
+	return nil, lastErr
 }
 }

+ 1 - 1
psiphon/TCPConn_nobind.go

@@ -29,7 +29,7 @@ import (
 )
 )
 
 
 // tcpDial is the platform-specific part of interruptibleTCPDial
 // tcpDial is the platform-specific part of interruptibleTCPDial
-func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn, error) {
+func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
 
 
 	if config.DeviceBinder != nil {
 	if config.DeviceBinder != nil {
 		return nil, common.ContextError(errors.New("psiphon.interruptibleTCPDial with DeviceBinder not supported"))
 		return nil, common.ContextError(errors.New("psiphon.interruptibleTCPDial with DeviceBinder not supported"))

+ 2 - 3
psiphon/controller_test.go

@@ -448,8 +448,6 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 	json.Unmarshal(configJSON, &modifyConfig)
 	json.Unmarshal(configJSON, &modifyConfig)
 	modifyConfig["DataStoreDirectory"] = testDataDirName
 	modifyConfig["DataStoreDirectory"] = testDataDirName
 	modifyConfig["RemoteServerListDownloadFilename"] = filepath.Join(testDataDirName, "server_list_compressed")
 	modifyConfig["RemoteServerListDownloadFilename"] = filepath.Join(testDataDirName, "server_list_compressed")
-	modifyConfig["ObfuscatedServerListDownloadDirectory"] = testDataDirName
-	modifyConfig["ObfuscatedServerListRootURL"] = "http://127.0.0.1/osl" // will fail
 	modifyConfig["UpgradeDownloadFilename"] = filepath.Join(testDataDirName, "upgrade")
 	modifyConfig["UpgradeDownloadFilename"] = filepath.Join(testDataDirName, "upgrade")
 	configJSON, _ = json.Marshal(modifyConfig)
 	configJSON, _ = json.Marshal(modifyConfig)
 
 
@@ -970,7 +968,8 @@ func initDisruptor() {
 				defer localConn.Close()
 				defer localConn.Close()
 				remoteConn, err := net.Dial("tcp", localConn.Req.Target)
 				remoteConn, err := net.Dial("tcp", localConn.Req.Target)
 				if err != nil {
 				if err != nil {
-					fmt.Printf("disruptor proxy dial error: %s\n", err)
+					// TODO: log "err" without logging server IPs
+					fmt.Printf("disruptor proxy dial error\n")
 					return
 					return
 				}
 				}
 				defer remoteConn.Close()
 				defer remoteConn.Close()

+ 4 - 1
psiphon/remoteServerList_test.go

@@ -74,6 +74,8 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 			EnableSSHAPIRequests: true,
 			EnableSSHAPIRequests: true,
 			WebServerPort:        8001,
 			WebServerPort:        8001,
 			TunnelProtocolPorts:  map[string]int{"OSSH": 4001},
 			TunnelProtocolPorts:  map[string]int{"OSSH": 4001},
+			LogFilename:          "psiphond.log",
+			LogLevel:             "debug",
 		})
 		})
 	if err != nil {
 	if err != nil {
 		t.Fatalf("error generating server config: %s", err)
 		t.Fatalf("error generating server config: %s", err)
@@ -371,7 +373,8 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 			case "RemoteServerListResourceDownloadedBytes":
 			case "RemoteServerListResourceDownloadedBytes":
 				// TODO: check for resumed download for each URL
 				// TODO: check for resumed download for each URL
 				//url := payload["url"].(string)
 				//url := payload["url"].(string)
-				printNotice = true
+				//printNotice = true
+				printNotice = false
 			case "RemoteServerListResourceDownloaded":
 			case "RemoteServerListResourceDownloaded":
 				printNotice = true
 				printNotice = true
 			}
 			}

+ 1 - 0
psiphon/server/api.go

@@ -509,6 +509,7 @@ var baseRequestParams = []requestParamSpec{
 	requestParamSpec{"sponsor_id", isHexDigits, 0},
 	requestParamSpec{"sponsor_id", isHexDigits, 0},
 	requestParamSpec{"client_version", isIntString, 0},
 	requestParamSpec{"client_version", isIntString, 0},
 	requestParamSpec{"client_platform", isClientPlatform, 0},
 	requestParamSpec{"client_platform", isClientPlatform, 0},
+	requestParamSpec{"client_build_rev", isHexDigits, requestParamOptional},
 	requestParamSpec{"relay_protocol", isRelayProtocol, 0},
 	requestParamSpec{"relay_protocol", isRelayProtocol, 0},
 	requestParamSpec{"tunnel_whole_device", isBooleanFlag, requestParamOptional},
 	requestParamSpec{"tunnel_whole_device", isBooleanFlag, requestParamOptional},
 	requestParamSpec{"device_region", isRegionCode, requestParamOptional},
 	requestParamSpec{"device_region", isRegionCode, requestParamOptional},

+ 7 - 1
psiphon/server/config.go

@@ -339,6 +339,7 @@ func validateNetworkAddress(address string, requireIPaddress bool) error {
 // a generated server config.
 // a generated server config.
 type GenerateConfigParams struct {
 type GenerateConfigParams struct {
 	LogFilename          string
 	LogFilename          string
+	LogLevel             string
 	ServerIPAddress      string
 	ServerIPAddress      string
 	WebServerPort        int
 	WebServerPort        int
 	EnableSSHAPIRequests bool
 	EnableSSHAPIRequests bool
@@ -486,8 +487,13 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 	// Note: this config is intended for either testing or as an illustrative
 	// Note: this config is intended for either testing or as an illustrative
 	// example or template and is not intended for production deployment.
 	// example or template and is not intended for production deployment.
 
 
+	logLevel := params.LogLevel
+	if logLevel == "" {
+		logLevel = "info"
+	}
+
 	config := &Config{
 	config := &Config{
-		LogLevel:                       "info",
+		LogLevel:                       logLevel,
 		LogFilename:                    params.LogFilename,
 		LogFilename:                    params.LogFilename,
 		GeoIPDatabaseFilenames:         nil,
 		GeoIPDatabaseFilenames:         nil,
 		HostID:                         "example-host-id",
 		HostID:                         "example-host-id",

+ 2 - 1
psiphon/server/server_test.go

@@ -302,7 +302,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	serverConfig["PsinetDatabaseFilename"] = psinetFilename
 	serverConfig["PsinetDatabaseFilename"] = psinetFilename
 	serverConfig["TrafficRulesFilename"] = trafficRulesFilename
 	serverConfig["TrafficRulesFilename"] = trafficRulesFilename
 	serverConfig["OSLConfigFilename"] = oslConfigFilename
 	serverConfig["OSLConfigFilename"] = oslConfigFilename
-	serverConfig["LogLevel"] = "error"
+	serverConfig["LogFilename"] = "psiphond.log"
+	serverConfig["LogLevel"] = "debug"
 
 
 	serverConfigJSON, _ = json.Marshal(serverConfig)
 	serverConfigJSON, _ = json.Marshal(serverConfig)
 
 

+ 9 - 2
psiphon/server/tunnelServer.go

@@ -1237,10 +1237,14 @@ func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
 const (
 const (
 	portForwardTypeTCP = iota
 	portForwardTypeTCP = iota
 	portForwardTypeUDP
 	portForwardTypeUDP
+	portForwardTypeTransparentDNS
 )
 )
 
 
 func (sshClient *sshClient) isPortForwardPermitted(
 func (sshClient *sshClient) isPortForwardPermitted(
-	portForwardType int, remoteIP net.IP, port int) bool {
+	portForwardType int,
+	isTransparentDNSForwarding bool,
+	remoteIP net.IP,
+	port int) bool {
 
 
 	sshClient.Lock()
 	sshClient.Lock()
 	defer sshClient.Unlock()
 	defer sshClient.Unlock()
@@ -1251,7 +1255,9 @@ func (sshClient *sshClient) isPortForwardPermitted(
 
 
 	// Disallow connection to loopback. This is a failsafe. The server
 	// Disallow connection to loopback. This is a failsafe. The server
 	// should be run on a host with correctly configured firewall rules.
 	// should be run on a host with correctly configured firewall rules.
-	if remoteIP.IsLoopback() {
+	// And exception is made in the case of tranparent DNS forwarding,
+	// where the remoteIP has been rewritten.
+	if !isTransparentDNSForwarding && remoteIP.IsLoopback() {
 		return false
 		return false
 	}
 	}
 
 
@@ -1423,6 +1429,7 @@ func (sshClient *sshClient) handleTCPChannel(
 	if !isWebServerPortForward &&
 	if !isWebServerPortForward &&
 		!sshClient.isPortForwardPermitted(
 		!sshClient.isPortForwardPermitted(
 			portForwardTypeTCP,
 			portForwardTypeTCP,
+			false,
 			lookupResult.IP,
 			lookupResult.IP,
 			portToConnect) {
 			portToConnect) {
 
 

+ 1 - 1
psiphon/server/udp.go

@@ -163,7 +163,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 			}
 			}
 
 
 			if !mux.sshClient.isPortForwardPermitted(
 			if !mux.sshClient.isPortForwardPermitted(
-				portForwardTypeUDP, dialIP, int(message.remotePort)) {
+				portForwardTypeUDP, message.forwardDNS, dialIP, int(message.remotePort)) {
 				// The udpgw protocol has no error response, so
 				// The udpgw protocol has no error response, so
 				// we just discard the message and read another.
 				// we just discard the message and read another.
 				continue
 				continue

+ 1 - 0
psiphon/serverApi.go

@@ -816,6 +816,7 @@ func (serverContext *ServerContext) getBaseParams() requestJSONObject {
 	// TODO: client_tunnel_core_version?
 	// TODO: client_tunnel_core_version?
 	params["relay_protocol"] = tunnel.protocol
 	params["relay_protocol"] = tunnel.protocol
 	params["client_platform"] = tunnel.config.ClientPlatform
 	params["client_platform"] = tunnel.config.ClientPlatform
+	params["client_build_rev"] = common.GetBuildInfo().BuildRev
 	params["tunnel_whole_device"] = strconv.Itoa(tunnel.config.TunnelWholeDevice)
 	params["tunnel_whole_device"] = strconv.Itoa(tunnel.config.TunnelWholeDevice)
 
 
 	// The following parameters may be blank and must
 	// The following parameters may be blank and must