Эх сурвалжийг харах

Merge pull request #106 from rod-hynes/master

More changes related to Psiphon Android conversion to tunnel-core
Rod Hynes 10 жил өмнө
parent
commit
3f793ce6f4

+ 27 - 6
SampleApps/Psibot/app/src/main/java/ca/psiphon/PsiphonTunnel.java

@@ -67,9 +67,11 @@ public class PsiphonTunnel extends Psi.PsiphonProvider.Stub {
         public void onConnecting();
         public void onConnecting();
         public void onConnected();
         public void onConnected();
         public void onHomepage(String url);
         public void onHomepage(String url);
+        public void onClientRegion(String region);
         public void onClientUpgradeDownloaded(String filename);
         public void onClientUpgradeDownloaded(String filename);
         public void onSplitTunnelRegion(String region);
         public void onSplitTunnelRegion(String region);
         public void onUntunneledAddress(String address);
         public void onUntunneledAddress(String address);
+        public void onBytesTransferred(long sent, long received);
     }
     }
 
 
     private final HostService mHostService;
     private final HostService mHostService;
@@ -268,12 +270,12 @@ public class PsiphonTunnel extends Psi.PsiphonProvider.Stub {
         stopPsiphon();
         stopPsiphon();
         mHostService.onDiagnosticMessage("starting Psiphon library");
         mHostService.onDiagnosticMessage("starting Psiphon library");
         try {
         try {
-            boolean useDeviceBinder = (mTunFd != null);
+            boolean isVpnMode = (mTunFd != null);
             Psi.Start(
             Psi.Start(
-                loadPsiphonConfig(mHostService.getContext()),
+                loadPsiphonConfig(mHostService.getContext(), isVpnMode),
                 embeddedServerEntries,
                 embeddedServerEntries,
                 this,
                 this,
-                useDeviceBinder);
+                isVpnMode);
         } catch (java.lang.Exception e) {
         } catch (java.lang.Exception e) {
             throw new Exception("failed to start Psiphon library", e);
             throw new Exception("failed to start Psiphon library", e);
         }
         }
@@ -286,7 +288,7 @@ public class PsiphonTunnel extends Psi.PsiphonProvider.Stub {
         mHostService.onDiagnosticMessage("Psiphon library stopped");
         mHostService.onDiagnosticMessage("Psiphon library stopped");
     }
     }
 
 
