|
|
@@ -20,10 +20,10 @@
|
|
|
package common
|
|
|
|
|
|
import (
|
|
|
- "io"
|
|
|
"net"
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
|
+ "time"
|
|
|
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
|
|
|
"github.com/juju/ratelimit"
|
|
|
@@ -72,15 +72,22 @@ type ThrottledConn struct {
|
|
|
writeBytesPerSecond int64
|
|
|
closeAfterExhausted int32
|
|
|
readLock sync.Mutex
|
|
|
- throttledReader io.Reader
|
|
|
+ readRateLimiter *ratelimit.Bucket
|
|
|
+ readDelayTimer *time.Timer
|
|
|
writeLock sync.Mutex
|
|
|
- throttledWriter io.Writer
|
|
|
+ writeRateLimiter *ratelimit.Bucket
|
|
|
+ writeDelayTimer *time.Timer
|
|
|
+ isClosed int32
|
|
|
+ stopBroadcast chan struct{}
|
|
|
net.Conn
|
|
|
}
|
|
|
|
|
|
// NewThrottledConn initializes a new ThrottledConn.
|
|
|
func NewThrottledConn(conn net.Conn, limits RateLimits) *ThrottledConn {
|
|
|
- throttledConn := &ThrottledConn{Conn: conn}
|
|
|
+ throttledConn := &ThrottledConn{
|
|
|
+ Conn: conn,
|
|
|
+ stopBroadcast: make(chan struct{}),
|
|
|
+ }
|
|
|
throttledConn.SetLimits(limits)
|
|
|
return throttledConn
|
|
|
}
|
|
|
@@ -124,13 +131,18 @@ func (conn *ThrottledConn) SetLimits(limits RateLimits) {
|
|
|
|
|
|
func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
|
|
|
|
|
|
- // A mutex is used to ensure conformance with net.Conn
|
|
|
- // concurrency semantics. The atomic.SwapInt64 and
|
|
|
- // subsequent assignment of throttledReader could be
|
|
|
- // a race condition with concurrent reads.
|
|
|
+ // A mutex is used to ensure conformance with net.Conn concurrency semantics.
|
|
|
+ // The atomic.SwapInt64 and subsequent assignment of readRateLimiter or
|
|
|
+ // readDelayTimer could be a race condition with concurrent reads.
|
|
|
conn.readLock.Lock()
|
|
|
defer conn.readLock.Unlock()
|
|
|
|
|
|
+ select {
|
|
|
+ case <-conn.stopBroadcast:
|
|
|
+ return 0, errors.TraceNew("throttled conn closed")
|
|
|
+ default:
|
|
|
+ }
|
|
|
+
|
|
|
// Use the base conn until the unthrottled count is
|
|
|
// exhausted. This is only an approximate enforcement
|
|
|
// since this read, or concurrent reads, could exceed
|
|
|
@@ -156,15 +168,40 @@ func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
|
|
|
// so a pending I/O throttle sleep may be skipped when
|
|
|
// the old and new rate are similar.
|
|
|
if rate == 0 {
|
|
|
- conn.throttledReader = conn.Conn
|
|
|
+ conn.readRateLimiter = nil
|
|
|
} else {
|
|
|
- conn.throttledReader = ratelimit.Reader(
|
|
|
- conn.Conn,
|
|
|
- ratelimit.NewBucketWithRate(float64(rate), rate))
|
|
|
+ conn.readRateLimiter =
|
|
|
+ ratelimit.NewBucketWithRate(float64(rate), rate)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ n, err := conn.Conn.Read(buffer)
|
|
|
+
|
|
|
+ // Sleep to enforce the rate limit. This is the same logic as implemented in
|
|
|
+ // ratelimit.Reader, but using a timer and a close signal instead of an
|
|
|
+ // uninterruptible time.Sleep.
|
|
|
+ //
|
|
|
+ // The readDelayTimer is always expired/stopped and drained after this code
|
|
|
+ // block and is ready to be Reset on the next call.
|
|
|
+
|
|
|
+ if n >= 0 && conn.readRateLimiter != nil {
|
|
|
+ sleepDuration := conn.readRateLimiter.Take(int64(n))
|
|
|
+ if sleepDuration > 0 {
|
|
|
+ if conn.readDelayTimer == nil {
|
|
|
+ conn.readDelayTimer = time.NewTimer(sleepDuration)
|
|
|
+ } else {
|
|
|
+ conn.readDelayTimer.Reset(sleepDuration)
|
|
|
+ }
|
|
|
+ select {
|
|
|
+ case <-conn.readDelayTimer.C:
|
|
|
+ case <-conn.stopBroadcast:
|
|
|
+ if !conn.readDelayTimer.Stop() {
|
|
|
+ <-conn.readDelayTimer.C
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- n, err := conn.throttledReader.Read(buffer)
|
|
|
return n, errors.Trace(err)
|
|
|
}
|
|
|
|
|
|
@@ -175,6 +212,12 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
|
|
|
conn.writeLock.Lock()
|
|
|
defer conn.writeLock.Unlock()
|
|
|
|
|
|
+ select {
|
|
|
+ case <-conn.stopBroadcast:
|
|
|
+ return 0, errors.TraceNew("throttled conn closed")
|
|
|
+ default:
|
|
|
+ }
|
|
|
+
|
|
|
if atomic.LoadInt64(&conn.writeUnthrottledBytes) > 0 {
|
|
|
n, err := conn.Conn.Write(buffer)
|
|
|
atomic.AddInt64(&conn.writeUnthrottledBytes, -int64(n))
|
|
|
@@ -190,14 +233,44 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
|
|
|
|
|
|
if rate != -1 {
|
|
|
if rate == 0 {
|
|
|
- conn.throttledWriter = conn.Conn
|
|
|
+ conn.writeRateLimiter = nil
|
|
|
} else {
|
|
|
- conn.throttledWriter = ratelimit.Writer(
|
|
|
- conn.Conn,
|
|
|
- ratelimit.NewBucketWithRate(float64(rate), rate))
|
|
|
+ conn.writeRateLimiter =
|
|
|
+ ratelimit.NewBucketWithRate(float64(rate), rate)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- n, err := conn.throttledWriter.Write(buffer)
|
|
|
+ if len(buffer) >= 0 && conn.writeRateLimiter != nil {
|
|
|
+ sleepDuration := conn.writeRateLimiter.Take(int64(len(buffer)))
|
|
|
+ if sleepDuration > 0 {
|
|
|
+ if conn.writeDelayTimer == nil {
|
|
|
+ conn.writeDelayTimer = time.NewTimer(sleepDuration)
|
|
|
+ } else {
|
|
|
+ conn.writeDelayTimer.Reset(sleepDuration)
|
|
|
+ }
|
|
|
+ select {
|
|
|
+ case <-conn.writeDelayTimer.C:
|
|
|
+ case <-conn.stopBroadcast:
|
|
|
+ if !conn.writeDelayTimer.Stop() {
|
|
|
+ <-conn.writeDelayTimer.C
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ n, err := conn.Conn.Write(buffer)
|
|
|
+
|
|
|
return n, errors.Trace(err)
|
|
|
}
|
|
|
+
|
|
|
+func (conn *ThrottledConn) Close() error {
|
|
|
+
|
|
|
+ // Ensure close channel only called once.
|
|
|
+ if !atomic.CompareAndSwapInt32(&conn.isClosed, 0, 1) {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ close(conn.stopBroadcast)
|
|
|
+
|
|
|
+ return errors.Trace(conn.Conn.Close())
|
|
|
+}
|