Browse Source

Added Windows support: split out platform-specific network connection code; worker pool cleanup in runTunnel is now asynchronous

Rod Hynes 11 years ago
parent
commit
4f1be2bada
6 changed files with 216 additions and 106 deletions
  1. 1 1
      README.md
  2. 10 48
      psiphon/conn.go
  3. 86 0
      psiphon/conn_unix.go
  4. 51 0
      psiphon/conn_windows.go
  5. 68 54
      psiphon/runTunnel.go
  6. 0 3
      psiphon/tlsDialer.go

+ 1 - 1
README.md

@@ -15,9 +15,9 @@ This project is currently at the proof-of-concept stage. Current production Psip
 
 
 ### TODO (proof-of-concept)
 ### TODO (proof-of-concept)
 
 
+* investigate "psiphon.transactionWithRetry: database is locked" errors 
 * shutdown results in log noise: "use of closed network connection"
 * shutdown results in log noise: "use of closed network connection"
 * use ContextError in more places
 * use ContextError in more places
-* psiphon.Conn for Windows
 * build/test on Android and iOS
 * build/test on Android and iOS
 * integrate meek-client
 * integrate meek-client
 * disconnect all local proxy clients when tunnel disconnected
 * disconnect all local proxy clients when tunnel disconnected

+ 10 - 48
psiphon/conn.go

@@ -22,9 +22,7 @@ package psiphon
 import (
 import (
 	"errors"
 	"errors"
 	"net"
 	"net"
-	"os"
 	"sync"
 	"sync"
-	"syscall"
 	"time"
 	"time"
 )
 )
 
 
@@ -37,15 +35,15 @@ import (
 //   routing compatibility, for example).
 //   routing compatibility, for example).
 type Conn struct {
 type Conn struct {
 	net.Conn
 	net.Conn
-	mutex        sync.Mutex
-	socketFd     int
-	isClosed     bool
-	closedSignal chan bool
-	readTimeout  time.Duration
-	writeTimeout time.Duration
+	mutex         sync.Mutex
+	interruptible interruptibleConn
+	isClosed      bool
+	closedSignal  chan bool
+	readTimeout   time.Duration
+	writeTimeout  time.Duration
 }
 }
 
 
-// NewConn creates a new, connected Conn. The connection can be interrupted
+// Dial creates a new, connected Conn. The connection can be interrupted
 // using pendingConns.interrupt(): the new Conn is added to pendingConns
 // using pendingConns.interrupt(): the new Conn is added to pendingConns
 // before the socket connect beings. The caller is responsible for removing the
 // before the socket connect beings. The caller is responsible for removing the
 // returned Conn from pendingConns.
 // returned Conn from pendingConns.
@@ -57,45 +55,9 @@ func Dial(
 	readTimeout, writeTimeout time.Duration,
 	readTimeout, writeTimeout time.Duration,
 	pendingConns *PendingConns) (conn *Conn, err error) {
 	pendingConns *PendingConns) (conn *Conn, err error) {
 
 
-	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
+	conn, err = interruptibleDial(ipAddress, port, readTimeout, writeTimeout, pendingConns)
 	if err != nil {
 	if err != nil {
-		return nil, err
-	}
-	err = syscall.SetsockoptInt(socketFd, syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE, TCP_KEEP_ALIVE_PERIOD_SECONDS)
-	if err != nil {
-		syscall.Close(socketFd)
-		return nil, err
-	}
-	/*
-		// TODO: requires root, which we won't have on Android in VpnService mode
-		//       an alternative may be to use http://golang.org/pkg/syscall/#UnixRights to
-		//       send the fd to the main Android process which receives the fd with
-		//       http://developer.android.com/reference/android/net/LocalSocket.html#getAncillaryFileDescriptors%28%29
-		//       and then calls
-		//       http://developer.android.com/reference/android/net/VpnService.html#protect%28int%29.
-		//       See, for example:
-		//       https://code.google.com/p/ics-openvpn/source/browse/main/src/main/java/de/blinkt/openvpn/core/OpenVpnManagementThread.java#164
-		const SO_BINDTODEVICE = 0x19 // only defined for Linux
-		err = syscall.SetsockoptString(socketFd, syscall.SOL_SOCKET, SO_BINDTODEVICE, deviceName)
-	*/
-	conn = &Conn{
-		socketFd:     socketFd,
-		readTimeout:  readTimeout,
-		writeTimeout: writeTimeout}
-	pendingConns.Add(conn)
-	// TODO: domain name resolution (for meek)
-	var addr [4]byte
-	copy(addr[:], net.ParseIP(ipAddress).To4())
-	sockAddr := syscall.SockaddrInet4{Addr: addr, Port: port}
-	err = syscall.Connect(conn.socketFd, &sockAddr)
-	if err != nil {
-		return nil, err
-	}
-	file := os.NewFile(uintptr(conn.socketFd), "")
-	defer file.Close()
-	conn.Conn, err = net.FileConn(file)
-	if err != nil {
-		return nil, err
+		return nil, ContextError(err)
 	}
 	}
 	return conn, nil
 	return conn, nil
 }
 }
