Browse Source

Add tunnel data transfer burst monitoring

- Also change the condition for tunnel/connection
  activity (via ActivityMoniteredConn) so that activity
  is recorded when a read/write fails but still transfers
  some bytes.
Rod Hynes 5 years ago
parent
commit
dd1c703d7a

+ 287 - 28
psiphon/common/net.go

@@ -261,22 +261,25 @@ func (entry *LRUConnsEntry) Touch() {
 	entry.lruConns.list.MoveToFront(entry.element)
 }
 
-// ActivityMonitoredConn wraps a net.Conn, adding logic to deal with
-// events triggered by I/O activity.
+// ActivityMonitoredConn wraps a net.Conn, adding logic to deal with events
+// triggered by I/O activity.
 //
-// When an inactivity timeout is specified, the network I/O will
-// timeout after the specified period of read inactivity. Optionally,
-// for the purpose of inactivity only, ActivityMonitoredConn will also
-// consider the connection active when data is written to it.
+// ActivityMonitoredConn uses lock-free concurrency synronization, avoiding an
+// additional mutex resource, making it suitable for wrapping many net.Conns
+// (e.g, each Psiphon port forward).
 //
-// When a LRUConnsEntry is specified, then the LRU entry is promoted on
-// either a successful read or write.
+// When an inactivity timeout is specified, the network I/O will timeout after
+// the specified period of read inactivity. Optionally, for the purpose of
+// inactivity only, ActivityMonitoredConn will also consider the connection
+// active when data is written to it.
 //
-// When an ActivityUpdater is set, then its UpdateActivity method is
-// called on each read and write with the number of bytes transferred.
-// The durationNanoseconds, which is the time since the last read, is
-// reported only on reads.
+// When a LRUConnsEntry is specified, then the LRU entry is promoted on either
+// a successful read or write.
 //
+// When an ActivityUpdater is set, then its UpdateActivity method is called on
+// each read and write with the number of bytes transferred. The
+// durationNanoseconds, which is the time since the last read, is reported
+// only on reads.
 type ActivityMonitoredConn struct {
 	// Note: 64-bit ints used with atomic operations are placed
 	// at the start of struct to ensure 64-bit alignment.
@@ -292,8 +295,8 @@ type ActivityMonitoredConn struct {
 }
 
 // ActivityUpdater defines an interface for receiving updates for
-// ActivityMonitoredConn activity. Values passed to UpdateProgress are
-// bytes transferred and conn duration since the previous UpdateProgress.
+// ActivityMonitoredConn activity. Values passed to UpdateProgress are bytes
+// transferred and conn duration since the previous UpdateProgress.
 type ActivityUpdater interface {
 	UpdateProgress(bytesRead, bytesWritten int64, durationNanoseconds int64)
 }
@@ -313,6 +316,9 @@ func NewActivityMonitoredConn(
 		}
 	}
 
+	// The "monotime" package is still used here as its time value is an int64,
+	// which is compatible with atomic operations.
+
 	now := int64(monotime.Now())
 
 	return &ActivityMonitoredConn{
@@ -327,27 +333,22 @@ func NewActivityMonitoredConn(
 	}, nil
 }
 
-// GetStartTime gets the time when the ActivityMonitoredConn was
-// initialized. Reported time is UTC.
+// GetStartTime gets the time when the ActivityMonitoredConn was initialized.
+// Reported time is UTC.
 func (conn *ActivityMonitoredConn) GetStartTime() time.Time {
 	return conn.realStartTime.UTC()
 }
 
-// GetActiveDuration returns the time elapsed between the initialization
-// of the ActivityMonitoredConn and the last Read. Only reads are used
-// for this calculation since writes may succeed locally due to buffering.
+// GetActiveDuration returns the time elapsed between the initialization of
+// the ActivityMonitoredConn and the last Read. Only reads are used for this
+// calculation since writes may succeed locally due to buffering.
 func (conn *ActivityMonitoredConn) GetActiveDuration() time.Duration {
 	return time.Duration(atomic.LoadInt64(&conn.lastReadActivityTime) - conn.monotonicStartTime)
 }
 
-// GetLastActivityMonotime returns the arbitrary monotonic time of the last Read.
-func (conn *ActivityMonitoredConn) GetLastActivityMonotime() monotime.Time {
-	return monotime.Time(atomic.LoadInt64(&conn.lastReadActivityTime))
-}
-
 func (conn *ActivityMonitoredConn) Read(buffer []byte) (int, error) {
 	n, err := conn.Conn.Read(buffer)
-	if err == nil {
+	if n > 0 {
 
 		if conn.inactivityTimeout > 0 {
 			err = conn.Conn.SetDeadline(time.Now().Add(conn.inactivityTimeout))
@@ -376,7 +377,7 @@ func (conn *ActivityMonitoredConn) Read(buffer []byte) (int, error) {
 
 func (conn *ActivityMonitoredConn) Write(buffer []byte) (int, error) {
 	n, err := conn.Conn.Write(buffer)
-	if err == nil && conn.activeOnWrite {
+	if n > 0 && conn.activeOnWrite {
 
 		if conn.inactivityTimeout > 0 {
 			err = conn.Conn.SetDeadline(time.Now().Add(conn.inactivityTimeout))
@@ -398,8 +399,8 @@ func (conn *ActivityMonitoredConn) Write(buffer []byte) (int, error) {
 	return n, err
 }
 
-// IsClosed implements the Closer iterface. The return value
-// indicates whether the underlying conn has been closed.
+// IsClosed implements the Closer iterface. The return value indicates whether
+// the underlying conn has been closed.
 func (conn *ActivityMonitoredConn) IsClosed() bool {
 	closer, ok := conn.Conn.(Closer)
 	if !ok {
@@ -408,6 +409,264 @@ func (conn *ActivityMonitoredConn) IsClosed() bool {
 	return closer.IsClosed()
 }
 
+// BurstMonitoredConn wraps a net.Conn and monitors for data transfer bursts.
+// Upstream (read) and downstream (write) bursts are tracked independently.
+//
+// A burst is defined as a transfer of at least "threshold" bytes, across
+// multiple I/O operations where the delay between operations does not exceed
+// "deadline". Both a non-zero deadline and theshold must be set to enable
+// monitoring. Four bursts are reported: the first, the last, the min (by
+// rate) and max.
+//
+// The reported rates will be more accurate for larger data transfers,
+// especially for higher transfer rates. Tune the deadline/threshold as
+// required. The threshold should be set to account for buffering (e.g, the
+// local host socket send/receive buffer) but this is not enforced by
+// BurstMonitoredConn.
+//
+// Close must be called to complete any outstanding bursts. For complete
+// results, call GetMetrics only after Close is called.
+//
+// Overhead: BurstMonitoredConn adds mutexes but does not use timers.
+type BurstMonitoredConn struct {
+	net.Conn
+	upstreamDeadline         time.Duration
+	upstreamThresholdBytes   int64
+	downstreamDeadline       time.Duration
+	downstreamThresholdBytes int64
+
+	readMutex            sync.Mutex
+	currentUpstreamBurst burst
+	upstreamBursts       burstHistory
+
+	writeMutex             sync.Mutex
+	currentDownstreamBurst burst
+	downstreamBursts       burstHistory
+}
+
+// NewBurstMonitoredConn creates a new BurstMonitoredConn.
+func NewBurstMonitoredConn(
+	conn net.Conn,
+	upstreamDeadline time.Duration,
+	upstreamThresholdBytes int64,
+	downstreamDeadline time.Duration,
+	downstreamThresholdBytes int64) *BurstMonitoredConn {
+
+	return &BurstMonitoredConn{
+		Conn:                     conn,
+		upstreamDeadline:         upstreamDeadline,
+		upstreamThresholdBytes:   upstreamThresholdBytes,
+		downstreamDeadline:       downstreamDeadline,
+		downstreamThresholdBytes: downstreamThresholdBytes,
+	}
+}
+
+type burst struct {
+	startTime    time.Time
+	lastByteTime time.Time
+	bytes        int64
+}
+
+func (b *burst) isZero() bool {
+	return b.startTime.IsZero()
+}
+
+func (b *burst) offset(baseTime time.Time) time.Duration {
+	offset := b.startTime.Sub(baseTime)
+	if offset <= 0 {
+		return 0
+	}
+	return offset
+}
+
+func (b *burst) duration() time.Duration {
+	duration := b.lastByteTime.Sub(b.startTime)
+	if duration <= 0 {
+		return 0
+	}
+	return duration
+}
+
+func (b *burst) rate() int64 {
+	return int64(
+		(float64(b.bytes) * float64(time.Second)) /
+			float64(b.duration()))
+}
+
+type burstHistory struct {
+	first burst
+	last  burst
+	min   burst
+	max   burst
+}
+
+func (conn *BurstMonitoredConn) Read(buffer []byte) (int, error) {
+
+	start := time.Now()
+	n, err := conn.Conn.Read(buffer)
+	end := time.Now()
+
+	if n > 0 &&
+		conn.upstreamDeadline > 0 && conn.upstreamThresholdBytes > 0 {
+
+		conn.readMutex.Lock()
+		conn.updateBurst(
+			start,
+			end,
+			int64(n),
+			conn.upstreamDeadline,
+			conn.upstreamThresholdBytes,
+			&conn.currentUpstreamBurst,
+			&conn.upstreamBursts)
+		conn.readMutex.Unlock()
+	}
+
+	// Note: no context error to preserve error type
+	return n, err
+}
+
+func (conn *BurstMonitoredConn) Write(buffer []byte) (int, error) {
+
+	start := time.Now()
+	n, err := conn.Conn.Write(buffer)
+	end := time.Now()
+
+	if n > 0 &&
+		conn.downstreamDeadline > 0 && conn.downstreamThresholdBytes > 0 {
+
+		conn.writeMutex.Lock()
+		conn.updateBurst(
+			start,
+			end,
+			int64(n),
+			conn.downstreamDeadline,
+			conn.downstreamThresholdBytes,
+			&conn.currentDownstreamBurst,
+			&conn.downstreamBursts)
+		conn.writeMutex.Unlock()
+	}
+
+	// Note: no context error to preserve error type
+	return n, err
+}
+
+func (conn *BurstMonitoredConn) Close() error {
+	err := conn.Conn.Close()
+
+	conn.readMutex.Lock()
+	conn.endBurst(
+		conn.upstreamThresholdBytes,
+		&conn.currentUpstreamBurst,
+		&conn.upstreamBursts)
+	conn.readMutex.Unlock()
+
+	conn.writeMutex.Lock()
+	conn.endBurst(
+		conn.downstreamThresholdBytes,
+		&conn.currentDownstreamBurst,
+		&conn.downstreamBursts)
+	conn.writeMutex.Unlock()
+
+	// Note: no context error to preserve error type
+	return err
+}
+
+// GetMetrics returns log fields with burst metrics for the first, last, min
+// (by rate), and max bursts for this conn. Time/duration values are reported
+// in milliseconds.
+func (conn *BurstMonitoredConn) GetMetrics(baseTime time.Time) LogFields {
+	logFields := make(LogFields)
+
+	addFields := func(prefix string, burst *burst) {
+		if burst.isZero() {
+			return
+		}
+		logFields[prefix+"offset"] = int64(burst.offset(baseTime) / time.Millisecond)
+		logFields[prefix+"duration"] = int64(burst.duration() / time.Millisecond)
+		logFields[prefix+"bytes"] = burst.bytes
+		logFields[prefix+"rate"] = burst.rate()
+	}
+
+	addHistory := func(prefix string, history *burstHistory) {
+		addFields(prefix+"first_", &history.first)
+		addFields(prefix+"last_", &history.last)
+		addFields(prefix+"min_", &history.min)
+		addFields(prefix+"max_", &history.max)
+	}
+
+	addHistory("burst_upstream_", &conn.upstreamBursts)
+	addHistory("burst_downstream_", &conn.downstreamBursts)
+
+	return logFields
+}
+
+func (conn *BurstMonitoredConn) updateBurst(
+	operationStart time.Time,
+	operationEnd time.Time,
+	operationBytes int64,
+	deadline time.Duration,
+	thresholdBytes int64,
+	currentBurst *burst,
+	history *burstHistory) {
+
+	// Assumes the associated mutex is locked.
+
+	if currentBurst.isZero() {
+		currentBurst.startTime = operationStart
+		currentBurst.lastByteTime = operationEnd
+		currentBurst.bytes = operationBytes
+
+	} else {
+
+		if operationStart.Sub(currentBurst.lastByteTime) >
+			deadline {
+
+			conn.endBurst(thresholdBytes, currentBurst, history)
+			currentBurst.startTime = operationStart
+		}
+
+		currentBurst.lastByteTime = operationEnd
+		currentBurst.bytes += operationBytes
+	}
+
+}
+
+func (conn *BurstMonitoredConn) endBurst(
+	thresholdBytes int64,
+	currentBurst *burst,
+	history *burstHistory) {
+
+	// Assumes the associated mutex is locked.
+
+	if currentBurst.isZero() {
+		return
+	}
+
+	burst := *currentBurst
+
+	currentBurst.startTime = time.Time{}
+	currentBurst.lastByteTime = time.Time{}
+	currentBurst.bytes = 0
+
+	if burst.bytes < thresholdBytes {
+		return
+	}
+
+	if history.first.isZero() {
+		history.first = burst
+	}
+
+	history.last = burst
+
+	if history.min.isZero() || history.min.rate() > burst.rate() {
+		history.min = burst
+	}
+
+	if history.max.isZero() || history.max.rate() < burst.rate() {
+		history.max = burst
+	}
+}
+
 // IsBogon checks if the specified IP is a bogon (loopback, private addresses,
 // link-local addresses, etc.)
 func IsBogon(IP net.IP) bool {

+ 246 - 75
psiphon/common/net_test.go

@@ -30,76 +30,6 @@ import (
 	"github.com/miekg/dns"
 )
 
-type dummyConn struct {
-	t        *testing.T
-	timeout  *time.Timer
-	isClosed int32
-}
-
-func (c *dummyConn) Read(b []byte) (n int, err error) {
-	if c.timeout != nil {
-		select {
-		case <-c.timeout.C:
-			return 0, iotest.ErrTimeout
-		default:
-		}
-	}
-	return len(b), nil
-}
-
-func (c *dummyConn) Write(b []byte) (n int, err error) {
-	if c.timeout != nil {
-		select {
-		case <-c.timeout.C:
-			return 0, iotest.ErrTimeout
-		default:
-		}
-	}
-	return len(b), nil
-}
-
-func (c *dummyConn) Close() error {
-	atomic.StoreInt32(&c.isClosed, 1)
-	return nil
-}
-
-func (c *dummyConn) IsClosed() bool {
-	return atomic.LoadInt32(&c.isClosed) == 1
-}
-
-func (c *dummyConn) LocalAddr() net.Addr {
-	c.t.Fatal("LocalAddr not implemented")
-	return nil
-}
-
-func (c *dummyConn) RemoteAddr() net.Addr {
-	c.t.Fatal("RemoteAddr not implemented")
-	return nil
-}
-
-func (c *dummyConn) SetDeadline(t time.Time) error {
-	duration := time.Until(t)
-	if c.timeout == nil {
-		c.timeout = time.NewTimer(duration)
-	} else {
-		if !c.timeout.Stop() {
-			<-c.timeout.C
-		}
-		c.timeout.Reset(duration)
-	}
-	return nil
-}
-
-func (c *dummyConn) SetReadDeadline(t time.Time) error {
-	c.t.Fatal("SetReadDeadline not implemented")
-	return nil
-}
-
-func (c *dummyConn) SetWriteDeadline(t time.Time) error {
-	c.t.Fatal("SetWriteDeadline not implemented")
-	return nil
-}
-
 func TestActivityMonitoredConn(t *testing.T) {
 	buffer := make([]byte, 1024)
 
@@ -165,11 +95,6 @@ func TestActivityMonitoredConn(t *testing.T) {
 		t.Fatalf("unexpected GetStartTime")
 	}
 
-	if int64(lastSuccessfulReadTime)/int64(time.Millisecond) !=
-		int64(conn.GetLastActivityMonotime())/int64(time.Millisecond) {
-		t.Fatalf("unexpected GetLastActivityTime")
-	}
-
 	diff := lastSuccessfulReadTime.Sub(monotonicStartTime).Nanoseconds() - conn.GetActiveDuration().Nanoseconds()
 	if diff < 0 {
 		diff = -diff
@@ -275,6 +200,167 @@ func TestLRUConns(t *testing.T) {
 	}
 }
 
+func TestBurstMonitoredConn(t *testing.T) {
+
+	burstDeadline := 100 * time.Millisecond
+	upstreamThresholdBytes := int64(100000)
+	downstreamThresholdBytes := int64(1000000)
+
+	baseTime := time.Now()
+
+	dummy := &dummyConn{}
+
+	conn := NewBurstMonitoredConn(
+		dummy,
+		burstDeadline,
+		upstreamThresholdBytes,
+		burstDeadline,
+		downstreamThresholdBytes)
+
+	// Simulate 128KB/s up, 1MB/s down; transmit >= min bytes in segments; sets "first" and "min"
+
+	dummy.SetRateLimits(131072, 1048576)
+
+	segments := 10
+
+	b := make([]byte, int(upstreamThresholdBytes)/segments)
+	firstReadStart := time.Now()
+	for i := 0; i < segments; i++ {
+		conn.Read(b)
+	}
+	firstReadEnd := time.Now()
+
+	b = make([]byte, int(downstreamThresholdBytes)/segments)
+	firstWriteStart := time.Now()
+	for i := 0; i < segments; i++ {
+		conn.Write(b)
+	}
+	firstWriteEnd := time.Now()
+
+	time.Sleep(burstDeadline * 2)
+
+	// Simulate 1MB/s up, 10MB/s down; repeatedly transmit < min bytes before deadline; ignored
+
+	dummy.SetRateLimits(1048576, 10485760)
+
+	b = make([]byte, 1)
+	segments = 1000
+	for i := 0; i < segments; i++ {
+		conn.Read(b)
+	}
+	for i := 0; i < segments; i++ {
+		conn.Write(b)
+	}
+
+	time.Sleep(burstDeadline * 2)
+
+	// Simulate 512Kb/s up, 5MB/s down; transmit >= min bytes; sets "max"
+
+	dummy.SetRateLimits(524288, 5242880)
+
+	maxReadStart := time.Now()
+	conn.Read(make([]byte, upstreamThresholdBytes))
+	maxReadEnd := time.Now()
+
+	maxWriteStart := time.Now()
+	conn.Write(make([]byte, downstreamThresholdBytes))
+	maxWriteEnd := time.Now()
+
+	time.Sleep(burstDeadline * 2)
+
+	// Simulate 256Kb/s up, 2MB/s down;, transmit >= min bytes; sets "last"
+
+	dummy.SetRateLimits(262144, 2097152)
+
+	lastReadStart := time.Now()
+	conn.Read(make([]byte, upstreamThresholdBytes))
+	lastReadEnd := time.Now()
+
+	lastWriteStart := time.Now()
+	conn.Write(make([]byte, downstreamThresholdBytes))
+	lastWriteEnd := time.Now()
+
+	time.Sleep(burstDeadline * 2)
+
+	conn.Close()
+
+	t.Logf("upstream first:    %d bytes in %s; %d bytes/s",
+		conn.upstreamBursts.first.bytes, conn.upstreamBursts.first.duration(), conn.upstreamBursts.first.rate())
+	t.Logf("upstream last:     %d bytes in %s; %d bytes/s",
+		conn.upstreamBursts.last.bytes, conn.upstreamBursts.last.duration(), conn.upstreamBursts.last.rate())
+	t.Logf("upstream min:      %d bytes in %s; %d bytes/s",
+		conn.upstreamBursts.min.bytes, conn.upstreamBursts.min.duration(), conn.upstreamBursts.min.rate())
+	t.Logf("upstream max:      %d bytes in %s; %d bytes/s",
+		conn.upstreamBursts.max.bytes, conn.upstreamBursts.max.duration(), conn.upstreamBursts.max.rate())
+	t.Logf("downstream first:  %d bytes in %s; %d bytes/s",
+		conn.downstreamBursts.first.bytes, conn.downstreamBursts.first.duration(), conn.downstreamBursts.first.rate())
+	t.Logf("downstream last:   %d bytes in %s; %d bytes/s",
+		conn.downstreamBursts.last.bytes, conn.downstreamBursts.last.duration(), conn.downstreamBursts.last.rate())
+	t.Logf("downstream min:    %d bytes in %s; %d bytes/s",
+		conn.downstreamBursts.min.bytes, conn.downstreamBursts.min.duration(), conn.downstreamBursts.min.rate())
+	t.Logf("downstream max:    %d bytes in %s; %d bytes/s",
+		conn.downstreamBursts.max.bytes, conn.downstreamBursts.max.duration(), conn.downstreamBursts.max.rate())
+
+	logFields := conn.GetMetrics(baseTime)
+
+	if len(logFields) != 32 {
+		t.Errorf("unexpected metric count: %d", len(logFields))
+	}
+
+	for name, expectedValue := range map[string]int64{
+		"burst_upstream_first_offset":     int64(firstReadStart.Sub(baseTime) / time.Millisecond),
+		"burst_upstream_first_duration":   int64(firstReadEnd.Sub(firstReadStart) / time.Millisecond),
+		"burst_upstream_first_bytes":      upstreamThresholdBytes,
+		"burst_upstream_first_rate":       131072,
+		"burst_upstream_last_offset":      int64(lastReadStart.Sub(baseTime) / time.Millisecond),
+		"burst_upstream_last_duration":    int64(lastReadEnd.Sub(lastReadStart) / time.Millisecond),
+		"burst_upstream_last_bytes":       upstreamThresholdBytes,
+		"burst_upstream_last_rate":        262144,
+		"burst_upstream_min_offset":       int64(firstReadStart.Sub(baseTime) / time.Millisecond),
+		"burst_upstream_min_duration":     int64(firstReadEnd.Sub(firstReadStart) / time.Millisecond),
+		"burst_upstream_min_bytes":        upstreamThresholdBytes,
+		"burst_upstream_min_rate":         131072,
+		"burst_upstream_max_offset":       int64(maxReadStart.Sub(baseTime) / time.Millisecond),
+		"burst_upstream_max_duration":     int64(maxReadEnd.Sub(maxReadStart) / time.Millisecond),
+		"burst_upstream_max_bytes":        upstreamThresholdBytes,
+		"burst_upstream_max_rate":         524288,
+		"burst_downstream_first_offset":   int64(firstWriteStart.Sub(baseTime) / time.Millisecond),
+		"burst_downstream_first_duration": int64(firstWriteEnd.Sub(firstWriteStart) / time.Millisecond),
+		"burst_downstream_first_bytes":    downstreamThresholdBytes,
+		"burst_downstream_first_rate":     1048576,
+		"burst_downstream_last_offset":    int64(lastWriteStart.Sub(baseTime) / time.Millisecond),
+		"burst_downstream_last_duration":  int64(lastWriteEnd.Sub(lastWriteStart) / time.Millisecond),
+		"burst_downstream_last_bytes":     downstreamThresholdBytes,
+		"burst_downstream_last_rate":      2097152,
+		"burst_downstream_min_offset":     int64(firstWriteStart.Sub(baseTime) / time.Millisecond),
+		"burst_downstream_min_duration":   int64(firstWriteEnd.Sub(firstWriteStart) / time.Millisecond),
+		"burst_downstream_min_bytes":      downstreamThresholdBytes,
+		"burst_downstream_min_rate":       1048576,
+		"burst_downstream_max_offset":     int64(maxWriteStart.Sub(baseTime) / time.Millisecond),
+		"burst_downstream_max_duration":   int64(maxWriteEnd.Sub(maxWriteStart) / time.Millisecond),
+		"burst_downstream_max_bytes":      downstreamThresholdBytes,
+		"burst_downstream_max_rate":       5242880,
+	} {
+		value, ok := logFields[name]
+		if !ok {
+			t.Errorf("missing expected metric: %s", name)
+			continue
+		}
+		valueInt64, ok := value.(int64)
+		if !ok {
+			t.Errorf("missing expected metric type: %s (%T)", name, value)
+			continue
+		}
+		minAcceptable := int64(float64(expectedValue) * 0.95)
+		maxAcceptable := int64(float64(expectedValue) * 1.05)
+		if valueInt64 < minAcceptable || valueInt64 > maxAcceptable {
+			t.Errorf("unexpected metric value: %s (%v <= %v <= %v)",
+				name, minAcceptable, valueInt64, maxAcceptable)
+			continue
+		}
+	}
+}
+
 func TestIsBogon(t *testing.T) {
 	if IsBogon(net.ParseIP("8.8.8.8")) {
 		t.Errorf("unexpected bogon")
@@ -340,3 +426,88 @@ func BenchmarkParseDNSQuestion(b *testing.B) {
 		ParseDNSQuestion(msg)
 	}
 }
+
+type dummyConn struct {
+	t                   *testing.T
+	timeout             *time.Timer
+	readBytesPerSecond  int64
+	writeBytesPerSecond int64
+	isClosed            int32
+}
+
+func (c *dummyConn) Read(b []byte) (n int, err error) {
+	if c.readBytesPerSecond > 0 {
+		sleep := time.Duration(float64(int64(len(b))*int64(time.Second)) / float64(c.readBytesPerSecond))
+		time.Sleep(sleep)
+	}
+	if c.timeout != nil {
+		select {
+		case <-c.timeout.C:
+			return 0, iotest.ErrTimeout
+		default:
+		}
+	}
+	return len(b), nil
+}
+
+func (c *dummyConn) Write(b []byte) (n int, err error) {
+	if c.writeBytesPerSecond > 0 {
+		sleep := time.Duration(float64(int64(len(b))*int64(time.Second)) / float64(c.writeBytesPerSecond))
+		time.Sleep(sleep)
+	}
+	if c.timeout != nil {
+		select {
+		case <-c.timeout.C:
+			return 0, iotest.ErrTimeout
+		default:
+		}
+	}
+	return len(b), nil
+}
+
+func (c *dummyConn) Close() error {
+	atomic.StoreInt32(&c.isClosed, 1)
+	return nil
+}
+
+func (c *dummyConn) IsClosed() bool {
+	return atomic.LoadInt32(&c.isClosed) == 1
+}
+
+func (c *dummyConn) LocalAddr() net.Addr {
+	c.t.Fatal("LocalAddr not implemented")
+	return nil
+}
+
+func (c *dummyConn) RemoteAddr() net.Addr {
+	c.t.Fatal("RemoteAddr not implemented")
+	return nil
+}
+
+func (c *dummyConn) SetDeadline(t time.Time) error {
+	duration := time.Until(t)
+	if c.timeout == nil {
+		c.timeout = time.NewTimer(duration)
+	} else {
+		if !c.timeout.Stop() {
+			<-c.timeout.C
+		}
+		c.timeout.Reset(duration)
+	}
+	return nil
+}
+
+func (c *dummyConn) SetReadDeadline(t time.Time) error {
+	c.t.Fatal("SetReadDeadline not implemented")
+	return nil
+}
+
+func (c *dummyConn) SetWriteDeadline(t time.Time) error {
+	c.t.Fatal("SetWriteDeadline not implemented")
+	return nil
+}
+
+func (c *dummyConn) SetRateLimits(readBytesPerSecond, writeBytesPerSecond int64) {
+	c.readBytesPerSecond = readBytesPerSecond
+	c.writeBytesPerSecond = writeBytesPerSecond
+}

+ 9 - 0
psiphon/common/parameters/parameters.go

@@ -258,6 +258,10 @@ const (
 	ServerReplayTargetUpstreamBytes                  = "ServerReplayTargetUpstreamBytes"
 	ServerReplayTargetDownstreamBytes                = "ServerReplayTargetDownstreamBytes"
 	ServerReplayFailedCountThreshold                 = "ServerReplayFailedCountThreshold"
+	ServerBurstUpstreamDeadline                      = "ServerBurstUpstreamDeadline"
+	ServerBurstUpstreamThresholdBytes                = "ServerBurstUpstreamThresholdBytes"
+	ServerBurstDownstreamDeadline                    = "ServerBurstDownstreamDeadline"
+	ServerBurstDownstreamThresholdBytes              = "ServerBurstDownstreamThresholdBytes"
 )
 
 const (
@@ -535,6 +539,11 @@ var defaultParameters = map[string]struct {
 	ServerReplayTargetUpstreamBytes:   {value: 0, minimum: 0, flags: serverSideOnly},
 	ServerReplayTargetDownstreamBytes: {value: 0, minimum: 0, flags: serverSideOnly},
 	ServerReplayFailedCountThreshold:  {value: 0, minimum: 0, flags: serverSideOnly},
+
+	ServerBurstUpstreamDeadline:         {value: time.Duration(0), minimum: time.Duration(0), flags: serverSideOnly},
+	ServerBurstUpstreamThresholdBytes:   {value: 0, minimum: 0, flags: serverSideOnly},
+	ServerBurstDownstreamDeadline:       {value: time.Duration(0), minimum: time.Duration(0), flags: serverSideOnly},
+	ServerBurstDownstreamThresholdBytes: {value: 0, minimum: 0, flags: serverSideOnly},
 }
 
 // IsServerSideOnly indicates if the parameter specified by name is used

+ 84 - 3
psiphon/server/server_test.go

@@ -132,6 +132,7 @@ func TestSSH(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -152,6 +153,7 @@ func TestOSSH(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -172,6 +174,7 @@ func TestFragmentedOSSH(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -192,6 +195,7 @@ func TestUnfrontedMeek(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -213,6 +217,7 @@ func TestUnfrontedMeekHTTPS(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -234,6 +239,7 @@ func TestUnfrontedMeekHTTPSTLS13(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -255,6 +261,7 @@ func TestUnfrontedMeekSessionTicket(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -276,6 +283,7 @@ func TestUnfrontedMeekSessionTicketTLS13(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    true,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -299,6 +307,7 @@ func TestQUICOSSH(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -322,6 +331,7 @@ func TestMarionetteOSSH(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -342,6 +352,7 @@ func TestWebTransportAPIRequests(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -362,6 +373,7 @@ func TestHotReload(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -382,6 +394,7 @@ func TestDefaultSponsorID(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -402,6 +415,7 @@ func TestDenyTrafficRules(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -422,6 +436,7 @@ func TestOmitAuthorization(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -442,6 +457,7 @@ func TestNoAuthorization(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -462,6 +478,7 @@ func TestUnusedAuthorization(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -482,6 +499,7 @@ func TestTCPOnlySLOK(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -502,6 +520,7 @@ func TestUDPOnlySLOK(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -522,6 +541,7 @@ func TestLivenessTest(t *testing.T) {
 			doPruneServerEntries: false,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
 		})
 }
 
@@ -542,6 +562,28 @@ func TestPruneServerEntries(t *testing.T) {
 			doPruneServerEntries: true,
 			doDanglingTCPConn:    false,
 			doPacketManipulation: false,
+			doBurstMonitor:       false,
+		})
+}
+
+func TestBurstMonitor(t *testing.T) {
+	runServer(t,
+		&runServerConfig{
+			tunnelProtocol:       "OSSH",
+			enableSSHAPIRequests: true,
+			doHotReload:          false,
+			doDefaultSponsorID:   false,
+			denyTrafficRules:     false,
+			requireAuthorization: true,
+			omitAuthorization:    false,
+			doTunneledWebRequest: true,
+			doTunneledNTPRequest: true,
+			forceFragmenting:     false,
+			forceLivenessTest:    false,
+			doPruneServerEntries: false,
+			doDanglingTCPConn:    true,
+			doPacketManipulation: false,
+			doBurstMonitor:       true,
 		})
 }
 
@@ -561,6 +603,7 @@ type runServerConfig struct {
 	doPruneServerEntries bool
 	doDanglingTCPConn    bool
 	doPacketManipulation bool
+	doBurstMonitor       bool
 }
 
 var (
@@ -607,7 +650,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	// establish.
 
 	doClientTactics := protocol.TunnelProtocolUsesMeek(runConfig.tunnelProtocol)
-	doServerTactics := doClientTactics || runConfig.forceFragmenting
+	doServerTactics := doClientTactics || runConfig.forceFragmenting || runConfig.doBurstMonitor
 
 	// All servers require a tactics config with valid keys.
 	tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey, err :=
@@ -696,7 +739,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			tacticsRequestObfuscatedKey,
 			runConfig.tunnelProtocol,
 			propagationChannelID,
-			livenessTestSize)
+			livenessTestSize,
+			runConfig.doBurstMonitor)
 	}
 
 	blocklistFilename := filepath.Join(testDataDirName, "blocklist.csv")
@@ -1196,6 +1240,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	expectClientBPFField := psiphon.ClientBPFEnabled() && doClientTactics
 	expectServerBPFField := ServerBPFEnabled() && doServerTactics
 	expectServerPacketManipulationField := runConfig.doPacketManipulation
+	expectBurstFields := runConfig.doBurstMonitor
 
 	select {
 	case logFields := <-serverTunnelLog:
@@ -1204,6 +1249,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			expectClientBPFField,
 			expectServerBPFField,
 			expectServerPacketManipulationField,
+			expectBurstFields,
 			logFields)
 		if err != nil {
 			t.Fatalf("invalid server tunnel log fields: %s", err)
@@ -1241,6 +1287,7 @@ func checkExpectedServerTunnelLogFields(
 	expectClientBPFField bool,
 	expectServerBPFField bool,
 	expectServerPacketManipulationField bool,
+	expectBurstFields bool,
 	fields map[string]interface{}) error {
 
 	// Limitations:
@@ -1250,6 +1297,8 @@ func checkExpectedServerTunnelLogFields(
 	// - fronting_provider_id/meek_dial_ip_address/meek_resolved_ip_address only logged for FRONTED meek protocols
 
 	for _, name := range []string{
+		"start_time",
+		"duration",
 		"session_id",
 		"last_connected",
 		"establishment_duration",
@@ -1426,6 +1475,25 @@ func checkExpectedServerTunnelLogFields(
 		}
 	}
 
+	if expectBurstFields {
+
+		// common.TestBurstMonitoredConn covers inclusion of additional fields.
+		for _, name := range []string{
+			"burst_upstream_first_rate",
+			"burst_upstream_last_rate",
+			"burst_upstream_min_rate",
+			"burst_upstream_max_rate",
+			"burst_downstream_first_rate",
+			"burst_downstream_last_rate",
+			"burst_downstream_min_rate",
+			"burst_downstream_max_rate",
+		} {
+			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
+				return fmt.Errorf("missing expected field '%s'", name)
+			}
+		}
+	}
+
 	if fields["network_type"].(string) != testNetworkType {
 		return fmt.Errorf("unexpected network_type '%s'", fields["network_type"])
 	}
@@ -1947,7 +2015,8 @@ func paveTacticsConfigFile(
 	tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey string,
 	tunnelProtocol string,
 	propagationChannelID string,
-	livenessTestSize int) {
+	livenessTestSize int,
+	doBurstMonitor bool) {
 
 	// Setting LimitTunnelProtocols passively exercises the
 	// server-side LimitTunnelProtocols enforcement.
@@ -1961,6 +2030,7 @@ func paveTacticsConfigFile(
         "TTL" : "60s",
         "Probability" : 1.0,
         "Parameters" : {
+          %s
           "LimitTunnelProtocols" : ["%s"],
           "FragmentorLimitProtocols" : ["%s"],
           "FragmentorProbability" : 1.0,
@@ -2024,9 +2094,20 @@ func paveTacticsConfigFile(
     }
     `
 
+	burstParameters := ""
+	if doBurstMonitor {
+		burstParameters = `
+          "ServerBurstUpstreamDeadline" : "100ms",
+          "ServerBurstUpstreamThresholdBytes" : 1000,
+          "ServerBurstDownstreamDeadline" : "100ms",
+          "ServerBurstDownstreamThresholdBytes" : 100000,
+	`
+	}
+
 	tacticsConfigJSON := fmt.Sprintf(
 		tacticsConfigJSONFormat,
 		tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
+		burstParameters,
 		tunnelProtocol,
 		tunnelProtocol,
 		tunnelProtocol,

+ 62 - 15
psiphon/server/tunnelServer.go

@@ -44,6 +44,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/marionette"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
@@ -1193,7 +1194,6 @@ type sshClient struct {
 	sshListener                          *sshListener
 	tunnelProtocol                       string
 	sshConn                              ssh.Conn
-	activityConn                         *common.ActivityMonitoredConn
 	throttledConn                        *common.ThrottledConn
 	serverPacketManipulation             string
 	replayedServerPacketManipulation     bool
@@ -1361,6 +1361,40 @@ func (sshClient *sshClient) run(
 	}
 	conn = activityConn
 
+	// Further wrap the connection with burst monitoring, when enabled.
+	//
+	// Limitation: burst parameters are fixed for the duration of the tunnel
+	// and do not change after a tactics hot reload.
+
+	var burstConn *common.BurstMonitoredConn
+
+	p, err := sshClient.sshServer.support.ServerTacticsParametersCache.Get(sshClient.geoIPData)
+	if err != nil {
+		log.WithTraceFields(LogFields{"error": errors.Trace(err)}).Warning(
+			"ServerTacticsParametersCache.Get failed")
+		return
+	}
+
+	if !p.IsNil() {
+		upstreamDeadline := p.Duration(parameters.ServerBurstUpstreamDeadline)
+		upstreamThresholdBytes := int64(p.Int(parameters.ServerBurstUpstreamThresholdBytes))
+		downstreamDeadline := p.Duration(parameters.ServerBurstUpstreamDeadline)
+		downstreamThresholdBytes := int64(p.Int(parameters.ServerBurstUpstreamThresholdBytes))
+
+		if (upstreamDeadline != 0 && upstreamThresholdBytes != 0) ||
+			(downstreamDeadline != 0 && downstreamThresholdBytes != 0) {
+
+			burstConn = common.NewBurstMonitoredConn(
+				conn,
+				upstreamDeadline, upstreamThresholdBytes,
+				downstreamDeadline, downstreamThresholdBytes)
+			conn = burstConn
+		}
+	}
+
+	// Allow garbage collection.
+	p.Close()
+
 	// Further wrap the connection in a rate limiting ThrottledConn.
 
 	throttledConn := common.NewThrottledConn(conn, sshClient.rateLimits())
@@ -1595,7 +1629,6 @@ func (sshClient *sshClient) run(
 
 	sshClient.Lock()
 	sshClient.sshConn = result.sshConn
-	sshClient.activityConn = activityConn
 	sshClient.throttledConn = throttledConn
 	sshClient.Unlock()
 
@@ -1612,10 +1645,35 @@ func (sshClient *sshClient) run(
 
 	sshClient.sshServer.unregisterEstablishedClient(sshClient)
 
+	// Log tunnel metrics.
+
+	var additionalMetrics []LogFields
+
+	// Add activity and burst metrics.
+	//
+	// The reported duration is based on last confirmed data transfer, which for
+	// sshClient.activityConn.GetActiveDuration() is time of last read byte and
+	// not conn close time. This is important for protocols such as meek. For
+	// meek, the connection remains open until the HTTP session expires, which
+	// may be some time after the tunnel has closed. (The meek protocol has no
+	// allowance for signalling payload EOF, and even if it did the client may
+	// not have the opportunity to send a final request with an EOF flag set.)
+
+	activityMetrics := make(LogFields)
+	activityMetrics["start_time"] = activityConn.GetStartTime()
+	activityMetrics["duration"] = int64(activityConn.GetActiveDuration() / time.Millisecond)
+	additionalMetrics = append(additionalMetrics, activityMetrics)
+
+	if burstConn != nil {
+		// Any outstanding burst should be recorded by burstConn.Close which should
+		// be called by unregisterEstablishedClient.
+		additionalMetrics = append(
+			additionalMetrics, LogFields(burstConn.GetMetrics(activityConn.GetStartTime())))
+	}
+
 	// Some conns report additional metrics. Meek conns report resiliency
 	// metrics and fragmentor.Conns report fragmentor configs.
 
-	var additionalMetrics []LogFields
 	if metricsSource, ok := baseConn.(common.MetricsSource); ok {
 		additionalMetrics = append(
 			additionalMetrics, LogFields(metricsSource.GetMetrics()))
@@ -1625,7 +1683,7 @@ func (sshClient *sshClient) run(
 			additionalMetrics, LogFields(result.obfuscatedSSHConn.GetMetrics()))
 	}
 
-	// Record server-replay metrics.
+	// Add server-replay metrics.
 
 	replayMetrics := make(LogFields)
 	replayedFragmentation := false
@@ -2376,15 +2434,6 @@ var serverTunnelStatParams = append(
 
 func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 
-	// Note: reporting duration based on last confirmed data transfer, which
-	// is reads for sshClient.activityConn.GetActiveDuration(), and not
-	// connection closing is important for protocols such as meek. For
-	// meek, the connection remains open until the HTTP session expires,
-	// which may be some time after the tunnel has closed. (The meek
-	// protocol has no allowance for signalling payload EOF, and even if
-	// it did the client may not have the opportunity to send a final
-	// request with an EOF flag set.)
-
 	sshClient.Lock()
 
 	logFields := getRequestLogFields(
@@ -2408,8 +2457,6 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 	}
 	logFields["session_id"] = sshClient.sessionID
 	logFields["handshake_completed"] = sshClient.handshakeState.completed
-	logFields["start_time"] = sshClient.activityConn.GetStartTime()
-	logFields["duration"] = int64(sshClient.activityConn.GetActiveDuration() / time.Millisecond)
 	logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
 	logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
 	logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount