Browse Source

Close downstream local proxy connection when upstream port forward closes

Improves responsiveness in tunneled apps (e.g., browser) when SOCKS
or HTTPS CONNECT proxy connections are used and when an underlying
tunnel disconnects and is replaced.

Previously, the app could potentially wait for a read timeout on an
open-but-orphaned App<->LocalProxy connection. Now when the upstream
LocalProxy<->SshPortForward connection is closes, the associated
downstream connection is explicitly closed at the same time.
Rod Hynes 11 years ago
parent
commit
705b1cc40c
4 changed files with 43 additions and 16 deletions
  1. 2 2
      psiphon/controller.go
  2. 9 3
      psiphon/httpProxy.go
  3. 4 1
      psiphon/socksProxy.go
  4. 28 10
      psiphon/tunnel.go

+ 2 - 2
psiphon/controller.go

@@ -465,13 +465,13 @@ func (controller *Controller) isActiveTunnelServerEntry(serverEntry *ServerEntry
 // Dial selects an active tunnel and establishes a port forward
 // Dial selects an active tunnel and establishes a port forward
 // connection through the selected tunnel. Failure to connect is considered
 // connection through the selected tunnel. Failure to connect is considered
 // a port foward failure, for the purpose of monitoring tunnel health.
 // a port foward failure, for the purpose of monitoring tunnel health.
-func (controller *Controller) Dial(remoteAddr string) (conn net.Conn, err error) {
+func (controller *Controller) Dial(remoteAddr string, downstreamConn net.Conn) (conn net.Conn, err error) {
 	tunnel := controller.getNextActiveTunnel()
 	tunnel := controller.getNextActiveTunnel()
 	if tunnel == nil {
 	if tunnel == nil {
 		return nil, ContextError(errors.New("no active tunnels"))
 		return nil, ContextError(errors.New("no active tunnels"))
 	}
 	}
 
 
-	tunneledConn, err := tunnel.Dial(remoteAddr)
+	tunneledConn, err := tunnel.Dial(remoteAddr, downstreamConn)
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}

+ 9 - 3
psiphon/httpProxy.go

@@ -50,8 +50,11 @@ func NewHttpProxy(config *Config, tunneler Tunneler) (proxy *HttpProxy, err erro
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
 	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
 	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
+		// downstreamConn is not set in this case, as there is not a fixed
+		// association between a downstream client connection and a particular
+		// tunnel.
 		// TODO: connect timeout?
 		// TODO: connect timeout?
-		return tunneler.Dial(addr)
+		return tunneler.Dial(addr, nil)
 	}
 	}
 	// TODO: also use http.Client, with its Timeout field?
 	// TODO: also use http.Client, with its Timeout field?
 	transport := &http.Transport{
 	transport := &http.Transport{
@@ -174,7 +177,7 @@ func forceClose(responseWriter http.ResponseWriter) {
 	}
 	}
 }
 }
 
 