@@ -123,7 +85,7 @@ func (conn *Conn) Close() (err error) {
 	conn.mutex.Lock()
 	conn.mutex.Lock()
 	if !conn.isClosed {
 	if !conn.isClosed {
 		if conn.Conn == nil {
 		if conn.Conn == nil {
-			err = syscall.Close(conn.socketFd)
+			err = interruptibleClose(conn.interruptible)
 		} else {
 		} else {
 			err = conn.Conn.Close()
 			err = conn.Conn.Close()
 		}
 		}

+ 86 - 0
psiphon/conn_unix.go

@@ -0,0 +1,86 @@
+// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris
+
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import (
+	"net"
+	"os"
+	"syscall"
+	"time"
+)
+
+type interruptibleConn struct {
+	socketFd int
+}
+
+func interruptibleDial(
+	ipAddress string, port int,
+	readTimeout, writeTimeout time.Duration,
+	pendingConns *PendingConns) (conn *Conn, err error) {
+
+	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
+	if err != nil {
+		return nil, err
+	}
+	err = syscall.SetsockoptInt(socketFd, syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE, TCP_KEEP_ALIVE_PERIOD_SECONDS)
+	if err != nil {
+		syscall.Close(socketFd)
+		return nil, err
+	}
+	/*
+		// TODO: requires root, which we won't have on Android in VpnService mode
+		//       an alternative may be to use http://golang.org/pkg/syscall/#UnixRights to
+		//       send the fd to the main Android process which receives the fd with
+		//       http://developer.android.com/reference/android/net/LocalSocket.html#getAncillaryFileDescriptors%28%29
+		//       and then calls
+		//       http://developer.android.com/reference/android/net/VpnService.html#protect%28int%29.
+		//       See, for example:
+		//       https://code.google.com/p/ics-openvpn/source/browse/main/src/main/java/de/blinkt/openvpn/core/OpenVpnManagementThread.java#164
+		const SO_BINDTODEVICE = 0x19 // only defined for Linux
+		err = syscall.SetsockoptString(socketFd, syscall.SOL_SOCKET, SO_BINDTODEVICE, deviceName)
+	*/
+	conn = &Conn{
+		interruptible: interruptibleConn{socketFd: socketFd},
+		readTimeout:   readTimeout,
+		writeTimeout:  writeTimeout}
+	// Note: syscall.Close(socketFd) not called on error after pendingConns.Add
+	pendingConns.Add(conn)
+	// TODO: domain name resolution (for meek)
+	var addr [4]byte
+	copy(addr[:], net.ParseIP(ipAddress).To4())
+	sockAddr := syscall.SockaddrInet4{Addr: addr, Port: port}
+	err = syscall.Connect(conn.interruptible.socketFd, &sockAddr)
+	if err != nil {
+		return nil, err
+	}
+	file := os.NewFile(uintptr(conn.interruptible.socketFd), "")
+	defer file.Close()
+	conn.Conn, err = net.FileConn(file)
+	if err != nil {
+		return nil, err
+	}
+	return conn, nil
+}
+
+func interruptibleClose(interruptible interruptibleConn) error {
+	return syscall.Close(interruptible.socketFd)
+}

+ 51 - 0
psiphon/conn_windows.go

@@ -0,0 +1,51 @@
+// +build windows
+
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import (
+	"fmt"
+	"net"
+	"time"
+)
+
+type interruptibleConn struct {
+}
+
+func interruptibleDial(
+	ipAddress string, port int,
+	readTimeout, writeTimeout time.Duration,
+	pendingConns *PendingConns) (conn *Conn, err error) {
+	// Note: using net.Dial(); interruptible connections not supported on Windows
+	netConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", ipAddress, port))
+	if err != nil {
+		return nil, err
+	}
+	conn = &Conn{
+		Conn:         netConn,
+		readTimeout:  readTimeout,
+		writeTimeout: writeTimeout}
+	return conn, nil
+}
+
+func interruptibleClose(interruptible interruptibleConn) error {
+	panic("interruptibleClose not supported on Windows")
+}

+ 68 - 54
psiphon/runTunnel.go

@@ -67,21 +67,19 @@ func establishTunnelWorker(
 	}
 	}
 }
 }
 
 
