Browse Source

Merge pull request #547 from rod-hynes/master

Traffic rule enhancements
Rod Hynes 6 years ago
parent
commit
6255a60a4c

+ 6 - 1
MobileLibrary/Android/PsiphonTunnel/PsiphonTunnel.java

@@ -96,6 +96,7 @@ public class PsiphonTunnel {
         default public void onStartedWaitingForNetworkConnectivity() {}
         default public void onStoppedWaitingForNetworkConnectivity() {}
         default public void onActiveAuthorizationIDs(List<String> authorizations) {}
+        default public void onTrafficRateLimits(long upstreamBytesPerSecond, long downstreamBytesPerSecond) {}
         default public void onApplicationParameter(String key, Object value) {}
         default public void onServerAlert(String reason, String subject) {}
         default public void onExiting() {}
@@ -760,13 +761,17 @@ public class PsiphonTunnel {
                 diagnostic = false;
                 JSONObject data = notice.getJSONObject("data");
                 mHostService.onBytesTransferred(data.getLong("sent"), data.getLong("received"));
-            }  else if (noticeType.equals("ActiveAuthorizationIDs")) {
+            } else if (noticeType.equals("ActiveAuthorizationIDs")) {
                 JSONArray activeAuthorizationIDs = notice.getJSONObject("data").getJSONArray("IDs");
                 ArrayList<String> authorizations = new ArrayList<String>();
                 for (int i=0; i<activeAuthorizationIDs.length(); i++) {
                     authorizations.add(activeAuthorizationIDs.getString(i));
                 }
                 mHostService.onActiveAuthorizationIDs(authorizations);
+            } else if (noticeType.equals("TrafficRateLimits")) {
+                JSONObject data = notice.getJSONObject("data");
+                mHostService.onTrafficRateLimits(
+                    data.getLong("upstreamBytesPerSecond"), data.getLong("downstreamBytesPerSecond"));
             } else if (noticeType.equals("Exiting")) {
                 mHostService.onExiting();
             } else if (noticeType.equals("ActiveTunnel")) {

+ 8 - 0
MobileLibrary/iOS/PsiphonTunnel/PsiphonTunnel/PsiphonTunnel.h

@@ -296,6 +296,14 @@ Swift: @code func onInternetReachabilityChanged(_ currentReachability: Reachabil
  */
 - (void)onActiveAuthorizationIDs:(NSArray * _Nonnull)authorizations;
 
+/*!
+ Called when tunnel-core receives traffic rate limit information in the handshake
+ @param upstreamBytesPerSecond  upstream rate limit; 0 for no limit
+ @param downstreamBytesPerSecond  downstream rate limit; 0 for no limit
+ Swift: @code func onTrafficRateLimits(_ upstreamBytesPerSecond: Int64, _ downstreamBytesPerSecond: Int64) @endcode
+ */
+- (void)onTrafficRateLimits:(int64_t)upstreamBytesPerSecond :(int64_t)downstreamBytesPerSecond;
+
 /*!
  Called when tunnel-core receives an alert from the server.
  @param reason The reason for the alert.

+ 15 - 1
MobileLibrary/iOS/PsiphonTunnel/PsiphonTunnel/PsiphonTunnel.m

@@ -968,7 +968,7 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
             [self logMessage:[NSString stringWithFormat: @"BytesTransferred notice missing data.sent or data.received: %@", noticeJSON]];
             return;
         }
-        
+
         if ([self.tunneledAppDelegate respondsToSelector:@selector(onBytesTransferred::)]) {
             dispatch_sync(self->callbackQueue, ^{
                 [self.tunneledAppDelegate onBytesTransferred:[sent longLongValue]:[received longLongValue]];
@@ -1001,6 +1001,20 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
             });
         }
     }
+    else if ([noticeType isEqualToString:@"TrafficRateLimits"]) {
+        id upstreamBytesPerSecond = [notice valueForKeyPath:@"data.upstreamBytesPerSecond"];
+        id downstreamBytesPerSecond = [notice valueForKeyPath:@"data.downstreamBytesPerSecond"];
+        if (![upstreamBytesPerSecond isKindOfClass:[NSNumber class]] || ![downstreamBytesPerSecond isKindOfClass:[NSNumber class]]) {
+            [self logMessage:[NSString stringWithFormat: @"TrafficRateLimits notice missing data.upstreamBytesPerSecond or data.downstreamBytesPerSecond: %@", noticeJSON]];
+            return;
+        }
+
+        if ([self.tunneledAppDelegate respondsToSelector:@selector(onTrafficRateLimits::)]) {
+            dispatch_sync(self->callbackQueue, ^{
+                [self.tunneledAppDelegate onTrafficRateLimits:[upstreamBytesPerSecond longLongValue]:[downstreamBytesPerSecond longLongValue]];
+            });
+        }
+    }
     else if ([noticeType isEqualToString:@"ServerAlert"]) {
         id reason = [notice valueForKeyPath:@"data.reason"];
         id subject = [notice valueForKeyPath:@"data.subject"];

+ 13 - 11
psiphon/common/protocol/protocol.go

@@ -403,17 +403,19 @@ func (labeledVersions LabeledQUICVersions) PruneInvalid() LabeledQUICVersions {
 }
 
 type HandshakeResponse struct {
-	SSHSessionID           string              `json:"ssh_session_id"`
-	Homepages              []string            `json:"homepages"`
-	UpgradeClientVersion   string              `json:"upgrade_client_version"`
-	PageViewRegexes        []map[string]string `json:"page_view_regexes"`
-	HttpsRequestRegexes    []map[string]string `json:"https_request_regexes"`
-	EncodedServerList      []string            `json:"encoded_server_list"`
-	ClientRegion           string              `json:"client_region"`
-	ServerTimestamp        string              `json:"server_timestamp"`
-	ActiveAuthorizationIDs []string            `json:"active_authorization_ids"`
-	TacticsPayload         json.RawMessage     `json:"tactics_payload"`
-	Padding                string              `json:"padding"`
+	SSHSessionID             string              `json:"ssh_session_id"`
+	Homepages                []string            `json:"homepages"`
+	UpgradeClientVersion     string              `json:"upgrade_client_version"`
+	PageViewRegexes          []map[string]string `json:"page_view_regexes"`
+	HttpsRequestRegexes      []map[string]string `json:"https_request_regexes"`
+	EncodedServerList        []string            `json:"encoded_server_list"`
+	ClientRegion             string              `json:"client_region"`
+	ServerTimestamp          string              `json:"server_timestamp"`
+	ActiveAuthorizationIDs   []string            `json:"active_authorization_ids"`
+	TacticsPayload           json.RawMessage     `json:"tactics_payload"`
+	UpstreamBytesPerSecond   int64               `json:"upstream_bytes_per_second"`
+	DownstreamBytesPerSecond int64               `json:"downstream_bytes_per_second"`
+	Padding                  string              `json:"padding"`
 }
 
 type ConnectedResponse struct {

+ 7 - 1
psiphon/controller.go

@@ -30,6 +30,7 @@ import (
 	"math/rand"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
@@ -67,6 +68,7 @@ type Controller struct {
 	establishCtx                            context.Context
 	stopEstablish                           context.CancelFunc
 	establishWaitGroup                      *sync.WaitGroup
+	establishedTunnelsCount                 int32
 	candidateServerEntries                  chan *candidateServerEntry
 	untunneledDialConfig                    *DialConfig
 	splitTunnelClassifier                   *SplitTunnelClassifier
@@ -731,6 +733,8 @@ loop:
 				break
 			}
 
+			atomic.AddInt32(&controller.establishedTunnelsCount, 1)
+
 			NoticeActiveTunnel(
 				connectedTunnel.dialParams.ServerEntry.GetDiagnosticID(),
 				connectedTunnel.dialParams.TunnelProtocol,
@@ -1556,6 +1560,7 @@ func (controller *Controller) doFetchTactics(
 		selectProtocol,
 		serverEntry,
 		true,
+		0,
 		0)
 	if dialParams == nil {
 		// MakeDialParameters may return nil, nil when the server entry can't
@@ -1924,7 +1929,8 @@ loop:
 			selectProtocol,
 			candidateServerEntry.serverEntry,
 			false,
-			controller.establishConnectTunnelCount)
+			controller.establishConnectTunnelCount,
+			int(atomic.LoadInt32(&controller.establishedTunnelsCount)))
 		if dialParams == nil || err != nil {
 
 			controller.concurrentEstablishTunnelsMutex.Unlock()

+ 8 - 5
psiphon/dialParameters.go

@@ -58,10 +58,11 @@ import (
 //
 // DialParameters is not safe for concurrent use.
 type DialParameters struct {
-	ServerEntry     *protocol.ServerEntry `json:"-"`
-	NetworkID       string                `json:"-"`
-	IsReplay        bool                  `json:"-"`
-	CandidateNumber int                   `json:"-"`
+	ServerEntry             *protocol.ServerEntry `json:"-"`
+	NetworkID               string                `json:"-"`
+	IsReplay                bool                  `json:"-"`
+	CandidateNumber         int                   `json:"-"`
+	EstablishedTunnelsCount int                   `json:"-"`
 
 	IsExchanged bool
 
@@ -149,7 +150,8 @@ func MakeDialParameters(
 	selectProtocol func(serverEntry *protocol.ServerEntry) (string, bool),
 	serverEntry *protocol.ServerEntry,
 	isTactics bool,
-	candidateNumber int) (*DialParameters, error) {
+	candidateNumber int,
+	establishedTunnelsCount int) (*DialParameters, error) {
 
 	networkID := config.GetNetworkID()
 
@@ -267,6 +269,7 @@ func MakeDialParameters(
 	dialParams.NetworkID = networkID
 	dialParams.IsReplay = isReplay
 	dialParams.CandidateNumber = candidateNumber
+	dialParams.EstablishedTunnelsCount = establishedTunnelsCount
 
 	// Even when replaying, LastUsedTimestamp is updated to extend the TTL of
 	// replayed dial parameters which will be updated in the datastore upon

+ 12 - 12
psiphon/dialParameters_test.go

@@ -108,7 +108,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	// Test: expected dial parameter fields set
 
-	dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -205,7 +205,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	dialParams.Failed(clientConfig)
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -220,7 +220,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	testNetworkID = prng.HexString(8)
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -237,7 +237,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	dialParams.Succeeded()
 
-	replayDialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	replayDialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -323,7 +323,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetClientParameters failed: %s", err)
 	}
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -338,7 +338,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	time.Sleep(1 * time.Second)
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -353,7 +353,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	serverEntries[0].ConfigurationVersion += 1
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -377,14 +377,14 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetClientParameters failed: %s", err)
 	}
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
 
 	dialParams.Succeeded()
 
-	replayDialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	replayDialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -432,7 +432,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 		if i%10 == 0 {
 
-			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0)
+			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0, 0)
 			if err != nil {
 				t.Fatalf("MakeDialParameters failed: %s", err)
 			}
@@ -461,7 +461,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 				t.Fatalf("ServerEntryIterator.Next failed: %s", err)
 			}
 
-			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0)
+			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0, 0)
 			if err != nil {
 				t.Fatalf("MakeDialParameters failed: %s", err)
 			}
@@ -483,7 +483,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 				t.Fatalf("ServerEntryIterator.Next failed: %s", err)
 			}
 
-			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0)
+			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0, 0)
 			if err != nil {
 				t.Fatalf("MakeDialParameters failed: %s", err)
 			}

+ 1 - 0
psiphon/exchange_test.go

@@ -184,6 +184,7 @@ func TestServerEntryExchange(t *testing.T) {
 				selectProtocol,
 				serverEntry,
 				false,
+				0,
 				0)
 			if err != nil {
 				t.Fatalf("MakeDialParameters failed: %s", err)

+ 15 - 0
psiphon/notice.go

@@ -433,6 +433,7 @@ func noticeWithDialParameters(noticeType string, dialParams *DialParameters) {
 		"protocol", dialParams.TunnelProtocol,
 		"isReplay", dialParams.IsReplay,
 		"candidateNumber", dialParams.CandidateNumber,
+		"establishedTunnelsCount", dialParams.EstablishedTunnelsCount,
 		"networkType", dialParams.GetNetworkType(),
 	}
 
@@ -795,6 +796,20 @@ func NoticeActiveAuthorizationIDs(activeAuthorizationIDs []string) {
 		"IDs", activeAuthorizationIDs)
 }
 
+// NoticeTrafficRateLimits reports the tunnel traffic rate limits in place for
+// this client, as reported by the server at the start of the tunnel. Values
+// of 0 indicate no limit. Values of -1 indicate that the server did not
+// report rate limits.
+//
+// Limitation: any rate limit changes during the lifetime of the tunnel are
+// not reported.
+func NoticeTrafficRateLimits(upstreamBytesPerSecond, downstreamBytesPerSecond int64) {
+	singletonNoticeLogger.outputNotice(
+		"TrafficRateLimits", 0,
+		"upstreamBytesPerSecond", upstreamBytesPerSecond,
+		"downstreamBytesPerSecond", downstreamBytesPerSecond)
+}
+
 func NoticeBindToDevice(deviceInfo string) {
 	outputRepetitiveNotice(
 		"BindToDevice", deviceInfo, 0,

+ 34 - 22
psiphon/server/api.go

@@ -205,6 +205,10 @@ func handshakeAPIRequestHandler(
 	isMobile := isMobileClientPlatform(clientPlatform)
 	normalizedPlatform := normalizeClientPlatform(clientPlatform)
 
+	// establishedTunnelsCount is used in traffic rule selection. When omitted by
+	// the client, a value of 0 will be used.
+	establishedTunnelsCount, _ := getIntStringRequestParam(params, "established_tunnels_count")
+
 	var authorizations []string
 	if params[protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS] != nil {
 		authorizations, err = getStringArrayRequestParam(params, protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS)
@@ -224,13 +228,14 @@ func handshakeAPIRequestHandler(
 	// TODO: in the case of SSH API requests, the actual sshClient could
 	// be passed in and used here. The session ID lookup is only strictly
 	// necessary to support web API requests.
-	activeAuthorizationIDs, authorizedAccessTypes, err := support.TunnelServer.SetClientHandshakeState(
+	handshakeStateInfo, err := support.TunnelServer.SetClientHandshakeState(
 		sessionID,
 		handshakeState{
-			completed:         true,
-			apiProtocol:       apiProtocol,
-			apiParams:         copyBaseRequestParams(params),
-			expectDomainBytes: len(httpsRequestRegexes) > 0,
+			completed:               true,
+			apiProtocol:             apiProtocol,
+			apiParams:               copyBaseRequestParams(params),
+			expectDomainBytes:       len(httpsRequestRegexes) > 0,
+			establishedTunnelsCount: establishedTunnelsCount,
 		},
 		authorizations)
 	if err != nil {
@@ -261,7 +266,7 @@ func handshakeAPIRequestHandler(
 			logFields := getRequestLogFields(
 				tactics.TACTICS_METRIC_EVENT_NAME,
 				geoIPData,
-				authorizedAccessTypes,
+				handshakeStateInfo.authorizedAccessTypes,
 				params,
 				handshakeRequestParams)
 
@@ -284,7 +289,7 @@ func handshakeAPIRequestHandler(
 		getRequestLogFields(
 			"",
 			geoIPData,
-			authorizedAccessTypes,
+			handshakeStateInfo.authorizedAccessTypes,
 			params,
 			baseRequestParams)).Debug("handshake")
 
@@ -311,17 +316,19 @@ func handshakeAPIRequestHandler(
 	}
 
 	handshakeResponse := protocol.HandshakeResponse{
-		SSHSessionID:           sessionID,
-		Homepages:              db.GetRandomizedHomepages(sponsorID, geoIPData.Country, geoIPData.ASN, isMobile),
-		UpgradeClientVersion:   db.GetUpgradeClientVersion(clientVersion, normalizedPlatform),
-		PageViewRegexes:        make([]map[string]string, 0),
-		HttpsRequestRegexes:    httpsRequestRegexes,
-		EncodedServerList:      encodedServerList,
-		ClientRegion:           geoIPData.Country,
-		ServerTimestamp:        common.GetCurrentTimestamp(),
-		ActiveAuthorizationIDs: activeAuthorizationIDs,
-		TacticsPayload:         marshaledTacticsPayload,
-		Padding:                strings.Repeat(" ", pad_response),
+		SSHSessionID:             sessionID,
+		Homepages:                db.GetRandomizedHomepages(sponsorID, geoIPData.Country, geoIPData.ASN, isMobile),
+		UpgradeClientVersion:     db.GetUpgradeClientVersion(clientVersion, normalizedPlatform),
+		PageViewRegexes:          make([]map[string]string, 0),
+		HttpsRequestRegexes:      httpsRequestRegexes,
+		EncodedServerList:        encodedServerList,
+		ClientRegion:             geoIPData.Country,
+		ServerTimestamp:          common.GetCurrentTimestamp(),
+		ActiveAuthorizationIDs:   handshakeStateInfo.activeAuthorizationIDs,
+		TacticsPayload:           marshaledTacticsPayload,
+		UpstreamBytesPerSecond:   handshakeStateInfo.upstreamBytesPerSecond,
+		DownstreamBytesPerSecond: handshakeStateInfo.downstreamBytesPerSecond,
+		Padding:                  strings.Repeat(" ", pad_response),
 	}
 
 	responsePayload, err := json.Marshal(handshakeResponse)
@@ -768,6 +775,7 @@ var baseRequestParams = []requestParamSpec{
 	{"egress_region", isRegionCode, requestParamOptional},
 	{"dial_duration", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"candidate_number", isIntString, requestParamOptional | requestParamLogStringAsInt},
+	{"established_tunnels_count", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"upstream_ossh_padding", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"meek_cookie_size", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"meek_limit_request", isIntString, requestParamOptional | requestParamLogStringAsInt},
@@ -1099,19 +1107,23 @@ func getStringRequestParam(params common.APIParameters, name string) (string, er
 	return value, nil
 }
 
-func getInt64RequestParam(params common.APIParameters, name string) (int64, error) {
+func getIntStringRequestParam(params common.APIParameters, name string) (int, error) {
 	if params[name] == nil {
 		return 0, errors.Tracef("missing param: %s", name)
 	}
-	value, ok := params[name].(float64)
+	valueStr, ok := params[name].(string)
 	if !ok {
 		return 0, errors.Tracef("invalid param: %s", name)
 	}
-	return int64(value), nil
+	value, err := strconv.Atoi(valueStr)
+	if !ok {
+		return 0, errors.Trace(err)
+	}
+	return value, nil
 }
 
 func getPaddingSizeRequestParam(params common.APIParameters, name string) (int, error) {
-	value, err := getInt64RequestParam(params, name)
+	value, err := getIntStringRequestParam(params, name)
 	if err != nil {
 		return 0, errors.Trace(err)
 	}

+ 2 - 0
psiphon/server/server_test.go

@@ -1185,6 +1185,7 @@ func checkExpectedLogFields(
 		"is_replay",
 		"dial_duration",
 		"candidate_number",
+		"established_tunnels_count",
 		"network_latency_multiplier",
 		"network_type",
 		"client_app_id",
@@ -2121,6 +2122,7 @@ func storePruneServerEntriesTest(
 			},
 			serverEntry,
 			false,
+			0,
 			0)
 		if err != nil {
 			t.Fatalf("MakeDialParameters failed: %s", err)

+ 2 - 1
psiphon/server/sessionID_test.go

@@ -166,7 +166,8 @@ func TestDuplicateSessionID(t *testing.T) {
 			func(_ *protocol.ServerEntry) (string, bool) { return "OSSH", true },
 			serverEntry,
 			false,
-			1)
+			0,
+			0)
 		if err != nil {
 			t.Fatalf("MakeDialParameters failed: %s", err)
 		}

+ 42 - 20
psiphon/server/tunnelServer.go

@@ -293,12 +293,14 @@ func (server *TunnelServer) ResetAllClientOSLConfigs() {
 //
 // The authorizations received from the client handshake are verified and the
 // resulting list of authorized access types are applied to the client's tunnel
-// and traffic rules. A list of active authorization IDs and authorized access
-// types is returned for responding to the client and logging.
+// and traffic rules.
+//
+// A list of active authorization IDs, authorized access types, and traffic
+// rate limits are returned for responding to the client and logging.
 func (server *TunnelServer) SetClientHandshakeState(
 	sessionID string,
 	state handshakeState,
-	authorizations []string) ([]string, []string, error) {
+	authorizations []string) (*handshakeStateInfo, error) {
 
 	return server.sshServer.setClientHandshakeState(sessionID, state, authorizations)
 }
@@ -908,23 +910,23 @@ func (sshServer *sshServer) resetAllClientOSLConfigs() {
 func (sshServer *sshServer) setClientHandshakeState(
 	sessionID string,
 	state handshakeState,
-	authorizations []string) ([]string, []string, error) {
+	authorizations []string) (*handshakeStateInfo, error) {
 
 	sshServer.clientsMutex.Lock()
 	client := sshServer.clients[sessionID]
 	sshServer.clientsMutex.Unlock()
 
 	if client == nil {
-		return nil, nil, errors.TraceNew("unknown session ID")
+		return nil, errors.TraceNew("unknown session ID")
 	}
 
-	activeAuthorizationIDs, authorizedAccessTypes, err := client.setHandshakeState(
+	handshakeStateInfo, err := client.setHandshakeState(
 		state, authorizations)
 	if err != nil {
-		return nil, nil, errors.Trace(err)
+		return nil, errors.Trace(err)
 	}
 
-	return activeAuthorizationIDs, authorizedAccessTypes, nil
+	return handshakeStateInfo, nil
 }
 
 func (sshServer *sshServer) getClientHandshaked(
@@ -1223,12 +1225,20 @@ type qualityMetrics struct {
 }
 
 type handshakeState struct {
-	completed             bool
-	apiProtocol           string
-	apiParams             common.APIParameters
-	authorizedAccessTypes []string
-	authorizationsRevoked bool
-	expectDomainBytes     bool
+	completed               bool
+	apiProtocol             string
+	apiParams               common.APIParameters
+	authorizedAccessTypes   []string
+	authorizationsRevoked   bool
+	expectDomainBytes       bool
+	establishedTunnelsCount int
+}
+
+type handshakeStateInfo struct {
+	activeAuthorizationIDs   []string
+	authorizedAccessTypes    []string
+	upstreamBytesPerSecond   int64
+	downstreamBytesPerSecond int64
 }
 
 func newSshClient(
@@ -2508,7 +2518,7 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, logMessa
 // sshClient.stop().
 func (sshClient *sshClient) setHandshakeState(
 	state handshakeState,
-	authorizations []string) ([]string, []string, error) {
+	authorizations []string) (*handshakeStateInfo, error) {
 
 	sshClient.Lock()
 	completed := sshClient.handshakeState.completed
@@ -2519,7 +2529,7 @@ func (sshClient *sshClient) setHandshakeState(
 
 	// Client must only perform one handshake
 	if completed {
-		return nil, nil, errors.TraceNew("handshake already completed")
+		return nil, errors.TraceNew("handshake already completed")
 	}
 
 	// Verify the authorizations submitted by the client. Verified, active
@@ -2653,10 +2663,16 @@ func (sshClient *sshClient) setHandshakeState(
 		sshClient.Unlock()
 	}
 
-	sshClient.setTrafficRules()
+	upstreamBytesPerSecond, downstreamBytesPerSecond := sshClient.setTrafficRules()
+
 	sshClient.setOSLConfig()
 
-	return authorizationIDs, authorizedAccessTypes, nil
+	return &handshakeStateInfo{
+		activeAuthorizationIDs:   authorizationIDs,
+		authorizedAccessTypes:    authorizedAccessTypes,
+		upstreamBytesPerSecond:   upstreamBytesPerSecond,
+		downstreamBytesPerSecond: downstreamBytesPerSecond,
+	}, nil
 }
 
 // getHandshaked returns whether the client has completed a handshake API
@@ -2719,12 +2735,15 @@ func (sshClient *sshClient) expectDomainBytes() bool {
 // setTrafficRules resets the client's traffic rules based on the latest server config
 // and client properties. As sshClient.trafficRules may be reset by a concurrent
 // goroutine, trafficRules must only be accessed within the sshClient mutex.
-func (sshClient *sshClient) setTrafficRules() {
+func (sshClient *sshClient) setTrafficRules() (int64, int64) {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
+	isFirstTunnelInSession := sshClient.isFirstTunnelInSession &&
+		sshClient.handshakeState.establishedTunnelsCount == 0
+
 	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
-		sshClient.isFirstTunnelInSession,
+		isFirstTunnelInSession,
 		sshClient.tunnelProtocol,
 		sshClient.geoIPData,
 		sshClient.handshakeState)
@@ -2734,6 +2753,9 @@ func (sshClient *sshClient) setTrafficRules() {
 		sshClient.throttledConn.SetLimits(
 			sshClient.trafficRules.RateLimits.CommonRateLimits())
 	}
+
+	return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
+		*sshClient.trafficRules.RateLimits.WriteBytesPerSecond
 }
 
 // setOSLConfig resets the client's OSL seed state based on the latest OSL config

+ 11 - 0
psiphon/serverApi.go

@@ -202,6 +202,12 @@ func (serverContext *ServerContext) doHandshakeRequest(
 	// - 'ssh_session_id' is ignored; client session ID is used instead
 
 	var handshakeResponse protocol.HandshakeResponse
+
+	// Initialize these fields to distinguish between psiphond omitting values in
+	// the response and the zero value, which means unlimited rate.
+	handshakeResponse.UpstreamBytesPerSecond = -1
+	handshakeResponse.DownstreamBytesPerSecond = -1
+
 	err := json.Unmarshal(response, &handshakeResponse)
 	if err != nil {
 		return errors.Trace(err)
@@ -282,6 +288,9 @@ func (serverContext *ServerContext) doHandshakeRequest(
 
 	NoticeActiveAuthorizationIDs(handshakeResponse.ActiveAuthorizationIDs)
 
+	NoticeTrafficRateLimits(
+		handshakeResponse.UpstreamBytesPerSecond, handshakeResponse.DownstreamBytesPerSecond)
+
 	if doTactics && handshakeResponse.TacticsPayload != nil &&
 		networkID == serverContext.tunnel.config.GetNetworkID() {
 
@@ -933,6 +942,8 @@ func getBaseAPIParameters(
 
 	params["candidate_number"] = strconv.Itoa(dialParams.CandidateNumber)
 
+	params["established_tunnels_count"] = strconv.Itoa(dialParams.EstablishedTunnelsCount)
+
 	if dialParams.NetworkLatencyMultiplier != 0.0 {
 		params["network_latency_multiplier"] =
 			fmt.Sprintf("%f", dialParams.NetworkLatencyMultiplier)