Эх сурвалжийг харах

Added rate limiting features
* Per-protocol rate limits
* Unlimited allowance before rate limiting activated

Rod Hynes 9 жил өмнө
parent
commit
6d23de7e6e

+ 50 - 11
psiphon/net.go

@@ -63,6 +63,7 @@ import (
 	"os"
 	"os"
 	"reflect"
 	"reflect"
 	"sync"
 	"sync"
+	"sync/atomic"
 	"time"
 	"time"
 
 
 	"github.com/Psiphon-Inc/dns"
 	"github.com/Psiphon-Inc/dns"
@@ -684,18 +685,30 @@ func (conn *IdleTimeoutConn) Write(buffer []byte) (int, error) {
 }
 }
 
 
 // ThrottledConn wraps a net.Conn with read and write rate limiters.
 // ThrottledConn wraps a net.Conn with read and write rate limiters.
-// Rates are specified as bytes per second. The underlying rate limiter
-// uses the token bucket algorithm to calculate delay times for read
-// and write operations. Specify limit values of 0 set no limit.
+// Rates are specified as bytes per second. Optional unlimited byte
+// counts allow for a number of bytes to read or write before
+// applying rate limiting. Specify limit values of 0 to set no rate
+// limit (unlimited counts are ignored in this case).
+// The underlying rate limiter uses the token bucket algorithm to
+// calculate delay times for read and write operations.
 type ThrottledConn struct {
 type ThrottledConn struct {
 	net.Conn
 	net.Conn
-	reader io.Reader
-	writer io.Writer
+	unlimitedReadBytes  int64
+	limitingReads       int32
+	limitedReader       io.Reader
+	unlimitedWriteBytes int64
+	limitingWrites      int32
+	limitedWriter       io.Writer
 }
 }
 
 
+// NewThrottledConn initializes a new ThrottledConn.
 func NewThrottledConn(
 func NewThrottledConn(
 	conn net.Conn,
 	conn net.Conn,
-	limitReadBytesPerSecond, limitWriteBytesPerSecond int64) *ThrottledConn {
+	unlimitedReadBytes, limitReadBytesPerSecond,
+	unlimitedWriteBytes, limitWriteBytesPerSecond int64) *ThrottledConn {
+
+	// When no limit is specified, the rate limited reader/writer
+	// is simply the base reader/writer.
 
 
 	var reader io.Reader
 	var reader io.Reader
 	if limitReadBytesPerSecond == 0 {
 	if limitReadBytesPerSecond == 0 {
@@ -705,6 +718,7 @@ func NewThrottledConn(
 			ratelimit.NewBucketWithRate(
 			ratelimit.NewBucketWithRate(
 				float64(limitReadBytesPerSecond), limitReadBytesPerSecond))
 				float64(limitReadBytesPerSecond), limitReadBytesPerSecond))
 	}
 	}
+
 	var writer io.Writer
 	var writer io.Writer
 	if limitWriteBytesPerSecond == 0 {
 	if limitWriteBytesPerSecond == 0 {
 		writer = conn
 		writer = conn
@@ -713,17 +727,42 @@ func NewThrottledConn(
 			ratelimit.NewBucketWithRate(
 			ratelimit.NewBucketWithRate(
 				float64(limitWriteBytesPerSecond), limitWriteBytesPerSecond))
 				float64(limitWriteBytesPerSecond), limitWriteBytesPerSecond))
 	}
 	}
+
 	return &ThrottledConn{
 	return &ThrottledConn{
-		Conn:   conn,
-		reader: reader,
-		writer: writer,
+		Conn:                conn,
+		unlimitedReadBytes:  unlimitedReadBytes,
+		limitingReads:       0,
+		limitedReader:       reader,
+		unlimitedWriteBytes: unlimitedWriteBytes,
+		limitingWrites:      0,
+		limitedWriter:       writer,
 	}
 	}
 }
 }
 
 
 func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
 func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
