Просмотр исходного кода

Improve TCP port forward handling

- add distinct "dialing" phase for TCP port forwards
- add traffic rule limit for "dialing" TCP port forwards
- when "dialing" limit is exceeded, port forward is
  rejected; "established" phase retains LRU logic
- report stats for "dialing" TCP port forwards
- explicit timeout for resolving port forward hostnames
- use context.Context to simplify cancellable dial code
- resolve/dial no longer left dangling when client stops
Rod Hynes 9 лет назад
Родитель
Сommit
e10ce6a784
3 измененных файлов с 265 добавлено и 161 удалено
  1. 24 3
      psiphon/server/trafficRules.go
  2. 226 144
      psiphon/server/tunnelServer.go
  3. 15 14
      psiphon/server/udp.go

+ 24 - 3
psiphon/server/trafficRules.go

@@ -30,6 +30,7 @@ import (
 const (
 	DEFAULT_IDLE_TCP_PORT_FORWARD_TIMEOUT_MILLISECONDS = 30000
 	DEFAULT_IDLE_UDP_PORT_FORWARD_TIMEOUT_MILLISECONDS = 30000
+	DEFAULT_MAX_TCP_DIALING_PORT_FORWARD_COUNT         = 64
 	DEFAULT_MAX_TCP_PORT_FORWARD_COUNT                 = 512
 	DEFAULT_MAX_UDP_PORT_FORWARD_COUNT                 = 32
 )
@@ -106,14 +107,25 @@ type TrafficRules struct {
 	// is used.
 	IdleUDPPortForwardTimeoutMilliseconds *int
 
-	// MaxTCPPortForwardCount is the maximum number of TCP port
-	// forwards each client may have open concurrently.
+	// MaxTCPDialingPortForwardCount is the maximum number of dialing
+	// TCP port forwards each client may have open concurrently. When
+	// at the limit, new TCP port forwards are rejected.
+	// A value of 0 specifies no maximum. When omitted in
+	// DefaultRules, DEFAULT_MAX_TCP_DIALING_PORT_FORWARD_COUNT is used.
+	MaxTCPDialingPortForwardCount *int
+
+	// MaxTCPPortForwardCount is the maximum number of established TCP
+	// port forwards each client may have open concurrently. If at the
+	// limit when a new TCP port forward is established, the LRU
+	// established TCP port forward is closed.
 	// A value of 0 specifies no maximum. When omitted in
 	// DefaultRules, DEFAULT_MAX_TCP_PORT_FORWARD_COUNT is used.
 	MaxTCPPortForwardCount *int
 
 	// MaxUDPPortForwardCount is the maximum number of UDP port
-	// forwards each client may have open concurrently.
+	// forwards each client may have open concurrently. If at the
+	// limit when a new UDP port forward is created, the LRU
+	// UDP port forward is closed.
 	// A value of 0 specifies no maximum. When omitted in
 	// DefaultRules, DEFAULT_MAX_UDP_PORT_FORWARD_COUNT is used.
 	MaxUDPPortForwardCount *int
@@ -299,6 +311,11 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			intPtr(DEFAULT_IDLE_UDP_PORT_FORWARD_TIMEOUT_MILLISECONDS)
 	}
 
+	if trafficRules.MaxTCPDialingPortForwardCount == nil {
+		trafficRules.MaxTCPDialingPortForwardCount =
+			intPtr(DEFAULT_MAX_TCP_DIALING_PORT_FORWARD_COUNT)
+	}
+
 	if trafficRules.MaxTCPPortForwardCount == nil {
 		trafficRules.MaxTCPPortForwardCount =
 			intPtr(DEFAULT_MAX_TCP_PORT_FORWARD_COUNT)
@@ -393,6 +410,10 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			trafficRules.IdleUDPPortForwardTimeoutMilliseconds = filteredRules.Rules.IdleUDPPortForwardTimeoutMilliseconds
 		}
 
+		if filteredRules.Rules.MaxTCPDialingPortForwardCount != nil {
+			trafficRules.MaxTCPDialingPortForwardCount = filteredRules.Rules.MaxTCPDialingPortForwardCount
+		}
+
 		if filteredRules.Rules.MaxTCPPortForwardCount != nil {
 			trafficRules.MaxTCPPortForwardCount = filteredRules.Rules.MaxTCPPortForwardCount
 		}

+ 226 - 144
psiphon/server/tunnelServer.go

@@ -20,6 +20,7 @@
 package server
 
 import (
+	"context"
 	"crypto/subtle"
 	"encoding/json"
 	"errors"
@@ -456,6 +457,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 		protocolStats[tunnelProtocol]["ALL"] = make(map[string]int64)
 		protocolStats[tunnelProtocol]["ALL"]["accepted_clients"] = 0
 		protocolStats[tunnelProtocol]["ALL"]["established_clients"] = 0
+		protocolStats[tunnelProtocol]["ALL"]["dialing_tcp_port_forwards"] = 0
 		protocolStats[tunnelProtocol]["ALL"]["tcp_port_forwards"] = 0
 		protocolStats[tunnelProtocol]["ALL"]["total_tcp_port_forwards"] = 0
 		protocolStats[tunnelProtocol]["ALL"]["udp_port_forwards"] = 0
@@ -472,6 +474,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 					protocolStats[tunnelProtocol][region] = make(map[string]int64)
 					protocolStats[tunnelProtocol][region]["accepted_clients"] = 0
 					protocolStats[tunnelProtocol][region]["established_clients"] = 0
+					protocolStats[tunnelProtocol][region]["dialing_tcp_port_forwards"] = 0
 					protocolStats[tunnelProtocol][region]["tcp_port_forwards"] = 0
 					protocolStats[tunnelProtocol][region]["total_tcp_port_forwards"] = 0
 					protocolStats[tunnelProtocol][region]["udp_port_forwards"] = 0
@@ -496,6 +499,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 				protocolStats[client.tunnelProtocol][region] = make(map[string]int64)
 				protocolStats[client.tunnelProtocol][region]["accepted_clients"] = 0
 				protocolStats[client.tunnelProtocol][region]["established_clients"] = 0
+				protocolStats[client.tunnelProtocol][region]["dialing_tcp_port_forwards"] = 0
 				protocolStats[client.tunnelProtocol][region]["tcp_port_forwards"] = 0
 				protocolStats[client.tunnelProtocol][region]["total_tcp_port_forwards"] = 0
 				protocolStats[client.tunnelProtocol][region]["udp_port_forwards"] = 0
@@ -505,8 +509,10 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 			// Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
 			protocolStats[client.tunnelProtocol][region]["established_clients"] += 1
 
+			protocolStats[client.tunnelProtocol][region]["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
 			protocolStats[client.tunnelProtocol][region]["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
 			protocolStats[client.tunnelProtocol][region]["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
+			// client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
 			protocolStats[client.tunnelProtocol][region]["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
 			protocolStats[client.tunnelProtocol][region]["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
 
@@ -532,6 +538,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 	allProtocolsStats := make(map[string]int64)
 	allProtocolsStats["accepted_clients"] = 0
 	allProtocolsStats["established_clients"] = 0
+	allProtocolsStats["dialing_tcp_port_forwards"] = 0
 	allProtocolsStats["tcp_port_forwards"] = 0
 	allProtocolsStats["total_tcp_port_forwards"] = 0
 	allProtocolsStats["udp_port_forwards"] = 0
@@ -629,7 +636,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	sshClient.run(clientConn)
 }
 
-func (sshServer *sshServer) handlePortForwardDialError(err error) {
+func (sshServer *sshServer) monitorPortForwardDialError(err error) {
 
 	// "err" is the error returned from a failed TCP or UDP port
 	// forward dial. Certain system error codes indicate low resource
@@ -675,15 +682,18 @@ type sshClient struct {
 	tcpPortForwardLRU      *common.LRUConns
 	oslClientSeedState     *osl.ClientSeedState
 	signalIssueSLOKs       chan struct{}
-	stopBroadcast          chan struct{}
+	runContext             context.Context
+	stopRunning            context.CancelFunc
 }
 
 type trafficState struct {
-	bytesUp                        int64
-	bytesDown                      int64
-	concurrentPortForwardCount     int64
-	peakConcurrentPortForwardCount int64
-	totalPortForwardCount          int64
+	bytesUp                               int64
+	bytesDown                             int64
+	concurrentDialingPortForwardCount     int64
+	peakConcurrentDialingPortForwardCount int64
+	concurrentPortForwardCount            int64
+	peakConcurrentPortForwardCount        int64
+	totalPortForwardCount                 int64
 }
 
 // qualityMetrics records upstream TCP dial attempts and
@@ -706,13 +716,17 @@ type handshakeState struct {
 
 func newSshClient(
 	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
+
+	runContext, stopRunning := context.WithCancel(context.Background())
+
 	return &sshClient{
 		sshServer:         sshServer,
 		tunnelProtocol:    tunnelProtocol,
 		geoIPData:         geoIPData,
 		tcpPortForwardLRU: common.NewLRUConns(),
 		signalIssueSLOKs:  make(chan struct{}, 1),
-		stopBroadcast:     make(chan struct{}),
+		runContext:        runContext,
+		stopRunning:       stopRunning,
 	}
 }
 
@@ -1019,7 +1033,7 @@ func (sshClient *sshClient) runTunnel(
 	// The channel loop is interrupted by a client
 	// disconnect or by calling sshClient.stop().
 
-	close(sshClient.stopBroadcast)
+	sshClient.stopRunning()
 
 	waitGroup.Wait()
 }
@@ -1049,10 +1063,12 @@ func (sshClient *sshClient) logTunnel() {
 	logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
 	logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
 	logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
+	logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
 	logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
 	logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
 	logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
 	logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
+	// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
 	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
 	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
 
@@ -1068,7 +1084,7 @@ func (sshClient *sshClient) runOSLSender() {
 		// TODO: use reflect.SelectCase, and optionally await timer here?
 		select {
 		case <-sshClient.signalIssueSLOKs:
-		case <-sshClient.stopBroadcast:
+		case <-sshClient.runContext.Done():
 			return
 		}
 
@@ -1086,7 +1102,7 @@ func (sshClient *sshClient) runOSLSender() {
 			select {
 			case <-retryTimer.C:
 			case <-sshClient.signalIssueSLOKs:
-			case <-sshClient.stopBroadcast:
+			case <-sshClient.runContext.Done():
 				retryTimer.Stop()
 				return
 			}
@@ -1363,29 +1379,64 @@ func (sshClient *sshClient) isPortForwardPermitted(
 	return false
 }
 
+func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	state := &sshClient.tcpTrafficState
+	max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
+
+	if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
+		return true
+	}
+	return false
+}
+
 func (sshClient *sshClient) isPortForwardLimitExceeded(
-	portForwardType int) (int, bool) {
+	portForwardType int) bool {
 
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
-	var maxPortForwardCount int
+	var max int
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
-		maxPortForwardCount = *sshClient.trafficRules.MaxTCPPortForwardCount
+		max = *sshClient.trafficRules.MaxTCPPortForwardCount
 		state = &sshClient.tcpTrafficState
 	} else {
-		maxPortForwardCount = *sshClient.trafficRules.MaxUDPPortForwardCount
+		max = *sshClient.trafficRules.MaxUDPPortForwardCount
 		state = &sshClient.udpTrafficState
 	}
 
-	if maxPortForwardCount > 0 && state.concurrentPortForwardCount >= int64(maxPortForwardCount) {
-		return maxPortForwardCount, true
+	if max > 0 && state.concurrentPortForwardCount >= int64(max) {
+		return true
+	}
+	return false
+}
+
+func (sshClient *sshClient) dialingTCPPortForward() {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	state := &sshClient.tcpTrafficState
+
+	state.concurrentDialingPortForwardCount += 1
+	if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
+		state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
 	}
-	return maxPortForwardCount, false
 }
 
-func (sshClient *sshClient) openedPortForward(
+func (sshClient *sshClient) failedTCPPortForward() {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
+}
+
+func (sshClient *sshClient) establishedPortForward(
 	portForwardType int) {
 
 	sshClient.Lock()
@@ -1394,6 +1445,10 @@ func (sshClient *sshClient) openedPortForward(
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
 		state = &sshClient.tcpTrafficState
+
+		// Assumes TCP port forwards called dialingTCPPortForward
+		state.concurrentDialingPortForwardCount -= 1
+
 	} else {
 		state = &sshClient.udpTrafficState
 	}
@@ -1405,22 +1460,6 @@ func (sshClient *sshClient) openedPortForward(
 	state.totalPortForwardCount += 1
 }
 
-func (sshClient *sshClient) updateQualityMetrics(
-	tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
-
-	sshClient.Lock()
-	defer sshClient.Unlock()
-
-	if tcpPortForwardDialSuccess {
-		sshClient.qualityMetrics.tcpPortForwardDialedCount += 1
-		sshClient.qualityMetrics.tcpPortForwardDialedDuration += dialDuration
-
-	} else {
-		sshClient.qualityMetrics.tcpPortForwardFailedCount += 1
-		sshClient.qualityMetrics.tcpPortForwardFailedDuration += dialDuration
-	}
-}
-
 func (sshClient *sshClient) closedPortForward(
 	portForwardType int, bytesUp, bytesDown int64) {
 
@@ -1439,11 +1478,78 @@ func (sshClient *sshClient) closedPortForward(
 	state.bytesDown += bytesDown
 }
 
+func (sshClient *sshClient) updateQualityMetrics(
+	tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	if tcpPortForwardDialSuccess {
+		sshClient.qualityMetrics.tcpPortForwardDialedCount += 1
+		sshClient.qualityMetrics.tcpPortForwardDialedDuration += dialDuration
+
+	} else {
+		sshClient.qualityMetrics.tcpPortForwardFailedCount += 1
+		sshClient.qualityMetrics.tcpPortForwardFailedDuration += dialDuration
+	}
+}
+
 func (sshClient *sshClient) handleTCPChannel(
 	hostToConnect string,
 	portToConnect int,
 	newChannel ssh.NewChannel) {
 
+	// Lifecycle of a TCP port forward:
+	//
+	// 1. Call dialingTCPPortForward(), which increments concurrentDialingPortForwardCount
+	//
+	// 2. Check isTCPDialingPortForwardLimitExceeded(), which enforces the configurable
+	//    limit on concurrentDialingPortForwardCount. If the limit is exceeded, the port
+	//    forward is _immediately rejected_; this hard limit applies back pressure when
+	//    upstream network resources are impaired.
+	//
+	//        TOCTOU note: important to increment the port forward counts _before_
+	//        checking limits; otherwise, the client could potentially consume excess
+	//        resources by initiating many port forwards concurrently.
+	//
+	// 3. Dial the target.
+	//
+	// 4. If the dial fails, call failedTCPPortForward() to decrement
+	//    concurrentDialingPortForwardCount, freeing up a dial slot.
+	//
+	// 5. If the dial succeeds, call establishedPortForward(), which decrements
+	//    concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
+	//    the "established" port forward count.
+	//
+	// 6. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
+	//    concurrentPortForwardCount, the number of _established_ TCP port forwards.
+	//    If the limit is exceeded, the LRU established TCP port forward is closed and
+	//    the newly established TCP port forward proceeds. This LRU logic allows some
+	//    dangling resource consumption (e.g., TIME_WAIT) while providing a better
+	//    experience for clients.
+	//
+	// 7. Relay data.
+	//
+	// 8. Call closedPortForward() which decrements concurrentPortForwardCount and
+	//    records bytes transferred.
+
+	sshClient.dialingTCPPortForward()
+	established := false
+	defer func() {
+		if !established {
+			sshClient.failedTCPPortForward()
+		}
+	}()
+
+	if exceeded := sshClient.isTCPDialingPortForwardLimitExceeded(); exceeded {
+
+		sshClient.rejectNewChannel(
+			newChannel, ssh.Prohibited, "dialing port forward limit exceeded")
+		return
+	}
+
+	// Transparently redirect web API request connections.
+
 	isWebServerPortForward := false
 	config := sshClient.sshServer.support.Config
 	if config.WebServerPortForwardAddress != "" {
@@ -1460,70 +1566,112 @@ func (sshClient *sshClient) handleTCPChannel(
 		}
 	}
 
-	type lookupIPResult struct {
-		IP  net.IP
-		err error
-	}
-	lookupResultChannel := make(chan *lookupIPResult, 1)
+	// Dial the remote address.
+	//
+	// Hostname resolution is performed explicitly, as a seperate step, as the target IP
+	// address is used for traffic rules (AllowSubnets) and OSL seed progress.
+	//
+	// Contexts are used for cancellation (via sshClient.runContext, which is cancelled
+	// when the client is stopping) and timeouts.
 
-	go func() {
-		// TODO: explicit timeout for DNS resolution?
-		IPs, err := net.LookupIP(hostToConnect)
-		// TODO: shuffle list to try other IPs
-		// TODO: IPv6 support
-		var IP net.IP
-		for _, ip := range IPs {
-			if ip.To4() != nil {
-				IP = ip
-			}
-		}
-		if err == nil && IP == nil {
-			err = errors.New("no IP address")
-		}
-		lookupResultChannel <- &lookupIPResult{IP, err}
-	}()
+	dialStartTime := monotime.Now()
 
-	var lookupResult *lookupIPResult
-	select {
-	case lookupResult = <-lookupResultChannel:
-	case <-sshClient.stopBroadcast:
-		// Note: may leave LookupIP in progress
-		return
+	log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
+
+	ctx, _ := context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT)
+	IPs, err := (&net.Resolver{}).LookupIPAddr(
+		ctx, hostToConnect)
+
+	// TODO: shuffle list to try other IPs
+	// TODO: IPv6 support
+	var IP net.IP
+	for _, ip := range IPs {
+		if ip.IP.To4() != nil {
+			IP = ip.IP
+		}
 	}
+	if err == nil && IP == nil {
+		err = errors.New("no IP address")
+	}
+
+	if err != nil {
+
+		// Record a port forward failure
+		sshClient.updateQualityMetrics(true, monotime.Since(dialStartTime))
 
-	if lookupResult.err != nil {
 		sshClient.rejectNewChannel(
-			newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", lookupResult.err))
+			newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
 		return
 	}
 
+	// Enforce traffic rules, using the resolved IP address.
+
 	if !isWebServerPortForward &&
 		!sshClient.isPortForwardPermitted(
 			portForwardTypeTCP,
 			false,
-			lookupResult.IP,
+			IP,
 			portToConnect) {
 
+		// Note: not recording a port forward failure in this case
+
 		sshClient.rejectNewChannel(
 			newChannel, ssh.Prohibited, "port forward not permitted")
 		return
 	}
 
+	// TCP dial.
+
+	remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
+
+	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
+
+	ctx, _ = context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
+	fwdConn, err := (&net.Dialer{}).DialContext(
+		ctx, "tcp", remoteAddr)
+
+	// Record port forward success or failure
+	sshClient.updateQualityMetrics(err == nil, monotime.Since(dialStartTime))
+
+	if err != nil {
+
+		// Monitor for low resource error conditions
+		sshClient.sshServer.monitorPortForwardDialError(err)
+
+		sshClient.rejectNewChannel(
+			newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", err))
+		return
+	}
+
+	// The upstream TCP port forward connection has been established. Schedule
+	// some cleanup and notify the SSH client that the channel is accepted.
+
+	defer fwdConn.Close()
+
+	fwdChannel, requests, err := newChannel.Accept()
+	if err != nil {
+		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
+		return
+	}
+	go ssh.DiscardRequests(requests)
+	defer fwdChannel.Close()
+
+	// Release the dialing slot and acquire an established slot.
+
+	// "established = true" cancels the deferred failedTCPPortForward()
+	established = true
+	sshClient.establishedPortForward(portForwardTypeTCP)
+
 	var bytesUp, bytesDown int64
-	sshClient.openedPortForward(portForwardTypeTCP)
 	defer func() {
 		sshClient.closedPortForward(
 			portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
 	}()
 
-	// TOCTOU note: important to increment the port forward count (via
-	// openPortForward) _before_ checking isPortForwardLimitExceeded
-	// otherwise, the client could potentially consume excess resources
-	// by initiating many port forwards concurrently.
-	if maxCount, exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
+	if exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
 
 		// Close the oldest TCP port forward. CloseOldest() closes
-		// the conn and the port forward's goroutine will complete
+		// the conn and the port forward's goroutines will complete
 		// the cleanup asynchronously.
 		//
 		// Some known limitations:
@@ -1536,91 +1684,25 @@ func (sshClient *sshClient) handleTCPChannel(
 		//   asynchronous, there exists a possibility that a client can consume
 		//   more than max port forward resources -- just not upstream sockets.
 		//
-		// - An LRU list entry for this port forward is not added until
-		//   after the dial completes, but the port forward is counted
-		//   towards max limits. This means many dials in progress will
-		//   put established connections in jeopardy.
-		//
-		// - We're closing the oldest open connection _before_ successfully
-		//   dialing the new port forward. This means we are potentially
-		//   discarding a good connection to make way for a failed connection.
-		//   We cannot simply dial first and still maintain a limit on
-		//   resources used, so to address this we'd need to add some
-		//   accounting for connections still establishing.
+		// - Closed sockets will enter the TIME_WAIT state, consuming some
+		//   resources.
 
 		sshClient.tcpPortForwardLRU.CloseOldest()
 
-		log.WithContextFields(
-			LogFields{
-				"maxCount": maxCount,
-			}).Debug("closed LRU TCP port forward")
-	}
-
-	// Dial the target remote address. This is done in a goroutine to
-	// ensure the shutdown signal is handled immediately.
-
-	remoteAddr := net.JoinHostPort(lookupResult.IP.String(), strconv.Itoa(portToConnect))
-
-	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
-
-	type dialTCPResult struct {
-		conn net.Conn
-		err  error
-	}
-	dialResultChannel := make(chan *dialTCPResult, 1)
-
-	dialStartTime := monotime.Now()
-
-	go func() {
-		conn, err := net.DialTimeout(
-			"tcp", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
-		dialResultChannel <- &dialTCPResult{conn, err}
-	}()
-
-	var dialResult *dialTCPResult
-	select {
-	case dialResult = <-dialResultChannel:
-	case <-sshClient.stopBroadcast:
-		// Note: may leave Dial in progress
-		// TODO: use net.Dialer.DialContext to be able to cancel
-		return
-	}
-
-	sshClient.updateQualityMetrics(
-		dialResult.err == nil, monotime.Since(dialStartTime))
-
-	if dialResult.err != nil {
-		sshClient.sshServer.handlePortForwardDialError(dialResult.err)
-		sshClient.rejectNewChannel(
-			newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", dialResult.err))
-		return
+		log.WithContext().Debug("closed LRU TCP port forward")
 	}
 
-	// The upstream TCP port forward connection has been established. Schedule
-	// some cleanup and notify the SSH client that the channel is accepted.
-
-	fwdConn := dialResult.conn
-	defer fwdConn.Close()
-
-	fwdChannel, requests, err := newChannel.Accept()
-	if err != nil {
-		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
-		return
-	}
-	go ssh.DiscardRequests(requests)
-	defer fwdChannel.Close()
+	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
+	defer lruEntry.Remove()
 
 	// ActivityMonitoredConn monitors the TCP port forward I/O and updates
 	// its LRU status. ActivityMonitoredConn also times out I/O on the port
 	// forward if both reads and writes have been idle for the specified
 	// duration.
 
-	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
-	defer lruEntry.Remove()
-
 	// Ensure nil interface if newClientSeedPortForward returns nil
 	var updater common.ActivityUpdater
-	seedUpdater := sshClient.newClientSeedPortForward(lookupResult.IP)
+	seedUpdater := sshClient.newClientSeedPortForward(IP)
 	if seedUpdater != nil {
 		updater = seedUpdater
 	}

+ 15 - 14
psiphon/server/udp.go

@@ -169,25 +169,22 @@ func (mux *udpPortForwardMultiplexer) run() {
 				continue
 			}
 
-			mux.sshClient.openedPortForward(portForwardTypeUDP)
-			// Note: can't defer sshClient.closedPortForward() here
+			// Note: UDP port forward counting has no dialing phase
+
+			mux.sshClient.establishedPortForward(portForwardTypeUDP)
+			// Can't defer sshClient.closedPortForward() here;
+			// relayDownstream will call sshClient.closedPortForward()
 
 			// TOCTOU note: important to increment the port forward count (via
 			// openPortForward) _before_ checking isPortForwardLimitExceeded
-			if maxCount, exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded {
+			if exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded {
 
 				// Close the oldest UDP port forward. CloseOldest() closes
 				// the conn and the port forward's goroutine will complete
 				// the cleanup asynchronously.
-				//
-				// See LRU comment in handleTCPChannel() for a known
-				// limitations regarding CloseOldest().
 				mux.portForwardLRU.CloseOldest()
 
-				log.WithContextFields(
-					LogFields{
-						"maxCount": maxCount,
-					}).Debug("closed LRU UDP port forward")
+				log.WithContext().Debug("closed LRU UDP port forward")
 			}
 
 			log.WithContextFields(
@@ -199,19 +196,24 @@ func (mux *udpPortForwardMultiplexer) run() {
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 			if err != nil {
 				mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0)
-				mux.sshClient.sshServer.handlePortForwardDialError(err)
+
+				// Monitor for low resource error conditions
+				mux.sshClient.sshServer.monitorPortForwardDialError(err)
+
 				// Note: Debug level, as logMessage may contain user traffic destination address information
 				log.WithContextFields(LogFields{"error": err}).Debug("DialUDP failed")
 				continue
 			}
 
+			lruEntry := mux.portForwardLRU.Add(udpConn)
+			// Can't defer lruEntry.Remove() here;
+			// relayDownstream will call lruEntry.Remove()
+
 			// ActivityMonitoredConn monitors the TCP port forward I/O and updates
 			// its LRU status. ActivityMonitoredConn also times out I/O on the port
 			// forward if both reads and writes have been idle for the specified
 			// duration.
 
-			lruEntry := mux.portForwardLRU.Add(udpConn)
-
 			// Ensure nil interface if newClientSeedPortForward returns nil
 			var updater common.ActivityUpdater
 			seedUpdater := mux.sshClient.newClientSeedPortForward(dialIP)
@@ -247,7 +249,6 @@ func (mux *udpPortForwardMultiplexer) run() {
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwardsMutex.Unlock()
 
-			// relayDownstream will call sshClient.closedPortForward()
 			mux.relayWaitGroup.Add(1)
 			go portForward.relayDownstream()
 		}