Sfoglia il codice sorgente

Decouple Controller from HttpProxy and SocksProxy. Use a Tunneler interface so either a Controller (tunnel pool with lifecycle management) or Tunnel (single, specific tunnel) may be used as the tunneler for proxies. This is to support prospective test scripts that will test tunnels to specific servers.

Rod Hynes 11 anni fa
parent
commit
f8a0a79f10
4 ha cambiato i file con 52 aggiunte e 29 eliminazioni
  1. 9 8
      psiphon/controller.go
  2. 7 7
      psiphon/httpProxy.go
  3. 6 6
      psiphon/socksProxy.go
  4. 30 8
      psiphon/tunnel.go

+ 9 - 8
psiphon/controller.go

@@ -84,13 +84,13 @@ func NewController(config *Config) (controller *Controller) {
 // - a local SOCKS proxy that port forwards through the pool of tunnels
 // - a local SOCKS proxy that port forwards through the pool of tunnels
 // - a local HTTP proxy that port forwards through the pool of tunnels
 // - a local HTTP proxy that port forwards through the pool of tunnels
 func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
 func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
-	socksProxy, err := NewSocksProxy(controller)
+	socksProxy, err := NewSocksProxy(controller.config, controller)
 	if err != nil {
 	if err != nil {
 		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
 		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
 		return
 		return
 	}
 	}
 	defer socksProxy.Close()
 	defer socksProxy.Close()
-	httpProxy, err := NewHttpProxy(controller)
+	httpProxy, err := NewHttpProxy(controller.config, controller)
 	if err != nil {
 	if err != nil {
 		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
 		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
 		return
 		return
@@ -117,7 +117,7 @@ func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
 	Notice(NOTICE_INFO, "exiting controller")
 	Notice(NOTICE_INFO, "exiting controller")
 }
 }
 
 
-// SignalFailure notifies the controller than a component has failed.
+// SignalFailure notifies the controller that an associated component has failed.
 // This will terminate the controller.
 // This will terminate the controller.
 func (controller *Controller) SignalFailure() {
 func (controller *Controller) SignalFailure() {
 	select {
 	select {
@@ -441,15 +441,15 @@ func (conn *TunneledConn) Write(buffer []byte) (n int, err error) {
 	return
 	return
 }
 }
 
 
-// dialWithTunnel selects an active tunnel and establishes a port forward
+// DialWithTunnel selects an active tunnel and establishes a port forward
 // connection through the selected tunnel. Failure to connect is considered
 // connection through the selected tunnel. Failure to connect is considered
 // a port foward failure, for the purpose of monitoring tunnel health.
 // a port foward failure, for the purpose of monitoring tunnel health.
-func (controller *Controller) dialWithTunnel(remoteAddr string) (conn net.Conn, err error) {
+func (controller *Controller) Dial(remoteAddr string) (conn net.Conn, err error) {
 	tunnel := controller.getNextActiveTunnel()
 	tunnel := controller.getNextActiveTunnel()
 	if tunnel == nil {
 	if tunnel == nil {
 		return nil, ContextError(errors.New("no active tunnels"))
 		return nil, ContextError(errors.New("no active tunnels"))
 	}
 	}
-	sshPortForward, err := tunnel.sshClient.Dial("tcp", remoteAddr)
+	tunnelConn, err := tunnel.Dial(remoteAddr)
 	if err != nil {
 	if err != nil {
 		// TODO: conditional on type of error or error message?
 		// TODO: conditional on type of error or error message?
 		select {
 		select {
@@ -459,7 +459,7 @@ func (controller *Controller) dialWithTunnel(remoteAddr string) (conn net.Conn,
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
 	return &TunneledConn{
 	return &TunneledConn{
-			Conn:   sshPortForward,
+			Conn:   tunnelConn,
 			tunnel: tunnel},
 			tunnel: tunnel},
 		nil
 		nil
 }
 }
@@ -576,7 +576,8 @@ func (controller *Controller) establishTunnelWorker() {
 			return
 			return
 		default:
 		default:
 		}
 		}
-		tunnel, err := EstablishTunnel(controller, serverEntry)
+		tunnel, err := EstablishTunnel(
+			controller.config, controller.pendingConns, serverEntry)
 		if err != nil {
 		if err != nil {
 			// TODO: distingush case where conn is interrupted?
 			// TODO: distingush case where conn is interrupted?
 			Notice(NOTICE_INFO, "failed to connect to %s: %s", serverEntry.IpAddress, err)
 			Notice(NOTICE_INFO, "failed to connect to %s: %s", serverEntry.IpAddress, err)

+ 7 - 7
psiphon/httpProxy.go

@@ -31,7 +31,7 @@ import (
 // HttpProxy is a HTTP server that relays HTTP requests through
 // HttpProxy is a HTTP server that relays HTTP requests through
 // the tunnel SSH client.
 // the tunnel SSH client.
 type HttpProxy struct {
 type HttpProxy struct {
-	controller     *Controller
+	tunneler       Tunneler
 	listener       net.Listener
 	listener       net.Listener
 	serveWaitGroup *sync.WaitGroup
 	serveWaitGroup *sync.WaitGroup
 	httpRelay      *http.Transport
 	httpRelay      *http.Transport
@@ -39,15 +39,15 @@ type HttpProxy struct {
 }
 }
 
 
 // NewHttpProxy initializes and runs a new HTTP proxy server.
 // NewHttpProxy initializes and runs a new HTTP proxy server.
-func NewHttpProxy(controller *Controller) (proxy *HttpProxy, err error) {
+func NewHttpProxy(config *Config, tunneler Tunneler) (proxy *HttpProxy, err error) {
 	listener, err := net.Listen(
 	listener, err := net.Listen(
-		"tcp", fmt.Sprintf("127.0.0.1:%d", controller.config.LocalHttpProxyPort))
+		"tcp", fmt.Sprintf("127.0.0.1:%d", config.LocalHttpProxyPort))
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
 	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
 	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
 		// TODO: connect timeout?
 		// TODO: connect timeout?
-		return controller.dialWithTunnel(addr)
+		return tunneler.Dial(addr)
 	}
 	}
 	// TODO: also use http.Client, with its Timeout field?
 	// TODO: also use http.Client, with its Timeout field?
 	transport := &http.Transport{
 	transport := &http.Transport{
@@ -56,7 +56,7 @@ func NewHttpProxy(controller *Controller) (proxy *HttpProxy, err error) {
 		ResponseHeaderTimeout: HTTP_PROXY_ORIGIN_SERVER_TIMEOUT,
 		ResponseHeaderTimeout: HTTP_PROXY_ORIGIN_SERVER_TIMEOUT,
 	}
 	}
 	proxy = &HttpProxy{
 	proxy = &HttpProxy{
-		controller:     controller,
+		tunneler:       tunneler,
 		listener:       listener,
 		listener:       listener,
 		serveWaitGroup: new(sync.WaitGroup),
 		serveWaitGroup: new(sync.WaitGroup),
 		httpRelay:      transport,
 		httpRelay:      transport,
@@ -187,7 +187,7 @@ func (proxy *HttpProxy) httpConnectHandler(localConn net.Conn, target string) (e
 	defer localConn.Close()
 	defer localConn.Close()
 	defer proxy.openConns.Remove(localConn)
 	defer proxy.openConns.Remove(localConn)
 	proxy.openConns.Add(localConn)
 	proxy.openConns.Add(localConn)
-	remoteConn, err := proxy.controller.dialWithTunnel(target)
+	remoteConn, err := proxy.tunneler.Dial(target)
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
@@ -227,7 +227,7 @@ func (proxy *HttpProxy) serve() {
 	// Note: will be interrupted by listener.Close() call made by proxy.Close()
 	// Note: will be interrupted by listener.Close() call made by proxy.Close()
 	err := httpServer.Serve(proxy.listener)
 	err := httpServer.Serve(proxy.listener)
 	if err != nil {
 	if err != nil {
-		proxy.controller.SignalFailure()
+		proxy.tunneler.SignalFailure()
 		Notice(NOTICE_ALERT, "%s", ContextError(err))
 		Notice(NOTICE_ALERT, "%s", ContextError(err))
 	}
 	}
 	Notice(NOTICE_HTTP_PROXY, "HTTP proxy stopped")
 	Notice(NOTICE_HTTP_PROXY, "HTTP proxy stopped")

+ 6 - 6
psiphon/socksProxy.go

@@ -31,7 +31,7 @@ import (
 // the tunnel SSH client and relays traffic through the port
 // the tunnel SSH client and relays traffic through the port
 // forward.
 // forward.
 type SocksProxy struct {
 type SocksProxy struct {
-	controller     *Controller
+	tunneler       Tunneler
 	listener       *socks.SocksListener
 	listener       *socks.SocksListener
 	serveWaitGroup *sync.WaitGroup
 	serveWaitGroup *sync.WaitGroup
 	openConns      *Conns
 	openConns      *Conns
@@ -40,14 +40,14 @@ type SocksProxy struct {
 // NewSocksProxy initializes a new SOCKS server. It begins listening for
 // NewSocksProxy initializes a new SOCKS server. It begins listening for
 // connections, starts a goroutine that runs an accept loop, and returns
 // connections, starts a goroutine that runs an accept loop, and returns
 // leaving the accept loop running.
 // leaving the accept loop running.
-func NewSocksProxy(controller *Controller) (proxy *SocksProxy, err error) {
+func NewSocksProxy(config *Config, tunneler Tunneler) (proxy *SocksProxy, err error) {
 	listener, err := socks.ListenSocks(
 	listener, err := socks.ListenSocks(
-		"tcp", fmt.Sprintf("127.0.0.1:%d", controller.config.LocalSocksProxyPort))
+		"tcp", fmt.Sprintf("127.0.0.1:%d", config.LocalSocksProxyPort))
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
 	proxy = &SocksProxy{
 	proxy = &SocksProxy{
-		controller:     controller,
+		tunneler:       tunneler,
 		listener:       listener,
 		listener:       listener,
 		serveWaitGroup: new(sync.WaitGroup),
 		serveWaitGroup: new(sync.WaitGroup),
 		openConns:      new(Conns),
 		openConns:      new(Conns),
@@ -70,7 +70,7 @@ func (proxy *SocksProxy) socksConnectionHandler(localConn *socks.SocksConn) (err
 	defer localConn.Close()
 	defer localConn.Close()
 	defer proxy.openConns.Remove(localConn)
 	defer proxy.openConns.Remove(localConn)
 	proxy.openConns.Add(localConn)
 	proxy.openConns.Add(localConn)
-	remoteConn, err := proxy.controller.dialWithTunnel(localConn.Req.Target)
+	remoteConn, err := proxy.tunneler.Dial(localConn.Req.Target)
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}
@@ -92,7 +92,7 @@ func (proxy *SocksProxy) serve() {
 		if err != nil {
 		if err != nil {
 			Notice(NOTICE_ALERT, "SOCKS proxy accept error: %s", err)
 			Notice(NOTICE_ALERT, "SOCKS proxy accept error: %s", err)
 			if e, ok := err.(net.Error); ok && !e.Temporary() {
 			if e, ok := err.(net.Error); ok && !e.Temporary() {
-				proxy.controller.SignalFailure()
+				proxy.tunneler.SignalFailure()
 				// Fatal error, stop the proxy
 				// Fatal error, stop the proxy
 				break
 				break
 			}
 			}

+ 30 - 8
psiphon/tunnel.go

@@ -31,6 +31,15 @@ import (
 	"time"
 	"time"
 )
 )
 
 
+// Tunneler specifies the interface required by components that use a tunnel.
+// Components which use this interface may be services by a single Tunnel instance,
+// or a Controller which manages a pool of tunnels, or any other object which
+// implements Tunneler.
+type Tunneler interface {
+	Dial(remoteAddr string) (conn net.Conn, err error)
+	SignalFailure()
+}
+
 const (
 const (
 	TUNNEL_PROTOCOL_SSH            = "SSH"
 	TUNNEL_PROTOCOL_SSH            = "SSH"
 	TUNNEL_PROTOCOL_OBFUSCATED_SSH = "OSSH"
 	TUNNEL_PROTOCOL_OBFUSCATED_SSH = "OSSH"
@@ -81,18 +90,18 @@ func (tunnel *Tunnel) Close() {
 // the first protocol in SupportedTunnelProtocols that's also in the
 // the first protocol in SupportedTunnelProtocols that's also in the
 // server capabilities is used.
 // server capabilities is used.
 func EstablishTunnel(
 func EstablishTunnel(
-	controller *Controller, serverEntry *ServerEntry) (tunnel *Tunnel, err error) {
+	config *Config, pendingConns *Conns, serverEntry *ServerEntry) (tunnel *Tunnel, err error) {
 
 
 	// Select the protocol
 	// Select the protocol
 	var selectedProtocol string
 	var selectedProtocol string
 	// TODO: properly handle protocols (e.g. FRONTED-MEEK-OSSH) vs. capabilities (e.g., {FRONTED-MEEK, OSSH})
 	// TODO: properly handle protocols (e.g. FRONTED-MEEK-OSSH) vs. capabilities (e.g., {FRONTED-MEEK, OSSH})
 	// for now, the code is simply assuming that MEEK capabilities imply OSSH capability.
 	// for now, the code is simply assuming that MEEK capabilities imply OSSH capability.
-	if controller.config.TunnelProtocol != "" {
-		requiredCapability := strings.TrimSuffix(controller.config.TunnelProtocol, "-OSSH")
+	if config.TunnelProtocol != "" {
+		requiredCapability := strings.TrimSuffix(config.TunnelProtocol, "-OSSH")
 		if !Contains(serverEntry.Capabilities, requiredCapability) {
 		if !Contains(serverEntry.Capabilities, requiredCapability) {
 			return nil, ContextError(fmt.Errorf("server does not have required capability"))
 			return nil, ContextError(fmt.Errorf("server does not have required capability"))
 		}
 		}
-		selectedProtocol = controller.config.TunnelProtocol
+		selectedProtocol = config.TunnelProtocol
 	} else {
 	} else {
 		// Order of SupportedTunnelProtocols is default preference order
 		// Order of SupportedTunnelProtocols is default preference order
 		for _, protocol := range SupportedTunnelProtocols {
 		for _, protocol := range SupportedTunnelProtocols {
@@ -144,9 +153,9 @@ func EstablishTunnel(
 		ConnectTimeout:             TUNNEL_CONNECT_TIMEOUT,
 		ConnectTimeout:             TUNNEL_CONNECT_TIMEOUT,
 		ReadTimeout:                TUNNEL_READ_TIMEOUT,
 		ReadTimeout:                TUNNEL_READ_TIMEOUT,
 		WriteTimeout:               TUNNEL_WRITE_TIMEOUT,
 		WriteTimeout:               TUNNEL_WRITE_TIMEOUT,
-		PendingConns:               controller.pendingConns,
-		BindToDeviceServiceAddress: controller.config.BindToDeviceServiceAddress,
-		BindToDeviceDnsServer:      controller.config.BindToDeviceDnsServer,
+		PendingConns:               pendingConns,
+		BindToDeviceServiceAddress: config.BindToDeviceServiceAddress,
+		BindToDeviceDnsServer:      config.BindToDeviceDnsServer,
 	}
 	}
 	var conn Conn
 	var conn Conn
 	if useMeek {
 	if useMeek {
@@ -247,6 +256,19 @@ func EstablishTunnel(
 			sshKeepAliveQuit: sshKeepAliveQuit,
 			sshKeepAliveQuit: sshKeepAliveQuit,
 			// portForwardFailures buffer size is large enough to receive the thresold number
 			// portForwardFailures buffer size is large enough to receive the thresold number
 			// of failure reports without blocking. Senders can drop failures without blocking.
 			// of failure reports without blocking. Senders can drop failures without blocking.
-			portForwardFailures: make(chan int, controller.config.PortForwardFailureThreshold)},
+			portForwardFailures: make(chan int, config.PortForwardFailureThreshold)},
 		nil
 		nil
 }
 }
+
+// Dial establishes a port forward connection through the tunnel
+func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
+	// TODO: should this track port forward failures as in Controller.DialWithTunnel?
+	return tunnel.sshClient.Dial("tcp", remoteAddr)
+}
+
+// SignalFailure notifies the tunnel that an associated component has failed.
+// This will terminate the tunnel.
+func (tunnel *Tunnel) SignalFailure() {
+	Notice(NOTICE_ALERT, "tunnel received failure signal")
+	tunnel.Close()
+}