Просмотр исходного кода

Add distinct rate limits for tunnel establishment

- Applies between initial network connection and
  completion of API handshake.

- Indirectly fixes an issue with establishment,
  including any liveness testing, being throttled
  when not first tunnel in session, potentially
  causing clients to favor switching servers upon
  reconnect.
Rod Hynes 4 лет назад
Родитель
Сommit
1232dd25e1
2 измененных файлов с 38 добавлено и 7 удалено
  1. 31 2
      psiphon/server/trafficRules.go
  2. 7 5
      psiphon/server/tunnelServer.go

+ 31 - 2
psiphon/server/trafficRules.go

@@ -266,6 +266,12 @@ type RateLimits struct {
 	WriteBytesPerSecond   *int64
 	WriteBytesPerSecond   *int64
 	CloseAfterExhausted   *bool
 	CloseAfterExhausted   *bool
 
 
+	// EstablishmentRead/WriteBytesPerSecond are used in place of
+	// Read/WriteBytesPerSecond for tunnels in the establishment phase, from the
+	// initial network connection up to the completion of the API handshake.
+	EstablishmentReadBytesPerSecond  *int64
+	EstablishmentWriteBytesPerSecond *int64
+
 	// UnthrottleFirstTunnelOnly specifies whether any
 	// UnthrottleFirstTunnelOnly specifies whether any
 	// ReadUnthrottledBytes/WriteUnthrottledBytes apply
 	// ReadUnthrottledBytes/WriteUnthrottledBytes apply
 	// only to the first tunnel in a session.
 	// only to the first tunnel in a session.
@@ -273,14 +279,19 @@ type RateLimits struct {
 }
 }
 
 
 // CommonRateLimits converts a RateLimits to a common.RateLimits.
 // CommonRateLimits converts a RateLimits to a common.RateLimits.
-func (rateLimits *RateLimits) CommonRateLimits() common.RateLimits {
-	return common.RateLimits{
+func (rateLimits *RateLimits) CommonRateLimits(handshaked bool) common.RateLimits {
+	r := common.RateLimits{
 		ReadUnthrottledBytes:  *rateLimits.ReadUnthrottledBytes,
 		ReadUnthrottledBytes:  *rateLimits.ReadUnthrottledBytes,
 		ReadBytesPerSecond:    *rateLimits.ReadBytesPerSecond,
 		ReadBytesPerSecond:    *rateLimits.ReadBytesPerSecond,
 		WriteUnthrottledBytes: *rateLimits.WriteUnthrottledBytes,
 		WriteUnthrottledBytes: *rateLimits.WriteUnthrottledBytes,
 		WriteBytesPerSecond:   *rateLimits.WriteBytesPerSecond,
 		WriteBytesPerSecond:   *rateLimits.WriteBytesPerSecond,
 		CloseAfterExhausted:   *rateLimits.CloseAfterExhausted,
 		CloseAfterExhausted:   *rateLimits.CloseAfterExhausted,
 	}
 	}
+	if !handshaked {
+		r.ReadBytesPerSecond = *rateLimits.EstablishmentReadBytesPerSecond
+		r.WriteBytesPerSecond = *rateLimits.EstablishmentWriteBytesPerSecond
+	}
+	return r
 }
 }
 
 
 // NewTrafficRulesSet initializes a TrafficRulesSet with
 // NewTrafficRulesSet initializes a TrafficRulesSet with
@@ -349,6 +360,8 @@ func (set *TrafficRulesSet) Validate() error {
 			(rules.RateLimits.ReadBytesPerSecond != nil && *rules.RateLimits.ReadBytesPerSecond < 0) ||
 			(rules.RateLimits.ReadBytesPerSecond != nil && *rules.RateLimits.ReadBytesPerSecond < 0) ||
 			(rules.RateLimits.WriteUnthrottledBytes != nil && *rules.RateLimits.WriteUnthrottledBytes < 0) ||
 			(rules.RateLimits.WriteUnthrottledBytes != nil && *rules.RateLimits.WriteUnthrottledBytes < 0) ||
 			(rules.RateLimits.WriteBytesPerSecond != nil && *rules.RateLimits.WriteBytesPerSecond < 0) ||
 			(rules.RateLimits.WriteBytesPerSecond != nil && *rules.RateLimits.WriteBytesPerSecond < 0) ||
+			(rules.RateLimits.EstablishmentReadBytesPerSecond != nil && *rules.RateLimits.EstablishmentReadBytesPerSecond < 0) ||
+			(rules.RateLimits.EstablishmentWriteBytesPerSecond != nil && *rules.RateLimits.EstablishmentWriteBytesPerSecond < 0) ||
 			(rules.DialTCPPortForwardTimeoutMilliseconds != nil && *rules.DialTCPPortForwardTimeoutMilliseconds < 0) ||
 			(rules.DialTCPPortForwardTimeoutMilliseconds != nil && *rules.DialTCPPortForwardTimeoutMilliseconds < 0) ||
 			(rules.IdleTCPPortForwardTimeoutMilliseconds != nil && *rules.IdleTCPPortForwardTimeoutMilliseconds < 0) ||
 			(rules.IdleTCPPortForwardTimeoutMilliseconds != nil && *rules.IdleTCPPortForwardTimeoutMilliseconds < 0) ||
 			(rules.IdleUDPPortForwardTimeoutMilliseconds != nil && *rules.IdleUDPPortForwardTimeoutMilliseconds < 0) ||
 			(rules.IdleUDPPortForwardTimeoutMilliseconds != nil && *rules.IdleUDPPortForwardTimeoutMilliseconds < 0) ||
@@ -527,6 +540,14 @@ func (set *TrafficRulesSet) GetTrafficRules(
 		trafficRules.RateLimits.CloseAfterExhausted = new(bool)
 		trafficRules.RateLimits.CloseAfterExhausted = new(bool)
 	}
 	}
 
 
+	if trafficRules.RateLimits.EstablishmentReadBytesPerSecond == nil {
+		trafficRules.RateLimits.EstablishmentReadBytesPerSecond = new(int64)
+	}
+
+	if trafficRules.RateLimits.EstablishmentWriteBytesPerSecond == nil {
+		trafficRules.RateLimits.EstablishmentWriteBytesPerSecond = new(int64)
+	}
+
 	if trafficRules.RateLimits.UnthrottleFirstTunnelOnly == nil {
 	if trafficRules.RateLimits.UnthrottleFirstTunnelOnly == nil {
 		trafficRules.RateLimits.UnthrottleFirstTunnelOnly = new(bool)
 		trafficRules.RateLimits.UnthrottleFirstTunnelOnly = new(bool)
 	}
 	}
@@ -727,6 +748,14 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			trafficRules.RateLimits.CloseAfterExhausted = filteredRules.Rules.RateLimits.CloseAfterExhausted
 			trafficRules.RateLimits.CloseAfterExhausted = filteredRules.Rules.RateLimits.CloseAfterExhausted
 		}
 		}
 
 
+		if filteredRules.Rules.RateLimits.EstablishmentReadBytesPerSecond != nil {
+			trafficRules.RateLimits.EstablishmentReadBytesPerSecond = filteredRules.Rules.RateLimits.EstablishmentReadBytesPerSecond
+		}
+
+		if filteredRules.Rules.RateLimits.EstablishmentWriteBytesPerSecond != nil {
+			trafficRules.RateLimits.EstablishmentWriteBytesPerSecond = filteredRules.Rules.RateLimits.EstablishmentWriteBytesPerSecond
+		}
+
 		if filteredRules.Rules.RateLimits.UnthrottleFirstTunnelOnly != nil {
 		if filteredRules.Rules.RateLimits.UnthrottleFirstTunnelOnly != nil {
 			trafficRules.RateLimits.UnthrottleFirstTunnelOnly = filteredRules.Rules.RateLimits.UnthrottleFirstTunnelOnly
 			trafficRules.RateLimits.UnthrottleFirstTunnelOnly = filteredRules.Rules.RateLimits.UnthrottleFirstTunnelOnly
 		}
 		}

+ 7 - 5
psiphon/server/tunnelServer.go

@@ -2280,9 +2280,9 @@ func (sshClient *sshClient) handleNewRandomStreamChannel(
 	// is available pre-handshake, albeit with additional restrictions.
 	// is available pre-handshake, albeit with additional restrictions.
 	//
 	//
 	// The random stream is subject to throttling in traffic rules; for
 	// The random stream is subject to throttling in traffic rules; for
-	// unthrottled liveness tests, set initial Read/WriteUnthrottledBytes as
-	// required. The random stream maximum count and response size cap
-	// mitigate clients abusing the facility to waste server resources.
+	// unthrottled liveness tests, set EstablishmentRead/WriteBytesPerSecond as
+	// required. The random stream maximum count and response size cap mitigate
+	// clients abusing the facility to waste server resources.
 	//
 	//
 	// Like all other channels, this channel type is handled asynchronously,
 	// Like all other channels, this channel type is handled asynchronously,
 	// so it's possible to run at any point in the tunnel lifecycle.
 	// so it's possible to run at any point in the tunnel lifecycle.
@@ -3130,7 +3130,8 @@ func (sshClient *sshClient) setTrafficRules() (int64, int64) {
 	if sshClient.throttledConn != nil {
 	if sshClient.throttledConn != nil {
 		// Any existing throttling state is reset.
 		// Any existing throttling state is reset.
 		sshClient.throttledConn.SetLimits(
 		sshClient.throttledConn.SetLimits(
-			sshClient.trafficRules.RateLimits.CommonRateLimits())
+			sshClient.trafficRules.RateLimits.CommonRateLimits(
+				sshClient.handshakeState.completed))
 	}
 	}
 
 
 	return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
 	return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
@@ -3224,7 +3225,8 @@ func (sshClient *sshClient) rateLimits() common.RateLimits {
 	sshClient.Lock()
 	sshClient.Lock()
 	defer sshClient.Unlock()
 	defer sshClient.Unlock()
 
 
-	return sshClient.trafficRules.RateLimits.CommonRateLimits()
+	return sshClient.trafficRules.RateLimits.CommonRateLimits(
+		sshClient.handshakeState.completed)
 }
 }
 
 
 func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
 func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {