Forráskód Böngészése

Merge pull request #70 from rod-hynes/master

Close downstream local proxy connection when upstream port forward closes
Rod Hynes 11 éve
szülő
commit
aa078a773a
4 módosított fájl, 43 hozzáadás és 16 törlés
  1. 2 2
      psiphon/controller.go
  2. 9 3
      psiphon/httpProxy.go
  3. 4 1
      psiphon/socksProxy.go
  4. 28 10
      psiphon/tunnel.go

+ 2 - 2
psiphon/controller.go

@@ -465,13 +465,13 @@ func (controller *Controller) isActiveTunnelServerEntry(serverEntry *ServerEntry
 // Dial selects an active tunnel and establishes a port forward
 // connection through the selected tunnel. Failure to connect is considered
 // a port foward failure, for the purpose of monitoring tunnel health.
-func (controller *Controller) Dial(remoteAddr string) (conn net.Conn, err error) {
+func (controller *Controller) Dial(remoteAddr string, downstreamConn net.Conn) (conn net.Conn, err error) {
 	tunnel := controller.getNextActiveTunnel()
 	if tunnel == nil {
 		return nil, ContextError(errors.New("no active tunnels"))
 	}
 
-	tunneledConn, err := tunnel.Dial(remoteAddr)
+	tunneledConn, err := tunnel.Dial(remoteAddr, downstreamConn)
 	if err != nil {
 		return nil, ContextError(err)
 	}

+ 9 - 3
psiphon/httpProxy.go

@@ -50,8 +50,11 @@ func NewHttpProxy(config *Config, tunneler Tunneler) (proxy *HttpProxy, err erro
 		return nil, ContextError(err)
 	}
 	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
+		// downstreamConn is not set in this case, as there is not a fixed
+		// association between a downstream client connection and a particular
+		// tunnel.
 		// TODO: connect timeout?
-		return tunneler.Dial(addr)
+		return tunneler.Dial(addr, nil)
 	}
 	// TODO: also use http.Client, with its Timeout field?
 	transport := &http.Transport{
@@ -174,7 +177,7 @@ func forceClose(responseWriter http.ResponseWriter) {
 	}
 }
 
-// From // https://golang.org/src/pkg/net/http/httputil/reverseproxy.go:
+// From https://golang.org/src/pkg/net/http/httputil/reverseproxy.go:
 // Hop-by-hop headers. These are removed when sent to the backend.
 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
 var hopHeaders = []string{
@@ -193,7 +196,10 @@ func (proxy *HttpProxy) httpConnectHandler(localConn net.Conn, target string) (e
 	defer localConn.Close()
 	defer proxy.openConns.Remove(localConn)
 	proxy.openConns.Add(localConn)
-	remoteConn, err := proxy.tunneler.Dial(target)
+	// Setting downstreamConn so localConn.Close() will be called when remoteConn.Close() is called.
+	// This ensures that the downstream client (e.g., web browser) doesn't keep waiting on the
+	// open connection for data which will never arrive.
+	remoteConn, err := proxy.tunneler.Dial(target, localConn)
 	if err != nil {
 		return ContextError(err)
 	}

+ 4 - 1
psiphon/socksProxy.go

@@ -77,7 +77,10 @@ func (proxy *SocksProxy) socksConnectionHandler(localConn *socks.SocksConn) (err
 	defer localConn.Close()
 	defer proxy.openConns.Remove(localConn)
 	proxy.openConns.Add(localConn)
-	remoteConn, err := proxy.tunneler.Dial(localConn.Req.Target)
+	// Setting peerConn so localConn.Close() will be called when remoteConn.Close() is called.
+	// This ensures that the downstream client (e.g., web browser) doesn't keep waiting on the
+	// open connection for data which will never arrive.
+	remoteConn, err := proxy.tunneler.Dial(localConn.Req.Target, localConn)
 	if err != nil {
 		return ContextError(err)
 	}

+ 28 - 10
psiphon/tunnel.go

@@ -39,8 +39,12 @@ import (
 // Components which use this interface may be serviced by a single Tunnel instance,
 // or a Controller which manages a pool of tunnels, or any other object which
 // implements Tunneler.
+// downstreamConn is an optional parameter which specifies a connection to be
+// explictly closed when the Dialed connection is closed. For instance, this
+// is used to close downstreamConn App<->LocalProxy connections when the related
+// LocalProxy<->SshPortForward connections close.
 type Tunneler interface {
-	Dial(remoteAddr string) (conn net.Conn, err error)
+	Dial(remoteAddr string, downstreamConn net.Conn) (conn net.Conn, err error)
 	SignalComponentFailure()
 }
 
@@ -177,7 +181,7 @@ func (tunnel *Tunnel) Close() {
 }
 
 // Dial establishes a port forward connection through the tunnel
-func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
+func (tunnel *Tunnel) Dial(remoteAddr string, downstreamConn net.Conn) (conn net.Conn, err error) {
 	tunnel.mutex.Lock()
 	isClosed := tunnel.isClosed
 	tunnel.mutex.Unlock()
@@ -196,8 +200,9 @@ func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 	}
 
 	conn = &TunneledConn{
-		Conn:   sshPortForwardConn,
-		tunnel: tunnel}
+		Conn:           sshPortForwardConn,
+		tunnel:         tunnel,
+		downstreamConn: downstreamConn}
 
 	// Tunnel does not have a session when DisableApi is set
 	if tunnel.session != nil {
@@ -208,12 +213,22 @@ func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 	return conn, nil
 }
 
+// SignalComponentFailure notifies the tunnel that an associated component has failed.
+// This will terminate the tunnel.
+func (tunnel *Tunnel) SignalComponentFailure() {
+	NoticeAlert("tunnel received component failure signal")
+	tunnel.Close()
+}
+
 // TunneledConn implements net.Conn and wraps a port foward connection.
 // It is used to hook into Read and Write to observe I/O errors and
 // report these errors back to the tunnel monitor as port forward failures.
+// TunneledConn optionally tracks a peer connection to be explictly closed
+// when the TunneledConn is closed.
 type TunneledConn struct {
 	net.Conn
-	tunnel *Tunnel
+	tunnel         *Tunnel
+	downstreamConn net.Conn
 }
 
 func (conn *TunneledConn) Read(buffer []byte) (n int, err error) {
@@ -242,11 +257,14 @@ func (conn *TunneledConn) Write(buffer []byte) (n int, err error) {
 	return
 }
 
-// SignalComponentFailure notifies the tunnel that an associated component has failed.
-// This will terminate the tunnel.
-func (tunnel *Tunnel) SignalComponentFailure() {
-	NoticeAlert("tunnel received component failure signal")
-	tunnel.Close()
+func (conn *TunneledConn) Close() error {
+	if conn.downstreamConn != nil {
+		err := conn.downstreamConn.Close()
+		if err != nil {
+			NoticeAlert("downstreamConn.Close() error: %s", ContextError(err))
+		}
+	}
+	return conn.Conn.Close()
 }
 
 // selectProtocol is a helper that picks the tunnel protocol