Browse Source

Further refine TCP port forward handling

- Change base dial timeout to match client's
  port forward timeout.
- Deduct from dial timeout the time elapsed
  at each phase, so the total time elapsed
  should not exceed the base timeout.
- Stricter resource cap: move dialing limit
  logic out of handleTCPChannel and enforce
  limit in runTunnel before spawning new
  goroutines.
- More conservative port forward rejection:
  allow up to base dial timeout elapsed time
  waiting for a dialing slot to become
  available.
Rod Hynes 9 years ago
parent
commit
404e681e58
1 changed files with 270 additions and 142 deletions
  1. 270 142
      psiphon/server/tunnelServer.go

+ 270 - 142
psiphon/server/tunnelServer.go

@@ -26,7 +26,6 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"math/rand"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
 	"sync"
 	"sync"
@@ -42,16 +41,13 @@ import (
 )
 )
 
 
 const (
 const (
-	SSH_HANDSHAKE_TIMEOUT                          = 30 * time.Second
-	SSH_CONNECTION_READ_DEADLINE                   = 5 * time.Minute
-	SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT         = 30 * time.Second
-	SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT              = 30 * time.Second
-	SSH_TCP_PORT_FORWARD_LIMIT_RETRIES             = 5
-	SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD = 10 * time.Millisecond
-	SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MAX_PERIOD = 100 * time.Millisecond
-	SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE          = 8192
-	SSH_SEND_OSL_INITIAL_RETRY_DELAY               = 30 * time.Second
-	SSH_SEND_OSL_RETRY_FACTOR                      = 2
+	SSH_HANDSHAKE_TIMEOUT                 = 30 * time.Second
+	SSH_CONNECTION_READ_DEADLINE          = 5 * time.Minute
+	SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT     = 10 * time.Second
+	SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
+	SSH_TCP_PORT_FORWARD_QUEUE_SIZE       = 1024
+	SSH_SEND_OSL_INITIAL_RETRY_DELAY      = 30 * time.Second
+	SSH_SEND_OSL_RETRY_FACTOR             = 2
 )
 )
 
 
 // TunnelServer is the main server that accepts Psiphon client
 // TunnelServer is the main server that accepts Psiphon client
