|
|
@@ -20,6 +20,7 @@
|
|
|
package server
|
|
|
|
|
|
import (
|
|
|
+ "context"
|
|
|
"crypto/subtle"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
@@ -29,6 +30,7 @@ import (
|
|
|
"strconv"
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
|
+ "syscall"
|
|
|
"time"
|
|
|
|
|
|
"github.com/Psiphon-Inc/crypto/ssh"
|
|
|
@@ -39,13 +41,13 @@ import (
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
- SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
|
|
|
- SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
|
|
|
- SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT = 30 * time.Second
|
|
|
- SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT = 30 * time.Second
|
|
|
- SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
|
|
|
- SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second
|
|
|
- SSH_SEND_OSL_RETRY_FACTOR = 2
|
|
|
+ SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
|
|
|
+ SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
|
|
|
+ SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT = 10 * time.Second
|
|
|
+ SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
|
|
|
+ SSH_TCP_PORT_FORWARD_QUEUE_SIZE = 1024
|
|
|
+ SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second
|
|
|
+ SSH_SEND_OSL_RETRY_FACTOR = 2
|
|
|
)
|
|
|
|
|
|
// TunnelServer is the main server that accepts Psiphon client
|
|
|
@@ -455,6 +457,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
protocolStats[tunnelProtocol]["ALL"] = make(map[string]int64)
|
|
|
protocolStats[tunnelProtocol]["ALL"]["accepted_clients"] = 0
|
|
|
protocolStats[tunnelProtocol]["ALL"]["established_clients"] = 0
|
|
|
+ protocolStats[tunnelProtocol]["ALL"]["dialing_tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol]["ALL"]["tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol]["ALL"]["total_tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol]["ALL"]["udp_port_forwards"] = 0
|
|
|
@@ -471,6 +474,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
protocolStats[tunnelProtocol][region] = make(map[string]int64)
|
|
|
protocolStats[tunnelProtocol][region]["accepted_clients"] = 0
|
|
|
protocolStats[tunnelProtocol][region]["established_clients"] = 0
|
|
|
+ protocolStats[tunnelProtocol][region]["dialing_tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol][region]["tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol][region]["total_tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol][region]["udp_port_forwards"] = 0
|
|
|
@@ -495,6 +499,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
protocolStats[client.tunnelProtocol][region] = make(map[string]int64)
|
|
|
protocolStats[client.tunnelProtocol][region]["accepted_clients"] = 0
|
|
|
protocolStats[client.tunnelProtocol][region]["established_clients"] = 0
|
|
|
+ protocolStats[client.tunnelProtocol][region]["dialing_tcp_port_forwards"] = 0
|
|
|
protocolStats[client.tunnelProtocol][region]["tcp_port_forwards"] = 0
|
|
|
protocolStats[client.tunnelProtocol][region]["total_tcp_port_forwards"] = 0
|
|
|
protocolStats[client.tunnelProtocol][region]["udp_port_forwards"] = 0
|
|
|
@@ -504,8 +509,10 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
// Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
|
|
|
protocolStats[client.tunnelProtocol][region]["established_clients"] += 1
|
|
|
|
|
|
+ protocolStats[client.tunnelProtocol][region]["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
|
|
|
protocolStats[client.tunnelProtocol][region]["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
|
|
|
protocolStats[client.tunnelProtocol][region]["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
|
|
|
+ // client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
|
|
|
protocolStats[client.tunnelProtocol][region]["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
|
|
|
protocolStats[client.tunnelProtocol][region]["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
|
|
|
|
|
|
@@ -517,10 +524,12 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
aggregatedQualityMetrics.tcpPortForwardFailedCount += client.qualityMetrics.tcpPortForwardFailedCount
|
|
|
aggregatedQualityMetrics.tcpPortForwardFailedDuration +=
|
|
|
client.qualityMetrics.tcpPortForwardFailedDuration / time.Millisecond
|
|
|
+ aggregatedQualityMetrics.tcpPortForwardRejectedDialingLimitCount += client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount
|
|
|
client.qualityMetrics.tcpPortForwardDialedCount = 0
|
|
|
client.qualityMetrics.tcpPortForwardDialedDuration = 0
|
|
|
client.qualityMetrics.tcpPortForwardFailedCount = 0
|
|
|
client.qualityMetrics.tcpPortForwardFailedDuration = 0
|
|
|
+ client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount = 0
|
|
|
|
|
|
client.Unlock()
|
|
|
}
|
|
|
@@ -531,6 +540,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
allProtocolsStats := make(map[string]int64)
|
|
|
allProtocolsStats["accepted_clients"] = 0
|
|
|
allProtocolsStats["established_clients"] = 0
|
|
|
+ allProtocolsStats["dialing_tcp_port_forwards"] = 0
|
|
|
allProtocolsStats["tcp_port_forwards"] = 0
|
|
|
allProtocolsStats["total_tcp_port_forwards"] = 0
|
|
|
allProtocolsStats["udp_port_forwards"] = 0
|
|
|
@@ -539,6 +549,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
allProtocolsStats["tcp_port_forward_dialed_duration"] = int64(aggregatedQualityMetrics.tcpPortForwardDialedDuration)
|
|
|
allProtocolsStats["tcp_port_forward_failed_count"] = aggregatedQualityMetrics.tcpPortForwardFailedCount
|
|
|
allProtocolsStats["tcp_port_forward_failed_duration"] = int64(aggregatedQualityMetrics.tcpPortForwardFailedDuration)
|
|
|
+ allProtocolsStats["tcp_port_forward_rejected_dialing_limit_count"] = aggregatedQualityMetrics.tcpPortForwardRejectedDialingLimitCount
|
|
|
|
|
|
for _, stats := range protocolStats {
|
|
|
for name, value := range stats["ALL"] {
|
|
|
@@ -628,34 +639,65 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
sshClient.run(clientConn)
|
|
|
}
|
|
|
|
|
|
+func (sshServer *sshServer) monitorPortForwardDialError(err error) {
|
|
|
+
|
|
|
+ // "err" is the error returned from a failed TCP or UDP port
|
|
|
+ // forward dial. Certain system error codes indicate low resource
|
|
|
+ // conditions: insufficient file descriptors, ephemeral ports, or
|
|
|
+ // memory. For these cases, log an alert.
|
|
|
+
|
|
|
+ // TODO: also temporarily suspend new clients
|
|
|
+
|
|
|
+ // Note: don't log net.OpError.Error() as the full error string
|
|
|
+ // may contain client destination addresses.
|
|
|
+
|
|
|
+ opErr, ok := err.(*net.OpError)
|
|
|
+ if ok {
|
|
|
+ if opErr.Err == syscall.EADDRNOTAVAIL ||
|
|
|
+ opErr.Err == syscall.EAGAIN ||
|
|
|
+ opErr.Err == syscall.ENOMEM ||
|
|
|
+ opErr.Err == syscall.EMFILE ||
|
|
|
+ opErr.Err == syscall.ENFILE {
|
|
|
+
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{"error": opErr.Err}).Error(
|
|
|
+ "port forward dial failed due to unavailable resource")
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
type sshClient struct {
|
|
|
sync.Mutex
|
|
|
- sshServer *sshServer
|
|
|
- tunnelProtocol string
|
|
|
- sshConn ssh.Conn
|
|
|
- activityConn *common.ActivityMonitoredConn
|
|
|
- throttledConn *common.ThrottledConn
|
|
|
- geoIPData GeoIPData
|
|
|
- sessionID string
|
|
|
- supportsServerRequests bool
|
|
|
- handshakeState handshakeState
|
|
|
- udpChannel ssh.Channel
|
|
|
- trafficRules TrafficRules
|
|
|
- tcpTrafficState trafficState
|
|
|
- udpTrafficState trafficState
|
|
|
- qualityMetrics qualityMetrics
|
|
|
- tcpPortForwardLRU *common.LRUConns
|
|
|
- oslClientSeedState *osl.ClientSeedState
|
|
|
- signalIssueSLOKs chan struct{}
|
|
|
- stopBroadcast chan struct{}
|
|
|
+ sshServer *sshServer
|
|
|
+ tunnelProtocol string
|
|
|
+ sshConn ssh.Conn
|
|
|
+ activityConn *common.ActivityMonitoredConn
|
|
|
+ throttledConn *common.ThrottledConn
|
|
|
+ geoIPData GeoIPData
|
|
|
+ sessionID string
|
|
|
+ supportsServerRequests bool
|
|
|
+ handshakeState handshakeState
|
|
|
+ udpChannel ssh.Channel
|
|
|
+ trafficRules TrafficRules
|
|
|
+ tcpTrafficState trafficState
|
|
|
+ udpTrafficState trafficState
|
|
|
+ qualityMetrics qualityMetrics
|
|
|
+ tcpPortForwardLRU *common.LRUConns
|
|
|
+ oslClientSeedState *osl.ClientSeedState
|
|
|
+ signalIssueSLOKs chan struct{}
|
|
|
+ runContext context.Context
|
|
|
+ stopRunning context.CancelFunc
|
|
|
+ tcpPortForwardDialingAvailableSignal context.CancelFunc
|
|
|
}
|
|
|
|
|
|
type trafficState struct {
|
|
|
- bytesUp int64
|
|
|
- bytesDown int64
|
|
|
- concurrentPortForwardCount int64
|
|
|
- peakConcurrentPortForwardCount int64
|
|
|
- totalPortForwardCount int64
|
|
|
+ bytesUp int64
|
|
|
+ bytesDown int64
|
|
|
+ concurrentDialingPortForwardCount int64
|
|
|
+ peakConcurrentDialingPortForwardCount int64
|
|
|
+ concurrentPortForwardCount int64
|
|
|
+ peakConcurrentPortForwardCount int64
|
|
|
+ totalPortForwardCount int64
|
|
|
}
|
|
|
|
|
|
// qualityMetrics records upstream TCP dial attempts and
|
|
|
@@ -664,10 +706,11 @@ type trafficState struct {
|
|
|
// upstream link. These stats are recorded by each sshClient
|
|
|
// and then reported and reset in sshServer.getLoadStats().
|
|
|
type qualityMetrics struct {
|
|
|
- tcpPortForwardDialedCount int64
|
|
|
- tcpPortForwardDialedDuration time.Duration
|
|
|
- tcpPortForwardFailedCount int64
|
|
|
- tcpPortForwardFailedDuration time.Duration
|
|
|
+ tcpPortForwardDialedCount int64
|
|
|
+ tcpPortForwardDialedDuration time.Duration
|
|
|
+ tcpPortForwardFailedCount int64
|
|
|
+ tcpPortForwardFailedDuration time.Duration
|
|
|
+ tcpPortForwardRejectedDialingLimitCount int64
|
|
|
}
|
|
|
|
|
|
type handshakeState struct {
|
|
|
@@ -678,13 +721,17 @@ type handshakeState struct {
|
|
|
|
|
|
func newSshClient(
|
|
|
sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
|
|
|
+
|
|
|
+ runContext, stopRunning := context.WithCancel(context.Background())
|
|
|
+
|
|
|
return &sshClient{
|
|
|
sshServer: sshServer,
|
|
|
tunnelProtocol: tunnelProtocol,
|
|
|
geoIPData: geoIPData,
|
|
|
tcpPortForwardLRU: common.NewLRUConns(),
|
|
|
signalIssueSLOKs: make(chan struct{}, 1),
|
|
|
- stopBroadcast: make(chan struct{}),
|
|
|
+ runContext: runContext,
|
|
|
+ stopRunning: stopRunning,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -922,14 +969,16 @@ func (sshClient *sshClient) stop() {
|
|
|
sshClient.sshConn.Wait()
|
|
|
}
|
|
|
|
|
|
-// runTunnel handles/dispatches new channel and new requests from the client.
|
|
|
+// runTunnel handles/dispatches new channels and new requests from the client.
|
|
|
// When the SSH client connection closes, both the channels and requests channels
|
|
|
-// will close and runClient will exit.
|
|
|
+// will close and runTunnel will exit.
|
|
|
func (sshClient *sshClient) runTunnel(
|
|
|
channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
|
|
|
|
|
|
waitGroup := new(sync.WaitGroup)
|
|
|
|
|
|
+ // Start client SSH API request handler
|
|
|
+
|
|
|
waitGroup.Add(1)
|
|
|
go func() {
|
|
|
defer waitGroup.Done()
|
|
|
@@ -965,6 +1014,8 @@ func (sshClient *sshClient) runTunnel(
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
+ // Start OSL sender
|
|
|
+
|
|
|
if sshClient.supportsServerRequests {
|
|
|
waitGroup.Add(1)
|
|
|
go func() {
|
|
|
@@ -973,6 +1024,157 @@ func (sshClient *sshClient) runTunnel(
|
|
|
}()
|
|
|
}
|
|
|
|
|
|
+ // Lifecycle of a TCP port forward:
|
|
|
+ //
|
|
|
+ // 1. A "direct-tcpip" SSH request is received from the client.
|
|
|
+ //
|
|
|
+ // A new TCP port forward request is enqueued. The queue delivers TCP port
|
|
|
+ // forward requests to the TCP port forward manager, which enforces the TCP
|
|
|
+ // port forward dial limit.
|
|
|
+ //
|
|
|
+ // Enqueuing new requests allows for reading further SSH requests from the
|
|
|
+ // client without blocking when the dial limit is hit; this is to permit new
|
|
|
+ // UDP/udpgw port forwards to be restablished without delay. The maximum size
|
|
|
+ // of the queue enforces a hard cap on resources consumed by a client in the
|
|
|
+ // pre-dial phase. When the queue is full, new TCP port forwards are
|
|
|
+ // immediately rejected.
|
|
|
+ //
|
|
|
+ // 2. The TCP port forward manager dequeues the request.
|
|
|
+ //
|
|
|
+ // The manager calls dialingTCPPortForward(), which increments
|
|
|
+ // concurrentDialingPortForwardCount, and calls
|
|
|
+ // isTCPDialingPortForwardLimitExceeded() to check the concurrent dialing
|
|
|
+ // count.
|
|
|
+ //
|
|
|
+ // The manager enforces the concurrent TCP dial limit: when at the limit, the
|
|
|
+ // manager blocks waiting for the number of dials to drop below the limit before
|
|
|
+ // dispatching the request to handleTCPPortForward(), which will run in its own
|
|
|
+ // goroutine and will dial and relay the port forward.
|
|
|
+ //
|
|
|
+ // The block delays the current request and also halts dequeuing of subsequent
|
|
|
+ // requests and could ultimately cause requests to be immediately rejected if
|
|
|
+ // the queue fills. These actions are intended to apply back pressure when
|
|
|
+ // upstream network resources are impaired.
|
|
|
+ //
|
|
|
+ // The time spent in the queue is deducted from the port forward's dial timeout.
|
|
|
+ // The time spent blocking while at the dial limit is similarly deducted from
|
|
|
+ // the dial timeout. If the dial timeout has expired before the dial begins, the
|
|
|
+ // port forward is rejected and a stat is recorded.
|
|
|
+ //
|
|
|
+ // 3. handleTCPPortForward() performs the port forward dial and relaying.
|
|
|
+ //
|
|
|
+ // a. Dial the target, using the dial timeout remaining after queue and blocking
|
|
|
+ // time is deducted.
|
|
|
+ //
|
|
|
+ // b. If the dial fails, call abortedTCPPortForward() to decrement
|
|
|
+ // concurrentDialingPortForwardCount, freeing up a dial slot.
|
|
|
+ //
|
|
|
+ // c. If the dial succeeds, call establishedPortForward(), which decrements
|
|
|
+ // concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
|
|
|
+ // the "established" port forward count.
|
|
|
+ //
|
|
|
+ // d. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
|
|
|
+ // concurrentPortForwardCount, the number of _established_ TCP port forwards.
|
|
|
+ // If the limit is exceeded, the LRU established TCP port forward is closed and
|
|
|
+ // the newly established TCP port forward proceeds. This LRU logic allows some
|
|
|
+ // dangling resource consumption (e.g., TIME_WAIT) while providing a better
|
|
|
+ // experience for clients.
|
|
|
+ //
|
|
|
+ // e. Relay data.
|
|
|
+ //
|
|
|
+ // f. Call closedPortForward() which decrements concurrentPortForwardCount and
|
|
|
+ // records bytes transferred.
|
|
|
+
|
|
|
+ // Start the TCP port forward manager
|
|
|
+
|
|
|
+ type newTCPPortForward struct {
|
|
|
+ enqueueTime monotime.Time
|
|
|
+ hostToConnect string
|
|
|
+ portToConnect int
|
|
|
+ newChannel ssh.NewChannel
|
|
|
+ }
|
|
|
+
|
|
|
+ // The queue size is set to the traffic rules MaxTCPDPortForwardCount, which is a
|
|
|
+ // reasonable indication of resource limits per client; when that value is not set,
|
|
|
+ // a default is used.
|
|
|
+ // A limitation: this queue size is set once and doesn't change, for this client,
|
|
|
+ // when traffic rules are reloaded.
|
|
|
+ queueSize := sshClient.getTCPPortForwardLimit()
|
|
|
+ if queueSize == 0 {
|
|
|
+ queueSize = SSH_TCP_PORT_FORWARD_QUEUE_SIZE
|
|
|
+ }
|
|
|
+ newTCPPortForwards := make(chan *newTCPPortForward, queueSize)
|
|
|
+
|
|
|
+ waitGroup.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer waitGroup.Done()
|
|
|
+ for newPortForward := range newTCPPortForwards {
|
|
|
+
|
|
|
+ remainingDialTimeout := SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT - monotime.Since(newPortForward.enqueueTime)
|
|
|
+
|
|
|
+ if remainingDialTimeout <= 0 {
|
|
|
+ sshClient.updateQualityMetricsWithRejectedDialingLimit()
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
+ newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out in queue")
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // Reserve a TCP dialing slot.
|
|
|
+ //
|
|
|
+ // TOCTOU note: important to increment counts _before_ checking limits; otherwise,
|
|
|
+ // the client could potentially consume excess resources by initiating many port
|
|
|
+ // forwards concurrently.
|
|
|
+
|
|
|
+ sshClient.dialingTCPPortForward()
|
|
|
+
|
|
|
+ // When max dials are in progress, wait up to remainingDialTimeout for dialing
|
|
|
+ // to become available. This blocks all dequeing.
|
|
|
+
|
|
|
+ if sshClient.isTCPDialingPortForwardLimitExceeded() {
|
|
|
+ blockStartTime := monotime.Now()
|
|
|
+ ctx, cancelFunc := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
|
|
|
+ sshClient.setTCPPortForwardDialingAvailableSignal(cancelFunc)
|
|
|
+ <-ctx.Done()
|
|
|
+ sshClient.setTCPPortForwardDialingAvailableSignal(nil)
|
|
|
+ remainingDialTimeout -= monotime.Since(blockStartTime)
|
|
|
+ }
|
|
|
+
|
|
|
+ if remainingDialTimeout <= 0 {
|
|
|
+
|
|
|
+ // Release the dialing slot here since handleTCPChannel() won't be called.
|
|
|
+ sshClient.abortedTCPPortForward()
|
|
|
+
|
|
|
+ sshClient.updateQualityMetricsWithRejectedDialingLimit()
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
+ newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out before dialing")
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // Dial and relay the TCP port forward. handleTCPChannel is run in its own worker goroutine.
|
|
|
+ // handleTCPChannel will release the dialing slot reserved by dialingTCPPortForward(); and
|
|
|
+ // will deal with remainingDialTimeout <= 0.
|
|
|
+
|
|
|
+ waitGroup.Add(1)
|
|
|
+ go func(remainingDialTimeout time.Duration, newPortForward *newTCPPortForward) {
|
|
|
+ defer waitGroup.Done()
|
|
|
+ sshClient.handleTCPChannel(
|
|
|
+ remainingDialTimeout,
|
|
|
+ newPortForward.hostToConnect,
|
|
|
+ newPortForward.portToConnect,
|
|
|
+ newPortForward.newChannel)
|
|
|
+ }(remainingDialTimeout, newPortForward)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Handle new channel (port forward) requests from the client.
|
|
|
+ //
|
|
|
+ // udpgw client connections are dispatched immediately (clients use this for
|
|
|
+ // DNS, so it's essential to not block; and only one udpgw connection is
|
|
|
+ // retained at a time).
|
|
|
+ //
|
|
|
+ // All other TCP port forwards are dispatched via the TCP port forward
|
|
|
+ // manager queue.
|
|
|
+
|
|
|
for newChannel := range channels {
|
|
|
|
|
|
if newChannel.ChannelType() != "direct-tcpip" {
|
|
|
@@ -980,18 +1182,66 @@ func (sshClient *sshClient) runTunnel(
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- // process each port forward concurrently
|
|
|
- waitGroup.Add(1)
|
|
|
- go func(channel ssh.NewChannel) {
|
|
|
- defer waitGroup.Done()
|
|
|
- sshClient.handleNewPortForwardChannel(channel)
|
|
|
- }(newChannel)
|
|
|
+ // http://tools.ietf.org/html/rfc4254#section-7.2
|
|
|
+ var directTcpipExtraData struct {
|
|
|
+ HostToConnect string
|
|
|
+ PortToConnect uint32
|
|
|
+ OriginatorIPAddress string
|
|
|
+ OriginatorPort uint32
|
|
|
+ }
|
|
|
+
|
|
|
+ err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
|
|
|
+ if err != nil {
|
|
|
+ sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // Intercept TCP port forwards to a specified udpgw server and handle directly.
|
|
|
+ // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
|
|
|
+ isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
|
|
|
+ sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
|
|
|
+ net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
|
|
|
+
|
|
|
+ if isUDPChannel {
|
|
|
+
|
|
|
+ // Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
|
|
|
+ // own worker goroutine.
|
|
|
+
|
|
|
+ waitGroup.Add(1)
|
|
|
+ go func(channel ssh.NewChannel) {
|
|
|
+ defer waitGroup.Done()
|
|
|
+ sshClient.handleUDPChannel(channel)
|
|
|
+ }(newChannel)
|
|
|
+
|
|
|
+ } else {
|
|
|
+
|
|
|
+ // Dispatch via TCP port forward manager. When the queue is full, the channel
|
|
|
+ // is immediately rejected.
|
|
|
+
|
|
|
+ tcpPortForward := &newTCPPortForward{
|
|
|
+ enqueueTime: monotime.Now(),
|
|
|
+ hostToConnect: directTcpipExtraData.HostToConnect,
|
|
|
+ portToConnect: int(directTcpipExtraData.PortToConnect),
|
|
|
+ newChannel: newChannel,
|
|
|
+ }
|
|
|
+
|
|
|
+ select {
|
|
|
+ case newTCPPortForwards <- tcpPortForward:
|
|
|
+ default:
|
|
|
+ sshClient.updateQualityMetricsWithRejectedDialingLimit()
|
|
|
+ sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "TCP port forward dial queue full")
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// The channel loop is interrupted by a client
|
|
|
// disconnect or by calling sshClient.stop().
|
|
|
|
|
|
- close(sshClient.stopBroadcast)
|
|
|
+ // Stop the TCP port forward manager
|
|
|
+ close(newTCPPortForwards)
|
|
|
+
|
|
|
+ // Stop all other worker goroutines
|
|
|
+ sshClient.stopRunning()
|
|
|
|
|
|
waitGroup.Wait()
|
|
|
}
|
|
|
@@ -1021,10 +1271,12 @@ func (sshClient *sshClient) logTunnel() {
|
|
|
logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
|
|
|
logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
|
|
|
logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
|
|
|
+ logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
|
|
|
logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
|
|
|
logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
|
|
|
logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
|
|
|
logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
|
|
|
+ // sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
|
|
|
logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
|
|
|
logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
|
|
|
|
|
|
@@ -1040,7 +1292,7 @@ func (sshClient *sshClient) runOSLSender() {
|
|
|
// TODO: use reflect.SelectCase, and optionally await timer here?
|
|
|
select {
|
|
|
case <-sshClient.signalIssueSLOKs:
|
|
|
- case <-sshClient.stopBroadcast:
|
|
|
+ case <-sshClient.runContext.Done():
|
|
|
return
|
|
|
}
|
|
|
|
|
|
@@ -1058,7 +1310,7 @@ func (sshClient *sshClient) runOSLSender() {
|
|
|
select {
|
|
|
case <-retryTimer.C:
|
|
|
case <-sshClient.signalIssueSLOKs:
|
|
|
- case <-sshClient.stopBroadcast:
|
|
|
+ case <-sshClient.runContext.Done():
|
|
|
retryTimer.Stop()
|
|
|
return
|
|
|
}
|
|
|
@@ -1119,36 +1371,6 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason s
|
|
|
newChannel.Reject(reason, reason.String())
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
|
|
|
-
|
|
|
- // http://tools.ietf.org/html/rfc4254#section-7.2
|
|
|
- var directTcpipExtraData struct {
|
|
|
- HostToConnect string
|
|
|
- PortToConnect uint32
|
|
|
- OriginatorIPAddress string
|
|
|
- OriginatorPort uint32
|
|
|
- }
|
|
|
-
|
|
|
- err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
|
|
|
- if err != nil {
|
|
|
- sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- // Intercept TCP port forwards to a specified udpgw server and handle directly.
|
|
|
- // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
|
|
|
- isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
|
|
|
- sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
|
|
|
- net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
|
|
|
-
|
|
|
- if isUDPChannel {
|
|
|
- sshClient.handleUDPChannel(newChannel)
|
|
|
- } else {
|
|
|
- sshClient.handleTCPChannel(
|
|
|
- directTcpipExtraData.HostToConnect, int(directTcpipExtraData.PortToConnect), newChannel)
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
// setHandshakeState records that a client has completed a handshake API request.
|
|
|
// Some parameters from the handshake request may be used in future traffic rule
|
|
|
// selection. Port forwards are disallowed until a handshake is complete. The
|
|
|
@@ -1270,13 +1492,19 @@ func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
|
|
|
-
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
|
|
|
return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
|
|
|
}
|
|
|
|
|
|
+func (sshClient *sshClient) setTCPPortForwardDialingAvailableSignal(signal context.CancelFunc) {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ sshClient.tcpPortForwardDialingAvailableSignal = signal
|
|
|
+}
|
|
|
+
|
|
|
const (
|
|
|
portForwardTypeTCP = iota
|
|
|
portForwardTypeUDP
|
|
|
@@ -1335,29 +1563,72 @@ func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
+func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ state := &sshClient.tcpTrafficState
|
|
|
+ max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
|
|
|
+
|
|
|
+ if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
func (sshClient *sshClient) isPortForwardLimitExceeded(
|
|
|
- portForwardType int) (int, bool) {
|
|
|
+ portForwardType int) bool {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
|
|
|
- var maxPortForwardCount int
|
|
|
+ var max int
|
|
|
var state *trafficState
|
|
|
if portForwardType == portForwardTypeTCP {
|
|
|
- maxPortForwardCount = *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
+ max = *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
state = &sshClient.tcpTrafficState
|
|
|
} else {
|
|
|
- maxPortForwardCount = *sshClient.trafficRules.MaxUDPPortForwardCount
|
|
|
+ max = *sshClient.trafficRules.MaxUDPPortForwardCount
|
|
|
state = &sshClient.udpTrafficState
|
|
|
}
|
|
|
|
|
|
- if maxPortForwardCount > 0 && state.concurrentPortForwardCount >= int64(maxPortForwardCount) {
|
|
|
- return maxPortForwardCount, true
|
|
|
+ if max > 0 && state.concurrentPortForwardCount >= int64(max) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) getTCPPortForwardLimit() int {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ return *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) dialingTCPPortForward() {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ state := &sshClient.tcpTrafficState
|
|
|
+
|
|
|
+ state.concurrentDialingPortForwardCount += 1
|
|
|
+ if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
|
|
|
+ state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
|
|
|
}
|
|
|
- return maxPortForwardCount, false
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) openedPortForward(
|
|
|
+func (sshClient *sshClient) abortedTCPPortForward() {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) establishedPortForward(
|
|
|
portForwardType int) {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
@@ -1366,6 +1637,18 @@ func (sshClient *sshClient) openedPortForward(
|
|
|
var state *trafficState
|
|
|
if portForwardType == portForwardTypeTCP {
|
|
|
state = &sshClient.tcpTrafficState
|
|
|
+
|
|
|
+ // Assumes TCP port forwards called dialingTCPPortForward
|
|
|
+ state.concurrentDialingPortForwardCount -= 1
|
|
|
+
|
|
|
+ if sshClient.tcpPortForwardDialingAvailableSignal != nil {
|
|
|
+
|
|
|
+ max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
|
|
|
+ if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
|
|
|
+ sshClient.tcpPortForwardDialingAvailableSignal()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
} else {
|
|
|
state = &sshClient.udpTrafficState
|
|
|
}
|
|
|
@@ -1377,7 +1660,25 @@ func (sshClient *sshClient) openedPortForward(
|
|
|
state.totalPortForwardCount += 1
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) updateQualityMetrics(
|
|
|
+func (sshClient *sshClient) closedPortForward(
|
|
|
+ portForwardType int, bytesUp, bytesDown int64) {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ var state *trafficState
|
|
|
+ if portForwardType == portForwardTypeTCP {
|
|
|
+ state = &sshClient.tcpTrafficState
|
|
|
+ } else {
|
|
|
+ state = &sshClient.udpTrafficState
|
|
|
+ }
|
|
|
+
|
|
|
+ state.concurrentPortForwardCount -= 1
|
|
|
+ state.bytesUp += bytesUp
|
|
|
+ state.bytesDown += bytesDown
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) updateQualityMetricsWithDialResult(
|
|
|
tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
@@ -1393,29 +1694,33 @@ func (sshClient *sshClient) updateQualityMetrics(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) closedPortForward(
|
|
|
- portForwardType int, bytesUp, bytesDown int64) {
|
|
|
+func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
|
|
|
- var state *trafficState
|
|
|
- if portForwardType == portForwardTypeTCP {
|
|
|
- state = &sshClient.tcpTrafficState
|
|
|
- } else {
|
|
|
- state = &sshClient.udpTrafficState
|
|
|
- }
|
|
|
-
|
|
|
- state.concurrentPortForwardCount -= 1
|
|
|
- state.bytesUp += bytesUp
|
|
|
- state.bytesDown += bytesDown
|
|
|
+ sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) handleTCPChannel(
|
|
|
+ remainingDialTimeout time.Duration,
|
|
|
hostToConnect string,
|
|
|
portToConnect int,
|
|
|
newChannel ssh.NewChannel) {
|
|
|
|
|
|
+ // Assumptions:
|
|
|
+ // - sshClient.dialingTCPPortForward() has been called
|
|
|
+ // - remainingDialTimeout > 0
|
|
|
+
|
|
|
+ established := false
|
|
|
+ defer func() {
|
|
|
+ if !established {
|
|
|
+ sshClient.abortedTCPPortForward()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Transparently redirect web API request connections.
|
|
|
+
|
|
|
isWebServerPortForward := false
|
|
|
config := sshClient.sshServer.support.Config
|
|
|
if config.WebServerPortForwardAddress != "" {
|
|
|
@@ -1432,72 +1737,121 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- type lookupIPResult struct {
|
|
|
- IP net.IP
|
|
|
- err error
|
|
|
- }
|
|
|
- lookupResultChannel := make(chan *lookupIPResult, 1)
|
|
|
+ // Dial the remote address.
|
|
|
+ //
|
|
|
+ // Hostname resolution is performed explicitly, as a seperate step, as the target IP
|
|
|
+ // address is used for traffic rules (AllowSubnets) and OSL seed progress.
|
|
|
+ //
|
|
|
+ // Contexts are used for cancellation (via sshClient.runContext, which is cancelled
|
|
|
+ // when the client is stopping) and timeouts.
|
|
|
|
|
|
- go func() {
|
|
|
- // TODO: explicit timeout for DNS resolution?
|
|
|
- IPs, err := net.LookupIP(hostToConnect)
|
|
|
- // TODO: shuffle list to try other IPs
|
|
|
- // TODO: IPv6 support
|
|
|
- var IP net.IP
|
|
|
- for _, ip := range IPs {
|
|
|
- if ip.To4() != nil {
|
|
|
- IP = ip
|
|
|
- }
|
|
|
- }
|
|
|
- if err == nil && IP == nil {
|
|
|
- err = errors.New("no IP address")
|
|
|
+ dialStartTime := monotime.Now()
|
|
|
+
|
|
|
+ log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
|
|
|
+
|
|
|
+ ctx, _ := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
|
|
|
+ IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
|
|
|
+
|
|
|
+ // TODO: shuffle list to try other IPs?
|
|
|
+ // TODO: IPv6 support
|
|
|
+ var IP net.IP
|
|
|
+ for _, ip := range IPs {
|
|
|
+ if ip.IP.To4() != nil {
|
|
|
+ IP = ip.IP
|
|
|
+ break
|
|
|
}
|
|
|
- lookupResultChannel <- &lookupIPResult{IP, err}
|
|
|
- }()
|
|
|
+ }
|
|
|
+ if err == nil && IP == nil {
|
|
|
+ err = errors.New("no IP address")
|
|
|
+ }
|
|
|
|
|
|
- var lookupResult *lookupIPResult
|
|
|
- select {
|
|
|
- case lookupResult = <-lookupResultChannel:
|
|
|
- case <-sshClient.stopBroadcast:
|
|
|
- // Note: may leave LookupIP in progress
|
|
|
+ resolveElapsedTime := monotime.Since(dialStartTime)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+
|
|
|
+ // Record a port forward failure
|
|
|
+ sshClient.updateQualityMetricsWithDialResult(true, resolveElapsedTime)
|
|
|
+
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
+ newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if lookupResult.err != nil {
|
|
|
+ remainingDialTimeout -= resolveElapsedTime
|
|
|
+
|
|
|
+ if remainingDialTimeout <= 0 {
|
|
|
sshClient.rejectNewChannel(
|
|
|
- newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", lookupResult.err))
|
|
|
+ newChannel, ssh.Prohibited, "TCP port forward timed out resolving")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ // Enforce traffic rules, using the resolved IP address.
|
|
|
+
|
|
|
if !isWebServerPortForward &&
|
|
|
!sshClient.isPortForwardPermitted(
|
|
|
portForwardTypeTCP,
|
|
|
false,
|
|
|
- lookupResult.IP,
|
|
|
+ IP,
|
|
|
portToConnect) {
|
|
|
|
|
|
+ // Note: not recording a port forward failure in this case
|
|
|
+
|
|
|
sshClient.rejectNewChannel(
|
|
|
newChannel, ssh.Prohibited, "port forward not permitted")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ // TCP dial.
|
|
|
+
|
|
|
+ remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
|
|
|
+
|
|
|
+ log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
|
+
|
|
|
+ ctx, _ = context.WithTimeout(sshClient.runContext, remainingDialTimeout)
|
|
|
+ fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
|
|
|
+
|
|
|
+ // Record port forward success or failure
|
|
|
+ sshClient.updateQualityMetricsWithDialResult(err == nil, monotime.Since(dialStartTime))
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+
|
|
|
+ // Monitor for low resource error conditions
|
|
|
+ sshClient.sshServer.monitorPortForwardDialError(err)
|
|
|
+
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
+ newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", err))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // The upstream TCP port forward connection has been established. Schedule
|
|
|
+ // some cleanup and notify the SSH client that the channel is accepted.
|
|
|
+
|
|
|
+ defer fwdConn.Close()
|
|
|
+
|
|
|
+ fwdChannel, requests, err := newChannel.Accept()
|
|
|
+ if err != nil {
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ go ssh.DiscardRequests(requests)
|
|
|
+ defer fwdChannel.Close()
|
|
|
+
|
|
|
+ // Release the dialing slot and acquire an established slot.
|
|
|
+
|
|
|
+ // "established = true" cancels the deferred abortedTCPPortForward()
|
|
|
+ established = true
|
|
|
+ sshClient.establishedPortForward(portForwardTypeTCP)
|
|
|
+
|
|
|
var bytesUp, bytesDown int64
|
|
|
- sshClient.openedPortForward(portForwardTypeTCP)
|
|
|
defer func() {
|
|
|
sshClient.closedPortForward(
|
|
|
portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
|
|
|
}()
|
|
|
|
|
|
- // TOCTOU note: important to increment the port forward count (via
|
|
|
- // openPortForward) _before_ checking isPortForwardLimitExceeded
|
|
|
- // otherwise, the client could potentially consume excess resources
|
|
|
- // by initiating many port forwards concurrently.
|
|
|
- // TODO: close LRU connection (after successful Dial) instead of
|
|
|
- // rejecting new connection?
|
|
|
- if maxCount, exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
|
|
|
+ if exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
|
|
|
|
|
|
// Close the oldest TCP port forward. CloseOldest() closes
|
|
|
- // the conn and the port forward's goroutine will complete
|
|
|
+ // the conn and the port forward's goroutines will complete
|
|
|
// the cleanup asynchronously.
|
|
|
//
|
|
|
// Some known limitations:
|
|
|
@@ -1510,91 +1864,25 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
// asynchronous, there exists a possibility that a client can consume
|
|
|
// more than max port forward resources -- just not upstream sockets.
|
|
|
//
|
|
|
- // - An LRU list entry for this port forward is not added until
|
|
|
- // after the dial completes, but the port forward is counted
|
|
|
- // towards max limits. This means many dials in progress will
|
|
|
- // put established connections in jeopardy.
|
|
|
- //
|
|
|
- // - We're closing the oldest open connection _before_ successfully
|
|
|
- // dialing the new port forward. This means we are potentially
|
|
|
- // discarding a good connection to make way for a failed connection.
|
|
|
- // We cannot simply dial first and still maintain a limit on
|
|
|
- // resources used, so to address this we'd need to add some
|
|
|
- // accounting for connections still establishing.
|
|
|
+ // - Closed sockets will enter the TIME_WAIT state, consuming some
|
|
|
+ // resources.
|
|
|
|
|
|
sshClient.tcpPortForwardLRU.CloseOldest()
|
|
|
|
|
|
- log.WithContextFields(
|
|
|
- LogFields{
|
|
|
- "maxCount": maxCount,
|
|
|
- }).Debug("closed LRU TCP port forward")
|
|
|
+ log.WithContext().Debug("closed LRU TCP port forward")
|
|
|
}
|
|
|
|
|
|
- // Dial the target remote address. This is done in a goroutine to
|
|
|
- // ensure the shutdown signal is handled immediately.
|
|
|
-
|
|
|
- remoteAddr := net.JoinHostPort(lookupResult.IP.String(), strconv.Itoa(portToConnect))
|
|
|
-
|
|
|
- log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
|
-
|
|
|
- type dialTCPResult struct {
|
|
|
- conn net.Conn
|
|
|
- err error
|
|
|
- }
|
|
|
- dialResultChannel := make(chan *dialTCPResult, 1)
|
|
|
-
|
|
|
- dialStartTime := monotime.Now()
|
|
|
-
|
|
|
- go func() {
|
|
|
- // TODO: on EADDRNOTAVAIL, temporarily suspend new clients
|
|
|
- conn, err := net.DialTimeout(
|
|
|
- "tcp", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
|
|
|
- dialResultChannel <- &dialTCPResult{conn, err}
|
|
|
- }()
|
|
|
-
|
|
|
- var dialResult *dialTCPResult
|
|
|
- select {
|
|
|
- case dialResult = <-dialResultChannel:
|
|
|
- case <-sshClient.stopBroadcast:
|
|
|
- // Note: may leave Dial in progress
|
|
|
- // TODO: use net.Dialer.DialContext to be able to cancel
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- sshClient.updateQualityMetrics(
|
|
|
- dialResult.err == nil, monotime.Since(dialStartTime))
|
|
|
-
|
|
|
- if dialResult.err != nil {
|
|
|
- sshClient.rejectNewChannel(
|
|
|
- newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", dialResult.err))
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- // The upstream TCP port forward connection has been established. Schedule
|
|
|
- // some cleanup and notify the SSH client that the channel is accepted.
|
|
|
-
|
|
|
- fwdConn := dialResult.conn
|
|
|
- defer fwdConn.Close()
|
|
|
-
|
|
|
- fwdChannel, requests, err := newChannel.Accept()
|
|
|
- if err != nil {
|
|
|
- log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
|
- return
|
|
|
- }
|
|
|
- go ssh.DiscardRequests(requests)
|
|
|
- defer fwdChannel.Close()
|
|
|
+ lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
|
|
|
+ defer lruEntry.Remove()
|
|
|
|
|
|
// ActivityMonitoredConn monitors the TCP port forward I/O and updates
|
|
|
// its LRU status. ActivityMonitoredConn also times out I/O on the port
|
|
|
// forward if both reads and writes have been idle for the specified
|
|
|
// duration.
|
|
|
|
|
|
- lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
|
|
|
- defer lruEntry.Remove()
|
|
|
-
|
|
|
// Ensure nil interface if newClientSeedPortForward returns nil
|
|
|
var updater common.ActivityUpdater
|
|
|
- seedUpdater := sshClient.newClientSeedPortForward(lookupResult.IP)
|
|
|
+ seedUpdater := sshClient.newClientSeedPortForward(IP)
|
|
|
if seedUpdater != nil {
|
|
|
updater = seedUpdater
|
|
|
}
|