|
|
@@ -225,6 +225,14 @@ func (server *TunnelServer) SetClientHandshakeState(
|
|
|
return server.sshServer.setClientHandshakeState(sessionID, state)
|
|
|
}
|
|
|
|
|
|
+// GetClientHandshaked indicates whether the client has completed a handshake
|
|
|
+// and whether its traffic rules are immediately exhausted.
|
|
|
+func (server *TunnelServer) GetClientHandshaked(
|
|
|
+ sessionID string) (bool, bool, error) {
|
|
|
+
|
|
|
+ return server.sshServer.getClientHandshaked(sessionID)
|
|
|
+}
|
|
|
+
|
|
|
// SetEstablishTunnels sets whether new tunnels may be established or not.
|
|
|
// When not establishing, incoming connections are immediately closed.
|
|
|
func (server *TunnelServer) SetEstablishTunnels(establish bool) {
|
|
|
@@ -660,6 +668,22 @@ func (sshServer *sshServer) setClientHandshakeState(
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+func (sshServer *sshServer) getClientHandshaked(
|
|
|
+ sessionID string) (bool, bool, error) {
|
|
|
+
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
+ client := sshServer.clients[sessionID]
|
|
|
+ sshServer.clientsMutex.Unlock()
|
|
|
+
|
|
|
+ if client == nil {
|
|
|
+ return false, false, common.ContextError(errors.New("unknown session ID"))
|
|
|
+ }
|
|
|
+
|
|
|
+ completed, exhausted := client.getHandshaked()
|
|
|
+
|
|
|
+ return completed, exhausted, nil
|
|
|
+}
|
|
|
+
|
|
|
func (sshServer *sshServer) stopClients() {
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
@@ -1497,6 +1521,40 @@ func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+// getHandshaked returns whether the client has completed a handshake API
|
|
|
+// request and whether the traffic rules that were selected after the
|
|
|
+// handshake immediately exhaust the client.
|
|
|
+//
|
|
|
+// When the client is immediately exhausted it will be closed; but this
|
|
|
+// takes effect asynchronously. The "exhausted" return value is used to
|
|
|
+// prevent API requests by clients that will close.
|
|
|
+func (sshClient *sshClient) getHandshaked() (bool, bool) {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ completed := sshClient.handshakeState.completed
|
|
|
+
|
|
|
+ exhausted := false
|
|
|
+
|
|
|
+ // Notes:
|
|
|
+ // - "Immediately exhausted" is when CloseAfterExhausted is set and
|
|
|
+ // either ReadUnthrottledBytes or WriteUnthrottledBytes starts from
|
|
|
+ // 0, so no bytes would be read or written. This check does not
|
|
|
+ // examine whether 0 bytes _remain_ in the ThrottledConn.
|
|
|
+ // - This check is made against the current traffic rules, which
|
|
|
+ // could have changed in a hot reload since the handshake.
|
|
|
+
|
|
|
+ if completed &&
|
|
|
+ *sshClient.trafficRules.RateLimits.CloseAfterExhausted == true &&
|
|
|
+ (*sshClient.trafficRules.RateLimits.ReadUnthrottledBytes == 0 ||
|
|
|
+ *sshClient.trafficRules.RateLimits.WriteUnthrottledBytes == 0) {
|
|
|
+
|
|
|
+ exhausted = true
|
|
|
+ }
|
|
|
+
|
|
|
+ return completed, exhausted
|
|
|
+}
|
|
|
+
|
|
|
// setTrafficRules resets the client's traffic rules based on the latest server config
|
|
|
// and client properties. As sshClient.trafficRules may be reset by a concurrent
|
|
|
// goroutine, trafficRules must only be accessed within the sshClient mutex.
|
|
|
@@ -1697,28 +1755,6 @@ func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) isAtPortForwardLimit(
|
|
|
- portForwardType int) bool {
|
|
|
-
|
|
|
- sshClient.Lock()
|
|
|
- defer sshClient.Unlock()
|
|
|
-
|
|
|
- var max int
|
|
|
- var state *trafficState
|
|
|
- if portForwardType == portForwardTypeTCP {
|
|
|
- max = *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
- state = &sshClient.tcpTrafficState
|
|
|
- } else {
|
|
|
- max = *sshClient.trafficRules.MaxUDPPortForwardCount
|
|
|
- state = &sshClient.udpTrafficState
|
|
|
- }
|
|
|
-
|
|
|
- if max > 0 && state.concurrentPortForwardCount >= int64(max) {
|
|
|
- return true
|
|
|
- }
|
|
|
- return false
|
|
|
-}
|
|
|
-
|
|
|
func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
@@ -1757,6 +1793,55 @@ func (sshClient *sshClient) abortedTCPPortForward() {
|
|
|
sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
|
|
|
}
|
|
|
|
|
|
+func (sshClient *sshClient) allocatePortForward(portForwardType int) bool {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ // Check if at port forward limit. The subsequent counter
|
|
|
+ // changes must be atomic with the limit check to ensure
|
|
|
+ // the counter never exceeds the limit in the case of
|
|
|
+ // concurrent allocations.
|
|
|
+
|
|
|
+ var max int
|
|
|
+ var state *trafficState
|
|
|
+ if portForwardType == portForwardTypeTCP {
|
|
|
+ max = *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
+ state = &sshClient.tcpTrafficState
|
|
|
+ } else {
|
|
|
+ max = *sshClient.trafficRules.MaxUDPPortForwardCount
|
|
|
+ state = &sshClient.udpTrafficState
|
|
|
+ }
|
|
|
+
|
|
|
+ if max > 0 && state.concurrentPortForwardCount >= int64(max) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ // Update port forward counters.
|
|
|
+
|
|
|
+ if portForwardType == portForwardTypeTCP {
|
|
|
+
|
|
|
+ // Assumes TCP port forwards called dialingTCPPortForward
|
|
|
+ state.concurrentDialingPortForwardCount -= 1
|
|
|
+
|
|
|
+ if sshClient.tcpPortForwardDialingAvailableSignal != nil {
|
|
|
+
|
|
|
+ max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
|
|
|
+ if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
|
|
|
+ sshClient.tcpPortForwardDialingAvailableSignal()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ state.concurrentPortForwardCount += 1
|
|
|
+ if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
|
|
|
+ state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
|
|
|
+ }
|
|
|
+ state.totalPortForwardCount += 1
|
|
|
+
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
// establishedPortForward increments the concurrent port
|
|
|
// forward counter. closedPortForward decrements it, so it
|
|
|
// must always be called for each establishedPortForward
|
|
|
@@ -1774,6 +1859,8 @@ func (sshClient *sshClient) abortedTCPPortForward() {
|
|
|
func (sshClient *sshClient) establishedPortForward(
|
|
|
portForwardType int, portForwardLRU *common.LRUConns) {
|
|
|
|
|
|
+ // Do not lock sshClient here.
|
|
|
+
|
|
|
var state *trafficState
|
|
|
if portForwardType == portForwardTypeTCP {
|
|
|
state = &sshClient.tcpTrafficState
|
|
|
@@ -1791,44 +1878,33 @@ func (sshClient *sshClient) establishedPortForward(
|
|
|
// waits for a LRU handler to be interrupted and signal
|
|
|
// availability.
|
|
|
//
|
|
|
- // Note: the port forward limit can change via a traffic
|
|
|
- // rules hot reload; the condition variable handles this
|
|
|
- // case whereas a channel-based semaphore would not.
|
|
|
+ // Notes:
|
|
|
+ //
|
|
|
+ // - the port forward limit can change via a traffic
|
|
|
+ // rules hot reload; the condition variable handles
|
|
|
+ // this case whereas a channel-based semaphore would
|
|
|
+ // not.
|
|
|
+ //
|
|
|
+ // - if a number of goroutines exceeding the total limit
|
|
|
+ // arrive here all concurrently, some CloseOldest() calls
|
|
|
+ // will have no effect as there can be less existing port
|
|
|
+ // forwards than new ones. In this case, the new port
|
|
|
+ // forward will be delayed. This is highly unlikely in
|
|
|
+ // practise since UDP calls to establishedPortForward are
|
|
|
+ // serialized and TCP calls are limited by the dial
|
|
|
+ // queue/count.
|
|
|
+
|
|
|
+ if !sshClient.allocatePortForward(portForwardType) {
|
|
|
|
|
|
- if sshClient.isAtPortForwardLimit(portForwardType) {
|
|
|
portForwardLRU.CloseOldest()
|
|
|
log.WithContext().Debug("closed LRU port forward")
|
|
|
+
|
|
|
state.availablePortForwardCond.L.Lock()
|
|
|
- for sshClient.isAtPortForwardLimit(portForwardType) {
|
|
|
+ for !sshClient.allocatePortForward(portForwardType) {
|
|
|
state.availablePortForwardCond.Wait()
|
|
|
}
|
|
|
state.availablePortForwardCond.L.Unlock()
|
|
|
}
|
|
|
-
|
|
|
- sshClient.Lock()
|
|
|
-
|
|
|
- if portForwardType == portForwardTypeTCP {
|
|
|
-
|
|
|
- // Assumes TCP port forwards called dialingTCPPortForward
|
|
|
- state.concurrentDialingPortForwardCount -= 1
|
|
|
-
|
|
|
- if sshClient.tcpPortForwardDialingAvailableSignal != nil {
|
|
|
-
|
|
|
- max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
|
|
|
- if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
|
|
|
- sshClient.tcpPortForwardDialingAvailableSignal()
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- state.concurrentPortForwardCount += 1
|
|
|
- if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
|
|
|
- state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
|
|
|
- }
|
|
|
- state.totalPortForwardCount += 1
|
|
|
-
|
|
|
- sshClient.Unlock()
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) closedPortForward(
|