Просмотр исходного кода

Fixes for tcpDial (bind version)
- Implement a retry when a hostname is
specified and it resolves to multiple
IP addresses. Fixes the case where an
IPv6 address is chosen and fails due
to "no route to host", where before no
attempt was made to use an IPv4 address.
- Implement real connect timeout, to avoid
infinitely dangling file descriptors and
go routines when a timeout is specified.

Rod Hynes 9 лет назад
Родитель
Сommit
ea5b1cbd54
3 измененных файлов с 134 добавлено и 74 удалено
  1. 4 4
      psiphon/TCPConn.go
  2. 129 69
      psiphon/TCPConn_bind.go
  3. 1 1
      psiphon/TCPConn_nobind.go

+ 4 - 4
psiphon/TCPConn.go

@@ -131,9 +131,9 @@ func interruptibleTCPDial(addr string, config *DialConfig) (*TCPConn, error) {
 		var netConn net.Conn
 		var netConn net.Conn
 		var err error
 		var err error
 		if config.UpstreamProxyUrl != "" {
 		if config.UpstreamProxyUrl != "" {
-			netConn, err = proxiedTcpDial(addr, config, conn.dialResult)
+			netConn, err = proxiedTcpDial(addr, config)
 		} else {
 		} else {
-			netConn, err = tcpDial(addr, config, conn.dialResult)
+			netConn, err = tcpDial(addr, config)
 		}
 		}
 
 
 		// Mutex is necessary for referencing conn.isClosed and conn.Conn as
 		// Mutex is necessary for referencing conn.isClosed and conn.Conn as
@@ -172,9 +172,9 @@ func interruptibleTCPDial(addr string, config *DialConfig) (*TCPConn, error) {
 
 
 // proxiedTcpDial wraps a tcpDial call in an upstreamproxy dial.
 // proxiedTcpDial wraps a tcpDial call in an upstreamproxy dial.
 func proxiedTcpDial(
 func proxiedTcpDial(
-	addr string, config *DialConfig, dialResult chan error) (net.Conn, error) {
+	addr string, config *DialConfig) (net.Conn, error) {
 	dialer := func(network, addr string) (net.Conn, error) {
 	dialer := func(network, addr string) (net.Conn, error) {
-		return tcpDial(addr, config, dialResult)
+		return tcpDial(addr, config)
 	}
 	}
 
 
 	dialHeaders, _ := common.UserAgentIfUnset(config.UpstreamProxyCustomHeaders)
 	dialHeaders, _ := common.UserAgentIfUnset(config.UpstreamProxyCustomHeaders)

+ 129 - 69
psiphon/TCPConn_bind.go

@@ -24,12 +24,14 @@ package psiphon
 import (
 import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"math/rand"
 	"net"
 	"net"
 	"os"
 	"os"
 	"strconv"
 	"strconv"
 	"syscall"
 	"syscall"
-	"time"
 
 
+	"github.com/Psiphon-Inc/goarista/monotime"
+	"github.com/Psiphon-Inc/goselect"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 )
 
 
@@ -37,19 +39,9 @@ import (
 //
 //
 // To implement socket device binding, the lower-level syscall APIs are used.
 // To implement socket device binding, the lower-level syscall APIs are used.
 // The sequence of syscalls in this implementation are taken from:
 // The sequence of syscalls in this implementation are taken from:
-// https://code.google.com/p/go/issues/detail?id=6966
-func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn, error) {
-
-	// Like interruption, this timeout doesn't stop this connection goroutine,
-	// it just unblocks the calling interruptibleTCPDial.
-	if config.ConnectTimeout != 0 {
-		time.AfterFunc(config.ConnectTimeout, func() {
-			select {
-			case dialResult <- errors.New("connect timeout"):
-			default:
-			}
-		})
-	}
+// https://github.com/golang/go/issues/6966
+// (originally: https://code.google.com/p/go/issues/detail?id=6966)
+func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
 
 
 	// Get the remote IP and port, resolving a domain name if necessary
 	// Get the remote IP and port, resolving a domain name if necessary
 	host, strPort, err := net.SplitHostPort(addr)
 	host, strPort, err := net.SplitHostPort(addr)
@@ -68,70 +60,138 @@ func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn,
 		return nil, common.ContextError(errors.New("no IP address"))
 		return nil, common.ContextError(errors.New("no IP address"))
 	}
 	}
 
 
-	// Select an IP at random from the list, so we're not always
-	// trying the same IP (when > 1) which may be blocked.
-	// TODO: retry all IPs until one connects? For now, this retry
-	// will happen on subsequent TCPDial calls, when a different IP
-	// is selected.
-	index, err := common.MakeSecureRandomInt(len(ipAddrs))
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
+	// Iterate over a pseudorandom permutation of the destination
+	// IPs and attempt connections.
+	//
+	// Only continue retrying as long as the initial ConnectTimeout
+	// has not expired. Unlike net.Dial, we do not fractionalize the
+	// timeout, as the ConnectTimeout is generally intended to apply
+	// to a single attempt. So these serial retries are most useful
+	// in cases of immediate failure, such as "no route to host"
+	// errors when a host resolves to both IPv4 and IPv6 but IPv6
+	// addresses are unreachable.
+	// Retries at higher levels cover other cases: e.g.,
+	// Controller.remoteServerListFetcher will retry its entire
+	// operation and tcpDial will try a new permutation; or similarly,
+	// Controller.establishCandidateGenerator will retry a candidate
+	// tunnel server dials.
 
 
-	var ipv4 [4]byte
-	var ipv6 [16]byte
-	var domain int
-	ipAddr := ipAddrs[index]
-
-	// Get address type (IPv4 or IPv6)
-	if ipAddr != nil && ipAddr.To4() != nil {
-		copy(ipv4[:], ipAddr.To4())
-		domain = syscall.AF_INET
-	} else if ipAddr != nil && ipAddr.To16() != nil {
-		copy(ipv6[:], ipAddr.To16())
-		domain = syscall.AF_INET6
-	} else {
-		return nil, common.ContextError(fmt.Errorf("Got invalid IP address: %s", ipAddr.String()))
-	}
+	permutedIndexes := rand.Perm(len(ipAddrs))
 
 
-	// Create a socket and bind to device, when configured to do so
-	socketFd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
-	if err != nil {
-		return nil, common.ContextError(err)
+	lastErr := errors.New("unknown error")
+
+	var deadline monotime.Time
+	if config.ConnectTimeout != 0 {
+		deadline = monotime.Now().Add(config.ConnectTimeout)
 	}
 	}
 
 
-	if config.DeviceBinder != nil {
-		// WARNING: this potentially violates the direction to not call into
-		// external components after the Controller may have been stopped.
-		// TODO: rework DeviceBinder as an internal 'service' which can trap
-		// external calls when they should not be made?
-		err = config.DeviceBinder.BindToDevice(socketFd)
+	for iteration, index := range permutedIndexes {
+
+		if iteration > 0 && deadline != 0 && monotime.Now().After(deadline) {
+			// lastErr should be set by the previous iteration
+			break
+		}
+
+		// Get address type (IPv4 or IPv6)
+
+		var ipv4 [4]byte
+		var ipv6 [16]byte
+		var domain int
+		var sockAddr syscall.Sockaddr
+
+		ipAddr := ipAddrs[index]
+		if ipAddr != nil && ipAddr.To4() != nil {
+			copy(ipv4[:], ipAddr.To4())
+			domain = syscall.AF_INET
+		} else if ipAddr != nil && ipAddr.To16() != nil {
+			copy(ipv6[:], ipAddr.To16())
+			domain = syscall.AF_INET6
+		} else {
+			lastErr = common.ContextError(fmt.Errorf("Got invalid IP address: %s", ipAddr.String()))
+			continue
+		}
+		if domain == syscall.AF_INET {
+			sockAddr = &syscall.SockaddrInet4{Addr: ipv4, Port: port}
+		} else if domain == syscall.AF_INET6 {
+			sockAddr = &syscall.SockaddrInet6{Addr: ipv6, Port: port}
+		}
+
+		// Create a socket and bind to device, when configured to do so
+
+		socketFd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
+		if err != nil {
+			lastErr = common.ContextError(err)
+			continue
+		}
+
+		if config.DeviceBinder != nil {
+			// WARNING: this potentially violates the direction to not call into
+			// external components after the Controller may have been stopped.
+			// TODO: rework DeviceBinder as an internal 'service' which can trap
+			// external calls when they should not be made?
+			err = config.DeviceBinder.BindToDevice(socketFd)
+			if err != nil {
+				syscall.Close(socketFd)
+				lastErr = common.ContextError(fmt.Errorf("BindToDevice failed: %s", err))
+				continue
+			}
+		}
+
+		// Connect socket to the server's IP address
+
+		err = syscall.SetNonblock(socketFd, true)
 		if err != nil {
 		if err != nil {
 			syscall.Close(socketFd)
 			syscall.Close(socketFd)
-			return nil, common.ContextError(fmt.Errorf("BindToDevice failed: %s", err))
+			lastErr = common.ContextError(err)
+			continue
 		}
 		}
-	}
 
 
-	// Connect socket to the server's IP address
-	if domain == syscall.AF_INET {
-		sockAddr := syscall.SockaddrInet4{Addr: ipv4, Port: port}
-		err = syscall.Connect(socketFd, &sockAddr)
-	} else if domain == syscall.AF_INET6 {
-		sockAddr := syscall.SockaddrInet6{Addr: ipv6, Port: port}
-		err = syscall.Connect(socketFd, &sockAddr)
-	}
-	if err != nil {
-		syscall.Close(socketFd)
-		return nil, common.ContextError(err)
-	}
+		err = syscall.Connect(socketFd, sockAddr)
+		if err != nil {
+			if errno, ok := err.(syscall.Errno); !ok || errno != syscall.EINPROGRESS {
+				syscall.Close(socketFd)
+				lastErr = common.ContextError(err)
+				continue
+			}
+		}
 
 
-	// Convert the socket fd to a net.Conn
-	file := os.NewFile(uintptr(socketFd), "")
-	netConn, err := net.FileConn(file) // net.FileConn() dups socketFd
-	file.Close()                       // file.Close() closes socketFd
-	if err != nil {
-		return nil, common.ContextError(err)
+		fdset := &goselect.FDSet{}
+		fdset.Set(uintptr(socketFd))
+
+		timeout := config.ConnectTimeout
+		if config.ConnectTimeout == 0 {
+			timeout = -1
+		}
+
+		err = goselect.Select(socketFd+1, nil, fdset, nil, timeout)
+		if err != nil {
+			lastErr = common.ContextError(err)
+			continue
+		}
+		if !fdset.IsSet(uintptr(socketFd)) {
+			lastErr = common.ContextError(errors.New("file descriptor not set"))
+			continue
+		}
+
+		err = syscall.SetNonblock(socketFd, false)
+		if err != nil {
+			syscall.Close(socketFd)
+			lastErr = common.ContextError(err)
+			continue
+		}
+
+		// Convert the socket fd to a net.Conn
+
+		file := os.NewFile(uintptr(socketFd), "")
+		netConn, err := net.FileConn(file) // net.FileConn() dups socketFd
+		file.Close()                       // file.Close() closes socketFd
+		if err != nil {
+			lastErr = common.ContextError(err)
+			continue
+		}
+
+		return netConn, nil
 	}
 	}
 
 
-	return netConn, nil
+	return nil, lastErr
 }
 }

+ 1 - 1
psiphon/TCPConn_nobind.go

@@ -29,7 +29,7 @@ import (
 )
 )
 
 
 // tcpDial is the platform-specific part of interruptibleTCPDial
 // tcpDial is the platform-specific part of interruptibleTCPDial
-func tcpDial(addr string, config *DialConfig, dialResult chan error) (net.Conn, error) {
+func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
 
 
 	if config.DeviceBinder != nil {
 	if config.DeviceBinder != nil {
 		return nil, common.ContextError(errors.New("psiphon.interruptibleTCPDial with DeviceBinder not supported"))
 		return nil, common.ContextError(errors.New("psiphon.interruptibleTCPDial with DeviceBinder not supported"))