Преглед изворни кода

refactored runTunnel and tunnel to use less shared memory and more resemble an idiomatic Go processing pipeline; refactored conn to support an explicit, synchronized interruption list and cleanup race conditions related to unsynced access to shared memory

Rod Hynes пре 11 година
родитељ
комит
9308d1ec87
5 измењених фајлова са 174 додато и 153 уклоњено
  1. 1 0
      README.md
  2. 93 55
      psiphon/conn.go
  3. 51 59
      psiphon/runTunnel.go
  4. 1 1
      psiphon/socksProxy.go
  5. 28 38
      psiphon/tunnel.go

+ 1 - 0
README.md

@@ -14,6 +14,7 @@ Status
 This project is currently at the proof-of-concept stage. Current production Psiphon client code is available at our [main repository](https://bitbucket.org/psiphon/psiphon-circumvention-system).
 
 ### TODO
+* psiphon.Conn for Windows
 * more test cases
 * integrate meek-client
 * add config options

+ 93 - 55
psiphon/conn.go

@@ -23,6 +23,7 @@ import (
 	"errors"
 	"net"
 	"os"
+	"sync"
 	"syscall"
 	"time"
 )
@@ -36,19 +37,26 @@ import (
 //   routing compatibility, for example).
 type Conn struct {
 	net.Conn
-	socketFd            int
-	needCloseSocketFd   bool
-	isDisconnected      bool
-	disconnectionSignal chan bool
-	readTimeout         time.Duration
-	writeTimeout        time.Duration
+	mutex        sync.Mutex
+	socketFd     int
+	isClosed     bool
+	closedSignal chan bool
+	readTimeout  time.Duration
+	writeTimeout time.Duration
 }
 
-// NewConn creates a new, configured Conn. Unlike standard Dial
-// functions, this does not return a connected net.Conn. Call the Connect function
-// to complete the connection establishment. To implement device binding and
-// interruptible connecting, the lower-level syscall APIs are used.
-func NewConn(readTimeout, writeTimeout time.Duration, deviceName string) (*Conn, error) {
+// NewConn creates a new, connected Conn. The connection can be interrupted
+// using pendingConns.interrupt(): the new Conn is added to pendingConns
+// before the socket connect beings. The caller is responsible for removing the
+// returned Conn from pendingConns.
+// To implement device binding and interruptible connecting, the lower-level
+// syscall APIs are used. The sequence of syscalls in this implementation are
+// taken from: https://code.google.com/p/go/issues/detail?id=6966
+func Dial(
+	ipAddress string, port int,
+	readTimeout, writeTimeout time.Duration,
+	pendingConns *PendingConns) (conn *Conn, err error) {
+
 	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
 	if err != nil {
 		return nil, err
@@ -58,7 +66,7 @@ func NewConn(readTimeout, writeTimeout time.Duration, deviceName string) (*Conn,
 		syscall.Close(socketFd)
 		return nil, err
 	}
-	if deviceName != "" {
+	/*
 		// TODO: requires root, which we won't have on Android in VpnService mode
 		//       an alternative may be to use http://golang.org/pkg/syscall/#UnixRights to
 		//       send the fd to the main Android process which receives the fd with
@@ -69,74 +77,74 @@ func NewConn(readTimeout, writeTimeout time.Duration, deviceName string) (*Conn,
 		//       https://code.google.com/p/ics-openvpn/source/browse/main/src/main/java/de/blinkt/openvpn/core/OpenVpnManagementThread.java#164
 		const SO_BINDTODEVICE = 0x19 // only defined for Linux
 		err = syscall.SetsockoptString(socketFd, syscall.SOL_SOCKET, SO_BINDTODEVICE, deviceName)
-		return nil, err
-	}
-	return &Conn{
-		socketFd:          socketFd,
-		needCloseSocketFd: true,
-		readTimeout:       readTimeout,
-		writeTimeout:      writeTimeout}, nil
-}
-
-// Connect establishes a connection to the specified host. The sequence of
-// syscalls in this implementation are taken from: https://code.google.com/p/go/issues/detail?id=6966
-func (conn *Conn) Connect(ipAddress string, port int) (err error) {
+	*/
+	conn = &Conn{
+		socketFd:     socketFd,
+		readTimeout:  readTimeout,
+		writeTimeout: writeTimeout}
+	pendingConns.Add(conn)
 	// TODO: domain name resolution (for meek)
 	var addr [4]byte
 	copy(addr[:], net.ParseIP(ipAddress).To4())
 	sockAddr := syscall.SockaddrInet4{Addr: addr, Port: port}
 	err = syscall.Connect(conn.socketFd, &sockAddr)
 	if err != nil {
-		return err
+		return nil, err
 	}
 	file := os.NewFile(uintptr(conn.socketFd), "")
 	defer file.Close()
-	fileConn, err := net.FileConn(file)
+	conn.Conn, err = net.FileConn(file)
 	if err != nil {
-		return err
+		return nil, err
 	}
-	conn.Conn = fileConn
-	conn.needCloseSocketFd = false
-	return nil
+	return conn, nil
 }
 
-// SetDisconnectionSignal sets the channel which will be signaled
-// when the connection terminates. This function returns an error
-// if the connection is already disconnected (and would never send
+// SetClosedSignal sets the channel which will be signaled
+// when the connection is closed. This function returns an error
+// if the connection is already closed (and would never send
 // the signal).
-func (conn *Conn) SetDisconnectionSignal(disconnectionSignal chan bool) (err error) {
-	if conn.isDisconnected {
-		return errors.New("connection is already disconnected")
+func (conn *Conn) SetClosedSignal(closedSignal chan bool) (err error) {
+	// TEMP **** needs comments
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+	if conn.isClosed {
+		return errors.New("connection is already closed")
 	}
-	conn.disconnectionSignal = disconnectionSignal
+	conn.closedSignal = closedSignal
 	return nil
 }
 
-// Close terminates down an established (net.Conn) or establishing (socketFd) connection.
+// Close terminates a connected (net.Conn) or connecting (socketFd) Conn.
+// A mutex syncs access to conn struct, allowing Close() to be called
+// from a goroutine that wants to interrupt the primary goroutine using
+// the connection.
 func (conn *Conn) Close() (err error) {
-	if conn.needCloseSocketFd {
-		err = syscall.Close(conn.socketFd)
-		conn.needCloseSocketFd = false
-	}
-	if conn.Conn != nil {
-		err = conn.Conn.Close()
+	var closedSignal chan bool
+	conn.mutex.Lock()
+	if !conn.isClosed {
+		if conn.Conn == nil {
+			err = syscall.Close(conn.socketFd)
+		} else {
+			err = conn.Conn.Close()
+		}
+		closedSignal = conn.closedSignal
+		conn.isClosed = true
 	}
-	if conn.disconnectionSignal != nil {
+	conn.mutex.Unlock()
+	if closedSignal != nil {
 		select {
-		case conn.disconnectionSignal <- true:
+		case closedSignal <- true:
 		default:
 		}
 	}
-	conn.isDisconnected = true
 	return err
 }
 
 // Read wraps standard Read to add an idle timeout. The connection
-// is explicitly terminated on timeout.
+// is explicitly closed on timeout.
 func (conn *Conn) Read(buffer []byte) (n int, err error) {
-	if conn.Conn == nil {
-		return 0, errors.New("not connected")
-	}
+	// Note: no mutex on the conn.readTimeout access
 	if conn.readTimeout != 0 {
 		err = conn.Conn.SetReadDeadline(time.Now().Add(conn.readTimeout))
 		if err != nil {
@@ -151,11 +159,9 @@ func (conn *Conn) Read(buffer []byte) (n int, err error) {
 }
 
 // Write wraps standard Write to add an idle timeout The connection
-// is explicitly terminated on timeout.
+// is explicitly closed on timeout.
 func (conn *Conn) Write(buffer []byte) (n int, err error) {
-	if conn.Conn == nil {
-		return 0, errors.New("not connected")
-	}
+	// Note: no mutex on the conn.writeTimeout access
 	if conn.writeTimeout != 0 {
 		err = conn.Conn.SetWriteDeadline(time.Now().Add(conn.writeTimeout))
 		if err != nil {
@@ -168,3 +174,35 @@ func (conn *Conn) Write(buffer []byte) (n int, err error) {
 	}
 	return
 }
+
+// PendingConns is a synchronized list of Conns that's used to coordinate
+// interrupting a set of goroutines establishing connections.
+type PendingConns struct {
+	mutex sync.Mutex
+	conns []*Conn
+}
+
+func (pendingConns *PendingConns) Add(conn *Conn) {
+	pendingConns.mutex.Lock()
+	defer pendingConns.mutex.Unlock()
+	pendingConns.conns = append(pendingConns.conns, conn)
+}
+
+func (pendingConns *PendingConns) Remove(conn *Conn) {
+	pendingConns.mutex.Lock()
+	defer pendingConns.mutex.Unlock()
+	for index, pendingConn := range pendingConns.conns {
+		if conn == pendingConn {
+			pendingConns.conns = append(pendingConns.conns[:index], pendingConns.conns[index+1:]...)
+			break
+		}
+	}
+}
+
+func (pendingConns *PendingConns) Interrupt() {
+	pendingConns.mutex.Lock()
+	defer pendingConns.mutex.Unlock()
+	for _, conn := range pendingConns.conns {
+		conn.Close()
+	}
+}

+ 51 - 59
psiphon/runTunnel.go

@@ -37,40 +37,28 @@ import (
 // if there's not already an established tunnel. This function is to be used in a pool
 // of goroutines.
 func establishTunnelWorker(
-	waitGroup *sync.WaitGroup, candidateQueue chan *Tunnel,
-	broadcastStopWorkers chan bool, firstEstablishedTunnel chan *Tunnel) {
+	waitGroup *sync.WaitGroup,
+	candidateServerEntries chan *ServerEntry,
+	broadcastStopWorkers chan bool,
+	pendingConns *PendingConns,
+	establishedTunnels chan *Tunnel) {
+
 	defer waitGroup.Done()
-	for tunnel := range candidateQueue {
+	for serverEntry := range candidateServerEntries {
 		// Note: don't receive from candidateQueue and broadcastStopWorkers in the same
 		// select, since we want to prioritize receiving the stop signal
 		if IsSignalled(broadcastStopWorkers) {
 			return
 		}
-		log.Printf("connecting to %s", tunnel.serverEntry.IpAddress)
-		err := EstablishTunnel(tunnel)
+		log.Printf("connecting to %s", serverEntry.IpAddress)
+		tunnel, err := EstablishTunnel(serverEntry, pendingConns)
 		if err != nil {
-			if tunnel.isClosed {
-				log.Printf("cancelled connection to %s", tunnel.serverEntry.IpAddress)
-			} else {
-				log.Printf("failed to connect to %s: %s", tunnel.serverEntry.IpAddress, err)
-			}
+			// TODO: distingush case where conn is interrupted?
+			log.Printf("failed to connect to %s: %s", serverEntry.IpAddress, err)
 		} else {
-			// Need to re-check broadcastStopWorkers signal before sending
-			// in case firstEstablishedTunnel has been closed
-			// TODO: race condition? may panic if so
-			if !IsSignalled(broadcastStopWorkers) {
-				select {
-				case firstEstablishedTunnel <- tunnel:
-					log.Printf("selected connection to %s using %s",
-						tunnel.serverEntry.IpAddress, tunnel.protocol)
-					// Leave tunnel open
-					return
-				default:
-				}
-			}
-			log.Printf("discard connection to %s", tunnel.serverEntry.IpAddress)
+			log.Printf("successfully connected to %s", serverEntry.IpAddress)
+			establishedTunnels <- tunnel
 		}
-		tunnel.Close()
 	}
 }
 
@@ -89,53 +77,57 @@ func runTunnel(config *Config) error {
 		return fmt.Errorf("failed to fetch remote server list: %s", err)
 	}
 	log.Printf("establishing tunnel")
-	candidateList := make([]*Tunnel, 0)
-	for _, serverEntry := range serverList {
-		candidateList = append(candidateList, &Tunnel{serverEntry: serverEntry})
-	}
 	waitGroup := new(sync.WaitGroup)
-	candidateQueue := make(chan *Tunnel, len(candidateList))
-	firstEstablishedTunnel := make(chan *Tunnel, 1)
+	candidateServerEntries := make(chan *ServerEntry)
+	pendingConns := new(PendingConns)
+	establishedTunnels := make(chan *Tunnel, len(serverList))
 	timeout := time.After(ESTABLISH_TUNNEL_TIMEOUT)
 	broadcastStopWorkers := make(chan bool)
 	for i := 0; i < CONNECTION_WORKER_POOL_SIZE; i++ {
 		waitGroup.Add(1)
-		go establishTunnelWorker(waitGroup, candidateQueue, broadcastStopWorkers, firstEstablishedTunnel)
-	}
-	for _, tunnel := range candidateList {
-		candidateQueue <- tunnel
+		go establishTunnelWorker(
+			waitGroup, candidateServerEntries, broadcastStopWorkers,
+			pendingConns, establishedTunnels)
 	}
-	close(candidateQueue)
-	var establishedTunnel *Tunnel
-	select {
-	case establishedTunnel = <-firstEstablishedTunnel:
-		defer establishedTunnel.Close()
-		close(firstEstablishedTunnel)
-	case <-timeout:
-		return errors.New("timeout establishing tunnel")
+	var selectedTunnel *Tunnel
+	for _, serverEntry := range serverList {
+		select {
+		case candidateServerEntries <- serverEntry:
+		case selectedTunnel = <-establishedTunnels:
+			defer selectedTunnel.Close()
+			log.Printf("selected connection to %s", selectedTunnel.serverEntry.IpAddress)
+		case <-timeout:
+			return errors.New("timeout establishing tunnel")
+		}
+		if selectedTunnel != nil {
+			break
+		}
 	}
-	log.Printf("stopping workers")
+	log.Printf("tunnel established")
+	close(candidateServerEntries)
 	close(broadcastStopWorkers)
-	for _, candidate := range candidateList {
-		if candidate != establishedTunnel {
-			// Interrupt any partial connections in progress, so that
-			// the worker will terminate immediately
-			candidate.Close()
-		}
+	// Interrupt any partial connections in progress, so that
+	// the worker will terminate immediately
+	pendingConns.Interrupt()
+	waitGroup.Wait()
+	// Drain any excess tunnels
+	close(establishedTunnels)
+	for tunnel := range establishedTunnels {
+		log.Printf("discard connection to %s", tunnel.serverEntry.IpAddress)
+		tunnel.Close()
 	}
+	// Don't hold references to candidates while running tunnel
+	candidateServerEntries = nil
+	pendingConns = nil
 	// TODO: can start SOCKS before synchronizing work group
-	waitGroup.Wait()
-	if establishedTunnel != nil {
-		// Don't hold references to candidates while running tunnel
-		candidateList = nil
-		candidateQueue = nil
+	if selectedTunnel != nil {
 		stopTunnelSignal := make(chan bool)
-		err = establishedTunnel.conn.SetDisconnectionSignal(stopTunnelSignal)
+		err = selectedTunnel.conn.SetClosedSignal(stopTunnelSignal)
 		if err != nil {
-			return fmt.Errorf("failed to set disconnection signal: %s", err)
+			return fmt.Errorf("failed to set closed signal: %s", err)
 		}
 		log.Printf("starting local SOCKS proxy")
-		socksServer := NewSocksServer(establishedTunnel, stopTunnelSignal)
+		socksServer := NewSocksServer(selectedTunnel, stopTunnelSignal)
 		if err != nil {
 			return fmt.Errorf("error initializing local SOCKS proxy: %s", err)
 		}
@@ -144,7 +136,7 @@ func runTunnel(config *Config) error {
 			return fmt.Errorf("error running local SOCKS proxy: %s", err)
 		}
 		defer socksServer.Close()
-		log.Printf("monitoring for failure")
+		log.Printf("monitoring tunnel")
 		<-stopTunnelSignal
 	}
 	return nil

+ 1 - 1
psiphon/socksProxy.go

@@ -83,7 +83,7 @@ func socksConnectionHandler(tunnel *Tunnel, localSocksConn *pt.SocksConn) (err e
 	waitGroup.Add(1)
 	go func() {
 		defer waitGroup.Done()
-		_, err = io.Copy(localSocksConn, remoteSshForward)
+		_, err := io.Copy(localSocksConn, remoteSshForward)
 		if err != nil {
 			log.Printf("ssh port forward downstream error: %s", err)
 		}

+ 28 - 38
psiphon/tunnel.go

@@ -43,19 +43,13 @@ type Tunnel struct {
 	protocol    string
 	conn        *Conn
 	sshClient   *ssh.Client
-	isClosed    bool
 }
 
-// Close terminates the tunnel SSH client session and the
-// underlying network transport.
+// Close terminates the tunnel.
 func (tunnel *Tunnel) Close() {
-	if tunnel.sshClient != nil {
-		tunnel.sshClient.Close()
-	}
 	if tunnel.conn != nil {
 		tunnel.conn.Close()
 	}
-	tunnel.isClosed = true
 }
 
 // EstablishTunnel first makes a network transport connection to the
@@ -65,44 +59,42 @@ func (tunnel *Tunnel) Close() {
 // Depending on the server's capabilities, the connection may use
 // plain SSH over TCP, obfuscated SSH over TCP, or obfuscated SSH over
 // HTTP (meek protocol).
-func EstablishTunnel(tunnel *Tunnel) (err error) {
-	if tunnel.conn != nil {
-		return errors.New("tunnel already connected")
-	}
-	if tunnel.sshClient != nil {
-		return errors.New("ssh client already established")
-	}
+func EstablishTunnel(serverEntry *ServerEntry, pendingConns *PendingConns) (tunnel *Tunnel, err error) {
 	// First connect the transport
 	// TODO: meek
-	sshCapable := Contains(tunnel.serverEntry.Capabilities, PROTOCOL_SSH)
-	obfuscatedSshCapable := Contains(tunnel.serverEntry.Capabilities, PROTOCOL_OBFUSCATED_SSH)
+	sshCapable := Contains(serverEntry.Capabilities, PROTOCOL_SSH)
+	obfuscatedSshCapable := Contains(serverEntry.Capabilities, PROTOCOL_OBFUSCATED_SSH)
 	if !sshCapable && !obfuscatedSshCapable {
-		return fmt.Errorf("server does not have sufficient capabilities")
+		return nil, fmt.Errorf("server does not have sufficient capabilities")
+	}
+	selectedProtocol := PROTOCOL_SSH
+	port := serverEntry.SshPort
+	if obfuscatedSshCapable {
+		selectedProtocol = PROTOCOL_OBFUSCATED_SSH
+		port = serverEntry.SshObfuscatedPort
 	}
-	tunnel.protocol = PROTOCOL_SSH
-	port := tunnel.serverEntry.SshPort
-	conn, err := NewConn(0, CONNECTION_CANDIDATE_TIMEOUT, "")
+	conn, err := Dial(serverEntry.IpAddress, port, 0, CONNECTION_CANDIDATE_TIMEOUT, pendingConns)
 	if err != nil {
-		return err
+		return nil, err
 	}
+	defer func() {
+		pendingConns.Remove(conn)
+		if err != nil {
+			conn.Close()
+		}
+	}()
 	var netConn net.Conn
 	netConn = conn
 	if obfuscatedSshCapable {
-		tunnel.protocol = PROTOCOL_OBFUSCATED_SSH
-		port = tunnel.serverEntry.SshObfuscatedPort
-		netConn, err = NewObfuscatedSshConn(conn, tunnel.serverEntry.SshObfuscatedKey)
+		netConn, err = NewObfuscatedSshConn(conn, serverEntry.SshObfuscatedKey)
 		if err != nil {
-			return err
+			return nil, err
 		}
 	}
-	err = conn.Connect(tunnel.serverEntry.IpAddress, port)
-	if err != nil {
-		return err
-	}
 	// Now establish the SSH session
-	expectedPublicKey, err := base64.StdEncoding.DecodeString(tunnel.serverEntry.SshHostKey)
+	expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey)
 	if err != nil {
-		return err
+		return nil, err
 	}
 	sshCertChecker := &ssh.CertChecker{
 		HostKeyFallback: func(addr string, remote net.Addr, publicKey ssh.PublicKey) error {
@@ -113,20 +105,18 @@ func EstablishTunnel(tunnel *Tunnel) (err error) {
 		},
 	}
 	sshClientConfig := &ssh.ClientConfig{
-		User: tunnel.serverEntry.SshUsername,
+		User: serverEntry.SshUsername,
 		Auth: []ssh.AuthMethod{
-			ssh.Password(tunnel.serverEntry.SshPassword),
+			ssh.Password(serverEntry.SshPassword),
 		},
 		HostKeyCallback: sshCertChecker.CheckHostKey,
 	}
 	// The folowing is adapted from ssh.Dial(), here using a custom conn
-	sshAddress := strings.Join([]string{tunnel.serverEntry.IpAddress, ":", strconv.Itoa(tunnel.serverEntry.SshPort)}, "")
+	sshAddress := strings.Join([]string{serverEntry.IpAddress, ":", strconv.Itoa(serverEntry.SshPort)}, "")
 	sshConn, sshChans, sshReqs, err := ssh.NewClientConn(netConn, sshAddress, sshClientConfig)
 	if err != nil {
-		return err
+		return nil, err
 	}
 	sshClient := ssh.NewClient(sshConn, sshChans, sshReqs)
-	tunnel.conn = conn
-	tunnel.sshClient = sshClient
-	return nil
+	return &Tunnel{serverEntry, selectedProtocol, conn, sshClient}, nil
 }