Преглед изворни кода

Modifications to ThrottledConn
- support changing rate limits on-the-fly.
- add ability to disconnect after initial unthrottled period.
- change upstream/downstream naming to read/write since
client and server both use ThrottledConn and have opposite
traffic flow perspectives.

Rod Hynes пре 9 година
родитељ
комит
02e9bd9a75
3 измењених фајлова са 156 додато и 83 уклоњено
  1. 123 57
      psiphon/common/throttled.go
  2. 29 22
      psiphon/common/throttled_test.go
  3. 4 4
      psiphon/server/config.go

+ 123 - 57
psiphon/common/throttled.go

@@ -20,8 +20,10 @@
 package common
 package common
 
 
 import (
 import (
+	"errors"
 	"io"
 	"io"
 	"net"
 	"net"
+	"sync"
 	"sync/atomic"
 	"sync/atomic"
 
 
 	"github.com/Psiphon-Inc/ratelimit"
 	"github.com/Psiphon-Inc/ratelimit"
@@ -30,23 +32,27 @@ import (
 // RateLimits specify the rate limits for a ThrottledConn.
 // RateLimits specify the rate limits for a ThrottledConn.
 type RateLimits struct {
 type RateLimits struct {
 
 
-	// DownstreamUnlimitedBytes specifies the number of downstream
-	// bytes to transfer, approximately, before starting rate
-	// limiting.
-	DownstreamUnlimitedBytes int64
+	// ReadUnthrottledBytes specifies the number of bytes to
+	// read, approximately, before starting rate limiting.
+	ReadUnthrottledBytes int64
 
 
-	// DownstreamBytesPerSecond specifies a rate limit for downstream
+	// ReadBytesPerSecond specifies a rate limit for read
 	// data transfer. The default, 0, is no limit.
 	// data transfer. The default, 0, is no limit.
-	DownstreamBytesPerSecond int64
+	ReadBytesPerSecond int64
 
 
-	// UpstreamUnlimitedBytes specifies the number of upstream
-	// bytes to transfer, approximately, before starting rate
-	// limiting.
-	UpstreamUnlimitedBytes int64
+	// WriteUnthrottledBytes specifies the number of bytes to
+	// write, approximately, before starting rate limiting.
+	WriteUnthrottledBytes int64
 
 
-	// UpstreamBytesPerSecond specifies a rate limit for upstream
+	// WriteBytesPerSecond specifies a rate limit for write
 	// data transfer. The default, 0, is no limit.
 	// data transfer. The default, 0, is no limit.
-	UpstreamBytesPerSecond int64
+	WriteBytesPerSecond int64
+
+	// CloseAfterExhausted indicates that the underlying
+	// net.Conn should be closed once either the read or
+	// write unthrottled bytes have been exhausted. In this
+	// case, throttling is never applied.
+	CloseAfterExhausted bool
 }
 }
 
 
 // ThrottledConn wraps a net.Conn with read and write rate limiters.
 // ThrottledConn wraps a net.Conn with read and write rate limiters.
@@ -60,76 +66,136 @@ type ThrottledConn struct {
 	// Note: 64-bit ints used with atomic operations are at placed
 	// Note: 64-bit ints used with atomic operations are at placed
 	// at the start of struct to ensure 64-bit alignment.
 	// at the start of struct to ensure 64-bit alignment.
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
-	unlimitedReadBytes  int64
-	unlimitedWriteBytes int64
-	limitingReads       int32
-	limitingWrites      int32
-	limitedReader       io.Reader
-	limitedWriter       io.Writer
+	readUnthrottledBytes  int64
+	readBytesPerSecond    int64
+	writeUnthrottledBytes int64
+	writeBytesPerSecond   int64
+	closeAfterExhausted   int32
+	readLock              sync.Mutex
+	throttledReader       io.Reader
+	writeLock             sync.Mutex
+	throttledWriter       io.Writer
 	net.Conn
 	net.Conn
 }
 }
 
 
 // NewThrottledConn initializes a new ThrottledConn.
 // NewThrottledConn initializes a new ThrottledConn.
 func NewThrottledConn(conn net.Conn, limits RateLimits) *ThrottledConn {
 func NewThrottledConn(conn net.Conn, limits RateLimits) *ThrottledConn {
+	throttledConn := &ThrottledConn{Conn: conn}
+	throttledConn.SetLimits(limits)
+	return throttledConn
+}
 
 
-	// When no limit is specified, the rate limited reader/writer
-	// is simply the base reader/writer.
-
-	var reader io.Reader
-	if limits.DownstreamBytesPerSecond == 0 {
-		reader = conn
-	} else {
-		reader = ratelimit.Reader(conn,
-			ratelimit.NewBucketWithRate(
-				float64(limits.DownstreamBytesPerSecond),
-				limits.DownstreamBytesPerSecond))
+// SetLimits modifies the rate limits of an existing
+// ThrottledConn. It is safe to call SetLimits while
+// other goroutines are calling Read/Write. This function
+// will not block, and the new rate limits will be
+// applied within Read/Write, but not necessarily until
+// some futher I/O at previous rates.
+func (conn *ThrottledConn) SetLimits(limits RateLimits) {
+
+	// Using atomic instead of mutex to avoid blocking
+	// this function on throttled I/O in an ongoing
+	// read or write. Precise synchronized application
+	// of the rate limit values is not required.
+
+	// Negative rates are invalid and -1 is a special
+	// value to used to signal throttling initialized
+	// state. Silently normalize negative values to 0.
+	rate := limits.ReadBytesPerSecond
+	if rate < 0 {
+		rate = 0
 	}
 	}
+	atomic.StoreInt64(&conn.readBytesPerSecond, rate)
+	atomic.StoreInt64(&conn.readUnthrottledBytes, limits.ReadUnthrottledBytes)
 
 
-	var writer io.Writer
-	if limits.UpstreamBytesPerSecond == 0 {
-		writer = conn
-	} else {
-		writer = ratelimit.Writer(conn,
-			ratelimit.NewBucketWithRate(
-				float64(limits.UpstreamBytesPerSecond),
-				limits.UpstreamBytesPerSecond))
+	rate = limits.WriteBytesPerSecond
+	if rate < 0 {
+		rate = 0
 	}
 	}
+	atomic.StoreInt64(&conn.writeBytesPerSecond, rate)
+	atomic.StoreInt64(&conn.writeUnthrottledBytes, limits.WriteUnthrottledBytes)
 
 
-	return &ThrottledConn{
-		Conn:                conn,
-		unlimitedReadBytes:  limits.DownstreamUnlimitedBytes,
-		limitingReads:       0,
-		limitedReader:       reader,
-		unlimitedWriteBytes: limits.UpstreamUnlimitedBytes,
-		limitingWrites:      0,
-		limitedWriter:       writer,
+	closeAfterExhausted := int32(0)
+	if limits.CloseAfterExhausted {
+		closeAfterExhausted = 1
 	}
 	}
+	atomic.StoreInt32(&conn.closeAfterExhausted, closeAfterExhausted)
 }
 }
 
 
 func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
 func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
 
 
-	// Use the base reader until the unlimited count is exhausted.
-	if atomic.LoadInt32(&conn.limitingReads) == 0 {
+	// A mutex is used to ensure conformance with net.Conn
+	// concurrency semantics. The atomic.SwapInt64 and
+	// subsequent assignment of throttledReader could be
+	// a race condition with concurrent reads.
+	conn.readLock.Lock()
+	defer conn.readLock.Unlock()
+
+	// Use the base conn until the unthrottled count is
+	// exhausted. This is only an approximate enforcement
+	// since this read, or concurrent reads, could exceed
+	// the remaining count.
+	if atomic.LoadInt64(&conn.readUnthrottledBytes) > 0 {
 		n, err := conn.Conn.Read(buffer)
 		n, err := conn.Conn.Read(buffer)
-		if atomic.AddInt64(&conn.unlimitedReadBytes, -int64(n)) <= 0 {
-			atomic.StoreInt32(&conn.limitingReads, 1)
-		}
+		atomic.AddInt64(&conn.readUnthrottledBytes, -int64(n))
 		return n, err
 		return n, err
 	}
 	}
 
 
-	return conn.limitedReader.Read(buffer)
+	if atomic.LoadInt32(&conn.closeAfterExhausted) == 1 {
+		conn.Conn.Close()
+		return 0, errors.New("throttled conn exhausted")
+	}
+
+	rate := atomic.SwapInt64(&conn.readBytesPerSecond, -1)
+
+	if rate != -1 {
+		// SetLimits has been called and a new rate limiter
+		// must be initialized. When no limit is specified,
+		// the reader/writer is simply the base conn.
+		// No state is retained from the previous rate limiter,
+		// so a pending I/O throttle sleep may be skipped when
+		// the old and new rate are similar.
+		if rate == 0 {
+			conn.throttledReader = conn.Conn
+		} else {
+			conn.throttledReader = ratelimit.Reader(
+				conn.Conn,
+				ratelimit.NewBucketWithRate(float64(rate), rate))
+		}
+	}
+
+	return conn.throttledReader.Read(buffer)
 }
 }
 
 
 func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 
 
