Просмотр исходного кода

Fix: Concurrent port forward counts can exceed traffic rule limits

- Metrics show concurrent counts exceeding limits
- CloseOldest may have been keeping the number of open
  sockets at the limit, but port forward handler goroutines
  would remain running until their blocked Reads were
  eventually interrupted
- With the fix, closing the LRU now also waits for the the
  LRU handler to be interrupted
Rod Hynes 8 лет назад
Родитель
Сommit
23b8f96ad2
2 измененных файлов с 78 добавлено и 47 удалено
  1. 73 34
      psiphon/server/tunnelServer.go
  2. 5 13
      psiphon/server/udp.go

+ 73 - 34
psiphon/server/tunnelServer.go

@@ -745,6 +745,7 @@ type trafficState struct {
 	concurrentPortForwardCount            int64
 	peakConcurrentPortForwardCount        int64
 	totalPortForwardCount                 int64
+	availablePortForwardCond              *sync.Cond
 }
 
 // qualityMetrics records upstream TCP dial attempts and
@@ -771,7 +772,7 @@ func newSshClient(
 
 	runContext, stopRunning := context.WithCancel(context.Background())
 
-	return &sshClient{
+	client := &sshClient{
 		sshServer:         sshServer,
 		tunnelProtocol:    tunnelProtocol,
 		geoIPData:         geoIPData,
@@ -780,6 +781,11 @@ func newSshClient(
 		runContext:        runContext,
 		stopRunning:       stopRunning,
 	}
+
+	client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
+	client.udpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
+
+	return client
 }
 
 func (sshClient *sshClient) run(clientConn net.Conn) {
@@ -1674,7 +1680,7 @@ func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
 	return false
 }
 
-func (sshClient *sshClient) isPortForwardLimitExceeded(
+func (sshClient *sshClient) isAtPortForwardLimit(
 	portForwardType int) bool {
 
 	sshClient.Lock()
@@ -1734,15 +1740,57 @@ func (sshClient *sshClient) abortedTCPPortForward() {
 	sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
 }
 
+// establishedPortForward increments the concurrent port
+// forward counter. closedPortForward decrements it, so it
+// must always be called for each establishedPortForward
+// call.
+//
+// When at the limit of established port forwards, the LRU
+// existing port forward is closed to make way for the newly
+// established one. There can be a minor delay as, in addition
+// to calling Close() on the port forward net.Conn,
+// establishedPortForward waits for the LRU's closedPortForward()
+// call which will decrement the concurrent counter. This
+// ensures all resources associated with the LRU (socket,
+// goroutine) are released or will very soon be released before
+// proceeding.
 func (sshClient *sshClient) establishedPortForward(
-	portForwardType int) {
-
-	sshClient.Lock()
-	defer sshClient.Unlock()
+	portForwardType int, portForwardLRU *common.LRUConns) {
 
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
 		state = &sshClient.tcpTrafficState
+	} else {
+		state = &sshClient.udpTrafficState
+	}
+
+	// When the maximum number of port forwards is already
+	// established, close the LRU. CloseOldest will call
+	// Close on the port forward net.Conn. Both TCP and
+	// UDP port forwards have handler goroutines that may
+	// be blocked calling Read on the net.Conn. Close will
+	// eventually interrupt the Read and cause the handlers
+	// to exit, but not immediately. So the following logic
+	// waits for a LRU handler to be interrupted and signal
+	// availability.
+	//
+	// Note: the port forward limit can change via a traffic
+	// rules hot reload; the condition variable handles this
+	// case whereas a channel-based semaphore would not.
+
+	if sshClient.isAtPortForwardLimit(portForwardType) {
+		portForwardLRU.CloseOldest()
+		log.WithContext().Debug("closed LRU port forward")
+		state.availablePortForwardCond.L.Lock()
+		for sshClient.isAtPortForwardLimit(portForwardType) {
+			state.availablePortForwardCond.Wait()
+		}
+		state.availablePortForwardCond.L.Unlock()
+	}
+
+	sshClient.Lock()
+
+	if portForwardType == portForwardTypeTCP {
 
 		// Assumes TCP port forwards called dialingTCPPortForward
 		state.concurrentDialingPortForwardCount -= 1
@@ -1755,8 +1803,6 @@ func (sshClient *sshClient) establishedPortForward(
 			}
 		}
 
-	} else {
-		state = &sshClient.udpTrafficState
 	}
 
 	state.concurrentPortForwardCount += 1
@@ -1764,13 +1810,14 @@ func (sshClient *sshClient) establishedPortForward(
 		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
 	}
 	state.totalPortForwardCount += 1
+
+	sshClient.Unlock()
 }
 
 func (sshClient *sshClient) closedPortForward(
 	portForwardType int, bytesUp, bytesDown int64) {
 
 	sshClient.Lock()
-	defer sshClient.Unlock()
 
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
@@ -1782,6 +1829,12 @@ func (sshClient *sshClient) closedPortForward(
 	state.concurrentPortForwardCount -= 1
 	state.bytesUp += bytesUp
 	state.bytesDown += bytesDown
+
+	sshClient.Unlock()
+
+	// Signal any goroutine waiting in establishedPortForward
+	// that an established port forward slot is available.
+	state.availablePortForwardCond.Signal()
 }
 
 func (sshClient *sshClient) updateQualityMetricsWithDialResult(
@@ -1945,10 +1998,20 @@ func (sshClient *sshClient) handleTCPChannel(
 	defer fwdChannel.Close()
 
 	// Release the dialing slot and acquire an established slot.
+	//
+	// establishedPortForward increments the concurrent TCP port
+	// forward counter and closes the LRU existing TCP port forward
+	// when already at the limit.
+	//
+	// Known limitations:
+	//
+	// - Closed LRU TCP sockets will enter the TIME_WAIT state,
+	//   continuing to consume some resources.
+
+	sshClient.establishedPortForward(portForwardTypeTCP, sshClient.tcpPortForwardLRU)
 
 	// "established = true" cancels the deferred abortedTCPPortForward()
 	established = true
-	sshClient.establishedPortForward(portForwardTypeTCP)
 
 	// TODO: 64-bit alignment? https://golang.org/pkg/sync/atomic/#pkg-note-BUG
 	var bytesUp, bytesDown int64
@@ -1957,30 +2020,6 @@ func (sshClient *sshClient) handleTCPChannel(
 			portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
 	}()
 
-	if exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
-
-		// Close the oldest TCP port forward. CloseOldest() closes
-		// the conn and the port forward's goroutines will complete
-		// the cleanup asynchronously.
-		//
-		// Some known limitations:
-		//
-		// - Since CloseOldest() closes the upstream socket but does not
-		//   clean up all resources associated with the port forward. These
-		//   include the goroutine(s) relaying traffic as well as the SSH
-		//   channel. Closing the socket will interrupt the goroutines which
-		//   will then complete the cleanup. But, since the full cleanup is
-		//   asynchronous, there exists a possibility that a client can consume
-		//   more than max port forward resources -- just not upstream sockets.
-		//
-		// - Closed sockets will enter the TIME_WAIT state, consuming some
-		//   resources.
-
-		sshClient.tcpPortForwardLRU.CloseOldest()
-
-		log.WithContext().Debug("closed LRU TCP port forward")
-	}
-
 	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
 	defer lruEntry.Remove()
 

+ 5 - 13
psiphon/server/udp.go

@@ -171,22 +171,14 @@ func (mux *udpPortForwardMultiplexer) run() {
 
 			// Note: UDP port forward counting has no dialing phase
 
-			mux.sshClient.establishedPortForward(portForwardTypeUDP)
+			// establishedPortForward increments the concurrent UDP port
+			// forward counter and closes the LRU existing UDP port forward
+			// when already at the limit.
+
+			mux.sshClient.establishedPortForward(portForwardTypeUDP, mux.portForwardLRU)
 			// Can't defer sshClient.closedPortForward() here;
 			// relayDownstream will call sshClient.closedPortForward()
 
-			// TOCTOU note: important to increment the port forward count (via
-			// openPortForward) _before_ checking isPortForwardLimitExceeded
-			if exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded {
-
-				// Close the oldest UDP port forward. CloseOldest() closes
-				// the conn and the port forward's goroutine will complete
-				// the cleanup asynchronously.
-				mux.portForwardLRU.CloseOldest()
-
-				log.WithContext().Debug("closed LRU UDP port forward")
-			}
-
 			log.WithContextFields(
 				LogFields{
 					"remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort),