Browse Source

Merge pull request #355 from rod-hynes/master

psiphond updates
Rod Hynes 9 years ago
parent
commit
6e07bbff0c

+ 1 - 1
psiphon/common/utils.go

@@ -105,7 +105,7 @@ func MakeSecureRandomPadding(minLength, maxLength int) ([]byte, error) {
 }
 
 // MakeRandomPeriod returns a random duration, within a given range.
-// In the unlikely case where an  underlying MakeRandom functions fails,
+// In the unlikely case where an underlying MakeRandom functions fails,
 // the period is the minimum.
 func MakeRandomPeriod(min, max time.Duration) (time.Duration, error) {
 	period, err := MakeSecureRandomInt64(max.Nanoseconds() - min.Nanoseconds())

+ 30 - 19
psiphon/server/dns.go

@@ -23,6 +23,7 @@ import (
 	"bufio"
 	"bytes"
 	"errors"
+	"math/rand"
 	"net"
 	"strings"
 	"sync/atomic"
@@ -38,7 +39,7 @@ const (
 	DNS_RESOLVER_PORT               = 53
 )
 
-// DNSResolver maintains a fresh DNS resolver value, monitoring
+// DNSResolver maintains fresh DNS resolver values, monitoring
 // "/etc/resolv.conf" on platforms where it is available; and
 // otherwise using a default value.
 type DNSResolver struct {
@@ -48,12 +49,12 @@ type DNSResolver struct {
 	lastReloadTime int64
 	common.ReloadableFile
 	isReloading int32
-	resolver    net.IP
+	resolvers   []net.IP
 }
 
 // NewDNSResolver initializes a new DNSResolver, loading it with
-// a fresh resolver value. The load must succeed, so either
-// "/etc/resolv.conf" must contain a valid "nameserver" line with
+// fresh resolver values. The load must succeed, so either
+// "/etc/resolv.conf" must contain valid "nameserver" lines with
 // a DNS server IP address, or a valid "defaultResolver" default
 // value must be provided.
 // On systems without "/etc/resolv.conf", "defaultResolver" is
@@ -79,18 +80,18 @@ func NewDNSResolver(defaultResolver string) (*DNSResolver, error) {
 		DNS_SYSTEM_CONFIG_FILENAME,
 		func(fileContent []byte) error {
 
-			resolver, err := parseResolveConf(fileContent)
+			resolvers, err := parseResolveConf(fileContent)
 			if err != nil {
 				// On error, state remains the same
 				return common.ContextError(err)
 			}
 
-			dns.resolver = resolver
+			dns.resolvers = resolvers
 
 			log.WithContextFields(
 				LogFields{
-					"resolver": resolver.String(),
-				}).Debug("loaded system DNS resolver")
+					"resolvers": resolvers,
+				}).Debug("loaded system DNS resolvers")
 
 			return nil
 		})
@@ -110,15 +111,19 @@ func NewDNSResolver(defaultResolver string) (*DNSResolver, error) {
 			return nil, common.ContextError(err)
 		}
 
-		dns.resolver = resolver
+		dns.resolvers = []net.IP{resolver}
 	}
 
 	return dns, nil
 }
 
