Просмотр исходного кода

Fix: client/server upstream/downstream reversal

Rod Hynes 5 лет назад
Родитель
Сommit
d76f212b22
4 измененных файлов с 75 добавлено и 50 удалено
  1. 62 38
      psiphon/common/burst.go
  2. 9 8
      psiphon/common/burst_test.go
  3. 1 0
      psiphon/server/tunnelServer.go
  4. 3 4
      psiphon/tunnel.go

+ 62 - 38
psiphon/common/burst.go

@@ -46,35 +46,48 @@ import (
 // Overhead: BurstMonitoredConn adds mutexes but does not use timers.
 type BurstMonitoredConn struct {
 	net.Conn
-	upstreamDeadline         time.Duration
-	upstreamThresholdBytes   int64
-	downstreamDeadline       time.Duration
-	downstreamThresholdBytes int64
-
-	readMutex            sync.Mutex
-	currentUpstreamBurst burst
-	upstreamBursts       burstHistory
-
-	writeMutex             sync.Mutex
-	currentDownstreamBurst burst
-	downstreamBursts       burstHistory
+	isServer            bool
+	readDeadline        time.Duration
+	readThresholdBytes  int64
+	writeDeadline       time.Duration
+	writeThresholdBytes int64
+
+	readMutex        sync.Mutex
+	currentReadBurst burst
+	readBursts       burstHistory
+
+	writeMutex        sync.Mutex
+	currentWriteBurst burst
+	writeBursts       burstHistory
 }
 
 // NewBurstMonitoredConn creates a new BurstMonitoredConn.
 func NewBurstMonitoredConn(
 	conn net.Conn,
+	isServer bool,
 	upstreamDeadline time.Duration,
 	upstreamThresholdBytes int64,
 	downstreamDeadline time.Duration,
 	downstreamThresholdBytes int64) *BurstMonitoredConn {
 
-	return &BurstMonitoredConn{
-		Conn:                     conn,
-		upstreamDeadline:         upstreamDeadline,
-		upstreamThresholdBytes:   upstreamThresholdBytes,
-		downstreamDeadline:       downstreamDeadline,
-		downstreamThresholdBytes: downstreamThresholdBytes,
+	burstConn := &BurstMonitoredConn{
+		Conn:     conn,
+		isServer: isServer,
 	}
+
+	if isServer {
+		burstConn.readDeadline = upstreamDeadline
+		burstConn.readThresholdBytes = upstreamThresholdBytes
+		burstConn.writeDeadline = downstreamDeadline
+		burstConn.writeThresholdBytes = downstreamThresholdBytes
+	} else {
+		burstConn.readDeadline = downstreamDeadline
+		burstConn.readThresholdBytes = downstreamThresholdBytes
+		burstConn.writeDeadline = upstreamDeadline
+		burstConn.writeThresholdBytes = upstreamThresholdBytes
+	}
+
+	return burstConn
 }
 
 type burst struct {
@@ -118,7 +131,7 @@ type burstHistory struct {
 
 func (conn *BurstMonitoredConn) Read(buffer []byte) (int, error) {
 
-	if conn.upstreamDeadline <= 0 || conn.upstreamThresholdBytes <= 0 {
+	if conn.readDeadline <= 0 || conn.readThresholdBytes <= 0 {
 		return conn.Conn.Read(buffer)
 	}
 
@@ -132,10 +145,10 @@ func (conn *BurstMonitoredConn) Read(buffer []byte) (int, error) {
 			start,
 			end,
 			int64(n),
-			conn.upstreamDeadline,
-			conn.upstreamThresholdBytes,
-			&conn.currentUpstreamBurst,
-			&conn.upstreamBursts)
+			conn.readDeadline,
+			conn.readThresholdBytes,
+			&conn.currentReadBurst,
+			&conn.readBursts)
 		conn.readMutex.Unlock()
 	}
 
@@ -145,7 +158,7 @@ func (conn *BurstMonitoredConn) Read(buffer []byte) (int, error) {
 
 func (conn *BurstMonitoredConn) Write(buffer []byte) (int, error) {
 
-	if conn.downstreamDeadline <= 0 || conn.downstreamThresholdBytes <= 0 {
+	if conn.writeDeadline <= 0 || conn.writeThresholdBytes <= 0 {
 		return conn.Conn.Write(buffer)
 	}
 
@@ -159,10 +172,10 @@ func (conn *BurstMonitoredConn) Write(buffer []byte) (int, error) {
 			start,
 			end,
 			int64(n),
-			conn.downstreamDeadline,
-			conn.downstreamThresholdBytes,
-			&conn.currentDownstreamBurst,
-			&conn.downstreamBursts)
+			conn.writeDeadline,
+			conn.writeThresholdBytes,
+			&conn.currentWriteBurst,
+			&conn.writeBursts)
 		conn.writeMutex.Unlock()
 	}
 
@@ -173,21 +186,21 @@ func (conn *BurstMonitoredConn) Write(buffer []byte) (int, error) {
 func (conn *BurstMonitoredConn) Close() error {
 	err := conn.Conn.Close()
 
-	if conn.upstreamDeadline > 0 && conn.upstreamThresholdBytes > 0 {
+	if conn.readDeadline > 0 && conn.readThresholdBytes > 0 {
 		conn.readMutex.Lock()
 		conn.endBurst(
-			conn.upstreamThresholdBytes,
-			&conn.currentUpstreamBurst,
-			&conn.upstreamBursts)
+			conn.readThresholdBytes,
+			&conn.currentReadBurst,
+			&conn.readBursts)
 		conn.readMutex.Unlock()
 	}
 
-	if conn.downstreamDeadline > 0 && conn.downstreamThresholdBytes > 0 {
+	if conn.writeDeadline > 0 && conn.writeThresholdBytes > 0 {
 		conn.writeMutex.Lock()
 		conn.endBurst(
-			conn.downstreamThresholdBytes,
-			&conn.currentDownstreamBurst,
-			&conn.downstreamBursts)
+			conn.writeThresholdBytes,
+			&conn.currentWriteBurst,
+			&conn.writeBursts)
 		conn.writeMutex.Unlock()
 	}
 
@@ -228,8 +241,19 @@ func (conn *BurstMonitoredConn) GetMetrics(baseTime time.Time) LogFields {
 		addFields(prefix+"max_", &history.max)
 	}
 
-	addHistory("burst_upstream_", &conn.upstreamBursts)
-	addHistory("burst_downstream_", &conn.downstreamBursts)
+	var upstreamBursts *burstHistory
+	var downstreamBursts *burstHistory
+
+	if conn.isServer {
+		upstreamBursts = &conn.readBursts
+		downstreamBursts = &conn.writeBursts
+	} else {
+		upstreamBursts = &conn.writeBursts
+		downstreamBursts = &conn.readBursts
+	}
+
+	addHistory("burst_upstream_", upstreamBursts)
+	addHistory("burst_downstream_", downstreamBursts)
 
 	return logFields
 }

+ 9 - 8
psiphon/common/burst_test.go

@@ -36,6 +36,7 @@ func TestBurstMonitoredConn(t *testing.T) {
 
 	conn := NewBurstMonitoredConn(
 		dummy,
+		true,
 		burstDeadline,
 		upstreamThresholdBytes,
 		burstDeadline,
@@ -109,21 +110,21 @@ func TestBurstMonitoredConn(t *testing.T) {
 	conn.Close()
 
 	t.Logf("upstream first:    %d bytes in %s; %d bytes/s",
-		conn.upstreamBursts.first.bytes, conn.upstreamBursts.first.duration(), conn.upstreamBursts.first.rate())
+		conn.readBursts.first.bytes, conn.readBursts.first.duration(), conn.readBursts.first.rate())
 	t.Logf("upstream last:     %d bytes in %s; %d bytes/s",
-		conn.upstreamBursts.last.bytes, conn.upstreamBursts.last.duration(), conn.upstreamBursts.last.rate())
+		conn.readBursts.last.bytes, conn.readBursts.last.duration(), conn.readBursts.last.rate())
 	t.Logf("upstream min:      %d bytes in %s; %d bytes/s",
-		conn.upstreamBursts.min.bytes, conn.upstreamBursts.min.duration(), conn.upstreamBursts.min.rate())
+		conn.readBursts.min.bytes, conn.readBursts.min.duration(), conn.readBursts.min.rate())
 	t.Logf("upstream max:      %d bytes in %s; %d bytes/s",
-		conn.upstreamBursts.max.bytes, conn.upstreamBursts.max.duration(), conn.upstreamBursts.max.rate())
+		conn.readBursts.max.bytes, conn.readBursts.max.duration(), conn.readBursts.max.rate())
 	t.Logf("downstream first:  %d bytes in %s; %d bytes/s",
-		conn.downstreamBursts.first.bytes, conn.downstreamBursts.first.duration(), conn.downstreamBursts.first.rate())
+		conn.writeBursts.first.bytes, conn.writeBursts.first.duration(), conn.writeBursts.first.rate())
 	t.Logf("downstream last:   %d bytes in %s; %d bytes/s",
-		conn.downstreamBursts.last.bytes, conn.downstreamBursts.last.duration(), conn.downstreamBursts.last.rate())
+		conn.writeBursts.last.bytes, conn.writeBursts.last.duration(), conn.writeBursts.last.rate())
 	t.Logf("downstream min:    %d bytes in %s; %d bytes/s",
-		conn.downstreamBursts.min.bytes, conn.downstreamBursts.min.duration(), conn.downstreamBursts.min.rate())
+		conn.writeBursts.min.bytes, conn.writeBursts.min.duration(), conn.writeBursts.min.rate())
 	t.Logf("downstream max:    %d bytes in %s; %d bytes/s",
-		conn.downstreamBursts.max.bytes, conn.downstreamBursts.max.duration(), conn.downstreamBursts.max.rate())
+		conn.writeBursts.max.bytes, conn.writeBursts.max.duration(), conn.writeBursts.max.rate())
 
 	logFields := conn.GetMetrics(baseTime)
 

+ 1 - 0
psiphon/server/tunnelServer.go

@@ -1386,6 +1386,7 @@ func (sshClient *sshClient) run(
 
 			burstConn = common.NewBurstMonitoredConn(
 				conn,
+				true,
 				upstreamDeadline, upstreamThresholdBytes,
 				downstreamDeadline, downstreamThresholdBytes)
 			conn = burstConn

+ 3 - 4
psiphon/tunnel.go

@@ -790,10 +790,9 @@ func dialTunnel(
 	monitoringStartTime := time.Now()
 	monitoredConn := common.NewBurstMonitoredConn(
 		dialConn,
-		burstUpstreamDeadline,
-		burstUpstreamThresholdBytes,
-		burstDownstreamDeadline,
-		burstDownstreamThresholdBytes)
+		false,
+		burstUpstreamDeadline, burstUpstreamThresholdBytes,
+		burstDownstreamDeadline, burstDownstreamThresholdBytes)
 
 	// Apply throttling (if configured)
 	throttledConn := common.NewThrottledConn(