فهرست منبع

refactored runTunnel and tunnel to use less shared memory and more resemble an idiomatic Go processing pipeline; refactored conn to support an explicit, synchronized interruption list and cleanup race conditions related to unsynced access to shared memory

Rod Hynes 11 سال پیش
والد
کامیت
9308d1ec87
5فایلهای تغییر یافته به همراه174 افزوده شده و 153 حذف شده
  1. 1 0
      README.md
  2. 93 55
      psiphon/conn.go
  3. 51 59
      psiphon/runTunnel.go
  4. 1 1
      psiphon/socksProxy.go
  5. 28 38
      psiphon/tunnel.go

+ 1 - 0
README.md

@@ -14,6 +14,7 @@ Status
 This project is currently at the proof-of-concept stage. Current production Psiphon client code is available at our [main repository](https://bitbucket.org/psiphon/psiphon-circumvention-system).
 This project is currently at the proof-of-concept stage. Current production Psiphon client code is available at our [main repository](https://bitbucket.org/psiphon/psiphon-circumvention-system).
 
 
 ### TODO
 ### TODO
+* psiphon.Conn for Windows
 * more test cases
 * more test cases
 * integrate meek-client
 * integrate meek-client
 * add config options
 * add config options

+ 93 - 55
psiphon/conn.go

@@ -23,6 +23,7 @@ import (
 	"errors"
 	"errors"
 	"net"
 	"net"
 	"os"
 	"os"
+	"sync"
 	"syscall"
 	"syscall"
 	"time"
 	"time"
 )
 )
@@ -36,19 +37,26 @@ import (
 //   routing compatibility, for example).
 //   routing compatibility, for example).
 type Conn struct {
 type Conn struct {
 	net.Conn
 	net.Conn
-	socketFd            int
-	needCloseSocketFd   bool
-	isDisconnected      bool
-	disconnectionSignal chan bool
-	readTimeout         time.Duration
-	writeTimeout        time.Duration
+	mutex        sync.Mutex
+	socketFd     int
+	isClosed     bool
+	closedSignal chan bool
+	readTimeout  time.Duration
+	writeTimeout time.Duration
 }
 }
 
 
-// NewConn creates a new, configured Conn. Unlike standard Dial
-// functions, this does not return a connected net.Conn. Call the Connect function
-// to complete the connection establishment. To implement device binding and
-// interruptible connecting, the lower-level syscall APIs are used.
-func NewConn(readTimeout, writeTimeout time.Duration, deviceName string) (*Conn, error) {
+// NewConn creates a new, connected Conn. The connection can be interrupted
+// using pendingConns.interrupt(): the new Conn is added to pendingConns
+// before the socket connect beings. The caller is responsible for removing the
+// returned Conn from pendingConns.
+// To implement device binding and interruptible connecting, the lower-level
+// syscall APIs are used. The sequence of syscalls in this implementation are
+// taken from: https://code.google.com/p/go/issues/detail?id=6966
+func Dial(
+	ipAddress string, port int,
+	readTimeout, writeTimeout time.Duration,
+	pendingConns *PendingConns) (conn *Conn, err error) {
+
 	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
 	socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -58,7 +66,7 @@ func NewConn(readTimeout, writeTimeout time.Duration, deviceName string) (*Conn,
 		syscall.Close(socketFd)
 		syscall.Close(socketFd)
 		return nil, err
 		return nil, err
 	}
 	}
-	if deviceName != "" {
+	/*
 		// TODO: requires root, which we won't have on Android in VpnService mode
 		// TODO: requires root, which we won't have on Android in VpnService mode
 		//       an alternative may be to use http://golang.org/pkg/syscall/#UnixRights to
 		//       an alternative may be to use http://golang.org/pkg/syscall/#UnixRights to
 		//       send the fd to the main Android process which receives the fd with
 		//       send the fd to the main Android process which receives the fd with
@@ -69,74 +77,74 @@ func NewConn(readTimeout, writeTimeout time.Duration, deviceName string) (*Conn,
 		//       https://code.google.com/p/ics-openvpn/source/browse/main/src/main/java/de/blinkt/openvpn/core/OpenVpnManagementThread.java#164
 		//       https://code.google.com/p/ics-openvpn/source/browse/main/src/main/java/de/blinkt/openvpn/core/OpenVpnManagementThread.java#164
 		const SO_BINDTODEVICE = 0x19 // only defined for Linux
 		const SO_BINDTODEVICE = 0x19 // only defined for Linux
 		err = syscall.SetsockoptString(socketFd, syscall.SOL_SOCKET, SO_BINDTODEVICE, deviceName)
 		err = syscall.SetsockoptString(socketFd, syscall.SOL_SOCKET, SO_BINDTODEVICE, deviceName)
-		return nil, err
-	}
-	return &Conn{
-		socketFd:          socketFd,
-		needCloseSocketFd: true,
-		readTimeout:       readTimeout,
-		writeTimeout:      writeTimeout}, nil
-}
-
-// Connect establishes a connection to the specified host. The sequence of
-// syscalls in this implementation are taken from: https://code.google.com/p/go/issues/detail?id=6966
-func (conn *Conn) Connect(ipAddress string, port int) (err error) {
+	*/
+	conn = &Conn{
+		socketFd:     socketFd,
+		readTimeout:  readTimeout,
+		writeTimeout: writeTimeout}
+	pendingConns.Add(conn)
 	// TODO: domain name resolution (for meek)
 	// TODO: domain name resolution (for meek)
 	var addr [4]byte
 	var addr [4]byte
 	copy(addr[:], net.ParseIP(ipAddress).To4())
 	copy(addr[:], net.ParseIP(ipAddress).To4())
 	sockAddr := syscall.SockaddrInet4{Addr: addr, Port: port}
 	sockAddr := syscall.SockaddrInet4{Addr: addr, Port: port}
 	err = syscall.Connect(conn.socketFd, &sockAddr)
 	err = syscall.Connect(conn.socketFd, &sockAddr)
 	if err != nil {
 	if err != nil {
-		return err
+		return nil, err
 	}
 	}
 	file := os.NewFile(uintptr(conn.socketFd), "")
 	file := os.NewFile(uintptr(conn.socketFd), "")
 	defer file.Close()
 	defer file.Close()
-	fileConn, err := net.FileConn(file)
+	conn.Conn, err = net.FileConn(file)
 	if err != nil {
 	if err != nil {
-		return err
+		return nil, err
 	}
 	}
-	conn.Conn = fileConn
-	conn.needCloseSocketFd = false
-	return nil
+	return conn, nil
 }
 }
 
 
-// SetDisconnectionSignal sets the channel which will be signaled
-// when the connection terminates. This function returns an error
-// if the connection is already disconnected (and would never send
+// SetClosedSignal sets the channel which will be signaled
+// when the connection is closed. This function returns an error
+// if the connection is already closed (and would never send
 // the signal).
 // the signal).
-func (conn *Conn) SetDisconnectionSignal(disconnectionSignal chan bool) (err error) {
-	if conn.isDisconnected {
-		return errors.New("connection is already disconnected")
+func (conn *Conn) SetClosedSignal(closedSignal chan bool) (err error) {
+	// TEMP **** needs comments
+	conn.mutex.Lock()
+	defer conn.mutex.Unlock()
+	if conn.isClosed {
+		return errors.New("connection is already closed")
 	}
 	}
-	conn.disconnectionSignal = disconnectionSignal
+	conn.closedSignal = closedSignal
 	return nil
 	return nil
 }
 }
 
 
-// Close terminates down an established (net.Conn) or establishing (socketFd) connection.
+// Close terminates a connected (net.Conn) or connecting (socketFd) Conn.
+// A mutex syncs access to conn struct, allowing Close() to be called
+// from a goroutine that wants to interrupt the primary goroutine using
+// the connection.
 func (conn *Conn) Close() (err error) {
 func (conn *Conn) Close() (err error) {
-	if conn.needCloseSocketFd {
-		err = syscall.Close(conn.socketFd)
-		conn.needCloseSocketFd = false
-	}
-	if conn.Conn != nil {
-		err = conn.Conn.Close()
+	var closedSignal chan bool
+	conn.mutex.Lock()
+	if !conn.isClosed {
+		if conn.Conn == nil {
+			err = syscall.Close(conn.socketFd)
+		} else {
+			err = conn.Conn.Close()
+		}
+		closedSignal = conn.closedSignal
+		conn.isClosed = true
 	}
 	}
-	if conn.disconnectionSignal != nil {
+	conn.mutex.Unlock()
+	if closedSignal != nil {
 		select {
 		select {
-		case conn.disconnectionSignal <- true:
+		case closedSignal <- true:
 		default:
 		default:
 		}
 		}
 	}
 	}
-	conn.isDisconnected = true
 	return err
 	return err
 }
 }
 
 
 // Read wraps standard Read to add an idle timeout. The connection
 // Read wraps standard Read to add an idle timeout. The connection
-// is explicitly terminated on timeout.
+// is explicitly closed on timeout.
 func (conn *Conn) Read(buffer []byte) (n int, err error) {
 func (conn *Conn) Read(buffer []byte) (n int, err error) {
-	if conn.Conn == nil {
-		return 0, errors.New("not connected")
-	}
+	// Note: no mutex on the conn.readTimeout access
 	if conn.readTimeout != 0 {
 	if conn.readTimeout != 0 {
 		err = conn.Conn.SetReadDeadline(time.Now().Add(conn.readTimeout))
 		err = conn.Conn.SetReadDeadline(time.Now().Add(conn.readTimeout))
 		if err != nil {
 		if err != nil {
@@ -151,11 +159,9 @@ func (conn *Conn) Read(buffer []byte) (n int, err error) {
 }
 }
 
 
 // Write wraps standard Write to add an idle timeout The connection
 // Write wraps standard Write to add an idle timeout The connection
-// is explicitly terminated on timeout.
+// is explicitly closed on timeout.
 func (conn *Conn) Write(buffer []byte) (n int, err error) {
 func (conn *Conn) Write(buffer []byte) (n int, err error) {
-	if conn.Conn == nil {
-		return 0, errors.New("not connected")
-	}
+	// Note: no mutex on the conn.writeTimeout access
 	if conn.writeTimeout != 0 {
 	if conn.writeTimeout != 0 {
 		err = conn.Conn.SetWriteDeadline(time.Now().Add(conn.writeTimeout))
 		err = conn.Conn.SetWriteDeadline(time.Now().Add(conn.writeTimeout))
 		if err != nil {
 		if err != nil {
@@ -168,3 +174,35 @@ func (conn *Conn) Write(buffer []byte) (n int, err error) {
 	}
 	}
 	return
 	return
 }
 }
+
+// PendingConns is a synchronized list of Conns that's used to coordinate
+// interrupting a set of goroutines establishing connections.
+type PendingConns struct {
+	mutex sync.Mutex
+	conns []*Conn
+}
+
+func (pendingConns *PendingConns) Add(conn *Conn) {
+	pendingConns.mutex.Lock()
+	defer pendingConns.mutex.Unlock()
+	pendingConns.conns = append(pendingConns.conns, conn)
+}
+
+func (pendingConns *PendingConns) Remove(conn *Conn) {
+	pendingConns.mutex.Lock()
+	defer pendingConns.mutex.Unlock()
+	for index, pendingConn := range pendingConns.conns {
+		if conn == pendingConn {
+			pendingConns.conns = append(pendingConns.conns[:index], pendingConns.conns[index+1:]...)
+			break
+		}
+	}
+}
+
+func (pendingConns *PendingConns) Interrupt() {
+	pendingConns.mutex.Lock()
+	defer pendingConns.mutex.Unlock()
+	for _, conn := range pendingConns.conns {
+		conn.Close()
+	}
+}

+ 51 - 59
psiphon/runTunnel.go

@@ -37,40 +37,28 @@ import (
 // if there's not already an established tunnel. This function is to be used in a pool
 // if there's not already an established tunnel. This function is to be used in a pool
 // of goroutines.
 // of goroutines.
 func establishTunnelWorker(
 func establishTunnelWorker(
-	waitGroup *sync.WaitGroup, candidateQueue chan *Tunnel,
-	broadcastStopWorkers chan bool, firstEstablishedTunnel chan *Tunnel) {
+	waitGroup *sync.WaitGroup,
+	candidateServerEntries chan *ServerEntry,
+	broadcastStopWorkers chan bool,
+	pendingConns *PendingConns,
+	establishedTunnels chan *Tunnel) {
+
 	defer waitGroup.Done()
 	defer waitGroup.Done()
-	for tunnel := range candidateQueue {
+	for serverEntry := range candidateServerEntries {
 		// Note: don't receive from candidateQueue and broadcastStopWorkers in the same
 		// Note: don't receive from candidateQueue and broadcastStopWorkers in the same
 		// select, since we want to prioritize receiving the stop signal
 		// select, since we want to prioritize receiving the stop signal
 		if IsSignalled(broadcastStopWorkers) {
 		if IsSignalled(broadcastStopWorkers) {
 			return
 			return
 		}
 		}
-		log.Printf("connecting to %s", tunnel.serverEntry.IpAddress)
-		err := EstablishTunnel(tunnel)
+		log.Printf("connecting to %s", serverEntry.IpAddress)
+		tunnel, err := EstablishTunnel(serverEntry, pendingConns)
 		if err != nil {
 		if err != nil {
-			if tunnel.isClosed {
-				log.Printf("cancelled connection to %s", tunnel.serverEntry.IpAddress)
-			} else {
-				log.Printf("failed to connect to %s: %s", tunnel.serverEntry.IpAddress, err)
-			}
+			// TODO: distingush case where conn is interrupted?
+			log.Printf("failed to connect to %s: %s", serverEntry.IpAddress, err)
 		} else {
 		} else {
-			// Need to re-check broadcastStopWorkers signal before sending
-			// in case firstEstablishedTunnel has been closed
-			// TODO: race condition? may panic if so
-			if !IsSignalled(broadcastStopWorkers) {
-				select {
-				case firstEstablishedTunnel <- tunnel:
-					log.Printf("selected connection to %s using %s",
-						tunnel.serverEntry.IpAddress, tunnel.protocol)
-					// Leave tunnel open
-					return
-				default:
-				}
-			}
-			log.Printf("discard connection to %s", tunnel.serverEntry.IpAddress)
+			log.Printf("successfully connected to %s", serverEntry.IpAddress)
+			establishedTunnels <- tunnel
 		}
 		}
-		tunnel.Close()
 	}
 	}
 }
 }
 
 
@@ -89,53 +77,57 @@ func runTunnel(config *Config) error {
 		return fmt.Errorf("failed to fetch remote server list: %s", err)
 		return fmt.Errorf("failed to fetch remote server list: %s", err)
 	}
 	}
 	log.Printf("establishing tunnel")
 	log.Printf("establishing tunnel")
-	candidateList := make([]*Tunnel, 0)
-	for _, serverEntry := range serverList {
-		candidateList = append(candidateList, &Tunnel{serverEntry: serverEntry})
-	}
 	waitGroup := new(sync.WaitGroup)
 	waitGroup := new(sync.WaitGroup)
-	candidateQueue := make(chan *Tunnel, len(candidateList))
-	firstEstablishedTunnel := make(chan *Tunnel, 1)
+	candidateServerEntries := make(chan *ServerEntry)
+	pendingConns := new(PendingConns)
+	establishedTunnels := make(chan *Tunnel, len(serverList))
 	timeout := time.After(ESTABLISH_TUNNEL_TIMEOUT)
 	timeout := time.After(ESTABLISH_TUNNEL_TIMEOUT)
 	broadcastStopWorkers := make(chan bool)
 	broadcastStopWorkers := make(chan bool)
 	for i := 0; i < CONNECTION_WORKER_POOL_SIZE; i++ {
 	for i := 0; i < CONNECTION_WORKER_POOL_SIZE; i++ {
 		waitGroup.Add(1)
 		waitGroup.Add(1)
-		go establishTunnelWorker(waitGroup, candidateQueue, broadcastStopWorkers, firstEstablishedTunnel)
-	}
-	for _, tunnel := range candidateList {
-		candidateQueue <- tunnel
+		go establishTunnelWorker(
+			waitGroup, candidateServerEntries, broadcastStopWorkers,
+			pendingConns, establishedTunnels)
 	}
 	}
-	close(candidateQueue)
-	var establishedTunnel *Tunnel
-	select {
-	case establishedTunnel = <-firstEstablishedTunnel:
-		defer establishedTunnel.Close()
-		close(firstEstablishedTunnel)
-	case <-timeout:
-		return errors.New("timeout establishing tunnel")
+	var selectedTunnel *Tunnel
+	for _, serverEntry := range serverList {
+		select {
+		case candidateServerEntries <- serverEntry:
+		case selectedTunnel = <-establishedTunnels:
+			defer selectedTunnel.Close()
+			log.Printf("selected connection to %s", selectedTunnel.serverEntry.IpAddress)
+		case <-timeout:
+			return errors.New("timeout establishing tunnel")
+		}
+		if selectedTunnel != nil {
+			break
+		}
 	}
 	}
-	log.Printf("stopping workers")
+	log.Printf("tunnel established")
+	close(candidateServerEntries)
 	close(broadcastStopWorkers)
 	close(broadcastStopWorkers)
-	for _, candidate := range candidateList {
-		if candidate != establishedTunnel {
-			// Interrupt any partial connections in progress, so that
-			// the worker will terminate immediately
-			candidate.Close()
-		}
+	// Interrupt any partial connections in progress, so that
+	// the worker will terminate immediately
+	pendingConns.Interrupt()
+	waitGroup.Wait()
+	// Drain any excess tunnels
+	close(establishedTunnels)
+	for tunnel := range establishedTunnels {
+		log.Printf("discard connection to %s", tunnel.serverEntry.IpAddress)
+		tunnel.Close()
 	}
 	}
+	// Don't hold references to candidates while running tunnel
+	candidateServerEntries = nil
+	pendingConns = nil
 	// TODO: can start SOCKS before synchronizing work group
 	// TODO: can start SOCKS before synchronizing work group
-	waitGroup.Wait()
-	if establishedTunnel != nil {
-		// Don't hold references to candidates while running tunnel
-		candidateList = nil
-		candidateQueue = nil
+	if selectedTunnel != nil {
 		stopTunnelSignal := make(chan bool)
 		stopTunnelSignal := make(chan bool)
-		err = establishedTunnel.conn.SetDisconnectionSignal(stopTunnelSignal)
+		err = selectedTunnel.conn.SetClosedSignal(stopTunnelSignal)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("failed to set disconnection signal: %s", err)
+			return fmt.Errorf("failed to set closed signal: %s", err)
 		}
 		}
 		log.Printf("starting local SOCKS proxy")
 		log.Printf("starting local SOCKS proxy")
-		socksServer := NewSocksServer(establishedTunnel, stopTunnelSignal)
+		socksServer := NewSocksServer(selectedTunnel, stopTunnelSignal)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("error initializing local SOCKS proxy: %s", err)
 			return fmt.Errorf("error initializing local SOCKS proxy: %s", err)
 		}
 		}
@@ -144,7 +136,7 @@ func runTunnel(config *Config) error {
 			return fmt.Errorf("error running local SOCKS proxy: %s", err)
 			return fmt.Errorf("error running local SOCKS proxy: %s", err)
 		}
 		}
 		defer socksServer.Close()
 		defer socksServer.Close()
-		log.Printf("monitoring for failure")
+		log.Printf("monitoring tunnel")
 		<-stopTunnelSignal
 		<-stopTunnelSignal
 	}
 	}
 	return nil
 	return nil

+ 1 - 1
psiphon/socksProxy.go

@@ -83,7 +83,7 @@ func socksConnectionHandler(tunnel *Tunnel, localSocksConn *pt.SocksConn) (err e
 	waitGroup.Add(1)
 	waitGroup.Add(1)
 	go func() {
 	go func() {
 		defer waitGroup.Done()
 		defer waitGroup.Done()
-		_, err = io.Copy(localSocksConn, remoteSshForward)
+		_, err := io.Copy(localSocksConn, remoteSshForward)
 		if err != nil {
 		if err != nil {
 			log.Printf("ssh port forward downstream error: %s", err)
 			log.Printf("ssh port forward downstream error: %s", err)
 		}
 		}

+ 28 - 38
psiphon/tunnel.go

@@ -43,19 +43,13 @@ type Tunnel struct {
 	protocol    string
 	protocol    string
 	conn        *Conn
 	conn        *Conn
 	sshClient   *ssh.Client
 	sshClient   *ssh.Client
-	isClosed    bool
 }
 }
 
 
-// Close terminates the tunnel SSH client session and the
-// underlying network transport.
+// Close terminates the tunnel.
 func (tunnel *Tunnel) Close() {
 func (tunnel *Tunnel) Close() {
-	if tunnel.sshClient != nil {
-		tunnel.sshClient.Close()
-	}
 	if tunnel.conn != nil {
 	if tunnel.conn != nil {
 		tunnel.conn.Close()
 		tunnel.conn.Close()
 	}
 	}
-	tunnel.isClosed = true
 }
 }
 
 
 // EstablishTunnel first makes a network transport connection to the
 // EstablishTunnel first makes a network transport connection to the
@@ -65,44 +59,42 @@ func (tunnel *Tunnel) Close() {
 // Depending on the server's capabilities, the connection may use
 // Depending on the server's capabilities, the connection may use
 // plain SSH over TCP, obfuscated SSH over TCP, or obfuscated SSH over
 // plain SSH over TCP, obfuscated SSH over TCP, or obfuscated SSH over
 // HTTP (meek protocol).
 // HTTP (meek protocol).
-func EstablishTunnel(tunnel *Tunnel) (err error) {
-	if tunnel.conn != nil {
-		return errors.New("tunnel already connected")
-	}
-	if tunnel.sshClient != nil {
-		return errors.New("ssh client already established")
-	}
+func EstablishTunnel(serverEntry *ServerEntry, pendingConns *PendingConns) (tunnel *Tunnel, err error) {
 	// First connect the transport
 	// First connect the transport
 	// TODO: meek
 	// TODO: meek
-	sshCapable := Contains(tunnel.serverEntry.Capabilities, PROTOCOL_SSH)
-	obfuscatedSshCapable := Contains(tunnel.serverEntry.Capabilities, PROTOCOL_OBFUSCATED_SSH)
+	sshCapable := Contains(serverEntry.Capabilities, PROTOCOL_SSH)
+	obfuscatedSshCapable := Contains(serverEntry.Capabilities, PROTOCOL_OBFUSCATED_SSH)
 	if !sshCapable && !obfuscatedSshCapable {
 	if !sshCapable && !obfuscatedSshCapable {
-		return fmt.Errorf("server does not have sufficient capabilities")
+		return nil, fmt.Errorf("server does not have sufficient capabilities")
+	}
+	selectedProtocol := PROTOCOL_SSH
+	port := serverEntry.SshPort
+	if obfuscatedSshCapable {
+		selectedProtocol = PROTOCOL_OBFUSCATED_SSH
+		port = serverEntry.SshObfuscatedPort
 	}
 	}
-	tunnel.protocol = PROTOCOL_SSH
-	port := tunnel.serverEntry.SshPort
-	conn, err := NewConn(0, CONNECTION_CANDIDATE_TIMEOUT, "")
+	conn, err := Dial(serverEntry.IpAddress, port, 0, CONNECTION_CANDIDATE_TIMEOUT, pendingConns)
 	if err != nil {
 	if err != nil {
-		return err
+		return nil, err
 	}
 	}
+	defer func() {
+		pendingConns.Remove(conn)
+		if err != nil {
+			conn.Close()
+		}
+	}()
 	var netConn net.Conn
 	var netConn net.Conn
 	netConn = conn
 	netConn = conn
 	if obfuscatedSshCapable {
 	if obfuscatedSshCapable {
-		tunnel.protocol = PROTOCOL_OBFUSCATED_SSH
-		port = tunnel.serverEntry.SshObfuscatedPort
-		netConn, err = NewObfuscatedSshConn(conn, tunnel.serverEntry.SshObfuscatedKey)
+		netConn, err = NewObfuscatedSshConn(conn, serverEntry.SshObfuscatedKey)
 		if err != nil {
 		if err != nil {
-			return err
+			return nil, err
 		}
 		}
 	}
 	}
-	err = conn.Connect(tunnel.serverEntry.IpAddress, port)
-	if err != nil {
-		return err
-	}
 	// Now establish the SSH session
 	// Now establish the SSH session
-	expectedPublicKey, err := base64.StdEncoding.DecodeString(tunnel.serverEntry.SshHostKey)
+	expectedPublicKey, err := base64.StdEncoding.DecodeString(serverEntry.SshHostKey)
 	if err != nil {
 	if err != nil {
-		return err
+		return nil, err
 	}
 	}
 	sshCertChecker := &ssh.CertChecker{
 	sshCertChecker := &ssh.CertChecker{
 		HostKeyFallback: func(addr string, remote net.Addr, publicKey ssh.PublicKey) error {
 		HostKeyFallback: func(addr string, remote net.Addr, publicKey ssh.PublicKey) error {
@@ -113,20 +105,18 @@ func EstablishTunnel(tunnel *Tunnel) (err error) {
 		},
 		},
 	}
 	}
 	sshClientConfig := &ssh.ClientConfig{
 	sshClientConfig := &ssh.ClientConfig{
-		User: tunnel.serverEntry.SshUsername,
+		User: serverEntry.SshUsername,
 		Auth: []ssh.AuthMethod{
 		Auth: []ssh.AuthMethod{
-			ssh.Password(tunnel.serverEntry.SshPassword),
+			ssh.Password(serverEntry.SshPassword),
 		},
 		},
 		HostKeyCallback: sshCertChecker.CheckHostKey,
 		HostKeyCallback: sshCertChecker.CheckHostKey,
 	}
 	}
 	// The folowing is adapted from ssh.Dial(), here using a custom conn
 	// The folowing is adapted from ssh.Dial(), here using a custom conn
-	sshAddress := strings.Join([]string{tunnel.serverEntry.IpAddress, ":", strconv.Itoa(tunnel.serverEntry.SshPort)}, "")
+	sshAddress := strings.Join([]string{serverEntry.IpAddress, ":", strconv.Itoa(serverEntry.SshPort)}, "")
 	sshConn, sshChans, sshReqs, err := ssh.NewClientConn(netConn, sshAddress, sshClientConfig)
 	sshConn, sshChans, sshReqs, err := ssh.NewClientConn(netConn, sshAddress, sshClientConfig)
 	if err != nil {
 	if err != nil {
-		return err
+		return nil, err
 	}
 	}
 	sshClient := ssh.NewClient(sshConn, sshChans, sshReqs)
 	sshClient := ssh.NewClient(sshConn, sshChans, sshReqs)
-	tunnel.conn = conn
-	tunnel.sshClient = sshClient
-	return nil
+	return &Tunnel{serverEntry, selectedProtocol, conn, sshClient}, nil
 }
 }