Browse Source

Merge branch 'master' of https://github.com/Psiphon-Labs/psiphon-tunnel-core

Eugene Fryntov 11 years ago
parent
commit
6738229f31

+ 13 - 18
README.md

@@ -33,7 +33,10 @@ Setup
         "LocalSocksProxyPort" : 0,
         "LocalSocksProxyPort" : 0,
         "EgressRegion" : "",
         "EgressRegion" : "",
         "TunnelProtocol" : "",
         "TunnelProtocol" : "",
-        "ConnectionWorkerPoolSize" : 10
+        "ConnectionWorkerPoolSize" : 10,
+        "TunnelPoolSize" : 1,
+        "PortForwardFailureThreshold" : 10,
+        "UpstreamHttpProxyAddress" : ""
     }
     }
     ```
     ```
 
 
@@ -43,23 +46,20 @@ Setup
 Roadmap
 Roadmap
 --------------------------------------------------------------------------------
 --------------------------------------------------------------------------------
 
 
-### TODO (proof-of-concept)
+### TODO (short-term)
 
 
+* requirements for integrating with Windows client
+  * split tunnel support
+  * implement page view and bytes transferred stats
+  * resumable download of client upgrades
 * Android app
 * Android app
   * open home pages
   * open home pages
   * Go binary PIE, or use a Go library and JNI
   * Go binary PIE, or use a Go library and JNI
   * settings UI (e.g., region selection)
   * settings UI (e.g., region selection)
-* reconnection busy loop when no network available (ex. close laptop); should wait for network connectivity
 * sometimes fails to promptly detect loss of connection after device sleep
 * sometimes fails to promptly detect loss of connection after device sleep
-* continuity and performance
-  * always-on local proxies
-  * multiplex across simultaneous tunnels
-  * monitor health of tunnels; for example fail-over to new server on "ssh: rejected: administratively prohibited (open failed)" error?
 * PendingConns: is interrupting connection establishment worth the extra code complexity?
 * PendingConns: is interrupting connection establishment worth the extra code complexity?
-* prefilter entries by capability; don't log "server does not have sufficient capabilities"
 * log noise: "use of closed network connection"
 * log noise: "use of closed network connection"
 * log noise(?): 'Unsolicited response received on idle HTTP channel starting with "H"'
 * log noise(?): 'Unsolicited response received on idle HTTP channel starting with "H"'
-* use ContextError in more places
 
 
 ### TODO (future)
 ### TODO (future)
 
 
@@ -70,19 +70,14 @@ Roadmap
   * unfronted meek almost makes this obsolete, since meek sessions survive underlying
   * unfronted meek almost makes this obsolete, since meek sessions survive underlying
      HTTP transport socket disconnects. The client could prefer unfronted meek protocol
      HTTP transport socket disconnects. The client could prefer unfronted meek protocol
      when handshake returns a preemptive_reconnect_lifetime_milliseconds.
      when handshake returns a preemptive_reconnect_lifetime_milliseconds.
-* split tunnel support
-* implement page view stats
+  * could also be accomplished with TunnelPoolSize > 1 and staggaring the establishment times
 * implement local traffic stats (e.g., to display bytes sent/received)
 * implement local traffic stats (e.g., to display bytes sent/received)
-* control interface (w/ event messages)?
-* upstream proxy support
-* support upgrades
-  * download entire client
-  * download core component only
+* more formal control interface (w/ event messages)?
+* support upgrading core only
 * try multiple protocols for each server (currently only tries one protocol per server)
 * try multiple protocols for each server (currently only tries one protocol per server)
 * support a config pushed by the network
 * support a config pushed by the network
   * server can push preferred/optimized settings; client should prefer over defaults
   * server can push preferred/optimized settings; client should prefer over defaults
-  * e.g., etablish worker pool size; multiplex tunnel pool size
-* overlap between httpProxy.go and socksProxy.go: refactor?
+  * e.g., etablish worker pool size; tunnel pool size
 
 
 Licensing
 Licensing
 --------------------------------------------------------------------------------
 --------------------------------------------------------------------------------

+ 8 - 1
psiphon/LookupIP.go

@@ -23,7 +23,7 @@ package psiphon
 
 
 import (
 import (
 	"errors"
 	"errors"
-	dns "github.com/miekg/dns"
+	dns "github.com/Psiphon-Inc/dns"
 	"net"
 	"net"
 	"os"
 	"os"
 	"syscall"
 	"syscall"
@@ -58,11 +58,13 @@ func bindLookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+
 	// config.BindToDeviceDnsServer must be an IP address
 	// config.BindToDeviceDnsServer must be an IP address
 	ipAddr := net.ParseIP(config.BindToDeviceDnsServer)
 	ipAddr := net.ParseIP(config.BindToDeviceDnsServer)
 	if ipAddr == nil {
 	if ipAddr == nil {
 		return nil, ContextError(errors.New("invalid IP address"))
 		return nil, ContextError(errors.New("invalid IP address"))
 	}
 	}
+
 	// TODO: IPv6 support
 	// TODO: IPv6 support
 	var ip [4]byte
 	var ip [4]byte
 	copy(ip[:], ipAddr.To4())
 	copy(ip[:], ipAddr.To4())
@@ -72,6 +74,7 @@ func bindLookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+
 	// Convert the syscall socket to a net.Conn, for use in the dns package
 	// Convert the syscall socket to a net.Conn, for use in the dns package
 	file := os.NewFile(uintptr(socketFd), "")
 	file := os.NewFile(uintptr(socketFd), "")
 	defer file.Close()
 	defer file.Close()
@@ -79,9 +82,11 @@ func bindLookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+
 	// Set DNS query timeouts, using the ConnectTimeout from the overall Dial
 	// Set DNS query timeouts, using the ConnectTimeout from the overall Dial
 	conn.SetReadDeadline(time.Now().Add(config.ConnectTimeout))
 	conn.SetReadDeadline(time.Now().Add(config.ConnectTimeout))
 	conn.SetWriteDeadline(time.Now().Add(config.ConnectTimeout))
 	conn.SetWriteDeadline(time.Now().Add(config.ConnectTimeout))
+
 	// Make the DNS query
 	// Make the DNS query
 	// TODO: make interruptible?
 	// TODO: make interruptible?
 	dnsConn := &dns.Conn{Conn: conn}
 	dnsConn := &dns.Conn{Conn: conn}
@@ -90,6 +95,8 @@ func bindLookupIP(host string, config *DialConfig) (addrs []net.IP, err error) {
 	query.SetQuestion(dns.Fqdn(host), dns.TypeA)
 	query.SetQuestion(dns.Fqdn(host), dns.TypeA)
 	query.RecursionDesired = true
 	query.RecursionDesired = true
 	dnsConn.WriteMsg(query)
 	dnsConn.WriteMsg(query)
+
+	// Process the response
 	response, err := dnsConn.ReadMsg()
 	response, err := dnsConn.ReadMsg()
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)

+ 0 - 1
psiphon/TCPConn.go

@@ -54,7 +54,6 @@ func NewTCPDialer(config *DialConfig) Dialer {
 
 
 // TCPConn creates a new, connected TCPConn.
 // TCPConn creates a new, connected TCPConn.
 func DialTCP(addr string, config *DialConfig) (conn *TCPConn, err error) {
 func DialTCP(addr string, config *DialConfig) (conn *TCPConn, err error) {
-
 	conn, err = interruptibleTCPDial(addr, config)
 	conn, err = interruptibleTCPDial(addr, config)
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)

+ 27 - 4
psiphon/TCPConn_unix.go

@@ -39,6 +39,7 @@ type interruptibleTCPSocket struct {
 // syscall APIs are used. The sequence of syscalls in this implementation are
 // syscall APIs are used. The sequence of syscalls in this implementation are
 // taken from: https://code.google.com/p/go/issues/detail?id=6966
 // taken from: https://code.google.com/p/go/issues/detail?id=6966
 func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err error) {
 func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err error) {
+
 	// Create a socket and then, before connecting, add a TCPConn with
 	// Create a socket and then, before connecting, add a TCPConn with
 	// the unconnected socket to pendingConns. This allows pendingConns to
 	// the unconnected socket to pendingConns. This allows pendingConns to
 	// abort connections in progress.
 	// abort connections in progress.
@@ -52,6 +53,7 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 			syscall.Close(socketFd)
 			syscall.Close(socketFd)
 		}
 		}
 	}()
 	}()
+
 	// Note: this step is not interruptible
 	// Note: this step is not interruptible
 	if config.BindToDeviceServiceAddress != "" {
 	if config.BindToDeviceServiceAddress != "" {
 		err = bindToDevice(socketFd, config)
 		err = bindToDevice(socketFd, config)
@@ -59,8 +61,16 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 			return nil, ContextError(err)
 			return nil, ContextError(err)
 		}
 		}
 	}
 	}
+
+	// When using an upstream HTTP proxy, first connect to the proxy,
+	// then use HTTP CONNECT to connect to the original destination.
+	dialAddr := addr
+	if config.UpstreamHttpProxyAddress != "" {
+		dialAddr = config.UpstreamHttpProxyAddress
+	}
+
 	// Get the remote IP and port, resolving a domain name if necessary
 	// Get the remote IP and port, resolving a domain name if necessary
-	host, strPort, err := net.SplitHostPort(addr)
+	host, strPort, err := net.SplitHostPort(dialAddr)
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
@@ -78,12 +88,15 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 	// TODO: IPv6 support
 	// TODO: IPv6 support
 	var ip [4]byte
 	var ip [4]byte
 	copy(ip[:], ipAddrs[0].To4())
 	copy(ip[:], ipAddrs[0].To4())
+
 	// Enable interruption
 	// Enable interruption
 	conn = &TCPConn{
 	conn = &TCPConn{
 		interruptible: interruptibleTCPSocket{socketFd: socketFd},
 		interruptible: interruptibleTCPSocket{socketFd: socketFd},
 		readTimeout:   config.ReadTimeout,
 		readTimeout:   config.ReadTimeout,
 		writeTimeout:  config.WriteTimeout}
 		writeTimeout:  config.WriteTimeout}
 	config.PendingConns.Add(conn)
 	config.PendingConns.Add(conn)
+	defer config.PendingConns.Remove(conn)
+
 	// Connect the socket
 	// Connect the socket
 	// TODO: adjust the timeout to account for time spent resolving hostname
 	// TODO: adjust the timeout to account for time spent resolving hostname
 	sockAddr := syscall.SockaddrInet4{Addr: ip, Port: port}
 	sockAddr := syscall.SockaddrInet4{Addr: ip, Port: port}
@@ -93,16 +106,16 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 			errChannel <- errors.New("connect timeout")
 			errChannel <- errors.New("connect timeout")
 		})
 		})
 		go func() {
 		go func() {
-			errChannel <- syscall.Connect(conn.interruptible.socketFd, &sockAddr)
+			errChannel <- syscall.Connect(socketFd, &sockAddr)
 		}()
 		}()
 		err = <-errChannel
 		err = <-errChannel
 	} else {
 	} else {
-		err = syscall.Connect(conn.interruptible.socketFd, &sockAddr)
+		err = syscall.Connect(socketFd, &sockAddr)
 	}
 	}
-	config.PendingConns.Remove(conn)
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+
 	// Convert the syscall socket to a net.Conn
 	// Convert the syscall socket to a net.Conn
 	file := os.NewFile(uintptr(conn.interruptible.socketFd), "")
 	file := os.NewFile(uintptr(conn.interruptible.socketFd), "")
 	defer file.Close()
 	defer file.Close()
@@ -110,6 +123,16 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+
+	// Going through upstream HTTP proxy
+	if config.UpstreamHttpProxyAddress != "" {
+		// This call can be interrupted by closing the pending conn
+		err := HttpProxyConnect(conn, addr)
+		if err != nil {
+			return nil, ContextError(err)
+		}
+	}
+
 	return conn, nil
 	return conn, nil
 }
 }
 
 

+ 52 - 9
psiphon/TCPConn_windows.go

@@ -22,29 +22,72 @@
 package psiphon
 package psiphon
 
 
 import (
 import (
+	"errors"
 	"net"
 	"net"
 )
 )
 
 
+// interruptibleTCPSocket simulates interruptible semantics on Windows. A call
+// to interruptibleTCPClose doesn't actually interrupt a connect in progress,
+// but abandons a dial that's running in a goroutine.
+// Interruptible semantics are required by the controller for timely component
+// state changes.
+// TODO: implement true interruptible semantics on Windows; use syscall and
+// a HANDLE similar to how TCPConn_unix uses a file descriptor?
 type interruptibleTCPSocket struct {
 type interruptibleTCPSocket struct {
+	results chan *interruptibleDialResult
+}
+
+type interruptibleDialResult struct {
+	netConn net.Conn
+	err     error
 }
 }
 
 
 func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err error) {
 func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err error) {
 	if config.BindToDeviceServiceAddress != "" {
 	if config.BindToDeviceServiceAddress != "" {
 		Fatal("psiphon.interruptibleTCPDial with bind not supported on Windows")
 		Fatal("psiphon.interruptibleTCPDial with bind not supported on Windows")
 	}
 	}
-	// Note: using standard net.Dial(); interruptible connections not supported on Windows
-	netConn, err := net.DialTimeout("tcp", addr, config.ConnectTimeout)
-	if err != nil {
-		return nil, ContextError(err)
-	}
+
 	conn = &TCPConn{
 	conn = &TCPConn{
-		Conn:         netConn,
-		readTimeout:  config.ReadTimeout,
-		writeTimeout: config.WriteTimeout}
+		interruptible: interruptibleTCPSocket{results: make(chan *interruptibleDialResult, 2)},
+		readTimeout:   config.ReadTimeout,
+		writeTimeout:  config.WriteTimeout}
+	config.PendingConns.Add(conn)
+
+	// Call the blocking Dial in a goroutine
+	results := conn.interruptible.results
+	go func() {
+
+		// When using an upstream HTTP proxy, first connect to the proxy,
+		// then use HTTP CONNECT to connect to the original destination.
+		dialAddr := addr
+		if config.UpstreamHttpProxyAddress != "" {
+			dialAddr = config.UpstreamHttpProxyAddress
+		}
+
+		netConn, err := net.DialTimeout("tcp", dialAddr, config.ConnectTimeout)
+
+		if config.UpstreamHttpProxyAddress != "" {
+			err := HttpProxyConnect(netConn, addr)
+			if err != nil {
+				netConn = nil
+			}
+		}
+
+		results <- &interruptibleDialResult{netConn, err}
+	}()
+
+	// Block until Dial completes (or times out) or until interrupt
+	result := <-conn.interruptible.results
+	config.PendingConns.Remove(conn)
+	if result.err != nil {
+		return nil, ContextError(result.err)
+	}
+	conn.Conn = result.netConn
+
 	return conn, nil
 	return conn, nil
 }
 }
 
 
 func interruptibleTCPClose(interruptible interruptibleTCPSocket) error {
 func interruptibleTCPClose(interruptible interruptibleTCPSocket) error {
-	Fatal("psiphon.interruptibleTCPClose not supported on Windows")
+	interruptible.results <- &interruptibleDialResult{nil, errors.New("socket interrupted")}
 	return nil
 	return nil
 }
 }

+ 23 - 7
psiphon/config.go

@@ -41,6 +41,9 @@ type Config struct {
 	ConnectionWorkerPoolSize           int
 	ConnectionWorkerPoolSize           int
 	BindToDeviceServiceAddress         string
 	BindToDeviceServiceAddress         string
 	BindToDeviceDnsServer              string
 	BindToDeviceDnsServer              string
+	TunnelPoolSize                     int
+	PortForwardFailureThreshold        int
+	UpstreamHttpProxyAddress           string
 }
 }
 
 
 // LoadConfig reads, and parse, and validates a JSON format Psiphon config
 // LoadConfig reads, and parse, and validates a JSON format Psiphon config
@@ -48,31 +51,36 @@ type Config struct {
 func LoadConfig(filename string) (*Config, error) {
 func LoadConfig(filename string) (*Config, error) {
 	fileContents, err := ioutil.ReadFile(filename)
 	fileContents, err := ioutil.ReadFile(filename)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	var config Config
 	var config Config
 	err = json.Unmarshal(fileContents, &config)
 	err = json.Unmarshal(fileContents, &config)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 
 
 	// These fields are required; the rest are optional
 	// These fields are required; the rest are optional
 	if config.PropagationChannelId == "" {
 	if config.PropagationChannelId == "" {
-		return nil, errors.New("propagation channel ID is missing from the configuration file")
+		return nil, ContextError(
+			errors.New("propagation channel ID is missing from the configuration file"))
 	}
 	}
 	if config.SponsorId == "" {
 	if config.SponsorId == "" {
-		return nil, errors.New("sponsor ID is missing from the configuration file")
+		return nil, ContextError(
+			errors.New("sponsor ID is missing from the configuration file"))
 	}
 	}
 	if config.RemoteServerListUrl == "" {
 	if config.RemoteServerListUrl == "" {
-		return nil, errors.New("remote server list URL is missing from the configuration file")
+		return nil, ContextError(
+			errors.New("remote server list URL is missing from the configuration file"))
 	}
 	}
 	if config.RemoteServerListSignaturePublicKey == "" {
 	if config.RemoteServerListSignaturePublicKey == "" {
-		return nil, errors.New("remote server list signature public key is missing from the configuration file")
+		return nil, ContextError(
+			errors.New("remote server list signature public key is missing from the configuration file"))
 	}
 	}
 
 
 	if config.TunnelProtocol != "" {
 	if config.TunnelProtocol != "" {
 		if !Contains(SupportedTunnelProtocols, config.TunnelProtocol) {
 		if !Contains(SupportedTunnelProtocols, config.TunnelProtocol) {
-			return nil, errors.New("invalid tunnel protocol")
+			return nil, ContextError(
+				errors.New("invalid tunnel protocol"))
 		}
 		}
 	}
 	}
 
 
@@ -80,5 +88,13 @@ func LoadConfig(filename string) (*Config, error) {
 		config.ConnectionWorkerPoolSize = CONNECTION_WORKER_POOL_SIZE
 		config.ConnectionWorkerPoolSize = CONNECTION_WORKER_POOL_SIZE
 	}
 	}
 
 
+	if config.TunnelPoolSize == 0 {
+		config.TunnelPoolSize = TUNNEL_POOL_SIZE
+	}
+
+	if config.PortForwardFailureThreshold == 0 {
+		config.PortForwardFailureThreshold = PORT_FORWARD_FAILURE_THRESHOLD
+	}
+
 	return &config, nil
 	return &config, nil
 }
 }

+ 19 - 12
psiphon/config_test.go

@@ -27,17 +27,15 @@ intended to be something to learn from and derive other test sets.
 */
 */
 
 
 import (
 import (
-	"encoding/json"
-	"errors"
-	"github.com/stretchr/testify/suite"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
-	"reflect"
 	"testing"
 	"testing"
+
+	"github.com/stretchr/testify/suite"
 )
 )
 
 
 const (
 const (
-	_TEST_DIR = "testfiles"
+	_TEST_DIR = "./testfiles"
 )
 )
 
 
 type ConfigTestSuite struct {
 type ConfigTestSuite struct {
@@ -68,14 +66,28 @@ func (suite *ConfigTestSuite) Test_LoadConfig_BadPath() {
 	suite.NotNil(err, "error should be set")
 	suite.NotNil(err, "error should be set")
 }
 }
 
 
+// Tests good config file path
+func (suite *ConfigTestSuite) Test_LoadConfig_GoodPath() {
+	filename := filepath.Join(_TEST_DIR, "good.json")
+	writeConfigFile(filename, `{"PropagationChannelId": "xyz", "SponsorId": "xyz", "RemoteServerListUrl": "xyz", "RemoteServerListSignaturePublicKey": "xyz"}`)
+
+	// Use absolute path
+	abspath, _ := filepath.Abs(filename)
+	_, err := LoadConfig(abspath)
+	suite.Nil(err, "error should not be set")
+
+	// Use relative path
+	suite.False(filepath.IsAbs(filename))
+	_, err = LoadConfig(filename)
+	suite.Nil(err, "error should not be set")
+}
+
 // Tests non-JSON file contents
 // Tests non-JSON file contents
 func (suite *ConfigTestSuite) Test_LoadConfig_BadFileContents() {
 func (suite *ConfigTestSuite) Test_LoadConfig_BadFileContents() {
 	filename := filepath.Join(_TEST_DIR, "junk.json")
 	filename := filepath.Join(_TEST_DIR, "junk.json")
 	writeConfigFile(filename, "**ohhi**")
 	writeConfigFile(filename, "**ohhi**")
 	_, err := LoadConfig(filename)
 	_, err := LoadConfig(filename)
 	suite.NotNil(err, "error should be set")
 	suite.NotNil(err, "error should be set")
-	// TODO: Is it worthwhile to test error types?
-	suite.Equal(reflect.TypeOf(json.SyntaxError{}).Name(), reflect.TypeOf(err).Elem().Name())
 }
 }
 
 
 // Tests config file with JSON contents that don't match our structure
 // Tests config file with JSON contents that don't match our structure
@@ -86,25 +98,21 @@ func (suite *ConfigTestSuite) Test_LoadConfig_BadJson() {
 	writeConfigFile(filename, `{"f1": 11, "f2": "two"}`)
 	writeConfigFile(filename, `{"f1": 11, "f2": "two"}`)
 	_, err := LoadConfig(filename)
 	_, err := LoadConfig(filename)
 	suite.NotNil(err, "error should be set")
 	suite.NotNil(err, "error should be set")
-	suite.Equal(reflect.TypeOf(errors.New("")).Elem().Name(), reflect.TypeOf(err).Elem().Name())
 
 
 	// Has one of our required fields, but wrong type
 	// Has one of our required fields, but wrong type
 	writeConfigFile(filename, `{"PropagationChannelId": 11, "f2": "two"}`)
 	writeConfigFile(filename, `{"PropagationChannelId": 11, "f2": "two"}`)
 	_, err = LoadConfig(filename)
 	_, err = LoadConfig(filename)
 	suite.NotNil(err, "error should be set")
 	suite.NotNil(err, "error should be set")
-	suite.Equal(reflect.TypeOf(json.UnmarshalTypeError{}).Name(), reflect.TypeOf(err).Elem().Name())
 
 
 	// Has one of our required fields, but null
 	// Has one of our required fields, but null
 	writeConfigFile(filename, `{"PropagationChannelId": null, "f2": "two"}`)
 	writeConfigFile(filename, `{"PropagationChannelId": null, "f2": "two"}`)
 	_, err = LoadConfig(filename)
 	_, err = LoadConfig(filename)
 	suite.NotNil(err, "error should be set")
 	suite.NotNil(err, "error should be set")
-	suite.Equal(reflect.TypeOf(errors.New("")).Elem().Name(), reflect.TypeOf(err).Elem().Name())
 
 
 	// Has one of our required fields, but empty string
 	// Has one of our required fields, but empty string
 	writeConfigFile(filename, `{"PropagationChannelId": "", "f2": "two"}`)
 	writeConfigFile(filename, `{"PropagationChannelId": "", "f2": "two"}`)
 	_, err = LoadConfig(filename)
 	_, err = LoadConfig(filename)
 	suite.NotNil(err, "error should be set")
 	suite.NotNil(err, "error should be set")
-	suite.Equal(reflect.TypeOf(errors.New("")).Elem().Name(), reflect.TypeOf(err).Elem().Name())
 
 
 	// Has all of our required fields, but no optional fields
 	// Has all of our required fields, but no optional fields
 	writeConfigFile(filename, `{"PropagationChannelId": "xyz", "SponsorId": "xyz", "RemoteServerListUrl": "xyz", "RemoteServerListSignaturePublicKey": "xyz"}`)
 	writeConfigFile(filename, `{"PropagationChannelId": "xyz", "SponsorId": "xyz", "RemoteServerListUrl": "xyz", "RemoteServerListSignaturePublicKey": "xyz"}`)
@@ -116,7 +124,6 @@ func (suite *ConfigTestSuite) Test_LoadConfig_BadJson() {
 	writeConfigFile(filename, `{"ClientVersion": "string, not int", "PropagationChannelId": "xyz", "SponsorId": "xyz", "RemoteServerListUrl": "xyz", "RemoteServerListSignaturePublicKey": "xyz"}`)
 	writeConfigFile(filename, `{"ClientVersion": "string, not int", "PropagationChannelId": "xyz", "SponsorId": "xyz", "RemoteServerListUrl": "xyz", "RemoteServerListSignaturePublicKey": "xyz"}`)
 	_, err = LoadConfig(filename)
 	_, err = LoadConfig(filename)
 	suite.NotNil(err, "error should be set")
 	suite.NotNil(err, "error should be set")
-	suite.Equal(reflect.TypeOf(json.UnmarshalTypeError{}).Name(), reflect.TypeOf(err).Elem().Name())
 
 
 	// Has null for optional field
 	// Has null for optional field
 	writeConfigFile(filename, `{"ClientVersion": null, "PropagationChannelId": "xyz", "SponsorId": "xyz", "RemoteServerListUrl": "xyz", "RemoteServerListSignaturePublicKey": "xyz"}`)
 	writeConfigFile(filename, `{"ClientVersion": null, "PropagationChannelId": "xyz", "SponsorId": "xyz", "RemoteServerListUrl": "xyz", "RemoteServerListSignaturePublicKey": "xyz"}`)

+ 60 - 0
psiphon/conn.go

@@ -20,6 +20,9 @@
 package psiphon
 package psiphon
 
 
 import (
 import (
+	"bytes"
+	"fmt"
+	"io"
 	"net"
 	"net"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -28,6 +31,12 @@ import (
 // DialConfig contains parameters to determine the behavior
 // DialConfig contains parameters to determine the behavior
 // of a Psiphon dialer (TCPDial, MeekDial, etc.)
 // of a Psiphon dialer (TCPDial, MeekDial, etc.)
 type DialConfig struct {
 type DialConfig struct {
+
+	// UpstreamHttpProxyAddress specifies an HTTP proxy to connect through
+	// (the proxy must support HTTP CONNECT). The address may be a hostname
+	// or IP address and must include a port number.
+	UpstreamHttpProxyAddress string
+
 	ConnectTimeout time.Duration
 	ConnectTimeout time.Duration
 	ReadTimeout    time.Duration
 	ReadTimeout    time.Duration
 	WriteTimeout   time.Duration
 	WriteTimeout   time.Duration
@@ -99,3 +108,54 @@ func (conns *Conns) CloseAll() {
 	}
 	}
 	conns.conns = make(map[net.Conn]bool)
 	conns.conns = make(map[net.Conn]bool)
 }
 }
+
+// Relay sends to remoteConn bytes received from localConn,
+// and sends to localConn bytes received from remoteConn.
+func Relay(localConn, remoteConn net.Conn) {
+	copyWaitGroup := new(sync.WaitGroup)
+	copyWaitGroup.Add(1)
+	go func() {
+		defer copyWaitGroup.Done()
+		_, err := io.Copy(localConn, remoteConn)
+		if err != nil {
+			Notice(NOTICE_ALERT, "Relay failed: %s", ContextError(err))
+		}
+	}()
+	_, err := io.Copy(remoteConn, localConn)
+	if err != nil {
+		Notice(NOTICE_ALERT, "Relay failed: %s", ContextError(err))
+	}
+	copyWaitGroup.Wait()
+}
+
+// HttpProxyConnect establishes a HTTP CONNECT tunnel to addr through
+// an established network connection to an HTTP proxy. It is assumed that
+// no payload bytes have been sent through the connection to the proxy.
+func HttpProxyConnect(rawConn net.Conn, addr string) (err error) {
+	hostname, _, err := net.SplitHostPort(addr)
+	if err != nil {
+		return ContextError(err)
+	}
+
+	// TODO: use the proxy request/response code from net/http/transport.go?
+	connectRequest := fmt.Sprintf(
+		"CONNECT %s HTTP/1.1\r\nHost: %s\r\nConnection: Keep-Alive\r\n\r\n",
+		addr, hostname)
+	_, err = rawConn.Write([]byte(connectRequest))
+	if err != nil {
+		return ContextError(err)
+	}
+
+	expectedResponse := []byte("HTTP/1.1 200 OK\r\n\r\n")
+	readBuffer := make([]byte, len(expectedResponse))
+	_, err = io.ReadFull(rawConn, readBuffer)
+	if err != nil {
+		return ContextError(err)
+	}
+
+	if !bytes.Equal(readBuffer, expectedResponse) {
+		return fmt.Errorf("unexpected HTTP proxy response: %s", string(readBuffer))
+	}
+
+	return nil
+}

+ 616 - 0
psiphon/controller.go

@@ -0,0 +1,616 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+// Package psiphon implements the core tunnel functionality of a Psiphon client.
+// The main function is RunForever, which runs a Controller that obtains lists of
+// servers, establishes tunnel connections, and runs local proxies through which
+// tunneled traffic may be sent.
+package psiphon
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"net"
+	"sync"
+	"time"
+)
+
+// Controller is a tunnel lifecycle coordinator. It manages lists of servers to
+// connect to; establishes and monitors tunnels; and runs local proxies which
+// route traffic through the tunnels.
+type Controller struct {
+	config                    *Config
+	failureSignal             chan struct{}
+	shutdownBroadcast         chan struct{}
+	runWaitGroup              *sync.WaitGroup
+	establishedTunnels        chan *Tunnel
+	failedTunnels             chan *Tunnel
+	tunnelMutex               sync.Mutex
+	tunnels                   []*Tunnel
+	nextTunnel                int
+	operateWaitGroup          *sync.WaitGroup
+	isEstablishing            bool
+	establishWaitGroup        *sync.WaitGroup
+	stopEstablishingBroadcast chan struct{}
+	candidateServerEntries    chan *ServerEntry
+	pendingConns              *Conns
+}
+
+// NewController initializes a new controller.
+func NewController(config *Config) (controller *Controller) {
+	return &Controller{
+		config: config,
+		// failureSignal receives a signal from a component (including socks and
+		// http local proxies) if they unexpectedly fail. Senders should not block.
+		// A buffer allows at least one stop signal to be sent before there is a receiver.
+		failureSignal:     make(chan struct{}, 1),
+		shutdownBroadcast: make(chan struct{}),
+		runWaitGroup:      new(sync.WaitGroup),
+		// establishedTunnels and failedTunnels buffer sizes are large enough to
+		// receive full pools of tunnels without blocking. Senders should not block.
+		establishedTunnels: make(chan *Tunnel, config.TunnelPoolSize),
+		failedTunnels:      make(chan *Tunnel, config.TunnelPoolSize),
+		tunnels:            make([]*Tunnel, 0),
+		operateWaitGroup:   new(sync.WaitGroup),
+		isEstablishing:     false,
+		pendingConns:       new(Conns),
+	}
+}
+
+// Run executes the controller. It launches components and then monitors
+// for a shutdown signal; after receiving the signal it shuts down the
+// controller.
+// The components include:
+// - the periodic remote server list fetcher
+// - the tunnel manager
+// - a local SOCKS proxy that port forwards through the pool of tunnels
+// - a local HTTP proxy that port forwards through the pool of tunnels
+func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
+
+	Notice(NOTICE_VERSION, VERSION)
+
+	socksProxy, err := NewSocksProxy(controller.config, controller)
+	if err != nil {
+		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
+		return
+	}
+	defer socksProxy.Close()
+	httpProxy, err := NewHttpProxy(controller.config, controller)
+	if err != nil {
+		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
+		return
+	}
+	defer httpProxy.Close()
+
+	controller.runWaitGroup.Add(2)
+	go controller.remoteServerListFetcher()
+	go controller.runTunnels()
+
+	select {
+	case <-shutdownBroadcast:
+		Notice(NOTICE_INFO, "controller shutdown by request")
+	case <-controller.failureSignal:
+		Notice(NOTICE_ALERT, "controller shutdown due to failure")
+	}
+
+	// Note: in addition to establish(), this pendingConns will interrupt
+	// FetchRemoteServerList
+	controller.pendingConns.CloseAll()
+	close(controller.shutdownBroadcast)
+	controller.runWaitGroup.Wait()
+
+	Notice(NOTICE_INFO, "exiting controller")
+}
+
+// SignalFailure notifies the controller that an associated component has failed.
+// This will terminate the controller.
+func (controller *Controller) SignalFailure() {
+	select {
+	case controller.failureSignal <- *new(struct{}):
+	default:
+	}
+}
+
+// remoteServerListFetcher fetches an out-of-band list of server entries
+// for more tunnel candidates. It fetches immediately, retries after failure
+// with a wait period, and refetches after success with a longer wait period.
+func (controller *Controller) remoteServerListFetcher() {
+	defer controller.runWaitGroup.Done()
+
+	// Note: unlike existing Psiphon clients, this code
+	// always makes the fetch remote server list request
+loop:
+	for {
+		// TODO: FetchRemoteServerList should have its own pendingConns,
+		// otherwise it may needlessly abort when establish is stopped.
+		err := FetchRemoteServerList(controller.config, controller.pendingConns)
+		var duration time.Duration
+		if err != nil {
+			Notice(NOTICE_ALERT, "failed to fetch remote server list: %s", err)
+			duration = FETCH_REMOTE_SERVER_LIST_RETRY_TIMEOUT
+		} else {
+			duration = FETCH_REMOTE_SERVER_LIST_STALE_TIMEOUT
+		}
+		timeout := time.After(duration)
+		select {
+		case <-timeout:
+			// Fetch again
+		case <-controller.shutdownBroadcast:
+			break loop
+		}
+	}
+
+	Notice(NOTICE_INFO, "exiting remote server list fetcher")
+}
+
+// runTunnels is the controller tunnel management main loop. It starts and stops
+// establishing tunnels based on the target tunnel pool size and the current size
+// of the pool. Tunnels are established asynchronously using worker goroutines.
+// When a tunnel is established, it's added to the active pool and a corresponding
+// operateTunnel goroutine is launched which starts a session in the tunnel and
+// monitors the tunnel for failures.
+// When a tunnel fails, it's removed from the pool and the establish process is
+// restarted to fill the pool.
+func (controller *Controller) runTunnels() {
+	defer controller.runWaitGroup.Done()
+
+	// Don't start establishing until there are some server candidates. The
+	// typical case is a client with no server entries which will wait for
+	// the first successful FetchRemoteServerList to populate the data store.
+	for {
+		if HasServerEntries(
+			controller.config.EgressRegion, controller.config.TunnelProtocol) {
+			break
+		}
+		// TODO: replace polling with signal
+		timeout := time.After(1 * time.Second)
+		select {
+		case <-timeout:
+		case <-controller.shutdownBroadcast:
+			return
+		}
+	}
+	controller.startEstablishing()
+loop:
+	for {
+		select {
+		case failedTunnel := <-controller.failedTunnels:
+			Notice(NOTICE_ALERT, "tunnel failed: %s", failedTunnel.serverEntry.IpAddress)
+			controller.terminateTunnel(failedTunnel)
+			// Note: only this goroutine may call startEstablishing/stopEstablishing and access
+			// isEstablishing.
+			if !controller.isEstablishing {
+				controller.startEstablishing()
+			}
+
+		// !TODO! design issue: might not be enough server entries with region/caps to ever fill tunnel slots
+		// solution(?) target MIN(CountServerEntries(region, protocol), TunnelPoolSize)
+		case establishedTunnel := <-controller.establishedTunnels:
+			Notice(NOTICE_INFO, "established tunnel: %s", establishedTunnel.serverEntry.IpAddress)
+			if controller.registerTunnel(establishedTunnel) {
+				Notice(NOTICE_INFO, "active tunnel: %s", establishedTunnel.serverEntry.IpAddress)
+				controller.operateWaitGroup.Add(1)
+				go controller.operateTunnel(establishedTunnel)
+			} else {
+				controller.discardTunnel(establishedTunnel)
+			}
+			if controller.isFullyEstablished() {
+				controller.stopEstablishing()
+			}
+
+		case <-controller.shutdownBroadcast:
+			break loop
+		}
+	}
+	controller.stopEstablishing()
+	controller.terminateAllTunnels()
+	controller.operateWaitGroup.Wait()
+
+	// Drain tunnel channels
+	close(controller.establishedTunnels)
+	for tunnel := range controller.establishedTunnels {
+		controller.discardTunnel(tunnel)
+	}
+	close(controller.failedTunnels)
+	for tunnel := range controller.failedTunnels {
+		controller.discardTunnel(tunnel)
+	}
+
+	Notice(NOTICE_INFO, "exiting run tunnels")
+}
+
+// discardTunnel disposes of a successful connection that is no longer required.
+func (controller *Controller) discardTunnel(tunnel *Tunnel) {
+	Notice(NOTICE_INFO, "discard tunnel: %s", tunnel.serverEntry.IpAddress)
+	// TODO: not calling PromoteServerEntry, since that would rank the
+	// discarded tunnel before fully active tunnels. Can a discarded tunnel
+	// be promoted (since it connects), but with lower rank than all active
+	// tunnels?
+	tunnel.Close()
+}
+
+// registerTunnel adds the connected tunnel to the pool of active tunnels
+// which are candidates for port forwarding. Returns true if the pool has an
+// empty slot and false if the pool is full (caller should discard the tunnel).
+func (controller *Controller) registerTunnel(tunnel *Tunnel) bool {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	if len(controller.tunnels) >= controller.config.TunnelPoolSize {
+		return false
+	}
+	// Perform a final check just in case we've established
+	// a duplicate connection.
+	for _, activeTunnel := range controller.tunnels {
+		if activeTunnel.serverEntry.IpAddress == tunnel.serverEntry.IpAddress {
+			Notice(NOTICE_ALERT, "duplicate tunnel: %s", tunnel.serverEntry.IpAddress)
+			return false
+		}
+	}
+	controller.tunnels = append(controller.tunnels, tunnel)
+	Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
+	return true
+}
+
+// isFullyEstablished indicates if the pool of active tunnels is full.
+func (controller *Controller) isFullyEstablished() bool {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	return len(controller.tunnels) >= controller.config.TunnelPoolSize
+}
+
+// terminateTunnel removes a tunnel from the pool of active tunnels
+// and closes the tunnel. The next-tunnel state used by getNextActiveTunnel
+// is adjusted as required.
+func (controller *Controller) terminateTunnel(tunnel *Tunnel) {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	for index, activeTunnel := range controller.tunnels {
+		if tunnel == activeTunnel {
+			controller.tunnels = append(
+				controller.tunnels[:index], controller.tunnels[index+1:]...)
+			if controller.nextTunnel > index {
+				controller.nextTunnel--
+			}
+			if controller.nextTunnel >= len(controller.tunnels) {
+				controller.nextTunnel = 0
+			}
+			activeTunnel.Close()
+			Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
+			break
+		}
+	}
+}
+
+// terminateAllTunnels empties the tunnel pool, closing all active tunnels.
+// This is used when shutting down the controller.
+func (controller *Controller) terminateAllTunnels() {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	for _, activeTunnel := range controller.tunnels {
+		activeTunnel.Close()
+	}
+	controller.tunnels = make([]*Tunnel, 0)
+	controller.nextTunnel = 0
+	Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
+}
+
+// getNextActiveTunnel returns the next tunnel from the pool of active
+// tunnels. Currently, tunnel selection order is simple round-robin.
+func (controller *Controller) getNextActiveTunnel() (tunnel *Tunnel) {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	for i := len(controller.tunnels); i > 0; i-- {
+		tunnel = controller.tunnels[controller.nextTunnel]
+		controller.nextTunnel =
+			(controller.nextTunnel + 1) % len(controller.tunnels)
+		// A tunnel must[*] have started its session (performed the server
+		// API handshake sequence) before it may be used for tunneling traffic
+		// [*]currently not enforced by the server, but may be in the future.
+		if tunnel.IsSessionStarted() {
+			return tunnel
+		}
+	}
+	return nil
+}
+
+// isActiveTunnelServerEntries is used to check if there's already
+// an existing tunnel to a candidate server.
+func (controller *Controller) isActiveTunnelServerEntry(serverEntry *ServerEntry) bool {
+	controller.tunnelMutex.Lock()
+	defer controller.tunnelMutex.Unlock()
+	for _, activeTunnel := range controller.tunnels {
+		if activeTunnel.serverEntry.IpAddress == serverEntry.IpAddress {
+			return true
+		}
+	}
+	return false
+}
+
+// operateTunnel starts a Psiphon session (handshake, etc.) on a newly
+// connected tunnel, and then monitors the tunnel for failures:
+//
+// 1. Overall tunnel failure: the tunnel sends a signal to the ClosedSignal
+// channel on keep-alive failure and other transport I/O errors. In case
+// of such a failure, the tunnel is marked as failed.
+//
+// 2. Tunnel port forward failures: the tunnel connection may stay up but
+// the client may still fail to establish port forwards due to server load
+// and other conditions. After a threshold number of such failures, the
+// overall tunnel is marked as failed.
+//
+// TODO: currently, any connect (dial), read, or write error associated with
+// a port forward is counted as a failure. It may be important to differentiate
+// between failures due to Psiphon server conditions and failures due to the
+// origin/target server (in the latter case, the tunnel is healthy). Here are
+// some typical error messages to consider matching against (or ignoring):
+//
+// - "ssh: rejected: administratively prohibited (open failed)"
+// - "ssh: rejected: connect failed (Connection timed out)"
+// - "write tcp ... broken pipe"
+// - "read tcp ... connection reset by peer"
+// - "ssh: unexpected packet in response to channel open: <nil>"
+//
+func (controller *Controller) operateTunnel(tunnel *Tunnel) {
+	defer controller.operateWaitGroup.Done()
+
+	tunnelClosedSignal := make(chan struct{}, 1)
+	err := tunnel.conn.SetClosedSignal(tunnelClosedSignal)
+	if err != nil {
+		err = fmt.Errorf("failed to set closed signal: %s", err)
+	}
+
+	Notice(NOTICE_INFO, "starting session for %s", tunnel.serverEntry.IpAddress)
+	// TODO: NewSession server API calls may block shutdown
+	_, err = NewSession(controller.config, tunnel)
+	if err != nil {
+		err = fmt.Errorf("error starting session for %s: %s", tunnel.serverEntry.IpAddress, err)
+	}
+
+	// Tunnel may now be used for port forwarding
+	tunnel.SetSessionStarted()
+
+	// Promote this successful tunnel to first rank so it's one
+	// of the first candidates next time establish runs.
+	PromoteServerEntry(tunnel.serverEntry.IpAddress)
+
+	for err == nil {
+		select {
+		case failures := <-tunnel.portForwardFailures:
+			tunnel.portForwardFailureTotal += failures
+			Notice(
+				NOTICE_INFO, "port forward failures for %s: %d",
+				tunnel.serverEntry.IpAddress, tunnel.portForwardFailureTotal)
+			if tunnel.portForwardFailureTotal > controller.config.PortForwardFailureThreshold {
+				err = errors.New("tunnel exceeded port forward failure threshold")
+			}
+
+		case <-tunnelClosedSignal:
+			// TODO: this signal can be received during a commanded shutdown due to
+			// how tunnels are closed; should rework this to avoid log noise.
+			err = errors.New("tunnel closed unexpectedly")
+
+		case <-controller.shutdownBroadcast:
+			Notice(NOTICE_INFO, "shutdown operate tunnel")
+			return
+		}
+	}
+
+	if err != nil {
+		Notice(NOTICE_ALERT, "operate tunnel error for %s: %s", tunnel.serverEntry.IpAddress, err)
+		// Don't block. Assumes the receiver has a buffer large enough for
+		// the typical number of operated tunnels. In case there's no room,
+		// terminate the tunnel (runTunnels won't get a signal in this case).
+		select {
+		case controller.failedTunnels <- tunnel:
+		default:
+			controller.terminateTunnel(tunnel)
+		}
+	}
+}
+
+// TunneledConn implements net.Conn and wraps a port foward connection.
+// It is used to hook into Read and Write to observe I/O errors and
+// report these errors back to the tunnel monitor as port forward failures.
+type TunneledConn struct {
+	net.Conn
+	tunnel *Tunnel
+}
+
+func (conn *TunneledConn) Read(buffer []byte) (n int, err error) {
+	n, err = conn.Conn.Read(buffer)
+	if err != nil && err != io.EOF {
+		// Report 1 new failure. Won't block; assumes the receiver
+		// has a sufficient buffer for the threshold number of reports.
+		// TODO: conditional on type of error or error message?
+		select {
+		case conn.tunnel.portForwardFailures <- 1:
+		default:
+		}
+	}
+	return
+}
+
+func (conn *TunneledConn) Write(buffer []byte) (n int, err error) {
+	n, err = conn.Conn.Write(buffer)
+	if err != nil && err != io.EOF {
+		// Same as TunneledConn.Read()
+		select {
+		case conn.tunnel.portForwardFailures <- 1:
+		default:
+		}
+	}
+	return
+}
+
+// Dial selects an active tunnel and establishes a port forward
+// connection through the selected tunnel. Failure to connect is considered
+// a port foward failure, for the purpose of monitoring tunnel health.
+func (controller *Controller) Dial(remoteAddr string) (conn net.Conn, err error) {
+	tunnel := controller.getNextActiveTunnel()
+	if tunnel == nil {
+		return nil, ContextError(errors.New("no active tunnels"))
+	}
+	tunnelConn, err := tunnel.Dial(remoteAddr)
+	if err != nil {
+		// TODO: conditional on type of error or error message?
+		select {
+		case tunnel.portForwardFailures <- 1:
+		default:
+		}
+		return nil, ContextError(err)
+	}
+	return &TunneledConn{
+			Conn:   tunnelConn,
+			tunnel: tunnel},
+		nil
+}
+
+// startEstablishing creates a pool of worker goroutines which will
+// attempt to establish tunnels to candidate servers. The candidates
+// are generated by another goroutine.
+func (controller *Controller) startEstablishing() {
+	if controller.isEstablishing {
+		return
+	}
+	Notice(NOTICE_INFO, "start establishing")
+	controller.isEstablishing = true
+	controller.establishWaitGroup = new(sync.WaitGroup)
+	controller.stopEstablishingBroadcast = make(chan struct{})
+	controller.candidateServerEntries = make(chan *ServerEntry)
+
+	for i := 0; i < controller.config.ConnectionWorkerPoolSize; i++ {
+		controller.establishWaitGroup.Add(1)
+		go controller.establishTunnelWorker()
+	}
+
+	controller.establishWaitGroup.Add(1)
+	go controller.establishCandidateGenerator()
+}
+
+// stopEstablishing signals the establish goroutines to stop and waits
+// for the group to halt. pendingConns is used to interrupt any worker
+// blocked on a socket connect.
+func (controller *Controller) stopEstablishing() {
+	if !controller.isEstablishing {
+		return
+	}
+	Notice(NOTICE_INFO, "stop establishing")
+	// Note: on Windows, interruptibleTCPClose doesn't really interrupt socket connects
+	// and may leave goroutines running for a time after the Wait call.
+	controller.pendingConns.CloseAll()
+	close(controller.stopEstablishingBroadcast)
+	// Note: establishCandidateGenerator closes controller.candidateServerEntries
+	// (as it may be sending to that channel).
+	controller.establishWaitGroup.Wait()
+
+	controller.isEstablishing = false
+	controller.establishWaitGroup = nil
+	controller.stopEstablishingBroadcast = nil
+	controller.candidateServerEntries = nil
+}
+
+// establishCandidateGenerator populates the candidate queue with server entries
+// from the data store. Server entries are iterated in rank order, so that promoted
+// servers with higher rank are priority candidates.
+func (controller *Controller) establishCandidateGenerator() {
+	defer controller.establishWaitGroup.Done()
+loop:
+	for {
+		// Note: it's possible that an active tunnel in excludeServerEntries will
+		// fail during this iteration of server entries and in that case the
+		// cooresponding server will not be retried (within the same iteration).
+		iterator, err := NewServerEntryIterator(
+			controller.config.EgressRegion, controller.config.TunnelProtocol)
+		if err != nil {
+			Notice(NOTICE_ALERT, "failed to iterate over candidates: %s", err)
+			controller.SignalFailure()
+			break loop
+		}
+		for {
+			serverEntry, err := iterator.Next()
+			if err != nil {
+				Notice(NOTICE_ALERT, "failed to get next candidate: %s", err)
+				controller.SignalFailure()
+				break loop
+			}
+			if serverEntry == nil {
+				// Completed this iteration
+				break
+			}
+			select {
+			case controller.candidateServerEntries <- serverEntry:
+			case <-controller.stopEstablishingBroadcast:
+				break loop
+			case <-controller.shutdownBroadcast:
+				break loop
+			}
+		}
+		iterator.Close()
+		// After a complete iteration of candidate servers, pause before iterating again.
+		// This helps avoid some busy wait loop conditions, and also allows some time for
+		// network conditions to change.
+		timeout := time.After(ESTABLISH_TUNNEL_PAUSE_PERIOD)
+		select {
+		case <-timeout:
+			// Retry iterating
+		case <-controller.stopEstablishingBroadcast:
+			break loop
+		case <-controller.shutdownBroadcast:
+			break loop
+		}
+	}
+	close(controller.candidateServerEntries)
+	Notice(NOTICE_INFO, "stopped candidate generator")
+}
+
+// establishTunnelWorker pulls candidates from the candidate queue, establishes
+// a connection to the tunnel server, and delivers the established tunnel to a channel.
+func (controller *Controller) establishTunnelWorker() {
+	defer controller.establishWaitGroup.Done()
+	for serverEntry := range controller.candidateServerEntries {
+		// Note: don't receive from candidateQueue and broadcastStopWorkers in the same
+		// select, since we want to prioritize receiving the stop signal
+		select {
+		case <-controller.stopEstablishingBroadcast:
+			return
+		default:
+		}
+		// There may already be a tunnel to this candidate. If so, skip it.
+		if controller.isActiveTunnelServerEntry(serverEntry) {
+			continue
+		}
+		tunnel, err := EstablishTunnel(
+			controller.config, controller.pendingConns, serverEntry)
+		if err != nil {
+			// TODO: distingush case where conn is interrupted?
+			Notice(NOTICE_INFO, "failed to connect to %s: %s", serverEntry.IpAddress, err)
+		} else {
+			// Don't block. Assumes the receiver has a buffer large enough for
+			// the number of desired tunnels. If there's no room, the tunnel must
+			// not be required so it's discarded.
+			select {
+			case controller.establishedTunnels <- tunnel:
+			default:
+				controller.discardTunnel(tunnel)
+			}
+		}
+	}
+	Notice(NOTICE_INFO, "stopped establish worker")
+}

+ 109 - 61
psiphon/dataStore.go

@@ -24,7 +24,8 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	sqlite3 "github.com/mattn/go-sqlite3"
+	sqlite3 "github.com/Psiphon-Inc/go-sqlite3"
+	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
 )
 )
@@ -47,6 +48,9 @@ func initDataStore() {
              rank integer not null unique,
              rank integer not null unique,
              region text not null,
              region text not null,
              data blob not null);
              data blob not null);
+	    create table if not exists serverEntryProtocol
+	        (serverEntryId text not null,
+	         protocol text not null);
         create table if not exists keyValue
         create table if not exists keyValue
             (key text not null,
             (key text not null,
              value text not null);
              value text not null);
@@ -130,7 +134,6 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 		if serverEntryExists && !replaceIfExists {
 		if serverEntryExists && !replaceIfExists {
 			return nil
 			return nil
 		}
 		}
-		// TODO: also skip updates if replaceIfExists but 'data' has not changed
 		_, err := transaction.Exec(`
 		_, err := transaction.Exec(`
             update serverEntry set rank = rank + 1
             update serverEntry set rank = rank + 1
                 where id = (select id from serverEntry order by rank desc limit 1);
                 where id = (select id from serverEntry order by rank desc limit 1);
@@ -150,6 +153,20 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
+		for _, protocol := range SupportedTunnelProtocols {
+			// Note: for meek, the capabilities are FRONTED-MEEK and UNFRONTED-MEEK
+			// and the additonal OSSH service is assumed to be available internally.
+			requiredCapability := strings.TrimSuffix(protocol, "-OSSH")
+			if Contains(serverEntry.Capabilities, requiredCapability) {
+				_, err = transaction.Exec(`
+		            insert or ignore into serverEntryProtocol (serverEntryId, protocol)
+		            values (?, ?);
+		            `, serverEntry.IpAddress, protocol)
+				if err != nil {
+					return err
+				}
+			}
+		}
 		// TODO: post notice after commit
 		// TODO: post notice after commit
 		if !serverEntryExists {
 		if !serverEntryExists {
 			Notice(NOTICE_INFO, "stored server %s", serverEntry.IpAddress)
 			Notice(NOTICE_INFO, "stored server %s", serverEntry.IpAddress)
@@ -176,90 +193,83 @@ func PromoteServerEntry(ipAddress string) error {
 	})
 	})
 }
 }
 
 
