Browse Source

Merge pull request #32 from rod-hynes/master

Bug fixes while testing controller
Rod Hynes 11 years ago
parent
commit
b4d64ed7ee
8 changed files with 114 additions and 80 deletions
  1. 2 2
      psiphon/TCPConn_unix.go
  2. 2 2
      psiphon/conn.go
  3. 45 46
      psiphon/controller.go
  4. 4 12
      psiphon/dataStore.go
  5. 1 1
      psiphon/meekConn.go
  6. 10 6
      psiphon/serverApi.go
  7. 20 10
      psiphon/tunnel.go
  8. 30 1
      psiphonClient.go

+ 2 - 2
psiphon/TCPConn_unix.go

@@ -106,11 +106,11 @@ func interruptibleTCPDial(addr string, config *DialConfig) (conn *TCPConn, err e
 			errChannel <- errors.New("connect timeout")
 			errChannel <- errors.New("connect timeout")
 		})
 		})
 		go func() {
 		go func() {
-			errChannel <- syscall.Connect(conn.interruptible.socketFd, &sockAddr)
+			errChannel <- syscall.Connect(socketFd, &sockAddr)
 		}()
 		}()
 		err = <-errChannel
 		err = <-errChannel
 	} else {
 	} else {
-		err = syscall.Connect(conn.interruptible.socketFd, &sockAddr)
+		err = syscall.Connect(socketFd, &sockAddr)
 	}
 	}
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)

+ 2 - 2
psiphon/conn.go

@@ -118,12 +118,12 @@ func Relay(localConn, remoteConn net.Conn) {
 		defer copyWaitGroup.Done()
 		defer copyWaitGroup.Done()
 		_, err := io.Copy(localConn, remoteConn)
 		_, err := io.Copy(localConn, remoteConn)
 		if err != nil {
 		if err != nil {
-			Notice(NOTICE_ALERT, "%s", ContextError(err))
+			Notice(NOTICE_ALERT, "Relay failed: %s", ContextError(err))
 		}
 		}
 	}()
 	}()
 	_, err := io.Copy(remoteConn, localConn)
 	_, err := io.Copy(remoteConn, localConn)
 	if err != nil {
 	if err != nil {
-		Notice(NOTICE_ALERT, "%s", ContextError(err))
+		Notice(NOTICE_ALERT, "Relay failed: %s", ContextError(err))
 	}
 	}
 	copyWaitGroup.Wait()
 	copyWaitGroup.Wait()
 }
 }

+ 45 - 46
psiphon/controller.go

@@ -27,9 +27,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"log"
 	"net"
 	"net"
-	"os"
 	"sync"
 	"sync"
 	"time"
 	"time"
 )
 )
