Просмотр исходного кода

Fix: ThrottledConn Read/Write block shutdown

- Large sleeps of many seconds are possible with
  very low rate limits

- Change sleeps to timers with a select that
  also monitors a close signal
Rod Hynes 6 лет назад
Родитель
Сommit
9bf6b0a5e3
2 измененных файлов с 185 добавлено и 26 удалено
  1. 91 18
      psiphon/common/throttled.go
  2. 94 8
      psiphon/common/throttled_test.go

+ 91 - 18
psiphon/common/throttled.go

@@ -20,10 +20,10 @@
 package common
 
 import (
-	"io"
 	"net"
 	"sync"
 	"sync/atomic"
+	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/juju/ratelimit"
@@ -72,15 +72,22 @@ type ThrottledConn struct {
 	writeBytesPerSecond   int64
 	closeAfterExhausted   int32
 	readLock              sync.Mutex
-	throttledReader       io.Reader
+	readRateLimiter       *ratelimit.Bucket
+	readDelayTimer        *time.Timer
 	writeLock             sync.Mutex
-	throttledWriter       io.Writer
+	writeRateLimiter      *ratelimit.Bucket
+	writeDelayTimer       *time.Timer
+	isClosed              int32
+	stopBroadcast         chan struct{}
 	net.Conn
 }
 
 // NewThrottledConn initializes a new ThrottledConn.
 func NewThrottledConn(conn net.Conn, limits RateLimits) *ThrottledConn {
-	throttledConn := &ThrottledConn{Conn: conn}
+	throttledConn := &ThrottledConn{
+		Conn:          conn,
+		stopBroadcast: make(chan struct{}),
+	}
 	throttledConn.SetLimits(limits)
 	return throttledConn
 }
@@ -124,13 +131,18 @@ func (conn *ThrottledConn) SetLimits(limits RateLimits) {
 
 func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
 
-	// A mutex is used to ensure conformance with net.Conn
-	// concurrency semantics. The atomic.SwapInt64 and
-	// subsequent assignment of throttledReader could be
-	// a race condition with concurrent reads.
+	// A mutex is used to ensure conformance with net.Conn concurrency semantics.
+	// The atomic.SwapInt64 and subsequent assignment of readRateLimiter or
+	// readDelayTimer could be a race condition with concurrent reads.
 	conn.readLock.Lock()
 	defer conn.readLock.Unlock()
 
+	select {
+	case <-conn.stopBroadcast:
+		return 0, errors.TraceNew("throttled conn closed")
+	default:
+	}
+
 	// Use the base conn until the unthrottled count is
 	// exhausted. This is only an approximate enforcement
 	// since this read, or concurrent reads, could exceed
@@ -156,15 +168,40 @@ func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
 		// so a pending I/O throttle sleep may be skipped when
 		// the old and new rate are similar.
 		if rate == 0 {
-			conn.throttledReader = conn.Conn
+			conn.readRateLimiter = nil
 		} else {
-			conn.throttledReader = ratelimit.Reader(
-				conn.Conn,
-				ratelimit.NewBucketWithRate(float64(rate), rate))
+			conn.readRateLimiter =
+				ratelimit.NewBucketWithRate(float64(rate), rate)
+		}
+	}
+
+	n, err := conn.Conn.Read(buffer)
+
+	// Sleep to enforce the rate limit. This is the same logic as implemented in
+	// ratelimit.Reader, but using a timer and a close signal instead of an
+	// uninterruptible time.Sleep.
+	//
+	// The readDelayTimer is always expired/stopped and drained after this code
+	// block and is ready to be Reset on the next call.
+
+	if n >= 0 && conn.readRateLimiter != nil {
+		sleepDuration := conn.readRateLimiter.Take(int64(n))
+		if sleepDuration > 0 {
+			if conn.readDelayTimer == nil {
+				conn.readDelayTimer = time.NewTimer(sleepDuration)
+			} else {
+				conn.readDelayTimer.Reset(sleepDuration)
+			}
+			select {
+			case <-conn.readDelayTimer.C:
+			case <-conn.stopBroadcast:
+				if !conn.readDelayTimer.Stop() {
+					<-conn.readDelayTimer.C
+				}
+			}
 		}
 	}
 
-	n, err := conn.throttledReader.Read(buffer)
 	return n, errors.Trace(err)
 }
 
@@ -175,6 +212,12 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 	conn.writeLock.Lock()
 	defer conn.writeLock.Unlock()
 
+	select {
+	case <-conn.stopBroadcast:
+		return 0, errors.TraceNew("throttled conn closed")
+	default:
+	}
+
 	if atomic.LoadInt64(&conn.writeUnthrottledBytes) > 0 {
 		n, err := conn.Conn.Write(buffer)
 		atomic.AddInt64(&conn.writeUnthrottledBytes, -int64(n))
@@ -190,14 +233,44 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 
 	if rate != -1 {
 		if rate == 0 {
-			conn.throttledWriter = conn.Conn
+			conn.writeRateLimiter = nil
 		} else {
-			conn.throttledWriter = ratelimit.Writer(
-				conn.Conn,
-				ratelimit.NewBucketWithRate(float64(rate), rate))
+			conn.writeRateLimiter =
+				ratelimit.NewBucketWithRate(float64(rate), rate)
 		}
 	}
 
-	n, err := conn.throttledWriter.Write(buffer)
+	if len(buffer) >= 0 && conn.writeRateLimiter != nil {
+		sleepDuration := conn.writeRateLimiter.Take(int64(len(buffer)))
+		if sleepDuration > 0 {
+			if conn.writeDelayTimer == nil {
+				conn.writeDelayTimer = time.NewTimer(sleepDuration)
+			} else {
+				conn.writeDelayTimer.Reset(sleepDuration)
+			}
+			select {
+			case <-conn.writeDelayTimer.C:
+			case <-conn.stopBroadcast:
+				if !conn.writeDelayTimer.Stop() {
+					<-conn.writeDelayTimer.C
+				}
+			}
+		}
+	}
+
+	n, err := conn.Conn.Write(buffer)
+
 	return n, errors.Trace(err)
 }
+
+func (conn *ThrottledConn) Close() error {
+
+	// Ensure close channel only called once.
+	if !atomic.CompareAndSwapInt32(&conn.isClosed, 0, 1) {
+		return nil
+	}
+
+	close(conn.stopBroadcast)
+
+	return errors.Trace(conn.Conn.Close())
+}

+ 94 - 8
psiphon/common/throttled_test.go

@@ -35,37 +35,37 @@ const (
 	testDataSize  = 10 * 1024 * 1024 // 10 MB
 )
 
