Browse Source

merged rod/master

Adam Pritchard 11 years ago
parent
commit
7664db038a

+ 13 - 18
README.md

@@ -33,7 +33,10 @@ Setup
         "LocalSocksProxyPort" : 0,
         "EgressRegion" : "",
         "TunnelProtocol" : "",
-        "ConnectionWorkerPoolSize" : 10
+        "ConnectionWorkerPoolSize" : 10,
+        "TunnelPoolSize" : 1,
+        "PortForwardFailureThreshold" : 10,
+        "UpstreamHttpProxyAddress" : ""
     }
     ```
 
@@ -43,23 +46,20 @@ Setup
 Roadmap
 --------------------------------------------------------------------------------
 
-### TODO (proof-of-concept)
+### TODO (short-term)
 
+* requirements for integrating with Windows client
+  * split tunnel support
+  * implement page view and bytes transferred stats
+  * resumable download of client upgrades
 * Android app
   * open home pages
   * Go binary PIE, or use a Go library and JNI
   * settings UI (e.g., region selection)
-* reconnection busy loop when no network available (ex. close laptop); should wait for network connectivity
 * sometimes fails to promptly detect loss of connection after device sleep
-* continuity and performance
-  * always-on local proxies
-  * multiplex across simultaneous tunnels
-  * monitor health of tunnels; for example fail-over to new server on "ssh: rejected: administratively prohibited (open failed)" error?
 * PendingConns: is interrupting connection establishment worth the extra code complexity?
-* prefilter entries by capability; don't log "server does not have sufficient capabilities"
 * log noise: "use of closed network connection"
 * log noise(?): 'Unsolicited response received on idle HTTP channel starting with "H"'
-* use ContextError in more places
 
 ### TODO (future)
 
@@ -70,19 +70,14 @@ Roadmap
   * unfronted meek almost makes this obsolete, since meek sessions survive underlying
      HTTP transport socket disconnects. The client could prefer unfronted meek protocol
      when handshake returns a preemptive_reconnect_lifetime_milliseconds.
-* split tunnel support
-* implement page view stats
+  * could also be accomplished with TunnelPoolSize > 1 and staggaring the establishment times
 * implement local traffic stats (e.g., to display bytes sent/received)
-* control interface (w/ event messages)?
-* upstream proxy support
-* support upgrades
-  * download entire client
-  * download core component only
+* more formal control interface (w/ event messages)?
+* support upgrading core only
 * try multiple protocols for each server (currently only tries one protocol per server)
 * support a config pushed by the network
   * server can push preferred/optimized settings; client should prefer over defaults
-  * e.g., etablish worker pool size; multiplex tunnel pool size
-* overlap between httpProxy.go and socksProxy.go: refactor?
+  * e.g., etablish worker pool size; tunnel pool size
 
 Licensing
 --------------------------------------------------------------------------------

+ 8 - 1
psiphon/LookupIP.go

@@ -23,7 +23,7 @@ package psiphon
 
 import (
 	"errors"
-	dns "github.com/miekg/dns"
+	dns "github.com/Psiphon-Inc/dns"
 	"net"
 	"os"
 	"syscall"
@@ -58,11 +58,13 @@ func bindLookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 	if err != nil {
 		return nil, ContextError(err)
 	}
+
 	// config.BindToDeviceDnsServer must be an IP address
 	ipAddr := net.ParseIP(config.BindToDeviceDnsServer)
 	if ipAddr == nil {
 		return nil, ContextError(errors.New("invalid IP address"))
 	}
+
 	// TODO: IPv6 support
 	var ip [4]byte
 	copy(ip[:], ipAddr.To4())
@@ -72,6 +74,7 @@ func bindLookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 	if err != nil {
 		return nil, ContextError(err)
 	}
+
 	// Convert the syscall socket to a net.Conn, for use in the dns package
 	file := os.NewFile(uintptr(socketFd), "")
 	defer file.Close()
@@ -79,9 +82,11 @@ func bindLookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 	if err != nil {
 		return nil, ContextError(err)
 	}
+
 	// Set DNS query timeouts, using the ConnectTimeout from the overall Dial
 	conn.SetReadDeadline(time.Now().Add(config.ConnectTimeout))
 	conn.SetWriteDeadline(time.Now().Add(config.ConnectTimeout))
+
 	// Make the DNS query
 	// TODO: make interruptible?
 	dnsConn := &dns.Conn{Conn: conn}
@@ -90,6 +95,8 @@ func bindLookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 	query.SetQuestion(dns.Fqdn(host), dns.TypeA)
 	query.RecursionDesired = true
 	dnsConn.WriteMsg(query)
+
+	// Process the response
 	response, err := dnsConn.ReadMsg()
 	if err != nil {
 		return nil, ContextError(err)

+ 0 - 1
psiphon/TCPConn.go

@@ -54,7 +54,6 @@ func NewTCPDialer(config *DialConfig) Dialer {
 
 // TCPConn creates a new, connected TCPConn.
 func DialTCP(addr string, config *DialConfig) (conn *TCPConn, err error) {
-
 	conn, err = interruptibleTCPDial(addr, config)
 	if err != nil {
 		return nil, ContextError(err)

+ 25 - 2
psiphon/TCPConn_unix.go

@@ -39,6 +39,7 @@ type interruptibleTCPSocket struct {
 // syscall APIs are used. The sequence of syscalls in this implementation are
 // taken from: https://code.google.com/p/go/issues/detail?id=6966
 func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err error) {
+
 	// Create a socket and then, before connecting, add a TCPConn with
 	// the unconnected socket to pendingConns. This allows pendingConns to
 	// abort connections in progress.
@@ -52,6 +53,7 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 			syscall.Close(socketFd)
 		}
 	}()
+
 	// Note: this step is not interruptible
 	if config.BindToDeviceServiceAddress != "" {
 		err = bindToDevice(socketFd, config)
@@ -59,8 +61,16 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 			return nil, ContextError(err)
 		}
 	}
+
+	// When using an upstream HTTP proxy, first connect to the proxy,
+	// then use HTTP CONNECT to connect to the original destination.
+	dialAddr := addr
+	if config.UpstreamHttpProxyAddress != "" {
+		dialAddr = config.UpstreamHttpProxyAddress
+	}
+
 	// Get the remote IP and port, resolving a domain name if necessary
-	host, strPort, err := net.SplitHostPort(addr)
+	host, strPort, err := net.SplitHostPort(dialAddr)
 	if err != nil {
 		return nil, ContextError(err)
 	}
@@ -78,12 +88,15 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 	// TODO: IPv6 support
 	var ip [4]byte
 	copy(ip[:], ipAddrs[0].To4())
+
 	// Enable interruption
 	conn = &TCPConn{
 		interruptible: interruptibleTCPSocket{socketFd: socketFd},
 		readTimeout:   config.ReadTimeout,
 		writeTimeout:  config.WriteTimeout}
 	config.PendingConns.Add(conn)
+	defer config.PendingConns.Remove(conn)
+
 	// Connect the socket
 	// TODO: adjust the timeout to account for time spent resolving hostname
 	sockAddr := syscall.SockaddrInet4{Addr: ip, Port: port}
@@ -99,10 +112,10 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 	} else {
 		err = syscall.Connect(conn.interruptible.socketFd, &sockAddr)
 	}
-	config.PendingConns.Remove(conn)
 	if err != nil {
 		return nil, ContextError(err)
 	}
+
 	// Convert the syscall socket to a net.Conn
 	file := os.NewFile(uintptr(conn.interruptible.socketFd), "")
 	defer file.Close()
@@ -110,6 +123,16 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 	if err != nil {
 		return nil, ContextError(err)
 	}
+
+	// Going through upstream HTTP proxy
+	if config.UpstreamHttpProxyAddress != "" {
+		// This call can be interrupted by closing the pending conn
+		err := HttpProxyConnect(conn, addr)
+		if err != nil {
+			return nil, ContextError(err)
+		}
+	}
+
 	return conn, nil
 }
 

+ 52 - 9
psiphon/TCPConn_windows.go

@@ -22,29 +22,72 @@
 package psiphon
 
 import (
+	"errors"
 	"net"
 )
 
+// interruptibleTCPSocket simulates interruptible semantics on Windows. A call
+// to interruptibleTCPClose doesn't actually interrupt a connect in progress,
+// but abandons a dial that's running in a goroutine.
+// Interruptible semantics are required by the controller for timely component
+// state changes.
+// TODO: implement true interruptible semantics on Windows; use syscall and
+// a HANDLE similar to how TCPConn_unix uses a file descriptor?
 type interruptibleTCPSocket struct {
+	results chan *interruptibleDialResult
+}
+
+type interruptibleDialResult struct {
+	netConn net.Conn
+	err     error
 }
 
 func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err error) {
 	if config.BindToDeviceServiceAddress != "" {
 		Fatal("psiphon.interruptibleTCPDial with bind not supported on Windows")
 	}
-	// Note: using standard net.Dial(); interruptible connections not supported on Windows
-	netConn, err := net.DialTimeout("tcp", addr, config.ConnectTimeout)
-	if err != nil {
-		return nil, ContextError(err)
-	}
+
 	conn = &TCPConn{
-		Conn:         netConn,
-		readTimeout:  config.ReadTimeout,
-		writeTimeout: config.WriteTimeout}
+		interruptible: interruptibleTCPSocket{results: make(chan *interruptibleDialResult, 2)},
+		readTimeout:   config.ReadTimeout,
+		writeTimeout:  config.WriteTimeout}
+	config.PendingConns.Add(conn)
+
+	// Call the blocking Dial in a goroutine
+	results := conn.interruptible.results
+	go func() {
+
+		// When using an upstream HTTP proxy, first connect to the proxy,
+		// then use HTTP CONNECT to connect to the original destination.
+		dialAddr := addr
+		if config.UpstreamHttpProxyAddress != "" {
+			dialAddr = config.UpstreamHttpProxyAddress
+		}
+
+		netConn, err := net.DialTimeout("tcp", dialAddr, config.ConnectTimeout)
+
+		if config.UpstreamHttpProxyAddress != "" {
+			err := HttpProxyConnect(netConn, addr)
+			if err != nil {
+				return nil, ContextError(err)
+			}
+		}
+
+		results <- &interruptibleDialResult{netConn, err}
+	}()
+
+	// Block until Dial completes (or times out) or until interrupt
+	result := <-conn.interruptible.results
+	config.PendingConns.Remove(conn)
+	if result.err != nil {
+		return nil, ContextError(result.err)
+	}
+	conn.Conn = result.netConn
+
 	return conn, nil
 }
 
 func interruptibleTCPClose(interruptible interruptibleTCPSocket) error {
-	Fatal("psiphon.interruptibleTCPClose not supported on Windows")
+	interruptible.results <- &interruptibleDialResult{nil, errors.New("socket interrupted")}
 	return nil
 }

+ 23 - 7
psiphon/config.go

@@ -41,6 +41,9 @@ type Config struct {
 	ConnectionWorkerPoolSize           int
 	BindToDeviceServiceAddress         string
 	BindToDeviceDnsServer              string
+	TunnelPoolSize                     int
+	PortForwardFailureThreshold        int
+	UpstreamHttpProxyAddress           string
 }
 
 // LoadConfig reads, and parse, and validates a JSON format Psiphon config
@@ -48,31 +51,36 @@ type Config struct {
 func LoadConfig(filename string) (*Config, error) {
 	fileContents, err := ioutil.ReadFile(filename)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	var config Config
 	err = json.Unmarshal(fileContents, &config)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 
 	// These fields are required; the rest are optional
 	if config.PropagationChannelId == "" {
-		return nil, errors.New("propagation channel ID is missing from the configuration file")
+		return nil, ContextError(
+			errors.New("propagation channel ID is missing from the configuration file"))
 	}
 	if config.SponsorId == "" {
-		return nil, errors.New("sponsor ID is missing from the configuration file")
+		return nil, ContextError(
+			errors.New("sponsor ID is missing from the configuration file"))
 	}
 	if config.RemoteServerListUrl == "" {
-		return nil, errors.New("remote server list URL is missing from the configuration file")
+		return nil, ContextError(
+			errors.New("remote server list URL is missing from the configuration file"))
 	}
 	if config.RemoteServerListSignaturePublicKey == "" {
-		return nil, errors.New("remote server list signature public key is missing from the configuration file")
+		return nil, ContextError(
+			errors.New("remote server list signature public key is missing from the configuration file"))
 	}
 
 	if config.TunnelProtocol != "" {
 		if !Contains(SupportedTunnelProtocols, config.TunnelProtocol) {
-			return nil, errors.New("invalid tunnel protocol")
+			return nil, ContextError(
+				errors.New("invalid tunnel protocol"))
 		}
 	}
 
@@ -80,5 +88,13 @@ func LoadConfig(filename string) (*Config, error) {
 		config.ConnectionWorkerPoolSize = CONNECTION_WORKER_POOL_SIZE
 	}
 
+	if config.TunnelPoolSize == 0 {
+		config.TunnelPoolSize = TUNNEL_POOL_SIZE
+	}
+
+	if config.PortForwardFailureThreshold == 0 {
+		config.PortForwardFailureThreshold = PORT_FORWARD_FAILURE_THRESHOLD
+	}
+
 	return &config, nil
 }

+ 60 - 0
psiphon/conn.go

@@ -20,6 +20,9 @@
 package psiphon
 
 import (
+	"bytes"
+	"fmt"
+	"io"
 	"net"
 	"sync"
 	"time"
@@ -28,6 +31,12 @@ import (
 // DialConfig contains parameters to determine the behavior
 // of a Psiphon dialer (TCPDial, MeekDial, etc.)
 type DialConfig struct {
+
+	// UpstreamHttpProxyAddress specifies an HTTP proxy to connect through
+	// (the proxy must support HTTP CONNECT). The address may be a hostname
+	// or IP address and must include a port number.
+	UpstreamHttpProxyAddress string
+
 	ConnectTimeout time.Duration
 	ReadTimeout    time.Duration
 	WriteTimeout   time.Duration
@@ -99,3 +108,54 @@ func (conns *Conns) CloseAll() {
 	}
 	conns.conns = make(map[net.Conn]bool)
 }
+
+// Relay sends to remoteConn bytes received from localConn,
+// and sends to localConn bytes received from remoteConn.
+func Relay(localConn, remoteConn net.Conn) {
+	copyWaitGroup := new(sync.WaitGroup)
+	copyWaitGroup.Add(1)
+	go func() {
+		defer copyWaitGroup.Done()
+		_, err := io.Copy(localConn, remoteConn)
+		if err != nil {
+			Notice(NOTICE_ALERT, "%s", ContextError(err))
+		}
+	}()
+	_, err := io.Copy(remoteConn, localConn)
+	if err != nil {
+		Notice(NOTICE_ALERT, "%s", ContextError(err))
+	}
+	copyWaitGroup.Wait()
+}
+
+// HttpProxyConnect establishes a HTTP CONNECT tunnel to addr through
+// an established network connection to an HTTP proxy. It is assumed that
+// no payload bytes have been sent through the connection to the proxy.
+func HttpProxyConnect(rawConn net.Conn, addr string) (err error) {
+	hostname, _, err := net.SplitHostPort(addr)
+	if err != nil {
+		return ContextError(err)
+	}
+
+	// TODO: use the proxy request/response code from net/http/transport.go?
+	connectRequest := fmt.Sprintf(
+		"CONNECT %s HTTP/1.1\r\nHost: %s\r\nConnection: Keep-Alive\r\n\r\n",
+		addr, hostname)
+	_, err = rawConn.Write([]byte(connectRequest))
+	if err != nil {
+		return ContextError(err)
+	}
+
+	expectedResponse := []byte("HTTP/1.1 200 OK\r\n\r\n")
+	readBuffer := make([]byte, len(expectedResponse))
+	_, err = io.ReadFull(rawConn, readBuffer)
+	if err != nil {
+		return ContextError(err)
+	}
+
+	if !bytes.Equal(readBuffer, expectedResponse) {
+		return fmt.Errorf("unexpected HTTP proxy response: %s", string(readBuffer))
+	}
+
+	return nil
+}

+ 615 - 0
psiphon/controller.go

@@ -0,0 +1,615 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+// Package psiphon implements the core tunnel functionality of a Psiphon client.
+// The main function is RunForever, which runs a Controller that obtains lists of
+// servers, establishes tunnel connections, and runs local proxies through which
+// tunneled traffic may be sent.
+package psiphon
+
+import (
+	"errors"
+	"fmt"
+	"log"
+	"net"
+	"os"
+	"sync"
+	"time"
+)
+
+// Controller is a tunnel lifecycle coordinator. It manages lists of servers to
+// connect to; establishes and monitors tunnels; and runs local proxies which
+// route traffic through the tunnels.
+type Controller struct {
+	config                    *Config
+	failureSignal             chan struct{}
+	shutdownBroadcast         chan struct{}
+	runWaitGroup              *sync.WaitGroup
+	establishedTunnels        chan *Tunnel
+	failedTunnels             chan *Tunnel
+	tunnelMutex               sync.Mutex
+	tunnels                   []*Tunnel
+	nextTunnel                int
+	operateWaitGroup          *sync.WaitGroup
+	isEstablishing            bool
+	establishWaitGroup        *sync.WaitGroup
+	stopEstablishingBroadcast chan struct{}
+	candidateServerEntries    chan *ServerEntry
+	pendingConns              *Conns
+}
+
+// NewController initializes a new controller.
+func NewController(config *Config) (controller *Controller) {
+	return &Controller{
+		config: config,
+		// failureSignal receives a signal from a component (including socks and
+		// http local proxies) if they unexpectedly fail. Senders should not block.
+		// A buffer allows at least one stop signal to be sent before there is a receiver.
+		failureSignal:     make(chan struct{}, 1),
+		shutdownBroadcast: make(chan struct{}),
+		runWaitGroup:      new(sync.WaitGroup),
+		// establishedTunnels and failedTunnels buffer sizes are large enough to
+		// receive full pools of tunnels without blocking. Senders should not block.
+		establishedTunnels: make(chan *Tunnel, config.TunnelPoolSize),
+		failedTunnels:      make(chan *Tunnel, config.TunnelPoolSize),
+		tunnels:            make([]*Tunnel, 0),
+		operateWaitGroup:   new(sync.WaitGroup),
+		isEstablishing:     false,
+		pendingConns:       new(Conns),
+	}
+}
+
+// Run executes the controller. It launches components and then monitors
+// for a shutdown signal; after receiving the signal it shuts down the
+// controller.
+// The components include:
+// - the periodic remote server list fetcher
+// - the tunnel manager
+// - a local SOCKS proxy that port forwards through the pool of tunnels
+// - a local HTTP proxy that port forwards through the pool of tunnels
+func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
+	socksProxy, err := NewSocksProxy(controller)
+	if err != nil {
+		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
+		return
+	}
+	defer socksProxy.Close()
+	httpProxy, err := NewHttpProxy(controller)
+	if err != nil {
+		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
+		return
+	}
+	defer httpProxy.Close()
+
+	controller.runWaitGroup.Add(2)
+	go controller.remoteServerListFetcher()
+	go controller.runTunnels()
+
+	select {
+	case <-shutdownBroadcast:
+		Notice(NOTICE_INFO, "controller shutdown by request")
+	case <-controller.failureSignal:
+		Notice(NOTICE_ALERT, "controller shutdown due to failure")
+	}
+
+	// Note: in addition to establish(), this pendingConns will interrupt
+	// FetchRemoteServerList
+	controller.pendingConns.CloseAll()
+	close(controller.shutdownBroadcast)
+	controller.runWaitGroup.Wait()
+
+	Notice(NOTICE_INFO, "exiting controller")
+}
+
+// SignalFailure notifies the controller than a component has failed.
+// This will terminate the controller.
+func (controller *Controller) SignalFailure() {
+	select {
+	case controller.failureSignal <- *new(struct{}):
+	default:
+	}
+}
+
+// remoteServerListFetcher fetches an out-of-band list of server entries
+// for more tunnel candidates. It fetches immediately, retries after failure
+// with a wait period, and refetches after success with a longer wait period.
+func (controller *Controller) remoteServerListFetcher() {
+	defer controller.runWaitGroup.Done()
+
+	// Note: unlike existing Psiphon clients, this code
+	// always makes the fetch remote server list request
+loop:
+	for {
+		// TODO: FetchRemoteServerList should have its own pendingConns,
+		// otherwise it may needlessly abort when establish is stopped.
+		err := FetchRemoteServerList(controller.config, controller.pendingConns)
+		var duration time.Duration
+		if err != nil {
+			Notice(NOTICE_ALERT, "failed to fetch remote server list: %s", err)
+			duration = FETCH_REMOTE_SERVER_LIST_RETRY_TIMEOUT
+		} else {
+			duration = FETCH_REMOTE_SERVER_LIST_STALE_TIMEOUT
+		}
+		timeout := time.After(duration)
+		select {
+		case <-timeout:
+			// Fetch again
+		case <-controller.shutdownBroadcast:
+			break loop
+		}
+	}
+
+	Notice(NOTICE_INFO, "exiting remote server list fetcher")
+}
+
+// runTunnels is the controller tunnel management main loop. It starts and stops
+// establishing tunnels based on the target tunnel pool size and the current size
+// of the pool. Tunnels are established asynchronously using worker goroutines.
+// When a tunnel is established, it's added to the active pool and a corresponding
+// operateTunnel goroutine is launched which starts a session in the tunnel and
+// monitors the tunnel for failures.
+// When a tunnel fails, it's removed from the pool and the establish process is
+// restarted to fill the pool.
+func (controller *Controller) runTunnels() {
+	defer controller.runWaitGroup.Done()
+
+	// Don't start establishing until there are some server candidates. The
+	// typical case is a client with no server entries which will wait for
+	// the first successful FetchRemoteServerList to populate the data store.
+	for {
+		if HasServerEntries(
+			controller.config.EgressRegion, controller.config.TunnelProtocol) {
+			break
+		}
+		// TODO: replace polling with signal
+		timeout := time.After(1 * time.Second)
+		select {
+		case <-timeout:
+		case <-controller.shutdownBroadcast:
+			return
+		}
+	}
+	controller.startEstablishing()
+loop:
+	for {
+		select {
+		case failedTunnel := <-controller.failedTunnels:
+			Notice(NOTICE_ALERT, "tunnel failed: %s", failedTunnel.serverEntry.IpAddress)
+			controller.terminateTunnel(failedTunnel)
+			// Note: only this goroutine may call startEstablishing/stopEstablishing and access
+			// isEstablishing.
+			if !controller.isEstablishing {
+				controller.startEstablishing()
+			}
+
+		// !TODO! design issue: might not be enough server entries with region/caps to ever fill tunnel slots
+		// solution(?) target MIN(CountServerEntries(region, protocol), TunnelPoolSize)
+		case establishedTunnel := <-controller.establishedTunnels:
+			Notice(NOTICE_INFO, "established tunnel: %s", establishedTunnel.serverEntry.IpAddress)
+			// !TODO! design issue: activateTunnel makes tunnel avail for port forward *before* operates does handshake
+			// solution(?) distinguish between two stages or states: connected, and then active.
+			if controller.activateTunnel(establishedTunnel) {
+				Notice(NOTICE_INFO, "active tunnel: %s", establishedTunnel.serverEntry.IpAddress)
+				controller.operateWaitGroup.Add(1)
+				go controller.operateTunnel(establishedTunnel)
+			} else {
+				controller.discardTunnel(establishedTunnel)
+			}
+			if controller.isFullyEstablished() {
+				controller.stopEstablishing()
+			}
+
+		case <-controller.shutdownBroadcast:
+			break loop
+		}
+	}
+	controller.stopEstablishing()
+	controller.terminateAllTunnels()
+	controller.operateWaitGroup.Wait()
+
+	// Drain tunnel channels
+	close(controller.establishedTunnels)
+	for tunnel := range controller.establishedTunnels {
+		controller.discardTunnel(tunnel)
+	}
+	close(controller.failedTunnels)
+	for tunnel := range controller.failedTunnels {
+		controller.discardTunnel(tunnel)
+	}
+
+	Notice(NOTICE_INFO, "exiting run tunnels")
+}
+
+// discardTunnel disposes of a successful connection that is no longer required.
+func (controller *Controller) discardTunnel(tunnel *Tunnel) {
+	Notice(NOTICE_INFO, "discard tunnel: %s", tunnel.serverEntry.IpAddress)
+	// TODO: not calling PromoteServerEntry, since that would rank the
+	// discarded tunnel before fully active tunnels. Can a discarded tunnel
+	// be promoted (since it connects), but with lower rank than all active
+	// tunnels?
+	tunnel.Close()
+}
+
+// activateTunnel adds the connected tunnel to the pool of active tunnels
+// which are used for port forwarding. Returns true if the pool has an empty
+// slot and false if the pool is full (caller should discard the tunnel).
+func (controller *Controller) activateTunnel(tunnel *Tunnel) bool {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	// !TODO! double check not already a tunnel to this server
+	if len(controller.tunnels) >= controller.config.TunnelPoolSize {
+		return false
+	}
+	controller.tunnels = append(controller.tunnels, tunnel)
+	Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
+	return true
+}
+
+// isFullyEstablished indicates if the pool of active tunnels is full.
+func (controller *Controller) isFullyEstablished() bool {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	return len(controller.tunnels) >= controller.config.TunnelPoolSize
+}
+
+// terminateTunnel removes a tunnel from the pool of active tunnels
+// and closes the tunnel. The next-tunnel state used by getNextActiveTunnel
+// is adjusted as required.
+func (controller *Controller) terminateTunnel(tunnel *Tunnel) {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	for index, activeTunnel := range controller.tunnels {
+		if tunnel == activeTunnel {
+			controller.tunnels = append(
+				controller.tunnels[:index], controller.tunnels[index+1:]...)
+			if controller.nextTunnel > index {
+				controller.nextTunnel--
+			}
+			if controller.nextTunnel >= len(controller.tunnels) {
+				controller.nextTunnel = 0
+			}
+			activeTunnel.Close()
+			Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
+			break
+		}
+	}
+}
+
+// terminateAllTunnels empties the tunnel pool, closing all active tunnels.
+// This is used when shutting down the controller.
+func (controller *Controller) terminateAllTunnels() {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	for _, activeTunnel := range controller.tunnels {
+		activeTunnel.Close()
+	}
+	controller.tunnels = make([]*Tunnel, 0)
+	controller.nextTunnel = 0
+	Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
+}
+
+// getNextActiveTunnel returns the next tunnel from the pool of active
+// tunnels. Currently, tunnel selection order is simple round-robin.
+func (controller *Controller) getNextActiveTunnel() (tunnel *Tunnel) {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	if len(controller.tunnels) == 0 {
+		return nil
+	}
+	tunnel = controller.tunnels[controller.nextTunnel]
+	controller.nextTunnel =
+		(controller.nextTunnel + 1) % len(controller.tunnels)
+	return tunnel
+}
+
+// getActiveTunnelServerEntries lists the Server Entries for
+// all the active tunnels. This is used to exclude those servers
+// from the set of candidates to establish connections to.
+func (controller *Controller) getActiveTunnelServerEntries() (serverEntries []*ServerEntry) {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	serverEntries = make([]*ServerEntry, 0)
+	for _, activeTunnel := range controller.tunnels {
+		serverEntries = append(serverEntries, activeTunnel.serverEntry)
+	}
+	return serverEntries
+}
+
+// operateTunnel starts a Psiphon session (handshake, etc.) on a newly
+// connected tunnel, and then monitors the tunnel for failures:
+//
+// 1. Overall tunnel failure: the tunnel sends a signal to the ClosedSignal
+// channel on keep-alive failure and other transport I/O errors. In case
+// of such a failure, the tunnel is marked as failed.
+//
+// 2. Tunnel port forward failures: the tunnel connection may stay up but
+// the client may still fail to establish port forwards due to server load
+// and other conditions. After a threshold number of such failures, the
+// overall tunnel is marked as failed.
+//
+// TODO: currently, any connect (dial), read, or write error associated with
+// a port forward is counted as a failure. It may be important to differentiate
+// between failures due to Psiphon server conditions and failures due to the
+// origin/target server (in the latter case, the tunnel is healthy). Here are
+// some typical error messages to consider matching against (or ignoring):
+//
+// - "ssh: rejected: administratively prohibited (open failed)"
+// - "ssh: rejected: connect failed (Connection timed out)"
+// - "write tcp ... broken pipe"
+// - "read tcp ... connection reset by peer"
+// - "ssh: unexpected packet in response to channel open: <nil>"
+//
+func (controller *Controller) operateTunnel(tunnel *Tunnel) {
+	defer controller.operateWaitGroup.Done()
+
+	tunnelClosedSignal := make(chan struct{}, 1)
+	err := tunnel.conn.SetClosedSignal(tunnelClosedSignal)
+	if err != nil {
+		err = fmt.Errorf("failed to set closed signal: %s", err)
+	}
+
+	Notice(NOTICE_INFO, "starting session for %s", tunnel.serverEntry.IpAddress)
+	// TODO: NewSession server API calls may block shutdown
+	_, err = NewSession(controller.config, tunnel)
+	if err != nil {
+		err = fmt.Errorf("error starting session for %s: %s", tunnel.serverEntry.IpAddress, err)
+	}
+
+	// Promote this successful tunnel to first rank so it's one
+	// of the first candidates next time establish runs.
+	PromoteServerEntry(tunnel.serverEntry.IpAddress)
+
+	for err == nil {
+		select {
+		case failures := <-tunnel.portForwardFailures:
+			tunnel.portForwardFailureTotal += failures
+			if tunnel.portForwardFailureTotal > controller.config.PortForwardFailureThreshold {
+				err = errors.New("tunnel exceeded port forward failure threshold")
+			}
+
+		case <-tunnelClosedSignal:
+			// TODO: this signal can be received during a commanded shutdown due to
+			// how tunnels are closed; should rework this to avoid log noise.
+			err = errors.New("tunnel closed unexpectedly")
+
+		case <-controller.shutdownBroadcast:
+			Notice(NOTICE_INFO, "shutdown operate tunnel")
+			return
+		}
+	}
+
+	if err != nil {
+		Notice(NOTICE_ALERT, "operate tunnel error for %s: %s", tunnel.serverEntry.IpAddress, err)
+		// Don't block. Assumes the receiver has a buffer large enough for
+		// the typical number of operated tunnels. In case there's no room,
+		// terminate the tunnel (runTunnels won't get a signal in this case).
+		select {
+		case controller.failedTunnels <- tunnel:
+		default:
+			controller.terminateTunnel(tunnel)
+		}
+	}
+}
+
+// TunneledConn implements net.Conn and wraps a port foward connection.
+// It is used to hook into Read and Write to observe I/O errors and
+// report these errors back to the tunnel monitor as port forward failures.
+type TunneledConn struct {
+	net.Conn
+	tunnel *Tunnel
+}
+
+func (conn *TunneledConn) Read(buffer []byte) (n int, err error) {
+	n, err = conn.Conn.Read(buffer)
+	if err != nil {
+		// Report 1 new failure. Won't block; assumes the receiver
+		// has a sufficient buffer for the threshold number of reports.
+		// TODO: conditional on type of error or error message?
+		select {
+		case conn.tunnel.portForwardFailures <- 1:
+		default:
+		}
+	}
+	return
+}
+
+func (conn *TunneledConn) Write(buffer []byte) (n int, err error) {
+	n, err = conn.Conn.Write(buffer)
+	if err != nil {
+		// Same as TunneledConn.Read()
+		select {
+		case conn.tunnel.portForwardFailures <- 1:
+		default:
+		}
+	}
+	return
+}
+
+// dialWithTunnel selects an active tunnel and establishes a port forward
+// connection through the selected tunnel. Failure to connect is considered
+// a port foward failure, for the purpose of monitoring tunnel health.
+func (controller *Controller) dialWithTunnel(remoteAddr string) (conn net.Conn, err error) {
+	tunnel := controller.getNextActiveTunnel()
+	if tunnel == nil {
+		return nil, ContextError(errors.New("no active tunnels"))
+	}
+	sshPortForward, err := tunnel.sshClient.Dial("tcp", remoteAddr)
+	if err != nil {
+		// TODO: conditional on type of error or error message?
+		select {
+		case tunnel.portForwardFailures <- 1:
+		default:
+		}
+		return nil, ContextError(err)
+	}
+	return &TunneledConn{
+			Conn:   sshPortForward,
+			tunnel: tunnel},
+		nil
+}
+
+// startEstablishing creates a pool of worker goroutines which will
+// attempt to establish tunnels to candidate servers. The candidates
+// are generated by another goroutine.
+func (controller *Controller) startEstablishing() {
+	if controller.isEstablishing {
+		return
+	}
+	Notice(NOTICE_INFO, "start establishing")
+	controller.isEstablishing = true
+	controller.establishWaitGroup = new(sync.WaitGroup)
+	controller.stopEstablishingBroadcast = make(chan struct{})
+	controller.candidateServerEntries = make(chan *ServerEntry)
+
+	for i := 0; i < controller.config.ConnectionWorkerPoolSize; i++ {
+		controller.establishWaitGroup.Add(1)
+		go controller.establishTunnelWorker()
+	}
+
+	controller.establishWaitGroup.Add(1)
+	go controller.establishCandidateGenerator()
+}
+
+// stopEstablishing signals the establish goroutines to stop and waits
+// for the group to halt. pendingConns is used to interrupt any worker
+// blocked on a socket connect.
+func (controller *Controller) stopEstablishing() {
+	if !controller.isEstablishing {
+		return
+	}
+	Notice(NOTICE_INFO, "stop establishing")
+	// Note: on Windows, interruptibleTCPClose doesn't really interrupt socket connects
+	// and may leave goroutines running for a time after the Wait call.
+	controller.pendingConns.CloseAll()
+	close(controller.stopEstablishingBroadcast)
+	// Note: establishCandidateGenerator closes controller.candidateServerEntries
+	// (as it may be sending to that channel).
+	controller.establishWaitGroup.Wait()
+
+	controller.isEstablishing = false
+	controller.establishWaitGroup = nil
+	controller.stopEstablishingBroadcast = nil
+	controller.candidateServerEntries = nil
+}
+
+// establishCandidateGenerator populates the candidate queue with server entries
+// from the data store. Server entries are iterated in rank order, so that promoted
+// servers with higher rank are priority candidates.
+func (controller *Controller) establishCandidateGenerator() {
+	defer controller.establishWaitGroup.Done()
+loop:
+	for {
+		// Note: it's possible that an active tunnel in excludeServerEntries will
+		// fail during this iteration of server entries and in that case the
+		// cooresponding server will not be retried (within the same iteration).
+		// !TODO! is there also a race that can result in multiple tunnels to the same server
+		excludeServerEntries := controller.getActiveTunnelServerEntries()
+		iterator, err := NewServerEntryIterator(
+			controller.config.EgressRegion, controller.config.TunnelProtocol, excludeServerEntries)
+		if err != nil {
+			Notice(NOTICE_ALERT, "failed to iterate over candidates: %s", err)
+			controller.SignalFailure()
+			break loop
+		}
+		for {
+			serverEntry, err := iterator.Next()
+			if err != nil {
+				Notice(NOTICE_ALERT, "failed to get next candidate: %s", err)
+				controller.SignalFailure()
+				break loop
+			}
+			if serverEntry == nil {
+				// Completed this iteration
+				break
+			}
+			select {
+			case controller.candidateServerEntries <- serverEntry:
+			case <-controller.stopEstablishingBroadcast:
+				break loop
+			case <-controller.shutdownBroadcast:
+				break loop
+			}
+		}
+		iterator.Close()
+		// After a complete iteration of candidate servers, pause before iterating again.
+		// This helps avoid some busy wait loop conditions, and also allows some time for
+		// network conditions to change.
+		timeout := time.After(ESTABLISH_TUNNEL_PAUSE_PERIOD)
+		select {
+		case <-timeout:
+			// Retry iterating
+		case <-controller.stopEstablishingBroadcast:
+			break loop
+		case <-controller.shutdownBroadcast:
+			break loop
+		}
+	}
+	close(controller.candidateServerEntries)
+	Notice(NOTICE_INFO, "stopped candidate generator")
+}
+
+// establishTunnelWorker pulls candidates from the candidate queue, establishes
+// a connection to the tunnel server, and delivers the established tunnel to a channel.
+func (controller *Controller) establishTunnelWorker() {
+	defer controller.establishWaitGroup.Done()
+	for serverEntry := range controller.candidateServerEntries {
+		// Note: don't receive from candidateQueue and broadcastStopWorkers in the same
+		// select, since we want to prioritize receiving the stop signal
+		select {
+		case <-controller.stopEstablishingBroadcast:
+			return
+		default:
+		}
+		tunnel, err := EstablishTunnel(controller, serverEntry)
+		if err != nil {
+			// TODO: distingush case where conn is interrupted?
+			Notice(NOTICE_INFO, "failed to connect to %s: %s", serverEntry.IpAddress, err)
+		} else {
+			// Don't block. Assumes the receiver has a buffer large enough for
+			// the number of desired tunnels. If there's no room, the tunnel must
+			// not be required so it's discarded.
+			select {
+			case controller.establishedTunnels <- tunnel:
+			default:
+				controller.discardTunnel(tunnel)
+			}
+		}
+	}
+	Notice(NOTICE_INFO, "stopped establish worker")
+}
+
+// RunForever executes the main loop of the Psiphon client. It launches
+// the controller with a shutdown that it never signaled.
+func RunForever(config *Config) {
+
+	if config.LogFilename != "" {
+		logFile, err := os.OpenFile(config.LogFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
+		if err != nil {
+			Fatal("error opening log file: %s", err)
+		}
+		defer logFile.Close()
+		log.SetOutput(logFile)
+	}
+
+	Notice(NOTICE_VERSION, VERSION)
+
+	controller := NewController(config)
+	shutdownBroadcast := make(chan struct{})
+	controller.Run(shutdownBroadcast)
+}

+ 117 - 61
psiphon/dataStore.go

@@ -24,7 +24,8 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	sqlite3 "github.com/mattn/go-sqlite3"
+	sqlite3 "github.com/Psiphon-Inc/go-sqlite3"
+	"strings"
 	"sync"
 	"time"
 )
@@ -47,6 +48,9 @@ func initDataStore() {
              rank integer not null unique,
              region text not null,
              data blob not null);
+	    create table if not exists serverEntryProtocol
+	        (serverEntryId text not null,
+	         protocol text not null);
         create table if not exists keyValue
             (key text not null,
              value text not null);
@@ -130,7 +134,6 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 		if serverEntryExists && !replaceIfExists {
 			return nil
 		}
-		// TODO: also skip updates if replaceIfExists but 'data' has not changed
 		_, err := transaction.Exec(`
             update serverEntry set rank = rank + 1
                 where id = (select id from serverEntry order by rank desc limit 1);