@@ -85,6 +83,9 @@ func NewController(config *Config) (controller *Controller) {
 // - a local SOCKS proxy that port forwards through the pool of tunnels
 // - a local SOCKS proxy that port forwards through the pool of tunnels
 // - a local HTTP proxy that port forwards through the pool of tunnels
 // - a local HTTP proxy that port forwards through the pool of tunnels
 func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
 func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
+
+	Notice(NOTICE_VERSION, VERSION)
+
 	socksProxy, err := NewSocksProxy(controller.config, controller)
 	socksProxy, err := NewSocksProxy(controller.config, controller)
 	if err != nil {
 	if err != nil {
 		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
 		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
@@ -203,9 +204,7 @@ loop:
 		// solution(?) target MIN(CountServerEntries(region, protocol), TunnelPoolSize)
 		// solution(?) target MIN(CountServerEntries(region, protocol), TunnelPoolSize)
 		case establishedTunnel := <-controller.establishedTunnels:
 		case establishedTunnel := <-controller.establishedTunnels:
 			Notice(NOTICE_INFO, "established tunnel: %s", establishedTunnel.serverEntry.IpAddress)
 			Notice(NOTICE_INFO, "established tunnel: %s", establishedTunnel.serverEntry.IpAddress)
-			// !TODO! design issue: activateTunnel makes tunnel avail for port forward *before* operates does handshake
-			// solution(?) distinguish between two stages or states: connected, and then active.
-			if controller.activateTunnel(establishedTunnel) {
+			if controller.registerTunnel(establishedTunnel) {
 				Notice(NOTICE_INFO, "active tunnel: %s", establishedTunnel.serverEntry.IpAddress)
 				Notice(NOTICE_INFO, "active tunnel: %s", establishedTunnel.serverEntry.IpAddress)
 				controller.operateWaitGroup.Add(1)
 				controller.operateWaitGroup.Add(1)
 				go controller.operateTunnel(establishedTunnel)
 				go controller.operateTunnel(establishedTunnel)
@@ -247,16 +246,23 @@ func (controller *Controller) discardTunnel(tunnel *Tunnel) {
 	tunnel.Close()
 	tunnel.Close()
 }
 }
 
 
-// activateTunnel adds the connected tunnel to the pool of active tunnels
-// which are used for port forwarding. Returns true if the pool has an empty
-// slot and false if the pool is full (caller should discard the tunnel).
-func (controller *Controller) activateTunnel(tunnel *Tunnel) bool {
+// registerTunnel adds the connected tunnel to the pool of active tunnels
+// which are candidates for port forwarding. Returns true if the pool has an
+// empty slot and false if the pool is full (caller should discard the tunnel).
+func (controller *Controller) registerTunnel(tunnel *Tunnel) bool {
 	controller.tunnelMutex.Lock()
 	controller.tunnelMutex.Lock()
 	defer controller.tunnelMutex.Unlock()
 	defer controller.tunnelMutex.Unlock()
-	// !TODO! double check not already a tunnel to this server
 	if len(controller.tunnels) >= controller.config.TunnelPoolSize {
 	if len(controller.tunnels) >= controller.config.TunnelPoolSize {
 		return false
 		return false
 	}
 	}
+	// Perform a final check just in case we've established
+	// a duplicate connection.
+	for _, activeTunnel := range controller.tunnels {
+		if activeTunnel.serverEntry.IpAddress == tunnel.serverEntry.IpAddress {
+			Notice(NOTICE_ALERT, "duplicate tunnel: %s", tunnel.serverEntry.IpAddress)
+			return false
+		}
+	}
 	controller.tunnels = append(controller.tunnels, tunnel)
 	controller.tunnels = append(controller.tunnels, tunnel)
 	Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
 	Notice(NOTICE_TUNNEL, "%d tunnels", len(controller.tunnels))
 	return true
 	return true
@@ -310,26 +316,31 @@ func (controller *Controller) terminateAllTunnels() {
 func (controller *Controller) getNextActiveTunnel() (tunnel *Tunnel) {
 func (controller *Controller) getNextActiveTunnel() (tunnel *Tunnel) {
 	controller.tunnelMutex.Lock()
 	controller.tunnelMutex.Lock()
 	defer controller.tunnelMutex.Unlock()
 	defer controller.tunnelMutex.Unlock()
-	if len(controller.tunnels) == 0 {
-		return nil
+	for i := len(controller.tunnels); i > 0; i-- {
+		tunnel = controller.tunnels[controller.nextTunnel]
+		controller.nextTunnel =
+			(controller.nextTunnel + 1) % len(controller.tunnels)
+		// A tunnel must[*] have started its session (performed the server
+		// API handshake sequence) before it may be used for tunneling traffic
+		// [*]currently not enforced by the server, but may be in the future.
+		if tunnel.IsSessionStarted() {
+			return tunnel
+		}
 	}
 	}
-	tunnel = controller.tunnels[controller.nextTunnel]
-	controller.nextTunnel =
-		(controller.nextTunnel + 1) % len(controller.tunnels)
-	return tunnel
+	return nil
 }
 }
 
 
-// getActiveTunnelServerEntries lists the Server Entries for
-// all the active tunnels. This is used to exclude those servers
-// from the set of candidates to establish connections to.
-func (controller *Controller) getActiveTunnelServerEntries() (serverEntries []*ServerEntry) {
+// isActiveTunnelServerEntries is used to check if there's already
+// an existing tunnel to a candidate server.
+func (controller *Controller) isActiveTunnelServerEntry(serverEntry *ServerEntry) bool {
 	controller.tunnelMutex.Lock()
 	controller.tunnelMutex.Lock()
 	defer controller.tunnelMutex.Unlock()
 	defer controller.tunnelMutex.Unlock()
-	serverEntries = make([]*ServerEntry, 0)
 	for _, activeTunnel := range controller.tunnels {
 	for _, activeTunnel := range controller.tunnels {
-		serverEntries = append(serverEntries, activeTunnel.serverEntry)
+		if activeTunnel.serverEntry.IpAddress == serverEntry.IpAddress {
+			return true
+		}
 	}
 	}
-	return serverEntries
+	return false
 }
 }
 
 
 // operateTunnel starts a Psiphon session (handshake, etc.) on a newly
 // operateTunnel starts a Psiphon session (handshake, etc.) on a newly
@@ -372,6 +383,9 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 		err = fmt.Errorf("error starting session for %s: %s", tunnel.serverEntry.IpAddress, err)
 		err = fmt.Errorf("error starting session for %s: %s", tunnel.serverEntry.IpAddress, err)
 	}
 	}
 
 
+	// Tunnel may now be used for port forwarding
+	tunnel.SetSessionStarted()
+
 	// Promote this successful tunnel to first rank so it's one
 	// Promote this successful tunnel to first rank so it's one
 	// of the first candidates next time establish runs.
 	// of the first candidates next time establish runs.
 	PromoteServerEntry(tunnel.serverEntry.IpAddress)
 	PromoteServerEntry(tunnel.serverEntry.IpAddress)
@@ -380,6 +394,9 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 		select {
 		select {
 		case failures := <-tunnel.portForwardFailures:
 		case failures := <-tunnel.portForwardFailures:
 			tunnel.portForwardFailureTotal += failures
 			tunnel.portForwardFailureTotal += failures
+			Notice(
+				NOTICE_INFO, "port forward failures for %s: %d",
+				tunnel.serverEntry.IpAddress, tunnel.portForwardFailureTotal)
 			if tunnel.portForwardFailureTotal > controller.config.PortForwardFailureThreshold {
 			if tunnel.portForwardFailureTotal > controller.config.PortForwardFailureThreshold {
 				err = errors.New("tunnel exceeded port forward failure threshold")
 				err = errors.New("tunnel exceeded port forward failure threshold")
 			}
 			}
@@ -519,10 +536,8 @@ loop:
 		// Note: it's possible that an active tunnel in excludeServerEntries will
 		// Note: it's possible that an active tunnel in excludeServerEntries will
 		// fail during this iteration of server entries and in that case the
 		// fail during this iteration of server entries and in that case the
 		// cooresponding server will not be retried (within the same iteration).
 		// cooresponding server will not be retried (within the same iteration).
-		// !TODO! is there also a race that can result in multiple tunnels to the same server
-		excludeServerEntries := controller.getActiveTunnelServerEntries()
 		iterator, err := NewServerEntryIterator(
 		iterator, err := NewServerEntryIterator(
-			controller.config.EgressRegion, controller.config.TunnelProtocol, excludeServerEntries)
+			controller.config.EgressRegion, controller.config.TunnelProtocol)
 		if err != nil {
 		if err != nil {
 			Notice(NOTICE_ALERT, "failed to iterate over candidates: %s", err)
 			Notice(NOTICE_ALERT, "failed to iterate over candidates: %s", err)
 			controller.SignalFailure()
 			controller.SignalFailure()
@@ -577,6 +592,10 @@ func (controller *Controller) establishTunnelWorker() {
 			return
 			return
 		default:
 		default:
 		}
 		}
+		// There may already be a tunnel to this candidate. If so, skip it.
+		if controller.isActiveTunnelServerEntry(serverEntry) {
+			continue
+		}
 		tunnel, err := EstablishTunnel(
 		tunnel, err := EstablishTunnel(
 			controller.config, controller.pendingConns, serverEntry)
 			controller.config, controller.pendingConns, serverEntry)
 		if err != nil {
 		if err != nil {
@@ -595,23 +614,3 @@ func (controller *Controller) establishTunnelWorker() {
 	}
 	}
 	Notice(NOTICE_INFO, "stopped establish worker")
 	Notice(NOTICE_INFO, "stopped establish worker")
 }
 }
-
-// RunForever executes the main loop of the Psiphon client. It launches
-// the controller with a shutdown that it never signaled.
-func RunForever(config *Config) {
-
-	if config.LogFilename != "" {
-		logFile, err := os.OpenFile(config.LogFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
-		if err != nil {
-			Fatal("error opening log file: %s", err)
-		}
-		defer logFile.Close()
-		log.SetOutput(logFile)
-	}
-
-	Notice(NOTICE_VERSION, VERSION)
-
-	controller := NewController(config)
-	shutdownBroadcast := make(chan struct{})
-	controller.Run(shutdownBroadcast)
-}

+ 4 - 12
psiphon/dataStore.go

@@ -204,19 +204,11 @@ type ServerEntryIterator struct {
 }
 }
 
 
 // NewServerEntryIterator creates a new NewServerEntryIterator
 // NewServerEntryIterator creates a new NewServerEntryIterator
-func NewServerEntryIterator(
-	region, protocol string,
-	excludeServerEntries []*ServerEntry) (iterator *ServerEntryIterator, err error) {
-
+func NewServerEntryIterator(region, protocol string) (iterator *ServerEntryIterator, err error) {
 	initDataStore()
 	initDataStore()
-	excludeIds := make([]string, len(excludeServerEntries))
-	for index, serverEntry := range excludeServerEntries {
-		excludeIds[index] = serverEntry.IpAddress
-	}
 	iterator = &ServerEntryIterator{
 	iterator = &ServerEntryIterator{
-		region:     region,
-		protocol:   protocol,
-		excludeIds: excludeIds,
+		region:   region,
+		protocol: protocol,
 	}
 	}
 	err = iterator.Reset()
 	err = iterator.Reset()
 	if err != nil {
 	if err != nil {
@@ -235,7 +227,7 @@ func (iterator *ServerEntryIterator) Reset() error {
 	}
 	}
 	var cursor *sql.Rows
 	var cursor *sql.Rows
 	whereClause, whereParams := makeServerEntryWhereClause(
 	whereClause, whereParams := makeServerEntryWhereClause(
-		iterator.region, iterator.protocol, iterator.excludeIds)
+		iterator.region, iterator.protocol, nil)
 	query := "select data from serverEntry" + whereClause + " order by rank desc;"
 	query := "select data from serverEntry" + whereClause + " order by rank desc;"
 	cursor, err = transaction.Query(query, whereParams...)
 	cursor, err = transaction.Query(query, whereParams...)
 	if err != nil {
 	if err != nil {

+ 1 - 1
psiphon/meekConn.go

@@ -326,7 +326,7 @@ func (meek *MeekConn) relay() {
 	defer meek.relayWaitGroup.Done()
 	defer meek.relayWaitGroup.Done()
 	interval := MIN_POLL_INTERVAL
 	interval := MIN_POLL_INTERVAL
 	timeout := time.NewTimer(interval)
 	timeout := time.NewTimer(interval)
-	var sendPayload = make([]byte, MAX_SEND_PAYLOAD_LENGTH)
+	sendPayload := make([]byte, MAX_SEND_PAYLOAD_LENGTH)
 	for {
 	for {
 		timeout.Reset(interval)
 		timeout.Reset(interval)
 		// Block until there is payload to send or it is time to poll
 		// Block until there is payload to send or it is time to poll

+ 10 - 6
psiphon/serverApi.go

@@ -67,6 +67,7 @@ func NewSession(config *Config, tunnel *Tunnel) (session *Session, err error) {
 	if err != nil {
 	if err != nil {
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
+
 	return session, nil
 	return session, nil
 }
 }
 
 
@@ -141,12 +142,15 @@ func (session *Session) doHandshakeRequest() error {
 	if upgradeClientVersion > session.config.ClientVersion {
 	if upgradeClientVersion > session.config.ClientVersion {
 		Notice(NOTICE_UPGRADE, "%d", upgradeClientVersion)
 		Notice(NOTICE_UPGRADE, "%d", upgradeClientVersion)
 	}
 	}
-	for _, pageViewRegex := range handshakeConfig.PageViewRegexes {
-		Notice(NOTICE_PAGE_VIEW_REGEX, "%s %s", pageViewRegex["regex"], pageViewRegex["replace"])
-	}
-	for _, httpsRequestRegex := range handshakeConfig.HttpsRequestRegexes {
-		Notice(NOTICE_HTTPS_REGEX, "%s %s", httpsRequestRegex["regex"], httpsRequestRegex["replace"])
-	}
+	// TODO: remove regex notices -- regexes will be used internally
+	/*
+		for _, pageViewRegex := range handshakeConfig.PageViewRegexes {
+			Notice(NOTICE_PAGE_VIEW_REGEX, "%s %s", pageViewRegex["regex"], pageViewRegex["replace"])
+		}
+		for _, httpsRequestRegex := range handshakeConfig.HttpsRequestRegexes {
+			Notice(NOTICE_HTTPS_REGEX, "%s %s", httpsRequestRegex["regex"], httpsRequestRegex["replace"])
+		}
+	*/
 	return nil
 	return nil
 }
 }
 
 

+ 20 - 10
psiphon/tunnel.go

@@ -28,6 +28,7 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"strings"
 	"strings"
+	"sync/atomic"
 	"time"
 	"time"
 )
 )
 
 
@@ -61,6 +62,7 @@ var SupportedTunnelProtocols = []string{
 type Tunnel struct {
 type Tunnel struct {
 	serverEntry             *ServerEntry
 	serverEntry             *ServerEntry
 	sessionId               string
 	sessionId               string
+	sessionStarted          int32
 	protocol                string
 	protocol                string
 	conn                    Conn
 	conn                    Conn
 	sshClient               *ssh.Client
 	sshClient               *ssh.Client
@@ -69,16 +71,6 @@ type Tunnel struct {
 	portForwardFailureTotal int
 	portForwardFailureTotal int
 }
 }
 
 
-// Close terminates the tunnel.
-func (tunnel *Tunnel) Close() {
-	if tunnel.sshKeepAliveQuit != nil {
-		close(tunnel.sshKeepAliveQuit)
-	}
-	if tunnel.conn != nil {
-		tunnel.conn.Close()
-	}
-}
-
 // EstablishTunnel first makes a network transport connection to the
 // EstablishTunnel first makes a network transport connection to the
 // Psiphon server and then establishes an SSH client session on top of
 // Psiphon server and then establishes an SSH client session on top of
 // that transport. The SSH server is authenticated using the public
 // that transport. The SSH server is authenticated using the public
@@ -260,6 +252,24 @@ func EstablishTunnel(
 		nil
 		nil
 }
 }
 
 
+// Close terminates the tunnel.
+func (tunnel *Tunnel) Close() {
+	if tunnel.sshKeepAliveQuit != nil {
+		close(tunnel.sshKeepAliveQuit)
+	}
+	if tunnel.conn != nil {
+		tunnel.conn.Close()
+	}
+}
+
+func (tunnel *Tunnel) IsSessionStarted() bool {
+	return atomic.LoadInt32(&tunnel.sessionStarted) == 1
+}
+
+func (tunnel *Tunnel) SetSessionStarted() {
+	atomic.StoreInt32(&tunnel.sessionStarted, 1)
+}
+
 // Dial establishes a port forward connection through the tunnel
 // Dial establishes a port forward connection through the tunnel
 func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 	// TODO: should this track port forward failures as in Controller.DialWithTunnel?
 	// TODO: should this track port forward failures as in Controller.DialWithTunnel?

+ 30 - 1
psiphonClient.go

@@ -23,9 +23,13 @@ import (
 	"flag"
 	"flag"
 	psiphon "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	psiphon "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"log"
 	"log"
+	"os"
+	"os/signal"
+	"sync"
 )
 )
 
 
 func main() {
 func main() {
+
 	var configFilename string
 	var configFilename string
 	flag.StringVar(&configFilename, "config", "", "configuration file")
 	flag.StringVar(&configFilename, "config", "", "configuration file")
 	flag.Parse()
 	flag.Parse()
@@ -36,5 +40,30 @@ func main() {
 	if err != nil {
 	if err != nil {
 		log.Fatalf("error loading configuration file: %s", err)
 		log.Fatalf("error loading configuration file: %s", err)
 	}
 	}
-	psiphon.RunForever(config)
+
+	if config.LogFilename != "" {
+		logFile, err := os.OpenFile(config.LogFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
+		if err != nil {
+			log.Fatalf("error opening log file: %s", err)
+		}
+		defer logFile.Close()
+		log.SetOutput(logFile)
+	}
+
+	controller := psiphon.NewController(config)
+	shutdownBroadcast := make(chan struct{})
+	controllerWaitGroup := new(sync.WaitGroup)
+	controllerWaitGroup.Add(1)
+	go func() {
+		defer controllerWaitGroup.Done()
+		controller.Run(shutdownBroadcast)
+	}()
+
+	systemStopSignal := make(chan os.Signal, 1)
+	signal.Notify(systemStopSignal, os.Interrupt, os.Kill)
+	<-systemStopSignal
+
+	psiphon.Notice(psiphon.NOTICE_INFO, "shutdown by system")
+	close(shutdownBroadcast)
+	controllerWaitGroup.Wait()
 }
 }