|
|
@@ -1697,28 +1697,6 @@ func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) isAtPortForwardLimit(
|
|
|
- portForwardType int) bool {
|
|
|
-
|
|
|
- sshClient.Lock()
|
|
|
- defer sshClient.Unlock()
|
|
|
-
|
|
|
- var max int
|
|
|
- var state *trafficState
|
|
|
- if portForwardType == portForwardTypeTCP {
|
|
|
- max = *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
- state = &sshClient.tcpTrafficState
|
|
|
- } else {
|
|
|
- max = *sshClient.trafficRules.MaxUDPPortForwardCount
|
|
|
- state = &sshClient.udpTrafficState
|
|
|
- }
|
|
|
-
|
|
|
- if max > 0 && state.concurrentPortForwardCount >= int64(max) {
|
|
|
- return true
|
|
|
- }
|
|
|
- return false
|
|
|
-}
|
|
|
-
|
|
|
func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
@@ -1757,6 +1735,56 @@ func (sshClient *sshClient) abortedTCPPortForward() {
|
|
|
sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
|
|
|
}
|
|
|
|
|
|
+func (sshClient *sshClient) allocatePortForward(portForwardType int) bool {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ // Check if at port forward limit. The subsequent counter
|
|
|
+ // changes must be atomic with the limit check to ensure
|
|
|
+ // the counter never exceeds the limit in the case of
|
|
|
+ // concurrent allocations.
|
|
|
+
|
|
|
+ var max int
|
|
|
+ var state *trafficState
|
|
|
+ if portForwardType == portForwardTypeTCP {
|
|
|
+ max = *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
+ state = &sshClient.tcpTrafficState
|
|
|
+ } else {
|
|
|
+ max = *sshClient.trafficRules.MaxUDPPortForwardCount
|
|
|
+ state = &sshClient.udpTrafficState
|
|
|
+ }
|
|
|
+
|
|
|
+ if max == 0 && state.concurrentPortForwardCount >= int64(max) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ // Update port forward counters.
|
|
|
+
|
|
|
+ if portForwardType == portForwardTypeTCP {
|
|
|
+
|
|
|
+ // Assumes TCP port forwards called dialingTCPPortForward
|
|
|
+ state.concurrentDialingPortForwardCount -= 1
|
|
|
+
|
|
|
+ if sshClient.tcpPortForwardDialingAvailableSignal != nil {
|
|
|
+
|
|
|
+ max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
|
|
|
+ if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
|
|
|
+ sshClient.tcpPortForwardDialingAvailableSignal()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ state.concurrentPortForwardCount += 1
|
|
|
+ if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
|
|
|
+ state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
|
|
|
+ }
|
|
|
+ state.totalPortForwardCount += 1
|
|
|
+
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
// establishedPortForward increments the concurrent port
|
|
|
// forward counter. closedPortForward decrements it, so it
|
|
|
// must always be called for each establishedPortForward
|
|
|
@@ -1774,6 +1802,8 @@ func (sshClient *sshClient) abortedTCPPortForward() {
|
|
|
func (sshClient *sshClient) establishedPortForward(
|
|
|
portForwardType int, portForwardLRU *common.LRUConns) {
|
|
|
|
|
|
+ // Do not lock sshClient here.
|
|
|
+
|
|
|
var state *trafficState
|
|
|
if portForwardType == portForwardTypeTCP {
|
|
|
state = &sshClient.tcpTrafficState
|
|
|
@@ -1791,44 +1821,33 @@ func (sshClient *sshClient) establishedPortForward(
|
|
|
// waits for a LRU handler to be interrupted and signal
|
|
|
// availability.
|
|
|
//
|
|
|
- // Note: the port forward limit can change via a traffic
|
|
|
- // rules hot reload; the condition variable handles this
|
|
|
- // case whereas a channel-based semaphore would not.
|
|
|
+ // Notes:
|
|
|
+ //
|
|
|
+ // - the port forward limit can change via a traffic
|
|
|
+ // rules hot reload; the condition variable handles
|
|
|
+ // this case whereas a channel-based semaphore would
|
|
|
+ // not.
|
|
|
+ //
|
|
|
+ // - if a number of goroutines exceeding the total limit
|
|
|
+ // arrive here all concurrently, some CloseOldest() calls
|
|
|
+ // will have no effect as there can be less existing port
|
|
|
+ // forwards than new ones. In this case, the new port
|
|
|
+ // forward will be delayed. This is highly unlikely in
|
|
|
+ // practise since UDP calls to establishedPortForward are
|
|
|
+ // serialized and TCP calls are limited by the dial
|
|
|
+ // queue/count.
|
|
|
+
|
|
|
+ if !sshClient.allocatePortForward(portForwardType) {
|
|
|
|
|
|
- if sshClient.isAtPortForwardLimit(portForwardType) {
|
|
|
portForwardLRU.CloseOldest()
|
|
|
log.WithContext().Debug("closed LRU port forward")
|
|
|
+
|
|
|
state.availablePortForwardCond.L.Lock()
|
|
|
- for sshClient.isAtPortForwardLimit(portForwardType) {
|
|
|
+ for !sshClient.allocatePortForward(portForwardType) {
|
|
|
state.availablePortForwardCond.Wait()
|
|
|
}
|
|
|
state.availablePortForwardCond.L.Unlock()
|
|
|
}
|
|
|
-
|
|
|
- sshClient.Lock()
|
|
|
-
|
|
|
- if portForwardType == portForwardTypeTCP {
|
|
|
-
|
|
|
- // Assumes TCP port forwards called dialingTCPPortForward
|
|
|
- state.concurrentDialingPortForwardCount -= 1
|
|
|
-
|
|
|
- if sshClient.tcpPortForwardDialingAvailableSignal != nil {
|
|
|
-
|
|
|
- max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
|
|
|
- if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
|
|
|
- sshClient.tcpPortForwardDialingAvailableSignal()
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- state.concurrentPortForwardCount += 1
|
|
|
- if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
|
|
|
- state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
|
|
|
- }
|
|
|
- state.totalPortForwardCount += 1
|
|
|
-
|
|
|
- sshClient.Unlock()
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) closedPortForward(
|