+// discardTunnel is used to dispose of a successful connection that is
+// no longer required (another tunnel has already been selected). Since
+// the connection was successful, the server entry is still promoted.
 func discardTunnel(tunnel *Tunnel) {
 func discardTunnel(tunnel *Tunnel) {
 	log.Printf("discard connection to %s", tunnel.serverEntry.IpAddress)
 	log.Printf("discard connection to %s", tunnel.serverEntry.IpAddress)
 	PromoteServerEntry(tunnel.serverEntry.IpAddress)
 	PromoteServerEntry(tunnel.serverEntry.IpAddress)
 	tunnel.Close()
 	tunnel.Close()
 }
 }
 
 
-// runTunnel establishes a tunnel session and runs local proxies that make use of
-// that tunnel. The tunnel connection is monitored and this function returns an
-// error when the tunnel unexpectedly disconnects.
-// fetchRemoteServerList is used to obtain a fresh list of servers to attempt
-// to connect to. A worker pool of goroutines is used to attempt several tunnel
-// connections in parallel, and this process is stopped once the first tunnel
-// is established.
-func runTunnel(config *Config) error {
-	log.Printf("establishing tunnel")
+// establishTunnel coordinates a worker pool of goroutines to attempt several
+// tunnel connections in parallel, and this process is stopped once the first
+// tunnel is established.
+func establishTunnel(config *Config) (tunnel *Tunnel, err error) {
 	waitGroup := new(sync.WaitGroup)
 	waitGroup := new(sync.WaitGroup)
 	candidateServerEntries := make(chan *ServerEntry)
 	candidateServerEntries := make(chan *ServerEntry)
 	pendingConns := new(PendingConns)
 	pendingConns := new(PendingConns)
@@ -95,7 +93,7 @@ func runTunnel(config *Config) error {
 			pendingConns, establishedTunnels)
 			pendingConns, establishedTunnels)
 	}
 	}
 	// TODO: add a throttle after each full cycle?
 	// TODO: add a throttle after each full cycle?
-	// Note: errors fall through to ensure worker and channel cleanup
+	// Note: errors fall through to ensure worker and channel cleanup (is started, at least)
 	var selectedTunnel *Tunnel
 	var selectedTunnel *Tunnel
 	cycler, err := NewServerEntryCycler(config.EgressRegion)
 	cycler, err := NewServerEntryCycler(config.EgressRegion)
 	for selectedTunnel == nil && err == nil {
 	for selectedTunnel == nil && err == nil {
@@ -107,7 +105,6 @@ func runTunnel(config *Config) error {
 		select {
 		select {
 		case candidateServerEntries <- serverEntry:
 		case candidateServerEntries <- serverEntry:
 		case selectedTunnel = <-establishedTunnels:
 		case selectedTunnel = <-establishedTunnels:
-			defer selectedTunnel.Close()
 			log.Printf("selected connection to %s", selectedTunnel.serverEntry.IpAddress)
 			log.Printf("selected connection to %s", selectedTunnel.serverEntry.IpAddress)
 		case <-timeout:
 		case <-timeout:
 			err = errors.New("timeout establishing tunnel")
 			err = errors.New("timeout establishing tunnel")
@@ -116,57 +113,74 @@ func runTunnel(config *Config) error {
 	cycler.Close()
 	cycler.Close()
 	close(candidateServerEntries)
 	close(candidateServerEntries)
 	close(broadcastStopWorkers)
 	close(broadcastStopWorkers)
-	// Interrupt any partial connections in progress, so that
-	// the worker will terminate immediately
-	pendingConns.Interrupt()
-	waitGroup.Wait()
-	// Drain any excess tunnels
-	close(establishedTunnels)
-	for tunnel := range establishedTunnels {
-		discardTunnel(tunnel)
-	}
+	// Clean up is now asynchronous since Windows doesn't support interruptible connections
+	go func() {
+		// Interrupt any partial connections in progress, so that
+		// the worker will terminate immediately
+		pendingConns.Interrupt()
+		waitGroup.Wait()
+		// Drain any excess tunnels
+		close(establishedTunnels)
+		for tunnel := range establishedTunnels {
+			discardTunnel(tunnel)
+		}
+		// Note: only call this PromoteServerEntry after all discards so the selected
+		// tunnel is the top ranked
+		if selectedTunnel != nil {
+			PromoteServerEntry(selectedTunnel.serverEntry.IpAddress)
+		}
+	}()
 	// Note: end of error fall through
 	// Note: end of error fall through
 	if err != nil {
 	if err != nil {
-		return fmt.Errorf("failed to establish tunnel: %s", err)
+		return nil, ContextError(err)
 	}
 	}
-	// Don't hold references to candidates while running tunnel
-	candidateServerEntries = nil
-	pendingConns = nil
-	// TODO: could start local proxies, etc., before synchronizing work group
-	if selectedTunnel != nil {
-		log.Printf("tunnel established")
-		PromoteServerEntry(selectedTunnel.serverEntry.IpAddress)
-		stopTunnelSignal := make(chan bool)
-		err = selectedTunnel.conn.SetClosedSignal(stopTunnelSignal)
-		if err != nil {
-			return fmt.Errorf("failed to set closed signal: %s", err)
-		}
-		log.Printf("starting local SOCKS proxy")
-		socksProxy, err := NewSocksProxy(selectedTunnel, stopTunnelSignal)
-		if err != nil {
-			return fmt.Errorf("error initializing local SOCKS proxy: %s", err)
-		}
-		defer socksProxy.Close()
-		log.Printf("starting local HTTP proxy")
-		httpProxy, err := NewHttpProxy(selectedTunnel, stopTunnelSignal)
-		if err != nil {
-			return fmt.Errorf("error initializing local HTTP proxy: %s", err)
-		}
-		defer httpProxy.Close()
-		log.Printf("starting session")
-		localHttpProxyAddress := httpProxy.listener.Addr().String()
-		_, err = NewSession(config, selectedTunnel, localHttpProxyAddress)
-		if err != nil {
-			return fmt.Errorf("error starting session: %s", err)
-		}
-		log.Printf("monitoring tunnel")
-		<-stopTunnelSignal
+	return selectedTunnel, nil
+}
+
+// runTunnel establishes a tunnel session and runs local proxies that make use of
+// that tunnel. The tunnel connection is monitored and this function returns an
+// error when the tunnel unexpectedly disconnects.
+func runTunnel(config *Config) error {
+	log.Printf("establishing tunnel")
+	tunnel, err := establishTunnel(config)
+	if err != nil {
+		return ContextError(err)
+	}
+	defer tunnel.Close()
+	// TODO: could start local proxies, etc., before synchronizing work group is establishTunnel
+	log.Printf("running tunnel")
+	stopTunnelSignal := make(chan bool)
+	err = tunnel.conn.SetClosedSignal(stopTunnelSignal)
+	if err != nil {
+		return fmt.Errorf("failed to set closed signal: %s", err)
+	}
+	log.Printf("starting local SOCKS proxy")
+	socksProxy, err := NewSocksProxy(tunnel, stopTunnelSignal)
+	if err != nil {
+		return fmt.Errorf("error initializing local SOCKS proxy: %s", err)
+	}
+	defer socksProxy.Close()
+	log.Printf("starting local HTTP proxy")
+	httpProxy, err := NewHttpProxy(tunnel, stopTunnelSignal)
+	if err != nil {
+		return fmt.Errorf("error initializing local HTTP proxy: %s", err)
+	}
+	defer httpProxy.Close()
+	log.Printf("starting session")
+	localHttpProxyAddress := httpProxy.listener.Addr().String()
+	_, err = NewSession(config, tunnel, localHttpProxyAddress)
+	if err != nil {
+		return fmt.Errorf("error starting session: %s", err)
 	}
 	}
-	return err
+	log.Printf("monitoring tunnel")
+	<-stopTunnelSignal
+	return nil
 }
 }
 
 
 // RunTunnelForever executes the main loop of the Psiphon client. It establishes
 // RunTunnelForever executes the main loop of the Psiphon client. It establishes
 // a tunnel and reconnects when the tunnel unexpectedly disconnects.
 // a tunnel and reconnects when the tunnel unexpectedly disconnects.
+// FetchRemoteServerList is used to obtain a fresh list of servers to attempt
+// to connect to.
 func RunTunnelForever(config *Config) {
 func RunTunnelForever(config *Config) {
 	if config.LogFilename != "" {
 	if config.LogFilename != "" {
 		logFile, err := os.OpenFile(config.LogFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
 		logFile, err := os.OpenFile(config.LogFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)

+ 0 - 3
psiphon/tlsDialer.go

@@ -77,7 +77,6 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"log"
 	"net"
 	"net"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -221,8 +220,6 @@ func CustomTLSDialWithDialer(dialer *net.Dialer, network, addr string, config *C
 		err = <-errChannel
 		err = <-errChannel
 	}
 	}
 
 
-	log.Printf("TEMP tlsDialer establishConnection done: %+v", conn.ConnectionState())
-
 	if err == nil && config.verifyLegacyCertificate != nil {
 	if err == nil && config.verifyLegacyCertificate != nil {
 		err = verifyLegacyCertificate(conn, config.verifyLegacyCertificate)
 		err = verifyLegacyCertificate(conn, config.verifyLegacyCertificate)
 	} else if err == nil && !config.sendServerName && !tlsConfig.InsecureSkipVerify {
 	} else if err == nil && !config.sendServerName && !tlsConfig.InsecureSkipVerify {