|
|
@@ -20,6 +20,7 @@
|
|
|
package server
|
|
|
|
|
|
import (
|
|
|
+ "context"
|
|
|
"crypto/subtle"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
@@ -456,6 +457,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
protocolStats[tunnelProtocol]["ALL"] = make(map[string]int64)
|
|
|
protocolStats[tunnelProtocol]["ALL"]["accepted_clients"] = 0
|
|
|
protocolStats[tunnelProtocol]["ALL"]["established_clients"] = 0
|
|
|
+ protocolStats[tunnelProtocol]["ALL"]["dialing_tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol]["ALL"]["tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol]["ALL"]["total_tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol]["ALL"]["udp_port_forwards"] = 0
|
|
|
@@ -472,6 +474,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
protocolStats[tunnelProtocol][region] = make(map[string]int64)
|
|
|
protocolStats[tunnelProtocol][region]["accepted_clients"] = 0
|
|
|
protocolStats[tunnelProtocol][region]["established_clients"] = 0
|
|
|
+ protocolStats[tunnelProtocol][region]["dialing_tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol][region]["tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol][region]["total_tcp_port_forwards"] = 0
|
|
|
protocolStats[tunnelProtocol][region]["udp_port_forwards"] = 0
|
|
|
@@ -496,6 +499,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
protocolStats[client.tunnelProtocol][region] = make(map[string]int64)
|
|
|
protocolStats[client.tunnelProtocol][region]["accepted_clients"] = 0
|
|
|
protocolStats[client.tunnelProtocol][region]["established_clients"] = 0
|
|
|
+ protocolStats[client.tunnelProtocol][region]["dialing_tcp_port_forwards"] = 0
|
|
|
protocolStats[client.tunnelProtocol][region]["tcp_port_forwards"] = 0
|
|
|
protocolStats[client.tunnelProtocol][region]["total_tcp_port_forwards"] = 0
|
|
|
protocolStats[client.tunnelProtocol][region]["udp_port_forwards"] = 0
|
|
|
@@ -505,8 +509,10 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
// Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
|
|
|
protocolStats[client.tunnelProtocol][region]["established_clients"] += 1
|
|
|
|
|
|
+ protocolStats[client.tunnelProtocol][region]["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
|
|
|
protocolStats[client.tunnelProtocol][region]["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
|
|
|
protocolStats[client.tunnelProtocol][region]["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
|
|
|
+ // client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
|
|
|
protocolStats[client.tunnelProtocol][region]["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
|
|
|
protocolStats[client.tunnelProtocol][region]["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
|
|
|
|
|
|
@@ -532,6 +538,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
|
|
|
allProtocolsStats := make(map[string]int64)
|
|
|
allProtocolsStats["accepted_clients"] = 0
|
|
|
allProtocolsStats["established_clients"] = 0
|
|
|
+ allProtocolsStats["dialing_tcp_port_forwards"] = 0
|
|
|
allProtocolsStats["tcp_port_forwards"] = 0
|
|
|
allProtocolsStats["total_tcp_port_forwards"] = 0
|
|
|
allProtocolsStats["udp_port_forwards"] = 0
|
|
|
@@ -629,7 +636,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
sshClient.run(clientConn)
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) handlePortForwardDialError(err error) {
|
|
|
+func (sshServer *sshServer) monitorPortForwardDialError(err error) {
|
|
|
|
|
|
// "err" is the error returned from a failed TCP or UDP port
|
|
|
// forward dial. Certain system error codes indicate low resource
|
|
|
@@ -675,15 +682,18 @@ type sshClient struct {
|
|
|
tcpPortForwardLRU *common.LRUConns
|
|
|
oslClientSeedState *osl.ClientSeedState
|
|
|
signalIssueSLOKs chan struct{}
|
|
|
- stopBroadcast chan struct{}
|
|
|
+ runContext context.Context
|
|
|
+ stopRunning context.CancelFunc
|
|
|
}
|
|
|
|
|
|
type trafficState struct {
|
|
|
- bytesUp int64
|
|
|
- bytesDown int64
|
|
|
- concurrentPortForwardCount int64
|
|
|
- peakConcurrentPortForwardCount int64
|
|
|
- totalPortForwardCount int64
|
|
|
+ bytesUp int64
|
|
|
+ bytesDown int64
|
|
|
+ concurrentDialingPortForwardCount int64
|
|
|
+ peakConcurrentDialingPortForwardCount int64
|
|
|
+ concurrentPortForwardCount int64
|
|
|
+ peakConcurrentPortForwardCount int64
|
|
|
+ totalPortForwardCount int64
|
|
|
}
|
|
|
|
|
|
// qualityMetrics records upstream TCP dial attempts and
|
|
|
@@ -706,13 +716,17 @@ type handshakeState struct {
|
|
|
|
|
|
func newSshClient(
|
|
|
sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
|
|
|
+
|
|
|
+ runContext, stopRunning := context.WithCancel(context.Background())
|
|
|
+
|
|
|
return &sshClient{
|
|
|
sshServer: sshServer,
|
|
|
tunnelProtocol: tunnelProtocol,
|
|
|
geoIPData: geoIPData,
|
|
|
tcpPortForwardLRU: common.NewLRUConns(),
|
|
|
signalIssueSLOKs: make(chan struct{}, 1),
|
|
|
- stopBroadcast: make(chan struct{}),
|
|
|
+ runContext: runContext,
|
|
|
+ stopRunning: stopRunning,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1019,7 +1033,7 @@ func (sshClient *sshClient) runTunnel(
|
|
|
// The channel loop is interrupted by a client
|
|
|
// disconnect or by calling sshClient.stop().
|
|
|
|
|
|
- close(sshClient.stopBroadcast)
|
|
|
+ sshClient.stopRunning()
|
|
|
|
|
|
waitGroup.Wait()
|
|
|
}
|
|
|
@@ -1049,10 +1063,12 @@ func (sshClient *sshClient) logTunnel() {
|
|
|
logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
|
|
|
logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
|
|
|
logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
|
|
|
+ logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
|
|
|
logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
|
|
|
logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
|
|
|
logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
|
|
|
logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
|
|
|
+ // sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
|
|
|
logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
|
|
|
logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
|
|
|
|
|
|
@@ -1068,7 +1084,7 @@ func (sshClient *sshClient) runOSLSender() {
|
|
|
// TODO: use reflect.SelectCase, and optionally await timer here?
|
|
|
select {
|
|
|
case <-sshClient.signalIssueSLOKs:
|
|
|
- case <-sshClient.stopBroadcast:
|
|
|
+ case <-sshClient.runContext.Done():
|
|
|
return
|
|
|
}
|
|
|
|
|
|
@@ -1086,7 +1102,7 @@ func (sshClient *sshClient) runOSLSender() {
|
|
|
select {
|
|
|
case <-retryTimer.C:
|
|
|
case <-sshClient.signalIssueSLOKs:
|
|
|
- case <-sshClient.stopBroadcast:
|
|
|
+ case <-sshClient.runContext.Done():
|
|
|
retryTimer.Stop()
|
|
|
return
|
|
|
}
|
|
|
@@ -1363,29 +1379,64 @@ func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
+func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ state := &sshClient.tcpTrafficState
|
|
|
+ max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
|
|
|
+
|
|
|
+ if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
func (sshClient *sshClient) isPortForwardLimitExceeded(
|
|
|
- portForwardType int) (int, bool) {
|
|
|
+ portForwardType int) bool {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
|
|
|
- var maxPortForwardCount int
|
|
|
+ var max int
|
|
|
var state *trafficState
|
|
|
if portForwardType == portForwardTypeTCP {
|
|
|
- maxPortForwardCount = *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
+ max = *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
state = &sshClient.tcpTrafficState
|
|
|
} else {
|
|
|
- maxPortForwardCount = *sshClient.trafficRules.MaxUDPPortForwardCount
|
|
|
+ max = *sshClient.trafficRules.MaxUDPPortForwardCount
|
|
|
state = &sshClient.udpTrafficState
|
|
|
}
|
|
|
|
|
|
- if maxPortForwardCount > 0 && state.concurrentPortForwardCount >= int64(maxPortForwardCount) {
|
|
|
- return maxPortForwardCount, true
|
|
|
+ if max > 0 && state.concurrentPortForwardCount >= int64(max) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) dialingTCPPortForward() {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ state := &sshClient.tcpTrafficState
|
|
|
+
|
|
|
+ state.concurrentDialingPortForwardCount += 1
|
|
|
+ if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
|
|
|
+ state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
|
|
|
}
|
|
|
- return maxPortForwardCount, false
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) openedPortForward(
|
|
|
+func (sshClient *sshClient) failedTCPPortForward() {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) establishedPortForward(
|
|
|
portForwardType int) {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
@@ -1394,6 +1445,10 @@ func (sshClient *sshClient) openedPortForward(
|
|
|
var state *trafficState
|
|
|
if portForwardType == portForwardTypeTCP {
|
|
|
state = &sshClient.tcpTrafficState
|
|
|
+
|
|
|
+ // Assumes TCP port forwards called dialingTCPPortForward
|
|
|
+ state.concurrentDialingPortForwardCount -= 1
|
|
|
+
|
|
|
} else {
|
|
|
state = &sshClient.udpTrafficState
|
|
|
}
|
|
|
@@ -1405,22 +1460,6 @@ func (sshClient *sshClient) openedPortForward(
|
|
|
state.totalPortForwardCount += 1
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) updateQualityMetrics(
|
|
|
- tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
|
|
|
-
|
|
|
- sshClient.Lock()
|
|
|
- defer sshClient.Unlock()
|
|
|
-
|
|
|
- if tcpPortForwardDialSuccess {
|
|
|
- sshClient.qualityMetrics.tcpPortForwardDialedCount += 1
|
|
|
- sshClient.qualityMetrics.tcpPortForwardDialedDuration += dialDuration
|
|
|
-
|
|
|
- } else {
|
|
|
- sshClient.qualityMetrics.tcpPortForwardFailedCount += 1
|
|
|
- sshClient.qualityMetrics.tcpPortForwardFailedDuration += dialDuration
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
func (sshClient *sshClient) closedPortForward(
|
|
|
portForwardType int, bytesUp, bytesDown int64) {
|
|
|
|
|
|
@@ -1439,11 +1478,78 @@ func (sshClient *sshClient) closedPortForward(
|
|
|
state.bytesDown += bytesDown
|
|
|
}
|
|
|
|
|
|
+func (sshClient *sshClient) updateQualityMetrics(
|
|
|
+ tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ if tcpPortForwardDialSuccess {
|
|
|
+ sshClient.qualityMetrics.tcpPortForwardDialedCount += 1
|
|
|
+ sshClient.qualityMetrics.tcpPortForwardDialedDuration += dialDuration
|
|
|
+
|
|
|
+ } else {
|
|
|
+ sshClient.qualityMetrics.tcpPortForwardFailedCount += 1
|
|
|
+ sshClient.qualityMetrics.tcpPortForwardFailedDuration += dialDuration
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (sshClient *sshClient) handleTCPChannel(
|
|
|
hostToConnect string,
|
|
|
portToConnect int,
|
|
|
newChannel ssh.NewChannel) {
|
|
|
|
|
|
+ // Lifecycle of a TCP port forward:
|
|
|
+ //
|
|
|
+ // 1. Call dialingTCPPortForward(), which increments concurrentDialingPortForwardCount
|
|
|
+ //
|
|
|
+ // 2. Check isTCPDialingPortForwardLimitExceeded(), which enforces the configurable
|
|
|
+ // limit on concurrentDialingPortForwardCount. If the limit is exceeded, the port
|
|
|
+ // forward is _immediately rejected_; this hard limit applies back pressure when
|
|
|
+ // upstream network resources are impaired.
|
|
|
+ //
|
|
|
+ // TOCTOU note: important to increment the port forward counts _before_
|
|
|
+ // checking limits; otherwise, the client could potentially consume excess
|
|
|
+ // resources by initiating many port forwards concurrently.
|
|
|
+ //
|
|
|
+ // 3. Dial the target.
|
|
|
+ //
|
|
|
+ // 4. If the dial fails, call failedTCPPortForward() to decrement
|
|
|
+ // concurrentDialingPortForwardCount, freeing up a dial slot.
|
|
|
+ //
|
|
|
+ // 5. If the dial succeeds, call establishedPortForward(), which decrements
|
|
|
+ // concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
|
|
|
+ // the "established" port forward count.
|
|
|
+ //
|
|
|
+ // 6. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
|
|
|
+ // concurrentPortForwardCount, the number of _established_ TCP port forwards.
|
|
|
+ // If the limit is exceeded, the LRU established TCP port forward is closed and
|
|
|
+ // the newly established TCP port forward proceeds. This LRU logic allows some
|
|
|
+ // dangling resource consumption (e.g., TIME_WAIT) while providing a better
|
|
|
+ // experience for clients.
|
|
|
+ //
|
|
|
+ // 7. Relay data.
|
|
|
+ //
|
|
|
+ // 8. Call closedPortForward() which decrements concurrentPortForwardCount and
|
|
|
+ // records bytes transferred.
|
|
|
+
|
|
|
+ sshClient.dialingTCPPortForward()
|
|
|
+ established := false
|
|
|
+ defer func() {
|
|
|
+ if !established {
|
|
|
+ sshClient.failedTCPPortForward()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ if exceeded := sshClient.isTCPDialingPortForwardLimitExceeded(); exceeded {
|
|
|
+
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
+ newChannel, ssh.Prohibited, "dialing port forward limit exceeded")
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Transparently redirect web API request connections.
|
|
|
+
|
|
|
isWebServerPortForward := false
|
|
|
config := sshClient.sshServer.support.Config
|
|
|
if config.WebServerPortForwardAddress != "" {
|
|
|
@@ -1460,70 +1566,112 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- type lookupIPResult struct {
|
|
|
- IP net.IP
|
|
|
- err error
|
|
|
- }
|
|
|
- lookupResultChannel := make(chan *lookupIPResult, 1)
|
|
|
+ // Dial the remote address.
|
|
|
+ //
|
|
|
+ // Hostname resolution is performed explicitly, as a seperate step, as the target IP
|
|
|
+ // address is used for traffic rules (AllowSubnets) and OSL seed progress.
|
|
|
+ //
|
|
|
+ // Contexts are used for cancellation (via sshClient.runContext, which is cancelled
|
|
|
+ // when the client is stopping) and timeouts.
|
|
|
|
|
|
- go func() {
|
|
|
- // TODO: explicit timeout for DNS resolution?
|
|
|
- IPs, err := net.LookupIP(hostToConnect)
|
|
|
- // TODO: shuffle list to try other IPs
|
|
|
- // TODO: IPv6 support
|
|
|
- var IP net.IP
|
|
|
- for _, ip := range IPs {
|
|
|
- if ip.To4() != nil {
|
|
|
- IP = ip
|
|
|
- }
|
|
|
- }
|
|
|
- if err == nil && IP == nil {
|
|
|
- err = errors.New("no IP address")
|
|
|
- }
|
|
|
- lookupResultChannel <- &lookupIPResult{IP, err}
|
|
|
- }()
|
|
|
+ dialStartTime := monotime.Now()
|
|
|
|
|
|
- var lookupResult *lookupIPResult
|
|
|
- select {
|
|
|
- case lookupResult = <-lookupResultChannel:
|
|
|
- case <-sshClient.stopBroadcast:
|
|
|
- // Note: may leave LookupIP in progress
|
|
|
- return
|
|
|
+ log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
|
|
|
+
|
|
|
+ ctx, _ := context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT)
|
|
|
+ IPs, err := (&net.Resolver{}).LookupIPAddr(
|
|
|
+ ctx, hostToConnect)
|
|
|
+
|
|
|
+ // TODO: shuffle list to try other IPs
|
|
|
+ // TODO: IPv6 support
|
|
|
+ var IP net.IP
|
|
|
+ for _, ip := range IPs {
|
|
|
+ if ip.IP.To4() != nil {
|
|
|
+ IP = ip.IP
|
|
|
+ }
|
|
|
}
|
|
|
+ if err == nil && IP == nil {
|
|
|
+ err = errors.New("no IP address")
|
|
|
+ }
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+
|
|
|
+ // Record a port forward failure
|
|
|
+ sshClient.updateQualityMetrics(true, monotime.Since(dialStartTime))
|
|
|
|
|
|
- if lookupResult.err != nil {
|
|
|
sshClient.rejectNewChannel(
|
|
|
- newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", lookupResult.err))
|
|
|
+ newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ // Enforce traffic rules, using the resolved IP address.
|
|
|
+
|
|
|
if !isWebServerPortForward &&
|
|
|
!sshClient.isPortForwardPermitted(
|
|
|
portForwardTypeTCP,
|
|
|
false,
|
|
|
- lookupResult.IP,
|
|
|
+ IP,
|
|
|
portToConnect) {
|
|
|
|
|
|
+ // Note: not recording a port forward failure in this case
|
|
|
+
|
|
|
sshClient.rejectNewChannel(
|
|
|
newChannel, ssh.Prohibited, "port forward not permitted")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ // TCP dial.
|
|
|
+
|
|
|
+ remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
|
|
|
+
|
|
|
+ log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
|
+
|
|
|
+ ctx, _ = context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
|
|
|
+ fwdConn, err := (&net.Dialer{}).DialContext(
|
|
|
+ ctx, "tcp", remoteAddr)
|
|
|
+
|
|
|
+ // Record port forward success or failure
|
|
|
+ sshClient.updateQualityMetrics(err == nil, monotime.Since(dialStartTime))
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+
|
|
|
+ // Monitor for low resource error conditions
|
|
|
+ sshClient.sshServer.monitorPortForwardDialError(err)
|
|
|
+
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
+ newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", err))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // The upstream TCP port forward connection has been established. Schedule
|
|
|
+ // some cleanup and notify the SSH client that the channel is accepted.
|
|
|
+
|
|
|
+ defer fwdConn.Close()
|
|
|
+
|
|
|
+ fwdChannel, requests, err := newChannel.Accept()
|
|
|
+ if err != nil {
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ go ssh.DiscardRequests(requests)
|
|
|
+ defer fwdChannel.Close()
|
|
|
+
|
|
|
+ // Release the dialing slot and acquire an established slot.
|
|
|
+
|
|
|
+ // "established = true" cancels the deferred failedTCPPortForward()
|
|
|
+ established = true
|
|
|
+ sshClient.establishedPortForward(portForwardTypeTCP)
|
|
|
+
|
|
|
var bytesUp, bytesDown int64
|
|
|
- sshClient.openedPortForward(portForwardTypeTCP)
|
|
|
defer func() {
|
|
|
sshClient.closedPortForward(
|
|
|
portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
|
|
|
}()
|
|
|
|
|
|
- // TOCTOU note: important to increment the port forward count (via
|
|
|
- // openPortForward) _before_ checking isPortForwardLimitExceeded
|
|
|
- // otherwise, the client could potentially consume excess resources
|
|
|
- // by initiating many port forwards concurrently.
|
|
|
- if maxCount, exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
|
|
|
+ if exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
|
|
|
|
|
|
// Close the oldest TCP port forward. CloseOldest() closes
|
|
|
- // the conn and the port forward's goroutine will complete
|
|
|
+ // the conn and the port forward's goroutines will complete
|
|
|
// the cleanup asynchronously.
|
|
|
//
|
|
|
// Some known limitations:
|
|
|
@@ -1536,91 +1684,25 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
// asynchronous, there exists a possibility that a client can consume
|
|
|
// more than max port forward resources -- just not upstream sockets.
|
|
|
//
|
|
|
- // - An LRU list entry for this port forward is not added until
|
|
|
- // after the dial completes, but the port forward is counted
|
|
|
- // towards max limits. This means many dials in progress will
|
|
|
- // put established connections in jeopardy.
|
|
|
- //
|
|
|
- // - We're closing the oldest open connection _before_ successfully
|
|
|
- // dialing the new port forward. This means we are potentially
|
|
|
- // discarding a good connection to make way for a failed connection.
|
|
|
- // We cannot simply dial first and still maintain a limit on
|
|
|
- // resources used, so to address this we'd need to add some
|
|
|
- // accounting for connections still establishing.
|
|
|
+ // - Closed sockets will enter the TIME_WAIT state, consuming some
|
|
|
+ // resources.
|
|
|
|
|
|
sshClient.tcpPortForwardLRU.CloseOldest()
|
|
|
|
|
|
- log.WithContextFields(
|
|
|
- LogFields{
|
|
|
- "maxCount": maxCount,
|
|
|
- }).Debug("closed LRU TCP port forward")
|
|
|
- }
|
|
|
-
|
|
|
- // Dial the target remote address. This is done in a goroutine to
|
|
|
- // ensure the shutdown signal is handled immediately.
|
|
|
-
|
|
|
- remoteAddr := net.JoinHostPort(lookupResult.IP.String(), strconv.Itoa(portToConnect))
|
|
|
-
|
|
|
- log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
|
-
|
|
|
- type dialTCPResult struct {
|
|
|
- conn net.Conn
|
|
|
- err error
|
|
|
- }
|
|
|
- dialResultChannel := make(chan *dialTCPResult, 1)
|
|
|
-
|
|
|
- dialStartTime := monotime.Now()
|
|
|
-
|
|
|
- go func() {
|
|
|
- conn, err := net.DialTimeout(
|
|
|
- "tcp", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
|
|
|
- dialResultChannel <- &dialTCPResult{conn, err}
|
|
|
- }()
|
|
|
-
|
|
|
- var dialResult *dialTCPResult
|
|
|
- select {
|
|
|
- case dialResult = <-dialResultChannel:
|
|
|
- case <-sshClient.stopBroadcast:
|
|
|
- // Note: may leave Dial in progress
|
|
|
- // TODO: use net.Dialer.DialContext to be able to cancel
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- sshClient.updateQualityMetrics(
|
|
|
- dialResult.err == nil, monotime.Since(dialStartTime))
|
|
|
-
|
|
|
- if dialResult.err != nil {
|
|
|
- sshClient.sshServer.handlePortForwardDialError(dialResult.err)
|
|
|
- sshClient.rejectNewChannel(
|
|
|
- newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", dialResult.err))
|
|
|
- return
|
|
|
+ log.WithContext().Debug("closed LRU TCP port forward")
|
|
|
}
|
|
|
|
|
|
- // The upstream TCP port forward connection has been established. Schedule
|
|
|
- // some cleanup and notify the SSH client that the channel is accepted.
|
|
|
-
|
|
|
- fwdConn := dialResult.conn
|
|
|
- defer fwdConn.Close()
|
|
|
-
|
|
|
- fwdChannel, requests, err := newChannel.Accept()
|
|
|
- if err != nil {
|
|
|
- log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
|
- return
|
|
|
- }
|
|
|
- go ssh.DiscardRequests(requests)
|
|
|
- defer fwdChannel.Close()
|
|
|
+ lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
|
|
|
+ defer lruEntry.Remove()
|
|
|
|
|
|
// ActivityMonitoredConn monitors the TCP port forward I/O and updates
|
|
|
// its LRU status. ActivityMonitoredConn also times out I/O on the port
|
|
|
// forward if both reads and writes have been idle for the specified
|
|
|
// duration.
|
|
|
|
|
|
- lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
|
|
|
- defer lruEntry.Remove()
|
|
|
-
|
|
|
// Ensure nil interface if newClientSeedPortForward returns nil
|
|
|
var updater common.ActivityUpdater
|
|
|
- seedUpdater := sshClient.newClientSeedPortForward(lookupResult.IP)
|
|
|
+ seedUpdater := sshClient.newClientSeedPortForward(IP)
|
|
|
if seedUpdater != nil {
|
|
|
updater = seedUpdater
|
|
|
}
|