-	// Use the base writer until the unlimited count is exhausted.
-	if atomic.LoadInt32(&conn.limitingWrites) == 0 {
+	// See comments in Read.
+
+	conn.writeLock.Lock()
+	defer conn.writeLock.Unlock()
+
+	if atomic.LoadInt64(&conn.writeUnthrottledBytes) > 0 {
 		n, err := conn.Conn.Write(buffer)
 		n, err := conn.Conn.Write(buffer)
-		if atomic.AddInt64(&conn.unlimitedWriteBytes, -int64(n)) <= 0 {
-			atomic.StoreInt32(&conn.limitingWrites, 1)
-		}
+		atomic.AddInt64(&conn.writeUnthrottledBytes, -int64(n))
 		return n, err
 		return n, err
 	}
 	}
 
 
-	return conn.limitedWriter.Write(buffer)
+	if atomic.LoadInt32(&conn.closeAfterExhausted) == 1 {
+		conn.Conn.Close()
+		return 0, errors.New("throttled conn exhausted")
+	}
+
+	rate := atomic.SwapInt64(&conn.writeBytesPerSecond, -1)
+
+	if rate != -1 {
+		if rate == 0 {
+			conn.throttledWriter = conn.Conn
+		} else {
+			conn.throttledWriter = ratelimit.Writer(
+				conn.Conn,
+				ratelimit.NewBucketWithRate(float64(rate), rate))
+		}
+	}
+
+	return conn.throttledWriter.Write(buffer)
 }
 }

+ 29 - 22
psiphon/common/throttled_test.go

@@ -40,40 +40,47 @@ const (
 func TestThrottledConn(t *testing.T) {
 func TestThrottledConn(t *testing.T) {
 
 
 	run(t, RateLimits{
 	run(t, RateLimits{
-		DownstreamUnlimitedBytes: 0,
-		DownstreamBytesPerSecond: 0,
-		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   0,
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    0,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   0,
 	})
 	})
 
 
 	run(t, RateLimits{
 	run(t, RateLimits{
-		DownstreamUnlimitedBytes: 0,
-		DownstreamBytesPerSecond: 5 * 1024 * 1024,
-		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   5 * 1024 * 1024,
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    5 * 1024 * 1024,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   5 * 1024 * 1024,
 	})
 	})
 
 
 	run(t, RateLimits{
 	run(t, RateLimits{
-		DownstreamUnlimitedBytes: 0,
-		DownstreamBytesPerSecond: 2 * 1024 * 1024,
-		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   2 * 1024 * 1024,
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    5 * 1024 * 1024,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   1024 * 1024,
 	})
 	})
 
 
 	run(t, RateLimits{
 	run(t, RateLimits{
-		DownstreamUnlimitedBytes: 0,
-		DownstreamBytesPerSecond: 1024 * 1024,
-		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   1024 * 1024,
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    2 * 1024 * 1024,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   2 * 1024 * 1024,
+	})
+
+	run(t, RateLimits{
+		ReadUnthrottledBytes:  0,
+		ReadBytesPerSecond:    1024 * 1024,
+		WriteUnthrottledBytes: 0,
+		WriteBytesPerSecond:   1024 * 1024,
 	})
 	})
 
 
 	// This test takes > 1 min to run, so disabled for now
 	// This test takes > 1 min to run, so disabled for now
 	/*
 	/*
 		run(t, RateLimits{
 		run(t, RateLimits{
-			DownstreamUnlimitedBytes: 0,
-			DownstreamBytesPerSecond: 1024 * 1024 / 8,
-			UpstreamUnlimitedBytes:   0,
-			UpstreamBytesPerSecond:   1024 * 1024 / 8,
+			ReadUnthrottledBytes: 0,
+			ReadBytesPerSecond: 1024 * 1024 / 8,
+			WriteUnthrottledBytes:   0,
+			WriteBytesPerSecond:   1024 * 1024 / 8,
 		})
 		})
 	*/
 	*/
 }
 }
@@ -136,7 +143,7 @@ func run(t *testing.T, rateLimits RateLimits) {
 
 
 	// Test: elapsed upload time must reflect rate limit
 	// Test: elapsed upload time must reflect rate limit
 
 
-	checkElapsedTime(t, testDataSize, rateLimits.UpstreamBytesPerSecond, monotime.Since(startTime))
+	checkElapsedTime(t, testDataSize, rateLimits.WriteBytesPerSecond, monotime.Since(startTime))
 
 
 	startTime = monotime.Now()
 	startTime = monotime.Now()
 
 
@@ -150,7 +157,7 @@ func run(t *testing.T, rateLimits RateLimits) {
 
 
 	// Test: elapsed download time must reflect rate limit
 	// Test: elapsed download time must reflect rate limit
 
 
-	checkElapsedTime(t, testDataSize, rateLimits.DownstreamBytesPerSecond, monotime.Since(startTime))
+	checkElapsedTime(t, testDataSize, rateLimits.ReadBytesPerSecond, monotime.Since(startTime))
 }
 }
 
 
 func checkElapsedTime(t *testing.T, dataSize int, rateLimit int64, duration time.Duration) {
 func checkElapsedTime(t *testing.T, dataSize int, rateLimit int64, duration time.Duration) {

+ 4 - 4
psiphon/server/config.go

@@ -507,10 +507,10 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error
 	trafficRulesSet := &TrafficRulesSet{
 	trafficRulesSet := &TrafficRulesSet{
 		DefaultRules: TrafficRules{
 		DefaultRules: TrafficRules{
 			DefaultLimits: common.RateLimits{
 			DefaultLimits: common.RateLimits{
-				DownstreamUnlimitedBytes: 0,
-				DownstreamBytesPerSecond: 0,
-				UpstreamUnlimitedBytes:   0,
-				UpstreamBytesPerSecond:   0,
+				ReadUnthrottledBytes:  0,
+				ReadBytesPerSecond:    0,
+				WriteUnthrottledBytes: 0,
+				WriteBytesPerSecond:   0,
 			},
 			},
 			IdleTCPPortForwardTimeoutMilliseconds: 30000,
 			IdleTCPPortForwardTimeoutMilliseconds: 30000,
 			IdleUDPPortForwardTimeoutMilliseconds: 30000,
 			IdleUDPPortForwardTimeoutMilliseconds: 30000,