-// Get returns the cached resolver, first updating the cached
-// value if it's stale. If reloading fails, the previous value
-// is used.
+// Get returns one of the cached resolvers, selected at random,
+// after first updating the cached values if they're stale. If
+// reloading fails, the previous values are used.
+//
+// Randomly selecting any one of the configured resolvers is
+// expected to be more resiliant to failure; e.g., if one of
+// the resolvers becomes unavailable.
 func (dns *DNSResolver) Get() net.IP {
 
 	// Every UDP DNS port forward frequently calls Get(), so this code
@@ -158,13 +163,15 @@ func (dns *DNSResolver) Get() net.IP {
 	dns.ReloadableFile.RLock()
 	defer dns.ReloadableFile.RUnlock()
 
-	return dns.resolver
+	return dns.resolvers[rand.Intn(len(dns.resolvers))]
 }
 
-func parseResolveConf(fileContent []byte) (net.IP, error) {
+func parseResolveConf(fileContent []byte) ([]net.IP, error) {
 
 	scanner := bufio.NewScanner(bytes.NewReader(fileContent))
 
+	var resolvers []net.IP
+
 	for scanner.Scan() {
 		line := scanner.Text()
 		if strings.HasPrefix(line, ";") || strings.HasPrefix(line, "#") {
@@ -172,10 +179,10 @@ func parseResolveConf(fileContent []byte) (net.IP, error) {
 		}
 		fields := strings.Fields(line)
 		if len(fields) == 2 && fields[0] == "nameserver" {
-			// TODO: parseResolverAddress will fail when the nameserver
-			// is not an IP address. It may be a domain name. To support
-			// this case, should proceed to the next "nameserver" line.
-			return parseResolver(fields[1])
+			resolver, err := parseResolver(fields[1])
+			if err == nil {
+				resolvers = append(resolvers, resolver)
+			}
 		}
 	}
 
@@ -183,7 +190,11 @@ func parseResolveConf(fileContent []byte) (net.IP, error) {
 		return nil, common.ContextError(err)
 	}
 
-	return nil, common.ContextError(errors.New("nameserver not found"))
+	if len(resolvers) == 0 {
+		return nil, common.ContextError(errors.New("no nameservers found"))
+	}
+
+	return resolvers, nil
 }
 
 func parseResolver(resolver string) (net.IP, error) {

+ 6 - 2
psiphon/server/meek.go

@@ -243,7 +243,9 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 	err = session.clientConn.pumpReads(request.Body)
 	if err != nil {
 		if err != io.EOF {
-			log.WithContextFields(LogFields{"error": err}).Warning("pump reads failed")
+			// Debug since errors such as "i/o timeout" occur during normal operation;
+			// also, golang network error messages may contain client IP.
+			log.WithContextFields(LogFields{"error": err}).Debug("pump reads failed")
 		}
 		server.terminateConnection(responseWriter, request)
 		server.closeSession(sessionID)
@@ -266,7 +268,9 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 	err = session.clientConn.pumpWrites(responseWriter)
 	if err != nil {
 		if err != io.EOF {
-			log.WithContextFields(LogFields{"error": err}).Warning("pump writes failed")
+			// Debug since errors such as "i/o timeout" occur during normal operation;
+			// also, golang network error messages may contain client IP.
+			log.WithContextFields(LogFields{"error": err}).Debug("pump writes failed")
 		}
 		server.terminateConnection(responseWriter, request)
 		server.closeSession(sessionID)

+ 4 - 1
psiphon/server/server_test.go

@@ -603,10 +603,13 @@ func makeTunneledNTPRequest(t *testing.T, localSOCKSProxyPort int, udpgwServerAd
 	return err
 }
 
+var nextUDPProxyPort = 7300
+
 func makeTunneledNTPRequestAttempt(
 	t *testing.T, testHostname string, timeout time.Duration, localSOCKSProxyPort int, udpgwServerAddress string) error {
 
-	localUDPProxyAddress, err := net.ResolveUDPAddr("udp", "127.0.0.1:7301")
+	nextUDPProxyPort++
+	localUDPProxyAddress, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", nextUDPProxyPort))
 	if err != nil {
 		return fmt.Errorf("ResolveUDPAddr failed: %s", err)
 	}

+ 24 - 3
psiphon/server/trafficRules.go

@@ -30,6 +30,7 @@ import (
 const (
 	DEFAULT_IDLE_TCP_PORT_FORWARD_TIMEOUT_MILLISECONDS = 30000
 	DEFAULT_IDLE_UDP_PORT_FORWARD_TIMEOUT_MILLISECONDS = 30000
+	DEFAULT_MAX_TCP_DIALING_PORT_FORWARD_COUNT         = 64
 	DEFAULT_MAX_TCP_PORT_FORWARD_COUNT                 = 512
 	DEFAULT_MAX_UDP_PORT_FORWARD_COUNT                 = 32
 )
@@ -106,14 +107,25 @@ type TrafficRules struct {
 	// is used.
 	IdleUDPPortForwardTimeoutMilliseconds *int
 
-	// MaxTCPPortForwardCount is the maximum number of TCP port
-	// forwards each client may have open concurrently.
+	// MaxTCPDialingPortForwardCount is the maximum number of dialing
+	// TCP port forwards each client may have open concurrently. When
+	// persistently at the limit, new TCP port forwards are rejected.
+	// A value of 0 specifies no maximum. When omitted in
+	// DefaultRules, DEFAULT_MAX_TCP_DIALING_PORT_FORWARD_COUNT is used.
+	MaxTCPDialingPortForwardCount *int
+
+	// MaxTCPPortForwardCount is the maximum number of established TCP
+	// port forwards each client may have open concurrently. If at the
+	// limit when a new TCP port forward is established, the LRU
+	// established TCP port forward is closed.
 	// A value of 0 specifies no maximum. When omitted in
 	// DefaultRules, DEFAULT_MAX_TCP_PORT_FORWARD_COUNT is used.
 	MaxTCPPortForwardCount *int
 
 	// MaxUDPPortForwardCount is the maximum number of UDP port
-	// forwards each client may have open concurrently.
+	// forwards each client may have open concurrently. If at the
+	// limit when a new UDP port forward is created, the LRU
+	// UDP port forward is closed.
 	// A value of 0 specifies no maximum. When omitted in
 	// DefaultRules, DEFAULT_MAX_UDP_PORT_FORWARD_COUNT is used.
 	MaxUDPPortForwardCount *int
@@ -299,6 +311,11 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			intPtr(DEFAULT_IDLE_UDP_PORT_FORWARD_TIMEOUT_MILLISECONDS)
 	}
 
+	if trafficRules.MaxTCPDialingPortForwardCount == nil {
+		trafficRules.MaxTCPDialingPortForwardCount =
+			intPtr(DEFAULT_MAX_TCP_DIALING_PORT_FORWARD_COUNT)
+	}
+
 	if trafficRules.MaxTCPPortForwardCount == nil {
 		trafficRules.MaxTCPPortForwardCount =
 			intPtr(DEFAULT_MAX_TCP_PORT_FORWARD_COUNT)
@@ -393,6 +410,10 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			trafficRules.IdleUDPPortForwardTimeoutMilliseconds = filteredRules.Rules.IdleUDPPortForwardTimeoutMilliseconds
 		}
 
+		if filteredRules.Rules.MaxTCPDialingPortForwardCount != nil {
+			trafficRules.MaxTCPDialingPortForwardCount = filteredRules.Rules.MaxTCPDialingPortForwardCount
+		}
+
 		if filteredRules.Rules.MaxTCPPortForwardCount != nil {
 			trafficRules.MaxTCPPortForwardCount = filteredRules.Rules.MaxTCPPortForwardCount
 		}

+ 309 - 153
psiphon/server/tunnelServer.go

@@ -20,15 +20,18 @@
 package server
 
 import (
+	"context"
 	"crypto/subtle"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
+	"math/rand"
 	"net"
 	"strconv"
 	"sync"
 	"sync/atomic"
+	"syscall"
 	"time"
 
 	"github.com/Psiphon-Inc/crypto/ssh"
@@ -39,13 +42,16 @@ import (
 )
 
 const (
-	SSH_HANDSHAKE_TIMEOUT                  = 30 * time.Second
-	SSH_CONNECTION_READ_DEADLINE           = 5 * time.Minute
-	SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT = 30 * time.Second
-	SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT      = 30 * time.Second
-	SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE  = 8192
-	SSH_SEND_OSL_INITIAL_RETRY_DELAY       = 30 * time.Second
-	SSH_SEND_OSL_RETRY_FACTOR              = 2
+	SSH_HANDSHAKE_TIMEOUT                          = 30 * time.Second
+	SSH_CONNECTION_READ_DEADLINE                   = 5 * time.Minute
+	SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT         = 30 * time.Second
+	SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT              = 30 * time.Second
+	SSH_TCP_PORT_FORWARD_LIMIT_RETRIES             = 5
+	SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD = 10 * time.Millisecond
+	SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MAX_PERIOD = 100 * time.Millisecond
+	SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE          = 8192
+	SSH_SEND_OSL_INITIAL_RETRY_DELAY               = 30 * time.Second
+	SSH_SEND_OSL_RETRY_FACTOR                      = 2
 )
 
 // TunnelServer is the main server that accepts Psiphon client
@@ -455,6 +461,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 		protocolStats[tunnelProtocol]["ALL"] = make(map[string]int64)
 		protocolStats[tunnelProtocol]["ALL"]["accepted_clients"] = 0
 		protocolStats[tunnelProtocol]["ALL"]["established_clients"] = 0
+		protocolStats[tunnelProtocol]["ALL"]["dialing_tcp_port_forwards"] = 0
 		protocolStats[tunnelProtocol]["ALL"]["tcp_port_forwards"] = 0
 		protocolStats[tunnelProtocol]["ALL"]["total_tcp_port_forwards"] = 0
 		protocolStats[tunnelProtocol]["ALL"]["udp_port_forwards"] = 0
@@ -471,6 +478,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 					protocolStats[tunnelProtocol][region] = make(map[string]int64)
 					protocolStats[tunnelProtocol][region]["accepted_clients"] = 0
 					protocolStats[tunnelProtocol][region]["established_clients"] = 0
+					protocolStats[tunnelProtocol][region]["dialing_tcp_port_forwards"] = 0
 					protocolStats[tunnelProtocol][region]["tcp_port_forwards"] = 0
 					protocolStats[tunnelProtocol][region]["total_tcp_port_forwards"] = 0
 					protocolStats[tunnelProtocol][region]["udp_port_forwards"] = 0
@@ -495,6 +503,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 				protocolStats[client.tunnelProtocol][region] = make(map[string]int64)
 				protocolStats[client.tunnelProtocol][region]["accepted_clients"] = 0
 				protocolStats[client.tunnelProtocol][region]["established_clients"] = 0
+				protocolStats[client.tunnelProtocol][region]["dialing_tcp_port_forwards"] = 0
 				protocolStats[client.tunnelProtocol][region]["tcp_port_forwards"] = 0
 				protocolStats[client.tunnelProtocol][region]["total_tcp_port_forwards"] = 0
 				protocolStats[client.tunnelProtocol][region]["udp_port_forwards"] = 0
@@ -504,8 +513,10 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 			// Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
 			protocolStats[client.tunnelProtocol][region]["established_clients"] += 1
 
+			protocolStats[client.tunnelProtocol][region]["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
 			protocolStats[client.tunnelProtocol][region]["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
 			protocolStats[client.tunnelProtocol][region]["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
+			// client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
 			protocolStats[client.tunnelProtocol][region]["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
 			protocolStats[client.tunnelProtocol][region]["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
 
@@ -517,10 +528,12 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 		aggregatedQualityMetrics.tcpPortForwardFailedCount += client.qualityMetrics.tcpPortForwardFailedCount
 		aggregatedQualityMetrics.tcpPortForwardFailedDuration +=
 			client.qualityMetrics.tcpPortForwardFailedDuration / time.Millisecond
+		aggregatedQualityMetrics.tcpPortForwardRejectedDialingLimitCount += client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount
 		client.qualityMetrics.tcpPortForwardDialedCount = 0
 		client.qualityMetrics.tcpPortForwardDialedDuration = 0
 		client.qualityMetrics.tcpPortForwardFailedCount = 0
 		client.qualityMetrics.tcpPortForwardFailedDuration = 0
+		client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount = 0
 
 		client.Unlock()
 	}
@@ -531,6 +544,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 	allProtocolsStats := make(map[string]int64)
 	allProtocolsStats["accepted_clients"] = 0
 	allProtocolsStats["established_clients"] = 0
+	allProtocolsStats["dialing_tcp_port_forwards"] = 0
 	allProtocolsStats["tcp_port_forwards"] = 0
 	allProtocolsStats["total_tcp_port_forwards"] = 0
 	allProtocolsStats["udp_port_forwards"] = 0
@@ -539,6 +553,7 @@ func (sshServer *sshServer) getLoadStats() map[string]interface{} {
 	allProtocolsStats["tcp_port_forward_dialed_duration"] = int64(aggregatedQualityMetrics.tcpPortForwardDialedDuration)
 	allProtocolsStats["tcp_port_forward_failed_count"] = aggregatedQualityMetrics.tcpPortForwardFailedCount
 	allProtocolsStats["tcp_port_forward_failed_duration"] = int64(aggregatedQualityMetrics.tcpPortForwardFailedDuration)
+	allProtocolsStats["tcp_port_forward_rejected_dialing_limit_count"] = aggregatedQualityMetrics.tcpPortForwardRejectedDialingLimitCount
 
 	for _, stats := range protocolStats {
 		for name, value := range stats["ALL"] {
@@ -628,6 +643,33 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	sshClient.run(clientConn)
 }
 
+func (sshServer *sshServer) monitorPortForwardDialError(err error) {
+
+	// "err" is the error returned from a failed TCP or UDP port
+	// forward dial. Certain system error codes indicate low resource
+	// conditions: insufficient file descriptors, ephemeral ports, or
+	// memory. For these cases, log an alert.
+
+	// TODO: also temporarily suspend new clients
+
+	// Note: don't log net.OpError.Error() as the full error string
+	// may contain client destination addresses.
+
+	opErr, ok := err.(*net.OpError)
+	if ok {
+		if opErr.Err == syscall.EADDRNOTAVAIL ||
+			opErr.Err == syscall.EAGAIN ||
+			opErr.Err == syscall.ENOMEM ||
+			opErr.Err == syscall.EMFILE ||
+			opErr.Err == syscall.ENFILE {
+
+			log.WithContextFields(
+				LogFields{"error": opErr.Err}).Error(
+				"port forward dial failed due to unavailable resource")
+		}
+	}
+}
+
 type sshClient struct {
 	sync.Mutex
 	sshServer              *sshServer
@@ -647,15 +689,18 @@ type sshClient struct {
 	tcpPortForwardLRU      *common.LRUConns
 	oslClientSeedState     *osl.ClientSeedState
 	signalIssueSLOKs       chan struct{}
-	stopBroadcast          chan struct{}
+	runContext             context.Context
+	stopRunning            context.CancelFunc
 }
 
 type trafficState struct {
-	bytesUp                        int64
-	bytesDown                      int64
-	concurrentPortForwardCount     int64
-	peakConcurrentPortForwardCount int64
-	totalPortForwardCount          int64
+	bytesUp                               int64
+	bytesDown                             int64
+	concurrentDialingPortForwardCount     int64
+	peakConcurrentDialingPortForwardCount int64
+	concurrentPortForwardCount            int64
+	peakConcurrentPortForwardCount        int64
+	totalPortForwardCount                 int64
 }
 
 // qualityMetrics records upstream TCP dial attempts and
@@ -664,10 +709,11 @@ type trafficState struct {
 // upstream link. These stats are recorded by each sshClient
 // and then reported and reset in sshServer.getLoadStats().
 type qualityMetrics struct {
-	tcpPortForwardDialedCount    int64
-	tcpPortForwardDialedDuration time.Duration
-	tcpPortForwardFailedCount    int64
-	tcpPortForwardFailedDuration time.Duration
+	tcpPortForwardDialedCount               int64
+	tcpPortForwardDialedDuration            time.Duration
+	tcpPortForwardFailedCount               int64
+	tcpPortForwardFailedDuration            time.Duration
+	tcpPortForwardRejectedDialingLimitCount int64
 }
 
 type handshakeState struct {
@@ -678,13 +724,17 @@ type handshakeState struct {
 
 func newSshClient(
 	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
+
+	runContext, stopRunning := context.WithCancel(context.Background())
+
 	return &sshClient{
 		sshServer:         sshServer,
 		tunnelProtocol:    tunnelProtocol,
 		geoIPData:         geoIPData,
 		tcpPortForwardLRU: common.NewLRUConns(),
 		signalIssueSLOKs:  make(chan struct{}, 1),
-		stopBroadcast:     make(chan struct{}),
+		runContext:        runContext,
+		stopRunning:       stopRunning,
 	}
 }
 
@@ -980,18 +1030,28 @@ func (sshClient *sshClient) runTunnel(
 			continue
 		}
 
-		// process each port forward concurrently
+		// Process each port forward concurrently
+
 		waitGroup.Add(1)
 		go func(channel ssh.NewChannel) {
 			defer waitGroup.Done()
 			sshClient.handleNewPortForwardChannel(channel)
 		}(newChannel)
+
+		// Throttle accepting new channels when at the TCP dialing port forwards limit.
+		// This mitigates clients rapidly enqueuing many port forward handlers in a
+		// pre-dial state.
+		// TODO: block here until under the limit (excepting UDP)?
+
+		if sshClient.isTCPDialingPortForwardLimitExceeded() {
+			sshClient.tpcDialingPortForwardLimitThrottle()
+		}
 	}
 
 	// The channel loop is interrupted by a client
 	// disconnect or by calling sshClient.stop().
 
-	close(sshClient.stopBroadcast)
+	sshClient.stopRunning()
 
 	waitGroup.Wait()
 }
@@ -1021,10 +1081,12 @@ func (sshClient *sshClient) logTunnel() {
 	logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
 	logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
 	logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
+	logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
 	logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
 	logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
 	logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
 	logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
+	// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
 	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
 	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
 
@@ -1040,7 +1102,7 @@ func (sshClient *sshClient) runOSLSender() {
 		// TODO: use reflect.SelectCase, and optionally await timer here?
 		select {
 		case <-sshClient.signalIssueSLOKs:
-		case <-sshClient.stopBroadcast:
+		case <-sshClient.runContext.Done():
 			return
 		}
 
@@ -1058,7 +1120,7 @@ func (sshClient *sshClient) runOSLSender() {
 			select {
 			case <-retryTimer.C:
 			case <-sshClient.signalIssueSLOKs:
-			case <-sshClient.stopBroadcast:
+			case <-sshClient.runContext.Done():
 				retryTimer.Stop()
 				return
 			}
@@ -1335,29 +1397,64 @@ func (sshClient *sshClient) isPortForwardPermitted(
 	return false
 }
 
+func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	state := &sshClient.tcpTrafficState
+	max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
+
+	if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
+		return true
+	}
+	return false
+}
+
 func (sshClient *sshClient) isPortForwardLimitExceeded(
-	portForwardType int) (int, bool) {
+	portForwardType int) bool {
 
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
-	var maxPortForwardCount int
+	var max int
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
-		maxPortForwardCount = *sshClient.trafficRules.MaxTCPPortForwardCount
+		max = *sshClient.trafficRules.MaxTCPPortForwardCount
 		state = &sshClient.tcpTrafficState
 	} else {
-		maxPortForwardCount = *sshClient.trafficRules.MaxUDPPortForwardCount
+		max = *sshClient.trafficRules.MaxUDPPortForwardCount
 		state = &sshClient.udpTrafficState
 	}
 
-	if maxPortForwardCount > 0 && state.concurrentPortForwardCount >= int64(maxPortForwardCount) {
-		return maxPortForwardCount, true
+	if max > 0 && state.concurrentPortForwardCount >= int64(max) {
+		return true
+	}
+	return false
+}
+
+func (sshClient *sshClient) dialingTCPPortForward() {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	state := &sshClient.tcpTrafficState
+
+	state.concurrentDialingPortForwardCount += 1
+	if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
+		state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
 	}
-	return maxPortForwardCount, false
 }
 
-func (sshClient *sshClient) openedPortForward(
+func (sshClient *sshClient) failedTCPPortForward() {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
+}
+
+func (sshClient *sshClient) establishedPortForward(
 	portForwardType int) {
 
 	sshClient.Lock()
@@ -1366,6 +1463,10 @@ func (sshClient *sshClient) openedPortForward(
 	var state *trafficState
 	if portForwardType == portForwardTypeTCP {
 		state = &sshClient.tcpTrafficState
+
+		// Assumes TCP port forwards called dialingTCPPortForward
+		state.concurrentDialingPortForwardCount -= 1
+
 	} else {
 		state = &sshClient.udpTrafficState
 	}
@@ -1377,7 +1478,25 @@ func (sshClient *sshClient) openedPortForward(
 	state.totalPortForwardCount += 1
 }
 
-func (sshClient *sshClient) updateQualityMetrics(
+func (sshClient *sshClient) closedPortForward(
+	portForwardType int, bytesUp, bytesDown int64) {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	var state *trafficState
+	if portForwardType == portForwardTypeTCP {
+		state = &sshClient.tcpTrafficState
+	} else {
+		state = &sshClient.udpTrafficState
+	}
+
+	state.concurrentPortForwardCount -= 1
+	state.bytesUp += bytesUp
+	state.bytesDown += bytesDown
+}
+
+func (sshClient *sshClient) updateQualityMetricsWithDialResult(
 	tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
 
 	sshClient.Lock()
@@ -1393,22 +1512,25 @@ func (sshClient *sshClient) updateQualityMetrics(
 	}
 }
 
-func (sshClient *sshClient) closedPortForward(
-	portForwardType int, bytesUp, bytesDown int64) {
+func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
 
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
-	var state *trafficState
-	if portForwardType == portForwardTypeTCP {
-		state = &sshClient.tcpTrafficState
-	} else {
-		state = &sshClient.udpTrafficState
-	}
+	sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
+}
 
-	state.concurrentPortForwardCount -= 1
-	state.bytesUp += bytesUp
-	state.bytesDown += bytesDown
+func (sshClient *sshClient) tpcDialingPortForwardLimitThrottle() {
+
+	duration :=
+		SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD +
+			time.Duration(
+				rand.Int63n(
+					int64(SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MAX_PERIOD-
+						SSH_TCP_PORT_FORWARD_LIMIT_THROTTLE_MIN_PERIOD)+1))
+
+	ctx, _ := context.WithTimeout(sshClient.runContext, duration)
+	<-ctx.Done()
 }
 
 func (sshClient *sshClient) handleTCPChannel(
@@ -1416,6 +1538,67 @@ func (sshClient *sshClient) handleTCPChannel(
 	portToConnect int,
 	newChannel ssh.NewChannel) {
 
+	// Lifecycle of a TCP port forward:
+	//
+	// 1. Call dialingTCPPortForward(), which increments concurrentDialingPortForwardCount
+	//
+	// 2. Check isTCPDialingPortForwardLimitExceeded(), which enforces the configurable
+	//    limit on concurrentDialingPortForwardCount. When the limit is exceeded, throttle
+	//    and recheck and then ultimately reject the port forward; this hard limit applies
+	//    back pressure when upstream network resources are impaired.
+	//
+	//        TOCTOU note: important to increment the port forward counts _before_
+	//        checking limits; otherwise, the client could potentially consume excess
+	//        resources by initiating many port forwards concurrently.
+	//
+	// 3. Dial the target.
+	//
+	// 4. If the dial fails, call failedTCPPortForward() to decrement
+	//    concurrentDialingPortForwardCount, freeing up a dial slot.
+	//
+	// 5. If the dial succeeds, call establishedPortForward(), which decrements
+	//    concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
+	//    the "established" port forward count.
+	//
+	// 6. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
+	//    concurrentPortForwardCount, the number of _established_ TCP port forwards.
+	//    If the limit is exceeded, the LRU established TCP port forward is closed and
+	//    the newly established TCP port forward proceeds. This LRU logic allows some
+	//    dangling resource consumption (e.g., TIME_WAIT) while providing a better
+	//    experience for clients.
+	//
+	// 7. Relay data.
+	//
+	// 8. Call closedPortForward() which decrements concurrentPortForwardCount and
+	//    records bytes transferred.
+
+	sshClient.dialingTCPPortForward()
+	established := false
+	defer func() {
+		if !established {
+			sshClient.failedTCPPortForward()
+		}
+	}()
+
+	exceeded := true
+	for i := 0; i < SSH_TCP_PORT_FORWARD_LIMIT_RETRIES; i++ {
+		exceeded = sshClient.isTCPDialingPortForwardLimitExceeded()
+		if !exceeded {
+			break
+		}
+		sshClient.tpcDialingPortForwardLimitThrottle()
+	}
+	if exceeded {
+
+		sshClient.updateQualityMetricsWithRejectedDialingLimit()
+
+		sshClient.rejectNewChannel(
+			newChannel, ssh.Prohibited, "dialing port forward limit exceeded")
+		return
+	}
+
+	// Transparently redirect web API request connections.
+
 	isWebServerPortForward := false
 	config := sshClient.sshServer.support.Config
 	if config.WebServerPortForwardAddress != "" {
@@ -1432,72 +1615,111 @@ func (sshClient *sshClient) handleTCPChannel(
 		}
 	}
 
-	type lookupIPResult struct {
-		IP  net.IP
-		err error
-	}
-	lookupResultChannel := make(chan *lookupIPResult, 1)
+	// Dial the remote address.
+	//
+	// Hostname resolution is performed explicitly, as a seperate step, as the target IP
+	// address is used for traffic rules (AllowSubnets) and OSL seed progress.
+	//
+	// Contexts are used for cancellation (via sshClient.runContext, which is cancelled
+	// when the client is stopping) and timeouts.
 
-	go func() {
-		// TODO: explicit timeout for DNS resolution?
-		IPs, err := net.LookupIP(hostToConnect)
-		// TODO: shuffle list to try other IPs
-		// TODO: IPv6 support
-		var IP net.IP
-		for _, ip := range IPs {
-			if ip.To4() != nil {
-				IP = ip
-			}
-		}
-		if err == nil && IP == nil {
-			err = errors.New("no IP address")
-		}
-		lookupResultChannel <- &lookupIPResult{IP, err}
-	}()
+	dialStartTime := monotime.Now()
 
-	var lookupResult *lookupIPResult
-	select {
-	case lookupResult = <-lookupResultChannel:
-	case <-sshClient.stopBroadcast:
-		// Note: may leave LookupIP in progress
-		return
+	log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
+
+	ctx, _ := context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT)
+	IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
+
+	// TODO: shuffle list to try other IPs?
+	// TODO: IPv6 support
+	var IP net.IP
+	for _, ip := range IPs {
+		if ip.IP.To4() != nil {
+			IP = ip.IP
+			break
+		}
+	}
+	if err == nil && IP == nil {
+		err = errors.New("no IP address")
 	}
 
-	if lookupResult.err != nil {
+	if err != nil {
+
+		// Record a port forward failure
+		sshClient.updateQualityMetricsWithDialResult(true, monotime.Since(dialStartTime))
+
 		sshClient.rejectNewChannel(
-			newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", lookupResult.err))
+			newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
 		return
 	}
 
+	// Enforce traffic rules, using the resolved IP address.
+
 	if !isWebServerPortForward &&
 		!sshClient.isPortForwardPermitted(
 			portForwardTypeTCP,
 			false,
-			lookupResult.IP,
+			IP,
 			portToConnect) {
 
+		// Note: not recording a port forward failure in this case
+
 		sshClient.rejectNewChannel(
 			newChannel, ssh.Prohibited, "port forward not permitted")
 		return
 	}
 
+	// TCP dial.
+
+	remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
+
+	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
+
+	ctx, _ = context.WithTimeout(sshClient.runContext, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
+	fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
+
+	// Record port forward success or failure
+	sshClient.updateQualityMetricsWithDialResult(err == nil, monotime.Since(dialStartTime))
+
+	if err != nil {
+
+		// Monitor for low resource error conditions
+		sshClient.sshServer.monitorPortForwardDialError(err)
+
+		sshClient.rejectNewChannel(
+			newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", err))
+		return
+	}
+
+	// The upstream TCP port forward connection has been established. Schedule
+	// some cleanup and notify the SSH client that the channel is accepted.
+
+	defer fwdConn.Close()
+
+	fwdChannel, requests, err := newChannel.Accept()
+	if err != nil {
+		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
+		return
+	}
+	go ssh.DiscardRequests(requests)
+	defer fwdChannel.Close()
+
+	// Release the dialing slot and acquire an established slot.
+
+	// "established = true" cancels the deferred failedTCPPortForward()
+	established = true
+	sshClient.establishedPortForward(portForwardTypeTCP)
+
 	var bytesUp, bytesDown int64
-	sshClient.openedPortForward(portForwardTypeTCP)
 	defer func() {
 		sshClient.closedPortForward(
 			portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
 	}()
 
-	// TOCTOU note: important to increment the port forward count (via
-	// openPortForward) _before_ checking isPortForwardLimitExceeded
-	// otherwise, the client could potentially consume excess resources
-	// by initiating many port forwards concurrently.
-	// TODO: close LRU connection (after successful Dial) instead of
-	// rejecting new connection?
-	if maxCount, exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
+	if exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
 
 		// Close the oldest TCP port forward. CloseOldest() closes
-		// the conn and the port forward's goroutine will complete
+		// the conn and the port forward's goroutines will complete
 		// the cleanup asynchronously.
 		//
 		// Some known limitations:
@@ -1510,91 +1732,25 @@ func (sshClient *sshClient) handleTCPChannel(
 		//   asynchronous, there exists a possibility that a client can consume
 		//   more than max port forward resources -- just not upstream sockets.
 		//
-		// - An LRU list entry for this port forward is not added until
-		//   after the dial completes, but the port forward is counted
-		//   towards max limits. This means many dials in progress will
-		//   put established connections in jeopardy.
-		//
-		// - We're closing the oldest open connection _before_ successfully
-		//   dialing the new port forward. This means we are potentially
-		//   discarding a good connection to make way for a failed connection.
-		//   We cannot simply dial first and still maintain a limit on
-		//   resources used, so to address this we'd need to add some
-		//   accounting for connections still establishing.
+		// - Closed sockets will enter the TIME_WAIT state, consuming some
+		//   resources.
 
 		sshClient.tcpPortForwardLRU.CloseOldest()
 
-		log.WithContextFields(
-			LogFields{
-				"maxCount": maxCount,
-			}).Debug("closed LRU TCP port forward")
-	}
-
-	// Dial the target remote address. This is done in a goroutine to
-	// ensure the shutdown signal is handled immediately.
-
-	remoteAddr := net.JoinHostPort(lookupResult.IP.String(), strconv.Itoa(portToConnect))
-
-	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
-
-	type dialTCPResult struct {
-		conn net.Conn
-		err  error
-	}
-	dialResultChannel := make(chan *dialTCPResult, 1)
-
-	dialStartTime := monotime.Now()
-
-	go func() {
-		// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
-		conn, err := net.DialTimeout(
-			"tcp", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
-		dialResultChannel <- &dialTCPResult{conn, err}
-	}()
-
-	var dialResult *dialTCPResult
-	select {
-	case dialResult = <-dialResultChannel:
-	case <-sshClient.stopBroadcast:
-		// Note: may leave Dial in progress
-		// TODO: use net.Dialer.DialContext to be able to cancel
-		return
-	}
-
-	sshClient.updateQualityMetrics(
-		dialResult.err == nil, monotime.Since(dialStartTime))
-
-	if dialResult.err != nil {
-		sshClient.rejectNewChannel(
-			newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", dialResult.err))
-		return
+		log.WithContext().Debug("closed LRU TCP port forward")
 	}
 
-	// The upstream TCP port forward connection has been established. Schedule
-	// some cleanup and notify the SSH client that the channel is accepted.
-
-	fwdConn := dialResult.conn
-	defer fwdConn.Close()
-
-	fwdChannel, requests, err := newChannel.Accept()
-	if err != nil {
-		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
-		return
-	}
-	go ssh.DiscardRequests(requests)
-	defer fwdChannel.Close()
+	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
+	defer lruEntry.Remove()
 
 	// ActivityMonitoredConn monitors the TCP port forward I/O and updates
 	// its LRU status. ActivityMonitoredConn also times out I/O on the port
 	// forward if both reads and writes have been idle for the specified
 	// duration.
 
-	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
-	defer lruEntry.Remove()
-
 	// Ensure nil interface if newClientSeedPortForward returns nil
 	var updater common.ActivityUpdater
-	seedUpdater := sshClient.newClientSeedPortForward(lookupResult.IP)
+	seedUpdater := sshClient.newClientSeedPortForward(IP)
 	if seedUpdater != nil {
 		updater = seedUpdater
 	}

+ 17 - 15
psiphon/server/udp.go

@@ -169,25 +169,22 @@ func (mux *udpPortForwardMultiplexer) run() {
 				continue
 			}
 
-			mux.sshClient.openedPortForward(portForwardTypeUDP)
-			// Note: can't defer sshClient.closedPortForward() here
+			// Note: UDP port forward counting has no dialing phase
+
+			mux.sshClient.establishedPortForward(portForwardTypeUDP)
+			// Can't defer sshClient.closedPortForward() here;
+			// relayDownstream will call sshClient.closedPortForward()
 
 			// TOCTOU note: important to increment the port forward count (via
 			// openPortForward) _before_ checking isPortForwardLimitExceeded
-			if maxCount, exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded {
+			if exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded {
 
 				// Close the oldest UDP port forward. CloseOldest() closes
 				// the conn and the port forward's goroutine will complete
 				// the cleanup asynchronously.
-				//
-				// See LRU comment in handleTCPChannel() for a known
-				// limitations regarding CloseOldest().
 				mux.portForwardLRU.CloseOldest()
 
-				log.WithContextFields(
-					LogFields{
-						"maxCount": maxCount,
-					}).Debug("closed LRU UDP port forward")
+				log.WithContext().Debug("closed LRU UDP port forward")
 			}
 
 			log.WithContextFields(
@@ -195,22 +192,28 @@ func (mux *udpPortForwardMultiplexer) run() {
 					"remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort),
 					"connID":     message.connID}).Debug("dialing")
 
-			// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
 			udpConn, err := net.DialUDP(
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 			if err != nil {
 				mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0)
-				log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
+
+				// Monitor for low resource error conditions
+				mux.sshClient.sshServer.monitorPortForwardDialError(err)
+
+				// Note: Debug level, as logMessage may contain user traffic destination address information
+				log.WithContextFields(LogFields{"error": err}).Debug("DialUDP failed")
 				continue
 			}
 
+			lruEntry := mux.portForwardLRU.Add(udpConn)
+			// Can't defer lruEntry.Remove() here;
+			// relayDownstream will call lruEntry.Remove()
+
 			// ActivityMonitoredConn monitors the TCP port forward I/O and updates
 			// its LRU status. ActivityMonitoredConn also times out I/O on the port
 			// forward if both reads and writes have been idle for the specified
 			// duration.
 
-			lruEntry := mux.portForwardLRU.Add(udpConn)
-
 			// Ensure nil interface if newClientSeedPortForward returns nil
 			var updater common.ActivityUpdater
 			seedUpdater := mux.sshClient.newClientSeedPortForward(dialIP)
@@ -246,7 +249,6 @@ func (mux *udpPortForwardMultiplexer) run() {
 			mux.portForwards[portForward.connID] = portForward
 			mux.portForwardsMutex.Unlock()
 
-			// relayDownstream will call sshClient.closedPortForward()
 			mux.relayWaitGroup.Add(1)
 			go portForward.relayDownstream()
 		}

+ 5 - 0
psiphon/tunnel.go

@@ -235,6 +235,11 @@ func (tunnel *Tunnel) Close(isDiscarded bool) {
 		tunnel.sshClient.Close()
 		// tunnel.conn.Close() may get called multiple times, which is allowed.
 		tunnel.conn.Close()
+
+		err := tunnel.sshClient.Wait()
+		if err != nil {
+			NoticeAlert("close tunnel ssh error: %s", err)
+		}
 	}
 }