Browse Source

Merge pull request #429 from rod-hynes/master

Replace pending conns with Context
Rod Hynes 8 years ago
parent
commit
d0f09bb2cf

+ 14 - 8
ConsoleClient/main.go

@@ -21,6 +21,7 @@ package main
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"flag"
 	"fmt"
@@ -271,26 +272,31 @@ func main() {
 		os.Exit(1)
 	}
 
-	controllerStopSignal := make(chan struct{}, 1)
-	shutdownBroadcast := make(chan struct{})
+	controllerCtx, stopController := context.WithCancel(context.Background())
+	defer stopController()
+
 	controllerWaitGroup := new(sync.WaitGroup)
 	controllerWaitGroup.Add(1)
 	go func() {
 		defer controllerWaitGroup.Done()
-		controller.Run(shutdownBroadcast)
-		controllerStopSignal <- *new(struct{})
-	}()
+		controller.Run(controllerCtx)
 
-	// Wait for an OS signal or a Run stop signal, then stop Psiphon and exit
+		// Signal the <-controllerCtx.Done() case below. If the <-systemStopSignal
+		// case already called stopController, this is a noop.
+		stopController()
+	}()
 
 	systemStopSignal := make(chan os.Signal, 1)
 	signal.Notify(systemStopSignal, os.Interrupt, os.Kill)
+
+	// Wait for an OS signal or a Run stop signal, then stop Psiphon and exit
+
 	select {
 	case <-systemStopSignal:
 		psiphon.NoticeInfo("shutdown by system")
-		close(shutdownBroadcast)
+		stopController()
 		controllerWaitGroup.Wait()
-	case <-controllerStopSignal:
+	case <-controllerCtx.Done():
 		psiphon.NoticeInfo("shutdown by controller")
 	}
 }

+ 9 - 5
MobileLibrary/psi/psi.go

