Просмотр исходного кода

Report tunnel rate limits via handshake response

Rod Hynes 6 лет назад
Родитель
Сommit
ecaa03f131

+ 6 - 1
MobileLibrary/Android/PsiphonTunnel/PsiphonTunnel.java

@@ -96,6 +96,7 @@ public class PsiphonTunnel {
         default public void onStartedWaitingForNetworkConnectivity() {}
         default public void onStoppedWaitingForNetworkConnectivity() {}
         default public void onActiveAuthorizationIDs(List<String> authorizations) {}
+        default public void onTrafficRateLimits(long upstreamBytesPerSecond, long downstreamBytesPerSecond) {}
         default public void onApplicationParameter(String key, Object value) {}
         default public void onServerAlert(String reason, String subject) {}
         default public void onExiting() {}
@@ -760,13 +761,17 @@ public class PsiphonTunnel {
                 diagnostic = false;
                 JSONObject data = notice.getJSONObject("data");
                 mHostService.onBytesTransferred(data.getLong("sent"), data.getLong("received"));
-            }  else if (noticeType.equals("ActiveAuthorizationIDs")) {
+            } else if (noticeType.equals("ActiveAuthorizationIDs")) {
                 JSONArray activeAuthorizationIDs = notice.getJSONObject("data").getJSONArray("IDs");
                 ArrayList<String> authorizations = new ArrayList<String>();
                 for (int i=0; i<activeAuthorizationIDs.length(); i++) {
                     authorizations.add(activeAuthorizationIDs.getString(i));
                 }
                 mHostService.onActiveAuthorizationIDs(authorizations);
+            } else if (noticeType.equals("TrafficRateLimits")) {
+                JSONObject data = notice.getJSONObject("data");
+                mHostService.onTrafficRateLimits(
+                    data.getLong("upstreamBytesPerSecond"), data.getLong("downstreamBytesPerSecond"));
             } else if (noticeType.equals("Exiting")) {
                 mHostService.onExiting();
             } else if (noticeType.equals("ActiveTunnel")) {

+ 8 - 0
MobileLibrary/iOS/PsiphonTunnel/PsiphonTunnel/PsiphonTunnel.h

@@ -296,6 +296,14 @@ Swift: @code func onInternetReachabilityChanged(_ currentReachability: Reachabil
  */
 - (void)onActiveAuthorizationIDs:(NSArray * _Nonnull)authorizations;
 
+/*!
+ Called when tunnel-core receives traffic rate limit information in the handshake
+ @param upstreamBytesPerSecond  upstream rate limit; 0 for no limit
+ @param downstreamBytesPerSecond  downstream rate limit; 0 for no limit
+ Swift: @code func onTrafficRateLimits(_ upstreamBytesPerSecond: Int64, _ downstreamBytesPerSecond: Int64) @endcode
+ */
+- (void)onTrafficRateLimits:(int64_t)upstreamBytesPerSecond :(int64_t)downstreamBytesPerSecond;
+
 /*!
  Called when tunnel-core receives an alert from the server.
  @param reason The reason for the alert.

+ 15 - 1
MobileLibrary/iOS/PsiphonTunnel/PsiphonTunnel/PsiphonTunnel.m

@@ -968,7 +968,7 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
             [self logMessage:[NSString stringWithFormat: @"BytesTransferred notice missing data.sent or data.received: %@", noticeJSON]];
             return;
         }
-        
+
         if ([self.tunneledAppDelegate respondsToSelector:@selector(onBytesTransferred::)]) {
             dispatch_sync(self->callbackQueue, ^{
                 [self.tunneledAppDelegate onBytesTransferred:[sent longLongValue]:[received longLongValue]];
@@ -1001,6 +1001,20 @@ typedef NS_ERROR_ENUM(PsiphonTunnelErrorDomain, PsiphonTunnelErrorCode) {
             });
         }
     }
+    else if ([noticeType isEqualToString:@"TrafficRateLimits"]) {
+        id upstreamBytesPerSecond = [notice valueForKeyPath:@"data.upstreamBytesPerSecond"];
+        id downstreamBytesPerSecond = [notice valueForKeyPath:@"data.downstreamBytesPerSecond"];
+        if (![upstreamBytesPerSecond isKindOfClass:[NSNumber class]] || ![downstreamBytesPerSecond isKindOfClass:[NSNumber class]]) {
+            [self logMessage:[NSString stringWithFormat: @"TrafficRateLimits notice missing data.upstreamBytesPerSecond or data.downstreamBytesPerSecond: %@", noticeJSON]];
+            return;
+        }
+
+        if ([self.tunneledAppDelegate respondsToSelector:@selector(onTrafficRateLimits::)]) {
+            dispatch_sync(self->callbackQueue, ^{
+                [self.tunneledAppDelegate onTrafficRateLimits:[upstreamBytesPerSecond longLongValue]:[downstreamBytesPerSecond longLongValue]];
+            });
+        }
+    }
     else if ([noticeType isEqualToString:@"ServerAlert"]) {
         id reason = [notice valueForKeyPath:@"data.reason"];
         id subject = [notice valueForKeyPath:@"data.subject"];

+ 13 - 11
psiphon/common/protocol/protocol.go

@@ -403,17 +403,19 @@ func (labeledVersions LabeledQUICVersions) PruneInvalid() LabeledQUICVersions {
 }
 
 type HandshakeResponse struct {
-	SSHSessionID           string              `json:"ssh_session_id"`
-	Homepages              []string            `json:"homepages"`
-	UpgradeClientVersion   string              `json:"upgrade_client_version"`
-	PageViewRegexes        []map[string]string `json:"page_view_regexes"`
-	HttpsRequestRegexes    []map[string]string `json:"https_request_regexes"`
-	EncodedServerList      []string            `json:"encoded_server_list"`
-	ClientRegion           string              `json:"client_region"`
-	ServerTimestamp        string              `json:"server_timestamp"`
-	ActiveAuthorizationIDs []string            `json:"active_authorization_ids"`
-	TacticsPayload         json.RawMessage     `json:"tactics_payload"`
-	Padding                string              `json:"padding"`
+	SSHSessionID             string              `json:"ssh_session_id"`
+	Homepages                []string            `json:"homepages"`
+	UpgradeClientVersion     string              `json:"upgrade_client_version"`
+	PageViewRegexes          []map[string]string `json:"page_view_regexes"`
+	HttpsRequestRegexes      []map[string]string `json:"https_request_regexes"`
+	EncodedServerList        []string            `json:"encoded_server_list"`
+	ClientRegion             string              `json:"client_region"`
+	ServerTimestamp          string              `json:"server_timestamp"`
+	ActiveAuthorizationIDs   []string            `json:"active_authorization_ids"`
+	TacticsPayload           json.RawMessage     `json:"tactics_payload"`
+	UpstreamBytesPerSecond   int64               `json:"upstream_bytes_per_second"`
+	DownstreamBytesPerSecond int64               `json:"downstream_bytes_per_second"`
+	Padding                  string              `json:"padding"`
 }
 
 type ConnectedResponse struct {

+ 14 - 0
psiphon/notice.go

@@ -795,6 +795,20 @@ func NoticeActiveAuthorizationIDs(activeAuthorizationIDs []string) {
 		"IDs", activeAuthorizationIDs)
 }
 
+// NoticeTrafficRateLimits reports the tunnel traffic rate limits in place for
+// this client, as reported by the server at the start of the tunnel. Values
+// of 0 indicate no limit. Values of -1 indicate that the server did not
+// report rate limits.
+//
+// Limitation: any rate limit changes during the lifetime of the tunnel are
+// not reported.
+func NoticeTrafficRateLimits(upstreamBytesPerSecond, downstreamBytesPerSecond int64) {
+	singletonNoticeLogger.outputNotice(
+		"TrafficRateLimits", 0,
+		"upstreamBytesPerSecond", upstreamBytesPerSecond,
+		"downstreamBytesPerSecond", downstreamBytesPerSecond)
+}
+
 func NoticeBindToDevice(deviceInfo string) {
 	outputRepetitiveNotice(
 		"BindToDevice", deviceInfo, 0,

+ 16 - 14
psiphon/server/api.go

@@ -224,7 +224,7 @@ func handshakeAPIRequestHandler(
 	// TODO: in the case of SSH API requests, the actual sshClient could
 	// be passed in and used here. The session ID lookup is only strictly
 	// necessary to support web API requests.
-	activeAuthorizationIDs, authorizedAccessTypes, err := support.TunnelServer.SetClientHandshakeState(
+	clientHandshakeStateInfo, err := support.TunnelServer.SetClientHandshakeState(
 		sessionID,
 		handshakeState{
 			completed:         true,
@@ -261,7 +261,7 @@ func handshakeAPIRequestHandler(
 			logFields := getRequestLogFields(
 				tactics.TACTICS_METRIC_EVENT_NAME,
 				geoIPData,
-				authorizedAccessTypes,
+				clientHandshakeStateInfo.AuthorizedAccessTypes,
 				params,
 				handshakeRequestParams)
 
@@ -284,7 +284,7 @@ func handshakeAPIRequestHandler(
 		getRequestLogFields(
 			"",
 			geoIPData,
-			authorizedAccessTypes,
+			clientHandshakeStateInfo.AuthorizedAccessTypes,
 			params,
 			baseRequestParams)).Debug("handshake")
 
@@ -311,17 +311,19 @@ func handshakeAPIRequestHandler(
 	}
 
 	handshakeResponse := protocol.HandshakeResponse{
-		SSHSessionID:           sessionID,
-		Homepages:              db.GetRandomizedHomepages(sponsorID, geoIPData.Country, geoIPData.ASN, isMobile),
-		UpgradeClientVersion:   db.GetUpgradeClientVersion(clientVersion, normalizedPlatform),
-		PageViewRegexes:        make([]map[string]string, 0),
-		HttpsRequestRegexes:    httpsRequestRegexes,
-		EncodedServerList:      encodedServerList,
-		ClientRegion:           geoIPData.Country,
-		ServerTimestamp:        common.GetCurrentTimestamp(),
-		ActiveAuthorizationIDs: activeAuthorizationIDs,
-		TacticsPayload:         marshaledTacticsPayload,
-		Padding:                strings.Repeat(" ", pad_response),
+		SSHSessionID:             sessionID,
+		Homepages:                db.GetRandomizedHomepages(sponsorID, geoIPData.Country, geoIPData.ASN, isMobile),
+		UpgradeClientVersion:     db.GetUpgradeClientVersion(clientVersion, normalizedPlatform),
+		PageViewRegexes:          make([]map[string]string, 0),
+		HttpsRequestRegexes:      httpsRequestRegexes,
+		EncodedServerList:        encodedServerList,
+		ClientRegion:             geoIPData.Country,
+		ServerTimestamp:          common.GetCurrentTimestamp(),
+		ActiveAuthorizationIDs:   clientHandshakeStateInfo.ActiveAuthorizationIDs,
+		TacticsPayload:           marshaledTacticsPayload,
+		UpstreamBytesPerSecond:   clientHandshakeStateInfo.UpstreamBytesPerSecond,
+		DownstreamBytesPerSecond: clientHandshakeStateInfo.DownstreamBytesPerSecond,
+		Padding:                  strings.Repeat(" ", pad_response),
 	}
 
 	responsePayload, err := json.Marshal(handshakeResponse)

+ 31 - 13
psiphon/server/tunnelServer.go

@@ -284,6 +284,13 @@ func (server *TunnelServer) ResetAllClientOSLConfigs() {
 	server.sshServer.resetAllClientOSLConfigs()
 }
 
+type ClientHandshakeStateInfo struct {
+	ActiveAuthorizationIDs   []string
+	AuthorizedAccessTypes    []string
+	UpstreamBytesPerSecond   int64
+	DownstreamBytesPerSecond int64
+}
+
 // SetClientHandshakeState sets the handshake state -- that it completed and
 // what parameters were passed -- in sshClient. This state is used for allowing
 // port forwards and for future traffic rule selection. SetClientHandshakeState
@@ -293,12 +300,14 @@ func (server *TunnelServer) ResetAllClientOSLConfigs() {
 //
 // The authorizations received from the client handshake are verified and the
 // resulting list of authorized access types are applied to the client's tunnel
-// and traffic rules. A list of active authorization IDs and authorized access
-// types is returned for responding to the client and logging.
+// and traffic rules.
+//
+// A list of active authorization IDs, authorized access types, and traffic
+// rate limits are returned for responding to the client and logging.
 func (server *TunnelServer) SetClientHandshakeState(
 	sessionID string,
 	state handshakeState,
-	authorizations []string) ([]string, []string, error) {
+	authorizations []string) (*ClientHandshakeStateInfo, error) {
 
 	return server.sshServer.setClientHandshakeState(sessionID, state, authorizations)
 }
@@ -908,23 +917,23 @@ func (sshServer *sshServer) resetAllClientOSLConfigs() {
 func (sshServer *sshServer) setClientHandshakeState(
 	sessionID string,
 	state handshakeState,
-	authorizations []string) ([]string, []string, error) {
+	authorizations []string) (*ClientHandshakeStateInfo, error) {
 
 	sshServer.clientsMutex.Lock()
 	client := sshServer.clients[sessionID]
 	sshServer.clientsMutex.Unlock()
 
 	if client == nil {
-		return nil, nil, errors.TraceNew("unknown session ID")
+		return nil, errors.TraceNew("unknown session ID")
 	}
 
-	activeAuthorizationIDs, authorizedAccessTypes, err := client.setHandshakeState(
+	clientHandshakeStateInfo, err := client.setHandshakeState(
 		state, authorizations)
 	if err != nil {
-		return nil, nil, errors.Trace(err)
+		return nil, errors.Trace(err)
 	}
 
-	return activeAuthorizationIDs, authorizedAccessTypes, nil
+	return clientHandshakeStateInfo, nil
 }
 
 func (sshServer *sshServer) getClientHandshaked(
@@ -2508,7 +2517,7 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, logMessa
 // sshClient.stop().
 func (sshClient *sshClient) setHandshakeState(
 	state handshakeState,
-	authorizations []string) ([]string, []string, error) {
+	authorizations []string) (*ClientHandshakeStateInfo, error) {
 
 	sshClient.Lock()
 	completed := sshClient.handshakeState.completed
@@ -2519,7 +2528,7 @@ func (sshClient *sshClient) setHandshakeState(
 
 	// Client must only perform one handshake
 	if completed {
-		return nil, nil, errors.TraceNew("handshake already completed")
+		return nil, errors.TraceNew("handshake already completed")
 	}
 
 	// Verify the authorizations submitted by the client. Verified, active
@@ -2653,10 +2662,16 @@ func (sshClient *sshClient) setHandshakeState(
 		sshClient.Unlock()
 	}
 
-	sshClient.setTrafficRules()
+	upstreamBytesPerSecond, downstreamBytesPerSecond := sshClient.setTrafficRules()
+
 	sshClient.setOSLConfig()
 
-	return authorizationIDs, authorizedAccessTypes, nil
+	return &ClientHandshakeStateInfo{
+		ActiveAuthorizationIDs:   authorizationIDs,
+		AuthorizedAccessTypes:    authorizedAccessTypes,
+		UpstreamBytesPerSecond:   upstreamBytesPerSecond,
+		DownstreamBytesPerSecond: downstreamBytesPerSecond,
+	}, nil
 }
 
 // getHandshaked returns whether the client has completed a handshake API
@@ -2719,7 +2734,7 @@ func (sshClient *sshClient) expectDomainBytes() bool {
 // setTrafficRules resets the client's traffic rules based on the latest server config
 // and client properties. As sshClient.trafficRules may be reset by a concurrent
 // goroutine, trafficRules must only be accessed within the sshClient mutex.
-func (sshClient *sshClient) setTrafficRules() {
+func (sshClient *sshClient) setTrafficRules() (int64, int64) {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
@@ -2734,6 +2749,9 @@ func (sshClient *sshClient) setTrafficRules() {
 		sshClient.throttledConn.SetLimits(
 			sshClient.trafficRules.RateLimits.CommonRateLimits())
 	}
+
+	return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
+		*sshClient.trafficRules.RateLimits.WriteBytesPerSecond
 }
 
 // setOSLConfig resets the client's OSL seed state based on the latest OSL config

+ 9 - 0
psiphon/serverApi.go

@@ -202,6 +202,12 @@ func (serverContext *ServerContext) doHandshakeRequest(
 	// - 'ssh_session_id' is ignored; client session ID is used instead
 
 	var handshakeResponse protocol.HandshakeResponse
+
+	// Initialize these fields to distinguish between psiphond omitting values in
+	// the response and the zero value, which means unlimited rate.
+	handshakeResponse.UpstreamBytesPerSecond = -1
+	handshakeResponse.DownstreamBytesPerSecond = -1
+
 	err := json.Unmarshal(response, &handshakeResponse)
 	if err != nil {
 		return errors.Trace(err)
@@ -282,6 +288,9 @@ func (serverContext *ServerContext) doHandshakeRequest(
 
 	NoticeActiveAuthorizationIDs(handshakeResponse.ActiveAuthorizationIDs)
 
+	NoticeTrafficRateLimits(
+		handshakeResponse.UpstreamBytesPerSecond, handshakeResponse.DownstreamBytesPerSecond)
+
 	if doTactics && handshakeResponse.TacticsPayload != nil &&
 		networkID == serverContext.tunnel.config.GetNetworkID() {