Ver Fonte

Continuation of refactor in 2cea38f: move TunneledConn to tunnel.go

Rod Hynes há 11 anos atrás
pai
commit
12fdcb4c3c
3 ficheiros alterados com 56 adições e 52 exclusões
  1. 2 3
      psiphon/TCPConn.go
  2. 4 47
      psiphon/controller.go
  3. 50 2
      psiphon/tunnel.go

+ 2 - 3
psiphon/TCPConn.go

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright (c) 2014, Psiphon Inc.
+ * Copyright (c) 2015, Psiphon Inc.
  * All rights reserved.
  * All rights reserved.
  *
  *
  * This program is free software: you can redistribute it and/or modify
  * This program is free software: you can redistribute it and/or modify
@@ -61,8 +61,7 @@ func DialTCP(addr string, config *DialConfig) (conn *TCPConn, err error) {
 	return conn, nil
 	return conn, nil
 }
 }
 
 
-// SetClosedSignal implements psiphon.Conn.SetClosedSignal. Returns true
-// if signal is successfully set, or false if the conn is already closed.
+// SetClosedSignal implements psiphon.Conn.SetClosedSignal.
 func (conn *TCPConn) SetClosedSignal(closedSignal chan struct{}) bool {
 func (conn *TCPConn) SetClosedSignal(closedSignal chan struct{}) bool {
 	conn.mutex.Lock()
 	conn.mutex.Lock()
 	defer conn.mutex.Unlock()
 	defer conn.mutex.Unlock()

+ 4 - 47
psiphon/controller.go

@@ -25,7 +25,6 @@ package psiphon
 
 
 import (
 import (
 	"errors"
 	"errors"
-	"io"
 	"net"
 	"net"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -346,40 +345,6 @@ func (controller *Controller) isActiveTunnelServerEntry(serverEntry *ServerEntry
 	return false
 	return false
 }
 }
 
 
-// TunneledConn implements net.Conn and wraps a port foward connection.
-// It is used to hook into Read and Write to observe I/O errors and
-// report these errors back to the tunnel monitor as port forward failures.
-type TunneledConn struct {
-	net.Conn
-	tunnel *Tunnel
-}
-
-func (conn *TunneledConn) Read(buffer []byte) (n int, err error) {
-	n, err = conn.Conn.Read(buffer)
-	if err != nil && err != io.EOF {
-		// Report 1 new failure. Won't block; assumes the receiver
-		// has a sufficient buffer for the threshold number of reports.
-		// TODO: conditional on type of error or error message?
-		select {
-		case conn.tunnel.portForwardFailures <- 1:
-		default:
-		}
-	}
-	return
-}
-
-func (conn *TunneledConn) Write(buffer []byte) (n int, err error) {
-	n, err = conn.Conn.Write(buffer)
-	if err != nil && err != io.EOF {
-		// Same as TunneledConn.Read()
-		select {
-		case conn.tunnel.portForwardFailures <- 1:
-		default:
-		}
-	}
-	return
-}
-
 // Dial selects an active tunnel and establishes a port forward
 // Dial selects an active tunnel and establishes a port forward
 // connection through the selected tunnel. Failure to connect is considered
 // connection through the selected tunnel. Failure to connect is considered
 // a port foward failure, for the purpose of monitoring tunnel health.
 // a port foward failure, for the purpose of monitoring tunnel health.
@@ -388,24 +353,16 @@ func (controller *Controller) Dial(remoteAddr string) (conn net.Conn, err error)
 	if tunnel == nil {
 	if tunnel == nil {
 		return nil, ContextError(errors.New("no active tunnels"))
 		return nil, ContextError(errors.New("no active tunnels"))
 	}
 	}
-	tunnelConn, err := tunnel.Dial(remoteAddr)
+
+	tunneledConn, err := tunnel.Dial(remoteAddr)
 	if err != nil {
 	if err != nil {
-		// TODO: conditional on type of error or error message?
-		select {
-		case tunnel.portForwardFailures <- 1:
-		default:
-		}
 		return nil, ContextError(err)
 		return nil, ContextError(err)
 	}
 	}
 
 
 	statsConn := NewStatsConn(
 	statsConn := NewStatsConn(
-		tunnelConn, tunnel.session.StatsServerID(), tunnel.session.StatsRegexps())
-
-	conn = &TunneledConn{
-		Conn:   statsConn,
-		tunnel: tunnel}
+		tunneledConn, tunnel.session.StatsServerID(), tunnel.session.StatsRegexps())
 
 
-	return
+	return statsConn, nil
 }
 }
 
 
 // startEstablishing creates a pool of worker goroutines which will
 // startEstablishing creates a pool of worker goroutines which will

+ 50 - 2
psiphon/tunnel.go

@@ -25,6 +25,7 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"io"
 	"net"
 	"net"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -181,8 +182,55 @@ func (tunnel *Tunnel) Dial(remoteAddr string) (conn net.Conn, err error) {
 	if isClosed {
 	if isClosed {
 		return nil, errors.New("tunnel is closed")
 		return nil, errors.New("tunnel is closed")
 	}
 	}
-	// TODO: should this track port forward failures as in Controller.DialWithTunnel?
-	return tunnel.sshClient.Dial("tcp", remoteAddr)
+
+	sshPortForwardConn, err := tunnel.sshClient.Dial("tcp", remoteAddr)
+	if err != nil {
+		// TODO: conditional on type of error or error message?
+		select {
+		case tunnel.portForwardFailures <- 1:
+		default:
+		}
+		return nil, ContextError(err)
+	}
+
+	return &TunneledConn{
+			Conn:   sshPortForwardConn,
+			tunnel: tunnel},
+		nil
+}
+
+// TunneledConn implements net.Conn and wraps a port foward connection.
+// It is used to hook into Read and Write to observe I/O errors and
+// report these errors back to the tunnel monitor as port forward failures.
+type TunneledConn struct {
+	net.Conn
+	tunnel *Tunnel
+}
+
+func (conn *TunneledConn) Read(buffer []byte) (n int, err error) {
+	n, err = conn.Conn.Read(buffer)
+	if err != nil && err != io.EOF {
+		// Report 1 new failure. Won't block; assumes the receiver
+		// has a sufficient buffer for the threshold number of reports.
+		// TODO: conditional on type of error or error message?
+		select {
+		case conn.tunnel.portForwardFailures <- 1:
+		default:
+		}
+	}
+	return
+}
+
+func (conn *TunneledConn) Write(buffer []byte) (n int, err error) {
+	n, err = conn.Conn.Write(buffer)
+	if err != nil && err != io.EOF {
+		// Same as TunneledConn.Read()
+		select {
+		case conn.tunnel.portForwardFailures <- 1:
+		default:
+		}
+	}
+	return
 }
 }
 
 
 // SignalComponentFailure notifies the tunnel that an associated component has failed.
 // SignalComponentFailure notifies the tunnel that an associated component has failed.