@@ -25,6 +25,7 @@ package psi
 // Start/Stop interface on top of a single Controller instance.
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"sync"
@@ -64,7 +65,8 @@ func NoticeUserLog(message string) {
 
 var controllerMutex sync.Mutex
 var controller *psiphon.Controller
-var shutdownBroadcast chan struct{}
+var controllerCtx context.Context
+var stopController context.CancelFunc
 var controllerWaitGroup *sync.WaitGroup
 
 func Start(
@@ -128,12 +130,13 @@ func Start(
 		return fmt.Errorf("error initializing controller: %s", err)
 	}
 
-	shutdownBroadcast = make(chan struct{})
+	controllerCtx, stopController = context.WithCancel(context.Background())
+
 	controllerWaitGroup = new(sync.WaitGroup)
 	controllerWaitGroup.Add(1)
 	go func() {
 		defer controllerWaitGroup.Done()
-		controller.Run(shutdownBroadcast)
+		controller.Run(controllerCtx)
 	}()
 
 	return nil
@@ -145,10 +148,11 @@ func Stop() {
 	defer controllerMutex.Unlock()
 
 	if controller != nil {
-		close(shutdownBroadcast)
+		stopController()
 		controllerWaitGroup.Wait()
 		controller = nil
-		shutdownBroadcast = nil
+		controllerCtx = nil
+		stopController = nil
 		controllerWaitGroup = nil
 	}
 }

+ 59 - 25
psiphon/LookupIP.go

@@ -22,12 +22,12 @@
 package psiphon
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"net"
 	"os"
 	"syscall"
-	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
@@ -37,38 +37,53 @@ import (
 // When BindToDevice is required, LookupIP explicitly creates a UDP
 // socket, binds it to the device, and makes an explicit DNS request
 // to the specified DNS resolver.
-func LookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
+func LookupIP(ctx context.Context, host string, config *DialConfig) ([]net.IP, error) {
 
-	// When the input host is an IP address, echo it back
-	ipAddr := net.ParseIP(host)
-	if ipAddr != nil {
-		return []net.IP{ipAddr}, nil
+	ip := net.ParseIP(host)
+	if ip != nil {
+		return []net.IP{ip}, nil
 	}
 
 	if config.DeviceBinder != nil {
-		addrs, err = bindLookupIP(host, config.DnsServerGetter.GetPrimaryDnsServer(), config)
+
+		dnsServer := config.DnsServerGetter.GetPrimaryDnsServer()
+
+		ips, err := bindLookupIP(ctx, host, dnsServer, config)
 		if err == nil {
-			if len(addrs) == 0 {
+			if len(ips) == 0 {
 				err = errors.New("empty address list")
 			} else {
-				return addrs, err
+				return ips, err
 			}
 		}
-		NoticeAlert("retry resolve host %s: %s", host, err)
-		dnsServer := config.DnsServerGetter.GetSecondaryDnsServer()
+
+		dnsServer = config.DnsServerGetter.GetSecondaryDnsServer()
 		if dnsServer == "" {
-			return addrs, err
+			return ips, err
 		}
-		return bindLookupIP(host, dnsServer, config)
+
+		NoticeAlert("retry resolve host %s: %s", host, err)
+
+		return bindLookupIP(ctx, host, dnsServer, config)
+	}
+
+	addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
+	if err != nil {
+		return nil, common.ContextError(err)
 	}
-	return net.LookupIP(host)
+
+	ips := make([]net.IP, len(addrs))
+	for i, addr := range addrs {
+		ips[i] = addr.IP
+	}
+
+	return ips, nil
 }
 
 // bindLookupIP implements the BindToDevice LookupIP case.
 // To implement socket device binding, the lower-level syscall APIs are used.
-// The sequence of syscalls in this implementation are taken from:
-// https://code.google.com/p/go/issues/detail?id=6966
-func bindLookupIP(host, dnsServer string, config *DialConfig) (addrs []net.IP, err error) {
+func bindLookupIP(
+	ctx context.Context, host, dnsServer string, config *DialConfig) ([]net.IP, error) {
 
 	// config.DnsServerGetter.GetDnsServers() must return IP addresses
 	ipAddr := net.ParseIP(dnsServer)
@@ -130,6 +145,9 @@ func bindLookupIP(host, dnsServer string, config *DialConfig) (addrs []net.IP, e
 	}
 
 	// Convert the syscall socket to a net.Conn, for use in the dns package
+	// This code block is from:
+	// https://github.com/golang/go/issues/6966
+
 	file := os.NewFile(uintptr(socketFd), "")
 	netConn, err := net.FileConn(file) // net.FileConn() dups socketFd
 	file.Close()                       // file.Close() closes socketFd
@@ -137,17 +155,33 @@ func bindLookupIP(host, dnsServer string, config *DialConfig) (addrs []net.IP, e
 		return nil, common.ContextError(err)
 	}
 
-	// Set DNS query timeouts, using the ConnectTimeout from the overall Dial
-	if config.ConnectTimeout != 0 {
-		netConn.SetReadDeadline(time.Now().Add(config.ConnectTimeout))
-		netConn.SetWriteDeadline(time.Now().Add(config.ConnectTimeout))
+	type resolveIPResult struct {
+		ips []net.IP
+		err error
 	}
 
-	addrs, _, err = ResolveIP(host, netConn)
-	netConn.Close()
-	if err != nil {
+	resultChannel := make(chan resolveIPResult)
+
+	go func() {
+		ips, _, err := ResolveIP(host, netConn)
+		netConn.Close()
+		resultChannel <- resolveIPResult{ips: ips, err: err}
+	}()
+
+	var result resolveIPResult
+
+	select {
+	case result = <-resultChannel:
+	case <-ctx.Done():
+		result.err = ctx.Err()
+		// Interrupt the goroutine
+		netConn.Close()
+		<-resultChannel
+	}
+
+	if result.err != nil {
 		return nil, common.ContextError(err)
 	}
 
-	return addrs, nil
+	return result.ips, nil
 }

+ 15 - 2
psiphon/LookupIP_nobind.go

@@ -22,6 +22,7 @@
 package psiphon
 
 import (
+	"context"
 	"errors"
 	"net"
 
@@ -30,9 +31,21 @@ import (
 
 // LookupIP resolves a hostname. When BindToDevice is not required, it
 // simply uses net.LookupIP.
-func LookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
+func LookupIP(ctx context.Context, host string, config *DialConfig) ([]net.IP, error) {
+
 	if config.DeviceBinder != nil {
 		return nil, common.ContextError(errors.New("LookupIP with DeviceBinder not supported on this platform"))
 	}
-	return net.LookupIP(host)
+
+	addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	ips := make([]net.IP, len(addrs))
+	for i, addr := range addrs {
+		ips[i] = addr.IP
+	}
+
+	return ips, nil
 }

+ 94 - 125
psiphon/TCPConn.go

@@ -20,137 +20,92 @@
 package psiphon
 
 import (
+	"context"
 	"errors"
+	"fmt"
 	"net"
-	"sync"
+	"sync/atomic"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/upstreamproxy"
 )
 
-// TCPConn is a customized TCP connection that:
-// - can be interrupted while dialing;
-// - implements a connect timeout;
-// - uses an upstream proxy when specified, and includes
-//   upstream proxy dialing in the connect timeout;
-// - can be bound to a specific system device (for Android VpnService
-//   routing compatibility, for example);
+// TCPConn is a customized TCP connection that supports the Closer interface
+// and which may be created using options in DialConfig, including
+// UpstreamProxyUrl, DeviceBinder, IPv6Synthesizer, and ResolvedIPCallback.
+// DeviceBinder is implemented using SO_BINDTODEVICE/IP_BOUND_IF, which
+// requires syscall-level socket code.
 type TCPConn struct {
 	net.Conn
-	mutex      sync.Mutex
-	isClosed   bool
-	dialResult chan error
+	isClosed int32
 }
 
 // NewTCPDialer creates a TCPDialer.
+//
+// Note: do not set an UpstreamProxyUrl in the config when using NewTCPDialer
+// as a custom dialer for NewProxyAuthTransport (or http.Transport with a
+// ProxyUrl), as that would result in double proxy chaining.
 func NewTCPDialer(config *DialConfig) Dialer {
-	return makeTCPDialer(config)
-}
-
-// DialTCP creates a new, connected TCPConn.
-func DialTCP(addr string, config *DialConfig) (conn net.Conn, err error) {
-	return makeTCPDialer(config)("tcp", addr)
-}
-
-// makeTCPDialer creates a custom dialer which creates TCPConn.
-func makeTCPDialer(config *DialConfig) func(network, addr string) (net.Conn, error) {
-	return func(network, addr string) (net.Conn, error) {
+	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 		if network != "tcp" {
-			return nil, errors.New("unsupported network type in TCPConn dialer")
-		}
-		conn, err := interruptibleTCPDial(addr, config)
-		if err != nil {
-			return nil, common.ContextError(err)
-		}
-		// Note: when an upstream proxy is used, we don't know what IP address
-		// was resolved, by the proxy, for that destination.
-		if config.ResolvedIPCallback != nil && config.UpstreamProxyUrl == "" {
-			ipAddress := common.IPAddressFromAddr(conn.RemoteAddr())
-			if ipAddress != "" {
-				config.ResolvedIPCallback(ipAddress)
-			}
+			return nil, common.ContextError(fmt.Errorf("%s unsupported", network))
 		}
-		return conn, nil
+		return DialTCP(ctx, addr, config)
 	}
 }
 
-// interruptibleTCPDial establishes a TCP network connection. A conn is added
-// to config.PendingConns before blocking on network I/O, which enables interruption.
-// The caller is responsible for removing an established conn from PendingConns.
-// An upstream proxy is used when specified.
-//
-// Note: do not to set a UpstreamProxyUrl in the config when using
-// NewTCPDialer as a custom dialer for NewProxyAuthTransport (or http.Transport
-// with a ProxyUrl), as that would result in double proxy chaining.
-//
-// Note: interruption does not actually cancel a connection in progress; it
-// stops waiting for the goroutine blocking on connect()/Dial.
-func interruptibleTCPDial(addr string, config *DialConfig) (*TCPConn, error) {
+// DialTCP creates a new, connected TCPConn.
+func DialTCP(
+	ctx context.Context, addr string, config *DialConfig) (net.Conn, error) {
 
-	// Buffers the first result; senders should discard results when
-	// sending would block, as that means the first result is already set.
-	conn := &TCPConn{dialResult: make(chan error, 1)}
+	var conn net.Conn
+	var err error
 
-	// Enable interruption
-	if config.PendingConns != nil && !config.PendingConns.Add(conn) {
-		return nil, common.ContextError(errors.New("pending connections already closed"))
+	if config.UpstreamProxyUrl != "" {
+		conn, err = proxiedTcpDial(ctx, addr, config)
+	} else {
+		conn, err = tcpDial(ctx, addr, config)
 	}
 
-	// Call the blocking Connect() in a goroutine. ConnectTimeout is handled
-	// in the platform-specific tcpDial helper function.
-	// Note: since this goroutine may be left running after an interrupt, don't
-	// call Notice() or perform other actions unexpected after a Controller stops.
-	// The lifetime of the goroutine may depend on the host OS TCP connect timeout
-	// when tcpDial, among other things, when makes a blocking syscall.Connect()
-	// call.
-	go func() {
-		var netConn net.Conn
-		var err error
-		if config.UpstreamProxyUrl != "" {
-			netConn, err = proxiedTcpDial(addr, config)
-		} else {
-			netConn, err = tcpDial(addr, config)
-		}
-
-		// Mutex is necessary for referencing conn.isClosed and conn.Conn as
-		// TCPConn.Close may be called while this goroutine is running.
-		conn.mutex.Lock()
-
-		// If already interrupted, cleanup the net.Conn resource and discard.
-		if conn.isClosed && netConn != nil {
-			netConn.Close()
-			conn.mutex.Unlock()
-			return
-		}
-
-		conn.Conn = netConn
-		conn.mutex.Unlock()
-
-		select {
-		case conn.dialResult <- err:
-		default:
-		}
-	}()
-
-	// Wait until Dial completes (or times out) or until interrupt
-	err := <-conn.dialResult
 	if err != nil {
-		if config.PendingConns != nil {
-			config.PendingConns.Remove(conn)
-		}
 		return nil, common.ContextError(err)
 	}
 
-	// TODO: now allow conn.dialResult to be garbage collected?
-
+	// Note: when an upstream proxy is used, we don't know what IP address
+	// was resolved, by the proxy, for that destination.
+	if config.ResolvedIPCallback != nil && config.UpstreamProxyUrl == "" {
+		ipAddress := common.IPAddressFromAddr(conn.RemoteAddr())
+		if ipAddress != "" {
+			config.ResolvedIPCallback(ipAddress)
+		}
+	}
 	return conn, nil
 }
 
 // proxiedTcpDial wraps a tcpDial call in an upstreamproxy dial.
 func proxiedTcpDial(
-	addr string, config *DialConfig) (net.Conn, error) {
+	ctx context.Context, addr string, config *DialConfig) (net.Conn, error) {
+
+	var interruptConns common.Conns
+
+	// Note: using interruptConns to interrupt a proxy dial assumes
+	// that the underlying proxy code will immediately exit with an
+	// error when all underlying conns unexpectedly close; e.g.,
+	// the proxy handshake won't keep retrying to dial new conns.
+
 	dialer := func(network, addr string) (net.Conn, error) {
-		return tcpDial(addr, config)
+		conn, err := tcpDial(ctx, addr, config)
+		if conn != nil {
+			if !interruptConns.Add(conn) {
+				err = errors.New("already interrupted")
+				conn.Close()
+				conn = nil
+			}
+		}
+		if err != nil {
+			return nil, common.ContextError(err)
+		}
+		return conn, nil
 	}
 
 	upstreamDialer := upstreamproxy.NewProxyDialFunc(
@@ -159,56 +114,70 @@ func proxiedTcpDial(
 			ProxyURIString:  config.UpstreamProxyUrl,
 			CustomHeaders:   config.CustomHeaders,
 		})
-	netConn, err := upstreamDialer("tcp", addr)
-	if _, ok := err.(*upstreamproxy.Error); ok {
-		NoticeUpstreamProxyError(err)
+
+	type upstreamDialResult struct {
+		conn net.Conn
+		err  error
 	}
-	return netConn, err
-}
 
-// Close terminates a connected TCPConn or interrupts a dialing TCPConn.
-func (conn *TCPConn) Close() (err error) {
-	conn.mutex.Lock()
-	defer conn.mutex.Unlock()
+	resultChannel := make(chan upstreamDialResult)
 
-	if conn.isClosed {
-		return
+	go func() {
+		conn, err := upstreamDialer("tcp", addr)
+		if _, ok := err.(*upstreamproxy.Error); ok {
+			NoticeUpstreamProxyError(err)
+		}
+		resultChannel <- upstreamDialResult{
+			conn: conn,
+			err:  err,
+		}
+	}()
+
+	var result upstreamDialResult
+
+	select {
+	case result = <-resultChannel:
+	case <-ctx.Done():
+		result.err = ctx.Err()
+		// Interrupt the goroutine
+		interruptConns.CloseAll()
+		<-resultChannel
 	}
-	conn.isClosed = true
 
-	if conn.Conn != nil {
-		err = conn.Conn.Close()
+	if result.err != nil {
+		return nil, common.ContextError(result.err)
 	}
 
-	select {
-	case conn.dialResult <- errors.New("dial interrupted"):
-	default:
+	return result.conn, nil
+}
+
+// Close terminates a connected TCPConn or interrupts a dialing TCPConn.
+func (conn *TCPConn) Close() (err error) {
+
+	if !atomic.CompareAndSwapInt32(&conn.isClosed, 0, 1) {
+		return nil
 	}
 
-	return err
+	return conn.Conn.Close()
 }
 
 // IsClosed implements the Closer iterface. The return value
 // indicates whether the TCPConn has been closed.
 func (conn *TCPConn) IsClosed() bool {
-	conn.mutex.Lock()
-	defer conn.mutex.Unlock()
-	return conn.isClosed
+	return atomic.LoadInt32(&conn.isClosed) == 1
 }
 
 // CloseWrite calls net.TCPConn.CloseWrite when the underlying
 // conn is a *net.TCPConn.
 func (conn *TCPConn) CloseWrite() (err error) {
-	conn.mutex.Lock()
-	defer conn.mutex.Unlock()
 
-	if conn.isClosed {
-		return errors.New("already closed")
+	if conn.IsClosed() {
+		return common.ContextError(errors.New("already closed"))
 	}
 
 	tcpConn, ok := conn.Conn.(*net.TCPConn)
 	if !ok {
-		return errors.New("conn is not a *net.TCPConn")
+		return common.ContextError(errors.New("conn is not a *net.TCPConn"))
 	}
 
 	return tcpConn.CloseWrite()

+ 95 - 48
psiphon/TCPConn_bind.go

@@ -22,6 +22,7 @@
 package psiphon
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"math/rand"
@@ -30,18 +31,17 @@ import (
 	"strconv"
 	"syscall"
 
-	"github.com/Psiphon-Inc/goarista/monotime"
 	"github.com/Psiphon-Inc/goselect"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
-// tcpDial is the platform-specific part of interruptibleTCPDial
+// tcpDial is the platform-specific part of DialTCP
 //
 // To implement socket device binding, the lower-level syscall APIs are used.
 // The sequence of syscalls in this implementation are taken from:
 // https://github.com/golang/go/issues/6966
 // (originally: https://code.google.com/p/go/issues/detail?id=6966)
-func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
+func tcpDial(ctx context.Context, addr string, config *DialConfig) (net.Conn, error) {
 
 	// Get the remote IP and port, resolving a domain name if necessary
 	host, strPort, err := net.SplitHostPort(addr)
@@ -52,7 +52,7 @@ func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
-	ipAddrs, err := LookupIP(host, config)
+	ipAddrs, err := LookupIP(ctx, host, config)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
@@ -80,13 +80,14 @@ func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
 	// Iterate over a pseudorandom permutation of the destination
 	// IPs and attempt connections.
 	//
-	// Only continue retrying as long as the initial ConnectTimeout
-	// has not expired. Unlike net.Dial, we do not fractionalize the
-	// timeout, as the ConnectTimeout is generally intended to apply
-	// to a single attempt. So these serial retries are most useful
-	// in cases of immediate failure, such as "no route to host"
+	// Only continue retrying as long as the dial context is not
+	// done. Unlike net.Dial, we do not fractionalize the context
+	// deadline, as the dial is generally intended to apply to a
+	// single attempt. So these serial retries are most useful in
+	// cases of immediate failure, such as "no route to host"
 	// errors when a host resolves to both IPv4 and IPv6 but IPv6
 	// addresses are unreachable.
+	//
 	// Retries at higher levels cover other cases: e.g.,
 	// Controller.remoteServerListFetcher will retry its entire
 	// operation and tcpDial will try a new permutation; or similarly,
@@ -97,17 +98,7 @@ func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
 
 	lastErr := errors.New("unknown error")
 
-	var deadline monotime.Time
-	if config.ConnectTimeout != 0 {
-		deadline = monotime.Now().Add(config.ConnectTimeout)
-	}
-
-	for iteration, index := range permutedIndexes {
-
-		if iteration > 0 && deadline != 0 && monotime.Now().After(deadline) {
-			// lastErr should be set by the previous iteration
-			break
-		}
+	for _, index := range permutedIndexes {
 
 		// Get address type (IPv4 or IPv6)
 
@@ -135,22 +126,20 @@ func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
 
 		// Create a socket and bind to device, when configured to do so
 
-		socketFd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
+		socketFD, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
 		if err != nil {
 			lastErr = common.ContextError(err)
 			continue
 		}
 
-		tcpDialSetAdditionalSocketOptions(socketFd)
+		syscall.CloseOnExec(socketFD)
+
+		tcpDialSetAdditionalSocketOptions(socketFD)
 
 		if config.DeviceBinder != nil {
-			// WARNING: this potentially violates the direction to not call into
-			// external components after the Controller may have been stopped.
-			// TODO: rework DeviceBinder as an internal 'service' which can trap
-			// external calls when they should not be made?
-			err = config.DeviceBinder.BindToDevice(socketFd)
+			err = config.DeviceBinder.BindToDevice(socketFD)
 			if err != nil {
-				syscall.Close(socketFd)
+				syscall.Close(socketFD)
 				lastErr = common.ContextError(fmt.Errorf("BindToDevice failed: %s", err))
 				continue
 			}
@@ -158,61 +147,119 @@ func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
 
 		// Connect socket to the server's IP address
 
-		err = syscall.SetNonblock(socketFd, true)
+		err = syscall.SetNonblock(socketFD, true)
 		if err != nil {
-			syscall.Close(socketFd)
+			syscall.Close(socketFD)
 			lastErr = common.ContextError(err)
 			continue
 		}
 
-		err = syscall.Connect(socketFd, sockAddr)
+		err = syscall.Connect(socketFD, sockAddr)
 		if err != nil {
 			if errno, ok := err.(syscall.Errno); !ok || errno != syscall.EINPROGRESS {
-				syscall.Close(socketFd)
+				syscall.Close(socketFD)
 				lastErr = common.ContextError(err)
 				continue
 			}
 		}
 
-		fdset := &goselect.FDSet{}
-		fdset.Set(uintptr(socketFd))
+		// Use a control pipe to interrupt if the dial context is done (timeout or
+		// interrupted) before the TCP connection is established.
+
+		var controlFDs [2]int
+		err = syscall.Pipe(controlFDs[:])
+		if err != nil {
+			syscall.Close(socketFD)
+			lastErr = common.ContextError(err)
+			continue
 
-		timeout := config.ConnectTimeout
-		if config.ConnectTimeout == 0 {
-			timeout = -1
 		}
 
-		err = goselect.Select(socketFd+1, nil, fdset, nil, timeout)
+		for _, controlFD := range controlFDs {
+			syscall.CloseOnExec(controlFD)
+			err = syscall.SetNonblock(controlFD, true)
+			if err != nil {
+				break
+			}
+		}
+
 		if err != nil {
-			syscall.Close(socketFd)
+			syscall.Close(socketFD)
 			lastErr = common.ContextError(err)
 			continue
 		}
 
-		if !fdset.IsSet(uintptr(socketFd)) {
-			syscall.Close(socketFd)
-			lastErr = common.ContextError(errors.New("connect timed out"))
+		resultChannel := make(chan error)
+
+		go func() {
+
+			readSet := goselect.FDSet{}
+			readSet.Set(uintptr(controlFDs[0]))
+			writeSet := goselect.FDSet{}
+			writeSet.Set(uintptr(socketFD))
+
+			max := socketFD
+			if controlFDs[0] > max {
+				max = controlFDs[0]
+			}
+
+			err := goselect.Select(max+1, &readSet, &writeSet, nil, -1)
+
+			if err == nil && !writeSet.IsSet(uintptr(socketFD)) {
+				err = errors.New("interrupted")
+			}
+
+			resultChannel <- err
+		}()
+
+		done := false
+		select {
+		case err = <-resultChannel:
+		case <-ctx.Done():
+			err = ctx.Err()
+			// Interrupt the goroutine
+			// TODO: if this Write fails, abandon the goroutine instead of hanging?
+			var b [1]byte
+			syscall.Write(controlFDs[1], b[:])
+			<-resultChannel
+			done = true
+		}
+
+		syscall.Close(controlFDs[0])
+		syscall.Close(controlFDs[1])
+
+		if err != nil {
+			syscall.Close(socketFD)
+
+			if done {
+				// Skip retry as dial context has timed out of been canceled.
+				return nil, common.ContextError(err)
+			}
+
+			lastErr = common.ContextError(err)
 			continue
 		}
 
-		err = syscall.SetNonblock(socketFd, false)
+		err = syscall.SetNonblock(socketFD, false)
 		if err != nil {
-			syscall.Close(socketFd)
+			syscall.Close(socketFD)
 			lastErr = common.ContextError(err)
 			continue
 		}
 
 		// Convert the socket fd to a net.Conn
+		// This code block is from:
+		// https://github.com/golang/go/issues/6966
 
-		file := os.NewFile(uintptr(socketFd), "")
-		netConn, err := net.FileConn(file) // net.FileConn() dups socketFd
-		file.Close()                       // file.Close() closes socketFd
+		file := os.NewFile(uintptr(socketFD), "")
+		conn, err := net.FileConn(file) // net.FileConn() dups socketFD
+		file.Close()                    // file.Close() closes socketFD
 		if err != nil {
 			lastErr = common.ContextError(err)
 			continue
 		}
 
-		return netConn, nil
+		return &TCPConn{Conn: conn}, nil
 	}
 
 	return nil, lastErr

+ 12 - 3
psiphon/TCPConn_nobind.go

@@ -22,18 +22,27 @@
 package psiphon
 
 import (
+	"context"
 	"errors"
 	"net"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
-// tcpDial is the platform-specific part of interruptibleTCPDial
-func tcpDial(addr string, config *DialConfig) (net.Conn, error) {
+// tcpDial is the platform-specific part of DialTCP
+func tcpDial(ctx context.Context, addr string, config *DialConfig) (net.Conn, error) {
 
 	if config.DeviceBinder != nil {
 		return nil, common.ContextError(errors.New("psiphon.interruptibleTCPDial with DeviceBinder not supported"))
 	}
 
-	return net.DialTimeout("tcp", addr, config.ConnectTimeout)
+	dialer := net.Dialer{}
+
+	conn, err := dialer.DialContext(ctx, "tcp", addr)
+
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	return &TCPConn{Conn: conn}, nil
 }

+ 74 - 139
psiphon/controller.go

@@ -24,6 +24,7 @@
 package psiphon
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"math/rand"
@@ -43,8 +44,8 @@ import (
 type Controller struct {
 	config                             *Config
 	sessionId                          string
-	componentFailureSignal             chan struct{}
-	shutdownBroadcast                  chan struct{}
+	runCtx                             context.Context
+	stopRunning                        context.CancelFunc
 	runWaitGroup                       *sync.WaitGroup
 	connectedTunnels                   chan *Tunnel
 	failedTunnels                      chan *Tunnel
@@ -59,11 +60,10 @@ type Controller struct {
 	concurrentMeekEstablishTunnels     int
 	peakConcurrentEstablishTunnels     int
 	peakConcurrentMeekEstablishTunnels int
+	establishCtx                       context.Context
+	stopEstablish                      context.CancelFunc
 	establishWaitGroup                 *sync.WaitGroup
-	stopEstablishingBroadcast          chan struct{}
 	candidateServerEntries             chan *candidateServerEntry
-	establishPendingConns              *common.Conns
-	untunneledPendingConns             *common.Conns
 	untunneledDialConfig               *DialConfig
 	splitTunnelClassifier              *SplitTunnelClassifier
 	signalFetchCommonRemoteServerList  chan struct{}
@@ -93,16 +93,9 @@ func NewController(config *Config) (controller *Controller, err error) {
 	// tunnels established by the controller.
 	NoticeSessionId(config.SessionID)
 
-	// untunneledPendingConns may be used to interrupt the fetch remote server list
-	// request and other untunneled connection establishments. BindToDevice may be
-	// used to exclude these requests and connection from VPN routing.
-	// TODO: fetch remote server list and untunneled upgrade download should remove
-	// their completed conns from untunneledPendingConns.
-	untunneledPendingConns := new(common.Conns)
 	untunneledDialConfig := &DialConfig{
 		UpstreamProxyUrl:              config.UpstreamProxyUrl,
 		CustomHeaders:                 config.CustomHeaders,
-		PendingConns:                  untunneledPendingConns,
 		DeviceBinder:                  config.DeviceBinder,
 		DnsServerGetter:               config.DnsServerGetter,
 		IPv6Synthesizer:               config.IPv6Synthesizer,
@@ -112,14 +105,9 @@ func NewController(config *Config) (controller *Controller, err error) {
 	}
 
 	controller = &Controller{
-		config:    config,
-		sessionId: config.SessionID,
-		// componentFailureSignal receives a signal from a component (including socks and
-		// http local proxies) if they unexpectedly fail. Senders should not block.
-		// Buffer allows at least one stop signal to be sent before there is a receiver.
-		componentFailureSignal: make(chan struct{}, 1),
-		shutdownBroadcast:      make(chan struct{}),
-		runWaitGroup:           new(sync.WaitGroup),
+		config:       config,
+		sessionId:    config.SessionID,
+		runWaitGroup: new(sync.WaitGroup),
 		// connectedTunnels and failedTunnels buffer sizes are large enough to
 		// receive full pools of tunnels without blocking. Senders should not block.
 		connectedTunnels:               make(chan *Tunnel, config.TunnelPoolSize),
@@ -128,8 +116,6 @@ func NewController(config *Config) (controller *Controller, err error) {
 		establishedOnce:                false,
 		startedConnectedReporter:       false,
 		isEstablishing:                 false,
-		establishPendingConns:          new(common.Conns),
-		untunneledPendingConns:         untunneledPendingConns,
 		untunneledDialConfig:           untunneledDialConfig,
 		impairedProtocolClassification: make(map[string]int),
 		// TODO: Add a buffer of 1 so we don't miss a signal while receiver is
@@ -172,19 +158,18 @@ func NewController(config *Config) (controller *Controller, err error) {
 	return controller, nil
 }
 
-// Run executes the controller. It launches components and then monitors
-// for a shutdown signal; after receiving the signal it shuts down the
-// controller.
-// The components include:
-// - the periodic remote server list fetcher
-// - the connected reporter
-// - the tunnel manager
-// - a local SOCKS proxy that port forwards through the pool of tunnels
-// - a local HTTP proxy that port forwards through the pool of tunnels
-func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
+// Run executes the controller. Run exits if a controller
+// component fails or the parent context is canceled.
+func (controller *Controller) Run(ctx context.Context) {
 
 	ReportAvailableRegions()
 
+	runCtx, stopRunning := context.WithCancel(ctx)
+	defer stopRunning()
+
+	controller.runCtx = runCtx
+	controller.stopRunning = stopRunning
+
 	// Start components
 
 	// TODO: IPv6 support
@@ -215,8 +200,7 @@ func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
 	}
 
 	if !controller.config.DisableLocalHTTPProxy {
-		httpProxy, err := NewHttpProxy(
-			controller.config, controller.untunneledDialConfig, controller, listenIP)
+		httpProxy, err := NewHttpProxy(controller.config, controller, listenIP)
 		if err != nil {
 			NoticeAlert("error initializing local HTTP proxy: %s", err)
 			return
@@ -272,38 +256,18 @@ func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
 
 	// Wait while running
 
-	select {
-	case <-shutdownBroadcast:
-		NoticeInfo("controller shutdown by request")
-	case <-controller.componentFailureSignal:
-		NoticeAlert("controller shutdown due to component failure")
-	}
-
-	close(controller.shutdownBroadcast)
+	<-controller.runCtx.Done()
+	NoticeInfo("controller stopped")
 
 	if controller.packetTunnelClient != nil {
 		controller.packetTunnelClient.Stop()
 	}
 
-	// Interrupts and stops establish workers blocking on
-	// tunnel establishment network operations.
-	controller.establishPendingConns.CloseAll()
-
-	// Interrupts and stops workers blocking on untunneled
-	// network operations. This includes fetch remote server
-	// list and untunneled uprade download.
-	// Note: this doesn't interrupt the final, untunneled status
-	// requests started in operateTunnel after shutdownBroadcast.
-	// This is by design -- we want to give these requests a short
-	// timer period to succeed and deliver stats. These particular
-	// requests opt out of untunneledPendingConns and use the
-	// PSIPHON_API_SHUTDOWN_SERVER_TIMEOUT timeout (see
-	// doUntunneledStatusRequest).
-	controller.untunneledPendingConns.CloseAll()
-
-	// Now with all workers signaled to stop and with all
-	// blocking network operations interrupted, wait for
-	// all workers to terminate.
+	// All workers -- runTunnels, establishment workers, and auxilliary
+	// workers such as fetch remote server list and untunneled uprade
+	// download -- operate with the controller run context and will all
+	// be interrupted when the run context is done.
+
 	controller.runWaitGroup.Wait()
 
 	controller.splitTunnelClassifier.Shutdown()
@@ -316,10 +280,8 @@ func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
 // SignalComponentFailure notifies the controller that an associated component has failed.
 // This will terminate the controller.
 func (controller *Controller) SignalComponentFailure() {
-	select {
-	case controller.componentFailureSignal <- *new(struct{}):
-	default:
-	}
+	NoticeAlert("controller shutdown due to component failure")
+	controller.stopRunning()
 }
 
 // SetClientVerificationPayloadForActiveTunnels sets the client verification
@@ -363,7 +325,7 @@ fetcherLoop:
 		// Wait for a signal before fetching
 		select {
 		case <-signal:
-		case <-controller.shutdownBroadcast:
+		case <-controller.runCtx.Done():
 			break fetcherLoop
 		}
 
@@ -379,8 +341,8 @@ fetcherLoop:
 			// Don't attempt to fetch while there is no network connectivity,
 			// to avoid alert notice noise.
 			if !WaitForNetworkConnectivity(
-				controller.config.NetworkConnectivityChecker,
-				controller.shutdownBroadcast) {
+				controller.runCtx,
+				controller.config.NetworkConnectivityChecker) {
 				break fetcherLoop
 			}
 
@@ -389,6 +351,7 @@ fetcherLoop:
 			tunnel := controller.getNextActiveTunnel()
 
 			err := fetcher(
+				controller.runCtx,
 				controller.config,
 				attempt,
 				tunnel,
@@ -404,7 +367,7 @@ fetcherLoop:
 			timer := time.NewTimer(retryPeriod)
 			select {
 			case <-timer.C:
-			case <-controller.shutdownBroadcast:
+			case <-controller.runCtx.Done():
 				timer.Stop()
 				break fetcherLoop
 			}
@@ -432,7 +395,7 @@ func (controller *Controller) establishTunnelWatcher() {
 			NoticeAlert("failed to establish tunnel before timeout")
 			controller.SignalComponentFailure()
 		}
-	case <-controller.shutdownBroadcast:
+	case <-controller.runCtx.Done():
 	}
 
 	NoticeInfo("exiting establish tunnel watcher")
@@ -477,7 +440,7 @@ loop:
 		case <-controller.signalReportConnected:
 		case <-timer.C:
 			// Make another connected request
-		case <-controller.shutdownBroadcast:
+		case <-controller.runCtx.Done():
 			doBreak = true
 		}
 		timer.Stop()
@@ -537,7 +500,7 @@ downloadLoop:
 		var handshakeVersion string
 		select {
 		case handshakeVersion = <-controller.signalDownloadUpgrade:
-		case <-controller.shutdownBroadcast:
+		case <-controller.runCtx.Done():
 			break downloadLoop
 		}
 
@@ -554,8 +517,8 @@ downloadLoop:
 			// Don't attempt to download while there is no network connectivity,
 			// to avoid alert notice noise.
 			if !WaitForNetworkConnectivity(
-				controller.config.NetworkConnectivityChecker,
-				controller.shutdownBroadcast) {
+				controller.runCtx,
+				controller.config.NetworkConnectivityChecker) {
 				break downloadLoop
 			}
 
@@ -564,6 +527,7 @@ downloadLoop:
 			tunnel := controller.getNextActiveTunnel()
 
 			err := DownloadUpgrade(
+				controller.runCtx,
 				controller.config,
 				attempt,
 				handshakeVersion,
@@ -581,7 +545,7 @@ downloadLoop:
 				time.Duration(*controller.config.DownloadUpgradeRetryPeriodSeconds) * time.Second)
 			select {
 			case <-timer.C:
-			case <-controller.shutdownBroadcast:
+			case <-controller.runCtx.Done():
 				timer.Stop()
 				break downloadLoop
 			}
@@ -620,19 +584,6 @@ loop:
 			NoticeAlert("tunnel failed: %s", failedTunnel.serverEntry.IpAddress)
 			controller.terminateTunnel(failedTunnel)
 
-			// Note: we make this extra check to ensure the shutdown signal takes priority
-			// and that we do not start establishing. Critically, startEstablishing() calls
-			// establishPendingConns.Reset() which clears the closed flag in
-			// establishPendingConns; this causes the pendingConns.Add() within
-			// interruptibleTCPDial to succeed instead of aborting, and the result
-			// is that it's possible for establish goroutines to run all the way through
-			// NewServerContext before being discarded... delaying shutdown.
-			select {
-			case <-controller.shutdownBroadcast:
-				break loop
-			default:
-			}
-
 			controller.classifyImpairedProtocol(failedTunnel)
 
 			// Clear the reference to this tunnel before calling startEstablishing,
@@ -697,7 +648,7 @@ loop:
 					controller.stopEstablishing()
 				}
 
-				err := connectedTunnel.Activate(controller, controller.shutdownBroadcast)
+				err := connectedTunnel.Activate(controller.runCtx, controller)
 
 				if err != nil {
 
@@ -801,7 +752,7 @@ loop:
 		case clientVerificationPayload = <-controller.newClientVerificationPayload:
 			controller.setClientVerificationPayloadForActiveTunnels(clientVerificationPayload)
 
-		case <-controller.shutdownBroadcast:
+		case <-controller.runCtx.Done():
 			break loop
 		}
 	}
@@ -1108,8 +1059,7 @@ func (controller *Controller) Dial(
 		// relative to the outbound network.
 
 		if controller.splitTunnelClassifier.IsUntunneled(host) {
-			// TODO: track downstreamConn and close it when the DialTCP conn closes, as with tunnel.Dial conns?
-			return DialTCP(remoteAddr, controller.untunneledDialConfig)
+			return controller.DirectDial(remoteAddr)
 		}
 	}
 
@@ -1121,6 +1071,11 @@ func (controller *Controller) Dial(
 	return tunneledConn, nil
 }
 
+// DirectDial dials an untunneled TCP connection within the controller run context.
+func (controller *Controller) DirectDial(remoteAddr string) (conn net.Conn, err error) {
+	return DialTCP(controller.runCtx, remoteAddr, controller.untunneledDialConfig)
+}
+
 // startEstablishing creates a pool of worker goroutines which will
 // attempt to establish tunnels to candidate servers. The candidates
 // are generated by another goroutine.
@@ -1140,11 +1095,13 @@ func (controller *Controller) startEstablishing() {
 	aggressiveGarbageCollection()
 	emitMemoryMetrics()
 
+	// Note: the establish context cancelFunc, controller.stopEstablish,
+	// is called in controller.stopEstablishing.
+
 	controller.isEstablishing = true
+	controller.establishCtx, controller.stopEstablish = context.WithCancel(controller.runCtx)
 	controller.establishWaitGroup = new(sync.WaitGroup)
-	controller.stopEstablishingBroadcast = make(chan struct{})
 	controller.candidateServerEntries = make(chan *candidateServerEntry)
-	controller.establishPendingConns.Reset()
 
 	// The server affinity mechanism attempts to favor the previously
 	// used server when reconnecting. This is beneficial for user
@@ -1184,25 +1141,22 @@ func (controller *Controller) startEstablishing() {
 }
 
 // stopEstablishing signals the establish goroutines to stop and waits
-// for the group to halt. pendingConns is used to interrupt any worker
-// blocked on a socket connect.
+// for the group to halt.
 func (controller *Controller) stopEstablishing() {
 	if !controller.isEstablishing {
 		return
 	}
 	NoticeInfo("stop establishing")
-	close(controller.stopEstablishingBroadcast)
-	// Note: interruptibleTCPClose doesn't really interrupt socket connects
-	// and may leave goroutines running for a time after the Wait call.
-	controller.establishPendingConns.CloseAll()
+	controller.stopEstablish()
 	// Note: establishCandidateGenerator closes controller.candidateServerEntries
 	// (as it may be sending to that channel).
 	controller.establishWaitGroup.Wait()
 	NoticeInfo("stopped establishing")
 
 	controller.isEstablishing = false
+	controller.establishCtx = nil
+	controller.stopEstablish = nil
 	controller.establishWaitGroup = nil
-	controller.stopEstablishingBroadcast = nil
 	controller.candidateServerEntries = nil
 	controller.serverAffinityDoneBroadcast = nil
 
@@ -1260,9 +1214,8 @@ loop:
 		networkWaitStartTime := monotime.Now()
 
 		if !WaitForNetworkConnectivity(
-			controller.config.NetworkConnectivityChecker,
-			controller.stopEstablishingBroadcast,
-			controller.shutdownBroadcast) {
+			controller.establishCtx,
+			controller.config.NetworkConnectivityChecker) {
 			break loop
 		}
 
@@ -1325,9 +1278,7 @@ loop:
 
 			select {
 			case controller.candidateServerEntries <- candidate:
-			case <-controller.stopEstablishingBroadcast:
-				break loop
-			case <-controller.shutdownBroadcast:
+			case <-controller.establishCtx.Done():
 				break loop
 			}
 
@@ -1344,37 +1295,27 @@ loop:
 				// and the grace period has elapsed.
 
 				timer := time.NewTimer(ESTABLISH_TUNNEL_SERVER_AFFINITY_GRACE_PERIOD)
-				doBreak := false
 				select {
 				case <-timer.C:
 				case <-controller.serverAffinityDoneBroadcast:
-				case <-controller.stopEstablishingBroadcast:
-					doBreak = true
-				case <-controller.shutdownBroadcast:
-					doBreak = true
-				}
-				timer.Stop()
-				if doBreak {
+				case <-controller.establishCtx.Done():
+					timer.Stop()
 					break loop
 				}
+				timer.Stop()
 			} else if controller.config.StaggerConnectionWorkersMilliseconds != 0 {
 
 				// Stagger concurrent connection workers.
 
 				timer := time.NewTimer(time.Millisecond * time.Duration(
 					controller.config.StaggerConnectionWorkersMilliseconds))
-				doBreak := false
 				select {
 				case <-timer.C:
-				case <-controller.stopEstablishingBroadcast:
-					doBreak = true
-				case <-controller.shutdownBroadcast:
-					doBreak = true
-				}
-				timer.Stop()
-				if doBreak {
+				case <-controller.establishCtx.Done():
+					timer.Stop()
 					break loop
 				}
+				timer.Stop()
 			}
 		}
 
@@ -1416,19 +1357,14 @@ loop:
 		// be more rounds if required).
 		timer := time.NewTimer(
 			time.Duration(*controller.config.EstablishTunnelPausePeriodSeconds) * time.Second)
-		doBreak := false
 		select {
 		case <-timer.C:
 			// Retry iterating
-		case <-controller.stopEstablishingBroadcast:
-			doBreak = true
-		case <-controller.shutdownBroadcast:
-			doBreak = true
-		}
-		timer.Stop()
-		if doBreak {
+		case <-controller.establishCtx.Done():
+			timer.Stop()
 			break loop
 		}
+		timer.Stop()
 
 		iterator.Reset()
 	}
@@ -1440,9 +1376,9 @@ func (controller *Controller) establishTunnelWorker() {
 	defer controller.establishWaitGroup.Done()
 loop:
 	for candidateServerEntry := range controller.candidateServerEntries {
-		// Note: don't receive from candidateServerEntries and stopEstablishingBroadcast
+		// Note: don't receive from candidateServerEntries and isStopEstablishing
 		// in the same select, since we want to prioritize receiving the stop signal
-		if controller.isStopEstablishingBroadcast() {
+		if controller.isStopEstablishing() {
 			break loop
 		}
 
@@ -1517,10 +1453,9 @@ loop:
 			controller.concurrentEstablishTunnelsMutex.Unlock()
 
 			tunnel, err = ConnectTunnel(
+				controller.establishCtx,
 				controller.config,
-				controller.untunneledDialConfig,
 				controller.sessionId,
-				controller.establishPendingConns,
 				candidateServerEntry.serverEntry,
 				selectedProtocol,
 				candidateServerEntry.adjustedEstablishStartTime)
@@ -1534,7 +1469,7 @@ loop:
 		}
 
 		// Periodically emit memory metrics during the establishment cycle.
-		if !controller.isStopEstablishingBroadcast() {
+		if !controller.isStopEstablishing() {
 			emitMemoryMetrics()
 		}
 
@@ -1559,7 +1494,7 @@ loop:
 
 			// Before emitting error, check if establish interrupted, in which
 			// case the error is noise.
-			if controller.isStopEstablishingBroadcast() {
+			if controller.isStopEstablishing() {
 				break loop
 			}
 
@@ -1590,9 +1525,9 @@ loop:
 	}
 }
 
-func (controller *Controller) isStopEstablishingBroadcast() bool {
+func (controller *Controller) isStopEstablishing() bool {
 	select {
-	case <-controller.stopEstablishingBroadcast:
+	case <-controller.establishCtx.Done():
 		return true
 	default:
 	}

+ 11 - 3
psiphon/controller_test.go

@@ -20,6 +20,7 @@
 package psiphon
 
 import (
+	"context"
 	"encoding/json"
 	"flag"
 	"fmt"
@@ -666,18 +667,21 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 
 	// Run controller, which establishes tunnels
 
-	shutdownBroadcast := make(chan struct{})
+	ctx, cancelFunc := context.WithCancel(context.Background())
+
 	controllerWaitGroup := new(sync.WaitGroup)
+
 	controllerWaitGroup.Add(1)
 	go func() {
 		defer controllerWaitGroup.Done()
-		controller.Run(shutdownBroadcast)
+		controller.Run(ctx)
 	}()
 
 	defer func() {
+
 		// Test: shutdown must complete within 20 seconds
 
-		close(shutdownBroadcast)
+		cancelFunc()
 
 		shutdownTimeout := time.NewTimer(20 * time.Second)
 
@@ -1036,6 +1040,10 @@ func initDisruptor() {
 		for {
 			localConn, err := listener.AcceptSocks()
 			if err != nil {
+				if e, ok := err.(net.Error); ok && e.Temporary() {
+					fmt.Printf("disruptor proxy temporary accept error: %s", err)
+					continue
+				}
 				fmt.Printf("disruptor proxy accept error: %s\n", err)
 				return
 			}

+ 13 - 4
psiphon/feedback.go

@@ -21,6 +21,7 @@ package psiphon
 
 import (
 	"bytes"
+	"context"
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/hmac"
@@ -111,7 +112,6 @@ func SendFeedback(configJson, diagnosticsJson, b64EncodedPublicKey, uploadServer
 	untunneledDialConfig := &DialConfig{
 		UpstreamProxyUrl:              config.UpstreamProxyUrl,
 		CustomHeaders:                 config.CustomHeaders,
-		PendingConns:                  nil,
 		DeviceBinder:                  nil,
 		IPv6Synthesizer:               nil,
 		DnsServerGetter:               nil,
@@ -157,13 +157,22 @@ func SendFeedback(configJson, diagnosticsJson, b64EncodedPublicKey, uploadServer
 
 // Attempt to upload feedback data to server.
 func uploadFeedback(config *DialConfig, feedbackData []byte, url, userAgent string, headerPieces []string) error {
-	client, parsedUrl, err := MakeUntunneledHttpsClient(
-		config, nil, url, false, time.Duration(FEEDBACK_UPLOAD_TIMEOUT_SECONDS*time.Second))
+
+	ctx, cancelFunc := context.WithTimeout(
+		context.Background(),
+		time.Duration(FEEDBACK_UPLOAD_TIMEOUT_SECONDS*time.Second))
+	defer cancelFunc()
+
+	client, err := MakeUntunneledHTTPClient(
+		ctx,
+		config,
+		nil,
+		false)
 	if err != nil {
 		return err
 	}
 
-	req, err := http.NewRequest("PUT", parsedUrl, bytes.NewBuffer(feedbackData))
+	req, err := http.NewRequest("PUT", url, bytes.NewBuffer(feedbackData))
 	if err != nil {
 		return common.ContextError(err)
 	}

+ 1 - 3
psiphon/httpProxy.go

@@ -89,7 +89,6 @@ var _HTTP_PROXY_TYPE = "HTTP"
 // NewHttpProxy initializes and runs a new HTTP proxy server.
 func NewHttpProxy(
 	config *Config,
-	untunneledDialConfig *DialConfig,
 	tunneler Tunneler,
 	listenIP string) (proxy *HttpProxy, err error) {
 
@@ -106,11 +105,10 @@ func NewHttpProxy(
 		// downstreamConn is not set in this case, as there is not a fixed
 		// association between a downstream client connection and a particular
 		// tunnel.
-		// TODO: connect timeout?
 		return tunneler.Dial(addr, false, nil)
 	}
 	directDialer := func(_, addr string) (conn net.Conn, err error) {
-		return DialTCP(addr, untunneledDialConfig)
+		return tunneler.DirectDial(addr)
 	}
 
 	responseHeaderTimeout := time.Duration(*config.HttpProxyOriginServerTimeoutSeconds) * time.Second

+ 227 - 0
psiphon/interrupt_dials_test.go

@@ -0,0 +1,227 @@
+/*
+ * Copyright (c) 2017, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import (
+	"context"
+	"fmt"
+	"net"
+	"runtime"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Inc/goarista/monotime"
+)
+
+func TestInterruptDials(t *testing.T) {
+
+	makeDialers := make(map[string]func(string) Dialer)
+
+	makeDialers["TCP"] = func(string) Dialer {
+		return NewTCPDialer(&DialConfig{})
+	}
+
+	makeDialers["SOCKS4-Proxied"] = func(mockServerAddr string) Dialer {
+		return NewTCPDialer(
+			&DialConfig{
+				UpstreamProxyUrl: "socks4a://" + mockServerAddr,
+			})
+	}
+
+	makeDialers["SOCKS5-Proxied"] = func(mockServerAddr string) Dialer {
+		return NewTCPDialer(
+			&DialConfig{
+				UpstreamProxyUrl: "socks5://" + mockServerAddr,
+			})
+	}
+
+	makeDialers["HTTP-CONNECT-Proxied"] = func(mockServerAddr string) Dialer {
+		return NewTCPDialer(
+			&DialConfig{
+				UpstreamProxyUrl: "http://" + mockServerAddr,
+			})
+	}
+
+	// TODO: test upstreamproxy.ProxyAuthTransport
+
+	makeDialers["TLS"] = func(string) Dialer {
+		return NewCustomTLSDialer(
+			&CustomTLSConfig{
+				Dial: NewTCPDialer(&DialConfig{}),
+			})
+	}
+
+	dialGoroutineFunctionNames := []string{"NewTCPDialer", "NewCustomTLSDialer"}
+
+	for dialerName, makeDialer := range makeDialers {
+		for _, doTimeout := range []bool{true, false} {
+			t.Run(
+				fmt.Sprintf("%s-timeout-%+v", dialerName, doTimeout),
+				func(t *testing.T) {
+					runInterruptDials(
+						t,
+						doTimeout,
+						makeDialer,
+						dialGoroutineFunctionNames)
+				})
+		}
+	}
+
+}
+
+func runInterruptDials(
+	t *testing.T,
+	doTimeout bool,
+	makeDialer func(string) Dialer,
+	dialGoroutineFunctionNames []string) {
+
+	t.Logf("Test timeout: %+v", doTimeout)
+
+	noAcceptListener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("Listen failed: %s", err)
+	}
+	defer noAcceptListener.Close()
+
+	noResponseListener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("Listen failed: %s", err)
+	}
+	defer noResponseListener.Close()
+
+	listenerAccepted := make(chan struct{}, 1)
+
+	noResponseListenerWaitGroup := new(sync.WaitGroup)
+	noResponseListenerWaitGroup.Add(1)
+	defer noResponseListenerWaitGroup.Wait()
+	go func() {
+		defer noResponseListenerWaitGroup.Done()
+		for {
+			conn, err := noResponseListener.Accept()
+			if err != nil {
+				return
+			}
+			listenerAccepted <- struct{}{}
+
+			var b [1024]byte
+			for {
+				_, err := conn.Read(b[:])
+				if err != nil {
+					conn.Close()
+					return
+				}
+			}
+		}
+	}()
+
+	var ctx context.Context
+	var cancelFunc context.CancelFunc
+
+	timeout := 100 * time.Millisecond
+
+	if doTimeout {
+		ctx, cancelFunc = context.WithTimeout(context.Background(), timeout)
+	} else {
+		ctx, cancelFunc = context.WithCancel(context.Background())
+	}
+
+	addrs := []string{
+		noAcceptListener.Addr().String(),
+		noResponseListener.Addr().String()}
+
+	dialTerminated := make(chan struct{}, len(addrs))
+
+	for _, addr := range addrs {
+		go func(addr string) {
+			conn, err := makeDialer(addr)(ctx, "tcp", addr)
+			if err == nil {
+				conn.Close()
+			}
+			dialTerminated <- struct{}{}
+		}(addr)
+	}
+
+	// Wait for noResponseListener to accept to ensure that we exercise
+	// post-TCP-dial interruption in the case of TLS and proxy dialers that
+	// do post-TCP-dial handshake I/O as part of their dial.
+
+	<-listenerAccepted
+
+	if doTimeout {
+		time.Sleep(timeout)
+		defer cancelFunc()
+	} else {
+		// No timeout, so interrupt with cancel
+		cancelFunc()
+	}
+
+	startWaiting := monotime.Now()
+
+	for _ = range addrs {
+		<-dialTerminated
+	}
+
+	// Test: dial interrupt must complete quickly
+
+	interruptDuration := monotime.Since(startWaiting)
+
+	if interruptDuration > 100*time.Millisecond {
+		t.Fatalf("interrupt duration too long: %s", interruptDuration)
+	}
+
+	// Test: interrupted dialers must not leave goroutines running
+
+	if findGoroutines(t, dialGoroutineFunctionNames) {
+		t.Fatalf("unexpected dial goroutines")
+	}
+}
+
+func findGoroutines(t *testing.T, targets []string) bool {
+	n, _ := runtime.GoroutineProfile(nil)
+	r := make([]runtime.StackRecord, n)
+	n, _ = runtime.GoroutineProfile(r)
+	found := false
+	for _, g := range r {
+		stack := g.Stack()
+		funcNames := make([]string, len(stack))
+		for i := 0; i < len(stack); i++ {
+			funcNames[i] = getFunctionName(stack[i])
+		}
+		s := strings.Join(funcNames, ", ")
+		for _, target := range targets {
+			if strings.Contains(s, target) {
+				t.Logf("found dial goroutine: %s", s)
+				found = true
+			}
+		}
+	}
+	return found
+}
+
+func getFunctionName(pc uintptr) string {
+	funcName := runtime.FuncForPC(pc).Name()
+	index := strings.LastIndex(funcName, "/")
+	if index != -1 {
+		funcName = funcName[index+1:]
+	}
+	return funcName
+}

+ 81 - 73
psiphon/meekConn.go

@@ -141,12 +141,11 @@ type MeekConn struct {
 	url                     *url.URL
 	additionalHeaders       http.Header
 	cookie                  *http.Cookie
-	pendingConns            *common.Conns
 	cachedTLSDialer         *cachedTLSDialer
 	transport               transporter
 	mutex                   sync.Mutex
 	isClosed                bool
-	runContext              context.Context
+	runCtx                  context.Context
 	stopRunning             context.CancelFunc
 	relayWaitGroup          *sync.WaitGroup
 	fullReceiveBufferLength int
@@ -173,35 +172,33 @@ type transporter interface {
 // When frontingAddress is not "", fronting is used. This option assumes caller has
 // already checked server entry capabilities.
 func DialMeek(
+	ctx context.Context,
 	meekConfig *MeekConfig,
 	dialConfig *DialConfig) (meek *MeekConn, err error) {
 
-	// Configure transport
-	// Note: MeekConn has its own PendingConns to manage the underlying HTTP transport connections,
-	// which may be interrupted on MeekConn.Close(). This code previously used the establishTunnel
-	// pendingConns here, but that was a lifecycle mismatch: we don't want to abort HTTP transport
-	// connections while MeekConn is still in use.
-	pendingConns := new(common.Conns)
+	runCtx, stopRunning := context.WithCancel(context.Background())
 
-	// Use a copy of DialConfig with the meek pendingConns
-	meekDialConfig := new(DialConfig)
-	*meekDialConfig = *dialConfig
-	meekDialConfig.PendingConns = pendingConns
-
-	var scheme string
+	cleanupStopRunning := true
 	cleanupCachedTLSDialer := true
 	var cachedTLSDialer *cachedTLSDialer
-	var transport transporter
-	var additionalHeaders http.Header
-	var proxyUrl func(*http.Request) (*url.URL, error)
 
-	// Close any cached pre-dialed conn in error cases
+	// Cleanup in error cases
 	defer func() {
+		if cleanupStopRunning {
+			stopRunning()
+		}
 		if cleanupCachedTLSDialer && cachedTLSDialer != nil {
-			cachedTLSDialer.Close()
+			cachedTLSDialer.close()
 		}
 	}()
 
+	// Configure transport: HTTP or HTTPS
+
+	var scheme string
+	var transport transporter
+	var additionalHeaders http.Header
+	var proxyUrl func(*http.Request) (*url.URL, error)
+
 	if meekConfig.UseHTTPS {
 
 		// Custom TLS dialer:
@@ -241,13 +238,12 @@ func DialMeek(
 
 		tlsConfig := &CustomTLSConfig{
 			DialAddr:                      meekConfig.DialAddress,
-			Dial:                          NewTCPDialer(meekDialConfig),
-			Timeout:                       meekDialConfig.ConnectTimeout,
+			Dial:                          NewTCPDialer(dialConfig),
 			SNIServerName:                 meekConfig.SNIServerName,
 			SkipVerify:                    true,
-			UseIndistinguishableTLS:       meekDialConfig.UseIndistinguishableTLS,
+			UseIndistinguishableTLS:       dialConfig.UseIndistinguishableTLS,
 			TLSProfile:                    meekConfig.TLSProfile,
-			TrustedCACertificatesFilename: meekDialConfig.TrustedCACertificatesFilename,
+			TrustedCACertificatesFilename: dialConfig.TrustedCACertificatesFilename,
 		}
 
 		if meekConfig.UseObfuscatedSessionTickets {
@@ -270,7 +266,7 @@ func DialMeek(
 		// return the cached pre-dialed connection to its first Dial caller, and
 		// use the tlsDialer for all other Dials.
 		//
-		// cachedTLSDialer.Close() must be called on all exits paths from this
+		// cachedTLSDialer.close() must be called on all exits paths from this
 		// function and in meek.Close() to ensure the cached conn is closed in
 		// any case where no Dial call is made.
 		//
@@ -286,20 +282,20 @@ func DialMeek(
 		// that the underlying TCPDial may still try multiple IP addreses when
 		// the destination is a domain and ir resolves to multiple IP adresses.
 
-		preConfig := &CustomTLSConfig{}
-		*preConfig = *tlsConfig
-		preConfig.Dial = NewTCPDialer(dialConfig)
-		preDialer := NewCustomTLSDialer(preConfig)
+		// The pre-dial is made within the parent dial context, so that DialMeek
+		// may be interrupted. Subsequent dials are made within the meek round trip
+		// request context. Since http.DialTLS doesn't take a context argument
+		// (yet; as of Go 1.9 this issue is still open: https://github.com/golang/go/issues/21526),
+		// cachedTLSDialer is used as a conduit to send the request context.
+		// meekConn.roundTrip sets its request context into cachedTLSDialer, and
+		// cachedTLSDialer.dial uses that context.
 
 		// As DialAddr is set in the CustomTLSConfig, no address is required here.
-		preConn, err := preDialer("tcp", "")
+		preConn, err := tlsDialer(ctx, "tcp", "")
 		if err != nil {
 			return nil, common.ContextError(err)
 		}
 
-		// Cancel interruptibility to keep this connection alive after establishment.
-		dialConfig.PendingConns.Remove(preConn)
-
 		isHTTP2 := false
 		if tlsConn, ok := preConn.(*tls.Conn); ok {
 			state := tlsConn.ConnectionState()
@@ -309,18 +305,20 @@ func DialMeek(
 			}
 		}
 
-		cachedTLSDialer = NewCachedTLSDialer(preConn, tlsDialer)
+		cachedTLSDialer = newCachedTLSDialer(preConn, tlsDialer)
 
 		if isHTTP2 {
 			NoticeInfo("negotiated HTTP/2 for %s", meekConfig.DialAddress)
 			transport = &http2.Transport{
 				DialTLS: func(network, addr string, _ *golangtls.Config) (net.Conn, error) {
-					return cachedTLSDialer.Dial(network, addr)
+					return cachedTLSDialer.dial(network, addr)
 				},
 			}
 		} else {
 			transport = &http.Transport{
-				DialTLS: cachedTLSDialer.Dial,
+				DialTLS: func(network, addr string) (net.Conn, error) {
+					return cachedTLSDialer.dial(network, addr)
+				},
 			}
 		}
 
@@ -330,8 +328,8 @@ func DialMeek(
 
 		// The dialer ignores address that http.Transport will pass in (derived
 		// from the HTTP request URL) and always dials meekConfig.DialAddress.
-		dialer := func(string, string) (net.Conn, error) {
-			return NewTCPDialer(meekDialConfig)("tcp", meekConfig.DialAddress)
+		dialer := func(ctx context.Context, network, _ string) (net.Conn, error) {
+			return NewTCPDialer(dialConfig)(ctx, network, meekConfig.DialAddress)
 		}
 
 		// For HTTP, and when the meekConfig.DialAddress matches the
@@ -339,29 +337,33 @@ func DialMeek(
 		// http.Transport will put the the HTTP server address in the HTTP
 		// request line. In this one case, we can use an HTTP proxy that does
 		// not offer CONNECT support.
-		if strings.HasPrefix(meekDialConfig.UpstreamProxyUrl, "http://") &&
+		if strings.HasPrefix(dialConfig.UpstreamProxyUrl, "http://") &&
 			(meekConfig.DialAddress == meekConfig.HostHeader ||
 				meekConfig.DialAddress == meekConfig.HostHeader+":80") {
-			url, err := url.Parse(meekDialConfig.UpstreamProxyUrl)
+
+			url, err := url.Parse(dialConfig.UpstreamProxyUrl)
 			if err != nil {
 				return nil, common.ContextError(err)
 			}
 			proxyUrl = http.ProxyURL(url)
-			meekDialConfig.UpstreamProxyUrl = ""
 
 			// Here, the dialer must use the address that http.Transport
 			// passes in (which will be proxy address).
-			dialer = NewTCPDialer(meekDialConfig)
+			copyDialConfig := new(DialConfig)
+			*copyDialConfig = *dialConfig
+			copyDialConfig.UpstreamProxyUrl = ""
+
+			dialer = NewTCPDialer(copyDialConfig)
 		}
 
-		// TODO: wrap in an http.Client and use http.Client.Timeout which actually covers round trip
 		httpTransport := &http.Transport{
-			Proxy: proxyUrl,
-			Dial:  dialer,
+			Proxy:       proxyUrl,
+			DialContext: dialer,
 		}
+
 		if proxyUrl != nil {
 			// Wrap transport with a transport that can perform HTTP proxy auth negotiation
-			transport, err = upstreamproxy.NewProxyAuthTransport(httpTransport, meekDialConfig.CustomHeaders)
+			transport, err = upstreamproxy.NewProxyAuthTransport(httpTransport, dialConfig.CustomHeaders)
 			if err != nil {
 				return nil, common.ContextError(err)
 			}
@@ -386,7 +388,7 @@ func DialMeek(
 		}
 	} else {
 		if proxyUrl == nil {
-			additionalHeaders = meekDialConfig.CustomHeaders
+			additionalHeaders = dialConfig.CustomHeaders
 		}
 	}
 
@@ -395,8 +397,6 @@ func DialMeek(
 		return nil, common.ContextError(err)
 	}
 
-	runContext, stopRunning := context.WithCancel(context.Background())
-
 	// The main loop of a MeekConn is run in the relay() goroutine.
 	// A MeekConn implements net.Conn concurrency semantics:
 	// "Multiple goroutines may invoke methods on a Conn simultaneously."
@@ -417,11 +417,10 @@ func DialMeek(
 		url:                     url,
 		additionalHeaders:       additionalHeaders,
 		cookie:                  cookie,
-		pendingConns:            pendingConns,
 		cachedTLSDialer:         cachedTLSDialer,
 		transport:               transport,
 		isClosed:                false,
-		runContext:              runContext,
+		runCtx:                  runCtx,
 		stopRunning:             stopRunning,
 		relayWaitGroup:          new(sync.WaitGroup),
 		fullReceiveBufferLength: FULL_RECEIVE_BUFFER_LENGTH,
@@ -434,7 +433,8 @@ func DialMeek(
 		fullSendBuffer:          make(chan *bytes.Buffer, 1),
 	}
 
-	// cachedTLSDialer will now be closed in meek.Close()
+	// stopRunning and cachedTLSDialer will now be closed in meek.Close()
+	cleanupStopRunning = false
 	cleanupCachedTLSDialer = false
 
 	meek.emptyReceiveBuffer <- new(bytes.Buffer)
@@ -448,38 +448,41 @@ func DialMeek(
 
 	go meek.relay()
 
-	// Enable interruption
-	if !dialConfig.PendingConns.Add(meek) {
-		meek.Close()
-		return nil, common.ContextError(errors.New("pending connections already closed"))
-	}
-
 	return meek, nil
 }
 
 type cachedTLSDialer struct {
 	usedCachedConn int32
 	cachedConn     net.Conn
+	requestContext atomic.Value
 	dialer         Dialer
 }
 
-func NewCachedTLSDialer(cachedConn net.Conn, dialer Dialer) *cachedTLSDialer {
+func newCachedTLSDialer(cachedConn net.Conn, dialer Dialer) *cachedTLSDialer {
 	return &cachedTLSDialer{
 		cachedConn: cachedConn,
 		dialer:     dialer,
 	}
 }
 
-func (c *cachedTLSDialer) Dial(network, addr string) (net.Conn, error) {
+func (c *cachedTLSDialer) setRequestContext(requestContext context.Context) {
+	c.requestContext.Store(requestContext)
+}
+
+func (c *cachedTLSDialer) dial(network, addr string) (net.Conn, error) {
 	if atomic.CompareAndSwapInt32(&c.usedCachedConn, 0, 1) {
 		conn := c.cachedConn
 		c.cachedConn = nil
 		return conn, nil
 	}
-	return c.dialer(network, addr)
+	ctx := c.requestContext.Load().(context.Context)
+	if ctx == nil {
+		ctx = context.Background()
+	}
+	return c.dialer(ctx, network, addr)
 }
 
-func (c *cachedTLSDialer) Close() {
+func (c *cachedTLSDialer) close() {
 	if atomic.CompareAndSwapInt32(&c.usedCachedConn, 0, 1) {
 		c.cachedConn.Close()
 		c.cachedConn = nil
@@ -498,9 +501,8 @@ func (meek *MeekConn) Close() (err error) {
 
 	if !isClosed {
 		meek.stopRunning()
-		meek.pendingConns.CloseAll()
 		if meek.cachedTLSDialer != nil {
-			meek.cachedTLSDialer.Close()
+			meek.cachedTLSDialer.close()
 		}
 		meek.relayWaitGroup.Wait()
 		meek.transport.CloseIdleConnections()
@@ -530,7 +532,7 @@ func (meek *MeekConn) Read(buffer []byte) (n int, err error) {
 	select {
 	case receiveBuffer = <-meek.partialReceiveBuffer:
 	case receiveBuffer = <-meek.fullReceiveBuffer:
-	case <-meek.runContext.Done():
+	case <-meek.runCtx.Done():
 		return 0, common.ContextError(errors.New("meek connection has closed"))
 	}
 	n, err = receiveBuffer.Read(buffer)
@@ -552,7 +554,7 @@ func (meek *MeekConn) Write(buffer []byte) (n int, err error) {
 		select {
 		case sendBuffer = <-meek.emptySendBuffer:
 		case sendBuffer = <-meek.partialSendBuffer:
-		case <-meek.runContext.Done():
+		case <-meek.runCtx.Done():
 			return 0, common.ContextError(errors.New("meek connection has closed"))
 		}
 		writeLen := MAX_SEND_PAYLOAD_LENGTH - sendBuffer.Len()
@@ -641,13 +643,13 @@ func (meek *MeekConn) relay() {
 		case sendBuffer = <-meek.fullSendBuffer:
 		case <-timeout.C:
 			// In the polling case, send an empty payload
-		case <-meek.runContext.Done():
+		case <-meek.runCtx.Done():
 			// Drop through to second Done() check
 		}
 
 		// Check Done() again, to ensure it takes precedence
 		select {
-		case <-meek.runContext.Done():
+		case <-meek.runCtx.Done():
 			return
 		default:
 		}
@@ -670,7 +672,7 @@ func (meek *MeekConn) relay() {
 
 		if err != nil {
 			select {
-			case <-meek.runContext.Done():
+			case <-meek.runCtx.Done():
 				// In this case, meek.roundTrip encountered Done(). Exit without logging error.
 				return
 			default:
@@ -827,7 +829,7 @@ func (meek *MeekConn) roundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 			// still reading the current round trip response. signaller provides
 			// the hook for awaiting RoundTrip's call to Close.
 
-			signaller = NewReadCloseSignaller(meek.runContext, bytes.NewReader(sendBuffer.Bytes()))
+			signaller = NewReadCloseSignaller(meek.runCtx, bytes.NewReader(sendBuffer.Bytes()))
 			requestBody = signaller
 			contentLength = sendBuffer.Len()
 		}
@@ -848,9 +850,15 @@ func (meek *MeekConn) roundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 		// - meek.stopRunning() will abort a round trip in flight
 		// - round trip will abort if it exceeds MEEK_ROUND_TRIP_TIMEOUT
 		requestContext, cancelFunc := context.WithTimeout(
-			meek.runContext,
+			meek.runCtx,
 			MEEK_ROUND_TRIP_TIMEOUT)
 		defer cancelFunc()
+
+		// Ensure TLS dials are made within the current request context.
+		if meek.cachedTLSDialer != nil {
+			meek.cachedTLSDialer.setRequestContext(requestContext)
+		}
+
 		request = request.WithContext(requestContext)
 
 		meek.addAdditionalHeaders(request)
@@ -887,7 +895,7 @@ func (meek *MeekConn) roundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 
 		if err != nil {
 			select {
-			case <-meek.runContext.Done():
+			case <-meek.runCtx.Done():
 				// Exit without retrying and without logging error.
 				return 0, common.ContextError(err)
 			default:
@@ -971,7 +979,7 @@ func (meek *MeekConn) roundTrip(sendBuffer *bytes.Buffer) (int64, error) {
 
 		select {
 		case <-delayTimer.C:
-		case <-meek.runContext.Done():
+		case <-meek.runCtx.Done():
 			delayTimer.Stop()
 			return 0, common.ContextError(err)
 		}
@@ -1026,7 +1034,7 @@ func (meek *MeekConn) readPayload(
 		select {
 		case receiveBuffer = <-meek.emptyReceiveBuffer:
 		case receiveBuffer = <-meek.partialReceiveBuffer:
-		case <-meek.runContext.Done():
+		case <-meek.runCtx.Done():
 			return 0, nil
 		}
 		// Note: receiveBuffer size may exceed meek.fullReceiveBufferLength by up to the size

+ 63 - 118
psiphon/net.go

@@ -20,6 +20,7 @@
 package psiphon
 
 import (
+	"context"
 	"crypto/tls"
 	"crypto/x509"
 	"errors"
@@ -28,9 +29,7 @@ import (
 	"io/ioutil"
 	"net"
 	"net/http"
-	"net/url"
 	"os"
-	"reflect"
 	"sync"
 	"time"
 
@@ -60,15 +59,6 @@ type DialConfig struct {
 	// upstream proxy when specified by UpstreamProxyUrl.
 	CustomHeaders http.Header
 
-	ConnectTimeout time.Duration
-
-	// PendingConns is used to track and interrupt dials in progress.
-	// Dials may be interrupted using PendingConns.CloseAll(). Once instantiated,
-	// a conn is added to pendingConns before the network connect begins and
-	// removed from pendingConns once the connect succeeds or fails.
-	// May be nil.
-	PendingConns *common.Conns
-
 	// BindToDevice parameters are used to exclude connections and
 	// associated DNS requests from VPN routing.
 	// When DeviceBinder is set, any underlying socket is
@@ -129,8 +119,8 @@ type IPv6Synthesizer interface {
 	IPv6Synthesize(IPv4Addr string) string
 }
 
-// Dialer is a custom dialer compatible with http.Transport.Dial.
-type Dialer func(string, string) (net.Conn, error)
+// Dialer is a custom network dialer.
+type Dialer func(context.Context, string, string) (net.Conn, error)
 
 // LocalProxyRelay sends to remoteConn bytes received from localConn,
 // and sends to localConn bytes received from remoteConn.
@@ -158,9 +148,9 @@ func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
 // no NetworkConnectivityChecker is provided (waiting is disabled)
 // or when NetworkConnectivityChecker.HasNetworkConnectivity()
 // indicates connectivity. It waits and polls the checker once a second.
-// If any stop is broadcast, false is returned immediately.
+// When the context is done, false is returned immediately.
 func WaitForNetworkConnectivity(
-	connectivityChecker NetworkConnectivityChecker, stopBroadcasts ...<-chan struct{}) bool {
+	ctx context.Context, connectivityChecker NetworkConnectivityChecker) bool {
 
 	if connectivityChecker == nil || 1 == connectivityChecker.HasNetworkConnectivity() {
 		return true
@@ -169,25 +159,17 @@ func WaitForNetworkConnectivity(
 	NoticeInfo("waiting for network connectivity")
 
 	ticker := time.NewTicker(1 * time.Second)
-
-	selectCases := make([]reflect.SelectCase, 1+len(stopBroadcasts))
-	selectCases[0] = reflect.SelectCase{
-		Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ticker.C)}
-	for i, stopBroadcast := range stopBroadcasts {
-		selectCases[i+1] = reflect.SelectCase{
-			Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stopBroadcast)}
-	}
+	defer ticker.Stop()
 
 	for {
 		if 1 == connectivityChecker.HasNetworkConnectivity() {
 			return true
 		}
 
-		chosen, _, ok := reflect.Select(selectCases)
-		if chosen == 0 && ok {
-			// Ticker case, so check again
-		} else {
-			// Stop case
+		select {
+		case <-ticker.C:
+			// Check HasNetworkConnectivity again
+		case <-ctx.Done():
 			return false
 		}
 	}
@@ -225,88 +207,64 @@ func ResolveIP(host string, conn net.Conn) (addrs []net.IP, ttls []time.Duration
 	return addrs, ttls, nil
 }
 
-// MakeUntunneledHttpsClient returns a net/http.Client which is
+// MakeUntunneledHTTPClient returns a net/http.Client which is
 // configured to use custom dialing features -- including BindToDevice,
-// UseIndistinguishableTLS, etc. -- for a specific HTTPS request URL.
-// If verifyLegacyCertificate is not nil, it's used for certificate
-// verification.
-//
-// Because UseIndistinguishableTLS requires a hack to work with
-// net/http, MakeUntunneledHttpClient may return a modified request URL
-// to be used. Callers should always use this return value to make
-// requests, not the input value.
-//
-// MakeUntunneledHttpsClient ignores the input requestUrl scheme,
-// which may be "http" or "https", and always performs HTTPS requests.
-func MakeUntunneledHttpsClient(
-	dialConfig *DialConfig,
+// UseIndistinguishableTLS, etc. If verifyLegacyCertificate is not nil,
+// it's used for certificate verification.
+// The context is applied to underlying TCP dials. The caller is responsible
+// for applying the context to requests made with the returned http.Client.
+func MakeUntunneledHTTPClient(
+	ctx context.Context,
+	untunneledDialConfig *DialConfig,
 	verifyLegacyCertificate *x509.Certificate,
-	requestUrl string,
-	skipVerify bool,
-	requestTimeout time.Duration) (*http.Client, string, error) {
-
-	// Change the scheme to "http"; otherwise http.Transport will try to do
-	// another TLS handshake inside the explicit TLS session. Also need to
-	// force an explicit port, as the default for "http", 80, won't talk TLS.
-	//
-	// TODO: set http.Transport.DialTLS instead of Dial to avoid this hack?
-	// See: https://golang.org/pkg/net/http/#Transport. DialTLS was added in
-	// Go 1.4 but this code may pre-date that.
-
-	urlComponents, err := url.Parse(requestUrl)
-	if err != nil {
-		return nil, "", common.ContextError(err)
-	}
-
-	urlComponents.Scheme = "http"
-	host, port, err := net.SplitHostPort(urlComponents.Host)
-	if err != nil {
-		// Assume there's no port
-		host = urlComponents.Host
-		port = ""
-	}
-	if port == "" {
-		port = "443"
-	}
-	urlComponents.Host = net.JoinHostPort(host, port)
+	skipVerify bool) (*http.Client, error) {
 
 	// Note: IndistinguishableTLS mode doesn't support VerifyLegacyCertificate
-	useIndistinguishableTLS := dialConfig.UseIndistinguishableTLS && verifyLegacyCertificate == nil
+	useIndistinguishableTLS := untunneledDialConfig.UseIndistinguishableTLS &&
+		verifyLegacyCertificate == nil
 
-	dialer := NewCustomTLSDialer(
+	dialer := NewTCPDialer(untunneledDialConfig)
+
+	tlsDialer := NewCustomTLSDialer(
 		// Note: when verifyLegacyCertificate is not nil, some
 		// of the other CustomTLSConfig is overridden.
 		&CustomTLSConfig{
-			Dial: NewTCPDialer(dialConfig),
+			Dial: dialer,
 			VerifyLegacyCertificate:       verifyLegacyCertificate,
-			SNIServerName:                 host,
+			UseDialAddrSNI:                true,
+			SNIServerName:                 "",
 			SkipVerify:                    skipVerify,
 			UseIndistinguishableTLS:       useIndistinguishableTLS,
-			TrustedCACertificatesFilename: dialConfig.TrustedCACertificatesFilename,
+			TrustedCACertificatesFilename: untunneledDialConfig.TrustedCACertificatesFilename,
 		})
 
 	transport := &http.Transport{
-		Dial: dialer,
+		Dial: func(network, addr string) (net.Conn, error) {
+			return dialer(ctx, network, addr)
+		},
+		DialTLS: func(network, addr string) (net.Conn, error) {
+			return tlsDialer(ctx, network, addr)
+		},
 	}
+
 	httpClient := &http.Client{
-		Timeout:   requestTimeout,
 		Transport: transport,
 	}
 
-	return httpClient, urlComponents.String(), nil
+	return httpClient, nil
 }
 
-// MakeTunneledHttpClient returns a net/http.Client which is
+// MakeTunneledHTTPClient returns a net/http.Client which is
 // configured to use custom dialing features including tunneled
 // dialing and, optionally, UseTrustedCACertificatesForStockTLS.
-// Unlike MakeUntunneledHttpsClient and makePsiphonHttpsClient,
-// This http.Client uses stock TLS and no scheme transformation
-// hack is required.
-func MakeTunneledHttpClient(
+// This http.Client uses stock TLS for HTTPS.
+func MakeTunneledHTTPClient(
 	config *Config,
 	tunnel *Tunnel,
-	skipVerify bool,
-	requestTimeout time.Duration) (*http.Client, error) {
+	skipVerify bool) (*http.Client, error) {
+
+	// Note: there is no dial context since SSH port forward dials cannot
+	// be interrupted directly. Closing the tunnel will interrupt the dials.
 
 	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
 		return tunnel.sshClient.Dial("tcp", addr)
@@ -337,55 +295,39 @@ func MakeTunneledHttpClient(
 
 	return &http.Client{
 		Transport: transport,
-		Timeout:   requestTimeout,
 	}, nil
 }
 
-// MakeDownloadHttpClient is a resusable helper that sets up a
-// http.Client for use either untunneled or through a tunnel.
-// See MakeUntunneledHttpsClient for a note about request URL
-// rewritting.
-func MakeDownloadHttpClient(
+// MakeDownloadHTTPClient is a helper that sets up a http.Client
+// for use either untunneled or through a tunnel.
+func MakeDownloadHTTPClient(
+	ctx context.Context,
 	config *Config,
 	tunnel *Tunnel,
 	untunneledDialConfig *DialConfig,
-	requestUrl string,
-	skipVerify bool,
-	requestTimeout time.Duration) (*http.Client, string, error) {
+	skipVerify bool) (*http.Client, error) {
 
 	var httpClient *http.Client
 	var err error
 
 	if tunnel != nil {
-		// MakeTunneledHttpClient works with both "http" and "https" schemes
-		httpClient, err = MakeTunneledHttpClient(
-			config, tunnel, skipVerify, requestTimeout)
+
+		httpClient, err = MakeTunneledHTTPClient(
+			config, tunnel, skipVerify)
 		if err != nil {
-			return nil, "", common.ContextError(err)
+			return nil, common.ContextError(err)
 		}
+
 	} else {
-		urlComponents, err := url.Parse(requestUrl)
+
+		httpClient, err = MakeUntunneledHTTPClient(
+			ctx, untunneledDialConfig, nil, skipVerify)
 		if err != nil {
-			return nil, "", common.ContextError(err)
-		}
-		// MakeUntunneledHttpsClient works only with "https" schemes
-		if urlComponents.Scheme == "https" {
-			httpClient, requestUrl, err = MakeUntunneledHttpsClient(
-				untunneledDialConfig, nil, requestUrl, skipVerify, requestTimeout)
-			if err != nil {
-				return nil, "", common.ContextError(err)
-			}
-		} else {
-			httpClient = &http.Client{
-				Timeout: requestTimeout,
-				Transport: &http.Transport{
-					Dial: NewTCPDialer(untunneledDialConfig),
-				},
-			}
+			return nil, common.ContextError(err)
 		}
 	}
 
-	return httpClient, requestUrl, nil
+	return httpClient, nil
 }
 
 // ResumeDownload is a reusable helper that downloads requestUrl via the
@@ -403,8 +345,9 @@ func MakeDownloadHttpClient(
 // partial download is in progress.
 //
 func ResumeDownload(
+	ctx context.Context,
 	httpClient *http.Client,
-	requestUrl string,
+	downloadURL string,
 	userAgent string,
 	downloadFilename string,
 	ifNoneMatchETag string) (int64, string, error) {
@@ -456,11 +399,13 @@ func ResumeDownload(
 		}
 	}
 
-	request, err := http.NewRequest("GET", requestUrl, nil)
+	request, err := http.NewRequest("GET", downloadURL, nil)
 	if err != nil {
 		return 0, "", common.ContextError(err)
 	}
 
+	request = request.WithContext(ctx)
+
 	request.Header.Set("User-Agent", userAgent)
 
 	request.Header.Add("Range", fmt.Sprintf("bytes=%d-", fileInfo.Size()))

+ 6 - 6
psiphon/packetTunnelTransport.go

@@ -33,7 +33,7 @@ import (
 // disconnect from and reconnect to the same or different Psiphon servers. PacketTunnelTransport
 // allows the Psiphon client to substitute new transport channels on-the-fly.
 type PacketTunnelTransport struct {
-	runContext    context.Context
+	runCtx        context.Context
 	stopRunning   context.CancelFunc
 	workers       *sync.WaitGroup
 	readMutex     sync.Mutex
@@ -47,10 +47,10 @@ type PacketTunnelTransport struct {
 // NewPacketTunnelTransport initializes a PacketTunnelTransport.
 func NewPacketTunnelTransport() *PacketTunnelTransport {
 
-	runContext, stopRunning := context.WithCancel(context.Background())
+	runCtx, stopRunning := context.WithCancel(context.Background())
 
 	return &PacketTunnelTransport{
-		runContext:   runContext,
+		runCtx:       runCtx,
 		stopRunning:  stopRunning,
 		workers:      new(sync.WaitGroup),
 		channelReady: sync.NewCond(new(sync.Mutex)),
@@ -122,7 +122,7 @@ func (p *PacketTunnelTransport) Close() error {
 	p.workers.Wait()
 
 	// This broadcast is to wake up reads or writes blocking in getChannel; those
-	// getChannel calls should then abort on the p.runContext.Done() check.
+	// getChannel calls should then abort on the p.runCtx.Done() check.
 	p.channelReady.Broadcast()
 
 	p.channelMutex.Lock()
@@ -173,7 +173,7 @@ func (p *PacketTunnelTransport) setChannel(
 	// UseTunnel call concurrent with a Close call doesn't leave a channel
 	// set.
 	select {
-	case <-p.runContext.Done():
+	case <-p.runCtx.Done():
 		p.channelMutex.Unlock()
 		return
 	default:
@@ -202,7 +202,7 @@ func (p *PacketTunnelTransport) getChannel() (net.Conn, *Tunnel, error) {
 	for {
 
 		select {
-		case <-p.runContext.Done():
+		case <-p.runCtx.Done():
 			return nil, nil, common.ContextError(errors.New("already closed"))
 		default:
 		}

+ 0 - 76
psiphon/pluginProtocol.go

@@ -1,76 +0,0 @@
-/*
- * Copyright (c) 2017, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program.  If not, see <http://www.gnu.org/licenses/>.
- *
- */
-
-package psiphon
-
-import (
-	"io"
-	"net"
-	"sync/atomic"
-
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
-)
-
-var registeredPluginProtocolDialer atomic.Value
-
-// PluginProtocolDialer creates a connection to addr over a
-// plugin protocol. It uses dialConfig to create its base network
-// connection(s) and sends its log messages to loggerOutput.
-//
-// To ensure timely interruption and shutdown, each
-// PluginProtocolDialerimplementation must:
-//
-// - Places its outer net.Conn in pendingConns and leave it
-//   there unless an error occurs
-// - Replace the dialConfig.pendingConns with its own
-//   PendingConns and use that to ensure base network
-//   connections are interrupted when Close() is invoked on
-//   the returned net.Conn.
-//
-// PluginProtocolDialer returns true if it attempts to create
-// a connection, or false if it decides not to attempt a connection.
-type PluginProtocolDialer func(
-	config *Config,
-	loggerOutput io.Writer,
-	pendingConns *common.Conns,
-	addr string,
-	dialConfig *DialConfig) (bool, net.Conn, error)
-
-// RegisterPluginProtocol sets the current plugin protocol
-// dialer.
-func RegisterPluginProtocol(protocolDialer PluginProtocolDialer) {
-	registeredPluginProtocolDialer.Store(protocolDialer)
-}
-
-// DialPluginProtocol uses the current plugin protocol dialer,
-// if set, to connect to addr over the plugin protocol.
-func DialPluginProtocol(
-	config *Config,
-	loggerOutput io.Writer,
-	pendingConns *common.Conns,
-	addr string,
-	dialConfig *DialConfig) (bool, net.Conn, error) {
-
-	dialer := registeredPluginProtocolDialer.Load()
-	if dialer != nil {
-		return dialer.(PluginProtocolDialer)(
-			config, loggerOutput, pendingConns, addr, dialConfig)
-	}
-	return false, nil, nil
-}

+ 20 - 6
psiphon/remoteServerList.go

@@ -20,6 +20,7 @@
 package psiphon
 
 import (
+	"context"
 	"encoding/hex"
 	"errors"
 	"fmt"
@@ -32,7 +33,7 @@ import (
 )
 
 type RemoteServerListFetcher func(
-	config *Config, attempt int, tunnel *Tunnel, untunneledDialConfig *DialConfig) error
+	ctx context.Context, config *Config, attempt int, tunnel *Tunnel, untunneledDialConfig *DialConfig) error
 
 // FetchCommonRemoteServerList downloads the common remote server list from
 // config.RemoteServerListUrl. It validates its digital signature using the
@@ -42,6 +43,7 @@ type RemoteServerListFetcher func(
 // download. As the download is resumed after failure, this filename must
 // be unique and persistent.
 func FetchCommonRemoteServerList(
+	ctx context.Context,
 	config *Config,
 	attempt int,
 	tunnel *Tunnel,
@@ -52,6 +54,7 @@ func FetchCommonRemoteServerList(
 	downloadURL, canonicalURL, skipVerify := selectDownloadURL(attempt, config.RemoteServerListURLs)
 
 	newETag, err := downloadRemoteServerListFile(
+		ctx,
 		config,
 		tunnel,
 		untunneledDialConfig,
@@ -116,6 +119,7 @@ func FetchCommonRemoteServerList(
 // downloaded files. As  downloads are resumed after failure, this directory
 // must be unique and persistent.
 func FetchObfuscatedServerLists(
+	ctx context.Context,
 	config *Config,
 	attempt int,
 	tunnel *Tunnel,
@@ -144,6 +148,7 @@ func FetchObfuscatedServerLists(
 	registryFilename := cachedFilename
 
 	newETag, err := downloadRemoteServerListFile(
+		ctx,
 		config,
 		tunnel,
 		untunneledDialConfig,
@@ -219,6 +224,7 @@ func FetchObfuscatedServerLists(
 		sourceETag := fmt.Sprintf("\"%s\"", hex.EncodeToString(oslFileSpec.MD5Sum))
 
 		newETag, err := downloadRemoteServerListFile(
+			ctx,
 			config,
 			tunnel,
 			untunneledDialConfig,
@@ -324,6 +330,7 @@ func FetchObfuscatedServerLists(
 // The caller is responsible for calling SetUrlETag once the file
 // content has been validated.
 func downloadRemoteServerListFile(
+	ctx context.Context,
 	config *Config,
 	tunnel *Tunnel,
 	untunneledDialConfig *DialConfig,
@@ -349,23 +356,30 @@ func downloadRemoteServerListFile(
 		return "", nil
 	}
 
+	if *config.FetchRemoteServerListTimeoutSeconds > 0 {
+		var cancelFunc context.CancelFunc
+		ctx, cancelFunc = context.WithTimeout(
+			ctx, time.Duration(*config.FetchRemoteServerListTimeoutSeconds)*time.Second)
+		defer cancelFunc()
+	}
+
 	// MakeDownloadHttpClient will select either a tunneled
 	// or untunneled configuration.
 
-	httpClient, requestURL, err := MakeDownloadHttpClient(
+	httpClient, err := MakeDownloadHTTPClient(
+		ctx,
 		config,
 		tunnel,
 		untunneledDialConfig,
-		sourceURL,
-		skipVerify,
-		time.Duration(*config.FetchRemoteServerListTimeoutSeconds)*time.Second)
+		skipVerify)
 	if err != nil {
 		return "", common.ContextError(err)
 	}
 
 	n, responseETag, err := ResumeDownload(
+		ctx,
 		httpClient,
-		requestURL,
+		sourceURL,
 		MakePsiphonUserAgent(config),
 		destinationFilename,
 		lastETag)

+ 5 - 1
psiphon/remoteServerList_test.go

@@ -21,6 +21,7 @@ package psiphon
 
 import (
 	"bytes"
+	"context"
 	"crypto/md5"
 	"encoding/base64"
 	"encoding/hex"
@@ -393,8 +394,11 @@ func TestObfuscatedRemoteServerLists(t *testing.T) {
 			}
 		}))
 
+	ctx, cancelFunc := context.WithCancel(context.Background())
+	defer cancelFunc()
+
 	go func() {
-		controller.Run(make(chan struct{}))
+		controller.Run(ctx)
 	}()
 
 	establishTimeout := time.NewTimer(30 * time.Second)

+ 6 - 2
psiphon/server/meek_test.go

@@ -21,6 +21,7 @@ package server
 
 import (
 	"bytes"
+	"context"
 	crypto_rand "crypto/rand"
 	"encoding/base64"
 	"fmt"
@@ -294,7 +295,6 @@ func TestMeekResiliency(t *testing.T) {
 	// Run meek client
 
 	dialConfig := &psiphon.DialConfig{
-		PendingConns:            new(common.Conns),
 		UseIndistinguishableTLS: true,
 		DeviceBinder:            new(fileDescriptorInterruptor),
 	}
@@ -308,7 +308,11 @@ func TestMeekResiliency(t *testing.T) {
 		MeekObfuscatedKey:             meekObfuscatedKey,
 	}
 
-	clientConn, err := psiphon.DialMeek(meekConfig, dialConfig)
+	ctx, cancelFunc := context.WithTimeout(
+		context.Background(), time.Second*5)
+	defer cancelFunc()
+
+	clientConn, err := psiphon.DialMeek(ctx, meekConfig, dialConfig)
 	if err != nil {
 		t.Fatalf("psiphon.DialMeek failed: %s", err)
 	}

+ 7 - 3
psiphon/server/server_test.go

@@ -20,6 +20,7 @@
 package server
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"flag"
@@ -500,15 +501,18 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			}
 		}))
 
-	controllerShutdownBroadcast := make(chan struct{})
+	ctx, cancelFunc := context.WithCancel(context.Background())
+
 	controllerWaitGroup := new(sync.WaitGroup)
+
 	controllerWaitGroup.Add(1)
 	go func() {
 		defer controllerWaitGroup.Done()
-		controller.Run(controllerShutdownBroadcast)
+		controller.Run(ctx)
 	}()
+
 	defer func() {
-		close(controllerShutdownBroadcast)
+		cancelFunc()
 
 		shutdownTimeout := time.NewTimer(20 * time.Second)
 

+ 9 - 9
psiphon/server/tunnelServer.go

@@ -761,7 +761,7 @@ type sshClient struct {
 	tcpPortForwardLRU                    *common.LRUConns
 	oslClientSeedState                   *osl.ClientSeedState
 	signalIssueSLOKs                     chan struct{}
-	runContext                           context.Context
+	runCtx                               context.Context
 	stopRunning                          context.CancelFunc
 	tcpPortForwardDialingAvailableSignal context.CancelFunc
 }
@@ -799,7 +799,7 @@ type handshakeState struct {
 func newSshClient(
 	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
 
-	runContext, stopRunning := context.WithCancel(context.Background())
+	runCtx, stopRunning := context.WithCancel(context.Background())
 
 	client := &sshClient{
 		sshServer:         sshServer,
@@ -807,7 +807,7 @@ func newSshClient(
 		geoIPData:         geoIPData,
 		tcpPortForwardLRU: common.NewLRUConns(),
 		signalIssueSLOKs:  make(chan struct{}, 1),
-		runContext:        runContext,
+		runCtx:            runCtx,
 		stopRunning:       stopRunning,
 	}
 
@@ -1256,7 +1256,7 @@ func (sshClient *sshClient) runTunnel(
 
 			if sshClient.isTCPDialingPortForwardLimitExceeded() {
 				blockStartTime := monotime.Now()
-				ctx, cancelCtx := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
+				ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
 				sshClient.setTCPPortForwardDialingAvailableSignal(cancelCtx)
 				<-ctx.Done()
 				sshClient.setTCPPortForwardDialingAvailableSignal(nil)
@@ -1511,7 +1511,7 @@ func (sshClient *sshClient) runOSLSender() {
 		// TODO: use reflect.SelectCase, and optionally await timer here?
 		select {
 		case <-sshClient.signalIssueSLOKs:
-		case <-sshClient.runContext.Done():
+		case <-sshClient.runCtx.Done():
 			return
 		}
 
@@ -1529,7 +1529,7 @@ func (sshClient *sshClient) runOSLSender() {
 			select {
 			case <-retryTimer.C:
 			case <-sshClient.signalIssueSLOKs:
-			case <-sshClient.runContext.Done():
+			case <-sshClient.runCtx.Done():
 				retryTimer.Stop()
 				return
 			}
@@ -2088,14 +2088,14 @@ func (sshClient *sshClient) handleTCPChannel(
 	// Hostname resolution is performed explicitly, as a separate step, as the target IP
 	// address is used for traffic rules (AllowSubnets) and OSL seed progress.
 	//
-	// Contexts are used for cancellation (via sshClient.runContext, which is cancelled
+	// Contexts are used for cancellation (via sshClient.runCtx, which is cancelled
 	// when the client is stopping) and timeouts.
 
 	dialStartTime := monotime.Now()
 
 	log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
 
-	ctx, cancelCtx := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
+	ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
 	IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
 	cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
 
@@ -2154,7 +2154,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
 
-	ctx, cancelCtx = context.WithTimeout(sshClient.runContext, remainingDialTimeout)
+	ctx, cancelCtx = context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
 	fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
 	cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
 

+ 22 - 101
psiphon/serverApi.go

@@ -21,6 +21,7 @@ package psiphon
 
 import (
 	"bytes"
+	"context"
 	"encoding/base64"
 	"encoding/hex"
 	"encoding/json"
@@ -33,7 +34,6 @@ import (
 	"net/url"
 	"strconv"
 	"sync/atomic"
-	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
@@ -475,92 +475,6 @@ func confirmStatusRequestPayload(payloadInfo *statusRequestPayloadInfo) {
 	}
 }
 
-// TryUntunneledStatusRequest makes direct connections to the specified
-// server (if supported) in an attempt to send useful bytes transferred
-// and tunnel duration stats after a tunnel has alreay failed.
-// The tunnel is assumed to be closed, but its config, protocol, and
-// context values must still be valid.
-// TryUntunneledStatusRequest emits notices detailing failed attempts.
-func (serverContext *ServerContext) TryUntunneledStatusRequest(isShutdown bool) error {
-
-	for _, port := range serverContext.tunnel.serverEntry.GetUntunneledWebRequestPorts() {
-		err := serverContext.doUntunneledStatusRequest(port, isShutdown)
-		if err == nil {
-			return nil
-		}
-		NoticeAlert("doUntunneledStatusRequest failed for %s:%s: %s",
-			serverContext.tunnel.serverEntry.IpAddress, port, err)
-	}
-
-	return errors.New("all attempts failed")
-}
-
-// doUntunneledStatusRequest attempts an untunneled status request.
-func (serverContext *ServerContext) doUntunneledStatusRequest(
-	port string, isShutdown bool) error {
-
-	tunnel := serverContext.tunnel
-
-	certificate, err := DecodeCertificate(tunnel.serverEntry.WebServerCertificate)
-	if err != nil {
-		return common.ContextError(err)
-	}
-
-	timeout := time.Duration(*tunnel.config.PsiphonApiServerTimeoutSeconds) * time.Second
-
-	dialConfig := tunnel.untunneledDialConfig
-
-	if isShutdown {
-		timeout = PSIPHON_API_SHUTDOWN_SERVER_TIMEOUT
-
-		// Use a copy of DialConfig without pendingConns. This ensures
-		// this request isn't interrupted/canceled. This measure should
-		// be used only with the very short PSIPHON_API_SHUTDOWN_SERVER_TIMEOUT.
-		dialConfig = new(DialConfig)
-		*dialConfig = *tunnel.untunneledDialConfig
-	}
-
-	url := makeRequestUrl(tunnel, port, "status", serverContext.getStatusParams(false))
-
-	httpClient, url, err := MakeUntunneledHttpsClient(
-		dialConfig,
-		certificate,
-		url,
-		false,
-		timeout)
-	if err != nil {
-		return common.ContextError(err)
-	}
-
-	statusPayload, statusPayloadInfo, err := makeStatusRequestPayload(tunnel.serverEntry.IpAddress)
-	if err != nil {
-		return common.ContextError(err)
-	}
-
-	bodyType := "application/json"
-	body := bytes.NewReader(statusPayload)
-
-	response, err := httpClient.Post(url, bodyType, body)
-	if err == nil && response.StatusCode != http.StatusOK {
-		response.Body.Close()
-		err = fmt.Errorf("HTTP POST request failed with response code: %d", response.StatusCode)
-	}
-	if err != nil {
-
-		// Resend the transfer stats and tunnel stats later
-		// Note: potential duplicate reports if the server received and processed
-		// the request but the client failed to receive the response.
-		putBackStatusRequestPayload(statusPayloadInfo)
-
-		// Trim this error since it may include long URLs
-		return common.ContextError(TrimError(err))
-	}
-	confirmStatusRequestPayload(statusPayloadInfo)
-	response.Body.Close()
-
-	return nil
-}
-
 // RecordTunnelStat records a tunnel duration and bytes
 // sent and received for subsequent reporting and quality
 // analysis.
@@ -909,9 +823,7 @@ func makeRequestUrl(tunnel *Tunnel, port, path string, params requestJSONObject)
 		port = tunnel.serverEntry.WebServerPort
 	}
 
-	// Note: don't prefix with HTTPS scheme, see comment in doGetRequest.
-	// e.g., don't do this: requestUrl.WriteString("https://")
-	requestUrl.WriteString("http://")
+	requestUrl.WriteString("https://")
 	requestUrl.WriteString(tunnel.serverEntry.IpAddress)
 	requestUrl.WriteString(":")
 	requestUrl.WriteString(port)
@@ -948,32 +860,41 @@ func makeRequestUrl(tunnel *Tunnel, port, path string, params requestJSONObject)
 
 // makePsiphonHttpsClient creates a Psiphon HTTPS client that tunnels web service API
 // requests and which validates the web server using the Psiphon server entry web server
-// certificate. This is not a general purpose HTTPS client.
-// As the custom dialer makes an explicit TLS connection, URLs submitted to the returned
-// http.Client should use the "http://" scheme. Otherwise http.Transport will try to do another TLS
-// handshake inside the explicit TLS session.
+// certificate.
 func makePsiphonHttpsClient(tunnel *Tunnel) (httpsClient *http.Client, err error) {
+
 	certificate, err := DecodeCertificate(tunnel.serverEntry.WebServerCertificate)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
-	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
-		// TODO: check tunnel.isClosed, and apply TUNNEL_PORT_FORWARD_DIAL_TIMEOUT as in Tunnel.Dial?
+
+	tunneledDialer := func(_ context.Context, _, addr string) (conn net.Conn, err error) {
 		return tunnel.sshClient.Dial("tcp", addr)
 	}
-	timeout := time.Duration(*tunnel.config.PsiphonApiServerTimeoutSeconds) * time.Second
+
+	// Note: as with SSH API requests, there no dial context here. SSH port forward dials
+	// cannot be interrupted directly. Closing the tunnel will interrupt both the dial and
+	// the request. While it's possible to add a timeout here, we leave it with no explicit
+	// timeout which is the same as SSH API requests: if the tunnel has stalled then SSH keep
+	// alives will cause the tunnel to close.
+
 	dialer := NewCustomTLSDialer(
 		&CustomTLSConfig{
-			Dial:                    tunneledDialer,
-			Timeout:                 timeout,
+			Dial: tunneledDialer,
 			VerifyLegacyCertificate: certificate,
 		})
+
 	transport := &http.Transport{
-		Dial: dialer,
+		DialTLS: func(network, addr string) (net.Conn, error) {
+			return dialer(context.Background(), network, addr)
+		},
+		Dial: func(network, addr string) (net.Conn, error) {
+			return nil, errors.New("HTTP not supported")
+		},
 	}
+
 	return &http.Client{
 		Transport: transport,
-		Timeout:   timeout,
 	}, nil
 }
 

+ 30 - 28
psiphon/tlsDialer.go

@@ -72,6 +72,7 @@ package psiphon
 
 import (
 	"bytes"
+	"context"
 	"crypto/x509"
 	"encoding/hex"
 	"errors"
@@ -95,17 +96,19 @@ type CustomTLSConfig struct {
 	// top of a new network connection created with dialer.
 	Dial Dialer
 
-	// Timeout is and optional timeout for combined network
-	// connection dial and TLS handshake.
-	Timeout time.Duration
-
 	// DialAddr overrides the "addr" input to Dial when specified
 	DialAddr string
 
+	// UseDialAddrSNI specifies whether to always use the dial "addr"
+	// host name in the SNI server_name field. When DialAddr is set,
+	// its host name is used.
+	UseDialAddrSNI bool
+
 	// SNIServerName specifies the value to set in the SNI
 	// server_name field. When blank, SNI is omitted. Note that
 	// underlying TLS code also automatically omits SNI when
 	// the server_name is an IP address.
+	// SNIServerName is ignored when UseDialAddrSNI is true.
 	SNIServerName string
 
 	// SkipVerify completely disables server certificate verification.
@@ -170,8 +173,8 @@ func SelectTLSProfile(
 }
 
 func NewCustomTLSDialer(config *CustomTLSConfig) Dialer {
-	return func(network, addr string) (net.Conn, error) {
-		return CustomTLSDial(network, addr, config)
+	return func(ctx context.Context, network, addr string) (net.Conn, error) {
+		return CustomTLSDial(ctx, network, addr, config)
 	}
 }
 
@@ -187,26 +190,17 @@ type handshakeConn interface {
 // tlsdialer comment:
 //   Note - if sendServerName is false, the VerifiedChains field on the
 //   connection's ConnectionState will never get populated.
-func CustomTLSDial(network, addr string, config *CustomTLSConfig) (net.Conn, error) {
-
-	// We want the Timeout and Deadline values from dialer to cover the
-	// whole process: TCP connection and TLS handshake. This means that we
-	// also need to start our own timers now.
-	var errChannel chan error
-	if config.Timeout != 0 {
-		errChannel = make(chan error, 2)
-		timeoutFunc := time.AfterFunc(config.Timeout, func() {
-			errChannel <- errors.New("timed out")
-		})
-		defer timeoutFunc.Stop()
-	}
+func CustomTLSDial(
+	ctx context.Context,
+	network, addr string,
+	config *CustomTLSConfig) (net.Conn, error) {
 
 	dialAddr := addr
 	if config.DialAddr != "" {
 		dialAddr = config.DialAddr
 	}
 
-	rawConn, err := config.Dial(network, dialAddr)
+	rawConn, err := config.Dial(ctx, network, dialAddr)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
@@ -258,7 +252,9 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (net.Conn, err
 		tlsConfig.InsecureSkipVerify = true
 	}
 
-	if config.SNIServerName != "" && config.VerifyLegacyCertificate == nil {
+	if config.UseDialAddrSNI {
+		tlsConfig.ServerName = hostname
+	} else if config.SNIServerName != "" && config.VerifyLegacyCertificate == nil {
 		// Set the ServerName and rely on the usual logic in
 		// tls.Conn.Handshake() to do its verification.
 		// Note: Go TLS will automatically omit this ServerName when it's an IP address
@@ -302,13 +298,19 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (net.Conn, err
 		conn = tls.Client(rawConn, tlsConfig)
 	}
 
-	if config.Timeout == 0 {
-		err = conn.Handshake()
-	} else {
-		go func() {
-			errChannel <- conn.Handshake()
-		}()
-		err = <-errChannel
+	resultChannel := make(chan error)
+
+	go func() {
+		resultChannel <- conn.Handshake()
+	}()
+
+	select {
+	case err = <-resultChannel:
+	case <-ctx.Done():
+		err = ctx.Err()
+		// Interrupt the goroutine
+		rawConn.Close()
+		<-resultChannel
 	}
 
 	// openSSLConns complete verification automatically. For Go TLS,

+ 109 - 170
psiphon/tunnel.go

@@ -21,6 +21,7 @@ package psiphon
 
 import (
 	"bytes"
+	"context"
 	"encoding/base64"
 	"encoding/json"
 	"errors"
@@ -44,15 +45,22 @@ import (
 // Components which use this interface may be serviced by a single Tunnel instance,
 // or a Controller which manages a pool of tunnels, or any other object which
 // implements Tunneler.
-// alwaysTunnel indicates that the connection should always be tunneled. If this
-// is not set, the connection may be made directly, depending on split tunnel
-// classification, when that feature is supported and active.
-// downstreamConn is an optional parameter which specifies a connection to be
-// explicitly closed when the Dialed connection is closed. For instance, this
-// is used to close downstreamConn App<->LocalProxy connections when the related
-// LocalProxy<->SshPortForward connections close.
 type Tunneler interface {
+
+	// Dial creates a tunneled connection.
+	//
+	// alwaysTunnel indicates that the connection should always be tunneled. If this
+	// is not set, the connection may be made directly, depending on split tunnel
+	// classification, when that feature is supported and active.
+	//
+	// downstreamConn is an optional parameter which specifies a connection to be
+	// explicitly closed when the Dialed connection is closed. For instance, this
+	// is used to close downstreamConn App<->LocalProxy connections when the related
+	// LocalProxy<->SshPortForward connections close.
 	Dial(remoteAddr string, alwaysTunnel bool, downstreamConn net.Conn) (conn net.Conn, err error)
+
+	DirectDial(remoteAddr string) (conn net.Conn, err error)
+
 	SignalComponentFailure()
 }
 
@@ -70,7 +78,6 @@ type TunnelOwner interface {
 type Tunnel struct {
 	mutex                        *sync.Mutex
 	config                       *Config
-	untunneledDialConfig         *DialConfig
 	isActivated                  bool
 	isDiscarded                  bool
 	isClosed                     bool
@@ -82,7 +89,8 @@ type Tunnel struct {
 	sshClient                    *ssh.Client
 	sshServerRequests            <-chan *ssh.Request
 	operateWaitGroup             *sync.WaitGroup
-	shutdownOperateBroadcast     chan struct{}
+	operateCtx                   context.Context
+	stopOperate                  context.CancelFunc
 	signalPortForwardFailure     chan struct{}
 	totalPortForwardFailures     int
 	adjustedEstablishStartTime   monotime.Time
@@ -123,7 +131,6 @@ type TunnelDialStats struct {
 // HTTP (meek protocol).
 // When requiredProtocol is not blank, that protocol is used. Otherwise,
 // the a random supported protocol is used.
-// untunneledDialConfig is used for untunneled final status requests.
 //
 // Call Activate on a connected tunnel to complete its establishment
 // before using.
@@ -136,10 +143,9 @@ type TunnelDialStats struct {
 // as necessary.
 //
 func ConnectTunnel(
+	ctx context.Context,
 	config *Config,
-	untunneledDialConfig *DialConfig,
 	sessionId string,
-	pendingConns *common.Conns,
 	serverEntry *protocol.ServerEntry,
 	selectedProtocol string,
 	adjustedEstablishStartTime monotime.Time) (*Tunnel, error) {
@@ -151,27 +157,21 @@ func ConnectTunnel(
 	// Build transport layers and establish SSH connection. Note that
 	// dialConn and monitoredConn are the same network connection.
 	dialResult, err := dialSsh(
-		config, pendingConns, serverEntry, selectedProtocol, sessionId)
+		ctx, config, serverEntry, selectedProtocol, sessionId)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
 
-	// Now that connection dials are complete, cancel interruptibility
-	pendingConns.Remove(dialResult.dialConn)
-
 	// The tunnel is now connected
 	return &Tunnel{
-		mutex:                    new(sync.Mutex),
-		config:                   config,
-		untunneledDialConfig:     untunneledDialConfig,
-		sessionId:                sessionId,
-		serverEntry:              serverEntry,
-		protocol:                 selectedProtocol,
-		conn:                     dialResult.monitoredConn,
-		sshClient:                dialResult.sshClient,
-		sshServerRequests:        dialResult.sshRequests,
-		operateWaitGroup:         new(sync.WaitGroup),
-		shutdownOperateBroadcast: make(chan struct{}),
+		mutex:             new(sync.Mutex),
+		config:            config,
+		sessionId:         sessionId,
+		serverEntry:       serverEntry,
+		protocol:          selectedProtocol,
+		conn:              dialResult.monitoredConn,
+		sshClient:         dialResult.sshClient,
+		sshServerRequests: dialResult.sshRequests,
 		// A buffer allows at least one signal to be sent even when the receiver is
 		// not listening. Senders should not block.
 		signalPortForwardFailure:   make(chan struct{}, 1),
@@ -187,8 +187,8 @@ func ConnectTunnel(
 // request and starting operateTunnel, the worker that monitors the tunnel
 // and handles periodic management.
 func (tunnel *Tunnel) Activate(
-	tunnelOwner TunnelOwner,
-	shutdownBroadcast chan struct{}) error {
+	ctx context.Context,
+	tunnelOwner TunnelOwner) error {
 
 	// Create a new Psiphon API server context for this tunnel. This includes
 	// performing a handshake request. If the handshake fails, this activation
@@ -207,6 +207,13 @@ func (tunnel *Tunnel) Activate(
 		// request. At this point, there is no operateTunnel monitor that will detect
 		// this condition with SSH keep alives.
 
+		if *tunnel.config.PsiphonApiServerTimeoutSeconds > 0 {
+			var cancelFunc context.CancelFunc
+			ctx, cancelFunc = context.WithTimeout(
+				ctx, time.Second*time.Duration(*tunnel.config.PsiphonApiServerTimeoutSeconds))
+			defer cancelFunc()
+		}
+
 		type newServerContextResult struct {
 			serverContext *ServerContext
 			err           error
@@ -224,37 +231,13 @@ func (tunnel *Tunnel) Activate(
 
 		var result newServerContextResult
 
-		if *tunnel.config.PsiphonApiServerTimeoutSeconds > 0 {
-
-			timer := time.NewTimer(
-				time.Second *
-					time.Duration(
-						*tunnel.config.PsiphonApiServerTimeoutSeconds))
-
-			select {
-			case result = <-resultChannel:
-			case <-timer.C:
-				result.err = errors.New("timed out")
-				// Interrupt the Activate goroutine and await its completion.
-				tunnel.Close(true)
-				<-resultChannel
-			case <-shutdownBroadcast:
-				result.err = errors.New("shutdown")
-				tunnel.Close(true)
-				<-resultChannel
-			}
-
-			timer.Stop()
-
-		} else {
-
-			select {
-			case result = <-resultChannel:
-			case <-shutdownBroadcast:
-				result.err = errors.New("shutdown")
-				tunnel.Close(true)
-				<-resultChannel
-			}
+		select {
+		case result = <-resultChannel:
+		case <-ctx.Done():
+			result.err = ctx.Err()
+			// Interrupt the goroutine
+			tunnel.Close(true)
+			<-resultChannel
 		}
 
 		if result.err != nil {
@@ -289,6 +272,11 @@ func (tunnel *Tunnel) Activate(
 	tunnel.establishDuration = monotime.Since(tunnel.adjustedEstablishStartTime)
 	tunnel.establishedTime = monotime.Now()
 
+	// Use the Background context instead of the controller run context, as tunnels
+	// are terminated when the controller calls tunnel.Close.
+	tunnel.operateCtx, tunnel.stopOperate = context.WithCancel(context.Background())
+	tunnel.operateWaitGroup = new(sync.WaitGroup)
+
 	// Spawn the operateTunnel goroutine, which monitors the tunnel and handles periodic
 	// stats updates.
 	tunnel.operateWaitGroup.Add(1)
@@ -318,12 +306,9 @@ func (tunnel *Tunnel) Close(isDiscarded bool) {
 		// shutdown.
 		// A timer is set, so if operateTunnel takes too long to stop, the
 		// tunnel is closed, which will interrupt any slow final status request.
-		// In effect, the TUNNEL_OPERATE_SHUTDOWN_TIMEOUT value will take
-		// precedence over the PSIPHON_API_SERVER_TIMEOUT http.Client.Timeout
-		// value set in makePsiphonHttpsClient.
 		if isActivated {
 			afterFunc := time.AfterFunc(TUNNEL_OPERATE_SHUTDOWN_TIMEOUT, func() { tunnel.conn.Close() })
-			close(tunnel.shutdownOperateBroadcast)
+			tunnel.stopOperate()
 			tunnel.operateWaitGroup.Wait()
 			afterFunc.Stop()
 		}
@@ -386,17 +371,28 @@ func (tunnel *Tunnel) Dial(
 		sshPortForwardConn net.Conn
 		err                error
 	}
-	resultChannel := make(chan *tunnelDialResult, 2)
+
+	// Note: there is no dial context since SSH port forward dials cannot
+	// be interrupted directly. Closing the tunnel will interrupt the dials.
+	// A timeout is set to unblock this function, but the goroutine may
+	// not exit until the tunnel is closed.
+
+	// Use a buffer of 1 as there are two senders and only one guaranteed receive.
+
+	resultChannel := make(chan *tunnelDialResult, 1)
+
 	if *tunnel.config.TunnelPortForwardDialTimeoutSeconds > 0 {
 		afterFunc := time.AfterFunc(time.Duration(*tunnel.config.TunnelPortForwardDialTimeoutSeconds)*time.Second, func() {
 			resultChannel <- &tunnelDialResult{nil, errors.New("tunnel dial timeout")}
 		})
 		defer afterFunc.Stop()
 	}
+
 	go func() {
 		sshPortForwardConn, err := tunnel.sshClient.Dial("tcp", remoteAddr)
 		resultChannel <- &tunnelDialResult{sshPortForwardConn, err}
 	}()
+
 	result := <-resultChannel
 
 	if result.err != nil {
@@ -739,12 +735,19 @@ type dialResult struct {
 // base dial conn. The *ActivityMonitoredConn return value is the layered conn passed into
 // the ssh.Client.
 func dialSsh(
+	ctx context.Context,
 	config *Config,
-	pendingConns *common.Conns,
 	serverEntry *protocol.ServerEntry,
 	selectedProtocol,
 	sessionId string) (*dialResult, error) {
 
+	if *config.TunnelConnectTimeoutSeconds > 0 {
+		var cancelFunc context.CancelFunc
+		ctx, cancelFunc = context.WithTimeout(
+			ctx, time.Second*time.Duration(*config.TunnelConnectTimeoutSeconds))
+		defer cancelFunc()
+	}
+
 	// The meek protocols tunnel obfuscated SSH. Obfuscated SSH is layered on top of SSH.
 	// So depending on which protocol is used, multiple layers are initialized.
 
@@ -808,8 +811,6 @@ func dialSsh(
 	dialConfig := &DialConfig{
 		UpstreamProxyUrl:              config.UpstreamProxyUrl,
 		CustomHeaders:                 dialCustomHeaders,
-		ConnectTimeout:                time.Duration(*config.TunnelConnectTimeoutSeconds) * time.Second,
-		PendingConns:                  pendingConns,
 		DeviceBinder:                  config.DeviceBinder,
 		DnsServerGetter:               config.DnsServerGetter,
 		IPv6Synthesizer:               config.IPv6Synthesizer,
@@ -869,41 +870,12 @@ func dialSsh(
 
 	var dialConn net.Conn
 	if meekConfig != nil {
-		dialConn, err = DialMeek(meekConfig, dialConfig)
+		dialConn, err = DialMeek(ctx, meekConfig, dialConfig)
 		if err != nil {
 			return nil, common.ContextError(err)
 		}
 	} else {
-
-		// For some direct connect servers, DialPluginProtocol
-		// will layer on another obfuscation protocol.
-
-		// Use a copy of DialConfig without pendingConns; the
-		// DialPluginProtocol must supply and manage its own
-		// for its base network connections.
-		pluginDialConfig := new(DialConfig)
-		*pluginDialConfig = *dialConfig
-		pluginDialConfig.PendingConns = nil
-
-		var dialedPlugin bool
-		dialedPlugin, dialConn, err = DialPluginProtocol(
-			config,
-			NewNoticeWriter("DialPluginProtocol"),
-			pendingConns,
-			directTCPDialAddress,
-			dialConfig)
-
-		if !dialedPlugin && err != nil {
-			NoticeInfo("DialPluginProtocol intialization failed: %s", err)
-		}
-
-		if dialedPlugin {
-			NoticeInfo("using DialPluginProtocol for %s", serverEntry.IpAddress)
-		} else {
-			// Standard direct connection.
-			dialConn, err = DialTCP(directTCPDialAddress, dialConfig)
-		}
-
+		dialConn, err = DialTCP(ctx, directTCPDialAddress, dialConfig)
 		if err != nil {
 			return nil, common.ContextError(err)
 		}
@@ -920,7 +892,6 @@ func dialSsh(
 		// Cleanup on error
 		if cleanupConn != nil {
 			cleanupConn.Close()
-			pendingConns.Remove(cleanupConn)
 		}
 	}()
 
@@ -967,6 +938,7 @@ func dialSsh(
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
+
 	sshClientConfig := &ssh.ClientConfig{
 		User: serverEntry.SshUsername,
 		Auth: []ssh.AuthMethod{
@@ -985,20 +957,19 @@ func dialSsh(
 	// Note: TCP handshake timeouts are provided by TCPConn, and session
 	// timeouts *after* ssh establishment are provided by the ssh keep alive
 	// in operate tunnel.
-	// TODO: adjust the timeout to account for time-elapsed-from-start
 
 	type sshNewClientResult struct {
 		sshClient   *ssh.Client
 		sshRequests <-chan *ssh.Request
 		err         error
 	}
-	resultChannel := make(chan *sshNewClientResult, 2)
-	if *config.TunnelConnectTimeoutSeconds > 0 {
-		afterFunc := time.AfterFunc(time.Duration(*config.TunnelConnectTimeoutSeconds)*time.Second, func() {
-			resultChannel <- &sshNewClientResult{nil, nil, errors.New("ssh dial timeout")}
-		})
-		defer afterFunc.Stop()
-	}
+
+	resultChannel := make(chan sshNewClientResult)
+
+	// Call NewClientConn in a goroutine, as it blocks on SSH handshake network
+	// operations, and would block canceling or shutdown. If the parent context
+	// is canceled, close the net.Conn underlying SSH, which will interrupt the
+	// SSH handshake that may be blocking NewClientConn.
 
 	go func() {
 		// The following is adapted from ssh.Dial(), here using a custom conn
@@ -1020,10 +991,20 @@ func dialSsh(
 
 			sshClient = ssh.NewClient(sshClientConn, sshChannels, noRequests)
 		}
-		resultChannel <- &sshNewClientResult{sshClient, sshRequests, err}
+		resultChannel <- sshNewClientResult{sshClient, sshRequests, err}
 	}()
 
-	result := <-resultChannel
+	var result sshNewClientResult
+
+	select {
+	case result = <-resultChannel:
+	case <-ctx.Done():
+		result.err = ctx.Err()
+		// Interrupt the goroutine
+		sshConn.Close()
+		<-resultChannel
+	}
+
 	if result.err != nil {
 		return nil, common.ContextError(result.err)
 	}
@@ -1135,14 +1116,8 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 	statsTimer := time.NewTimer(nextStatusRequestPeriod())
 	defer statsTimer.Stop()
 
-	// Schedule an immediate status request to deliver any unreported
+	// Schedule an almost-immediate status request to deliver any unreported
 	// persistent stats.
-	// Note: this may not be effective when there's an outstanding
-	// asynchronous untunneled final status request is holding the
-	// persistent stats records. It may also conflict with other
-	// tunnel candidates which attempt to send an immediate request
-	// before being discarded. For now, we mitigate this with a short,
-	// random delay.
 	unreported := CountUnreportedPersistentStats()
 	if unreported > 0 {
 		NoticeInfo("Unreported persistent stats: %d", unreported)
@@ -1319,7 +1294,7 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 				}
 			}
 
-		case <-tunnel.shutdownOperateBroadcast:
+		case <-tunnel.operateCtx.Done():
 			shutdown = true
 		}
 	}
@@ -1393,40 +1368,18 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 		}
 	}
 
-	// Final status request notes:
-	//
-	// It's highly desirable to send a final status request in order to report
-	// domain bytes transferred stats as well as to report tunnel stats as
-	// soon as possible. For this reason, we attempt untunneled requests when
-	// the tunneled request isn't possible or has failed.
-	//
-	// In an orderly shutdown (err == nil), the Controller is stopping and
-	// everything must be wrapped up quickly. Also, we still have a working
-	// tunnel. So we first attempt a tunneled status request (with a short
-	// timeout) and then attempt, synchronously -- otherwise the Contoller's
-	// runWaitGroup.Wait() will return while a request is still in progress
-	// -- untunneled requests (also with short timeouts). Note that in this
-	// case the untunneled request will opt out of untunneledPendingConns so
-	// that it's not inadvertently canceled by the Controller shutdown
-	// sequence (see doUntunneledStatusRequest).
-	//
-	// If the tunnel has failed, the Controller may continue working. We want
-	// to re-establish as soon as possible (so don't want to block on status
-	// requests, even for a second). We may have a long time to attempt
-	// untunneled requests in the background. And there is no tunnel through
-	// which to attempt tunneled requests. So we spawn a goroutine to run the
-	// untunneled requests, which are allowed a longer timeout. These requests
-	// will be interrupted by the Controller's untunneledPendingConns.CloseAll()
-	// in the case of a shutdown.
-
 	if err == nil {
 		NoticeInfo("shutdown operate tunnel")
-		if !sendStats(tunnel) {
-			sendUntunneledStats(tunnel, true)
-		}
+
+		// Send a final status request in order to report any outstanding
+		// domain bytes transferred stats as well as to report session stats
+		// as soon as possible.
+		// This request will be interrupted when the tunnel is closed after
+		// TUNNEL_OPERATE_SHUTDOWN_TIMEOUT.
+		sendStats(tunnel)
+
 	} else {
 		NoticeAlert("operate tunnel error for %s: %s", tunnel.serverEntry.IpAddress, err)
-		go sendUntunneledStats(tunnel, false)
 		tunnelOwner.SignalTunnelFailure(tunnel)
 	}
 }
@@ -1438,7 +1391,14 @@ func (tunnel *Tunnel) operateTunnel(tunnelOwner TunnelOwner) {
 func sendSshKeepAlive(
 	sshClient *ssh.Client, conn net.Conn, timeout time.Duration) error {
 
-	errChannel := make(chan error, 2)
+	// Note: there is no request context since SSH requests cannot be
+	// interrupted directly. Closing the tunnel will interrupt the request.
+	// A timeout is set to unblock this function, but the goroutine may
+	// not exit until the tunnel is closed.
+
+	// Use a buffer of 1 as there are two senders and only one guaranteed receive.
+
+	errChannel := make(chan error, 1)
 	if timeout > 0 {
 		afterFunc := time.AfterFunc(timeout, func() {
 			errChannel <- errors.New("timed out")
@@ -1490,27 +1450,6 @@ func sendStats(tunnel *Tunnel) bool {
 	return err == nil
 }
 
-// sendUntunnelStats sends final status requests directly to Psiphon
-// servers after the tunnel has already failed. This is an attempt
-// to retain useful bytes transferred stats.
-func sendUntunneledStats(tunnel *Tunnel, isShutdown bool) {
-
-	// Tunnel does not have a serverContext when DisableApi is set
-	if tunnel.serverContext == nil {
-		return
-	}
-
-	// Skip when tunnel is discarded
-	if tunnel.IsDiscarded() {
-		return
-	}
-
-	err := tunnel.serverContext.TryUntunneledStatusRequest(isShutdown)
-	if err != nil {
-		NoticeAlert("TryUntunneledStatusRequest failed for %s: %s", tunnel.serverEntry.IpAddress, err)
-	}
-}
-
 // sendClientVerification is a helper for sending a client verification request
 // to the server.
 func sendClientVerification(tunnel *Tunnel, clientVerificationPayload string) bool {

+ 19 - 6
psiphon/upgradeDownload.go

@@ -20,6 +20,7 @@
 package psiphon
 
 import (
+	"context"
 	"fmt"
 	"net/http"
 	"os"
@@ -55,6 +56,7 @@ import (
 // necessary to re-download; (b) newer upgrades will be downloaded even when an older
 // upgrade is still pending install by the outer client.
 func DownloadUpgrade(
+	ctx context.Context,
 	config *Config,
 	attempt int,
 	handshakeVersion string,
@@ -71,28 +73,38 @@ func DownloadUpgrade(
 		return nil
 	}
 
+	if *config.DownloadUpgradeTimeoutSeconds > 0 {
+		var cancelFunc context.CancelFunc
+		ctx, cancelFunc = context.WithTimeout(
+			ctx, time.Duration(*config.DownloadUpgradeTimeoutSeconds)*time.Second)
+		defer cancelFunc()
+	}
 	// Select tunneled or untunneled configuration
 
 	downloadURL, _, skipVerify := selectDownloadURL(attempt, config.UpgradeDownloadURLs)
 
-	httpClient, requestUrl, err := MakeDownloadHttpClient(
+	httpClient, err := MakeDownloadHTTPClient(
+		ctx,
 		config,
 		tunnel,
 		untunneledDialConfig,
-		downloadURL,
-		skipVerify,
-		time.Duration(*config.DownloadUpgradeTimeoutSeconds)*time.Second)
+		skipVerify)
 
 	// If no handshake version is supplied, make an initial HEAD request
 	// to get the current version from the version header.
 
 	availableClientVersion := handshakeVersion
 	if availableClientVersion == "" {
-		request, err := http.NewRequest("HEAD", requestUrl, nil)
+
+		request, err := http.NewRequest("HEAD", downloadURL, nil)
 		if err != nil {
 			return common.ContextError(err)
 		}
+
+		request = request.WithContext(ctx)
+
 		response, err := httpClient.Do(request)
+
 		if err == nil && response.StatusCode != http.StatusOK {
 			response.Body.Close()
 			err = fmt.Errorf("unexpected response status code: %d", response.StatusCode)
@@ -139,8 +151,9 @@ func DownloadUpgrade(
 		"%s.%s", config.UpgradeDownloadFilename, availableClientVersion)
 
 	n, _, err := ResumeDownload(
+		ctx,
 		httpClient,
-		requestUrl,
+		downloadURL,
 		MakePsiphonUserAgent(config),
 		downloadFilename,
 		"")

+ 91 - 41
psiphon/upstreamproxy/transport_proxy_auth.go

@@ -22,6 +22,7 @@ package upstreamproxy
 import (
 	"bufio"
 	"bytes"
+	"context"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -38,7 +39,6 @@ const HTTP_STAT_LINE_LENGTH = 12
 // when requested by server
 type ProxyAuthTransport struct {
 	*http.Transport
-	Dial          DialFunc
 	Username      string
 	Password      string
 	Authenticator HttpAuthenticator
@@ -46,34 +46,53 @@ type ProxyAuthTransport struct {
 	CustomHeaders http.Header
 }
 
-func NewProxyAuthTransport(rawTransport *http.Transport, customHeaders http.Header) (*ProxyAuthTransport, error) {
-	dialFn := rawTransport.Dial
-	if dialFn == nil {
-		dialFn = net.Dial
+func NewProxyAuthTransport(
+	rawTransport *http.Transport,
+	customHeaders http.Header) (*ProxyAuthTransport, error) {
+
+	if rawTransport.DialContext == nil {
+		return nil, fmt.Errorf("rawTransport must have DialContext")
+	}
+
+	if rawTransport.Proxy == nil {
+		return nil, fmt.Errorf("rawTransport must have Proxy")
 	}
-	tr := &ProxyAuthTransport{Dial: dialFn, CustomHeaders: customHeaders}
-	proxyUrlFn := rawTransport.Proxy
-	if proxyUrlFn != nil {
-		wrappedDialFn := tr.wrapTransportDial()
-		rawTransport.Dial = wrappedDialFn
-		proxyUrl, err := proxyUrlFn(nil)
+
+	tr := &ProxyAuthTransport{
+		Transport:     rawTransport,
+		CustomHeaders: customHeaders,
+	}
+
+	// Wrap the original transport's custom dialed conns in transportConns,
+	// which handle connection-based authentication.
+	originalDialContext := rawTransport.DialContext
+	rawTransport.DialContext = func(
+		ctx context.Context, network, addr string) (net.Conn, error) {
+		conn, err := originalDialContext(ctx, "tcp", addr)
 		if err != nil {
 			return nil, err
 		}
-		if proxyUrl.Scheme != "http" {
-			return nil, fmt.Errorf("Only HTTP proxy supported, for SOCKS use http.Transport with custom dialers & upstreamproxy.NewProxyDialFunc")
-		}
-		if proxyUrl.User != nil {
-			tr.Username = proxyUrl.User.Username()
-			tr.Password, _ = proxyUrl.User.Password()
-		}
-		// strip username and password from the proxyURL because
-		// we do not want the wrapped transport to handle authentication
-		proxyUrl.User = nil
-		rawTransport.Proxy = http.ProxyURL(proxyUrl)
+		// Any additional dials made by transportConn are within
+		// the original dial context.
+		return newTransportConn(ctx, conn, tr), nil
+	}
+
+	proxyUrl, err := rawTransport.Proxy(nil)
+	if err != nil {
+		return nil, err
+	}
+	if proxyUrl.Scheme != "http" {
+		return nil, fmt.Errorf("%s unsupported", proxyUrl.Scheme)
+	}
+	if proxyUrl.User != nil {
+		tr.Username = proxyUrl.User.Username()
+		tr.Password, _ = proxyUrl.User.Password()
 	}
+	// strip username and password from the proxyURL because
+	// we do not want the wrapped transport to handle authentication
+	proxyUrl.User = nil
+	rawTransport.Proxy = http.ProxyURL(proxyUrl)
 
-	tr.Transport = rawTransport
 	return tr, nil
 }
 
@@ -88,7 +107,7 @@ func (tr *ProxyAuthTransport) preAuthenticateRequest(req *http.Request) error {
 
 func (tr *ProxyAuthTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
 	if req.URL.Scheme != "http" {
-		return nil, fmt.Errorf("Only plain HTTP supported, for HTTPS use http.Transport with DialTLS & upstreamproxy.NewProxyDialFunc")
+		return nil, fmt.Errorf("%s unsupported", req.URL.Scheme)
 	}
 	err = tr.preAuthenticateRequest(req)
 	if err != nil {
@@ -143,21 +162,6 @@ func (tr *ProxyAuthTransport) RoundTrip(req *http.Request) (resp *http.Response,
 
 }
 
-// wrapTransportDial wraps original transport Dial function
-// and returns a new net.Conn interface provided by transportConn
-// that allows us to intercept both outgoing requests and incoming
-// responses and examine / mutate them
-func (tr *ProxyAuthTransport) wrapTransportDial() DialFunc {
-	return func(network, addr string) (net.Conn, error) {
-		c, err := tr.Dial("tcp", addr)
-		if err != nil {
-			return nil, err
-		}
-		tc := newTransportConn(c, tr)
-		return tc, nil
-	}
-}
-
 // Based on https://github.com/golang/oauth2/blob/master/transport.go
 // Copyright 2014 The Go Authors. All rights reserved.
 func cloneRequest(r *http.Request, ch http.Header) *http.Request {
@@ -197,11 +201,16 @@ func cloneRequest(r *http.Request, ch http.Header) *http.Request {
 
 		r2.Body = ioutil.NopCloser(bytes.NewReader(body))
 	}
+
+	// A replayed request inherits the original request's deadline (and interruptability).
+	r2 = r2.WithContext(r.Context())
+
 	return r2
 }
 
 type transportConn struct {
 	net.Conn
+	ctx                context.Context
 	requestInterceptor io.Writer
 	reqDone            chan struct{}
 	errChannel         chan error
@@ -210,9 +219,42 @@ type transportConn struct {
 	transport          *ProxyAuthTransport
 }
 
-func newTransportConn(c net.Conn, tr *ProxyAuthTransport) *transportConn {
+func newTransportConn(
+	ctx context.Context,
+	c net.Conn,
+	tr *ProxyAuthTransport) *transportConn {
+
+	// TODOs:
+	//
+	// - Additional dials made by transportConn, for authentication, use the
+	//   original conn's dial context. If authentication can be requested at any
+	//   time, instead of just at the start of a connection, then any deadline for
+	//   this context will be inappropriate.
+	//
+	// - The "intercept" goroutine spawned below will never terminate? Even if the
+	//   transportConn is closed, nothing will unblock reads of the pipe made by
+	//   http.ReadRequest. There should be a call to pw.Close() in transportConn.Close().
+	//
+	// - The ioutil.ReadAll in the "intercept" goroutine allocates new buffers for
+	//   every request. To avoid GC churn it should use a byte.Buffer to reuse a
+	//   single buffer. In practise, there will be a reasonably small maximum request
+	//   body size, so its better to retain and reuse a buffer than to continously
+	//   reallocate.
+	//
+	// - transportConn.Read will not do anything if the caller passes in a very small
+	//   read buffer. This should be documented, as its assuming that the caller is
+	//   fully reading at least HTTP_STAT_LINE_LENGTH at the start of request.
+	//
+	// - As a net.Conn, transportConn.Read should always be interrupted by a call to
+	//   Close, but it may be possible for Read to remain blocked:
+	//   1. caller writes less than a full request to Write
+	//   2. "intercept" call to http.ReadRequest will not return
+	//   3. caller calls Close, which just calls transportConn.Conn.Close
+	//   4. any existing call to Read remains blocked in the select
+
 	tc := &transportConn{
 		Conn:       c,
+		ctx:        ctx,
 		reqDone:    make(chan struct{}),
 		errChannel: make(chan error),
 		transport:  tr,
@@ -301,7 +343,15 @@ func (tc *transportConn) Read(p []byte) (n int, readErr error) {
 				// dial a new one
 				addr := tc.Conn.RemoteAddr()
 				tc.Conn.Close()
-				tc.Conn, err = tc.transport.Dial(addr.Network(), addr.String())
+
+				// Additional dials are made within the context of the dial of the
+				// outer conn this transportConn is wrapping, so the scope of outer
+				// dial timeouts includes these additional dials. This is also to
+				// ensure these dials are interrupted when the context is canceled.
+
+				tc.Conn, err = tc.transport.Transport.DialContext(
+					tc.ctx, addr.Network(), addr.String())
+
 				if err != nil {
 					return 0, err
 				}

+ 7 - 3
psiphon/userAgent_test.go

@@ -20,6 +20,7 @@
 package psiphon
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"net/http"
@@ -233,19 +234,22 @@ func attemptConnectionsWithUserAgent(
 		t.Fatalf("error creating client controller: %s", err)
 	}
 
-	controllerShutdownBroadcast := make(chan struct{})
+	ctx, cancelFunc := context.WithCancel(context.Background())
+
 	controllerWaitGroup := new(sync.WaitGroup)
+
 	controllerWaitGroup.Add(1)
 	go func() {
 		defer controllerWaitGroup.Done()
-		controller.Run(controllerShutdownBroadcast)
+		controller.Run(ctx)
 	}()
 
 	// repeat attempts for long enough to select each user agent
 
 	time.Sleep(20 * time.Second)
 
-	close(controllerShutdownBroadcast)
+	cancelFunc()
+
 	controllerWaitGroup.Wait()
 
 	checkUserAgentCounts(t, isCONNECT)