@@ -672,25 +668,26 @@ func (sshServer *sshServer) monitorPortForwardDialError(err error) {
 
 
 type sshClient struct {
 type sshClient struct {
 	sync.Mutex
 	sync.Mutex
-	sshServer              *sshServer
-	tunnelProtocol         string
-	sshConn                ssh.Conn
-	activityConn           *common.ActivityMonitoredConn
-	throttledConn          *common.ThrottledConn
-	geoIPData              GeoIPData
-	sessionID              string
-	supportsServerRequests bool
-	handshakeState         handshakeState
-	udpChannel             ssh.Channel
-	trafficRules           TrafficRules
-	tcpTrafficState        trafficState
-	udpTrafficState        trafficState
-	qualityMetrics         qualityMetrics
-	tcpPortForwardLRU      *common.LRUConns
-	oslClientSeedState     *osl.ClientSeedState
-	signalIssueSLOKs       chan struct{}
-	runContext             context.Context
-	stopRunning            context.CancelFunc
+	sshServer                            *sshServer
+	tunnelProtocol                       string
+	sshConn                              ssh.Conn
+	activityConn                         *common.ActivityMonitoredConn
+	throttledConn                        *common.ThrottledConn
+	geoIPData                            GeoIPData
+	sessionID                            string
+	supportsServerRequests               bool
+	handshakeState                       handshakeState
+	udpChannel                           ssh.Channel
+	trafficRules                         TrafficRules
+	tcpTrafficState                      trafficState
+	udpTrafficState                      trafficState
+	qualityMetrics                       qualityMetrics
+	tcpPortForwardLRU                    *common.LRUConns
+	oslClientSeedState                   *osl.ClientSeedState
+	signalIssueSLOKs                     chan struct{}
+	runContext                           context.Context
+	stopRunning                          context.CancelFunc
+	tcpPortForwardDialingAvailableSignal context.CancelFunc
 }
 }
 
 
 type trafficState struct {
 type trafficState struct {
@@ -972,14 +969,16 @@ func (sshClient *sshClient) stop() {
 	sshClient.sshConn.Wait()
 	sshClient.sshConn.Wait()
 }
 }
 
 
-// runTunnel handles/dispatches new channel and new requests from the client.
+// runTunnel handles/dispatches new channels and new requests from the client.
 // When the SSH client connection closes, both the channels and requests channels
 // When the SSH client connection closes, both the channels and requests channels
-// will close and runClient will exit.
+// will close and runTunnel will exit.
 func (sshClient *sshClient) runTunnel(
 func (sshClient *sshClient) runTunnel(
 	channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
 	channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
 
 
 	waitGroup := new(sync.WaitGroup)
 	waitGroup := new(sync.WaitGroup)
 
 
+	// Start client SSH API request handler
+
 	waitGroup.Add(1)
 	waitGroup.Add(1)
 	go func() {
 	go func() {
 		defer waitGroup.Done()
 		defer waitGroup.Done()
@@ -1015,6 +1014,8 @@ func (sshClient *sshClient) runTunnel(
 		}
 		}
 	}()
 	}()
 
 
+	// Start OSL sender
+
 	if sshClient.supportsServerRequests {
 	if sshClient.supportsServerRequests {
 		waitGroup.Add(1)
 		waitGroup.Add(1)
 		go func() {
 		go func() {
@@ -1023,6 +1024,153 @@ func (sshClient *sshClient) runTunnel(
 		}()
 		}()
 	}
 	}
 
 
+	// Lifecycle of a TCP port forward:
+	//
+	// 1. A "direct-tcpip" SSH request is received from the client.
+	//
+	//    A new TCP port forward request is enqueued. The queue delivers TCP port
+	//    forward requests to the TCP port forward manager, which enforces the TCP
+	//    port forward dial limit.
+	//
+	//    Enqueuing new requests allows for reading further SSH requests from the
+	//    client without blocking when the dial limit is hit; this is to allow new
+	//    UDP/udpgw port forwards to be restablished without delay. The maximum size
+	//    of the queue enforces a hard cap on resources consumed by a client in the
+	//    pre-dial phase. When the queue is full, new TCP port forwards are
+	//    immediately rejected.
+	//
+	// 2. The TCP port forward manager dequeues the request.
+	//
+	//    The manager calls dialingTCPPortForward(), which increments
+	//    concurrentDialingPortForwardCount, and calls
+	//    isTCPDialingPortForwardLimitExceeded() to check the concurrent dialing
+	//    count.
+	//
+	//    The manager enforces the concurrent TCP dial limit: when at the limit, the
+	//    manager blocks waiting for the number of dials to drop below the limit before
+	//    dispatching the request to handleTCPPortForward(), which will run in its own
+	//    goroutine and will dial and relay the port forward.
+	//
+	//    The block delays the current request and also halts dequeuing of subsequent
+	//    requests and could ultimately cause requests to be immediately rejected if
+	//    the queue fills. These actions are intended to apply back pressure when
+	//    upstream network resources are impaired.
+	//
+	//    The time spent in the queue is deducted from the port forward's dial timeout.
+	//    The time spent blocking while at the dial limit is similarly deducted from
+	//    the dial timeout. If the dial timeout has expired before the dial begins, the
+	//    port forward is rejected and a stat is recorded.
+	//
+	// 3. handleTCPPortForward() performs the port forward dial and relaying.
+	//
+	//     a. Dial the target, using the dial timeout remaining after queue and blocking
+	//        time is deducted.
+	//
+	//     b. If the dial fails, call failedTCPPortForward() to decrement
+	//        concurrentDialingPortForwardCount, freeing up a dial slot.
+	//
+	//     c. If the dial succeeds, call establishedPortForward(), which decrements
+	//        concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
+	//        the "established" port forward count.
+	//
+	//    d. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
+	//       concurrentPortForwardCount, the number of _established_ TCP port forwards.
+	//       If the limit is exceeded, the LRU established TCP port forward is closed and
+	//       the newly established TCP port forward proceeds. This LRU logic allows some
+	//       dangling resource consumption (e.g., TIME_WAIT) while providing a better
+	//       experience for clients.
+	//
+	//    e. Relay data.
+	//
+	//    f. Call closedPortForward() which decrements concurrentPortForwardCount and
+	//       records bytes transferred.
+
+	// Start the TCP port forward manager
+
+	type newTCPPortForward struct {
+		enqueueTime   monotime.Time
+		hostToConnect string
+		portToConnect int
+		newChannel    ssh.NewChannel
+	}
+
+	// The queue size is set to the traffic rules MaxTCPDialingPortForwardCount, which is a
+	// reasonable indication of resource limits per client; when that value is not set, a default
+	// is used.
+	// A limitation: this queue size is set once and doesn't change, for this client, when
+	// traffic rules are reloaded.
+	queueSize := sshClient.getTCPPortForwardLimit()
+	if queueSize == 0 {
+		queueSize = SSH_TCP_PORT_FORWARD_QUEUE_SIZE
+	}
+	newTCPPortForwards := make(chan *newTCPPortForward, queueSize)
+
+	waitGroup.Add(1)
+	go func() {
+		defer waitGroup.Done()
+		for newPortForward := range newTCPPortForwards {
+
+			remainingDialTimeout := SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT - monotime.Since(newPortForward.enqueueTime)
+
+			if remainingDialTimeout <= 0 {
+				sshClient.updateQualityMetricsWithRejectedDialingLimit()
+				sshClient.rejectNewChannel(
+					newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out in queue")
+				continue
+			}
+
+			// Reserve a TCP dialing slot.
+			//
+			// TOCTOU note: important to increment counts _before_ checking limits; otherwise,
+			// the client could potentially consume excess resources by initiating many port
+			// forwards concurrently.
+
+			sshClient.dialingTCPPortForward()
+
+			// When max dials are in progress, wait up to remainingDialTimeout for dialing
+			// to become available. This blocks all dequeing.
+
+			if sshClient.isTCPDialingPortForwardLimitExceeded() {
+				blockStartTime := monotime.Now()
+				ctx, cancelFunc := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
+				sshClient.setTCPPortForwardDialingAvailableSignal(cancelFunc)
+				<-ctx.Done()
+				sshClient.setTCPPortForwardDialingAvailableSignal(nil)
+				remainingDialTimeout -= monotime.Since(blockStartTime)
+			}
+
+			if remainingDialTimeout <= 0 {
+				sshClient.updateQualityMetricsWithRejectedDialingLimit()
+				sshClient.rejectNewChannel(
+					newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out before dialing")
+				continue
+			}
+
+			// Dial and relay the TCP port forward. handleTCPChannel is run in its own worker goroutine.
+			// handleTCPChannel will release the dialing slot reserved by dialingTCPPortForward(); and
+			// will deal with remainingDialTimeout <= 0.
+
+			waitGroup.Add(1)
+			go func(remainingDialTimeout time.Duration, newPortForward *newTCPPortForward) {
+				defer waitGroup.Done()
+				sshClient.handleTCPChannel(
+					remainingDialTimeout,
+					newPortForward.hostToConnect,
+					newPortForward.portToConnect,
+					newPortForward.newChannel)
+			}(remainingDialTimeout, newPortForward)
+		}
+	}()
+
+	// Handle new channel (port forward) requests from the client.
+	//
+	// udpgw client connections are dispatched immediately (clients use this for
+	// DNS, so it's essential to not block; and only one udpgw connection is
+	// retained at a time).
+	//
+	// All other TCP port forwards are dispatched via the TCP port forward
+	// manager queue.
+
 	for newChannel := range channels {
 	for newChannel := range channels {
 
 
 		if newChannel.ChannelType() != "direct-tcpip" {
 		if newChannel.ChannelType() != "direct-tcpip" {
@@ -1030,27 +1178,65 @@ func (sshClient *sshClient) runTunnel(
 			continue
 			continue
 		}
 		}
 
 
-		// Process each port forward concurrently
+		// http://tools.ietf.org/html/rfc4254#section-7.2
+		var directTcpipExtraData struct {
+			HostToConnect       string
+			PortToConnect       uint32
+			OriginatorIPAddress string
+			OriginatorPort      uint32
+		}
+
+		err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
+		if err != nil {
+			sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
+			continue
+		}
+
+		// Intercept TCP port forwards to a specified udpgw server and handle directly.
+		// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
+		isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
+			sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
+				net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
 
 
-		waitGroup.Add(1)
-		go func(channel ssh.NewChannel) {
-			defer waitGroup.Done()
-			sshClient.handleNewPortForwardChannel(channel)
-		}(newChannel)
+		if isUDPChannel {
+
+			// Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
+			// own worker goroutine.
 
 
-		// Throttle accepting new channels when at the TCP dialing port forwards limit.
-		// This mitigates clients rapidly enqueuing many port forward handlers in a
-		// pre-dial state.
-		// TODO: block here until under the limit (excepting UDP)?
+			waitGroup.Add(1)
+			go func(channel ssh.NewChannel) {
+				defer waitGroup.Done()
+				sshClient.handleUDPChannel(channel)
+			}(newChannel)
 
 
-		if sshClient.isTCPDialingPortForwardLimitExceeded() {
-			sshClient.tpcDialingPortForwardLimitThrottle()
+		} else {
+
+			// Dispatch via TCP port forward manager. When the queue is full, the channel
+			// is immediately rejected.
+
+			tcpPortForward := &newTCPPortForward{
+				enqueueTime:   monotime.Now(),
+				hostToConnect: directTcpipExtraData.HostToConnect,
+				portToConnect: int(directTcpipExtraData.PortToConnect),
+				newChannel:    newChannel,
+			}
+
+			select {
+			case newTCPPortForwards <- tcpPortForward:
+			default:
+				sshClient.updateQualityMetricsWithRejectedDialingLimit()
+				sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "TCP port forward dial queue full")
+			}
 		}
 		}
 	}
 	}
 
 
 	// The channel loop is interrupted by a client
 	// The channel loop is interrupted by a client
 	// disconnect or by calling sshClient.stop().
 	// disconnect or by calling sshClient.stop().
 
 
+	// Stop the TCP port forward manager
+	close(newTCPPortForwards)
+
+	// Stop all other worker goroutines
 	sshClient.stopRunning()
 	sshClient.stopRunning()
 
 
 	waitGroup.Wait()
 	waitGroup.Wait()
@@ -1181,36 +1367,6 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason s
 	newChannel.Reject(reason, reason.String())
 	newChannel.Reject(reason, reason.String())
 }
 }
 
 
-func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
-
-	// http://tools.ietf.org/html/rfc4254#section-7.2
-	var directTcpipExtraData struct {
-		HostToConnect       string
-		PortToConnect       uint32
-		OriginatorIPAddress string
-		OriginatorPort      uint32
-	}
-
-	err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
-	if err != nil {
-		sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
-		return
-	}
-
-	// Intercept TCP port forwards to a specified udpgw server and handle directly.
-	// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
-	isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
-		sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
-			net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
-
-	if isUDPChannel {
-		sshClient.handleUDPChannel(newChannel)
-	} else {
-		sshClient.handleTCPChannel(
-			directTcpipExtraData.HostToConnect, int(directTcpipExtraData.PortToConnect), newChannel)
-	}
-}
-
 // setHandshakeState records that a client has completed a handshake API request.
 // setHandshakeState records that a client has completed a handshake API request.
 // Some parameters from the handshake request may be used in future traffic rule
 // Some parameters from the handshake request may be used in future traffic rule
 // selection. Port forwards are disallowed until a handshake is complete. The
 // selection. Port forwards are disallowed until a handshake is complete. The
@@ -1332,13 +1488,19 @@ func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
 }
 }
 
 
 func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
 func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
-
 	sshClient.Lock()
 	sshClient.Lock()
 	defer sshClient.Unlock()
 	defer sshClient.Unlock()
 
 
 	return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
 	return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
 }
 }
 
 
+func (sshClient *sshClient) setTCPPortForwardDialingAvailableSignal(signal context.CancelFunc) {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	sshClient.tcpPortForwardDialingAvailableSignal = signal
+}
+
 const (
 const (
 	portForwardTypeTCP = iota
 	portForwardTypeTCP = iota
 	portForwardTypeUDP
 	portForwardTypeUDP
@@ -1433,6 +1595,14 @@ func (sshClient *sshClient) isPortForwardLimitExceeded(
 	return false
 	return false
 }
 }
 
 
+func (sshClient *sshClient) getTCPPortForwardLimit() int {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	return *sshClient.trafficRules.MaxTCPPortForwardCount
+}
+
 func (sshClient *sshClient) dialingTCPPortForward() {
 func (sshClient *sshClient) dialingTCPPortForward() {
 
 
 	sshClient.Lock()
 	sshClient.Lock()
@@ -1467,6 +1637,14 @@ func (sshClient *sshClient) establishedPortForward(
 		// Assumes TCP port forwards called dialingTCPPortForward
 		// Assumes TCP port forwards called dialingTCPPortForward
 		state.concurrentDialingPortForwardCount -= 1
 		state.concurrentDialingPortForwardCount -= 1
 
 
+		if sshClient.tcpPortForwardDialingAvailableSignal != nil {
+
+			max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
+			if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
+				sshClient.tcpPortForwardDialingAvailableSignal()
+			}
+		}
+
 	} else {
 	} else {
 		state = &sshClient.udpTrafficState
 		state = &sshClient.udpTrafficState
 	}
 	}
@@ -1520,59 +1698,16 @@ func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
 	sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
 	sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
 }
 }
 
 