-func TestThrottledConn(t *testing.T) {
+func TestThrottledConnRates(t *testing.T) {
 
-	run(t, RateLimits{
+	runRateLimitsTest(t, RateLimits{
 		ReadUnthrottledBytes:  0,
 		ReadBytesPerSecond:    0,
 		WriteUnthrottledBytes: 0,
 		WriteBytesPerSecond:   0,
 	})
 
-	run(t, RateLimits{
+	runRateLimitsTest(t, RateLimits{
 		ReadUnthrottledBytes:  0,
 		ReadBytesPerSecond:    5 * 1024 * 1024,
 		WriteUnthrottledBytes: 0,
 		WriteBytesPerSecond:   5 * 1024 * 1024,
 	})
 
-	run(t, RateLimits{
+	runRateLimitsTest(t, RateLimits{
 		ReadUnthrottledBytes:  0,
 		ReadBytesPerSecond:    5 * 1024 * 1024,
 		WriteUnthrottledBytes: 0,
 		WriteBytesPerSecond:   1024 * 1024,
 	})
 
-	run(t, RateLimits{
+	runRateLimitsTest(t, RateLimits{
 		ReadUnthrottledBytes:  0,
 		ReadBytesPerSecond:    2 * 1024 * 1024,
 		WriteUnthrottledBytes: 0,
 		WriteBytesPerSecond:   2 * 1024 * 1024,
 	})
 
-	run(t, RateLimits{
+	runRateLimitsTest(t, RateLimits{
 		ReadUnthrottledBytes:  0,
 		ReadBytesPerSecond:    1024 * 1024,
 		WriteUnthrottledBytes: 0,
@@ -74,7 +74,7 @@ func TestThrottledConn(t *testing.T) {
 
 	// This test takes > 1 min to run, so disabled for now
 	/*
-		run(t, RateLimits{
+		runRateLimitsTest(t, RateLimits{
 			ReadUnthrottledBytes: 0,
 			ReadBytesPerSecond: 1024 * 1024 / 8,
 			WriteUnthrottledBytes:   0,
@@ -83,7 +83,7 @@ func TestThrottledConn(t *testing.T) {
 	*/
 }
 
-func run(t *testing.T, rateLimits RateLimits) {
+func runRateLimitsTest(t *testing.T, rateLimits RateLimits) {
 
 	// Run a local HTTP server which serves large chunks of data
 
@@ -193,3 +193,89 @@ func checkElapsedTime(t *testing.T, dataSize int, rateLimit int64, duration time
 		t.Errorf("unexpected duration: %s > %s", duration, ceilingElapsedTime)
 	}
 }
+
+func TestThrottledConnClose(t *testing.T) {
+
+	rateLimits := RateLimits{
+		ReadBytesPerSecond:  1,
+		WriteBytesPerSecond: 1,
+	}
+
+	n := 4
+	b := make([]byte, n+1)
+
+	throttledConn := NewThrottledConn(&testConn{}, rateLimits)
+
+	now := time.Now()
+	_, err := throttledConn.Read(b)
+	elapsed := time.Since(now)
+	if err != nil || elapsed < time.Duration(n)*time.Second {
+		t.Errorf("unexpected interrupted read: %s, %v", elapsed, err)
+	}
+
+	now = time.Now()
+	go func() {
+		time.Sleep(500 * time.Millisecond)
+		throttledConn.Close()
+	}()
+	_, err = throttledConn.Read(b)
+	elapsed = time.Since(now)
+	if elapsed > 1*time.Second {
+		t.Errorf("unexpected uninterrupted read: %s, %v", elapsed, err)
+	}
+
+	throttledConn = NewThrottledConn(&testConn{}, rateLimits)
+
+	now = time.Now()
+	_, err = throttledConn.Write(b)
+	elapsed = time.Since(now)
+	if err != nil || elapsed < time.Duration(n)*time.Second {
+		t.Errorf("unexpected interrupted write: %s, %v", elapsed, err)
+	}
+
+	now = time.Now()
+	go func() {
+		time.Sleep(500 * time.Millisecond)
+		throttledConn.Close()
+	}()
+	_, err = throttledConn.Write(b)
+	elapsed = time.Since(now)
+	if elapsed > 1*time.Second {
+		t.Errorf("unexpected uninterrupted write: %s, %v", elapsed, err)
+	}
+}
+
+type testConn struct {
+}
+
+func (conn *testConn) Read(b []byte) (n int, err error) {
+	return len(b), nil
+}
+
+func (conn *testConn) Write(b []byte) (n int, err error) {
+	return len(b), nil
+}
+
+func (conn *testConn) Close() error {
+	return nil
+}
+
+func (conn *testConn) LocalAddr() net.Addr {
+	return nil
+}
+
+func (conn *testConn) RemoteAddr() net.Addr {
+	return nil
+}
+
+func (conn *testConn) SetDeadline(t time.Time) error {
+	return nil
+}
+
+func (conn *testConn) SetReadDeadline(t time.Time) error {
+	return nil
+}
+
+func (conn *testConn) SetWriteDeadline(t time.Time) error {
+	return nil
+}