Browse Source

Don't use a tunnel for port forwarding until it has started its server session (handshake, etc.)

Rod Hynes 11 years ago
parent
commit
635bde06c1
3 changed files with 49 additions and 25 deletions
  1. 26 15
      psiphon/controller.go
  2. 3 0
      psiphon/serverApi.go
  3. 20 10
      psiphon/tunnel.go

+ 26 - 15
psiphon/controller.go

@@ -203,9 +203,7 @@ loop:
 		// solution(?) target MIN(CountServerEntries(region, protocol), TunnelPoolSize)
 		// solution(?) target MIN(CountServerEntries(region, protocol), TunnelPoolSize)
 		case establishedTunnel := <-controller.establishedTunnels:
 		case establishedTunnel := <-controller.establishedTunnels:
 			Notice(NOTICE_INFO, "established tunnel: %s", establishedTunnel.serverEntry.IpAddress)
 			Notice(NOTICE_INFO, "established tunnel: %s", establishedTunnel.serverEntry.IpAddress)
-			// !TODO! design issue: activateTunnel makes tunnel avail for port forward *before* operates does handshake
-			// solution(?) distinguish between two stages or states: connected, and then active.
-			if controller.activateTunnel(establishedTunnel) {
+			if controller.registerTunnel(establishedTunnel) {
 				Notice(NOTICE_INFO, "active tunnel: %s", establishedTunnel.serverEntry.IpAddress)
 				Notice(NOTICE_INFO, "active tunnel: %s", establishedTunnel.serverEntry.IpAddress)
 				controller.operateWaitGroup.Add(1)
 				controller.operateWaitGroup.Add(1)
 				go controller.operateTunnel(establishedTunnel)
 				go controller.operateTunnel(establishedTunnel)
@@ -247,16 +245,23 @@ func (controller *Controller) discardTunnel(tunnel *Tunnel) {
 	tunnel.Close()
 	tunnel.Close()
 }
 }
 
 
-// activateTunnel adds the connected tunnel to the pool of active tunnels
-// which are used for port forwarding. Returns true if the pool has an empty
-// slot and false if the pool is full (caller should discard the tunnel).
-func (controller *Controller) activateTunnel(tunnel *Tunnel) bool {
+// registerTunnel adds the connected tunnel to the pool of active tunnels
+// which are candidates for port forwarding. Returns true if the pool has an
+// empty slot and false if the pool is full (caller should discard the tunnel).
+func (controller *Controller) registerTunnel(tunnel *Tunnel) bool {
 	controller.tunnelMutex.Lock()
 	controller.tunnelMutex.Lock()
 	defer controller.tunnelMutex.Unlock()
 	defer controller.tunnelMutex.Unlock()
-	// !TODO! double check not already a tunnel to this server
 	if len(controller.tunnels) >= controller.config.TunnelPoolSize {
 	if len(controller.tunnels) >= controller.config.TunnelPoolSize {
 		return false
 		return false
 	}
 	}
+	// Perform a fail-safe check just in case we've established
+	// a duplicate connection.
+	for _, activeTunnel := range controller.tunnels {
+		if activeTunnel.serverEntry.IpAddress == tunnel.serverEntry.IpAddress {
+			Notice(NOTICE_ALERT, "duplicate tunnel: %s", tunnel.serverEntry.IpAddress)
+			return false
+		}
+	}
 	controller.tunnels = append(controller.tunnels, tunnel)
 	controller.tunnels = append(controller.tunnels, tunnel)
 	Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
 	Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
 	return true
 	return true
@@ -310,13 +315,18 @@ func (controller *Controller) terminateAllTunnels() {
 func (controller *Controller) getNextActiveTunnel() (tunnel *Tunnel) {
 func (controller *Controller) getNextActiveTunnel() (tunnel *Tunnel) {
 	controller.tunnelMutex.Lock()
 	controller.tunnelMutex.Lock()
 	defer controller.tunnelMutex.Unlock()
 	defer controller.tunnelMutex.Unlock()
-	if len(controller.tunnels) == 0 {
-		return nil
+	for i := len(controller.tunnels); i >= 0; i-- {
+		tunnel = controller.tunnels[controller.nextTunnel]
+		controller.nextTunnel =
+			(controller.nextTunnel + 1) % len(controller.tunnels)
+		// A tunnel must[*] have started its session (performed the server
+		// API handshake sequence) before it may be used for tunneling traffic
+		// [*]currently not enforced by the server, but may be in the future.
+		if tunnel.IsSessionStarted() {
+			return tunnel
+		}
 	}
 	}
-	tunnel = controller.tunnels[controller.nextTunnel]
-	controller.nextTunnel =
-		(controller.nextTunnel + 1) % len(controller.tunnels)
-	return tunnel
+	return nil
 }
 }
 
 
 // getActiveTunnelServerEntries lists the Server Entries for
 // getActiveTunnelServerEntries lists the Server Entries for
@@ -522,7 +532,8 @@ loop:
 		// Note: it's possible that an active tunnel in excludeServerEntries will
 		// Note: it's possible that an active tunnel in excludeServerEntries will
 		// fail during this iteration of server entries and in that case the
 		// fail during this iteration of server entries and in that case the
 		// cooresponding server will not be retried (within the same iteration).
 		// cooresponding server will not be retried (within the same iteration).
-		// !TODO! is there also a race that can result in multiple tunnels to the same server
+		// TODO: is there also a race that can result in multiple tunnels to the same
+		// server? (if there is, registerTunnel will reject the duplicate instance.)
 		excludeServerEntries := controller.getActiveTunnelServerEntries()
 		excludeServerEntries := controller.getActiveTunnelServerEntries()
 		iterator, err := NewServerEntryIterator(
 		iterator, err := NewServerEntryIterator(
 			controller.config.EgressRegion, controller.config.TunnelProtocol, excludeServerEntries)
 			controller.config.EgressRegion, controller.config.TunnelProtocol, excludeServerEntries)

+ 3 - 0
psiphon/serverApi.go

@@ -67,6 +67,9 @@ func NewSession(config *Config, tunnel *Tunnel) (session *Session, err error) {
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+
+	tunnel.SetSessionStarted()
+
 	return session, nil
 	return session, nil
 }
 }
 
 

+ 20 - 10
psiphon/tunnel.go

@@ -28,6 +28,7 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"strings"
 	"strings"
+	"sync/atomic"
 	"time"
 	"time"
 )
 )
 
 
@@ -61,6 +62,7 @@ var SupportedTunnelProtocols = []string{
 type Tunnel struct {
 type Tunnel struct {
 	serverEntry             *ServerEntry
 	serverEntry             *ServerEntry
 	sessionId               string
 	sessionId               string
+	sessionStarted          int32
 	protocol                string
 	protocol                string
 	conn                    Conn
 	conn                    Conn
 	sshClient               *ssh.Client
 	sshClient               *ssh.Client
@@ -69,16 +71,6 @@ type Tunnel struct {
 	portForwardFailureTotal int
 	portForwardFailureTotal int
 }
 }
 
 
-// Close terminates the tunnel.
-func (tunnel *Tunnel) Close() {
-	if tunnel.sshKeepAliveQuit != nil {
-		close(tunnel.sshKeepAliveQuit)
-	}
-	if tunnel.conn != nil {
-		tunnel.conn.Close()
-	}
-}
-
 // EstablishTunnel first makes a network transport connection to the
 // EstablishTunnel first makes a network transport connection to the
 // Psiphon server and then establishes an SSH client session on top of
 // Psiphon server and then establishes an SSH client session on top of
 // that transport. The SSH server is authenticated using the public
 // that transport. The SSH server is authenticated using the public
@@ -260,6 +252,24 @@ func EstablishTunnel(
 		nil
 		nil
 }
 }
 
 
+// Close terminates the tunnel.
+func (tunnel *Tunnel) Close() {
+	if tunnel.sshKeepAliveQuit != nil {
+		close(tunnel.sshKeepAliveQuit)
+	}
+	if tunnel.conn != nil {
+		tunnel.conn.Close()
+	}
+}
+
+func (tunnel *Tunnel) IsSessionStarted() bool {
+	return atomic.LoadInt32(&tunnel.sessionStarted) == 1
+}
+
+func (tunnel *Tunnel) SetSessionStarted() {
+	atomic.StoreInt32(&tunnel.sessionStarted, 1)
+}
+
 // Dial establishes a port forward connection through the tunnel
 // Dial establishes a port forward connection through the tunnel
 func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 	// TODO: should this track port forward failures as in Controller.DialWithTunnel?
 	// TODO: should this track port forward failures as in Controller.DialWithTunnel?