-func (sshClient *sshClient) tpcDialingPortForwardLimitThrottle() {
-
-	duration :=
-		SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD +
-			time.Duration(
-				rand.Int63n(
-					int64(SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MAX_PERIOD-
-						SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD)+1))
-
-	ctx, _ := context.WithTimeout(sshClient.runContext, duration)
-	<-ctx.Done()
-}
-
 func (sshClient *sshClient) handleTCPChannel(
 func (sshClient *sshClient) handleTCPChannel(
+	remainingDialTimeout time.Duration,
 	hostToConnect string,
 	hostToConnect string,
 	portToConnect int,
 	portToConnect int,
 	newChannel ssh.NewChannel) {
 	newChannel ssh.NewChannel) {
 
 
-	// Lifecycle of a TCP port forward:
-	//
-	// 1. Call dialingTCPPortForward(), which increments concurrentDialingPortForwardCount
-	//
-	// 2. Check isTCPDialingPortForwardLimitExceeded(), which enforces the configurable
-	//    limit on concurrentDialingPortForwardCount. When the limit is exceeded, throttle
-	//    and recheck and then ultimately reject the port forward; this hard limit applies
-	//    back pressure when upstream network resources are impaired.
-	//
-	//        TOCTOU note: important to increment the port forward counts _before_
-	//        checking limits; otherwise, the client could potentially consume excess
-	//        resources by initiating many port forwards concurrently.
-	//
-	// 3. Dial the target.
-	//
-	// 4. If the dial fails, call failedTCPPortForward() to decrement
-	//    concurrentDialingPortForwardCount, freeing up a dial slot.
-	//
-	// 5. If the dial succeeds, call establishedPortForward(), which decrements
-	//    concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
-	//    the "established" port forward count.
-	//
-	// 6. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
-	//    concurrentPortForwardCount, the number of _established_ TCP port forwards.
-	//    If the limit is exceeded, the LRU established TCP port forward is closed and
-	//    the newly established TCP port forward proceeds. This LRU logic allows some
-	//    dangling resource consumption (e.g., TIME_WAIT) while providing a better
-	//    experience for clients.
-	//
-	// 7. Relay data.
-	//
-	// 8. Call closedPortForward() which decrements concurrentPortForwardCount and
-	//    records bytes transferred.
+	// Assumptions:
+	// - sshClient.dialingTCPPortForward() has been called
+	// - remainingDialTimeout > 0
 
 
-	sshClient.dialingTCPPortForward()
 	established := false
 	established := false
 	defer func() {
 	defer func() {
 		if !established {
 		if !established {
@@ -1580,23 +1715,6 @@ func (sshClient *sshClient) handleTCPChannel(
 		}
 		}
 	}()
 	}()
 
 
-	exceeded := true
-	for i := 0; i < SSH_TCP_PORT_FORWARD_LIMIT_RETRIES; i++ {
-		exceeded = sshClient.isTCPDialingPortForwardLimitExceeded()
-		if !exceeded {
-			break
-		}
-		sshClient.tpcDialingPortForwardLimitThrottle()
-	}
-	if exceeded {
-
-		sshClient.updateQualityMetricsWithRejectedDialingLimit()
-
-		sshClient.rejectNewChannel(
-			newChannel, ssh.Prohibited, "dialing port forward limit exceeded")
-		return
-	}
-
 	// Transparently redirect web API request connections.
 	// Transparently redirect web API request connections.
 
 
 	isWebServerPortForward := false
 	isWebServerPortForward := false
@@ -1627,7 +1745,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 
 	log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
 	log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
 
 
-	ctx, _ := context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT)
+	ctx, _ := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
 	IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
 	IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
 
 
 	// TODO: shuffle list to try other IPs?
 	// TODO: shuffle list to try other IPs?
@@ -1643,16 +1761,26 @@ func (sshClient *sshClient) handleTCPChannel(
 		err = errors.New("no IP address")
 		err = errors.New("no IP address")
 	}
 	}
 
 
+	resolveElapsedTime := monotime.Since(dialStartTime)
+
 	if err != nil {
 	if err != nil {
 
 
 		// Record a port forward failure
 		// Record a port forward failure
-		sshClient.updateQualityMetricsWithDialResult(true, monotime.Since(dialStartTime))
+		sshClient.updateQualityMetricsWithDialResult(true, resolveElapsedTime)
 
 
 		sshClient.rejectNewChannel(
 		sshClient.rejectNewChannel(
 			newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
 			newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
 		return
 		return
 	}
 	}
 
 
+	remainingDialTimeout -= resolveElapsedTime
+
+	if remainingDialTimeout <= 0 {
+		sshClient.rejectNewChannel(
+			newChannel, ssh.Prohibited, "TCP port forward timed out resolving")
+		return
+	}
+
 	// Enforce traffic rules, using the resolved IP address.
 	// Enforce traffic rules, using the resolved IP address.
 
 
 	if !isWebServerPortForward &&
 	if !isWebServerPortForward &&
@@ -1675,7 +1803,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 
 	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
 	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
 
 
-	ctx, _ = context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
+	ctx, _ = context.WithTimeout(sshClient.runContext, remainingDialTimeout)
 	fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
 	fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
 
 
 	// Record port forward success or failure
 	// Record port forward success or failure