-// From // https://golang.org/src/pkg/net/http/httputil/reverseproxy.go:
+// From https://golang.org/src/pkg/net/http/httputil/reverseproxy.go:
 // Hop-by-hop headers. These are removed when sent to the backend.
 // Hop-by-hop headers. These are removed when sent to the backend.
 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
 var hopHeaders = []string{
 var hopHeaders = []string{
@@ -193,7 +196,10 @@ func (proxy *HttpProxy) httpConnectHandler(localConn net.Conn, target string) (e
 	defer localConn.Close()
 	defer localConn.Close()
 	defer proxy.openConns.Remove(localConn)
 	defer proxy.openConns.Remove(localConn)
 	proxy.openConns.Add(localConn)
 	proxy.openConns.Add(localConn)
-	remoteConn, err := proxy.tunneler.Dial(target)
+	// Setting downstreamConn so localConn.Close() will be called when remoteConn.Close() is called.
+	// This ensures that the downstream client (e.g., web browser) doesn't keep waiting on the
+	// open connection for data which will never arrive.
+	remoteConn, err := proxy.tunneler.Dial(target, localConn)
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}

+ 4 - 1
psiphon/socksProxy.go

@@ -77,7 +77,10 @@ func (proxy *SocksProxy) socksConnectionHandler(localConn *socks.SocksConn) (err
 	defer localConn.Close()
 	defer localConn.Close()
 	defer proxy.openConns.Remove(localConn)
 	defer proxy.openConns.Remove(localConn)
 	proxy.openConns.Add(localConn)
 	proxy.openConns.Add(localConn)
-	remoteConn, err := proxy.tunneler.Dial(localConn.Req.Target)
+	// Setting peerConn so localConn.Close() will be called when remoteConn.Close() is called.
+	// This ensures that the downstream client (e.g., web browser) doesn't keep waiting on the
+	// open connection for data which will never arrive.
+	remoteConn, err := proxy.tunneler.Dial(localConn.Req.Target, localConn)
 	if err != nil {
 	if err != nil {
 		return ContextError(err)
 		return ContextError(err)
 	}
 	}

+ 28 - 10
psiphon/tunnel.go

@@ -39,8 +39,12 @@ import (
 // Components which use this interface may be serviced by a single Tunnel instance,
 // Components which use this interface may be serviced by a single Tunnel instance,
 // or a Controller which manages a pool of tunnels, or any other object which
 // or a Controller which manages a pool of tunnels, or any other object which
 // implements Tunneler.
 // implements Tunneler.
+// downstreamConn is an optional parameter which specifies a connection to be
+// explictly closed when the Dialed connection is closed. For instance, this
+// is used to close downstreamConn App<->LocalProxy connections when the related
+// LocalProxy<->SshPortForward connections close.
 type Tunneler interface {
 type Tunneler interface {
-	Dial(remoteAddr string) (conn net.Conn, err error)
+	Dial(remoteAddr string, downstreamConn net.Conn) (conn net.Conn, err error)
 	SignalComponentFailure()
 	SignalComponentFailure()
 }
 }
 
 
@@ -177,7 +181,7 @@ func (tunnel *Tunnel) Close() {
 }
 }
 
 
 // Dial establishes a port forward connection through the tunnel
 // Dial establishes a port forward connection through the tunnel
-func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
+func (tunnel *Tunnel) Dial(remoteAddr string, downstreamConn net.Conn) (conn net.Conn, err error) {
 	tunnel.mutex.Lock()
 	tunnel.mutex.Lock()
 	isClosed := tunnel.isClosed
 	isClosed := tunnel.isClosed
 	tunnel.mutex.Unlock()
 	tunnel.mutex.Unlock()
@@ -196,8 +200,9 @@ func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 	}
 	}
 
 
 	conn = &TunneledConn{
 	conn = &TunneledConn{
-		Conn:   sshPortForwardConn,
-		tunnel: tunnel}
+		Conn:           sshPortForwardConn,
+		tunnel:         tunnel,
+		downstreamConn: downstreamConn}
 
 
 	// Tunnel does not have a session when DisableApi is set
 	// Tunnel does not have a session when DisableApi is set
 	if tunnel.session != nil {
 	if tunnel.session != nil {
@@ -208,12 +213,22 @@ func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 	return conn, nil
 	return conn, nil
 }
 }
 
 
+// SignalComponentFailure notifies the tunnel that an associated component has failed.
+// This will terminate the tunnel.
+func (tunnel *Tunnel) SignalComponentFailure() {
+	NoticeAlert("tunnel received component failure signal")
+	tunnel.Close()
+}
+
 // TunneledConn implements net.Conn and wraps a port foward connection.
 // TunneledConn implements net.Conn and wraps a port foward connection.
 // It is used to hook into Read and Write to observe I/O errors and
 // It is used to hook into Read and Write to observe I/O errors and
 // report these errors back to the tunnel monitor as port forward failures.
 // report these errors back to the tunnel monitor as port forward failures.
+// TunneledConn optionally tracks a peer connection to be explictly closed
+// when the TunneledConn is closed.
 type TunneledConn struct {
 type TunneledConn struct {
 	net.Conn
 	net.Conn
-	tunnel *Tunnel
+	tunnel         *Tunnel
+	downstreamConn net.Conn
 }
 }
 
 
 func (conn *TunneledConn) Read(buffer []byte) (n int, err error) {
 func (conn *TunneledConn) Read(buffer []byte) (n int, err error) {
@@ -242,11 +257,14 @@ func (conn *TunneledConn) Write(buffer []byte) (n int, err error) {
 	return
 	return
 }
 }
 
 
-// SignalComponentFailure notifies the tunnel that an associated component has failed.
-// This will terminate the tunnel.
-func (tunnel *Tunnel) SignalComponentFailure() {
-	NoticeAlert("tunnel received component failure signal")
-	tunnel.Close()
+func (conn *TunneledConn) Close() error {
+	if conn.downstreamConn != nil {
+		err := conn.downstreamConn.Close()
+		if err != nil {
+			NoticeAlert("downstreamConn.Close() error: %s", ContextError(err))
+		}
+	}
+	return conn.Conn.Close()
 }
 }
 
 
 // selectProtocol is a helper that picks the tunnel protocol
 // selectProtocol is a helper that picks the tunnel protocol