Просмотр исходного кода

Add refinements to e10ce6a

- throttle accepting new channels when at
  the dialing limit
- before rejecting a new TCP port forward, due
  to the dialing limit, throttle and retry
- add tcp_port_forward_rejected_dialing_limit_count
  server_load stat for monitoring dialing
  limit impact
- fix: select _first_ acceptable IP returned from
  resolver
Rod Hynes 9 лет назад
Родитель
Сommit
d136f8de35
3 измененных файлов с 74 добавлено и 26 удалено
  1. 1 1
      psiphon/common/utils.go
  2. 1 1
      psiphon/server/trafficRules.go
  3. 72 24
      psiphon/server/tunnelServer.go

+ 1 - 1
psiphon/common/utils.go

@@ -105,7 +105,7 @@ func MakeSecureRandomPadding(minLength, maxLength int) ([]byte, error) {
 }
 
 // MakeRandomPeriod returns a random duration, within a given range.
-// In the unlikely case where an  underlying MakeRandom functions fails,
+// In the unlikely case where an underlying MakeRandom functions fails,
 // the period is the minimum.
 func MakeRandomPeriod(min, max time.Duration) (time.Duration, error) {
 	period, err := MakeSecureRandomInt64(max.Nanoseconds() - min.Nanoseconds())

+ 1 - 1
psiphon/server/trafficRules.go

@@ -109,7 +109,7 @@ type TrafficRules struct {
 
 	// MaxTCPDialingPortForwardCount is the maximum number of dialing
 	// TCP port forwards each client may have open concurrently. When
-	// at the limit, new TCP port forwards are rejected.
+	// persistently at the limit, new TCP port forwards are rejected.
 	// A value of 0 specifies no maximum. When omitted in
 	// DefaultRules, DEFAULT_MAX_TCP_DIALING_PORT_FORWARD_COUNT is used.
 	MaxTCPDialingPortForwardCount *int

+ 72 - 24
psiphon/server/tunnelServer.go

@@ -26,6 +26,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"math/rand"
 	"net"
 	"strconv"
 	"sync"
@@ -41,13 +42,16 @@ import (
 )
 
 const (
-	SSH_HANDSHAKE_TIMEOUT                  = 30 * time.Second
-	SSH_CONNECTION_READ_DEADLINE           = 5 * time.Minute
-	SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT = 30 * time.Second
-	SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT      = 30 * time.Second
-	SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE  = 8192
-	SSH_SEND_OSL_INITIAL_RETRY_DELAY       = 30 * time.Second
-	SSH_SEND_OSL_RETRY_FACTOR              = 2
+	SSH_HANDSHAKE_TIMEOUT                          = 30 * time.Second
+	SSH_CONNECTION_READ_DEADLINE                   = 5 * time.Minute
+	SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT         = 30 * time.Second
+	SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT              = 30 * time.Second
+	SSH_TCP_PORT_FORWARD_LIMIT_RETRIES             = 5
+	SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD = 10 * time.Millisecond
+	SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MAX_PERIOD = 100 * time.Millisecond
+	SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE          = 8192
+	SSH_SEND_OSL_INITIAL_RETRY_DELAY               = 30 * time.Second
+	SSH_SEND_OSL_RETRY_FACTOR                      = 2
 )
 
 // TunnelServer is the main server that accepts Psiphon client
@@ -524,10 +528,12 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 		aggregatedQualityMetrics.tcpPortForwardFailedCount += client.qualityMetrics.tcpPortForwardFailedCount
 		aggregatedQualityMetrics.tcpPortForwardFailedDuration +=
 			client.qualityMetrics.tcpPortForwardFailedDuration / time.Millisecond
+		aggregatedQualityMetrics.tcpPortForwardRejectedDialingLimitCount += client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount
 		client.qualityMetrics.tcpPortForwardDialedCount = 0
 		client.qualityMetrics.tcpPortForwardDialedDuration = 0
 		client.qualityMetrics.tcpPortForwardFailedCount = 0
 		client.qualityMetrics.tcpPortForwardFailedDuration = 0
+		client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount = 0
 
 		client.Unlock()
 	}
@@ -547,6 +553,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 	allProtocolsStats["tcp_port_forward_dialed_duration"] = int64(aggregatedQualityMetrics.tcpPortForwardDialedDuration)
 	allProtocolsStats["tcp_port_forward_failed_count"] = aggregatedQualityMetrics.tcpPortForwardFailedCount
 	allProtocolsStats["tcp_port_forward_failed_duration"] = int64(aggregatedQualityMetrics.tcpPortForwardFailedDuration)
+	allProtocolsStats["tcp_port_forward_rejected_dialing_limit_count"] = aggregatedQualityMetrics.tcpPortForwardRejectedDialingLimitCount
 
 	for _, stats := range protocolStats {
 		for name, value := range stats["ALL"] {
@@ -702,10 +709,11 @@ type trafficState struct {
 // upstream link. These stats are recorded by each sshClient
 // and then reported and reset in sshServer.getLoadStats().
 type qualityMetrics struct {
-	tcpPortForwardDialedCount    int64
-	tcpPortForwardDialedDuration time.Duration
-	tcpPortForwardFailedCount    int64
-	tcpPortForwardFailedDuration time.Duration
+	tcpPortForwardDialedCount               int64
+	tcpPortForwardDialedDuration            time.Duration
+	tcpPortForwardFailedCount               int64
+	tcpPortForwardFailedDuration            time.Duration
+	tcpPortForwardRejectedDialingLimitCount int64
 }
 
 type handshakeState struct {
@@ -1022,12 +1030,22 @@ func (sshClient *sshClient) runTunnel(
 			continue
 		}
 
-		// process each port forward concurrently
+		// Process each port forward concurrently
+
 		waitGroup.Add(1)
 		go func(channel ssh.NewChannel) {
 			defer waitGroup.Done()
 			sshClient.handleNewPortForwardChannel(channel)
 		}(newChannel)
+
+		// Throttle accepting new channels when at the TCP dialing port forwards limit.
+		// This mitigates clients rapidly enqueuing many port forward handlers in a
+		// pre-dial state.
+		// TODO: block here until under the limit (excepting UDP)?
+
+		if sshClient.isTCPDialingPortForwardLimitExceeded() {
+			sshClient.tpcDialingPortForwardLimitThrottle()
+		}
 	}
 
 	// The channel loop is interrupted by a client
@@ -1478,7 +1496,7 @@ func (sshClient *sshClient) closedPortForward(
 	state.bytesDown += bytesDown
 }
 
-func (sshClient *sshClient) updateQualityMetrics(
+func (sshClient *sshClient) updateQualityMetricsWithDialResult(
 	tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
 
 	sshClient.Lock()
@@ -1494,6 +1512,27 @@ func (sshClient *sshClient) updateQualityMetrics(
 	}
 }
 
+func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
+}
+
+func (sshClient *sshClient) tpcDialingPortForwardLimitThrottle() {
+
+	duration :=
+		SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD +
+			time.Duration(
+				rand.Int63n(
+					int64(SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MAX_PERIOD-
+						SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD)+1))
+
+	ctx, _ := context.WithTimeout(sshClient.runContext, duration)
+	<-ctx.Done()
+}
+
 func (sshClient *sshClient) handleTCPChannel(
 	hostToConnect string,
 	portToConnect int,
@@ -1504,9 +1543,9 @@ func (sshClient *sshClient) handleTCPChannel(
 	// 1. Call dialingTCPPortForward(), which increments concurrentDialingPortForwardCount
 	//
 	// 2. Check isTCPDialingPortForwardLimitExceeded(), which enforces the configurable
-	//    limit on concurrentDialingPortForwardCount. If the limit is exceeded, the port
-	//    forward is _immediately rejected_; this hard limit applies back pressure when
-	//    upstream network resources are impaired.
+	//    limit on concurrentDialingPortForwardCount. When the limit is exceeded, throttle
+	//    and recheck and then ultimately reject the port forward; this hard limit applies
+	//    back pressure when upstream network resources are impaired.
 	//
 	//        TOCTOU note: important to increment the port forward counts _before_
 	//        checking limits; otherwise, the client could potentially consume excess
@@ -1541,7 +1580,17 @@ func (sshClient *sshClient) handleTCPChannel(
 		}
 	}()
 
-	if exceeded := sshClient.isTCPDialingPortForwardLimitExceeded(); exceeded {
+	exceeded := true
+	for i := 0; i < SSH_TCP_PORT_FORWARD_LIMIT_RETRIES; i++ {
+		exceeded = sshClient.isTCPDialingPortForwardLimitExceeded()
+		if !exceeded {
+			break
+		}
+		sshClient.tpcDialingPortForwardLimitThrottle()
+	}
+	if exceeded {
+
+		sshClient.updateQualityMetricsWithRejectedDialingLimit()
 
 		sshClient.rejectNewChannel(
 			newChannel, ssh.Prohibited, "dialing port forward limit exceeded")
@@ -1579,15 +1628,15 @@ func (sshClient *sshClient) handleTCPChannel(
 	log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
 
 	ctx, _ := context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT)
-	IPs, err := (&net.Resolver{}).LookupIPAddr(
-		ctx, hostToConnect)
+	IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
 
-	// TODO: shuffle list to try other IPs
+	// TODO: shuffle list to try other IPs?
 	// TODO: IPv6 support
 	var IP net.IP
 	for _, ip := range IPs {
 		if ip.IP.To4() != nil {
 			IP = ip.IP
+			break
 		}
 	}
 	if err == nil && IP == nil {
@@ -1597,7 +1646,7 @@ func (sshClient *sshClient) handleTCPChannel(
 	if err != nil {
 
 		// Record a port forward failure
-		sshClient.updateQualityMetrics(true, monotime.Since(dialStartTime))
+		sshClient.updateQualityMetricsWithDialResult(true, monotime.Since(dialStartTime))
 
 		sshClient.rejectNewChannel(
 			newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
@@ -1627,11 +1676,10 @@ func (sshClient *sshClient) handleTCPChannel(
 	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
 
 	ctx, _ = context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
-	fwdConn, err := (&net.Dialer{}).DialContext(
-		ctx, "tcp", remoteAddr)
+	fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
 
 	// Record port forward success or failure
-	sshClient.updateQualityMetrics(err == nil, monotime.Since(dialStartTime))
+	sshClient.updateQualityMetricsWithDialResult(err == nil, monotime.Since(dialStartTime))
 
 	if err != nil {