-// ServerEntryCycler is used to continuously iterate over
+// ServerEntryIterator is used to iterate over
 // stored server entries in rank order.
 // stored server entries in rank order.
-type ServerEntryCycler struct {
+type ServerEntryIterator struct {
 	region      string
 	region      string
+	protocol    string
+	excludeIds  []string
 	transaction *sql.Tx
 	transaction *sql.Tx
 	cursor      *sql.Rows
 	cursor      *sql.Rows
-	isReset     bool
 }
 }
 
 
-// NewServerEntryCycler creates a new ServerEntryCycler
-func NewServerEntryCycler(region string) (cycler *ServerEntryCycler, err error) {
+// NewServerEntryIterator creates a new NewServerEntryIterator
+func NewServerEntryIterator(region, protocol string) (iterator *ServerEntryIterator, err error) {
 	initDataStore()
 	initDataStore()
-	cycler = &ServerEntryCycler{region: region}
-	err = cycler.Reset()
+	iterator = &ServerEntryIterator{
+		region:   region,
+		protocol: protocol,
+	}
+	err = iterator.Reset()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return cycler, nil
+	return iterator, nil
 }
 }
 
 
-// Reset a ServerEntryCycler to the start of its cycle. The next
+// Reset a NewServerEntryIterator to the start of its cycle. The next
 // call to Next will return the first server entry.
 // call to Next will return the first server entry.
-func (cycler *ServerEntryCycler) Reset() error {
-	cycler.Close()
+func (iterator *ServerEntryIterator) Reset() error {
+	iterator.Close()
 	transaction, err := singleton.db.Begin()
 	transaction, err := singleton.db.Begin()
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
 	var cursor *sql.Rows
 	var cursor *sql.Rows
-	if cycler.region == "" {
-		cursor, err = transaction.Query(
-			"select data from serverEntry order by rank desc;")
-	} else {
-		cursor, err = transaction.Query(
-			"select data from serverEntry where region = ? order by rank desc;",
-			cycler.region)
-	}
+	whereClause, whereParams := makeServerEntryWhereClause(
+		iterator.region, iterator.protocol, nil)
+	query := "select data from serverEntry" + whereClause + " order by rank desc;"
+	cursor, err = transaction.Query(query, whereParams...)
 	if err != nil {
 	if err != nil {
 		transaction.Rollback()
 		transaction.Rollback()
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
-	cycler.isReset = true
-	cycler.transaction = transaction
-	cycler.cursor = cursor
+	iterator.transaction = transaction
+	iterator.cursor = cursor
 	return nil
 	return nil
 }
 }
 
 
-// Close cleans up resources associated with a ServerEntryCycler.
-func (cycler *ServerEntryCycler) Close() {
-	if cycler.cursor != nil {
-		cycler.cursor.Close()
+// Close cleans up resources associated with a ServerEntryIterator.
+func (iterator *ServerEntryIterator) Close() {
+	if iterator.cursor != nil {
+		iterator.cursor.Close()
 	}
 	}
-	cycler.cursor = nil
-	if cycler.transaction != nil {
-		cycler.transaction.Rollback()
+	iterator.cursor = nil
+	if iterator.transaction != nil {
+		iterator.transaction.Rollback()
 	}
 	}
-	cycler.transaction = nil
+	iterator.transaction = nil
 }
 }
 
 
-// Next returns the next server entry, by rank, for a ServerEntryCycler. When
-// the ServerEntryCycler has worked through all known server entries, Next will
-// call Reset and start over and return the first server entry again.
-func (cycler *ServerEntryCycler) Next() (serverEntry *ServerEntry, err error) {
+// Next returns the next server entry, by rank, for a ServerEntryIterator.
+// Returns nil with no error when there is no next item.
+func (iterator *ServerEntryIterator) Next() (serverEntry *ServerEntry, err error) {
 	defer func() {
 	defer func() {
 		if err != nil {
 		if err != nil {
-			cycler.Close()
+			iterator.Close()
 		}
 		}
 	}()
 	}()
-	for !cycler.cursor.Next() {
-		err = cycler.cursor.Err()
-		if err != nil {
-			return nil, ContextError(err)
-		}
-		if cycler.isReset {
-			return nil, ContextError(errors.New("no server entries"))
-		}
-		err = cycler.Reset()
+	if !iterator.cursor.Next() {
+		err = iterator.cursor.Err()
 		if err != nil {
 		if err != nil {
 			return nil, ContextError(err)
 			return nil, ContextError(err)
 		}
 		}
+		// There is no next item
+		return nil, nil
 	}
 	}
-	cycler.isReset = false
+
 	var data []byte
 	var data []byte
-	err = cycler.cursor.Scan(&data)
+	err = iterator.cursor.Scan(&data)
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
@@ -271,24 +281,62 @@ func (cycler *ServerEntryCycler) Next() (serverEntry *ServerEntry, err error) {
 	return serverEntry, nil
 	return serverEntry, nil
 }
 }
 
 
+func makeServerEntryWhereClause(
+	region, protocol string, excludeIds []string) (whereClause string, whereParams []interface{}) {
+	whereClause = ""
+	whereParams = make([]interface{}, 0)
+	if region != "" {
+		whereClause += " where region = ?"
+		whereParams = append(whereParams, region)
+	}
+	if protocol != "" {
+		if len(whereClause) > 0 {
+			whereClause += " and"
+		} else {
+			whereClause += " where"
+		}
+		whereClause +=
+			" exists (select 1 from serverEntryProtocol where protocol = ? and serverEntryId = serverEntry.id)"
+		whereParams = append(whereParams, protocol)
+	}
+	if len(excludeIds) > 0 {
+		if len(whereClause) > 0 {
+			whereClause += " and"
+		} else {
+			whereClause += " where"
+		}
+		whereClause += " id in ("
+		for index, id := range excludeIds {
+			if index > 0 {
+				whereClause += ", "
+			}
+			whereClause += "?"
+			whereParams = append(whereParams, id)
+		}
+		whereClause += ")"
+	}
+	return whereClause, whereParams
+}
+
 // HasServerEntries returns true if the data store contains at
 // HasServerEntries returns true if the data store contains at
-// least one server entry (for the specified region, in not blank).
-func HasServerEntries(region string) bool {
+// least one server entry (for the specified region and/or protocol,
+// when not blank).
+func HasServerEntries(region, protocol string) bool {
 	initDataStore()
 	initDataStore()
-	var err error
 	var count int
 	var count int
+	whereClause, whereParams := makeServerEntryWhereClause(region, protocol, nil)
+	query := "select count(*) from serverEntry" + whereClause
+	err := singleton.db.QueryRow(query, whereParams...).Scan(&count)
+
 	if region == "" {
 	if region == "" {
-		err = singleton.db.QueryRow("select count(*) from serverEntry;").Scan(&count)
-		if err == nil {
-			Notice(NOTICE_INFO, "servers: %d", count)
-		}
-	} else {
-		err = singleton.db.QueryRow(
-			"select count(*) from serverEntry where region = ?;", region).Scan(&count)
-		if err == nil {
-			Notice(NOTICE_INFO, "servers for region %s: %d", region, count)
-		}
+		region = "(any)"
 	}
 	}
+	if protocol == "" {
+		protocol = "(any)"
+	}
+	Notice(NOTICE_INFO, "servers for region %s and protocol %s: %d",
+		region, protocol, count)
+
 	return err == nil && count > 0
 	return err == nil && count > 0
 }
 }
 
 

+ 7 - 4
psiphon/defaults.go

@@ -24,19 +24,22 @@ import (
 )
 )
 
 
 const (
 const (
-	VERSION                                  = "0.0.2"
+	VERSION                                  = "0.0.3"
 	DATA_STORE_FILENAME                      = "psiphon.db"
 	DATA_STORE_FILENAME                      = "psiphon.db"
-	FETCH_REMOTE_SERVER_LIST_TIMEOUT         = 5 * time.Second
+	CONNECTION_WORKER_POOL_SIZE              = 10
+	TUNNEL_POOL_SIZE                         = 1
 	TUNNEL_CONNECT_TIMEOUT                   = 15 * time.Second
 	TUNNEL_CONNECT_TIMEOUT                   = 15 * time.Second
 	TUNNEL_READ_TIMEOUT                      = 0 * time.Second
 	TUNNEL_READ_TIMEOUT                      = 0 * time.Second
 	TUNNEL_WRITE_TIMEOUT                     = 5 * time.Second
 	TUNNEL_WRITE_TIMEOUT                     = 5 * time.Second
 	TUNNEL_SSH_KEEP_ALIVE_PERIOD             = 60 * time.Second
 	TUNNEL_SSH_KEEP_ALIVE_PERIOD             = 60 * time.Second
 	ESTABLISH_TUNNEL_TIMEOUT                 = 60 * time.Second
 	ESTABLISH_TUNNEL_TIMEOUT                 = 60 * time.Second
-	CONNECTION_WORKER_POOL_SIZE              = 10
+	ESTABLISH_TUNNEL_PAUSE_PERIOD            = 10 * time.Second
+	PORT_FORWARD_FAILURE_THRESHOLD           = 10
 	HTTP_PROXY_ORIGIN_SERVER_TIMEOUT         = 15 * time.Second
 	HTTP_PROXY_ORIGIN_SERVER_TIMEOUT         = 15 * time.Second
+	HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST = 50
+	FETCH_REMOTE_SERVER_LIST_TIMEOUT         = 5 * time.Second
 	FETCH_REMOTE_SERVER_LIST_RETRY_TIMEOUT   = 5 * time.Second
 	FETCH_REMOTE_SERVER_LIST_RETRY_TIMEOUT   = 5 * time.Second
 	FETCH_REMOTE_SERVER_LIST_STALE_TIMEOUT   = 6 * time.Hour
 	FETCH_REMOTE_SERVER_LIST_STALE_TIMEOUT   = 6 * time.Hour
 	PSIPHON_API_CLIENT_SESSION_ID_LENGTH     = 16
 	PSIPHON_API_CLIENT_SESSION_ID_LENGTH     = 16
 	PSIPHON_API_SERVER_TIMEOUT               = 20 * time.Second
 	PSIPHON_API_SERVER_TIMEOUT               = 20 * time.Second
-	HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST = 50
 )
 )

+ 38 - 37
psiphon/httpProxy.go

@@ -31,39 +31,39 @@ import (
 // HttpProxy is a HTTP server that relays HTTP requests through
 // HttpProxy is a HTTP server that relays HTTP requests through
 // the tunnel SSH client.
 // the tunnel SSH client.
 type HttpProxy struct {
 type HttpProxy struct {
-	tunnel        *Tunnel
-	stoppedSignal chan struct{}
-	listener      net.Listener
-	waitGroup     *sync.WaitGroup
-	httpRelay     *http.Transport
-	openConns     *Conns
+	tunneler       Tunneler
+	listener       net.Listener
+	serveWaitGroup *sync.WaitGroup
+	httpRelay      *http.Transport
+	openConns      *Conns
 }
 }
 
 
 // NewHttpProxy initializes and runs a new HTTP proxy server.
 // NewHttpProxy initializes and runs a new HTTP proxy server.
-func NewHttpProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{}) (proxy *HttpProxy, err error) {
-	listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
+func NewHttpProxy(config *Config, tunneler Tunneler) (proxy *HttpProxy, err error) {
+	listener, err := net.Listen(
+		"tcp", fmt.Sprintf("127.0.0.1:%d", config.LocalHttpProxyPort))
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
-	tunnelledDialer := func(_, targetAddress string) (conn net.Conn, err error) {
+	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
 		// TODO: connect timeout?
 		// TODO: connect timeout?
-		return tunnel.sshClient.Dial("tcp", targetAddress)
+		return tunneler.Dial(addr)
 	}
 	}
+	// TODO: also use http.Client, with its Timeout field?
 	transport := &http.Transport{
 	transport := &http.Transport{
-		Dial:                  tunnelledDialer,
+		Dial:                  tunneledDialer,
 		MaxIdleConnsPerHost:   HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST,
 		MaxIdleConnsPerHost:   HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST,
 		ResponseHeaderTimeout: HTTP_PROXY_ORIGIN_SERVER_TIMEOUT,
 		ResponseHeaderTimeout: HTTP_PROXY_ORIGIN_SERVER_TIMEOUT,
 	}
 	}
 	proxy = &HttpProxy{
 	proxy = &HttpProxy{
-		tunnel:        tunnel,
-		stoppedSignal: stoppedSignal,
-		listener:      listener,
-		waitGroup:     new(sync.WaitGroup),
-		httpRelay:     transport,
-		openConns:     new(Conns),
-	}
-	proxy.waitGroup.Add(1)
-	go proxy.serveHttpRequests()
+		tunneler:       tunneler,
+		listener:       listener,
+		serveWaitGroup: new(sync.WaitGroup),
+		httpRelay:      transport,
+		openConns:      new(Conns),
+	}
+	proxy.serveWaitGroup.Add(1)
+	go proxy.serve()
 	Notice(NOTICE_HTTP_PROXY, "local HTTP proxy running at address %s", proxy.listener.Addr().String())
 	Notice(NOTICE_HTTP_PROXY, "local HTTP proxy running at address %s", proxy.listener.Addr().String())
 	return proxy, nil
 	return proxy, nil
 }
 }
@@ -71,7 +71,7 @@ func NewHttpProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{}) (
 // Close terminates the HTTP server.
 // Close terminates the HTTP server.
 func (proxy *HttpProxy) Close() {
 func (proxy *HttpProxy) Close() {
 	proxy.listener.Close()
 	proxy.listener.Close()
-	proxy.waitGroup.Wait()
+	proxy.serveWaitGroup.Wait()
 	// Close local->proxy persistent connections
 	// Close local->proxy persistent connections
 	proxy.openConns.CloseAll()
 	proxy.openConns.CloseAll()
 	// Close idle proxy->origin persistent connections
 	// Close idle proxy->origin persistent connections
@@ -105,7 +105,7 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
 			return
 			return
 		}
 		}
 		go func() {
 		go func() {
-			err := proxy.httpConnectHandler(proxy.tunnel, conn, request.URL.Host)
+			err := proxy.httpConnectHandler(conn, request.URL.Host)
 			if err != nil {
 			if err != nil {
 				Notice(NOTICE_ALERT, "%s", ContextError(err))
 				Notice(NOTICE_ALERT, "%s", ContextError(err))
 			}
 			}
@@ -117,12 +117,14 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
 		http.Error(responseWriter, "", http.StatusInternalServerError)
 		http.Error(responseWriter, "", http.StatusInternalServerError)
 		return
 		return
 	}
 	}
+
 	// Transform request struct before using as input to relayed request
 	// Transform request struct before using as input to relayed request
 	request.Close = false
 	request.Close = false
 	request.RequestURI = ""
 	request.RequestURI = ""
 	for _, key := range hopHeaders {
 	for _, key := range hopHeaders {
 		request.Header.Del(key)
 		request.Header.Del(key)
 	}
 	}
+
 	// Relay the HTTP request and get the response
 	// Relay the HTTP request and get the response
 	response, err := proxy.httpRelay.RoundTrip(request)
 	response, err := proxy.httpRelay.RoundTrip(request)
 	if err != nil {
 	if err != nil {
@@ -131,6 +133,7 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
 		return
 		return
 	}
 	}
 	defer response.Body.Close()
 	defer response.Body.Close()
+
 	// Relay the remote response headers
 	// Relay the remote response headers
 	for _, key := range hopHeaders {
 	for _, key := range hopHeaders {
 		response.Header.Del(key)
 		response.Header.Del(key)
@@ -143,6 +146,7 @@ func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *h
 			responseWriter.Header().Add(key, value)
 			responseWriter.Header().Add(key, value)
 		}
 		}
 	}
 	}
+
 	// Relay the response code and body
 	// Relay the response code and body
 	responseWriter.WriteHeader(response.StatusCode)
 	responseWriter.WriteHeader(response.StatusCode)
 	_, err = io.Copy(responseWriter, response.Body)
 	_, err = io.Copy(responseWriter, response.Body)
@@ -179,20 +183,20 @@ var hopHeaders = []string{
 	"Upgrade",
 	"Upgrade",
 }
 }
 
 
-func (proxy *HttpProxy) httpConnectHandler(tunnel *Tunnel, localHttpConn net.Conn, target string) (err error) {
-	defer localHttpConn.Close()
-	defer proxy.openConns.Remove(localHttpConn)
-	proxy.openConns.Add(localHttpConn)
-	remoteSshForward, err := tunnel.sshClient.Dial("tcp", target)
+func (proxy *HttpProxy) httpConnectHandler(localConn net.Conn, target string) (err error) {
+	defer localConn.Close()
+	defer proxy.openConns.Remove(localConn)
+	proxy.openConns.Add(localConn)
+	remoteConn, err := proxy.tunneler.Dial(target)
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
-	defer remoteSshForward.Close()
-	_, err = localHttpConn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
+	defer remoteConn.Close()
+	_, err = localConn.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
-	relayPortForward(localHttpConn, remoteSshForward)
+	Relay(localConn, remoteConn)
 	return nil
 	return nil
 }
 }
 
 
@@ -213,9 +217,9 @@ func (proxy *HttpProxy) httpConnStateCallback(conn net.Conn, connState http.Conn
 	}
 	}
 }
 }
 
 
-func (proxy *HttpProxy) serveHttpRequests() {
+func (proxy *HttpProxy) serve() {
 	defer proxy.listener.Close()
 	defer proxy.listener.Close()
-	defer proxy.waitGroup.Done()
+	defer proxy.serveWaitGroup.Done()
 	httpServer := &http.Server{
 	httpServer := &http.Server{
 		Handler:   proxy,
 		Handler:   proxy,
 		ConnState: proxy.httpConnStateCallback,
 		ConnState: proxy.httpConnStateCallback,
@@ -223,10 +227,7 @@ func (proxy *HttpProxy) serveHttpRequests() {
 	// Note: will be interrupted by listener.Close() call made by proxy.Close()
 	// Note: will be interrupted by listener.Close() call made by proxy.Close()
 	err := httpServer.Serve(proxy.listener)
 	err := httpServer.Serve(proxy.listener)
 	if err != nil {
 	if err != nil {
-		select {
-		case proxy.stoppedSignal <- *new(struct{}):
-		default:
-		}
+		proxy.tunneler.SignalFailure()
 		Notice(NOTICE_ALERT, "%s", ContextError(err))
 		Notice(NOTICE_ALERT, "%s", ContextError(err))
 	}
 	}
 	Notice(NOTICE_HTTP_PROXY, "HTTP proxy stopped")
 	Notice(NOTICE_HTTP_PROXY, "HTTP proxy stopped")

+ 13 - 3
psiphon/meekConn.go

@@ -91,6 +91,7 @@ type MeekConn struct {
 func DialMeek(
 func DialMeek(
 	serverEntry *ServerEntry, sessionId string,
 	serverEntry *ServerEntry, sessionId string,
 	useFronting bool, config *DialConfig) (meek *MeekConn, err error) {
 	useFronting bool, config *DialConfig) (meek *MeekConn, err error) {
+
 	// Configure transport
 	// Configure transport
 	// Note: MeekConn has its own PendingConns to manage the underlying HTTP transport connections,
 	// Note: MeekConn has its own PendingConns to manage the underlying HTTP transport connections,
 	// which may be interrupted on MeekConn.Close(). This code previously used the establishTunnel
 	// which may be interrupted on MeekConn.Close(). This code previously used the establishTunnel
@@ -121,6 +122,7 @@ func DialMeek(
 		host = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
 		host = fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.MeekServerPort)
 		dialer = NewTCPDialer(configCopy)
 		dialer = NewTCPDialer(configCopy)
 	}
 	}
+
 	// Scheme is always "http". Otherwise http.Transport will try to do another TLS
 	// Scheme is always "http". Otherwise http.Transport will try to do another TLS
 	// handshake inside the explicit TLS session (in fronting mode).
 	// handshake inside the explicit TLS session (in fronting mode).
 	url := &url.URL{
 	url := &url.URL{
@@ -132,10 +134,12 @@ func DialMeek(
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+	// TODO: also use http.Client, with its Timeout field?
 	transport := &http.Transport{
 	transport := &http.Transport{
 		Dial: dialer,
 		Dial: dialer,
 		ResponseHeaderTimeout: TUNNEL_WRITE_TIMEOUT,
 		ResponseHeaderTimeout: TUNNEL_WRITE_TIMEOUT,
 	}
 	}
+
 	// The main loop of a MeekConn is run in the relay() goroutine.
 	// The main loop of a MeekConn is run in the relay() goroutine.
 	// A MeekConn implements net.Conn concurrency semantics:
 	// A MeekConn implements net.Conn concurrency semantics:
 	// "Multiple goroutines may invoke methods on a Conn simultaneously."
 	// "Multiple goroutines may invoke methods on a Conn simultaneously."
@@ -312,7 +316,7 @@ func (meek *MeekConn) replaceSendBuffer(sendBuffer *bytes.Buffer) {
 	}
 	}
 }
 }
 
 
-// relay sends and receives tunnelled traffic (payload). An HTTP request is
+// relay sends and receives tunneled traffic (payload). An HTTP request is
 // triggered when data is in the write queue or at a polling interval.
 // triggered when data is in the write queue or at a polling interval.
 // There's a geometric increase, up to a maximum, in the polling interval when
 // There's a geometric increase, up to a maximum, in the polling interval when
 // no data is exchanged. Only one HTTP request is in flight at a time.
 // no data is exchanged. Only one HTTP request is in flight at a time.
@@ -321,14 +325,16 @@ func (meek *MeekConn) relay() {
 	// (using goroutines) since Close() will wait on this WaitGroup.
 	// (using goroutines) since Close() will wait on this WaitGroup.
 	defer meek.relayWaitGroup.Done()
 	defer meek.relayWaitGroup.Done()
 	interval := MIN_POLL_INTERVAL
 	interval := MIN_POLL_INTERVAL
-	var sendPayload = make([]byte, MAX_SEND_PAYLOAD_LENGTH)
+	timeout := time.NewTimer(interval)
+	sendPayload := make([]byte, MAX_SEND_PAYLOAD_LENGTH)
 	for {
 	for {
+		timeout.Reset(interval)
 		// Block until there is payload to send or it is time to poll
 		// Block until there is payload to send or it is time to poll
 		var sendBuffer *bytes.Buffer
 		var sendBuffer *bytes.Buffer
 		select {
 		select {
 		case sendBuffer = <-meek.partialSendBuffer:
 		case sendBuffer = <-meek.partialSendBuffer:
 		case sendBuffer = <-meek.fullSendBuffer:
 		case sendBuffer = <-meek.fullSendBuffer:
-		case <-time.After(interval):
+		case <-timeout.C:
 			// In the polling case, send an empty payload
 			// In the polling case, send an empty payload
 		case <-meek.broadcastClosed:
 		case <-meek.broadcastClosed:
 			return
 			return
@@ -458,6 +464,7 @@ type meekCookieData struct {
 // In unfronted meek mode, the cookie is visible over the adversary network, so the
 // In unfronted meek mode, the cookie is visible over the adversary network, so the
 // cookie is encrypted and obfuscated.
 // cookie is encrypted and obfuscated.
 func makeCookie(serverEntry *ServerEntry, sessionId string) (cookie *http.Cookie, err error) {
 func makeCookie(serverEntry *ServerEntry, sessionId string) (cookie *http.Cookie, err error) {
+
 	// Make the JSON data
 	// Make the JSON data
 	serverAddress := fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort)
 	serverAddress := fmt.Sprintf("%s:%d", serverEntry.IpAddress, serverEntry.SshObfuscatedPort)
 	cookieData := &meekCookieData{
 	cookieData := &meekCookieData{
@@ -469,6 +476,7 @@ func makeCookie(serverEntry *ServerEntry, sessionId string) (cookie *http.Cookie
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+
 	// Encrypt the JSON data
 	// Encrypt the JSON data
 	// NaCl box is used for encryption. The peer public key comes from the server entry.
 	// NaCl box is used for encryption. The peer public key comes from the server entry.
 	// Nonce is always all zeros, and is not sent in the cookie (the server also uses an all-zero nonce).
 	// Nonce is always all zeros, and is not sent in the cookie (the server also uses an all-zero nonce).
@@ -491,6 +499,7 @@ func makeCookie(serverEntry *ServerEntry, sessionId string) (cookie *http.Cookie
 	encryptedCookie := make([]byte, 32+len(box))
 	encryptedCookie := make([]byte, 32+len(box))
 	copy(encryptedCookie[0:32], ephemeralPublicKey[0:32])
 	copy(encryptedCookie[0:32], ephemeralPublicKey[0:32])
 	copy(encryptedCookie[32:], box)
 	copy(encryptedCookie[32:], box)
+
 	// Obfuscate the encrypted data
 	// Obfuscate the encrypted data
 	obfuscator, err := NewObfuscator(
 	obfuscator, err := NewObfuscator(
 		&ObfuscatorConfig{Keyword: serverEntry.MeekObfuscatedKey, MaxPadding: MEEK_COOKIE_MAX_PADDING})
 		&ObfuscatorConfig{Keyword: serverEntry.MeekObfuscatedKey, MaxPadding: MEEK_COOKIE_MAX_PADDING})
@@ -501,6 +510,7 @@ func makeCookie(serverEntry *ServerEntry, sessionId string) (cookie *http.Cookie
 	seedLen := len(obfuscatedCookie)
 	seedLen := len(obfuscatedCookie)
 	obfuscatedCookie = append(obfuscatedCookie, encryptedCookie...)
 	obfuscatedCookie = append(obfuscatedCookie, encryptedCookie...)
 	obfuscator.ObfuscateClientToServer(obfuscatedCookie[seedLen:])
 	obfuscator.ObfuscateClientToServer(obfuscatedCookie[seedLen:])
+
 	// Format the HTTP cookie
 	// Format the HTTP cookie
 	// The format is <random letter 'A'-'Z'>=<base64 data>, which is intended to match common cookie formats.
 	// The format is <random letter 'A'-'Z'>=<base64 data>, which is intended to match common cookie formats.
 	A := int('A')
 	A := int('A')

+ 24 - 12
psiphon/obfuscatedSshConn.go

@@ -84,7 +84,7 @@ const (
 func NewObfuscatedSshConn(conn net.Conn, obfuscationKeyword string) (*ObfuscatedSshConn, error) {
 func NewObfuscatedSshConn(conn net.Conn, obfuscationKeyword string) (*ObfuscatedSshConn, error) {
 	obfuscator, err := NewObfuscator(&ObfuscatorConfig{Keyword: obfuscationKeyword})
 	obfuscator, err := NewObfuscator(&ObfuscatorConfig{Keyword: obfuscationKeyword})
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	return &ObfuscatedSshConn{
 	return &ObfuscatedSshConn{
 		Conn:       conn,
 		Conn:       conn,
@@ -111,7 +111,7 @@ func (conn *ObfuscatedSshConn) Write(buffer []byte) (n int, err error) {
 	}
 	}
 	err = conn.transformAndWrite(buffer)
 	err = conn.transformAndWrite(buffer)
 	if err != nil {
 	if err != nil {
-		return 0, err
+		return 0, ContextError(err)
 	}
 	}
 	// Reports that we wrote all the bytes
 	// Reports that we wrote all the bytes
 	// (althogh we may have buffered some or all)
 	// (althogh we may have buffered some or all)
@@ -157,6 +157,7 @@ func (conn *ObfuscatedSshConn) Write(buffer []byte) (n int, err error) {
 // packet may need to be buffered due to partial reading.
 // packet may need to be buffered due to partial reading.
 func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error) {
 func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error) {
 	nextState := conn.readState
 	nextState := conn.readState
+
 	switch conn.readState {
 	switch conn.readState {
 	case OBFUSCATION_READ_STATE_SERVER_IDENTIFICATION_LINE:
 	case OBFUSCATION_READ_STATE_SERVER_IDENTIFICATION_LINE:
 		if len(conn.readBuffer) == 0 {
 		if len(conn.readBuffer) == 0 {
@@ -167,7 +168,7 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 				for len(conn.readBuffer) < SSH_MAX_SERVER_LINE_LENGTH {
 				for len(conn.readBuffer) < SSH_MAX_SERVER_LINE_LENGTH {
 					_, err := io.ReadFull(conn.Conn, oneByte[:])
 					_, err := io.ReadFull(conn.Conn, oneByte[:])
 					if err != nil {
 					if err != nil {
-						return 0, err
+						return 0, ContextError(err)
 					}
 					}
 					conn.obfuscator.ObfuscateServerToClient(oneByte[:])
 					conn.obfuscator.ObfuscateServerToClient(oneByte[:])
 					conn.readBuffer = append(conn.readBuffer, oneByte[0])
 					conn.readBuffer = append(conn.readBuffer, oneByte[0])
@@ -177,7 +178,7 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 					}
 					}
 				}
 				}
 				if !validLine {
 				if !validLine {
-					return 0, errors.New("ObfuscatedSshConn: invalid server line")
+					return 0, ContextError(errors.New("ObfuscatedSshConn: invalid server line"))
 				}
 				}
 				if bytes.HasPrefix(conn.readBuffer, []byte("SSH-")) {
 				if bytes.HasPrefix(conn.readBuffer, []byte("SSH-")) {
 					break
 					break
@@ -187,23 +188,24 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 			}
 			}
 		}
 		}
 		nextState = OBFUSCATION_READ_STATE_SERVER_KEX_PACKETS
 		nextState = OBFUSCATION_READ_STATE_SERVER_KEX_PACKETS
+
 	case OBFUSCATION_READ_STATE_SERVER_KEX_PACKETS:
 	case OBFUSCATION_READ_STATE_SERVER_KEX_PACKETS:
 		if len(conn.readBuffer) == 0 {
 		if len(conn.readBuffer) == 0 {
 			prefix := make([]byte, SSH_PACKET_PREFIX_LENGTH)
 			prefix := make([]byte, SSH_PACKET_PREFIX_LENGTH)
 			_, err := io.ReadFull(conn.Conn, prefix)
 			_, err := io.ReadFull(conn.Conn, prefix)
 			if err != nil {
 			if err != nil {
-				return 0, err
+				return 0, ContextError(err)
 			}
 			}
 			conn.obfuscator.ObfuscateServerToClient(prefix)
 			conn.obfuscator.ObfuscateServerToClient(prefix)
 			packetLength, _, payloadLength, messageLength := getSshPacketPrefix(prefix)
 			packetLength, _, payloadLength, messageLength := getSshPacketPrefix(prefix)
 			if packetLength > SSH_MAX_PACKET_LENGTH {
 			if packetLength > SSH_MAX_PACKET_LENGTH {
-				return 0, errors.New("ObfuscatedSshConn: ssh packet length too large")
+				return 0, ContextError(errors.New("ObfuscatedSshConn: ssh packet length too large"))
 			}
 			}
 			conn.readBuffer = make([]byte, messageLength)
 			conn.readBuffer = make([]byte, messageLength)
 			copy(conn.readBuffer, prefix)
 			copy(conn.readBuffer, prefix)
 			_, err = io.ReadFull(conn.Conn, conn.readBuffer[len(prefix):])
 			_, err = io.ReadFull(conn.Conn, conn.readBuffer[len(prefix):])
 			if err != nil {
 			if err != nil {
-				return 0, err
+				return 0, ContextError(err)
 			}
 			}
 			conn.obfuscator.ObfuscateServerToClient(conn.readBuffer[len(prefix):])
 			conn.obfuscator.ObfuscateServerToClient(conn.readBuffer[len(prefix):])
 			if payloadLength > 0 {
 			if payloadLength > 0 {
@@ -213,11 +215,14 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 				}
 				}
 			}
 			}
 		}
 		}
+
 	case OBFUSCATION_READ_STATE_FLUSH:
 	case OBFUSCATION_READ_STATE_FLUSH:
 		nextState = OBFUSCATION_READ_STATE_FINISHED
 		nextState = OBFUSCATION_READ_STATE_FINISHED
+
 	case OBFUSCATION_READ_STATE_FINISHED:
 	case OBFUSCATION_READ_STATE_FINISHED:
 		panic("ObfuscatedSshConn: invalid read state")
 		panic("ObfuscatedSshConn: invalid read state")
 	}
 	}
+
 	n = copy(buffer, conn.readBuffer)
 	n = copy(buffer, conn.readBuffer)
 	conn.readBuffer = conn.readBuffer[n:]
 	conn.readBuffer = conn.readBuffer[n:]
 	if len(conn.readBuffer) == 0 {
 	if len(conn.readBuffer) == 0 {
@@ -258,15 +263,18 @@ func (conn *ObfuscatedSshConn) readAndTransform(buffer []byte) (n int, err error
 // (The transformer can do this since only the payload and not the padding of
 // (The transformer can do this since only the payload and not the padding of
 // these packets is authenticated in the "exchange hash").
 // these packets is authenticated in the "exchange hash").
 func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
 func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
+
 	if conn.writeState == OBFUSCATION_WRITE_STATE_SEND_CLIENT_SEED_MESSAGE {
 	if conn.writeState == OBFUSCATION_WRITE_STATE_SEND_CLIENT_SEED_MESSAGE {
 		_, err = conn.Conn.Write(conn.obfuscator.ConsumeSeedMessage())
 		_, err = conn.Conn.Write(conn.obfuscator.ConsumeSeedMessage())
 		if err != nil {
 		if err != nil {
-			return err
+			return ContextError(err)
 		}
 		}
 		conn.writeState = OBFUSCATION_WRITE_STATE_CLIENT_IDENTIFICATION_LINE
 		conn.writeState = OBFUSCATION_WRITE_STATE_CLIENT_IDENTIFICATION_LINE
 	}
 	}
+
 	conn.writeBuffer = append(conn.writeBuffer, buffer...)
 	conn.writeBuffer = append(conn.writeBuffer, buffer...)
 	var messageBuffer []byte
 	var messageBuffer []byte
+
 	switch conn.writeState {
 	switch conn.writeState {
 	case OBFUSCATION_WRITE_STATE_CLIENT_IDENTIFICATION_LINE:
 	case OBFUSCATION_WRITE_STATE_CLIENT_IDENTIFICATION_LINE:
 		index := bytes.Index(conn.writeBuffer, []byte("\r\n"))
 		index := bytes.Index(conn.writeBuffer, []byte("\r\n"))
@@ -276,6 +284,7 @@ func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
 			conn.writeBuffer = conn.writeBuffer[messageLength:]
 			conn.writeBuffer = conn.writeBuffer[messageLength:]
 			conn.writeState = OBFUSCATION_WRITE_STATE_CLIENT_KEX_PACKETS
 			conn.writeState = OBFUSCATION_WRITE_STATE_CLIENT_KEX_PACKETS
 		}
 		}
+
 	case OBFUSCATION_WRITE_STATE_CLIENT_KEX_PACKETS:
 	case OBFUSCATION_WRITE_STATE_CLIENT_KEX_PACKETS:
 		for len(conn.writeBuffer) >= SSH_PACKET_PREFIX_LENGTH {
 		for len(conn.writeBuffer) >= SSH_PACKET_PREFIX_LENGTH {
 			packetLength, paddingLength, payloadLength, messageLength := getSshPacketPrefix(conn.writeBuffer)
 			packetLength, paddingLength, payloadLength, messageLength := getSshPacketPrefix(conn.writeBuffer)
@@ -297,33 +306,36 @@ func (conn *ObfuscatedSshConn) transformAndWrite(buffer []byte) (err error) {
 			if possiblePaddings > 0 {
 			if possiblePaddings > 0 {
 				selectedPadding, err := MakeSecureRandomInt(possiblePaddings)
 				selectedPadding, err := MakeSecureRandomInt(possiblePaddings)
 				if err != nil {
 				if err != nil {
-					return err
+					return ContextError(err)
 				}
 				}
 				extraPaddingLength := selectedPadding * SSH_PADDING_MULTIPLE
 				extraPaddingLength := selectedPadding * SSH_PADDING_MULTIPLE
 				extraPadding, err := MakeSecureRandomBytes(extraPaddingLength)
 				extraPadding, err := MakeSecureRandomBytes(extraPaddingLength)
 				if err != nil {
 				if err != nil {
-					return err
+					return ContextError(err)
 				}
 				}
 				setSshPacketPrefix(
 				setSshPacketPrefix(
 					messageBuffer, packetLength+extraPaddingLength, paddingLength+extraPaddingLength)
 					messageBuffer, packetLength+extraPaddingLength, paddingLength+extraPaddingLength)
 				messageBuffer = append(messageBuffer, extraPadding...)
 				messageBuffer = append(messageBuffer, extraPadding...)
 			}
 			}
 		}
 		}
+
 	case OBFUSCATION_WRITE_STATE_FINISHED:
 	case OBFUSCATION_WRITE_STATE_FINISHED:
 		panic("ObfuscatedSshConn: invalid write state")
 		panic("ObfuscatedSshConn: invalid write state")
 	}
 	}
+
 	if messageBuffer != nil {
 	if messageBuffer != nil {
 		conn.obfuscator.ObfuscateClientToServer(messageBuffer)
 		conn.obfuscator.ObfuscateClientToServer(messageBuffer)
 		_, err := conn.Conn.Write(messageBuffer)
 		_, err := conn.Conn.Write(messageBuffer)
 		if err != nil {
 		if err != nil {
-			return err
+			return ContextError(err)
 		}
 		}
 	}
 	}
+
 	if conn.writeState == OBFUSCATION_WRITE_STATE_FINISHED {
 	if conn.writeState == OBFUSCATION_WRITE_STATE_FINISHED {
 		// After SSH_MSG_NEWKEYS, any remaining bytes are un-obfuscated
 		// After SSH_MSG_NEWKEYS, any remaining bytes are un-obfuscated
 		_, err := conn.Conn.Write(conn.writeBuffer)
 		_, err := conn.Conn.Write(conn.writeBuffer)
 		if err != nil {
 		if err != nil {
-			return err
+			return ContextError(err)
 		}
 		}
 		// The buffer memory is no longer used
 		// The buffer memory is no longer used
 		conn.writeBuffer = nil
 		conn.writeBuffer = nil

+ 13 - 13
psiphon/obfuscator.go

@@ -57,23 +57,23 @@ type ObfuscatorConfig struct {
 func NewObfuscator(config *ObfuscatorConfig) (obfuscator *Obfuscator, err error) {
 func NewObfuscator(config *ObfuscatorConfig) (obfuscator *Obfuscator, err error) {
 	seed, err := MakeSecureRandomBytes(OBFUSCATE_SEED_LENGTH)
 	seed, err := MakeSecureRandomBytes(OBFUSCATE_SEED_LENGTH)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	clientToServerKey, err := deriveKey(seed, []byte(config.Keyword), []byte(OBFUSCATE_CLIENT_TO_SERVER_IV))
 	clientToServerKey, err := deriveKey(seed, []byte(config.Keyword), []byte(OBFUSCATE_CLIENT_TO_SERVER_IV))
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	serverToClientKey, err := deriveKey(seed, []byte(config.Keyword), []byte(OBFUSCATE_SERVER_TO_CLIENT_IV))
 	serverToClientKey, err := deriveKey(seed, []byte(config.Keyword), []byte(OBFUSCATE_SERVER_TO_CLIENT_IV))
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	clientToServerCipher, err := rc4.NewCipher(clientToServerKey)
 	clientToServerCipher, err := rc4.NewCipher(clientToServerKey)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	serverToClientCipher, err := rc4.NewCipher(serverToClientKey)
 	serverToClientCipher, err := rc4.NewCipher(serverToClientKey)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	maxPadding := OBFUSCATE_MAX_PADDING
 	maxPadding := OBFUSCATE_MAX_PADDING
 	if config.MaxPadding > 0 {
 	if config.MaxPadding > 0 {
@@ -81,7 +81,7 @@ func NewObfuscator(config *ObfuscatorConfig) (obfuscator *Obfuscator, err error)
 	}
 	}
 	seedMessage, err := makeSeedMessage(maxPadding, seed, clientToServerCipher)
 	seedMessage, err := makeSeedMessage(maxPadding, seed, clientToServerCipher)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	return &Obfuscator{
 	return &Obfuscator{
 		seedMessage:          seedMessage,
 		seedMessage:          seedMessage,
@@ -119,7 +119,7 @@ func deriveKey(seed, keyword, iv []byte) ([]byte, error) {
 		digest = h.Sum(nil)
 		digest = h.Sum(nil)
 	}
 	}
 	if len(digest) < OBFUSCATE_KEY_LENGTH {
 	if len(digest) < OBFUSCATE_KEY_LENGTH {
-		return nil, errors.New("insufficient bytes for obfuscation key")
+		return nil, ContextError(errors.New("insufficient bytes for obfuscation key"))
 	}
 	}
 	return digest[0:OBFUSCATE_KEY_LENGTH], nil
 	return digest[0:OBFUSCATE_KEY_LENGTH], nil
 }
 }
@@ -127,28 +127,28 @@ func deriveKey(seed, keyword, iv []byte) ([]byte, error) {
 func makeSeedMessage(maxPadding int, seed []byte, clientToServerCipher *rc4.Cipher) ([]byte, error) {
 func makeSeedMessage(maxPadding int, seed []byte, clientToServerCipher *rc4.Cipher) ([]byte, error) {
 	paddingLength, err := MakeSecureRandomInt(maxPadding)
 	paddingLength, err := MakeSecureRandomInt(maxPadding)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	padding, err := MakeSecureRandomBytes(paddingLength)
 	padding, err := MakeSecureRandomBytes(paddingLength)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	buffer := new(bytes.Buffer)
 	buffer := new(bytes.Buffer)
 	err = binary.Write(buffer, binary.BigEndian, seed)
 	err = binary.Write(buffer, binary.BigEndian, seed)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	err = binary.Write(buffer, binary.BigEndian, uint32(OBFUSCATE_MAGIC_VALUE))
 	err = binary.Write(buffer, binary.BigEndian, uint32(OBFUSCATE_MAGIC_VALUE))
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	err = binary.Write(buffer, binary.BigEndian, uint32(paddingLength))
 	err = binary.Write(buffer, binary.BigEndian, uint32(paddingLength))
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	err = binary.Write(buffer, binary.BigEndian, padding)
 	err = binary.Write(buffer, binary.BigEndian, padding)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	seedMessage := buffer.Bytes()
 	seedMessage := buffer.Bytes()
 	clientToServerCipher.XORKeyStream(seedMessage[len(seed):], seedMessage[len(seed):])
 	clientToServerCipher.XORKeyStream(seedMessage[len(seed):], seedMessage[len(seed):])

+ 18 - 2
psiphon/remoteServerList.go

@@ -45,20 +45,35 @@ type RemoteServerList struct {
 // config.RemoteServerListUrl; validates its digital signature using the
 // config.RemoteServerListUrl; validates its digital signature using the
 // public key config.RemoteServerListSignaturePublicKey; and parses the
 // public key config.RemoteServerListSignaturePublicKey; and parses the
 // data field into ServerEntry records.
 // data field into ServerEntry records.
-func FetchRemoteServerList(config *Config) (err error) {
+func FetchRemoteServerList(config *Config, pendingConns *Conns) (err error) {
 	Notice(NOTICE_INFO, "fetching remote server list")
 	Notice(NOTICE_INFO, "fetching remote server list")
+
+	// Note: pendingConns may be used to interrupt the fetch remote server list
+	// request. BindToDevice may be used to exclude requests from VPN routing.
+	dialConfig := &DialConfig{
+		PendingConns:               pendingConns,
+		BindToDeviceServiceAddress: config.BindToDeviceServiceAddress,
+		BindToDeviceDnsServer:      config.BindToDeviceDnsServer,
+	}
+	transport := &http.Transport{
+		Dial: NewTCPDialer(dialConfig),
+	}
 	httpClient := http.Client{
 	httpClient := http.Client{
-		Timeout: FETCH_REMOTE_SERVER_LIST_TIMEOUT,
+		Timeout:   FETCH_REMOTE_SERVER_LIST_TIMEOUT,
+		Transport: transport,
 	}
 	}
+
 	response, err := httpClient.Get(config.RemoteServerListUrl)
 	response, err := httpClient.Get(config.RemoteServerListUrl)
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
 	defer response.Body.Close()
 	defer response.Body.Close()
+
 	body, err := ioutil.ReadAll(response.Body)
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
+
 	var remoteServerList *RemoteServerList
 	var remoteServerList *RemoteServerList
 	err = json.Unmarshal(body, &remoteServerList)
 	err = json.Unmarshal(body, &remoteServerList)
 	if err != nil {
 	if err != nil {
@@ -68,6 +83,7 @@ func FetchRemoteServerList(config *Config) (err error) {
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
+
 	for _, encodedServerEntry := range strings.Split(remoteServerList.Data, "\n") {
 	for _, encodedServerEntry := range strings.Split(remoteServerList.Data, "\n") {
 		serverEntry, err := DecodeServerEntry(encodedServerEntry)
 		serverEntry, err := DecodeServerEntry(encodedServerEntry)
 		if err != nil {
 		if err != nil {

+ 0 - 227
psiphon/runTunnel.go

@@ -1,227 +0,0 @@
-/*
- * Copyright (c) 2014, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program.  If not, see <http://www.gnu.org/licenses/>.
- *
- */
-
-// Package psiphon implements the core tunnel functionality of a Psiphon client.
-// The main interface is RunTunnelForever, which obtains lists of servers,
-// establishes tunnel connections, and runs local proxies through which
-// tunnelled traffic may be sent.
-package psiphon
-
-import (
-	"errors"
-	"fmt"
-	"log"
-	"os"
-	"sync"
-	"time"
-)
-
-// establishTunnelWorker pulls candidates from the potential tunnel queue, establishes
-// a connection to the tunnel server, and delivers the established tunnel to a channel,
-// if there's not already an established tunnel. This function is to be used in a pool
-// of goroutines.
-func establishTunnelWorker(
-	config *Config,
-	sessionId string,
-	workerWaitGroup *sync.WaitGroup,
-	candidateServerEntries chan *ServerEntry,
-	broadcastStopWorkers chan struct{},
-	pendingConns *Conns,
-	establishedTunnels chan *Tunnel) {
-
-	defer workerWaitGroup.Done()
-	for serverEntry := range candidateServerEntries {
-		// Note: don't receive from candidateQueue and broadcastStopWorkers in the same
-		// select, since we want to prioritize receiving the stop signal
-		select {
-		case <-broadcastStopWorkers:
-			return
-		default:
-		}
-		tunnel, err := EstablishTunnel(config, sessionId, serverEntry, pendingConns)
-		if err != nil {
-			// TODO: distingush case where conn is interrupted?
-			Notice(NOTICE_INFO, "failed to connect to %s: %s", serverEntry.IpAddress, err)
-		} else {
-			Notice(NOTICE_INFO, "successfully connected to %s", serverEntry.IpAddress)
-			select {
-			case establishedTunnels <- tunnel:
-			default:
-				discardTunnel(tunnel)
-			}
-		}
-	}
-}
-
-// discardTunnel is used to dispose of a successful connection that is
-// no longer required (another tunnel has already been selected). Since
-// the connection was successful, the server entry is still promoted.
-func discardTunnel(tunnel *Tunnel) {
-	Notice(NOTICE_INFO, "discard connection to %s", tunnel.serverEntry.IpAddress)
-	PromoteServerEntry(tunnel.serverEntry.IpAddress)
-	tunnel.Close()
-}
-
-// establishTunnel coordinates a worker pool of goroutines to attempt several
-// tunnel connections in parallel, and this process is stopped once the first
-// tunnel is established.
-func establishTunnel(config *Config, sessionId string) (tunnel *Tunnel, err error) {
-	workerWaitGroup := new(sync.WaitGroup)
-	candidateServerEntries := make(chan *ServerEntry)
-	pendingConns := new(Conns)
-	establishedTunnels := make(chan *Tunnel, 1)
-	timeout := time.After(ESTABLISH_TUNNEL_TIMEOUT)
-	broadcastStopWorkers := make(chan struct{})
-	for i := 0; i < config.ConnectionWorkerPoolSize; i++ {
-		workerWaitGroup.Add(1)
-		go establishTunnelWorker(
-			config, sessionId,
-			workerWaitGroup, candidateServerEntries, broadcastStopWorkers,
-			pendingConns, establishedTunnels)
-	}
-	// TODO: add a throttle after each full cycle?
-	// Note: errors fall through to ensure worker and channel cleanup (is started, at least)
-	var selectedTunnel *Tunnel
-	cycler, err := NewServerEntryCycler(config.EgressRegion)
-	for selectedTunnel == nil && err == nil {
-		var serverEntry *ServerEntry
-		// Note: don't mask err here, we want to reference it after the loop
-		serverEntry, err = cycler.Next()
-		if err != nil {
-			break
-		}
-		select {
-		case candidateServerEntries <- serverEntry:
-		case selectedTunnel = <-establishedTunnels:
-			Notice(NOTICE_INFO, "selected connection to %s", selectedTunnel.serverEntry.IpAddress)
-		case <-timeout:
-			err = errors.New("timeout establishing tunnel")
-		}
-	}
-	if cycler != nil {
-		cycler.Close()
-	}
-	close(candidateServerEntries)
-	close(broadcastStopWorkers)
-	// Clean up is now asynchronous since Windows doesn't support interruptible connections
-	go func() {
-		// Interrupt any partial connections in progress, so that
-		// the worker will terminate immediately
-		pendingConns.CloseAll()
-		workerWaitGroup.Wait()
-		// Drain any excess tunnels
-		close(establishedTunnels)
-		for tunnel := range establishedTunnels {
-			discardTunnel(tunnel)
-		}
-		// Note: only call this PromoteServerEntry after all discards so the selected
-		// tunnel is the top ranked
-		if selectedTunnel != nil {
-			PromoteServerEntry(selectedTunnel.serverEntry.IpAddress)
-		}
-	}()
-	// Note: end of error fall through
-	if err != nil {
-		return nil, ContextError(err)
-	}
-	return selectedTunnel, nil
-}
-
-// runTunnel establishes a tunnel session and runs local proxies that make use of
-// that tunnel. The tunnel connection is monitored and this function returns an
-// error when the tunnel unexpectedly disconnects.
-func runTunnel(config *Config) error {
-	Notice(NOTICE_INFO, "establishing tunnel")
-	sessionId, err := MakeSessionId()
-	if err != nil {
-		return ContextError(err)
-	}
-	tunnel, err := establishTunnel(config, sessionId)
-	if err != nil {
-		return ContextError(err)
-	}
-	defer tunnel.Close()
-	// Tunnel connection and local proxies will send signals to this channel
-	// when they close or stop. Signal senders should not block. Allows at
-	// least one stop signal to be sent before there is a receiver.
-	stopTunnelSignal := make(chan struct{}, 1)
-	err = tunnel.conn.SetClosedSignal(stopTunnelSignal)
-	if err != nil {
-		return fmt.Errorf("failed to set closed signal: %s", err)
-	}
-	socksProxy, err := NewSocksProxy(config.LocalSocksProxyPort, tunnel, stopTunnelSignal)
-	if err != nil {
-		return fmt.Errorf("error initializing local SOCKS proxy: %s", err)
-	}
-	defer socksProxy.Close()
-	httpProxy, err := NewHttpProxy(config.LocalHttpProxyPort, tunnel, stopTunnelSignal)
-	if err != nil {
-		return fmt.Errorf("error initializing local HTTP proxy: %s", err)
-	}
-	defer httpProxy.Close()
-	Notice(NOTICE_INFO, "starting session")
-	localHttpProxyAddress := httpProxy.listener.Addr().String()
-	_, err = NewSession(config, tunnel, localHttpProxyAddress, sessionId)
-	if err != nil {
-		return fmt.Errorf("error starting session: %s", err)
-	}
-	Notice(NOTICE_TUNNEL, "tunnel started")
-	Notice(NOTICE_INFO, "monitoring tunnel")
-	<-stopTunnelSignal
-	Notice(NOTICE_TUNNEL, "tunnel stopped")
-	return nil
-}
-
-// RunTunnelForever executes the main loop of the Psiphon client. It establishes
-// a tunnel and reconnects when the tunnel unexpectedly disconnects.
-// FetchRemoteServerList is used to obtain a fresh list of servers to attempt
-// to connect to.
-func RunTunnelForever(config *Config) {
-	if config.LogFilename != "" {
-		logFile, err := os.OpenFile(config.LogFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
-		if err != nil {
-			Fatal("error opening log file: %s", err)
-		}
-		defer logFile.Close()
-		log.SetOutput(logFile)
-	}
-	Notice(NOTICE_VERSION, VERSION)
-	// TODO: unlike existing Psiphon clients, this code
-	// always makes the fetch remote server list request
-	go func() {
-		for {
-			err := FetchRemoteServerList(config)
-			if err != nil {
-				Notice(NOTICE_ALERT, "failed to fetch remote server list: %s", err)
-				time.Sleep(FETCH_REMOTE_SERVER_LIST_RETRY_TIMEOUT)
-			} else {
-				time.Sleep(FETCH_REMOTE_SERVER_LIST_STALE_TIMEOUT)
-			}
-		}
-	}()
-	for {
-		if HasServerEntries(config.EgressRegion) {
-			err := runTunnel(config)
-			if err != nil {
-				Notice(NOTICE_ALERT, "run tunnel error: %s", err)
-			}
-		}
-		time.Sleep(1 * time.Second)
-	}
-}

+ 27 - 39
psiphon/serverApi.go

@@ -25,6 +25,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"net/http"
 	"strconv"
 	"strconv"
 )
 )
@@ -32,12 +33,10 @@ import (
 // Session is a utility struct which holds all of the data associated
 // Session is a utility struct which holds all of the data associated
 // with a Psiphon session. In addition to the established tunnel, this
 // with a Psiphon session. In addition to the established tunnel, this
 // includes the session ID (used for Psiphon API requests) and a http
 // includes the session ID (used for Psiphon API requests) and a http
-// client configured to make tunnelled Psiphon API requests.
+// client configured to make tunneled Psiphon API requests.
 type Session struct {
 type Session struct {
-	sessionId          string
 	config             *Config
 	config             *Config
 	tunnel             *Tunnel
 	tunnel             *Tunnel
-	pendingConns       *Conns
 	psiphonHttpsClient *http.Client
 	psiphonHttpsClient *http.Client
 }
 }
 
 
@@ -45,21 +44,15 @@ type Session struct {
 // Psiphon server and returns a Session struct, initialized with the
 // Psiphon server and returns a Session struct, initialized with the
 // session ID, for use with subsequent Psiphon server API requests (e.g.,
 // session ID, for use with subsequent Psiphon server API requests (e.g.,
 // periodic status requests).
 // periodic status requests).
-func NewSession(
-	config *Config,
-	tunnel *Tunnel,
-	localHttpProxyAddress, sessionId string) (session *Session, err error) {
+func NewSession(config *Config, tunnel *Tunnel) (session *Session, err error) {
 
 
-	pendingConns := new(Conns)
-	psiphonHttpsClient, err := makePsiphonHttpsClient(tunnel, pendingConns, localHttpProxyAddress)
+	psiphonHttpsClient, err := makePsiphonHttpsClient(tunnel)
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
 	session = &Session{
 	session = &Session{
-		sessionId:          sessionId,
 		config:             config,
 		config:             config,
 		tunnel:             tunnel,
 		tunnel:             tunnel,
-		pendingConns:       pendingConns,
 		psiphonHttpsClient: psiphonHttpsClient,
 		psiphonHttpsClient: psiphonHttpsClient,
 	}
 	}
 	// Sending two seperate requests is a legacy from when the handshake was
 	// Sending two seperate requests is a legacy from when the handshake was
@@ -74,6 +67,7 @@ func NewSession(
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+
 	return session, nil
 	return session, nil
 }
 }
 
 
@@ -148,12 +142,15 @@ func (session *Session) doHandshakeRequest() error {
 	if upgradeClientVersion > session.config.ClientVersion {
 	if upgradeClientVersion > session.config.ClientVersion {
 		Notice(NOTICE_UPGRADE, "%d", upgradeClientVersion)
 		Notice(NOTICE_UPGRADE, "%d", upgradeClientVersion)
 	}
 	}
-	for _, pageViewRegex := range handshakeConfig.PageViewRegexes {
-		Notice(NOTICE_PAGE_VIEW_REGEX, "%s %s", pageViewRegex["regex"], pageViewRegex["replace"])
-	}
-	for _, httpsRequestRegex := range handshakeConfig.HttpsRequestRegexes {
-		Notice(NOTICE_HTTPS_REGEX, "%s %s", httpsRequestRegex["regex"], httpsRequestRegex["replace"])
-	}
+	// TODO: remove regex notices -- regexes will be used internally
+	/*
+		for _, pageViewRegex := range handshakeConfig.PageViewRegexes {
+			Notice(NOTICE_PAGE_VIEW_REGEX, "%s %s", pageViewRegex["regex"], pageViewRegex["replace"])
+		}
+		for _, httpsRequestRegex := range handshakeConfig.HttpsRequestRegexes {
+			Notice(NOTICE_HTTPS_REGEX, "%s %s", httpsRequestRegex["regex"], httpsRequestRegex["replace"])
+		}
+	*/
 	return nil
 	return nil
 }
 }
 
 
@@ -174,7 +171,7 @@ func (session *Session) doConnectedRequest() error {
 	}
 	}
 	url := session.buildRequestUrl(
 	url := session.buildRequestUrl(
 		"connected",
 		"connected",
-		&ExtraParam{"session_id", session.sessionId},
+		&ExtraParam{"session_id", session.tunnel.sessionId},
 		&ExtraParam{"last_connected", lastConnected})
 		&ExtraParam{"last_connected", lastConnected})
 	responseBody, err := session.doGetRequest(url)
 	responseBody, err := session.doGetRequest(url)
 	if err != nil {
 	if err != nil {
@@ -210,7 +207,7 @@ func (session *Session) buildRequestUrl(path string, extraParams ...*ExtraParam)
 	requestUrl.WriteString("/")
 	requestUrl.WriteString("/")
 	requestUrl.WriteString(path)
 	requestUrl.WriteString(path)
 	requestUrl.WriteString("?client_session_id=")
 	requestUrl.WriteString("?client_session_id=")
-	requestUrl.WriteString(session.sessionId)
+	requestUrl.WriteString(session.tunnel.sessionId)
 	requestUrl.WriteString("&server_secret=")
 	requestUrl.WriteString("&server_secret=")
 	requestUrl.WriteString(session.tunnel.serverEntry.WebServerSecret)
 	requestUrl.WriteString(session.tunnel.serverEntry.WebServerSecret)
 	requestUrl.WriteString("&propagation_channel_id=")
 	requestUrl.WriteString("&propagation_channel_id=")
@@ -253,36 +250,24 @@ func (session *Session) doGetRequest(requestUrl string) (responseBody []byte, er
 	return body, nil
 	return body, nil
 }
 }
 
 
-// makeHttpsClient creates a Psiphon HTTPS client that uses the local http proxy to tunnel
-// requests and which validates the web server using the Psiphon server entry web server certificate.
+// makeHttpsClient creates a Psiphon HTTPS client that tunnels requests and which validates
+// the web server using the Psiphon server entry web server certificate.
 // This is not a general purpose HTTPS client.
 // This is not a general purpose HTTPS client.
 // As the custom dialer makes an explicit TLS connection, URLs submitted to the returned
 // As the custom dialer makes an explicit TLS connection, URLs submitted to the returned
 // http.Client should use the "http://" scheme. Otherwise http.Transport will try to do another TLS
 // http.Client should use the "http://" scheme. Otherwise http.Transport will try to do another TLS
 // handshake inside the explicit TLS session.
 // handshake inside the explicit TLS session.
-func makePsiphonHttpsClient(
-	tunnel *Tunnel, pendingConns *Conns,
-	localHttpProxyAddress string) (httpsClient *http.Client, err error) {
-
+func makePsiphonHttpsClient(tunnel *Tunnel) (httpsClient *http.Client, err error) {
 	certificate, err := DecodeCertificate(tunnel.serverEntry.WebServerCertificate)
 	certificate, err := DecodeCertificate(tunnel.serverEntry.WebServerCertificate)
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
-	// Note: This use of readTimeout will tear down persistent HTTP connections, which is not the
-	// intended purpose. The readTimeout is to abort NewSession when the Psiphon server responds to
-	// handshake/connected requests but fails to deliver the response body (e.g., ResponseHeaderTimeout
-	// is not sufficient to timeout this case).
-	tcpDialer := NewTCPDialer(
-		&DialConfig{
-			ConnectTimeout: PSIPHON_API_SERVER_TIMEOUT,
-			ReadTimeout:    PSIPHON_API_SERVER_TIMEOUT,
-			WriteTimeout:   PSIPHON_API_SERVER_TIMEOUT,
-			PendingConns:   pendingConns,
-		})
+	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
+		return tunnel.sshClient.Dial("tcp", addr)
+	}
 	dialer := NewCustomTLSDialer(
 	dialer := NewCustomTLSDialer(
 		&CustomTLSConfig{
 		&CustomTLSConfig{
-			Dial:                    tcpDialer,
+			Dial:                    tunneledDialer,
 			Timeout:                 PSIPHON_API_SERVER_TIMEOUT,
 			Timeout:                 PSIPHON_API_SERVER_TIMEOUT,
-			HttpProxyAddress:        localHttpProxyAddress,
 			SendServerName:          false,
 			SendServerName:          false,
 			VerifyLegacyCertificate: certificate,
 			VerifyLegacyCertificate: certificate,
 		})
 		})
@@ -290,5 +275,8 @@ func makePsiphonHttpsClient(
 		Dial: dialer,
 		Dial: dialer,
 		ResponseHeaderTimeout: PSIPHON_API_SERVER_TIMEOUT,
 		ResponseHeaderTimeout: PSIPHON_API_SERVER_TIMEOUT,
 	}
 	}
-	return &http.Client{Transport: transport}, nil
+	return &http.Client{
+		Transport: transport,
+		Timeout:   PSIPHON_API_SERVER_TIMEOUT,
+	}, nil
 }
 }

+ 27 - 52
psiphon/socksProxy.go

@@ -22,7 +22,6 @@ package psiphon
 import (
 import (
 	"fmt"
 	"fmt"
 	socks "github.com/Psiphon-Inc/goptlib"
 	socks "github.com/Psiphon-Inc/goptlib"
-	"io"
 	"net"
 	"net"
 	"sync"
 	"sync"
 )
 )
@@ -32,30 +31,29 @@ import (
 // the tunnel SSH client and relays traffic through the port
 // the tunnel SSH client and relays traffic through the port
 // forward.
 // forward.
 type SocksProxy struct {
 type SocksProxy struct {
-	tunnel        *Tunnel
-	stoppedSignal chan struct{}
-	listener      *socks.SocksListener
-	waitGroup     *sync.WaitGroup
-	openConns     *Conns
+	tunneler       Tunneler
+	listener       *socks.SocksListener
+	serveWaitGroup *sync.WaitGroup
+	openConns      *Conns
 }
 }
 
 
 // NewSocksProxy initializes a new SOCKS server. It begins listening for
 // NewSocksProxy initializes a new SOCKS server. It begins listening for
 // connections, starts a goroutine that runs an accept loop, and returns
 // connections, starts a goroutine that runs an accept loop, and returns
 // leaving the accept loop running.
 // leaving the accept loop running.
-func NewSocksProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{}) (proxy *SocksProxy, err error) {
-	listener, err := socks.ListenSocks("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
+func NewSocksProxy(config *Config, tunneler Tunneler) (proxy *SocksProxy, err error) {
+	listener, err := socks.ListenSocks(
+		"tcp", fmt.Sprintf("127.0.0.1:%d", config.LocalSocksProxyPort))
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	proxy = &SocksProxy{
 	proxy = &SocksProxy{
-		tunnel:        tunnel,
-		stoppedSignal: stoppedSignal,
-		listener:      listener,
-		waitGroup:     new(sync.WaitGroup),
-		openConns:     new(Conns),
+		tunneler:       tunneler,
+		listener:       listener,
+		serveWaitGroup: new(sync.WaitGroup),
+		openConns:      new(Conns),
 	}
 	}
-	proxy.waitGroup.Add(1)
-	go proxy.acceptSocksConnections()
+	proxy.serveWaitGroup.Add(1)
+	go proxy.serve()
 	Notice(NOTICE_SOCKS_PROXY, "local SOCKS proxy running at address %s", proxy.listener.Addr().String())
 	Notice(NOTICE_SOCKS_PROXY, "local SOCKS proxy running at address %s", proxy.listener.Addr().String())
 	return proxy, nil
 	return proxy, nil
 }
 }
@@ -64,60 +62,37 @@ func NewSocksProxy(listenPort int, tunnel *Tunnel, stoppedSignal chan struct{})
 // goroutine to complete.
 // goroutine to complete.
 func (proxy *SocksProxy) Close() {
 func (proxy *SocksProxy) Close() {
 	proxy.listener.Close()
 	proxy.listener.Close()
-	proxy.waitGroup.Wait()
+	proxy.serveWaitGroup.Wait()
 	proxy.openConns.CloseAll()
 	proxy.openConns.CloseAll()
 }
 }
 
 
-func (proxy *SocksProxy) socksConnectionHandler(tunnel *Tunnel, localSocksConn *socks.SocksConn) (err error) {
-	defer localSocksConn.Close()
-	defer proxy.openConns.Remove(localSocksConn)
-	proxy.openConns.Add(localSocksConn)
-	remoteSshForward, err := tunnel.sshClient.Dial("tcp", localSocksConn.Req.Target)
+func (proxy *SocksProxy) socksConnectionHandler(localConn *socks.SocksConn) (err error) {
+	defer localConn.Close()
+	defer proxy.openConns.Remove(localConn)
+	proxy.openConns.Add(localConn)
+	remoteConn, err := proxy.tunneler.Dial(localConn.Req.Target)
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
-	defer remoteSshForward.Close()
-	err = localSocksConn.Grant(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 0})
+	defer remoteConn.Close()
+	err = localConn.Grant(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 0})
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
-	relayPortForward(localSocksConn, remoteSshForward)
+	Relay(localConn, remoteConn)
 	return nil
 	return nil
 }
 }
 
 
-// relayPortForward is also used by HttpProxy
-func relayPortForward(local, remote net.Conn) {
-	// TODO: page view stats would be done here
-	// TODO: interrupt and stop on proxy.Close()
-	waitGroup := new(sync.WaitGroup)
-	waitGroup.Add(1)
-	go func() {
-		defer waitGroup.Done()
-		_, err := io.Copy(local, remote)
-		if err != nil {
-			Notice(NOTICE_ALERT, "%s", ContextError(err))
-		}
-	}()
-	_, err := io.Copy(remote, local)
-	if err != nil {
-		Notice(NOTICE_ALERT, "%s", ContextError(err))
-	}
-	waitGroup.Wait()
-}
-
-func (proxy *SocksProxy) acceptSocksConnections() {
+func (proxy *SocksProxy) serve() {
 	defer proxy.listener.Close()
 	defer proxy.listener.Close()
-	defer proxy.waitGroup.Done()
+	defer proxy.serveWaitGroup.Done()
 	for {
 	for {
 		// Note: will be interrupted by listener.Close() call made by proxy.Close()
 		// Note: will be interrupted by listener.Close() call made by proxy.Close()
 		socksConnection, err := proxy.listener.AcceptSocks()
 		socksConnection, err := proxy.listener.AcceptSocks()
 		if err != nil {
 		if err != nil {
 			Notice(NOTICE_ALERT, "SOCKS proxy accept error: %s", err)
 			Notice(NOTICE_ALERT, "SOCKS proxy accept error: %s", err)
 			if e, ok := err.(net.Error); ok && !e.Temporary() {
 			if e, ok := err.(net.Error); ok && !e.Temporary() {
-				select {
-				case proxy.stoppedSignal <- *new(struct{}):
-				default:
-				}
+				proxy.tunneler.SignalFailure()
 				// Fatal error, stop the proxy
 				// Fatal error, stop the proxy
 				break
 				break
 			}
 			}
@@ -125,7 +100,7 @@ func (proxy *SocksProxy) acceptSocksConnections() {
 			continue
 			continue
 		}
 		}
 		go func() {
 		go func() {
-			err := proxy.socksConnectionHandler(proxy.tunnel, socksConnection)
+			err := proxy.socksConnectionHandler(socksConnection)
 			if err != nil {
 			if err != nil {
 				Notice(NOTICE_ALERT, "%s", ContextError(err))
 				Notice(NOTICE_ALERT, "%s", ContextError(err))
 			}
 			}

+ 25 - 53
psiphon/tlsDialer.go

@@ -75,10 +75,7 @@ import (
 	"crypto/tls"
 	"crypto/tls"
 	"crypto/x509"
 	"crypto/x509"
 	"errors"
 	"errors"
-	"fmt"
-	"io"
 	"net"
 	"net"
-	"strings"
 	"time"
 	"time"
 )
 )
 
 
@@ -91,25 +88,28 @@ func (timeoutError) Temporary() bool { return true }
 // CustomTLSConfig contains parameters to determine the behavior
 // CustomTLSConfig contains parameters to determine the behavior
 // of CustomTLSDial.
 // of CustomTLSDial.
 type CustomTLSConfig struct {
 type CustomTLSConfig struct {
+
 	// Dial is the network connection dialer. TLS is layered on
 	// Dial is the network connection dialer. TLS is layered on
 	// top of a new network connection created with dialer.
 	// top of a new network connection created with dialer.
 	Dial Dialer
 	Dial Dialer
+
 	// Timeout is and optional timeout for combined network
 	// Timeout is and optional timeout for combined network
 	// connection dial and TLS handshake.
 	// connection dial and TLS handshake.
 	Timeout time.Duration
 	Timeout time.Duration
+
 	// FrontingAddr overrides the "addr" input to Dial when specified
 	// FrontingAddr overrides the "addr" input to Dial when specified
 	FrontingAddr string
 	FrontingAddr string
-	// HttpProxyAddress specifies an HTTP proxy to be used
-	// (with HTTP CONNECT).
-	HttpProxyAddress string
+
 	// SendServerName specifies whether to use SNI
 	// SendServerName specifies whether to use SNI
 	// (tlsdialer functionality)
 	// (tlsdialer functionality)
 	SendServerName bool
 	SendServerName bool
+
 	// VerifyLegacyCertificate is a special case self-signed server
 	// VerifyLegacyCertificate is a special case self-signed server
 	// certificate case. Ignores IP SANs and basic constraints. No
 	// certificate case. Ignores IP SANs and basic constraints. No
 	// certificate chain. Just checks that the server presented the
 	// certificate chain. Just checks that the server presented the
 	// specified certificate.
 	// specified certificate.
 	VerifyLegacyCertificate *x509.Certificate
 	VerifyLegacyCertificate *x509.Certificate
+
 	// TlsConfig is a tls.Config to use in the
 	// TlsConfig is a tls.Config to use in the
 	// non-verifyLegacyCertificate case.
 	// non-verifyLegacyCertificate case.
 	TlsConfig *tls.Config
 	TlsConfig *tls.Config
@@ -141,45 +141,36 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 	}
 	}
 
 
 	dialAddr := addr
 	dialAddr := addr
-	if config.HttpProxyAddress != "" {
-		dialAddr = config.HttpProxyAddress
-	} else if config.FrontingAddr != "" {
+	if config.FrontingAddr != "" {
 		dialAddr = config.FrontingAddr
 		dialAddr = config.FrontingAddr
 	}
 	}
 
 
 	rawConn, err := config.Dial(network, dialAddr)
 	rawConn, err := config.Dial(network, dialAddr)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 
 
-	targetAddr := addr
-	if config.FrontingAddr != "" {
-		targetAddr = config.FrontingAddr
-	}
-
-	colonPos := strings.LastIndex(targetAddr, ":")
-	if colonPos == -1 {
-		colonPos = len(targetAddr)
+	hostname, _, err := net.SplitHostPort(dialAddr)
+	if err != nil {
+		return nil, ContextError(err)
 	}
 	}
-	hostname := targetAddr[:colonPos]
 
 
 	tlsConfig := config.TlsConfig
 	tlsConfig := config.TlsConfig
 	if tlsConfig == nil {
 	if tlsConfig == nil {
 		tlsConfig = &tls.Config{}
 		tlsConfig = &tls.Config{}
 	}
 	}
 
 
-	serverName := tlsConfig.ServerName
+	// Copy config so we can tweak it
+	tlsConfigCopy := new(tls.Config)
+	*tlsConfigCopy = *tlsConfig
 
 
+	serverName := tlsConfig.ServerName
 	// If no ServerName is set, infer the ServerName
 	// If no ServerName is set, infer the ServerName
 	// from the hostname we're connecting to.
 	// from the hostname we're connecting to.
 	if serverName == "" {
 	if serverName == "" {
 		serverName = hostname
 		serverName = hostname
 	}
 	}
 
 
-	// Copy config so we can tweak it
-	tlsConfigCopy := new(tls.Config)
-	*tlsConfigCopy = *tlsConfig
-
 	if config.SendServerName {
 	if config.SendServerName {
 		// Set the ServerName and rely on the usual logic in
 		// Set the ServerName and rely on the usual logic in
 		// tls.Conn.Handshake() to do its verification
 		// tls.Conn.Handshake() to do its verification
@@ -192,34 +183,11 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 
 
 	conn := tls.Client(rawConn, tlsConfigCopy)
 	conn := tls.Client(rawConn, tlsConfigCopy)
 
 
-	establishConnection := func(rawConn net.Conn, conn *tls.Conn) error {
-		// TODO: use the proxy request/response code from net/http/transport.go
-		if config.HttpProxyAddress != "" {
-			connectRequest := fmt.Sprintf(
-				"CONNECT %s HTTP/1.1\r\nHost: %s\r\nConnection: Keep-Alive\r\n\r\n",
-				targetAddr, hostname)
-			_, err := rawConn.Write([]byte(connectRequest))
-			if err != nil {
-				return err
-			}
-			expectedResponse := []byte("HTTP/1.1 200 OK\r\n\r\n")
-			readBuffer := make([]byte, len(expectedResponse))
-			_, err = io.ReadFull(rawConn, readBuffer)
-			if err != nil {
-				return err
-			}
-			if !bytes.Equal(readBuffer, expectedResponse) {
-				return fmt.Errorf("unexpected HTTP proxy response: %s", string(readBuffer))
-			}
-		}
-		return conn.Handshake()
-	}
-
 	if config.Timeout == 0 {
 	if config.Timeout == 0 {
-		err = establishConnection(rawConn, conn)
+		err = conn.Handshake()
 	} else {
 	} else {
 		go func() {
 		go func() {
-			errChannel <- establishConnection(rawConn, conn)
+			errChannel <- conn.Handshake()
 		}()
 		}()
 		err = <-errChannel
 		err = <-errChannel
 	}
 	}
@@ -233,7 +201,7 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 
 
 	if err != nil {
 	if err != nil {
 		rawConn.Close()
 		rawConn.Close()
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 
 
 	return conn, nil
 	return conn, nil
@@ -242,10 +210,10 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 func verifyLegacyCertificate(conn *tls.Conn, expectedCertificate *x509.Certificate) error {
 func verifyLegacyCertificate(conn *tls.Conn, expectedCertificate *x509.Certificate) error {
 	certs := conn.ConnectionState().PeerCertificates
 	certs := conn.ConnectionState().PeerCertificates
 	if len(certs) < 1 {
 	if len(certs) < 1 {
-		return errors.New("no certificate to verify")
+		return ContextError(errors.New("no certificate to verify"))
 	}
 	}
 	if !bytes.Equal(certs[0].Raw, expectedCertificate.Raw) {
 	if !bytes.Equal(certs[0].Raw, expectedCertificate.Raw) {
-		return errors.New("unexpected certificate")
+		return ContextError(errors.New("unexpected certificate"))
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -266,6 +234,10 @@ func verifyServerCerts(conn *tls.Conn, serverName string, config *tls.Config) er
 		}
 		}
 		opts.Intermediates.AddCert(cert)
 		opts.Intermediates.AddCert(cert)
 	}
 	}
+
 	_, err := certs[0].Verify(opts)
 	_, err := certs[0].Verify(opts)
-	return err
+	if err != nil {
+		return ContextError(err)
+	}
+	return nil
 }
 }

+ 81 - 21
psiphon/tunnel.go

@@ -23,13 +23,24 @@ import (
 	"bytes"
 	"bytes"
 	"code.google.com/p/go.crypto/ssh"
 	"code.google.com/p/go.crypto/ssh"
 	"encoding/base64"
 	"encoding/base64"
+	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"strings"
 	"strings"
+	"sync/atomic"
 	"time"
 	"time"
 )
 )
 
 
+// Tunneler specifies the interface required by components that use a tunnel.
+// Components which use this interface may be serviced by a single Tunnel instance,
+// or a Controller which manages a pool of tunnels, or any other object which
+// implements Tunneler.
+type Tunneler interface {
+	Dial(remoteAddr string) (conn net.Conn, err error)
+	SignalFailure()
+}
+
 const (
 const (
 	TUNNEL_PROTOCOL_SSH            = "SSH"
 	TUNNEL_PROTOCOL_SSH            = "SSH"
 	TUNNEL_PROTOCOL_OBFUSCATED_SSH = "OSSH"
 	TUNNEL_PROTOCOL_OBFUSCATED_SSH = "OSSH"
@@ -49,21 +60,15 @@ var SupportedTunnelProtocols = []string{
 // tunnel includes a network connection to the specified server
 // tunnel includes a network connection to the specified server
 // and an SSH session built on top of that transport.
 // and an SSH session built on top of that transport.
 type Tunnel struct {
 type Tunnel struct {
-	serverEntry      *ServerEntry
-	protocol         string
-	conn             Conn
-	sshClient        *ssh.Client
-	sshKeepAliveQuit chan struct{}
-}
-
-// Close terminates the tunnel.
-func (tunnel *Tunnel) Close() {
-	if tunnel.sshKeepAliveQuit != nil {
-		close(tunnel.sshKeepAliveQuit)
-	}
-	if tunnel.conn != nil {
-		tunnel.conn.Close()
-	}
+	serverEntry             *ServerEntry
+	sessionId               string
+	sessionStarted          int32
+	protocol                string
+	conn                    Conn
+	sshClient               *ssh.Client
+	sshKeepAliveQuit        chan struct{}
+	portForwardFailures     chan int
+	portForwardFailureTotal int
 }
 }
 
 
 // EstablishTunnel first makes a network transport connection to the
 // EstablishTunnel first makes a network transport connection to the
@@ -77,10 +82,8 @@ func (tunnel *Tunnel) Close() {
 // the first protocol in SupportedTunnelProtocols that's also in the
 // the first protocol in SupportedTunnelProtocols that's also in the
 // server capabilities is used.
 // server capabilities is used.
 func EstablishTunnel(
 func EstablishTunnel(
-	config *Config,
-	sessionId string,
-	serverEntry *ServerEntry,
-	pendingConns *Conns) (tunnel *Tunnel, err error) {
+	config *Config, pendingConns *Conns, serverEntry *ServerEntry) (tunnel *Tunnel, err error) {
+
 	// Select the protocol
 	// Select the protocol
 	var selectedProtocol string
 	var selectedProtocol string
 	// TODO: properly handle protocols (e.g. FRONTED-MEEK-OSSH) vs. capabilities (e.g., {FRONTED-MEEK, OSSH})
 	// TODO: properly handle protocols (e.g. FRONTED-MEEK-OSSH) vs. capabilities (e.g., {FRONTED-MEEK, OSSH})
@@ -106,6 +109,7 @@ func EstablishTunnel(
 	}
 	}
 	Notice(NOTICE_INFO, "connecting to %s in region %s using %s",
 	Notice(NOTICE_INFO, "connecting to %s in region %s using %s",
 		serverEntry.IpAddress, serverEntry.Region, selectedProtocol)
 		serverEntry.IpAddress, serverEntry.Region, selectedProtocol)
+
 	// The meek protocols tunnel obfuscated SSH. Obfuscated SSH is layered on top of SSH.
 	// The meek protocols tunnel obfuscated SSH. Obfuscated SSH is layered on top of SSH.
 	// So depending on which protocol is used, multiple layers are initialized.
 	// So depending on which protocol is used, multiple layers are initialized.
 	port := 0
 	port := 0
@@ -127,6 +131,15 @@ func EstablishTunnel(
 	case TUNNEL_PROTOCOL_SSH:
 	case TUNNEL_PROTOCOL_SSH:
 		port = serverEntry.SshPort
 		port = serverEntry.SshPort
 	}
 	}
+
+	// Generate a session Id for the Psiphon server API. This is generated now so
+	// that it can be sent with the SSH password payload, which helps the server
+	// associate client geo location, used in server API stats, with the session ID.
+	sessionId, err := MakeSessionId()
+	if err != nil {
+		return nil, ContextError(err)
+	}
+
 	// Create the base transport: meek or direct connection
 	// Create the base transport: meek or direct connection
 	dialConfig := &DialConfig{
 	dialConfig := &DialConfig{
 		ConnectTimeout:             TUNNEL_CONNECT_TIMEOUT,
 		ConnectTimeout:             TUNNEL_CONNECT_TIMEOUT,
@@ -157,6 +170,7 @@ func EstablishTunnel(
 			conn.Close()
 			conn.Close()
 		}
 		}
 	}()
 	}()
+
 	// Add obfuscated SSH layer
 	// Add obfuscated SSH layer
 	var sshConn net.Conn
 	var sshConn net.Conn
 	sshConn = conn
 	sshConn = conn
@@ -166,6 +180,7 @@ func EstablishTunnel(
 			return nil, ContextError(err)
 			return nil, ContextError(err)
 		}
 		}
 	}
 	}
+
 	// Now establish the SSH session over the sshConn transport
 	// Now establish the SSH session over the sshConn transport
 	expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey)
 	expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey)
 	if err != nil {
 	if err != nil {
@@ -179,10 +194,18 @@ func EstablishTunnel(
 			return nil
 			return nil
 		},
 		},
 	}
 	}
+	sshPasswordPayload, err := json.Marshal(
+		struct {
+			SessionId   string `json:"SessionId"`
+			SshPassword string `json:"SshPassword"`
+		}{sessionId, serverEntry.SshPassword})
+	if err != nil {
+		return nil, ContextError(err)
+	}
 	sshClientConfig := &ssh.ClientConfig{
 	sshClientConfig := &ssh.ClientConfig{
 		User: serverEntry.SshUsername,
 		User: serverEntry.SshUsername,
 		Auth: []ssh.AuthMethod{
 		Auth: []ssh.AuthMethod{
-			ssh.Password(serverEntry.SshPassword),
+			ssh.Password(string(sshPasswordPayload)),
 		},
 		},
 		HostKeyCallback: sshCertChecker.CheckHostKey,
 		HostKeyCallback: sshCertChecker.CheckHostKey,
 	}
 	}
@@ -194,6 +217,7 @@ func EstablishTunnel(
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
 	sshClient := ssh.NewClient(sshClientConn, sshChans, sshReqs)
 	sshClient := ssh.NewClient(sshClientConn, sshChans, sshReqs)
+
 	// Run a goroutine to periodically execute SSH keepalive
 	// Run a goroutine to periodically execute SSH keepalive
 	sshKeepAliveQuit := make(chan struct{})
 	sshKeepAliveQuit := make(chan struct{})
 	sshKeepAliveTicker := time.NewTicker(TUNNEL_SSH_KEEP_ALIVE_PERIOD)
 	sshKeepAliveTicker := time.NewTicker(TUNNEL_SSH_KEEP_ALIVE_PERIOD)
@@ -214,11 +238,47 @@ func EstablishTunnel(
 			}
 			}
 		}
 		}
 	}()
 	}()
+
 	return &Tunnel{
 	return &Tunnel{
 			serverEntry:      serverEntry,
 			serverEntry:      serverEntry,
+			sessionId:        sessionId,
 			protocol:         selectedProtocol,
 			protocol:         selectedProtocol,
 			conn:             conn,
 			conn:             conn,
 			sshClient:        sshClient,
 			sshClient:        sshClient,
-			sshKeepAliveQuit: sshKeepAliveQuit},
+			sshKeepAliveQuit: sshKeepAliveQuit,
+			// portForwardFailures buffer size is large enough to receive the thresold number
+			// of failure reports without blocking. Senders can drop failures without blocking.
+			portForwardFailures: make(chan int, config.PortForwardFailureThreshold)},
 		nil
 		nil
 }
 }
+
+// Close terminates the tunnel.
+func (tunnel *Tunnel) Close() {
+	if tunnel.sshKeepAliveQuit != nil {
+		close(tunnel.sshKeepAliveQuit)
+	}
+	if tunnel.conn != nil {
+		tunnel.conn.Close()
+	}
+}
+
+func (tunnel *Tunnel) IsSessionStarted() bool {
+	return atomic.LoadInt32(&tunnel.sessionStarted) == 1
+}
+
+func (tunnel *Tunnel) SetSessionStarted() {
+	atomic.StoreInt32(&tunnel.sessionStarted, 1)
+}
+
+// Dial establishes a port forward connection through the tunnel
+func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
+	// TODO: should this track port forward failures as in Controller.DialWithTunnel?
+	return tunnel.sshClient.Dial("tcp", remoteAddr)
+}
+
+// SignalFailure notifies the tunnel that an associated component has failed.
+// This will terminate the tunnel.
+func (tunnel *Tunnel) SignalFailure() {
+	Notice(NOTICE_ALERT, "tunnel received failure signal")
+	tunnel.Close()
+}

+ 30 - 1
psiphonTunnelCore.go → psiphonClient.go

@@ -23,9 +23,13 @@ import (
 	"flag"
 	"flag"
 	psiphon "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	psiphon "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"log"
 	"log"
+	"os"
+	"os/signal"
+	"sync"
 )
 )
 
 
 func main() {
 func main() {
+
 	var configFilename string
 	var configFilename string
 	flag.StringVar(&configFilename, "config", "", "configuration file")
 	flag.StringVar(&configFilename, "config", "", "configuration file")
 	flag.Parse()
 	flag.Parse()
@@ -36,5 +40,30 @@ func main() {
 	if err != nil {
 	if err != nil {
 		log.Fatalf("error loading configuration file: %s", err)
 		log.Fatalf("error loading configuration file: %s", err)
 	}
 	}
-	psiphon.RunTunnelForever(config)
+
+	if config.LogFilename != "" {
+		logFile, err := os.OpenFile(config.LogFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
+		if err != nil {
+			log.Fatalf("error opening log file: %s", err)
+		}
+		defer logFile.Close()
+		log.SetOutput(logFile)
+	}
+
+	controller := psiphon.NewController(config)
+	shutdownBroadcast := make(chan struct{})
+	controllerWaitGroup := new(sync.WaitGroup)
+	controllerWaitGroup.Add(1)
+	go func() {
+		defer controllerWaitGroup.Done()
+		controller.Run(shutdownBroadcast)
+	}()
+
+	systemStopSignal := make(chan os.Signal, 1)
+	signal.Notify(systemStopSignal, os.Interrupt, os.Kill)
+	<-systemStopSignal
+
+	psiphon.Notice(psiphon.NOTICE_INFO, "shutdown by system")
+	close(shutdownBroadcast)
+	controllerWaitGroup.Wait()
 }
 }

+ 0 - 30
psiphonTunnelCore_test.go

@@ -1,30 +0,0 @@
-/*
- * Copyright (c) 2014, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program.  If not, see <http://www.gnu.org/licenses/>.
- *
- */
-
-package main
-
-import (
-	"psiphon"
-)
-
-func TestPsiphon(t *testing.T) {
-	var config psiphon.Config
-	// TODO: put a test config here
-	psiphon.RunTunnelForever(&config)
-}