Rod Hynes 7 лет назад
Родитель
Сommit
a4086731cf
3 измененных файлов с 45 добавлено и 4 удалено
  1. 7 0
      psiphon/server/geoip.go
  2. 22 1
      psiphon/server/trafficRules.go
  3. 16 3
      psiphon/server/tunnelServer.go

+ 7 - 0
psiphon/server/geoip.go

@@ -209,6 +209,13 @@ func (geoIP *GeoIPService) GetSessionCache(sessionID string) GeoIPData {
 	return geoIPData.(GeoIPData)
 }
 
+// InSessionCache returns whether the session ID is present
+// in the session cache.
+func (geoIP *GeoIPService) InSessionCache(sessionID string) bool {
+	_, found := geoIP.sessionCache.Get(sessionID)
+	return found
+}
+
 // calculateDiscoveryValue derives a value from the client IP address to be
 // used as input in the server discovery algorithm. Since we do not explicitly
 // store the client IP address, we must derive the value here and store it for

+ 22 - 1
psiphon/server/trafficRules.go

@@ -179,6 +179,11 @@ type RateLimits struct {
 	WriteUnthrottledBytes *int64
 	WriteBytesPerSecond   *int64
 	CloseAfterExhausted   *bool
+
+	// UnthrottleFirstTunnelOnly specifies whether any
+	// ReadUnthrottledBytes/WriteUnthrottledBytes apply
+	// only to the first tunnel in a session.
+	UnthrottleFirstTunnelOnly *bool
 }
 
 // CommonRateLimits converts a RateLimits to a common.RateLimits.
@@ -272,7 +277,10 @@ func (set *TrafficRulesSet) Validate() error {
 // For the return value TrafficRules, all pointer and slice fields are initialized,
 // so nil checks are not required. The caller must not modify the returned TrafficRules.
 func (set *TrafficRulesSet) GetTrafficRules(
-	tunnelProtocol string, geoIPData GeoIPData, state handshakeState) TrafficRules {
+	isFirstTunnelInSession bool,
+	tunnelProtocol string,
+	geoIPData GeoIPData,
+	state handshakeState) TrafficRules {
 
 	set.ReloadableFile.RLock()
 	defer set.ReloadableFile.RUnlock()
@@ -315,6 +323,10 @@ func (set *TrafficRulesSet) GetTrafficRules(
 		trafficRules.RateLimits.CloseAfterExhausted = new(bool)
 	}
 
+	if trafficRules.RateLimits.UnthrottleFirstTunnelOnly == nil {
+		trafficRules.RateLimits.UnthrottleFirstTunnelOnly = new(bool)
+	}
+
 	intPtr := func(i int) *int {
 		return &i
 	}
@@ -448,6 +460,10 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			trafficRules.RateLimits.CloseAfterExhausted = filteredRules.Rules.RateLimits.CloseAfterExhausted
 		}
 
+		if filteredRules.Rules.RateLimits.UnthrottleFirstTunnelOnly != nil {
+			trafficRules.RateLimits.UnthrottleFirstTunnelOnly = filteredRules.Rules.RateLimits.UnthrottleFirstTunnelOnly
+		}
+
 		if filteredRules.Rules.DialTCPPortForwardTimeoutMilliseconds != nil {
 			trafficRules.DialTCPPortForwardTimeoutMilliseconds = filteredRules.Rules.DialTCPPortForwardTimeoutMilliseconds
 		}
@@ -483,6 +499,11 @@ func (set *TrafficRulesSet) GetTrafficRules(
 		break
 	}
 
+	if *trafficRules.RateLimits.UnthrottleFirstTunnelOnly && !isFirstTunnelInSession {
+		*trafficRules.RateLimits.ReadUnthrottledBytes = 0
+		*trafficRules.RateLimits.WriteUnthrottledBytes = 0
+	}
+
 	log.WithContextFields(LogFields{"trafficRules": trafficRules}).Debug("selected traffic rules")
 
 	return trafficRules

+ 16 - 3
psiphon/server/tunnelServer.go

@@ -888,6 +888,7 @@ type sshClient struct {
 	throttledConn                        *common.ThrottledConn
 	geoIPData                            GeoIPData
 	sessionID                            string
+	isFirstTunnelInSession               bool
 	supportsServerRequests               bool
 	handshakeState                       handshakeState
 	udpChannel                           ssh.Channel
@@ -1171,17 +1172,26 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 
 	sessionID := sshPasswordPayload.SessionId
 
+	// The GeoIP session cache will be populated if there was a previous tunnel
+	// with this session ID. This will be true up to GEOIP_SESSION_CACHE_TTL, which
+	// is currently much longer than the OSL session cache, another option to use if
+	// the GeoIP session cache is retired (the GeoIP session cache currently only
+	// supports legacy use cases).
+	isFirstTunnelInSession := sshClient.sshServer.support.GeoIPService.InSessionCache(sessionID)
+
 	supportsServerRequests := common.Contains(
 		sshPasswordPayload.ClientCapabilities, protocol.CLIENT_CAPABILITY_SERVER_REQUESTS)
 
 	sshClient.Lock()
 
-	// After this point, sshClient.sessionID is read-only as it will be read
+	// After this point, these values are read-only as they are read
 	// without obtaining sshClient.Lock.
 	sshClient.sessionID = sessionID
-
+	sshClient.isFirstTunnelInSession = isFirstTunnelInSession
 	sshClient.supportsServerRequests = supportsServerRequests
+
 	geoIPData := sshClient.geoIPData
+
 	sshClient.Unlock()
 
 	// Store the GeoIP data associated with the session ID. This makes
@@ -2010,7 +2020,10 @@ func (sshClient *sshClient) setTrafficRules() {
 	defer sshClient.Unlock()
 
 	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
-		sshClient.tunnelProtocol, sshClient.geoIPData, sshClient.handshakeState)
+		sshClient.isFirstTunnelInSession,
+		sshClient.tunnelProtocol,
+		sshClient.geoIPData,
+		sshClient.handshakeState)
 
 	if sshClient.throttledConn != nil {
 		// Any existing throttling state is reset.