@@ -150,6 +153,20 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 		if err != nil {
 			return err
 		}
+		for _, protocol := range SupportedTunnelProtocols {
+			// Note: for meek, the capabilities are FRONTED-MEEK and UNFRONTED-MEEK
+			// and the additonal OSSH service is assumed to be available internally.
+			requiredCapability := strings.TrimSuffix(protocol, "-OSSH")
+			if Contains(serverEntry.Capabilities, requiredCapability) {
+				_, err = transaction.Exec(`
+		            insert or ignore into serverEntryProtocol (serverEntryId, protocol)
+		            values (?, ?);
+		            `, serverEntry.IpAddress, protocol)
+				if err != nil {
+					return err
+				}
+			}
+		}
 		// TODO: post notice after commit
 		if !serverEntryExists {
 			Notice(NOTICE_INFO, "stored server %s", serverEntry.IpAddress)
@@ -176,90 +193,91 @@ func PromoteServerEntry(ipAddress string) error {
 	})
 }
 
-// ServerEntryCycler is used to continuously iterate over
+// ServerEntryIterator is used to iterate over
 // stored server entries in rank order.
-type ServerEntryCycler struct {
+type ServerEntryIterator struct {
 	region      string
+	protocol    string
+	excludeIds  []string
 	transaction *sql.Tx
 	cursor      *sql.Rows
-	isReset     bool
 }
 
-// NewServerEntryCycler creates a new ServerEntryCycler
-func NewServerEntryCycler(region string) (cycler *ServerEntryCycler, err error) {
+// NewServerEntryIterator creates a new NewServerEntryIterator
+func NewServerEntryIterator(
+	region, protocol string,
+	excludeServerEntries []*ServerEntry) (iterator *ServerEntryIterator, err error) {
+
 	initDataStore()
-	cycler = &ServerEntryCycler{region: region}
-	err = cycler.Reset()
+	excludeIds := make([]string, len(excludeServerEntries))
+	for index, serverEntry := range excludeServerEntries {
+		excludeIds[index] = serverEntry.IpAddress
+	}
+	iterator = &ServerEntryIterator{
+		region:     region,
+		protocol:   protocol,
+		excludeIds: excludeIds,
+	}
+	err = iterator.Reset()
 	if err != nil {
 		return nil, err
 	}
-	return cycler, nil
+	return iterator, nil
 }
 
-// Reset a ServerEntryCycler to the start of its cycle. The next
+// Reset a NewServerEntryIterator to the start of its cycle. The next
 // call to Next will return the first server entry.
-func (cycler *ServerEntryCycler) Reset() error {
-	cycler.Close()
+func (iterator *ServerEntryIterator) Reset() error {
+	iterator.Close()
 	transaction, err := singleton.db.Begin()
 	if err != nil {
 		return ContextError(err)
 	}
 	var cursor *sql.Rows
-	if cycler.region == "" {
-		cursor, err = transaction.Query(
-			"select data from serverEntry order by rank desc;")
-	} else {
-		cursor, err = transaction.Query(
-			"select data from serverEntry where region = ? order by rank desc;",
-			cycler.region)
-	}
+	whereClause, whereParams := makeServerEntryWhereClause(
+		iterator.region, iterator.protocol, iterator.excludeIds)
+	query := "select data from serverEntry" + whereClause + " order by rank desc;"
+	cursor, err = transaction.Query(query, whereParams...)
 	if err != nil {
 		transaction.Rollback()
 		return ContextError(err)
 	}
-	cycler.isReset = true
-	cycler.transaction = transaction
-	cycler.cursor = cursor
+	iterator.transaction = transaction
+	iterator.cursor = cursor
 	return nil
 }
 
-// Close cleans up resources associated with a ServerEntryCycler.
-func (cycler *ServerEntryCycler) Close() {
-	if cycler.cursor != nil {
-		cycler.cursor.Close()
+// Close cleans up resources associated with a ServerEntryIterator.
+func (iterator *ServerEntryIterator) Close() {
+	if iterator.cursor != nil {
+		iterator.cursor.Close()
 	}
-	cycler.cursor = nil
-	if cycler.transaction != nil {
-		cycler.transaction.Rollback()
+	iterator.cursor = nil
+	if iterator.transaction != nil {
+		iterator.transaction.Rollback()
 	}
-	cycler.transaction = nil
+	iterator.transaction = nil
 }
 
-// Next returns the next server entry, by rank, for a ServerEntryCycler. When
-// the ServerEntryCycler has worked through all known server entries, Next will
-// call Reset and start over and return the first server entry again.
-func (cycler *ServerEntryCycler) Next() (serverEntry *ServerEntry, err error) {
+// Next returns the next server entry, by rank, for a ServerEntryIterator.
+// Returns nil with no error when there is no next item.
+func (iterator *ServerEntryIterator) Next() (serverEntry *ServerEntry, err error) {
 	defer func() {
 		if err != nil {
-			cycler.Close()
+			iterator.Close()
 		}
 	}()
-	for !cycler.cursor.Next() {
-		err = cycler.cursor.Err()
-		if err != nil {
-			return nil, ContextError(err)
-		}
-		if cycler.isReset {
-			return nil, ContextError(errors.New("no server entries"))
-		}
-		err = cycler.Reset()
+	if !iterator.cursor.Next() {
+		err = iterator.cursor.Err()
 		if err != nil {
 			return nil, ContextError(err)
 		}
+		// There is no next item
+		return nil, nil
 	}
-	cycler.isReset = false
+
 	var data []byte
-	err = cycler.cursor.Scan(&data)
+	err = iterator.cursor.Scan(&data)
 	if err != nil {
 		return nil, ContextError(err)
 	}
@@ -271,24 +289,62 @@ func (cycler *ServerEntryCycler) Next() (serverEntry *ServerEntry, err error) {
 	return serverEntry, nil
 }
 
+func makeServerEntryWhereClause(
+	region, protocol string, excludeIds []string) (whereClause string, whereParams []interface{}) {
+	whereClause = ""
+	whereParams = make([]interface{}, 0)
+	if region != "" {
+		whereClause += " where region = ?"
+		whereParams = append(whereParams, region)
+	}
+	if protocol != "" {
+		if len(whereClause) > 0 {
+			whereClause += " and"
+		} else {
+			whereClause += " where"
+		}
+		whereClause +=
+			" exists (select 1 from serverEntryProtocol where protocol = ? and serverEntryId = serverEntry.id)"
+		whereParams = append(whereParams, protocol)
+	}
+	if len(excludeIds) > 0 {
+		if len(whereClause) > 0 {
+			whereClause += " and"
+		} else {
+			whereClause += " where"
+		}
+		whereClause += " id in ("
+		for index, id := range excludeIds {
+			if index > 0 {
+				whereClause += ", "
+			}
+			whereClause += "?"
+			whereParams = append(whereParams, id)
+		}
+		whereClause += ")"
+	}
+	return whereClause, whereParams
+}
+
 // HasServerEntries returns true if the data store contains at
-// least one server entry (for the specified region, in not blank).
-func HasServerEntries(region string) bool {
+// least one server entry (for the specified region and/or protocol,
+// when not blank).
+func HasServerEntries(region, protocol string) bool {
 	initDataStore()
-	var err error
 	var count int
+	whereClause, whereParams := makeServerEntryWhereClause(region, protocol, nil)
+	query := "select count(*) from serverEntry" + whereClause
+	err := singleton.db.QueryRow(query, whereParams...).Scan(&count)
+
 	if region == "" {
-		err = singleton.db.QueryRow("select count(*) from serverEntry;").Scan(&count)
-		if err == nil {
-			Notice(NOTICE_INFO, "servers: %d", count)
-		}
-	} else {
-		err = singleton.db.QueryRow(
-			"select count(*) from serverEntry where region = ?;", region).Scan(&count)
-		if err == nil {
-			Notice(NOTICE_INFO, "servers for region %s: %d", region, count)
-		}
+		region = "(any)"
+	}
+	if protocol == "" {
+		protocol = "(any)"
 	}
+	Notice(NOTICE_INFO, "servers for region %s and protocol %s: %d",
+		region, protocol, count)
+
 	return err == nil && count > 0
 }
 

+ 7 - 4
psiphon/defaults.go

@@ -24,19 +24,22 @@ import (
 )
 
 const (
-	VERSION                                  = "0.0.2"
+	VERSION                                  = "0.0.3"
 	DATA_STORE_FILENAME                      = "psiphon.db"
-	FETCH_REMOTE_SERVER_LIST_TIMEOUT         = 5 * time.Second
+	CONNECTION_WORKER_POOL_SIZE              = 10
+	TUNNEL_POOL_SIZE                         = 1
 	TUNNEL_CONNECT_TIMEOUT                   = 15 * time.Second
 	TUNNEL_READ_TIMEOUT                      = 0 * time.Second
 	TUNNEL_WRITE_TIMEOUT                     = 5 * time.Second
 	TUNNEL_SSH_KEEP_ALIVE_PERIOD             = 60 * time.Second
 	ESTABLISH_TUNNEL_TIMEOUT                 = 60 * time.Second
-	CONNECTION_WORKER_POOL_SIZE              = 10
+	ESTABLISH_TUNNEL_PAUSE_PERIOD            = 10 * time.Second
+	PORT_FORWARD_FAILURE_THRESHOLD           = 10
 	HTTP_PROXY_ORIGIN_SERVER_TIMEOUT         = 15 * time.Second
+	HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST = 50
+	FETCH_REMOTE_SERVER_LIST_TIMEOUT         = 5 * time.Second
 	FETCH_REMOTE_SERVER_LIST_RETRY_TIMEOUT   = 5 * time.Second
 	FETCH_REMOTE_SERVER_LIST_STALE_TIMEOUT   = 6 * time.Hour
 	PSIPHON_API_CLIENT_SESSION_ID_LENGTH     = 16
 	PSIPHON_API_SERVER_TIMEOUT               = 20 * time.Second
-	HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST = 50
 )

+ 38 - 37
psiphon/httpProxy.go

@@ -31,39 +31,39 @@ import (
 // HttpProxy is a HTTP server that relays HTTP requests through
 // the tunnel SSH client.
 type HttpProxy struct {
-	tunnel        *Tunnel
-	stoppedSignal chan struct{}
-	listener      net.Listener
-	waitGroup     *sync.WaitGroup
-	httpRelay     *http.Transport
-	openConns     *Conns
+	controller     *Controller
+	listener       net.Listener
+	serveWaitGroup *sync.WaitGroup
+	httpRelay      *http.Transport
+	openConns      *Conns
 }
 
 // NewHttpProxy initializes and runs a new HTTP proxy server.
-func NewHttpProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{}) (proxy *HttpProxy, err error) {
-	listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
+func NewHttpProxy(controller *Controller) (proxy *HttpProxy, err error) {
+	listener, err := net.Listen(
+		"tcp", fmt.Sprintf("127.0.0.1:%d", controller.config.LocalHttpProxyPort))
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
-	tunnelledDialer := func(_, targetAddress string) (conn net.Conn, err error) {
+	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
 		// TODO: connect timeout?
-		return tunnel.sshClient.Dial("tcp", targetAddress)
+		return controller.dialWithTunnel(addr)
 	}
+	// TODO: also use http.Client, with its Timeout field?
 	transport := &http.Transport{
-		Dial:                  tunnelledDialer,
+		Dial:                  tunneledDialer,
 		MaxIdleConnsPerHost:   HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST,
 		ResponseHeaderTimeout: HTTP_PROXY_ORIGIN_SERVER_TIMEOUT,
 	}
 	proxy = &HttpProxy{
-		tunnel:        tunnel,
-		stoppedSignal: stoppedSignal,
-		listener:      listener,
-		waitGroup:     new(sync.WaitGroup),
-		httpRelay:     transport,
-		openConns:     new(Conns),
-	}
-	proxy.waitGroup.Add(1)
-	go proxy.serveHttpRequests()
+		controller:     controller,
+		listener:       listener,
+		serveWaitGroup: new(sync.WaitGroup),
+		httpRelay:      transport,
+		openConns:      new(Conns),
+	}
+	proxy.serveWaitGroup.Add(1)
+	go proxy.serve()
 	Notice(NOTICE_HTTP_PROXY, "local HTTP proxy running at address %s", proxy.listener.Addr().String())
 	return proxy, nil
 }
@@ -71,7 +71,7 @@ func NewHttpProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{}) (
 // Close terminates the HTTP server.
 func (proxy *HttpProxy) Close() {
 	proxy.listener.Close()
-	proxy.waitGroup.Wait()
+	proxy.serveWaitGroup.Wait()
 	// Close local->proxy persistent connections
 	proxy.openConns.CloseAll()
 	// Close idle proxy->origin persistent connections
@@ -105,7 +105,7 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
 			return
 		}
 		go func() {
-			err := proxy.httpConnectHandler(proxy.tunnel, conn, request.URL.Host)
+			err := proxy.httpConnectHandler(conn, request.URL.Host)
 			if err != nil {
 				Notice(NOTICE_ALERT, "%s", ContextError(err))
 			}
@@ -117,12 +117,14 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
 		http.Error(responseWriter, "", http.StatusInternalServerError)
 		return
 	}
+
 	// Transform request struct before using as input to relayed request
 	request.Close = false
 	request.RequestURI = ""
 	for _, key := range hopHeaders {
 		request.Header.Del(key)
 	}
+
 	// Relay the HTTP request and get the response
 	response, err := proxy.httpRelay.RoundTrip(request)
 	if err != nil {
@@ -131,6 +133,7 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
 		return
 	}
 	defer response.Body.Close()
+
 	// Relay the remote response headers
 	for _, key := range hopHeaders {
 		response.Header.Del(key)
@@ -143,6 +146,7 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
 			responseWriter.Header().Add(key, value)
 		}
 	}
+
 	// Relay the response code and body
 	responseWriter.WriteHeader(response.StatusCode)
 	_, err = io.Copy(responseWriter, response.Body)
@@ -179,20 +183,20 @@ var hopHeaders = []string{
 	"Upgrade",
 }
 
-func (proxy *HttpProxy) httpConnectHandler(tunnel *Tunnel, localHttpConn net.Conn, target string) (err error) {
-	defer localHttpConn.Close()
-	defer proxy.openConns.Remove(localHttpConn)
-	proxy.openConns.Add(localHttpConn)
-	remoteSshForward, err := tunnel.sshClient.Dial("tcp", target)
+func (proxy *HttpProxy) httpConnectHandler(localConn net.Conn, target string) (err error) {
+	defer localConn.Close()
+	defer proxy.openConns.Remove(localConn)
+	proxy.openConns.Add(localConn)
+	remoteConn, err := proxy.controller.dialWithTunnel(target)
 	if err != nil {
 		return ContextError(err)
 	}
-	defer remoteSshForward.Close()
-	_, err = localHttpConn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
+	defer remoteConn.Close()
+	_, err = localConn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
 	if err != nil {
 		return ContextError(err)
 	}
-	relayPortForward(localHttpConn, remoteSshForward)
+	Relay(localConn, remoteConn)
 	return nil
 }
 
@@ -213,9 +217,9 @@ func (proxy *HttpProxy) httpConnStateCallback(conn net.Conn, connState http.Conn
 	}
 }
 
-func (proxy *HttpProxy) serveHttpRequests() {
+func (proxy *HttpProxy) serve() {
 	defer proxy.listener.Close()
-	defer proxy.waitGroup.Done()
+	defer proxy.serveWaitGroup.Done()
 	httpServer := &http.Server{
 		Handler:   proxy,
 		ConnState: proxy.httpConnStateCallback,
@@ -223,10 +227,7 @@ func (proxy *HttpProxy) serveHttpRequests() {
 	// Note: will be interrupted by listener.Close() call made by proxy.Close()
 	err := httpServer.Serve(proxy.listener)
 	if err != nil {
-		select {
-		case proxy.stoppedSignal <- *new(struct{}):
-		default:
-		}
+		proxy.controller.SignalFailure()
 		Notice(NOTICE_ALERT, "%s", ContextError(err))
 	}
 	Notice(NOTICE_HTTP_PROXY, "HTTP proxy stopped")

+ 9 - 1
psiphon/meekConn.go

@@ -91,6 +91,7 @@ type MeekConn struct {
 func DialMeek(
 	serverEntry *ServerEntry, sessionId string,
 	useFronting bool, config *DialConfig) (meek *MeekConn, err error) {
+
 	// Configure transport
 	// Note: MeekConn has its own PendingConns to manage the underlying HTTP transport connections,
 	// which may be interrupted on MeekConn.Close(). This code previously used the establishTunnel
@@ -121,6 +122,7 @@ func DialMeek(
 		host = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
 		dialer = NewTCPDialer(configCopy)
 	}
+
 	// Scheme is always "http". Otherwise http.Transport will try to do another TLS
 	// handshake inside the explicit TLS session (in fronting mode).
 	url := &url.URL{
@@ -132,10 +134,12 @@ func DialMeek(
 	if err != nil {
 		return nil, ContextError(err)
 	}
+	// TODO: also use http.Client, with its Timeout field?
 	transport := &http.Transport{
 		Dial: dialer,
 		ResponseHeaderTimeout: TUNNEL_WRITE_TIMEOUT,
 	}
+
 	// The main loop of a MeekConn is run in the relay() goroutine.
 	// A MeekConn implements net.Conn concurrency semantics:
 	// "Multiple goroutines may invoke methods on a Conn simultaneously."
@@ -312,7 +316,7 @@ func (meek *MeekConn) replaceSendBuffer(sendBuffer *bytes.Buffer) {
 	}
 }
 
-// relay sends and receives tunnelled traffic (payload). An HTTP request is
+// relay sends and receives tunneled traffic (payload). An HTTP request is
 // triggered when data is in the write queue or at a polling interval.
 // There's a geometric increase, up to a maximum, in the polling interval when
 // no data is exchanged. Only one HTTP request is in flight at a time.
@@ -448,6 +452,7 @@ type meekCookieData struct {
 // In unfronted meek mode, the cookie is visible over the adversary network, so the
 // cookie is encrypted and obfuscated.
 func makeCookie(serverEntry *ServerEntry, sessionId string) (cookie *http.Cookie, err error) {
+
 	// Make the JSON data
 	serverAddress := fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort)
 	cookieData := &meekCookieData{
@@ -459,6 +464,7 @@ func makeCookie(serverEntry *ServerEntry, sessionId string) (cookie *http.Cookie
 	if err != nil {
 		return nil, ContextError(err)
 	}
+
 	// Encrypt the JSON data
 	// NaCl box is used for encryption. The peer public key comes from the server entry.
 	// Nonce is always all zeros, and is not sent in the cookie (the server also uses an all-zero nonce).
@@ -481,6 +487,7 @@ func makeCookie(serverEntry *ServerEntry, sessionId string) (cookie *http.Cookie
 	encryptedCookie := make([]byte, 32+len(box))
 	copy(encryptedCookie[0:32], ephemeralPublicKey[0:32])
 	copy(encryptedCookie[32:], box)
+
 	// Obfuscate the encrypted data
 	obfuscator, err := NewObfuscator(
 		&ObfuscatorConfig{Keyword: serverEntry.MeekObfuscatedKey, MaxPadding: MEEK_COOKIE_MAX_PADDING})
@@ -491,6 +498,7 @@ func makeCookie(serverEntry *ServerEntry, sessionId string) (cookie *http.Cookie
 	seedLen := len(obfuscatedCookie)
 	obfuscatedCookie = append(obfuscatedCookie, encryptedCookie...)
 	obfuscator.ObfuscateClientToServer(obfuscatedCookie[seedLen:])
+
 	// Format the HTTP cookie
 	// The format is <random letter 'A'-'Z'>=<base64 data>, which is intended to match common cookie formats.
 	A := int('A')

+ 24 - 12
psiphon/obfuscatedSshConn.go

@@ -84,7 +84,7 @@ const (
 func NewObfuscatedSshConn(conn net.Conn, obfuscationKeyword string) (*ObfuscatedSshConn, error) {
 	obfuscator, err := NewObfuscator(&ObfuscatorConfig{Keyword: obfuscationKeyword})
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	return &ObfuscatedSshConn{
 		Conn:       conn,
@@ -111,7 +111,7 @@ func (conn *ObfuscatedSshConn) Write(buffer []byte) (n int, err error) {
 	}
 	err = conn.transformAndWrite(buffer)
 	if err != nil {
-		return 0, err
+		return 0, ContextError(err)
 	}
 	// Reports that we wrote all the bytes
 	// (althogh we may have buffered some or all)
@@ -157,6 +157,7 @@ func (conn *ObfuscatedSshConn) Write(buffer []byte) (n int, err error) {
 // packet may need to be buffered due to partial reading.
 func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error) {
 	nextState := conn.readState
+
 	switch conn.readState {
 	case OBFUSCATION_READ_STATE_SERVER_IDENTIFICATION_LINE:
 		if len(conn.readBuffer) == 0 {
@@ -167,7 +168,7 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 				for len(conn.readBuffer) < SSH_MAX_SERVER_LINE_LENGTH {
 					_, err := io.ReadFull(conn.Conn, oneByte[:])
 					if err != nil {
-						return 0, err
+						return 0, ContextError(err)
 					}
 					conn.obfuscator.ObfuscateServerToClient(oneByte[:])
 					conn.readBuffer = append(conn.readBuffer, oneByte[0])
@@ -177,7 +178,7 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 					}
 				}
 				if !validLine {
-					return 0, errors.New("ObfuscatedSshConn: invalid server line")
+					return 0, ContextError(errors.New("ObfuscatedSshConn: invalid server line"))
 				}
 				if bytes.HasPrefix(conn.readBuffer, []byte("SSH-")) {
 					break
@@ -187,23 +188,24 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 			}
 		}
 		nextState = OBFUSCATION_READ_STATE_SERVER_KEX_PACKETS
+
 	case OBFUSCATION_READ_STATE_SERVER_KEX_PACKETS:
 		if len(conn.readBuffer) == 0 {
 			prefix := make([]byte, SSH_PACKET_PREFIX_LENGTH)
 			_, err := io.ReadFull(conn.Conn, prefix)
 			if err != nil {
-				return 0, err
+				return 0, ContextError(err)
 			}
 			conn.obfuscator.ObfuscateServerToClient(prefix)
 			packetLength, _, payloadLength, messageLength := getSshPacketPrefix(prefix)
 			if packetLength > SSH_MAX_PACKET_LENGTH {
-				return 0, errors.New("ObfuscatedSshConn: ssh packet length too large")
+				return 0, ContextError(errors.New("ObfuscatedSshConn: ssh packet length too large"))
 			}
 			conn.readBuffer = make([]byte, messageLength)
 			copy(conn.readBuffer, prefix)
 			_, err = io.ReadFull(conn.Conn, conn.readBuffer[len(prefix):])
 			if err != nil {
-				return 0, err
+				return 0, ContextError(err)
 			}
 			conn.obfuscator.ObfuscateServerToClient(conn.readBuffer[len(prefix):])
 			if payloadLength > 0 {
@@ -213,11 +215,14 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 				}
 			}
 		}
+
 	case OBFUSCATION_READ_STATE_FLUSH:
 		nextState = OBFUSCATION_READ_STATE_FINISHED
+
 	case OBFUSCATION_READ_STATE_FINISHED:
 		panic("ObfuscatedSshConn: invalid read state")
 	}
+
 	n = copy(buffer, conn.readBuffer)
 	conn.readBuffer = conn.readBuffer[n:]
 	if len(conn.readBuffer) == 0 {
@@ -258,15 +263,18 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 // (The transformer can do this since only the payload and not the padding of
 // these packets is authenticated in the "exchange hash").
 func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
+
 	if conn.writeState == OBFUSCATION_WRITE_STATE_SEND_CLIENT_SEED_MESSAGE {
 		_, err = conn.Conn.Write(conn.obfuscator.ConsumeSeedMessage())
 		if err != nil {
-			return err
+			return ContextError(err)
 		}
 		conn.writeState = OBFUSCATION_WRITE_STATE_CLIENT_IDENTIFICATION_LINE
 	}
+
 	conn.writeBuffer = append(conn.writeBuffer, buffer...)
 	var messageBuffer []byte
+
 	switch conn.writeState {
 	case OBFUSCATION_WRITE_STATE_CLIENT_IDENTIFICATION_LINE:
 		index := bytes.Index(conn.writeBuffer, []byte("\r\n"))
@@ -276,6 +284,7 @@ func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
 			conn.writeBuffer = conn.writeBuffer[messageLength:]
 			conn.writeState = OBFUSCATION_WRITE_STATE_CLIENT_KEX_PACKETS
 		}
+
 	case OBFUSCATION_WRITE_STATE_CLIENT_KEX_PACKETS:
 		for len(conn.writeBuffer) >= SSH_PACKET_PREFIX_LENGTH {
 			packetLength, paddingLength, payloadLength, messageLength := getSshPacketPrefix(conn.writeBuffer)
@@ -297,33 +306,36 @@ func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
 			if possiblePaddings > 0 {
 				selectedPadding, err := MakeSecureRandomInt(possiblePaddings)
 				if err != nil {
-					return err
+					return ContextError(err)
 				}
 				extraPaddingLength := selectedPadding * SSH_PADDING_MULTIPLE
 				extraPadding, err := MakeSecureRandomBytes(extraPaddingLength)
 				if err != nil {
-					return err
+					return ContextError(err)
 				}
 				setSshPacketPrefix(
 					messageBuffer, packetLength+extraPaddingLength, paddingLength+extraPaddingLength)
 				messageBuffer = append(messageBuffer, extraPadding...)
 			}
 		}
+
 	case OBFUSCATION_WRITE_STATE_FINISHED:
 		panic("ObfuscatedSshConn: invalid write state")
 	}
+
 	if messageBuffer != nil {
 		conn.obfuscator.ObfuscateClientToServer(messageBuffer)
 		_, err := conn.Conn.Write(messageBuffer)
 		if err != nil {
-			return err
+			return ContextError(err)
 		}
 	}
+
 	if conn.writeState == OBFUSCATION_WRITE_STATE_FINISHED {
 		// After SSH_MSG_NEWKEYS, any remaining bytes are un-obfuscated
 		_, err := conn.Conn.Write(conn.writeBuffer)
 		if err != nil {
-			return err
+			return ContextError(err)
 		}
 		// The buffer memory is no longer used
 		conn.writeBuffer = nil

+ 13 - 13
psiphon/obfuscator.go

@@ -57,23 +57,23 @@ type ObfuscatorConfig struct {
 func NewObfuscator(config *ObfuscatorConfig) (obfuscator *Obfuscator, err error) {
 	seed, err := MakeSecureRandomBytes(OBFUSCATE_SEED_LENGTH)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	clientToServerKey, err := deriveKey(seed, []byte(config.Keyword), []byte(OBFUSCATE_CLIENT_TO_SERVER_IV))
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	serverToClientKey, err := deriveKey(seed, []byte(config.Keyword), []byte(OBFUSCATE_SERVER_TO_CLIENT_IV))
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	clientToServerCipher, err := rc4.NewCipher(clientToServerKey)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	serverToClientCipher, err := rc4.NewCipher(serverToClientKey)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	maxPadding := OBFUSCATE_MAX_PADDING
 	if config.MaxPadding > 0 {
@@ -81,7 +81,7 @@ func NewObfuscator(config *ObfuscatorConfig) (obfuscator *Obfuscator, err error)
 	}
 	seedMessage, err := makeSeedMessage(maxPadding, seed, clientToServerCipher)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	return &Obfuscator{
 		seedMessage:          seedMessage,
@@ -119,7 +119,7 @@ func deriveKey(seed, keyword, iv []byte) ([]byte, error) {
 		digest = h.Sum(nil)
 	}
 	if len(digest) < OBFUSCATE_KEY_LENGTH {
-		return nil, errors.New("insufficient bytes for obfuscation key")
+		return nil, ContextError(errors.New("insufficient bytes for obfuscation key"))
 	}
 	return digest[0:OBFUSCATE_KEY_LENGTH], nil
 }
@@ -127,28 +127,28 @@ func deriveKey(seed, keyword, iv []byte) ([]byte, error) {
 func makeSeedMessage(maxPadding int, seed []byte, clientToServerCipher *rc4.Cipher) ([]byte, error) {
 	paddingLength, err := MakeSecureRandomInt(maxPadding)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	padding, err := MakeSecureRandomBytes(paddingLength)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	buffer := new(bytes.Buffer)
 	err = binary.Write(buffer, binary.BigEndian, seed)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	err = binary.Write(buffer, binary.BigEndian, uint32(OBFUSCATE_MAGIC_VALUE))
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	err = binary.Write(buffer, binary.BigEndian, uint32(paddingLength))
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	err = binary.Write(buffer, binary.BigEndian, padding)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	seedMessage := buffer.Bytes()
 	clientToServerCipher.XORKeyStream(seedMessage[len(seed):], seedMessage[len(seed):])

+ 18 - 2
psiphon/remoteServerList.go

@@ -45,20 +45,35 @@ type RemoteServerList struct {
 // config.RemoteServerListUrl; validates its digital signature using the
 // public key config.RemoteServerListSignaturePublicKey; and parses the
 // data field into ServerEntry records.
-func FetchRemoteServerList(config *Config) (err error) {
+func FetchRemoteServerList(config *Config, pendingConns *Conns) (err error) {
 	Notice(NOTICE_INFO, "fetching remote server list")
+
+	// Note: pendingConns may be used to interrupt the fetch remote server list
+	// request. BindToDevice may be used to exclude requests from VPN routing.
+	dialConfig := &DialConfig{
+		PendingConns:               pendingConns,
+		BindToDeviceServiceAddress: config.BindToDeviceServiceAddress,
+		BindToDeviceDnsServer:      config.BindToDeviceDnsServer,
+	}
+	transport := &http.Transport{
+		Dial: NewTCPDialer(dialConfig),
+	}
 	httpClient := http.Client{
-		Timeout: FETCH_REMOTE_SERVER_LIST_TIMEOUT,
+		Timeout:   FETCH_REMOTE_SERVER_LIST_TIMEOUT,
+		Transport: transport,
 	}
+
 	response, err := httpClient.Get(config.RemoteServerListUrl)
 	if err != nil {
 		return ContextError(err)
 	}
 	defer response.Body.Close()
+
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
 		return ContextError(err)
 	}
+
 	var remoteServerList *RemoteServerList
 	err = json.Unmarshal(body, &remoteServerList)
 	if err != nil {
@@ -68,6 +83,7 @@ func FetchRemoteServerList(config *Config) (err error) {
 	if err != nil {
 		return ContextError(err)
 	}
+
 	for _, encodedServerEntry := range strings.Split(remoteServerList.Data, "\n") {
 		serverEntry, err := DecodeServerEntry(encodedServerEntry)
 		if err != nil {

+ 0 - 227
psiphon/runTunnel.go

@@ -1,227 +0,0 @@
-/*
- * Copyright (c) 2014, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program.  If not, see <http://www.gnu.org/licenses/>.
- *
- */
-
-// Package psiphon implements the core tunnel functionality of a Psiphon client.
-// The main interface is RunTunnelForever, which obtains lists of servers,
-// establishes tunnel connections, and runs local proxies through which
-// tunnelled traffic may be sent.
-package psiphon
-
-import (
-	"errors"
-	"fmt"
-	"log"
-	"os"
-	"sync"
-	"time"
-)
-
-// establishTunnelWorker pulls candidates from the potential tunnel queue, establishes
-// a connection to the tunnel server, and delivers the established tunnel to a channel,
-// if there's not already an established tunnel. This function is to be used in a pool
-// of goroutines.
-func establishTunnelWorker(
-	config *Config,
-	sessionId string,
-	workerWaitGroup *sync.WaitGroup,
-	candidateServerEntries chan *ServerEntry,
-	broadcastStopWorkers chan struct{},
-	pendingConns *Conns,
-	establishedTunnels chan *Tunnel) {
-
-	defer workerWaitGroup.Done()
-	for serverEntry := range candidateServerEntries {
-		// Note: don't receive from candidateQueue and broadcastStopWorkers in the same
-		// select, since we want to prioritize receiving the stop signal
-		select {
-		case <-broadcastStopWorkers:
-			return
-		default:
-		}
-		tunnel, err := EstablishTunnel(config, sessionId, serverEntry, pendingConns)
-		if err != nil {
-			// TODO: distingush case where conn is interrupted?
-			Notice(NOTICE_INFO, "failed to connect to %s: %s", serverEntry.IpAddress, err)
-		} else {
-			Notice(NOTICE_INFO, "successfully connected to %s", serverEntry.IpAddress)
-			select {
-			case establishedTunnels <- tunnel:
-			default:
-				discardTunnel(tunnel)
-			}
-		}
-	}
-}
-
-// discardTunnel is used to dispose of a successful connection that is
-// no longer required (another tunnel has already been selected). Since
-// the connection was successful, the server entry is still promoted.
-func discardTunnel(tunnel *Tunnel) {
-	Notice(NOTICE_INFO, "discard connection to %s", tunnel.serverEntry.IpAddress)
-	PromoteServerEntry(tunnel.serverEntry.IpAddress)
-	tunnel.Close()
-}
-
-// establishTunnel coordinates a worker pool of goroutines to attempt several
-// tunnel connections in parallel, and this process is stopped once the first
-// tunnel is established.
-func establishTunnel(config *Config, sessionId string) (tunnel *Tunnel, err error) {
-	workerWaitGroup := new(sync.WaitGroup)
-	candidateServerEntries := make(chan *ServerEntry)
-	pendingConns := new(Conns)
-	establishedTunnels := make(chan *Tunnel, 1)
-	timeout := time.After(ESTABLISH_TUNNEL_TIMEOUT)
-	broadcastStopWorkers := make(chan struct{})
-	for i := 0; i < config.ConnectionWorkerPoolSize; i++ {
-		workerWaitGroup.Add(1)
-		go establishTunnelWorker(
-			config, sessionId,
-			workerWaitGroup, candidateServerEntries, broadcastStopWorkers,
-			pendingConns, establishedTunnels)
-	}
-	// TODO: add a throttle after each full cycle?
-	// Note: errors fall through to ensure worker and channel cleanup (is started, at least)
-	var selectedTunnel *Tunnel
-	cycler, err := NewServerEntryCycler(config.EgressRegion)
-	for selectedTunnel == nil && err == nil {
-		var serverEntry *ServerEntry
-		// Note: don't mask err here, we want to reference it after the loop
-		serverEntry, err = cycler.Next()
-		if err != nil {
-			break
-		}
-		select {
-		case candidateServerEntries <- serverEntry:
-		case selectedTunnel = <-establishedTunnels:
-			Notice(NOTICE_INFO, "selected connection to %s", selectedTunnel.serverEntry.IpAddress)
-		case <-timeout:
-			err = errors.New("timeout establishing tunnel")
-		}
-	}
-	if cycler != nil {
-		cycler.Close()
-	}
-	close(candidateServerEntries)
-	close(broadcastStopWorkers)
-	// Clean up is now asynchronous since Windows doesn't support interruptible connections
-	go func() {
-		// Interrupt any partial connections in progress, so that
-		// the worker will terminate immediately
-		pendingConns.CloseAll()
-		workerWaitGroup.Wait()
-		// Drain any excess tunnels
-		close(establishedTunnels)
-		for tunnel := range establishedTunnels {
-			discardTunnel(tunnel)
-		}
-		// Note: only call this PromoteServerEntry after all discards so the selected
-		// tunnel is the top ranked
-		if selectedTunnel != nil {
-			PromoteServerEntry(selectedTunnel.serverEntry.IpAddress)
-		}
-	}()
-	// Note: end of error fall through
-	if err != nil {
-		return nil, ContextError(err)
-	}
-	return selectedTunnel, nil
-}
-
-// runTunnel establishes a tunnel session and runs local proxies that make use of
-// that tunnel. The tunnel connection is monitored and this function returns an
-// error when the tunnel unexpectedly disconnects.
-func runTunnel(config *Config) error {
-	Notice(NOTICE_INFO, "establishing tunnel")
-	sessionId, err := MakeSessionId()
-	if err != nil {
-		return ContextError(err)
-	}
-	tunnel, err := establishTunnel(config, sessionId)
-	if err != nil {
-		return ContextError(err)
-	}
-	defer tunnel.Close()
-	// Tunnel connection and local proxies will send signals to this channel
-	// when they close or stop. Signal senders should not block. Allows at
-	// least one stop signal to be sent before there is a receiver.
-	stopTunnelSignal := make(chan struct{}, 1)
-	err = tunnel.conn.SetClosedSignal(stopTunnelSignal)
-	if err != nil {
-		return fmt.Errorf("failed to set closed signal: %s", err)
-	}
-	socksProxy, err := NewSocksProxy(config.LocalSocksProxyPort, tunnel, stopTunnelSignal)
-	if err != nil {
-		return fmt.Errorf("error initializing local SOCKS proxy: %s", err)
-	}
-	defer socksProxy.Close()
-	httpProxy, err := NewHttpProxy(config.LocalHttpProxyPort, tunnel, stopTunnelSignal)
-	if err != nil {
-		return fmt.Errorf("error initializing local HTTP proxy: %s", err)
-	}
-	defer httpProxy.Close()
-	Notice(NOTICE_INFO, "starting session")
-	localHttpProxyAddress := httpProxy.listener.Addr().String()
-	_, err = NewSession(config, tunnel, localHttpProxyAddress, sessionId)
-	if err != nil {
-		return fmt.Errorf("error starting session: %s", err)
-	}
-	Notice(NOTICE_TUNNEL, "tunnel started")
-	Notice(NOTICE_INFO, "monitoring tunnel")
-	<-stopTunnelSignal
-	Notice(NOTICE_TUNNEL, "tunnel stopped")
-	return nil
-}
-
-// RunTunnelForever executes the main loop of the Psiphon client. It establishes
-// a tunnel and reconnects when the tunnel unexpectedly disconnects.
-// FetchRemoteServerList is used to obtain a fresh list of servers to attempt
-// to connect to.
-func RunTunnelForever(config *Config) {
-	if config.LogFilename != "" {
-		logFile, err := os.OpenFile(config.LogFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
-		if err != nil {
-			Fatal("error opening log file: %s", err)
-		}
-		defer logFile.Close()
-		log.SetOutput(logFile)
-	}
-	Notice(NOTICE_VERSION, VERSION)
-	// TODO: unlike existing Psiphon clients, this code
-	// always makes the fetch remote server list request
-	go func() {
-		for {
-			err := FetchRemoteServerList(config)
-			if err != nil {
-				Notice(NOTICE_ALERT, "failed to fetch remote server list: %s", err)
-				time.Sleep(FETCH_REMOTE_SERVER_LIST_RETRY_TIMEOUT)
-			} else {
-				time.Sleep(FETCH_REMOTE_SERVER_LIST_STALE_TIMEOUT)
-			}
-		}
-	}()
-	for {
-		if HasServerEntries(config.EgressRegion) {
-			err := runTunnel(config)
-			if err != nil {
-				Notice(NOTICE_ALERT, "run tunnel error: %s", err)
-			}
-		}
-		time.Sleep(1 * time.Second)
-	}
-}

+ 17 - 33
psiphon/serverApi.go

@@ -25,6 +25,7 @@ import (
 	"errors"
 	"fmt"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"strconv"
 )
@@ -32,12 +33,10 @@ import (
 // Session is a utility struct which holds all of the data associated
 // with a Psiphon session. In addition to the established tunnel, this
 // includes the session ID (used for Psiphon API requests) and a http
-// client configured to make tunnelled Psiphon API requests.
+// client configured to make tunneled Psiphon API requests.
 type Session struct {
-	sessionId          string
 	config             *Config
 	tunnel             *Tunnel
-	pendingConns       *Conns
 	psiphonHttpsClient *http.Client
 }
 
@@ -45,21 +44,15 @@ type Session struct {
 // Psiphon server and returns a Session struct, initialized with the
 // session ID, for use with subsequent Psiphon server API requests (e.g.,
 // periodic status requests).
-func NewSession(
-	config *Config,
-	tunnel *Tunnel,
-	localHttpProxyAddress, sessionId string) (session *Session, err error) {
+func NewSession(config *Config, tunnel *Tunnel) (session *Session, err error) {
 
-	pendingConns := new(Conns)
-	psiphonHttpsClient, err := makePsiphonHttpsClient(tunnel, pendingConns, localHttpProxyAddress)
+	psiphonHttpsClient, err := makePsiphonHttpsClient(tunnel)
 	if err != nil {
 		return nil, ContextError(err)
 	}
 	session = &Session{
-		sessionId:          sessionId,
 		config:             config,
 		tunnel:             tunnel,
-		pendingConns:       pendingConns,
 		psiphonHttpsClient: psiphonHttpsClient,
 	}
 	// Sending two seperate requests is a legacy from when the handshake was
@@ -174,7 +167,7 @@ func (session *Session) doConnectedRequest() error {
 	}
 	url := session.buildRequestUrl(
 		"connected",
-		&ExtraParam{"session_id", session.sessionId},
+		&ExtraParam{"session_id", session.tunnel.sessionId},
 		&ExtraParam{"last_connected", lastConnected})
 	responseBody, err := session.doGetRequest(url)
 	if err != nil {
@@ -210,7 +203,7 @@ func (session *Session) buildRequestUrl(path string, extraParams ...*ExtraParam)
 	requestUrl.WriteString("/")
 	requestUrl.WriteString(path)
 	requestUrl.WriteString("?client_session_id=")
-	requestUrl.WriteString(session.sessionId)
+	requestUrl.WriteString(session.tunnel.sessionId)
 	requestUrl.WriteString("&server_secret=")
 	requestUrl.WriteString(session.tunnel.serverEntry.WebServerSecret)
 	requestUrl.WriteString("&propagation_channel_id=")
@@ -253,36 +246,24 @@ func (session *Session) doGetRequest(requestUrl string) (responseBody []byte, er
 	return body, nil
 }
 
-// makeHttpsClient creates a Psiphon HTTPS client that uses the local http proxy to tunnel
-// requests and which validates the web server using the Psiphon server entry web server certificate.
+// makeHttpsClient creates a Psiphon HTTPS client that tunnels requests and which validates
+// the web server using the Psiphon server entry web server certificate.
 // This is not a general purpose HTTPS client.
 // As the custom dialer makes an explicit TLS connection, URLs submitted to the returned
 // http.Client should use the "http://" scheme. Otherwise http.Transport will try to do another TLS
 // handshake inside the explicit TLS session.
-func makePsiphonHttpsClient(
-	tunnel *Tunnel, pendingConns *Conns,
-	localHttpProxyAddress string) (httpsClient *http.Client, err error) {
-
+func makePsiphonHttpsClient(tunnel *Tunnel) (httpsClient *http.Client, err error) {
 	certificate, err := DecodeCertificate(tunnel.serverEntry.WebServerCertificate)
 	if err != nil {
 		return nil, ContextError(err)
 	}
-	// Note: This use of readTimeout will tear down persistent HTTP connections, which is not the
-	// intended purpose. The readTimeout is to abort NewSession when the Psiphon server responds to
-	// handshake/connected requests but fails to deliver the response body (e.g., ResponseHeaderTimeout
-	// is not sufficient to timeout this case).
-	tcpDialer := NewTCPDialer(
-		&DialConfig{
-			ConnectTimeout: PSIPHON_API_SERVER_TIMEOUT,
-			ReadTimeout:    PSIPHON_API_SERVER_TIMEOUT,
-			WriteTimeout:   PSIPHON_API_SERVER_TIMEOUT,
-			PendingConns:   pendingConns,
-		})
+	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
+		return tunnel.sshClient.Dial("tcp", addr)
+	}
 	dialer := NewCustomTLSDialer(
 		&CustomTLSConfig{
-			Dial:                    tcpDialer,
+			Dial:                    tunneledDialer,
 			Timeout:                 PSIPHON_API_SERVER_TIMEOUT,
-			HttpProxyAddress:        localHttpProxyAddress,
 			SendServerName:          false,
 			VerifyLegacyCertificate: certificate,
 		})
@@ -290,5 +271,8 @@ func makePsiphonHttpsClient(
 		Dial: dialer,
 		ResponseHeaderTimeout: PSIPHON_API_SERVER_TIMEOUT,
 	}
-	return &http.Client{Transport: transport}, nil
+	return &http.Client{
+		Transport: transport,
+		Timeout:   PSIPHON_API_SERVER_TIMEOUT,
+	}, nil
 }

+ 27 - 52
psiphon/socksProxy.go

@@ -22,7 +22,6 @@ package psiphon
 import (
 	"fmt"
 	socks "github.com/Psiphon-Inc/goptlib"
-	"io"
 	"net"
 	"sync"
 )
@@ -32,30 +31,29 @@ import (
 // the tunnel SSH client and relays traffic through the port
 // forward.
 type SocksProxy struct {
-	tunnel        *Tunnel
-	stoppedSignal chan struct{}
-	listener      *socks.SocksListener
-	waitGroup     *sync.WaitGroup
-	openConns     *Conns
+	controller     *Controller
+	listener       *socks.SocksListener
+	serveWaitGroup *sync.WaitGroup
+	openConns      *Conns
 }
 
 // NewSocksProxy initializes a new SOCKS server. It begins listening for
 // connections, starts a goroutine that runs an accept loop, and returns
 // leaving the accept loop running.
-func NewSocksProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{}) (proxy *SocksProxy, err error) {
-	listener, err := socks.ListenSocks("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
+func NewSocksProxy(controller *Controller) (proxy *SocksProxy, err error) {
+	listener, err := socks.ListenSocks(
+		"tcp", fmt.Sprintf("127.0.0.1:%d", controller.config.LocalSocksProxyPort))
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	proxy = &SocksProxy{
-		tunnel:        tunnel,
-		stoppedSignal: stoppedSignal,
-		listener:      listener,
-		waitGroup:     new(sync.WaitGroup),
-		openConns:     new(Conns),
+		controller:     controller,
+		listener:       listener,
+		serveWaitGroup: new(sync.WaitGroup),
+		openConns:      new(Conns),
 	}
-	proxy.waitGroup.Add(1)
-	go proxy.acceptSocksConnections()
+	proxy.serveWaitGroup.Add(1)
+	go proxy.serve()
 	Notice(NOTICE_SOCKS_PROXY, "local SOCKS proxy running at address %s", proxy.listener.Addr().String())
 	return proxy, nil
 }
@@ -64,60 +62,37 @@ func NewSocksProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{})
 // goroutine to complete.
 func (proxy *SocksProxy) Close() {
 	proxy.listener.Close()
-	proxy.waitGroup.Wait()
+	proxy.serveWaitGroup.Wait()
 	proxy.openConns.CloseAll()
 }
 
-func (proxy *SocksProxy) socksConnectionHandler(tunnel *Tunnel, localSocksConn *socks.SocksConn) (err error) {
-	defer localSocksConn.Close()
-	defer proxy.openConns.Remove(localSocksConn)
-	proxy.openConns.Add(localSocksConn)
-	remoteSshForward, err := tunnel.sshClient.Dial("tcp", localSocksConn.Req.Target)
+func (proxy *SocksProxy) socksConnectionHandler(localConn *socks.SocksConn) (err error) {
+	defer localConn.Close()
+	defer proxy.openConns.Remove(localConn)
+	proxy.openConns.Add(localConn)
+	remoteConn, err := proxy.controller.dialWithTunnel(localConn.Req.Target)
 	if err != nil {
 		return ContextError(err)
 	}
-	defer remoteSshForward.Close()
-	err = localSocksConn.Grant(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 0})
+	defer remoteConn.Close()
+	err = localConn.Grant(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 0})
 	if err != nil {
 		return ContextError(err)
 	}
-	relayPortForward(localSocksConn, remoteSshForward)
+	Relay(localConn, remoteConn)
 	return nil
 }
 
-// relayPortForward is also used by HttpProxy
-func relayPortForward(local, remote net.Conn) {
-	// TODO: page view stats would be done here
-	// TODO: interrupt and stop on proxy.Close()
-	waitGroup := new(sync.WaitGroup)
-	waitGroup.Add(1)
-	go func() {
-		defer waitGroup.Done()
-		_, err := io.Copy(local, remote)
-		if err != nil {
-			Notice(NOTICE_ALERT, "%s", ContextError(err))
-		}
-	}()
-	_, err := io.Copy(remote, local)
-	if err != nil {
-		Notice(NOTICE_ALERT, "%s", ContextError(err))
-	}
-	waitGroup.Wait()
-}
-
-func (proxy *SocksProxy) acceptSocksConnections() {
+func (proxy *SocksProxy) serve() {
 	defer proxy.listener.Close()
-	defer proxy.waitGroup.Done()
+	defer proxy.serveWaitGroup.Done()
 	for {
 		// Note: will be interrupted by listener.Close() call made by proxy.Close()
 		socksConnection, err := proxy.listener.AcceptSocks()
 		if err != nil {
 			Notice(NOTICE_ALERT, "SOCKS proxy accept error: %s", err)
 			if e, ok := err.(net.Error); ok && !e.Temporary() {
-				select {
-				case proxy.stoppedSignal <- *new(struct{}):
-				default:
-				}
+				proxy.controller.SignalFailure()
 				// Fatal error, stop the proxy
 				break
 			}
@@ -125,7 +100,7 @@ func (proxy *SocksProxy) acceptSocksConnections() {
 			continue
 		}
 		go func() {
-			err := proxy.socksConnectionHandler(proxy.tunnel, socksConnection)
+			err := proxy.socksConnectionHandler(socksConnection)
 			if err != nil {
 				Notice(NOTICE_ALERT, "%s", ContextError(err))
 			}

+ 25 - 53
psiphon/tlsDialer.go

@@ -75,10 +75,7 @@ import (
 	"crypto/tls"
 	"crypto/x509"
 	"errors"
-	"fmt"
-	"io"
 	"net"
-	"strings"
 	"time"
 )
 
@@ -91,25 +88,28 @@ func (timeoutError) Temporary() bool { return true }
 // CustomTLSConfig contains parameters to determine the behavior
 // of CustomTLSDial.
 type CustomTLSConfig struct {
+
 	// Dial is the network connection dialer. TLS is layered on
 	// top of a new network connection created with dialer.
 	Dial Dialer
+
 	// Timeout is and optional timeout for combined network
 	// connection dial and TLS handshake.
 	Timeout time.Duration
+
 	// FrontingAddr overrides the "addr" input to Dial when specified
 	FrontingAddr string
-	// HttpProxyAddress specifies an HTTP proxy to be used
-	// (with HTTP CONNECT).
-	HttpProxyAddress string
+
 	// SendServerName specifies whether to use SNI
 	// (tlsdialer functionality)
 	SendServerName bool
+
 	// VerifyLegacyCertificate is a special case self-signed server
 	// certificate case. Ignores IP SANs and basic constraints. No
 	// certificate chain. Just checks that the server presented the
 	// specified certificate.
 	VerifyLegacyCertificate *x509.Certificate
+
 	// TlsConfig is a tls.Config to use in the
 	// non-verifyLegacyCertificate case.
 	TlsConfig *tls.Config
@@ -141,45 +141,36 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 	}
 
 	dialAddr := addr
-	if config.HttpProxyAddress != "" {
-		dialAddr = config.HttpProxyAddress
-	} else if config.FrontingAddr != "" {
+	if config.FrontingAddr != "" {
 		dialAddr = config.FrontingAddr
 	}
 
 	rawConn, err := config.Dial(network, dialAddr)
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 
-	targetAddr := addr
-	if config.FrontingAddr != "" {
-		targetAddr = config.FrontingAddr
-	}
-
-	colonPos := strings.LastIndex(targetAddr, ":")
-	if colonPos == -1 {
-		colonPos = len(targetAddr)
+	hostname, _, err := net.SplitHostPort(dialAddr)
+	if err != nil {
+		return nil, ContextError(err)
 	}
-	hostname := targetAddr[:colonPos]
 
 	tlsConfig := config.TlsConfig
 	if tlsConfig == nil {
 		tlsConfig = &tls.Config{}
 	}
 
-	serverName := tlsConfig.ServerName
+	// Copy config so we can tweak it
+	tlsConfigCopy := new(tls.Config)
+	*tlsConfigCopy = *tlsConfig
 
+	serverName := tlsConfig.ServerName
 	// If no ServerName is set, infer the ServerName
 	// from the hostname we're connecting to.
 	if serverName == "" {
 		serverName = hostname
 	}
 
-	// Copy config so we can tweak it
-	tlsConfigCopy := new(tls.Config)
-	*tlsConfigCopy = *tlsConfig
-
 	if config.SendServerName {
 		// Set the ServerName and rely on the usual logic in
 		// tls.Conn.Handshake() to do its verification
@@ -192,34 +183,11 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 
 	conn := tls.Client(rawConn, tlsConfigCopy)
 
-	establishConnection := func(rawConn net.Conn, conn *tls.Conn) error {
-		// TODO: use the proxy request/response code from net/http/transport.go
-		if config.HttpProxyAddress != "" {
-			connectRequest := fmt.Sprintf(
-				"CONNECT %s HTTP/1.1\r\nHost: %s\r\nConnection: Keep-Alive\r\n\r\n",
-				targetAddr, hostname)
-			_, err := rawConn.Write([]byte(connectRequest))
-			if err != nil {
-				return err
-			}
-			expectedResponse := []byte("HTTP/1.1 200 OK\r\n\r\n")
-			readBuffer := make([]byte, len(expectedResponse))
-			_, err = io.ReadFull(rawConn, readBuffer)
-			if err != nil {
-				return err
-			}
-			if !bytes.Equal(readBuffer, expectedResponse) {
-				return fmt.Errorf("unexpected HTTP proxy response: %s", string(readBuffer))
-			}
-		}
-		return conn.Handshake()
-	}
-
 	if config.Timeout == 0 {
-		err = establishConnection(rawConn, conn)
+		err = conn.Handshake()
 	} else {
 		go func() {
-			errChannel <- establishConnection(rawConn, conn)
+			errChannel <- conn.Handshake()
 		}()
 		err = <-errChannel
 	}
@@ -233,7 +201,7 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 
 	if err != nil {
 		rawConn.Close()
-		return nil, err
+		return nil, ContextError(err)
 	}
 
 	return conn, nil
@@ -242,10 +210,10 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 func verifyLegacyCertificate(conn *tls.Conn, expectedCertificate *x509.Certificate) error {
 	certs := conn.ConnectionState().PeerCertificates
 	if len(certs) < 1 {
-		return errors.New("no certificate to verify")
+		return ContextError(errors.New("no certificate to verify"))
 	}
 	if !bytes.Equal(certs[0].Raw, expectedCertificate.Raw) {
-		return errors.New("unexpected certificate")
+		return ContextError(errors.New("unexpected certificate"))
 	}
 	return nil
 }
@@ -266,6 +234,10 @@ func verifyServerCerts(conn *tls.Conn, serverName string, config *tls.Config) er
 		}
 		opts.Intermediates.AddCert(cert)
 	}
+
 	_, err := certs[0].Verify(opts)
-	return err
+	if err != nil {
+		return ContextError(err)
+	}
+	return nil
 }

+ 45 - 17
psiphon/tunnel.go

@@ -23,6 +23,7 @@ import (
 	"bytes"
 	"code.google.com/p/go.crypto/ssh"
 	"encoding/base64"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"net"
@@ -49,11 +50,14 @@ var SupportedTunnelProtocols = []string{
 // tunnel includes a network connection to the specified server
 // and an SSH session built on top of that transport.
 type Tunnel struct {
-	serverEntry      *ServerEntry
-	protocol         string
-	conn             Conn
-	sshClient        *ssh.Client
-	sshKeepAliveQuit chan struct{}
+	serverEntry             *ServerEntry
+	sessionId               string
+	protocol                string
+	conn                    Conn
+	sshClient               *ssh.Client
+	sshKeepAliveQuit        chan struct{}
+	portForwardFailures     chan int
+	portForwardFailureTotal int
 }
 
 // Close terminates the tunnel.
@@ -77,20 +81,18 @@ func (tunnel *Tunnel) Close() {
 // the first protocol in SupportedTunnelProtocols that's also in the
 // server capabilities is used.
 func EstablishTunnel(
-	config *Config,
-	sessionId string,
-	serverEntry *ServerEntry,
-	pendingConns *Conns) (tunnel *Tunnel, err error) {
+	controller *Controller, serverEntry *ServerEntry) (tunnel *Tunnel, err error) {
+
 	// Select the protocol
 	var selectedProtocol string
 	// TODO: properly handle protocols (e.g. FRONTED-MEEK-OSSH) vs. capabilities (e.g., {FRONTED-MEEK, OSSH})
 	// for now, the code is simply assuming that MEEK capabilities imply OSSH capability.
-	if config.TunnelProtocol != "" {
-		requiredCapability := strings.TrimSuffix(config.TunnelProtocol, "-OSSH")
+	if controller.config.TunnelProtocol != "" {
+		requiredCapability := strings.TrimSuffix(controller.config.TunnelProtocol, "-OSSH")
 		if !Contains(serverEntry.Capabilities, requiredCapability) {
 			return nil, ContextError(fmt.Errorf("server does not have required capability"))
 		}
-		selectedProtocol = config.TunnelProtocol
+		selectedProtocol = controller.config.TunnelProtocol
 	} else {
 		// Order of SupportedTunnelProtocols is default preference order
 		for _, protocol := range SupportedTunnelProtocols {
@@ -106,6 +108,7 @@ func EstablishTunnel(
 	}
 	Notice(NOTICE_INFO, "connecting to %s in region %s using %s",
 		serverEntry.IpAddress, serverEntry.Region, selectedProtocol)
+
 	// The meek protocols tunnel obfuscated SSH. Obfuscated SSH is layered on top of SSH.
 	// So depending on which protocol is used, multiple layers are initialized.
 	port := 0
@@ -127,14 +130,23 @@ func EstablishTunnel(
 	case TUNNEL_PROTOCOL_SSH:
 		port = serverEntry.SshPort
 	}
+
+	// Generate a session Id for the Psiphon server API. This is generated now so
+	// that it can be sent with the SSH password payload, which helps the server
+	// associate client geo location, used in server API stats, with the session ID.
+	sessionId, err := MakeSessionId()
+	if err != nil {
+		return nil, ContextError(err)
+	}
+
 	// Create the base transport: meek or direct connection
 	dialConfig := &DialConfig{
 		ConnectTimeout:             TUNNEL_CONNECT_TIMEOUT,
 		ReadTimeout:                TUNNEL_READ_TIMEOUT,
 		WriteTimeout:               TUNNEL_WRITE_TIMEOUT,
-		PendingConns:               pendingConns,
-		BindToDeviceServiceAddress: config.BindToDeviceServiceAddress,
-		BindToDeviceDnsServer:      config.BindToDeviceDnsServer,
+		PendingConns:               controller.pendingConns,
+		BindToDeviceServiceAddress: controller.config.BindToDeviceServiceAddress,
+		BindToDeviceDnsServer:      controller.config.BindToDeviceDnsServer,
 	}
 	var conn Conn
 	if useMeek {
@@ -157,6 +169,7 @@ func EstablishTunnel(
 			conn.Close()
 		}
 	}()
+
 	// Add obfuscated SSH layer
 	var sshConn net.Conn
 	sshConn = conn
@@ -166,6 +179,7 @@ func EstablishTunnel(
 			return nil, ContextError(err)
 		}
 	}
+
 	// Now establish the SSH session over the sshConn transport
 	expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey)
 	if err != nil {
@@ -179,10 +193,18 @@ func EstablishTunnel(
 			return nil
 		},
 	}
+	sshPasswordPayload, err := json.Marshal(
+		struct {
+			SessionId   string `json:"SessionId"`
+			SshPassword string `json:"SshPassword"`
+		}{sessionId, serverEntry.SshPassword})
+	if err != nil {
+		return nil, ContextError(err)
+	}
 	sshClientConfig := &ssh.ClientConfig{
 		User: serverEntry.SshUsername,
 		Auth: []ssh.AuthMethod{
-			ssh.Password(serverEntry.SshPassword),
+			ssh.Password(string(sshPasswordPayload)),
 		},
 		HostKeyCallback: sshCertChecker.CheckHostKey,
 	}
@@ -194,6 +216,7 @@ func EstablishTunnel(
 		return nil, ContextError(err)
 	}
 	sshClient := ssh.NewClient(sshClientConn, sshChans, sshReqs)
+
 	// Run a goroutine to periodically execute SSH keepalive
 	sshKeepAliveQuit := make(chan struct{})
 	sshKeepAliveTicker := time.NewTicker(TUNNEL_SSH_KEEP_ALIVE_PERIOD)
@@ -214,11 +237,16 @@ func EstablishTunnel(
 			}
 		}
 	}()
+
 	return &Tunnel{
 			serverEntry:      serverEntry,
+			sessionId:        sessionId,
 			protocol:         selectedProtocol,
 			conn:             conn,
 			sshClient:        sshClient,
-			sshKeepAliveQuit: sshKeepAliveQuit},
+			sshKeepAliveQuit: sshKeepAliveQuit,
+			// portForwardFailures buffer size is large enough to receive the thresold number
+			// of failure reports without blocking. Senders can drop failures without blocking.
+			portForwardFailures: make(chan int, controller.config.PortForwardFailureThreshold)},
 		nil
 }

+ 1 - 1
psiphonTunnelCore.go → psiphonClient.go

@@ -36,5 +36,5 @@ func main() {
 	if err != nil {
 		log.Fatalf("error loading configuration file: %s", err)
 	}
-	psiphon.RunTunnelForever(config)
+	psiphon.RunForever(config)
 }

+ 0 - 28
psiphonTunnelCore_test.go

@@ -1,28 +0,0 @@
-/*
- * Copyright (c) 2014, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program.  If not, see <http://www.gnu.org/licenses/>.
- *
- */
-
-package main
-
-import (
-	//"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
-	"testing"
-)
-
-func TestPsiphon(t *testing.T) {
-}