-    private String loadPsiphonConfig(Context context)
+    private String loadPsiphonConfig(Context context, boolean isVpnMode)
             throws IOException, JSONException {
             throws IOException, JSONException {
 
 
         // Load settings from the raw resource JSON config file and
         // Load settings from the raw resource JSON config file and
@@ -305,6 +307,18 @@ public class PsiphonTunnel extends Psi.PsiphonProvider.Stub {
         // Continue to run indefinitely until connected
         // Continue to run indefinitely until connected
         json.put("EstablishTunnelTimeoutSeconds", 0);
         json.put("EstablishTunnelTimeoutSeconds", 0);
 
 
+        // Enable tunnel auto-reconnect after a threshold number of port
+        // forward failures. By default, this mechanism is disabled in
+        // tunnel-core due to the chance of false positives due to
+        // bad user input. Since VpnService mode resolves domain names
+        // differently (udpgw), invalid domain name user input won't result
+        // in SSH port forward failures.
+        if (isVpnMode) {
+            json.put("PortForwardFailureThreshold", 10);
+        }
+
+        json.put("EmitBytesTransferred", true);
+
         if (mLocalSocksProxyPort != 0) {
         if (mLocalSocksProxyPort != 0) {
             // When mLocalSocksProxyPort is set, tun2socks is already configured
             // When mLocalSocksProxyPort is set, tun2socks is already configured
             // to use that port value. So we force use of the same port.
             // to use that port value. So we force use of the same port.
@@ -368,11 +382,18 @@ public class PsiphonTunnel extends Psi.PsiphonProvider.Stub {
             } else if (noticeType.equals("Homepage")) {
             } else if (noticeType.equals("Homepage")) {
                 mHostService.onHomepage(notice.getJSONObject("data").getString("url"));
                 mHostService.onHomepage(notice.getJSONObject("data").getString("url"));
 
 
+            } else if (noticeType.equals("ClientRegion")) {
+                mHostService.onClientRegion(notice.getJSONObject("data").getString("region"));
+
             } else if (noticeType.equals("SplitTunnelRegion")) {
             } else if (noticeType.equals("SplitTunnelRegion")) {
-                mHostService.onHomepage(notice.getJSONObject("data").getString("region"));
+                mHostService.onSplitTunnelRegion(notice.getJSONObject("data").getString("region"));
 
 
             } else if (noticeType.equals("UntunneledAddress")) {
             } else if (noticeType.equals("UntunneledAddress")) {
-                mHostService.onHomepage(notice.getJSONObject("data").getString("address"));
+                mHostService.onUntunneledAddress(notice.getJSONObject("data").getString("address"));
+
+            } else if (noticeType.equals("BytesTransferred")) {
+                JSONObject data = notice.getJSONObject("data");
+                mHostService.onBytesTransferred(data.getLong("sent"), data.getLong("received"));
             }
             }
 
 
             if (diagnostic) {
             if (diagnostic) {

+ 9 - 0
SampleApps/Psibot/app/src/main/java/ca/psiphon/psibot/Service.java

@@ -224,6 +224,15 @@ public class Service extends VpnService
         Log.addEntry("untunneled address: " + address);
         Log.addEntry("untunneled address: " + address);
     }
     }
 
 
+    @Override
+    public void onBytesTransferred(long sent, long received) {
+    }
+
+    @Override
+    public void onClientRegion(String region) {
+        Log.addEntry("client region: " + region);
+    }
+
     private static String readInputStreamToString(InputStream inputStream) throws IOException {
     private static String readInputStreamToString(InputStream inputStream) throws IOException {
         return new String(readInputStreamToBytes(inputStream), "UTF-8");
         return new String(readInputStreamToBytes(inputStream), "UTF-8");
     }
     }

+ 3 - 0
psiphon/config.go

@@ -59,6 +59,8 @@ const (
 	FETCH_ROUTES_TIMEOUT                         = 1 * time.Minute
 	FETCH_ROUTES_TIMEOUT                         = 1 * time.Minute
 	DOWNLOAD_UPGRADE_TIMEOUT                     = 15 * time.Minute
 	DOWNLOAD_UPGRADE_TIMEOUT                     = 15 * time.Minute
 	DOWNLOAD_UPGRADE_RETRY_PAUSE_PERIOD          = 5 * time.Second
 	DOWNLOAD_UPGRADE_RETRY_PAUSE_PERIOD          = 5 * time.Second
+	IMPAIRED_PROTOCOL_CLASSIFICATION_DURATION    = 2 * time.Minute
+	IMPAIRED_PROTOCOL_CLASSIFICATION_THRESHOLD   = 3
 )
 )
 
 
 // To distinguish omitted timeout params from explicit 0 value timeout
 // To distinguish omitted timeout params from explicit 0 value timeout
@@ -96,6 +98,7 @@ type Config struct {
 	SplitTunnelDnsServer                string
 	SplitTunnelDnsServer                string
 	UpgradeDownloadUrl                  string
 	UpgradeDownloadUrl                  string
 	UpgradeDownloadFilename             string
 	UpgradeDownloadFilename             string
+	EmitBytesTransferred                bool
 }
 }
 
 
 // LoadConfig parses and validates a JSON format Psiphon config JSON
 // LoadConfig parses and validates a JSON format Psiphon config JSON

+ 103 - 33
psiphon/controller.go

@@ -34,28 +34,29 @@ import (
 // connect to; establishes and monitors tunnels; and runs local proxies which
 // connect to; establishes and monitors tunnels; and runs local proxies which
 // route traffic through the tunnels.
 // route traffic through the tunnels.
 type Controller struct {
 type Controller struct {
-	config                      *Config
-	sessionId                   string
-	componentFailureSignal      chan struct{}
-	shutdownBroadcast           chan struct{}
-	runWaitGroup                *sync.WaitGroup
-	establishedTunnels          chan *Tunnel
-	failedTunnels               chan *Tunnel
-	tunnelMutex                 sync.Mutex
-	establishedOnce             bool
-	tunnels                     []*Tunnel
-	nextTunnel                  int
-	startedConnectedReporter    bool
-	startedUpgradeDownloader    bool
-	isEstablishing              bool
-	establishWaitGroup          *sync.WaitGroup
-	stopEstablishingBroadcast   chan struct{}
-	candidateServerEntries      chan *ServerEntry
-	establishPendingConns       *Conns
-	untunneledPendingConns      *Conns
-	untunneledDialConfig        *DialConfig
-	splitTunnelClassifier       *SplitTunnelClassifier
-	signalFetchRemoteServerList chan struct{}
+	config                         *Config
+	sessionId                      string
+	componentFailureSignal         chan struct{}
+	shutdownBroadcast              chan struct{}
+	runWaitGroup                   *sync.WaitGroup
+	establishedTunnels             chan *Tunnel
+	failedTunnels                  chan *Tunnel
+	tunnelMutex                    sync.Mutex
+	establishedOnce                bool
+	tunnels                        []*Tunnel
+	nextTunnel                     int
+	startedConnectedReporter       bool
+	startedUpgradeDownloader       bool
+	isEstablishing                 bool
+	establishWaitGroup             *sync.WaitGroup
+	stopEstablishingBroadcast      chan struct{}
+	candidateServerEntries         chan *ServerEntry
+	establishPendingConns          *Conns
+	untunneledPendingConns         *Conns
+	untunneledDialConfig           *DialConfig
+	splitTunnelClassifier          *SplitTunnelClassifier
+	signalFetchRemoteServerList    chan struct{}
+	impairedProtocolClassification map[string]int
 }
 }
 
 
 // NewController initializes a new controller.
 // NewController initializes a new controller.
@@ -102,7 +103,8 @@ func NewController(config *Config) (controller *Controller, err error) {
 		untunneledDialConfig:     untunneledDialConfig,
 		untunneledDialConfig:     untunneledDialConfig,
 		// A buffer allows at least one signal to be sent even when the receiver is
 		// A buffer allows at least one signal to be sent even when the receiver is
 		// not listening. Senders should not block.
 		// not listening. Senders should not block.
-		signalFetchRemoteServerList: make(chan struct{}, 1),
+		signalFetchRemoteServerList:    make(chan struct{}, 1),
+		impairedProtocolClassification: make(map[string]int),
 	}
 	}
 
 
 	controller.splitTunnelClassifier = NewSplitTunnelClassifier(config, controller)
 	controller.splitTunnelClassifier = NewSplitTunnelClassifier(config, controller)
@@ -411,7 +413,7 @@ loop:
 			// establishPendingConns.Reset() which clears the closed flag in
 			// establishPendingConns.Reset() which clears the closed flag in
 			// establishPendingConns; this causes the pendingConns.Add() within
 			// establishPendingConns; this causes the pendingConns.Add() within
 			// interruptibleTCPDial to succeed instead of aborting, and the result
 			// interruptibleTCPDial to succeed instead of aborting, and the result
-			// is that it's possible for extablish goroutines to run all the way through
+			// is that it's possible for establish goroutines to run all the way through
 			// NewSession before being discarded... delaying shutdown.
 			// NewSession before being discarded... delaying shutdown.
 			select {
 			select {
 			case <-controller.shutdownBroadcast:
 			case <-controller.shutdownBroadcast:
@@ -419,6 +421,8 @@ loop:
 			default:
 			default:
 			}
 			}
 
 
+			controller.classifyImpairedProtocol(failedTunnel)
+
 			// Concurrency note: only this goroutine may call startEstablishing/stopEstablishing
 			// Concurrency note: only this goroutine may call startEstablishing/stopEstablishing
 			// and access isEstablishing.
 			// and access isEstablishing.
 			if !controller.isEstablishing {
 			if !controller.isEstablishing {
@@ -462,6 +466,50 @@ loop:
 	NoticeInfo("exiting run tunnels")
 	NoticeInfo("exiting run tunnels")
 }
 }
 
 
+// classifyImpairedProtocol tracks "impaired" protocol classifications for failed
+// tunnels. A protocol is classified as impaired if a tunnel using that protocol
+// fails, repeatedly, shortly after the start of the session. During tunnel
+// establishment, impaired protocols are briefly skipped.
+//
+// One purpose of this measure is to defend against an attack where the adversary,
+// for example, tags an OSSH TCP connection as an "unidentified" protocol; allows
+// it to connect; but then kills the underlying TCP connection after a short time.
+// Since OSSH has less latency than other protocols that may bypass an "unidentified"
+// filter, these other protocols might never be selected for use.
+//
+// Concurrency note: only the runTunnels() goroutine may call classifyImpairedProtocol
+func (controller *Controller) classifyImpairedProtocol(failedTunnel *Tunnel) {
+	if failedTunnel.sessionStartTime.Add(IMPAIRED_PROTOCOL_CLASSIFICATION_DURATION).After(time.Now()) {
+		controller.impairedProtocolClassification[failedTunnel.protocol] += 1
+	} else {
+		controller.impairedProtocolClassification[failedTunnel.protocol] = 0
+	}
+	if len(controller.getImpairedProtocols()) == len(SupportedTunnelProtocols) {
+		// Reset classification if all protocols are classified as impaired as
+		// the network situation (or attack) may not be protocol-specific.
+		// TODO: compare against count of distinct supported protocols for
+		// current known server entries.
+		controller.impairedProtocolClassification = make(map[string]int)
+	}
+}
+
+// getImpairedProtocols returns a list of protocols that have sufficient
+// classifications to be considered impaired protocols.
+//
+// Concurrency note: only the runTunnels() goroutine may call getImpairedProtocols
+func (controller *Controller) getImpairedProtocols() []string {
+	if len(controller.impairedProtocolClassification) > 0 {
+		NoticeInfo("impaired protocols: %+v", controller.impairedProtocolClassification)
+	}
+	impairedProtocols := make([]string, 0)
+	for protocol, count := range controller.impairedProtocolClassification {
+		if count >= IMPAIRED_PROTOCOL_CLASSIFICATION_THRESHOLD {
+			impairedProtocols = append(impairedProtocols, protocol)
+		}
+	}
+	return impairedProtocols
+}
+
 // SignalTunnelFailure implements the TunnelOwner interface. This function
 // SignalTunnelFailure implements the TunnelOwner interface. This function
 // is called by Tunnel.operateTunnel when the tunnel has detected that it
 // is called by Tunnel.operateTunnel when the tunnel has detected that it
 // has failed. The Controller will signal runTunnels to create a new
 // has failed. The Controller will signal runTunnels to create a new
@@ -676,7 +724,8 @@ func (controller *Controller) startEstablishing() {
 	}
 	}
 
 
 	controller.establishWaitGroup.Add(1)
 	controller.establishWaitGroup.Add(1)
-	go controller.establishCandidateGenerator()
+	go controller.establishCandidateGenerator(
+		controller.getImpairedProtocols())
 }
 }
 
 
 // stopEstablishing signals the establish goroutines to stop and waits
 // stopEstablishing signals the establish goroutines to stop and waits