-	return conn.reader.Read(buffer)
+
+	// Use the base reader until the unlimited count is exhausted.
+	if atomic.LoadInt32(&conn.limitingReads) == 0 {
+		if atomic.AddInt64(&conn.unlimitedReadBytes, -int64(len(buffer))) <= 0 {
+			atomic.StoreInt32(&conn.limitingReads, 1)
+		} else {
+			return conn.Read(buffer)
+		}
+	}
+
+	return conn.limitedReader.Read(buffer)
 }
 }
 
 
 func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
-	return conn.writer.Write(buffer)
+
+	// Use the base writer until the unlimited count is exhausted.
+	if atomic.LoadInt32(&conn.limitingWrites) == 0 {
+		if atomic.AddInt64(&conn.unlimitedWriteBytes, -int64(len(buffer))) <= 0 {
+			atomic.StoreInt32(&conn.limitingWrites, 1)
+		} else {
+			return conn.Write(buffer)
+		}
+	}
+
+	return conn.limitedWriter.Write(buffer)
 }
 }

+ 55 - 16
psiphon/server/config.go

@@ -217,20 +217,41 @@ type Config struct {
 	LoadMonitorPeriodSeconds int
 	LoadMonitorPeriodSeconds int
 }
 }
 
 
+// RateLimits specify the rate limits for tunneled data transfer
+// between an individual client and the server.
+type RateLimits struct {
+
+	// DownstreamUnlimitedBytes specifies the number of downstream
+	// bytes to transfer, approximately, before starting rate
+	// limiting.
+	DownstreamUnlimitedBytes int64
+
+	// DownstreamBytesPerSecond specifies a rate limit for downstream
+	// data transfer. The default, 0, is no limit.
+	DownstreamBytesPerSecond int
+
+	// UpstreamUnlimitedBytes specifies the number of upstream
+	// bytes to transfer, approximately, before starting rate
+	// limiting.
+	UpstreamUnlimitedBytes int64
+
+	// UpstreamBytesPerSecond specifies a rate limit for upstream
+	// data transfer. The default, 0, is no limit.
+	UpstreamBytesPerSecond int
+}
+
 // TrafficRules specify the limits placed on client traffic.
 // TrafficRules specify the limits placed on client traffic.
 type TrafficRules struct {
 type TrafficRules struct {
+	// DefaultRateLimitsare the rate limits to be applied when
+	// no protocol-specific rates are set.
+	DefaultRateLimits RateLimits
 
 
-	// LimitDownstreamBytesPerSecond specifies a rate limit for
-	// downstream data transfer between a single client and the
-	// server.
-	// The default, 0, is no rate limit.
-	LimitDownstreamBytesPerSecond int
-
-	// LimitUpstreamBytesPerSecond specifies a rate limit for
-	// upstream data transfer between a single client and the
-	// server.
-	// The default, 0, is no rate limit.
-	LimitUpstreamBytesPerSecond int
+	// ProtocolRateLimits specifies the rate limits for particular
+	// tunnel protocols. The key for each rate limit entry is one
+	// or more space delimited Psiphon tunnel protocol names. Valid
+	// tunnel protocols includes the same list as for
+	// TunnelProtocolPorts.
+	ProtocolRateLimits map[string]RateLimits
 
 
 	// IdlePortForwardTimeoutMilliseconds is the timeout period
 	// IdlePortForwardTimeoutMilliseconds is the timeout period
 	// after which idle (no bytes flowing in either direction)
 	// after which idle (no bytes flowing in either direction)
@@ -292,12 +313,12 @@ func (config *Config) UseFail2Ban() bool {
 }
 }
 
 
 // GetTrafficRules looks up the traffic rules for the specified country. If there
 // GetTrafficRules looks up the traffic rules for the specified country. If there
-// are no RegionalTrafficRules for the country, DefaultTrafficRules are returned.
-func (config *Config) GetTrafficRules(targetCountryCode string) TrafficRules {
+// are no RegionalTrafficRules for the country, DefaultTrafficRules are used.
+func (config *Config) GetTrafficRules(clientCountryCode string) TrafficRules {
 	// TODO: faster lookup?
 	// TODO: faster lookup?
 	for countryCodes, trafficRules := range config.RegionalTrafficRules {
 	for countryCodes, trafficRules := range config.RegionalTrafficRules {
 		for _, countryCode := range strings.Split(countryCodes, " ") {
 		for _, countryCode := range strings.Split(countryCodes, " ") {
-			if countryCode == targetCountryCode {
+			if countryCode == clientCountryCode {
 				return trafficRules
 				return trafficRules
 			}
 			}
 		}
 		}
@@ -305,6 +326,20 @@ func (config *Config) GetTrafficRules(targetCountryCode string) TrafficRules {
 	return config.DefaultTrafficRules
 	return config.DefaultTrafficRules
 }
 }
 
 
+// GetRateLimits looks up the rate limits for the specified tunnel protocol.
+// If there are no ProtocolRateLimits for the protocol, DefaultRateLimits are used.
+func (rules *TrafficRules) GetRateLimits(clientTunnelProtocol string) RateLimits {
+	// TODO: faster lookup?
+	for tunnelProtocols, rateLimits := range rules.ProtocolRateLimits {
+		for _, tunnelProtocol := range strings.Split(tunnelProtocols, " ") {
+			if tunnelProtocol == clientTunnelProtocol {
+				return rateLimits
+			}
+		}
+	}
+	return rules.DefaultRateLimits
+}
+
 // LoadConfig loads and validates a JSON encoded server config. If more than one
 // LoadConfig loads and validates a JSON encoded server config. If more than one
 // JSON config is specified, then all are loaded and values are merged together,
 // JSON config is specified, then all are loaded and values are merged together,
 // in order. Multiple configs allows for use cases like storing static, server-specific
 // in order. Multiple configs allows for use cases like storing static, server-specific
@@ -525,8 +560,12 @@ func GenerateConfig(serverIPaddress string) ([]byte, []byte, error) {
 		MeekProhibitedHeaders:          nil,
 		MeekProhibitedHeaders:          nil,
 		MeekProxyForwardedForHeaders:   []string{"X-Forwarded-For"},
 		MeekProxyForwardedForHeaders:   []string{"X-Forwarded-For"},
 		DefaultTrafficRules: TrafficRules{
 		DefaultTrafficRules: TrafficRules{
-			LimitDownstreamBytesPerSecond:      0,
-			LimitUpstreamBytesPerSecond:        0,
+			DefaultRateLimits: RateLimits{
+				DownstreamUnlimitedBytes: 0,
+				DownstreamBytesPerSecond: 0,
+				UpstreamUnlimitedBytes:   0,
+				UpstreamBytesPerSecond:   0,
+			},
 			IdlePortForwardTimeoutMilliseconds: 0,
 			IdlePortForwardTimeoutMilliseconds: 0,
 			MaxTCPPortForwardCount:             256,
 			MaxTCPPortForwardCount:             256,
 			MaxUDPPortForwardCount:             32,
 			MaxUDPPortForwardCount:             32,

+ 5 - 2
psiphon/server/tunnelServer.go

@@ -365,10 +365,13 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 
 
 	// Further wrap the connection in a rate limiting ThrottledConn.
 	// Further wrap the connection in a rate limiting ThrottledConn.
 
 
+	rateLimits := sshClient.trafficRules.GetRateLimits(tunnelProtocol)
 	clientConn = psiphon.NewThrottledConn(
 	clientConn = psiphon.NewThrottledConn(
 		clientConn,
 		clientConn,
-		int64(sshClient.trafficRules.LimitDownstreamBytesPerSecond),
-		int64(sshClient.trafficRules.LimitUpstreamBytesPerSecond))
+		rateLimits.DownstreamUnlimitedBytes,
+		int64(rateLimits.DownstreamBytesPerSecond),
+		rateLimits.UpstreamUnlimitedBytes,
+		int64(rateLimits.UpstreamBytesPerSecond))
 
 
 	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
 	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
 	// respect shutdownBroadcast and implement a specific handshake timeout.
 	// respect shutdownBroadcast and implement a specific handshake timeout.