|
|
@@ -26,6 +26,7 @@ import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "math/rand"
|
|
|
"net"
|
|
|
"strconv"
|
|
|
"sync"
|
|
|
@@ -41,13 +42,16 @@ import (
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
- SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
|
|
|
- SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
|
|
|
- SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT = 30 * time.Second
|
|
|
- SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT = 30 * time.Second
|
|
|
- SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
|
|
|
- SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second
|
|
|
- SSH_SEND_OSL_RETRY_FACTOR = 2
|
|
|
+ SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
|
|
|
+ SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
|
|
|
+ SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT = 30 * time.Second
|
|
|
+ SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT = 30 * time.Second
|
|
|
+ SSH_TCP_PORT_FORWARD_LIMIT_RETRIES = 5
|
|
|
+ SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD = 10 * time.Millisecond
|
|
|
+ SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MAX_PERIOD = 100 * time.Millisecond
|
|
|
+ SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
|
|
|
+ SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second
|
|
|
+ SSH_SEND_OSL_RETRY_FACTOR = 2
|
|
|
)
|
|
|
|
|
|
// TunnelServer is the main server that accepts Psiphon client
|
|
|
@@ -524,10 +528,12 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
aggregatedQualityMetrics.tcpPortForwardFailedCount += client.qualityMetrics.tcpPortForwardFailedCount
|
|
|
aggregatedQualityMetrics.tcpPortForwardFailedDuration +=
|
|
|
client.qualityMetrics.tcpPortForwardFailedDuration / time.Millisecond
|
|
|
+ aggregatedQualityMetrics.tcpPortForwardRejectedDialingLimitCount += client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount
|
|
|
client.qualityMetrics.tcpPortForwardDialedCount = 0
|
|
|
client.qualityMetrics.tcpPortForwardDialedDuration = 0
|
|
|
client.qualityMetrics.tcpPortForwardFailedCount = 0
|
|
|
client.qualityMetrics.tcpPortForwardFailedDuration = 0
|
|
|
+ client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount = 0
|
|
|
|
|
|
client.Unlock()
|
|
|
}
|
|
|
@@ -547,6 +553,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
allProtocolsStats["tcp_port_forward_dialed_duration"] = int64(aggregatedQualityMetrics.tcpPortForwardDialedDuration)
|
|
|
allProtocolsStats["tcp_port_forward_failed_count"] = aggregatedQualityMetrics.tcpPortForwardFailedCount
|
|
|
allProtocolsStats["tcp_port_forward_failed_duration"] = int64(aggregatedQualityMetrics.tcpPortForwardFailedDuration)
|
|
|
+ allProtocolsStats["tcp_port_forward_rejected_dialing_limit_count"] = aggregatedQualityMetrics.tcpPortForwardRejectedDialingLimitCount
|
|
|
|
|
|
for _, stats := range protocolStats {
|
|
|
for name, value := range stats["ALL"] {
|
|
|
@@ -702,10 +709,11 @@ type trafficState struct {
|
|
|
// upstream link. These stats are recorded by each sshClient
|
|
|
// and then reported and reset in sshServer.getLoadStats().
|
|
|
type qualityMetrics struct {
|
|
|
- tcpPortForwardDialedCount int64
|
|
|
- tcpPortForwardDialedDuration time.Duration
|
|
|
- tcpPortForwardFailedCount int64
|
|
|
- tcpPortForwardFailedDuration time.Duration
|
|
|
+ tcpPortForwardDialedCount int64
|
|
|
+ tcpPortForwardDialedDuration time.Duration
|
|
|
+ tcpPortForwardFailedCount int64
|
|
|
+ tcpPortForwardFailedDuration time.Duration
|
|
|
+ tcpPortForwardRejectedDialingLimitCount int64
|
|
|
}
|
|
|
|
|
|
type handshakeState struct {
|
|
|
@@ -1022,12 +1030,22 @@ func (sshClient *sshClient) runTunnel(
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- // process each port forward concurrently
|
|
|
+ // Process each port forward concurrently
|
|
|
+
|
|
|
waitGroup.Add(1)
|
|
|
go func(channel ssh.NewChannel) {
|
|
|
defer waitGroup.Done()
|
|
|
sshClient.handleNewPortForwardChannel(channel)
|
|
|
}(newChannel)
|
|
|
+
|
|
|
+ // Throttle accepting new channels when at the TCP dialing port forwards limit.
|
|
|
+ // This mitigates clients rapidly enqueuing many port forward handlers in a
|
|
|
+ // pre-dial state.
|
|
|
+ // TODO: block here until under the limit (excepting UDP)?
|
|
|
+
|
|
|
+ if sshClient.isTCPDialingPortForwardLimitExceeded() {
|
|
|
+ sshClient.tpcDialingPortForwardLimitThrottle()
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// The channel loop is interrupted by a client
|
|
|
@@ -1478,7 +1496,7 @@ func (sshClient *sshClient) closedPortForward(
|
|
|
state.bytesDown += bytesDown
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) updateQualityMetrics(
|
|
|
+func (sshClient *sshClient) updateQualityMetricsWithDialResult(
|
|
|
tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
@@ -1494,6 +1512,27 @@ func (sshClient *sshClient) updateQualityMetrics(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) tpcDialingPortForwardLimitThrottle() {
|
|
|
+
|
|
|
+ duration :=
|
|
|
+ SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD +
|
|
|
+ time.Duration(
|
|
|
+ rand.Int63n(
|
|
|
+ int64(SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MAX_PERIOD-
|
|
|
+ SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD)+1))
|
|
|
+
|
|
|
+ ctx, _ := context.WithTimeout(sshClient.runContext, duration)
|
|
|
+ <-ctx.Done()
|
|
|
+}
|
|
|
+
|
|
|
func (sshClient *sshClient) handleTCPChannel(
|
|
|
hostToConnect string,
|
|
|
portToConnect int,
|
|
|
@@ -1504,9 +1543,9 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
// 1. Call dialingTCPPortForward(), which increments concurrentDialingPortForwardCount
|
|
|
//
|
|
|
// 2. Check isTCPDialingPortForwardLimitExceeded(), which enforces the configurable
|
|
|
- // limit on concurrentDialingPortForwardCount. If the limit is exceeded, the port
|
|
|
- // forward is _immediately rejected_; this hard limit applies back pressure when
|
|
|
- // upstream network resources are impaired.
|
|
|
+ // limit on concurrentDialingPortForwardCount. When the limit is exceeded, throttle
|
|
|
+ // and recheck and then ultimately reject the port forward; this hard limit applies
|
|
|
+ // back pressure when upstream network resources are impaired.
|
|
|
//
|
|
|
// TOCTOU note: important to increment the port forward counts _before_
|
|
|
// checking limits; otherwise, the client could potentially consume excess
|
|
|
@@ -1541,7 +1580,17 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
- if exceeded := sshClient.isTCPDialingPortForwardLimitExceeded(); exceeded {
|
|
|
+ exceeded := true
|
|
|
+ for i := 0; i < SSH_TCP_PORT_FORWARD_LIMIT_RETRIES; i++ {
|
|
|
+ exceeded = sshClient.isTCPDialingPortForwardLimitExceeded()
|
|
|
+ if !exceeded {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ sshClient.tpcDialingPortForwardLimitThrottle()
|
|
|
+ }
|
|
|
+ if exceeded {
|
|
|
+
|
|
|
+ sshClient.updateQualityMetricsWithRejectedDialingLimit()
|
|
|
|
|
|
sshClient.rejectNewChannel(
|
|
|
newChannel, ssh.Prohibited, "dialing port forward limit exceeded")
|
|
|
@@ -1579,15 +1628,15 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
|
|
|
|
|
|
ctx, _ := context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT)
|
|
|
- IPs, err := (&net.Resolver{}).LookupIPAddr(
|
|
|
- ctx, hostToConnect)
|
|
|
+ IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
|
|
|
|
|
|
- // TODO: shuffle list to try other IPs
|
|
|
+ // TODO: shuffle list to try other IPs?
|
|
|
// TODO: IPv6 support
|
|
|
var IP net.IP
|
|
|
for _, ip := range IPs {
|
|
|
if ip.IP.To4() != nil {
|
|
|
IP = ip.IP
|
|
|
+ break
|
|
|
}
|
|
|
}
|
|
|
if err == nil && IP == nil {
|
|
|
@@ -1597,7 +1646,7 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
if err != nil {
|
|
|
|
|
|
// Record a port forward failure
|
|
|
- sshClient.updateQualityMetrics(true, monotime.Since(dialStartTime))
|
|
|
+ sshClient.updateQualityMetricsWithDialResult(true, monotime.Since(dialStartTime))
|
|
|
|
|
|
sshClient.rejectNewChannel(
|
|
|
newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
|
|
|
@@ -1627,11 +1676,10 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
|
|
|
|
ctx, _ = context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
|
|
|
- fwdConn, err := (&net.Dialer{}).DialContext(
|
|
|
- ctx, "tcp", remoteAddr)
|
|
|
+ fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
|
|
|
|
|
|
// Record port forward success or failure
|
|
|
- sshClient.updateQualityMetrics(err == nil, monotime.Since(dialStartTime))
|
|
|
+ sshClient.updateQualityMetricsWithDialResult(err == nil, monotime.Since(dialStartTime))
|
|
|
|
|
|
if err != nil {
|
|
|
|