Просмотр исходного кода

Added support for upstream HTTP proxy; fix missing BindToDevice support for fetch remote server list

Rod Hynes 11 лет назад
Родитель
Сommit
ecfb147ba3
8 измененных файлов с 119 добавлено и 57 удалено
  1. 2 2
      README.md
  2. 19 2
      psiphon/TCPConn_unix.go
  3. 17 1
      psiphon/TCPConn_windows.go
  4. 1 0
      psiphon/config.go
  5. 40 0
      psiphon/conn.go
  6. 6 2
      psiphon/controller.go
  7. 18 2
      psiphon/remoteServerList.go
  8. 16 48
      psiphon/tlsDialer.go

+ 2 - 2
README.md

@@ -35,7 +35,8 @@ Setup
         "TunnelProtocol" : "",
         "ConnectionWorkerPoolSize" : 10,
         "TunnelPoolSize" : 1,
-        "PortForwardFailureThreshold" : 10
+        "PortForwardFailureThreshold" : 10,
+        "UpstreamHttpProxyAddress" : ""
     }
     ```
 
@@ -50,7 +51,6 @@ Roadmap
 * requirements for integrating with Windows client
   * split tunnel support
   * implement page view and bytes transferred stats
-  * upstream proxy support
   * resumable download of client upgrades
 * Android app
   * open home pages

+ 19 - 2
psiphon/TCPConn_unix.go

@@ -62,8 +62,15 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 		}
 	}
 
+	// When using an upstream HTTP proxy, first connect to the proxy,
+	// then use HTTP CONNECT to connect to the original destination.
+	dialAddr := addr
+	if config.UpstreamHttpProxyAddress != "" {
+		dialAddr = config.UpstreamHttpProxyAddress
+	}
+
 	// Get the remote IP and port, resolving a domain name if necessary
-	host, strPort, err := net.SplitHostPort(addr)
+	host, strPort, err := net.SplitHostPort(dialAddr)
 	if err != nil {
 		return nil, ContextError(err)
 	}
@@ -88,6 +95,7 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 		readTimeout:   config.ReadTimeout,
 		writeTimeout:  config.WriteTimeout}
 	config.PendingConns.Add(conn)
+	defer config.PendingConns.Remove(conn)
 
 	// Connect the socket
 	// TODO: adjust the timeout to account for time spent resolving hostname
@@ -104,7 +112,6 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 	} else {
 		err = syscall.Connect(conn.interruptible.socketFd, &sockAddr)
 	}
-	config.PendingConns.Remove(conn)
 	if err != nil {
 		return nil, ContextError(err)
 	}
@@ -116,6 +123,16 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 	if err != nil {
 		return nil, ContextError(err)
 	}
+
+	// Going through upstream HTTP proxy
+	if config.UpstreamHttpProxyAddress != "" {
+		// This call can be interrupted by closing the pending conn
+		err := HttpProxyConnect(conn, addr)
+		if err != nil {
+			return nil, ContextError(err)
+		}
+	}
+
 	return conn, nil
 }
 

+ 17 - 1
psiphon/TCPConn_windows.go

@@ -56,7 +56,23 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 	// Call the blocking Dial in a goroutine
 	results := conn.interruptible.results
 	go func() {
-		netConn, err := net.DialTimeout("tcp", addr, config.ConnectTimeout)
+
+		// When using an upstream HTTP proxy, first connect to the proxy,
+		// then use HTTP CONNECT to connect to the original destination.
+		dialAddr := addr
+		if config.UpstreamHttpProxyAddress != "" {
+			dialAddr = config.UpstreamHttpProxyAddress
+		}
+
+		netConn, err := net.DialTimeout("tcp", dialAddr, config.ConnectTimeout)
+
+		if config.UpstreamHttpProxyAddress != "" {
+			err := HttpProxyConnect(netConn, addr)
+			if err != nil {
+				return nil, ContextError(err)
+			}
+		}
+
 		results <- &interruptibleDialResult{netConn, err}
 	}()
 

+ 1 - 0
psiphon/config.go

@@ -43,6 +43,7 @@ type Config struct {
 	BindToDeviceDnsServer              string
 	TunnelPoolSize                     int
 	PortForwardFailureThreshold        int
+	UpstreamHttpProxyAddress           string
 }
 
 // LoadConfig reads, and parse, and validates a JSON format Psiphon config

+ 40 - 0
psiphon/conn.go

@@ -20,6 +20,8 @@
 package psiphon
 
 import (
+	"bytes"
+	"fmt"
 	"io"
 	"net"
 	"sync"
@@ -29,6 +31,12 @@ import (
 // DialConfig contains parameters to determine the behavior
 // of a Psiphon dialer (TCPDial, MeekDial, etc.)
 type DialConfig struct {
+
+	// UpstreamHttpProxyAddress specifies an HTTP proxy to connect through
+	// (the proxy must support HTTP CONNECT). The address may be a hostname
+	// or IP address and must include a port number.
+	UpstreamHttpProxyAddress string
+
 	ConnectTimeout time.Duration
 	ReadTimeout    time.Duration
 	WriteTimeout   time.Duration
@@ -119,3 +127,35 @@ func Relay(localConn, remoteConn net.Conn) {
 	}
 	copyWaitGroup.Wait()
 }
+
+// HttpProxyConnect establishes a HTTP CONNECT tunnel to addr through
+// an established network connection to an HTTP proxy. It is assumed that
+// no payload bytes have been sent through the connection to the proxy.
+func HttpProxyConnect(rawConn net.Conn, addr string) (err error) {
+	hostname, _, err := net.SplitHostPort(addr)
+	if err != nil {
+		return ContextError(err)
+	}
+
+	// TODO: use the proxy request/response code from net/http/transport.go?
+	connectRequest := fmt.Sprintf(
+		"CONNECT %s HTTP/1.1\r\nHost: %s\r\nConnection: Keep-Alive\r\n\r\n",
+		addr, hostname)
+	_, err = rawConn.Write([]byte(connectRequest))
+	if err != nil {
+		return ContextError(err)
+	}
+
+	expectedResponse := []byte("HTTP/1.1 200 OK\r\n\r\n")
+	readBuffer := make([]byte, len(expectedResponse))
+	_, err = io.ReadFull(rawConn, readBuffer)
+	if err != nil {
+		return ContextError(err)
+	}
+
+	if !bytes.Equal(readBuffer, expectedResponse) {
+		return fmt.Errorf("unexpected HTTP proxy response: %s", string(readBuffer))
+	}
+
+	return nil
+}

+ 6 - 2
psiphon/controller.go

@@ -108,6 +108,9 @@ func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
 		Notice(NOTICE_ALERT, "controller shutdown due to failure")
 	}
 
+	// Note: in addition to establish(), this pendingConns will interrupt
+	// FetchRemoteServerList
+	controller.pendingConns.CloseAll()
 	close(controller.shutdownBroadcast)
 	controller.runWaitGroup.Wait()
 
@@ -133,8 +136,9 @@ func (controller *Controller) remoteServerListFetcher() {
 	// always makes the fetch remote server list request
 loop:
 	for {
-		// TODO: FetchRemoteServerList should abort immediately on shutdownBroadcast
-		err := FetchRemoteServerList(controller.config)
+		// TODO: FetchRemoteServerList should have its own pendingConns,
+		// otherwise it may needlessly abort when establish is stopped.
+		err := FetchRemoteServerList(controller.config, controller.pendingConns)
 		var duration time.Duration
 		if err != nil {
 			Notice(NOTICE_ALERT, "failed to fetch remote server list: %s", err)

+ 18 - 2
psiphon/remoteServerList.go

@@ -45,20 +45,35 @@ type RemoteServerList struct {
 // config.RemoteServerListUrl; validates its digital signature using the
 // public key config.RemoteServerListSignaturePublicKey; and parses the
 // data field into ServerEntry records.
-func FetchRemoteServerList(config *Config) (err error) {
+func FetchRemoteServerList(config *Config, pendingConns *Conns) (err error) {
 	Notice(NOTICE_INFO, "fetching remote server list")
+
+	// Note: pendingConns may be used to interrupt the fetch remote server list
+	// request. BindToDevice may be used to exclude requests from VPN routing.
+	dialConfig := &DialConfig{
+		PendingConns:               pendingConns,
+		BindToDeviceServiceAddress: config.BindToDeviceServiceAddress,
+		BindToDeviceDnsServer:      config.BindToDeviceDnsServer,
+	}
+	transport := &http.Transport{
+		Dial: NewTCPDialer(dialConfig),
+	}
 	httpClient := http.Client{
-		Timeout: FETCH_REMOTE_SERVER_LIST_TIMEOUT,
+		Timeout:   FETCH_REMOTE_SERVER_LIST_TIMEOUT,
+		Transport: transport,
 	}
+
 	response, err := httpClient.Get(config.RemoteServerListUrl)
 	if err != nil {
 		return ContextError(err)
 	}
 	defer response.Body.Close()
+
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
 		return ContextError(err)
 	}
+
 	var remoteServerList *RemoteServerList
 	err = json.Unmarshal(body, &remoteServerList)
 	if err != nil {
@@ -68,6 +83,7 @@ func FetchRemoteServerList(config *Config) (err error) {
 	if err != nil {
 		return ContextError(err)
 	}
+
 	for _, encodedServerEntry := range strings.Split(remoteServerList.Data, "\n") {
 		serverEntry, err := DecodeServerEntry(encodedServerEntry)
 		if err != nil {

+ 16 - 48
psiphon/tlsDialer.go

@@ -75,10 +75,7 @@ import (
 	"crypto/tls"
 	"crypto/x509"
 	"errors"
-	"fmt"
-	"io"
 	"net"
-	"strings"
 	"time"
 )
 
@@ -91,25 +88,28 @@ func (timeoutError) Temporary() bool { return true }
 // CustomTLSConfig contains parameters to determine the behavior
 // of CustomTLSDial.
 type CustomTLSConfig struct {
+
 	// Dial is the network connection dialer. TLS is layered on
 	// top of a new network connection created with dialer.
 	Dial Dialer
+
 	// Timeout is and optional timeout for combined network
 	// connection dial and TLS handshake.
 	Timeout time.Duration
+
 	// FrontingAddr overrides the "addr" input to Dial when specified
 	FrontingAddr string
-	// HttpProxyAddress specifies an HTTP proxy to be used
-	// (with HTTP CONNECT).
-	HttpProxyAddress string
+
 	// SendServerName specifies whether to use SNI
 	// (tlsdialer functionality)
 	SendServerName bool
+
 	// VerifyLegacyCertificate is a special case self-signed server
 	// certificate case. Ignores IP SANs and basic constraints. No
 	// certificate chain. Just checks that the server presented the
 	// specified certificate.
 	VerifyLegacyCertificate *x509.Certificate
+
 	// TlsConfig is a tls.Config to use in the
 	// non-verifyLegacyCertificate case.
 	TlsConfig *tls.Config
@@ -141,9 +141,7 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 	}
 
 	dialAddr := addr
-	if config.HttpProxyAddress != "" {
-		dialAddr = config.HttpProxyAddress
-	} else if config.FrontingAddr != "" {
+	if config.FrontingAddr != "" {
 		dialAddr = config.FrontingAddr
 	}
 
@@ -152,34 +150,27 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 		return nil, ContextError(err)
 	}
 
-	targetAddr := addr
-	if config.FrontingAddr != "" {
-		targetAddr = config.FrontingAddr
-	}
-
-	colonPos := strings.LastIndex(targetAddr, ":")
-	if colonPos == -1 {
-		colonPos = len(targetAddr)
+	hostname, _, err := net.SplitHostPort(dialAddr)
+	if err != nil {
+		return nil, ContextError(err)
 	}
-	hostname := targetAddr[:colonPos]
 
 	tlsConfig := config.TlsConfig
 	if tlsConfig == nil {
 		tlsConfig = &tls.Config{}
 	}
 
-	serverName := tlsConfig.ServerName
+	// Copy config so we can tweak it
+	tlsConfigCopy := new(tls.Config)
+	*tlsConfigCopy = *tlsConfig
 
+	serverName := tlsConfig.ServerName
 	// If no ServerName is set, infer the ServerName
 	// from the hostname we're connecting to.
 	if serverName == "" {
 		serverName = hostname
 	}
 
-	// Copy config so we can tweak it
-	tlsConfigCopy := new(tls.Config)
-	*tlsConfigCopy = *tlsConfig
-
 	if config.SendServerName {
 		// Set the ServerName and rely on the usual logic in
 		// tls.Conn.Handshake() to do its verification
@@ -192,34 +183,11 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
 
 	conn := tls.Client(rawConn, tlsConfigCopy)
 
-	establishConnection := func(rawConn net.Conn, conn *tls.Conn) error {
-		// TODO: use the proxy request/response code from net/http/transport.go
-		if config.HttpProxyAddress != "" {
-			connectRequest := fmt.Sprintf(
-				"CONNECT %s HTTP/1.1\r\nHost: %s\r\nConnection: Keep-Alive\r\n\r\n",
-				targetAddr, hostname)
-			_, err := rawConn.Write([]byte(connectRequest))
-			if err != nil {
-				return ContextError(err)
-			}
-			expectedResponse := []byte("HTTP/1.1 200 OK\r\n\r\n")
-			readBuffer := make([]byte, len(expectedResponse))
-			_, err = io.ReadFull(rawConn, readBuffer)
-			if err != nil {
-				return ContextError(err)
-			}
-			if !bytes.Equal(readBuffer, expectedResponse) {
-				return fmt.Errorf("unexpected HTTP proxy response: %s", string(readBuffer))
-			}
-		}
-		return conn.Handshake()
-	}
-
 	if config.Timeout == 0 {
-		err = establishConnection(rawConn, conn)
+		err = conn.Handshake()
 	} else {
 		go func() {
-			errChannel <- establishConnection(rawConn, conn)
+			errChannel <- conn.Handshake()
 		}()
 		err = <-errChannel
 	}