Просмотр исходного кода

Fix: make limit check and counter increment atomic

Without this, it was possible for many goroutines
to concurrently arrive at and pass the port forward
limit check and all proceed even if only a single
slot remained
Rod Hynes 8 лет назад
Родитель
Сommit
d5223dcadc
1 измененных файлов с 71 добавлено и 52 удалено
  1. 71 52
      psiphon/server/tunnelServer.go

+ 71 - 52
psiphon/server/tunnelServer.go

@@ -1697,28 +1697,6 @@ func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
 	return false
 }
 
-func (sshClient *sshClient) isAtPortForwardLimit(
-	portForwardType int) bool {
-
-	sshClient.Lock()
-	defer sshClient.Unlock()
-
-	var max int
-	var state *trafficState
-	if portForwardType == portForwardTypeTCP {
-		max = *sshClient.trafficRules.MaxTCPPortForwardCount
-		state = &sshClient.tcpTrafficState
-	} else {
-		max = *sshClient.trafficRules.MaxUDPPortForwardCount
-		state = &sshClient.udpTrafficState
-	}
-
-	if max > 0 && state.concurrentPortForwardCount >= int64(max) {
-		return true
-	}
-	return false
-}
-
 func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
 
 	sshClient.Lock()
@@ -1757,6 +1735,56 @@ func (sshClient *sshClient) abortedTCPPortForward() {
 	sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
 }
 
+func (sshClient *sshClient) allocatePortForward(portForwardType int) bool {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	// Check if at port forward limit. The subsequent counter
+	// changes must be atomic with the limit check to ensure
+	// the counter never exceeds the limit in the case of
+	// concurrent allocations.
+
+	var max int
+	var state *trafficState
+	if portForwardType == portForwardTypeTCP {
+		max = *sshClient.trafficRules.MaxTCPPortForwardCount
+		state = &sshClient.tcpTrafficState
+	} else {
+		max = *sshClient.trafficRules.MaxUDPPortForwardCount
+		state = &sshClient.udpTrafficState
+	}
+
+	if max == 0 && state.concurrentPortForwardCount >= int64(max) {
+		return false
+	}
+
+	// Update port forward counters.
+
+	if portForwardType == portForwardTypeTCP {
+
+		// Assumes TCP port forwards called dialingTCPPortForward
+		state.concurrentDialingPortForwardCount -= 1
+
+		if sshClient.tcpPortForwardDialingAvailableSignal != nil {
+
+			max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
+			if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
+				sshClient.tcpPortForwardDialingAvailableSignal()
+			}
+		}
+
+	}
+
+	state.concurrentPortForwardCount += 1
+	if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
+		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
+	}
+	state.totalPortForwardCount += 1
+
+	return true
+}
+
 // establishedPortForward increments the concurrent port
 // forward counter. closedPortForward decrements it, so it
 // must always be called for each establishedPortForward
@@ -1774,6 +1802,8 @@ func (sshClient *sshClient) abortedTCPPortForward() {
 func (sshClient *sshClient) establishedPortForward(
 	portForwardType int, portForwardLRU *common.LRUConns) {
 
+	// Do not lock sshClient here.
+
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
 		state = &sshClient.tcpTrafficState
@@ -1791,44 +1821,33 @@ func (sshClient *sshClient) establishedPortForward(
 	// waits for a LRU handler to be interrupted and signal
 	// availability.
 	//
-	// Note: the port forward limit can change via a traffic
-	// rules hot reload; the condition variable handles this
-	// case whereas a channel-based semaphore would not.
+	// Notes:
+	//
+	// - the port forward limit can change via a traffic
+	//   rules hot reload; the condition variable handles
+	//   this case whereas a channel-based semaphore would
+	//   not.
+	//
+	// - if a number of goroutines exceeding the total limit
+	//   arrive here all concurrently, some CloseOldest() calls
+	//   will have no effect as there can be less existing port
+	//   forwards than new ones. In this case, the new port
+	//   forward will be delayed. This is highly unlikely in
+	//   practise since UDP calls to establishedPortForward are
+	//   serialized and TCP calls are limited by the dial
+	//   queue/count.
+
+	if !sshClient.allocatePortForward(portForwardType) {
 
-	if sshClient.isAtPortForwardLimit(portForwardType) {
 		portForwardLRU.CloseOldest()
 		log.WithContext().Debug("closed LRU port forward")
+
 		state.availablePortForwardCond.L.Lock()
-		for sshClient.isAtPortForwardLimit(portForwardType) {
+		for !sshClient.allocatePortForward(portForwardType) {
 			state.availablePortForwardCond.Wait()
 		}
 		state.availablePortForwardCond.L.Unlock()
 	}
-
-	sshClient.Lock()
-
-	if portForwardType == portForwardTypeTCP {
-
-		// Assumes TCP port forwards called dialingTCPPortForward
-		state.concurrentDialingPortForwardCount -= 1
-
-		if sshClient.tcpPortForwardDialingAvailableSignal != nil {
-
-			max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
-			if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
-				sshClient.tcpPortForwardDialingAvailableSignal()
-			}
-		}
-
-	}
-
-	state.concurrentPortForwardCount += 1
-	if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
-		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
-	}
-	state.totalPortForwardCount += 1
-
-	sshClient.Unlock()
 }
 
 func (sshClient *sshClient) closedPortForward(