@@ -704,7 +753,7 @@ func (controller *Controller) stopEstablishing() {
 // establishCandidateGenerator populates the candidate queue with server entries
 // establishCandidateGenerator populates the candidate queue with server entries
 // from the data store. Server entries are iterated in rank order, so that promoted
 // from the data store. Server entries are iterated in rank order, so that promoted
 // servers with higher rank are priority candidates.
 // servers with higher rank are priority candidates.
-func (controller *Controller) establishCandidateGenerator() {
+func (controller *Controller) establishCandidateGenerator(impairedProtocols []string) {
 	defer controller.establishWaitGroup.Done()
 	defer controller.establishWaitGroup.Done()
 	defer close(controller.candidateServerEntries)
 	defer close(controller.candidateServerEntries)
 
 
@@ -720,9 +769,16 @@ loop:
 	// Repeat until stopped
 	// Repeat until stopped
 	for {
 	for {
 
 
+		if !WaitForNetworkConnectivity(
+			controller.config.NetworkConnectivityChecker,
+			controller.stopEstablishingBroadcast,
+			controller.shutdownBroadcast) {
+			break loop
+		}
+
 		// Send each iterator server entry to the establish workers
 		// Send each iterator server entry to the establish workers
 		startTime := time.Now()
 		startTime := time.Now()
-		for {
+		for i := 0; ; i++ {
 			serverEntry, err := iterator.Next()
 			serverEntry, err := iterator.Next()
 			if err != nil {
 			if err != nil {
 				NoticeAlert("failed to get next candidate: %s", err)
 				NoticeAlert("failed to get next candidate: %s", err)
@@ -734,6 +790,26 @@ loop:
 				break
 				break
 			}
 			}
 
 
+			// Disable impaired protocols. This is only done for the
+			// first iteration of the ESTABLISH_TUNNEL_WORK_TIME_SECONDS
+			// loop since (a) one iteration should be sufficient to
+			// evade the attack; (b) there's a good chance of false
+			// positives (such as short session durations due to network
+			// hopping on a mobile device).
+			// Impaired protocols logic is not applied when
+			// config.TunnelProtocol is specified.
+			// The edited serverEntry is temporary copy which is not
+			// stored or reused.
+			if i == 0 && controller.config.TunnelProtocol == "" {
+				serverEntry.DisableImpairedProtocols(impairedProtocols)
+				if len(serverEntry.GetSupportedProtocols()) == 0 {
+					// Skip this server entry, as it has no supported
+					// protocols after disabling the impaired ones
+					// TODO: modify ServerEntryIterator to skip these?
+					continue
+				}
+			}
+
 			// TODO: here we could generate multiple candidates from the
 			// TODO: here we could generate multiple candidates from the
 			// server entry when there are many MeekFrontingAddresses.
 			// server entry when there are many MeekFrontingAddresses.
 
 
@@ -804,12 +880,6 @@ loop:
 			continue
 			continue
 		}
 		}
 
 
-		if !WaitForNetworkConnectivity(
-			controller.config.NetworkConnectivityChecker,
-			controller.stopEstablishingBroadcast) {
-			break loop
-		}
-
 		tunnel, err := EstablishTunnel(
 		tunnel, err := EstablishTunnel(
 			controller.config,
 			controller.config,
 			controller.sessionId,
 			controller.sessionId,

+ 19 - 8
psiphon/net.go

@@ -22,6 +22,7 @@ package psiphon
 import (
 import (
 	"io"
 	"io"
 	"net"
 	"net"
+	"reflect"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
@@ -161,11 +162,11 @@ func Relay(localConn, remoteConn net.Conn) {
 // WaitForNetworkConnectivity uses a NetworkConnectivityChecker to
 // WaitForNetworkConnectivity uses a NetworkConnectivityChecker to
 // periodically check for network connectivity. It returns true if
 // periodically check for network connectivity. It returns true if
 // no NetworkConnectivityChecker is provided (waiting is disabled)
 // no NetworkConnectivityChecker is provided (waiting is disabled)
-// or if NetworkConnectivityChecker.HasNetworkConnectivity() indicates
-// connectivity. It polls the checker once a second. If a stop is
-// broadcast, false is returned.
+// or when NetworkConnectivityChecker.HasNetworkConnectivity()
+// indicates connectivity. It waits and polls the checker once a second.
+// If any stop is broadcast, false is returned immediately.
 func WaitForNetworkConnectivity(
 func WaitForNetworkConnectivity(
-	connectivityChecker NetworkConnectivityChecker, stopBroadcast <-chan struct{}) bool {
+	connectivityChecker NetworkConnectivityChecker, stopBroadcasts ...<-chan struct{}) bool {
 	if connectivityChecker == nil || 1 == connectivityChecker.HasNetworkConnectivity() {
 	if connectivityChecker == nil || 1 == connectivityChecker.HasNetworkConnectivity() {
 		return true
 		return true
 	}
 	}
@@ -175,10 +176,20 @@ func WaitForNetworkConnectivity(
 		if 1 == connectivityChecker.HasNetworkConnectivity() {
 		if 1 == connectivityChecker.HasNetworkConnectivity() {
 			return true
 			return true
 		}
 		}
-		select {
-		case <-ticker.C:
-			// Check again
-		case <-stopBroadcast:
+
+		selectCases := make([]reflect.SelectCase, 1+len(stopBroadcasts))
+		selectCases[0] = reflect.SelectCase{
+			Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ticker.C)}
+		for i, stopBroadcast := range stopBroadcasts {
+			selectCases[i+1] = reflect.SelectCase{
+				Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stopBroadcast)}
+		}
+
+		chosen, _, ok := reflect.Select(selectCases)
+		if chosen == 0 && ok {
+			// Ticker case, so check again
+		} else {
+			// Stop case
 			return false
 			return false
 		}
 		}
 	}
 	}

+ 12 - 0
psiphon/notice.go

@@ -157,6 +157,12 @@ func NoticeHomepage(url string) {
 	outputNotice("Homepage", false, "url", url)
 	outputNotice("Homepage", false, "url", url)
 }
 }
 
 
+// NoticeClientRegion is the client's region, as determined by the server and
+// reported to the client in the handshake.
+func NoticeClientRegion(region string) {
+	outputNotice("ClientRegion", false, "region", region)
+}
+
 // NoticeTunnels is how many active tunnels are available. The client should use this to
 // NoticeTunnels is how many active tunnels are available. The client should use this to
 // determine connecting/unexpected disconnect state transitions. When count is 0, the core is
 // determine connecting/unexpected disconnect state transitions. When count is 0, the core is
 // disconnected; when count > 1, the core is connected.
 // disconnected; when count > 1, the core is connected.
@@ -191,6 +197,12 @@ func NoticeClientUpgradeDownloaded(filename string) {
 	outputNotice("ClientUpgradeDownloaded", false, "filename", filename)
 	outputNotice("ClientUpgradeDownloaded", false, "filename", filename)
 }
 }
 
 
+// NoticeBytesTransferred reports how many tunneled bytes have been
+// transferred since the last NoticeBytesTransferred.
+func NoticeBytesTransferred(sent, received int64) {
+	outputNotice("BytesTransferred", false, "sent", sent, "received", received)
+}
+
 type noticeObject struct {
 type noticeObject struct {
 	NoticeType string          `json:"noticeType"`
 	NoticeType string          `json:"noticeType"`
 	Data       json.RawMessage `json:"data"`
 	Data       json.RawMessage `json:"data"`

+ 3 - 1
psiphon/serverApi.go

@@ -26,12 +26,13 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/transferstats"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"strconv"
 	"strconv"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/transferstats"
 )
 )
 
 
 // Session is a utility struct which holds all of the data associated
 // Session is a utility struct which holds all of the data associated
@@ -216,6 +217,7 @@ func (session *Session) doHandshakeRequest() error {
 	}
 	}
 
 
 	session.clientRegion = handshakeConfig.ClientRegion
 	session.clientRegion = handshakeConfig.ClientRegion
+	NoticeClientRegion(session.clientRegion)
 
 
 	var decodedServerEntries []*ServerEntry
 	var decodedServerEntries []*ServerEntry
 
 

+ 54 - 0
psiphon/serverEntry.go

@@ -29,6 +29,20 @@ import (
 	"strings"
 	"strings"
 )
 )
 
 
+const (
+	TUNNEL_PROTOCOL_SSH            = "SSH"
+	TUNNEL_PROTOCOL_OBFUSCATED_SSH = "OSSH"
+	TUNNEL_PROTOCOL_UNFRONTED_MEEK = "UNFRONTED-MEEK-OSSH"
+	TUNNEL_PROTOCOL_FRONTED_MEEK   = "FRONTED-MEEK-OSSH"
+)
+
+var SupportedTunnelProtocols = []string{
+	TUNNEL_PROTOCOL_FRONTED_MEEK,
+	TUNNEL_PROTOCOL_UNFRONTED_MEEK,
+	TUNNEL_PROTOCOL_OBFUSCATED_SSH,
+	TUNNEL_PROTOCOL_SSH,
+}
+
 // ServerEntry represents a Psiphon server. It contains information
 // ServerEntry represents a Psiphon server. It contains information
 // about how to estalish a tunnel connection to the server through
 // about how to estalish a tunnel connection to the server through
 // several protocols. ServerEntry are JSON records downloaded from
 // several protocols. ServerEntry are JSON records downloaded from
@@ -54,6 +68,46 @@ type ServerEntry struct {
 	MeekFrontingAddresses         []string `json:"meekFrontingAddresses"`
 	MeekFrontingAddresses         []string `json:"meekFrontingAddresses"`
 }
 }
 
 
+// SupportsProtocol returns true if and only if the ServerEntry has
+// the necessary capability to support the specified tunnel protocol.
+func (serverEntry *ServerEntry) SupportsProtocol(protocol string) bool {
+	requiredCapability := strings.TrimSuffix(protocol, "-OSSH")
+	return Contains(serverEntry.Capabilities, requiredCapability)
+}
+
+// GetSupportedProtocols returns a list of tunnel protocols supported
+// by the ServerEntry's capabilities.
+func (serverEntry *ServerEntry) GetSupportedProtocols() []string {
+	supportedProtocols := make([]string, 0)
+	for _, protocol := range SupportedTunnelProtocols {
+		if serverEntry.SupportsProtocol(protocol) {
+			supportedProtocols = append(supportedProtocols, protocol)
+		}
+	}
+	return supportedProtocols
+}
+
+// DisableImpairedProtocols modifies the ServerEntry to disable
+// the specified protocols.
+// Note: this assumes that protocol capabilities are 1-to-1.
+func (serverEntry *ServerEntry) DisableImpairedProtocols(impairedProtocols []string) {
+	capabilities := make([]string, 0)
+	for _, capability := range serverEntry.Capabilities {
+		omit := false
+		for _, protocol := range impairedProtocols {
+			requiredCapability := strings.TrimSuffix(protocol, "-OSSH")
+			if capability == requiredCapability {
+				omit = true
+				break
+			}
+		}
+		if !omit {
+			capabilities = append(capabilities, capability)
+		}
+	}
+	serverEntry.Capabilities = capabilities
+}
+
 // DecodeServerEntry extracts server entries from the encoding
 // DecodeServerEntry extracts server entries from the encoding
 // used by remote server lists and Psiphon server handshake requests.
 // used by remote server lists and Psiphon server handshake requests.
 func DecodeServerEntry(encodedServerEntry string) (serverEntry *ServerEntry, err error) {
 func DecodeServerEntry(encodedServerEntry string) (serverEntry *ServerEntry, err error) {

+ 27 - 1
psiphon/transferstats/collector.go

@@ -46,7 +46,9 @@ func newHostStats() *hostStats {
 
 
 // serverStats holds per-server stats.
 // serverStats holds per-server stats.
 type serverStats struct {
 type serverStats struct {
-	hostnameToStats map[string]*hostStats
+	hostnameToStats    map[string]*hostStats
+	totalBytesSent     int64
+	totalBytesReceived int64
 }
 }
 
 
 func newServerStats() *serverStats {
 func newServerStats() *serverStats {
@@ -94,6 +96,9 @@ func recordStat(stat *statsUpdate) {
 		storedServerStats.hostnameToStats[stat.hostname] = storedHostStats
 		storedServerStats.hostnameToStats[stat.hostname] = storedHostStats
 	}
 	}
 
 
+	storedServerStats.totalBytesSent += stat.numBytesSent
+	storedServerStats.totalBytesReceived += stat.numBytesReceived
+
 	storedHostStats.numBytesSent += stat.numBytesSent
 	storedHostStats.numBytesSent += stat.numBytesSent
 	storedHostStats.numBytesReceived += stat.numBytesReceived
 	storedHostStats.numBytesReceived += stat.numBytesReceived
 
 
@@ -123,6 +128,27 @@ func (ss serverStats) MarshalJSON() ([]byte, error) {
 	return json.Marshal(out)
 	return json.Marshal(out)
 }
 }
 
 
+// GetBytesTransferredForServer returns total bytes sent and received since
+// the last call to GetBytesTransferredForServer.
+func GetBytesTransferredForServer(serverID string) (sent, received int64) {
+	allStats.statsMutex.Lock()
+	defer allStats.statsMutex.Unlock()
+
+	stats := allStats.serverIDtoStats[serverID]
+
+	if stats == nil {
+		return
+	}
+
+	sent = stats.totalBytesSent
+	received = stats.totalBytesReceived
+
+	stats.totalBytesSent = 0
+	stats.totalBytesReceived = 0
+
+	return
+}
+
 // GetForServer returns the json-able stats package for the given server.
 // GetForServer returns the json-able stats package for the given server.
 // If there are no stats, nil will be returned.
 // If there are no stats, nil will be returned.
 func GetForServer(serverID string) (payload *serverStats) {
 func GetForServer(serverID string) (payload *serverStats) {

+ 23 - 38
psiphon/tunnel.go

@@ -27,7 +27,6 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
-	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
@@ -58,21 +57,6 @@ type TunnelOwner interface {
 	SignalTunnelFailure(tunnel *Tunnel)
 	SignalTunnelFailure(tunnel *Tunnel)
 }
 }
 
 
-const (
-	TUNNEL_PROTOCOL_SSH            = "SSH"
-	TUNNEL_PROTOCOL_OBFUSCATED_SSH = "OSSH"
-	TUNNEL_PROTOCOL_UNFRONTED_MEEK = "UNFRONTED-MEEK-OSSH"
-	TUNNEL_PROTOCOL_FRONTED_MEEK   = "FRONTED-MEEK-OSSH"
-)
-
-// This is a list of supported tunnel protocols, in default preference order
-var SupportedTunnelProtocols = []string{
-	TUNNEL_PROTOCOL_FRONTED_MEEK,
-	TUNNEL_PROTOCOL_UNFRONTED_MEEK,
-	TUNNEL_PROTOCOL_OBFUSCATED_SSH,
-	TUNNEL_PROTOCOL_SSH,
-}
-
 // Tunnel is a connection to a Psiphon server. An established
 // Tunnel is a connection to a Psiphon server. An established
 // tunnel includes a network connection to the specified server
 // tunnel includes a network connection to the specified server
 // and an SSH session built on top of that transport.
 // and an SSH session built on top of that transport.
@@ -88,6 +72,7 @@ type Tunnel struct {
 	shutdownOperateBroadcast chan struct{}
 	shutdownOperateBroadcast chan struct{}
 	portForwardFailures      chan int
 	portForwardFailures      chan int
 	portForwardFailureTotal  int
 	portForwardFailureTotal  int
+	sessionStartTime         time.Time
 }
 }
 
 
 // EstablishTunnel first makes a network transport connection to the
 // EstablishTunnel first makes a network transport connection to the
@@ -98,8 +83,7 @@ type Tunnel struct {
 // plain SSH over TCP, obfuscated SSH over TCP, or obfuscated SSH over
 // plain SSH over TCP, obfuscated SSH over TCP, or obfuscated SSH over
 // HTTP (meek protocol).
 // HTTP (meek protocol).
 // When requiredProtocol is not blank, that protocol is used. Otherwise,
 // When requiredProtocol is not blank, that protocol is used. Otherwise,
-// the first protocol in SupportedTunnelProtocols that's also in the
-// server capabilities is used.
+// the a random supported protocol is used.
 func EstablishTunnel(
 func EstablishTunnel(
 	config *Config,
 	config *Config,
 	sessionId string,
 	sessionId string,
@@ -155,6 +139,8 @@ func EstablishTunnel(
 		}
 		}
 	}
 	}
 
 
+	tunnel.sessionStartTime = time.Now()
+
 	// Now that network operations are complete, cancel interruptibility
 	// Now that network operations are complete, cancel interruptibility
 	pendingConns.Remove(conn)
 	pendingConns.Remove(conn)
 
 
@@ -306,8 +292,7 @@ func selectProtocol(config *Config, serverEntry *ServerEntry) (selectedProtocol
 	// TODO: properly handle protocols (e.g. FRONTED-MEEK-OSSH) vs. capabilities (e.g., {FRONTED-MEEK, OSSH})
 	// TODO: properly handle protocols (e.g. FRONTED-MEEK-OSSH) vs. capabilities (e.g., {FRONTED-MEEK, OSSH})
 	// for now, the code is simply assuming that MEEK capabilities imply OSSH capability.
 	// for now, the code is simply assuming that MEEK capabilities imply OSSH capability.
 	if config.TunnelProtocol != "" {
 	if config.TunnelProtocol != "" {
-		requiredCapability := strings.TrimSuffix(config.TunnelProtocol, "-OSSH")
-		if !Contains(serverEntry.Capabilities, requiredCapability) {
+		if !serverEntry.SupportsProtocol(config.TunnelProtocol) {
 			return "", ContextError(fmt.Errorf("server does not have required capability"))
 			return "", ContextError(fmt.Errorf("server does not have required capability"))
 		}
 		}
 		selectedProtocol = config.TunnelProtocol
 		selectedProtocol = config.TunnelProtocol
@@ -315,26 +300,10 @@ func selectProtocol(config *Config, serverEntry *ServerEntry) (selectedProtocol
 		// Pick at random from the supported protocols. This ensures that we'll eventually
 		// Pick at random from the supported protocols. This ensures that we'll eventually
 		// try all possible protocols. Depending on network configuration, it may be the
 		// try all possible protocols. Depending on network configuration, it may be the
 		// case that some protocol is only available through multi-capability servers,
 		// case that some protocol is only available through multi-capability servers,
-		// and a simplr ranked preference of protocols could lead to that protocol never
+		// and a simpler ranked preference of protocols could lead to that protocol never
 		// being selected.
 		// being selected.
 
 
-		// TODO: this is a good spot to apply protocol selection weightings. This would be
-		// to defend against an attack where the adversary, for example, classifies OSSH as
-		// an "unidentified" protocol; allows it to connect; but then kills the underlying
-		// TCP connection after a short time. Since OSSH has less latency than other protocols
-		// that may bypass an "unidentified" filter, other protocols which would be otherwise
-		// classified and not killed might never be selected for use.
-		// So one proposed defense is to add negative selection weights to the protocol
-		// associated with failed tunnels (controller.failedTunnels) with short session
-		// durations.
-
-		candidateProtocols := make([]string, 0)
-		for _, protocol := range SupportedTunnelProtocols {
-			requiredCapability := strings.TrimSuffix(protocol, "-OSSH")
-			if Contains(serverEntry.Capabilities, requiredCapability) {
-				candidateProtocols = append(candidateProtocols, protocol)
-			}
-		}
+		candidateProtocols := serverEntry.GetSupportedProtocols()
 		if len(candidateProtocols) == 0 {
 		if len(candidateProtocols) == 0 {
 			return "", ContextError(fmt.Errorf("server does not have any supported capabilities"))
 			return "", ContextError(fmt.Errorf("server does not have any supported capabilities"))
 		}
 		}
@@ -560,6 +529,14 @@ func (tunnel *Tunnel) operateTunnel(config *Config, tunnelOwner TunnelOwner) {
 			TUNNEL_SSH_KEEP_ALIVE_PERIOD_MAX)
 			TUNNEL_SSH_KEEP_ALIVE_PERIOD_MAX)
 	}
 	}
 
 
+	// TODO: don't initialize if !config.EmitBytesTransferred
+	noticeBytesTransferredTicker := time.NewTicker(1 * time.Second)
+	if !config.EmitBytesTransferred {
+		noticeBytesTransferredTicker.Stop()
+	} else {
+		defer noticeBytesTransferredTicker.Stop()
+	}
+
 	statsTimer := time.NewTimer(nextStatusRequestPeriod())
 	statsTimer := time.NewTimer(nextStatusRequestPeriod())
 	defer statsTimer.Stop()
 	defer statsTimer.Stop()
 
 
@@ -569,6 +546,14 @@ func (tunnel *Tunnel) operateTunnel(config *Config, tunnelOwner TunnelOwner) {
 	var err error
 	var err error
 	for err == nil {
 	for err == nil {
 		select {
 		select {
+		case <-noticeBytesTransferredTicker.C:
+			sent, received := transferstats.GetBytesTransferredForServer(
+				tunnel.serverEntry.IpAddress)
+			// Only emit notice when tunnel is not idle.
+			if sent > 0 || received > 0 {
+				NoticeBytesTransferred(sent, received)
+			}
+
 		case <-statsTimer.C:
 		case <-statsTimer.C:
 			sendStats(tunnel)
 			sendStats(tunnel)
 			statsTimer.Reset(nextStatusRequestPeriod())
 			statsTimer.Reset(nextStatusRequestPeriod())