Просмотр исходного кода

Merge pull request #177 from rod-hynes/master

Latest server changes
Rod Hynes 9 лет назад
Родитель
Сommit
eb1a53da16

+ 8 - 8
Server/main.go

@@ -31,25 +31,25 @@ import (
 
 func main() {
 
-	var generateServerIPaddress, newConfigFilename, newServerEntryFilename string
-	var networkInterface string
+	var generateServerIPaddress, generateServerNetworkInterface string
+	var generateConfigFilename, generateServerEntryFilename string
 	var generateWebServerPort, generateSSHServerPort, generateObfuscatedSSHServerPort int
 	var runConfigFilenames stringListFlag
 
 	flag.StringVar(
-		&newConfigFilename,
+		&generateConfigFilename,
 		"newConfig",
 		server.SERVER_CONFIG_FILENAME,
 		"generate new config with this `filename`")
 
 	flag.StringVar(
-		&newServerEntryFilename,
+		&generateServerEntryFilename,
 		"newServerEntry",
 		server.SERVER_ENTRY_FILENAME,
 		"generate new server entry with this `filename`")
 
 	flag.StringVar(
-		&networkInterface,
+		&generateServerNetworkInterface,
 		"interface",
 		"",
 		"generate server entry with this `network-interface`")
@@ -104,7 +104,7 @@ func main() {
 		configFileContents, serverEntryFileContents, err := server.GenerateConfig(
 			&server.GenerateConfigParams{
 				ServerIPAddress:         generateServerIPaddress,
-				ServerNetworkInterface:  networkInterface,
+				ServerNetworkInterface:  generateServerNetworkInterface,
 				WebServerPort:           generateWebServerPort,
 				SSHServerPort:           generateSSHServerPort,
 				ObfuscatedSSHServerPort: generateObfuscatedSSHServerPort,
@@ -114,13 +114,13 @@ func main() {
 			fmt.Errorf("generate failed: %s", err)
 			os.Exit(1)
 		}
-		err = ioutil.WriteFile(newConfigFilename, configFileContents, 0600)
+		err = ioutil.WriteFile(generateConfigFilename, configFileContents, 0600)
 		if err != nil {
 			fmt.Errorf("error writing configuration file: %s", err)
 			os.Exit(1)
 		}
 
-		err = ioutil.WriteFile(newServerEntryFilename, serverEntryFileContents, 0600)
+		err = ioutil.WriteFile(generateServerEntryFilename, serverEntryFileContents, 0600)
 		if err != nil {
 			fmt.Errorf("error writing server entry file: %s", err)
 			os.Exit(1)

+ 9 - 0
psiphon/controller_test.go

@@ -387,6 +387,15 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 		config.HostNameTransformer = &TestHostNameTransformer{}
 	}
 
+	// Override client retry throttle values to speed up automated
+	// tests and ensure tests complete within fixed deadlines.
+	fetchRemoteServerListRetryPeriodSeconds := 0
+	config.FetchRemoteServerListRetryPeriodSeconds = &fetchRemoteServerListRetryPeriodSeconds
+	downloadUpgradeRetryPeriodSeconds := 0
+	config.DownloadUpgradeRetryPeriodSeconds = &downloadUpgradeRetryPeriodSeconds
+	establishTunnelPausePeriodSeconds := 1
+	config.EstablishTunnelPausePeriodSeconds = &establishTunnelPausePeriodSeconds
+
 	os.Remove(config.UpgradeDownloadFilename)
 
 	config.TunnelProtocol = runConfig.protocol

+ 3 - 0
psiphon/networkInterface.go

@@ -55,6 +55,9 @@ func GetInterfaceIPAddress(listenInterface string) (string, error) {
 			}
 			// TODO: IPv6 support
 			ip = iptype.IP.To4()
+			if ip == nil {
+				continue
+			}
 			return ip.String(), nil
 		}
 	}

+ 24 - 13
psiphon/server/config.go

@@ -57,6 +57,7 @@ const (
 	DEFAULT_SSH_SERVER_PORT                = 2222
 	SSH_HANDSHAKE_TIMEOUT                  = 30 * time.Second
 	SSH_CONNECTION_READ_DEADLINE           = 5 * time.Minute
+	SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT      = 30 * time.Second
 	SSH_OBFUSCATED_KEY_BYTE_LENGTH         = 32
 	DEFAULT_OBFUSCATED_SSH_SERVER_PORT     = 3333
 	REDIS_POOL_MAX_IDLE                    = 50
@@ -181,6 +182,12 @@ type Config struct {
 	// by this server, which parses the SSH channel using the udpgw
 	// protocol.
 	UdpgwServerAddress string
+
+	// LoadMonitorPeriodSeconds indicates how frequently to log server
+	// load information (number of connected clients per tunnel protocol,
+	// number of running goroutines, amount of memory allocated).
+	// The default, 0, disables load logging.
+	LoadMonitorPeriodSeconds int
 }
 
 // TrafficRules specify the limits placed on client traffic.
@@ -250,6 +257,11 @@ func (config *Config) RunObfuscatedSSHServer() bool {
 	return config.ObfuscatedSSHServerPort > 0
 }
 
+// RunLoadMonitor indicates whether to monitor and log server load.
+func (config *Config) RunLoadMonitor() bool {
+	return config.LoadMonitorPeriodSeconds > 0
+}
+
 // UseRedis indicates whether to store per-session GeoIP information in
 // redis. This is for integration with the legacy psi_web component.
 func (config *Config) UseRedis() bool {
@@ -355,8 +367,9 @@ type GenerateConfigParams struct {
 	// ServerIPAddress is the public IP address of the server.
 	ServerIPAddress string
 
-	// ServerNetworkInterface is the (optional) nic to expose the server on
-	// when running in unprivileged mode but want to allow external clients to connect.
+	// ServerNetworkInterface specifies a network interface to
+	// use to determine the ServerIPAddress automatically. When
+	// set, ServerIPAddress is ignored.
 	ServerNetworkInterface string
 
 	// WebServerPort is the listening port of the web server.
@@ -378,13 +391,19 @@ type GenerateConfigParams struct {
 // the config file and server entry as necessary.
 func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 
-	// TODO: support disabling web server or a subset of protocols
-
 	serverIPaddress := params.ServerIPAddress
 	if serverIPaddress == "" {
 		serverIPaddress = DEFAULT_SERVER_IP_ADDRESS
 	}
 
+	if params.ServerNetworkInterface != "" {
+		var err error
+		serverIPaddress, err = psiphon.GetInterfaceIPAddress(params.ServerNetworkInterface)
+		if err != nil {
+			return nil, nil, psiphon.ContextError(err)
+		}
+	}
+
 	// Web server config
 
 	webServerPort := params.WebServerPort
@@ -457,14 +476,6 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 		return nil, nil, psiphon.ContextError(err)
 	}
 
-	// Find IP address of the network interface (if not loopback)
-	serverNetworkInterface := params.ServerNetworkInterface
-	serverNetworkInterfaceIP, err := psiphon.GetInterfaceIPAddress(serverNetworkInterface)
-	if err != nil {
-		serverNetworkInterfaceIP = serverIPaddress
-		fmt.Printf("Could not find IP address of nic.  Falling back to %s\n", serverIPaddress)
-	}
-
 	// Assemble config and server entry
 
 	config := &Config{
@@ -504,7 +515,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 	}
 
 	serverEntry := &psiphon.ServerEntry{
-		IpAddress:            serverNetworkInterfaceIP,
+		IpAddress:            serverIPaddress,
 		WebServerPort:        fmt.Sprintf("%d", webServerPort),
 		WebServerSecret:      webServerSecret,
 		WebServerCertificate: strippedWebServerCertificate,

+ 1 - 13
psiphon/server/services.go

@@ -79,7 +79,7 @@ func RunServices(encodedConfigs [][]byte) error {
 		}()
 	}
 
-	if config.RunSSHServer() {
+	if config.RunSSHServer() || config.RunObfuscatedSSHServer() {
 		waitGroup.Add(1)
 		go func() {
 			defer waitGroup.Done()
@@ -91,18 +91,6 @@ func RunServices(encodedConfigs [][]byte) error {
 		}()
 	}
 
-	if config.RunObfuscatedSSHServer() {
-		waitGroup.Add(1)
-		go func() {
-			defer waitGroup.Done()
-			err := RunObfuscatedSSHServer(config, shutdownBroadcast)
-			select {
-			case errors <- err:
-			default:
-			}
-		}()
-	}
-
 	// An OS signal triggers an orderly shutdown
 	systemStopSignal := make(chan os.Signal, 1)
 	signal.Notify(systemStopSignal, os.Interrupt, os.Kill)

+ 303 - 171
psiphon/server/sshService.go

@@ -26,28 +26,21 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"runtime"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"golang.org/x/crypto/ssh"
 )
 
-// RunSSHServer runs an ssh server with plain SSH protocol.
-func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
-	return runSSHServer(config, false, shutdownBroadcast)
-}
-
-// RunSSHServer runs an ssh server with Obfuscated SSH protocol.
-func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
-	return runSSHServer(config, true, shutdownBroadcast)
-}
-
-// runSSHServer runs an SSH or Obfuscated SSH server. In the Obfuscated SSH case, an
-// ObfuscatedSSHConn is layered in front of the client TCP connection; otherwise, both
-// modes are identical.
+// RunSSHServer runs an SSH server, the core tunneling component of the Psiphon
+// server. The SSH server runs a selection of listeners that handle connections
+// using various, optional obfuscation protocols layered on top of SSH.
+// (Currently, just Obfuscated SSH).
 //
-// runSSHServer listens on the designated port and spawns new goroutines to handle
+// RunSSHServer listens on the designated port(s) and spawns new goroutines to handle
 // each client connection. It halts when shutdownBroadcast is signaled. A list of active
 // clients is maintained, and when halting all clients are first shutdown.
 //
@@ -55,11 +48,12 @@ func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) e
 // authentication, and then looping on client new channel requests. At this time, only
 // "direct-tcpip" channels, dynamic port fowards, are expected and supported.
 //
-// A new goroutine is spawned to handle each port forward. Each port forward tracks its
-// bytes transferred. Overall per-client stats for connection duration, GeoIP, number of
-// port forwards, and bytes transferred are tracked and logged when the client shuts down.
-func runSSHServer(
-	config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
+// A new goroutine is spawned to handle each port forward for each client. Each port
+// forward tracks its bytes transferred. Overall per-client stats for connection duration,
+// GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
+// client shuts down.
+func RunSSHServer(
+	config *Config, shutdownBroadcast <-chan struct{}) error {
 
 	privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
 	if err != nil {
@@ -74,89 +68,94 @@ func runSSHServer(
 
 	sshServer := &sshServer{
 		config:            config,
-		useObfuscation:    useObfuscation,
+		runWaitGroup:      new(sync.WaitGroup),
+		listenerError:     make(chan error),
 		shutdownBroadcast: shutdownBroadcast,
 		sshHostKey:        signer,
 		nextClientID:      1,
 		clients:           make(map[sshClientID]*sshClient),
 	}
 
-	var serverPort int
-	if useObfuscation {
-		serverPort = config.ObfuscatedSSHServerPort
-	} else {
-		serverPort = config.SSHServerPort
+	type sshListener struct {
+		net.Listener
+		localAddress   string
+		tunnelProtocol string
 	}
 
-	listener, err := net.Listen(
-		"tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, serverPort))
-	if err != nil {
-		return psiphon.ContextError(err)
-	}
+	var listeners []*sshListener
 
-	log.WithContextFields(
-		LogFields{
-			"useObfuscation": useObfuscation,
-			"port":           serverPort,
-		}).Info("starting")
-
-	err = nil
-	errors := make(chan error)
-	waitGroup := new(sync.WaitGroup)
+	if config.RunSSHServer() {
+		listeners = append(listeners, &sshListener{
+			localAddress: fmt.Sprintf(
+				"%s:%d", config.ServerIPAddress, config.SSHServerPort),
+			tunnelProtocol: psiphon.TUNNEL_PROTOCOL_SSH,
+		})
+	}
 
-	waitGroup.Add(1)
-	go func() {
-		defer waitGroup.Done()
+	if config.RunObfuscatedSSHServer() {
+		listeners = append(listeners, &sshListener{
+			localAddress: fmt.Sprintf(
+				"%s:%d", config.ServerIPAddress, config.ObfuscatedSSHServerPort),
+			tunnelProtocol: psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
+		})
+	}
 
-	loop:
-		for {
-			conn, err := listener.Accept()
+	// TODO: add additional protocol listeners here (e.g, meek)
 
-			select {
-			case <-shutdownBroadcast:
-				if err == nil {
-					conn.Close()
-				}
-				break loop
-			default:
+	for i, listener := range listeners {
+		var err error
+		listener.Listener, err = net.Listen("tcp", listener.localAddress)
+		if err != nil {
+			for j := 0; j < i; j++ {
+				listener.Listener.Close()
 			}
+			return psiphon.ContextError(err)
+		}
+		log.WithContextFields(
+			LogFields{
+				"localAddress":   listener.localAddress,
+				"tunnelProtocol": listener.tunnelProtocol,
+			}).Info("listening")
+	}
 
-			if err != nil {
-				if e, ok := err.(net.Error); ok && e.Temporary() {
-					log.WithContextFields(LogFields{"error": err}).Error("accept failed")
-					// Temporary error, keep running
-					continue
-				}
-
-				select {
-				case errors <- psiphon.ContextError(err):
-				default:
-				}
+	for _, listener := range listeners {
+		sshServer.runWaitGroup.Add(1)
+		go func(listener *sshListener) {
+			defer sshServer.runWaitGroup.Done()
 
-				break loop
-			}
+			sshServer.runListener(
+				listener.Listener, listener.tunnelProtocol)
 
-			// process each client connection concurrently
-			go sshServer.handleClient(conn.(*net.TCPConn))
-		}
+			log.WithContextFields(
+				LogFields{
+					"localAddress":   listener.localAddress,
+					"tunnelProtocol": listener.tunnelProtocol,
+				}).Info("stopping")
 
-		sshServer.stopClients()
+		}(listener)
+	}
 
-		log.WithContextFields(
-			LogFields{"useObfuscation": useObfuscation}).Info("stopped")
-	}()
+	if config.RunLoadMonitor() {
+		sshServer.runWaitGroup.Add(1)
+		go func() {
+			defer sshServer.runWaitGroup.Done()
+			sshServer.runLoadMonitor()
+		}()
+	}
 
+	err = nil
 	select {
-	case <-shutdownBroadcast:
-	case err = <-errors:
+	case <-sshServer.shutdownBroadcast:
+	case err = <-sshServer.listenerError:
 	}
 
-	listener.Close()
-
-	waitGroup.Wait()
+	for _, listener := range listeners {
+		listener.Close()
+	}
+	sshServer.stopClients()
+	sshServer.runWaitGroup.Wait()
 
-	log.WithContextFields(
-		LogFields{"useObfuscation": useObfuscation}).Info("exiting")
+	log.WithContext().Info("stopped")
 
 	return err
 }
@@ -165,7 +164,8 @@ type sshClientID uint64
 
 type sshServer struct {
 	config            *Config
-	useObfuscation    bool
+	runWaitGroup      *sync.WaitGroup
+	listenerError     chan error
 	shutdownBroadcast <-chan struct{}
 	sshHostKey        ssh.Signer
 	nextClientID      sshClientID
@@ -174,6 +174,73 @@ type sshServer struct {
 	clients           map[sshClientID]*sshClient
 }
 
+func (sshServer *sshServer) runListener(
+	listener net.Listener, tunnelProtocol string) {
+
+	for {
+		conn, err := listener.Accept()
+
+		if err == nil && tunnelProtocol == psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
+			conn, err = psiphon.NewObfuscatedSshConn(
+				psiphon.OBFUSCATION_CONN_MODE_SERVER,
+				conn,
+				sshServer.config.ObfuscatedSSHKey)
+		}
+
+		select {
+		case <-sshServer.shutdownBroadcast:
+			if err == nil {
+				conn.Close()
+			}
+			return
+		default:
+		}
+
+		if err != nil {
+			if e, ok := err.(net.Error); ok && e.Temporary() {
+				log.WithContextFields(LogFields{"error": err}).Error("accept failed")
+				// Temporary error, keep running
+				continue
+			}
+
+			select {
+			case sshServer.listenerError <- psiphon.ContextError(err):
+			default:
+			}
+
+			return
+		}
+
+		// process each client connection concurrently
+		go sshServer.handleClient(tunnelProtocol, conn)
+	}
+}
+
+func (sshServer *sshServer) runLoadMonitor() {
+	ticker := time.NewTicker(
+		time.Duration(sshServer.config.LoadMonitorPeriodSeconds) * time.Second)
+	defer ticker.Stop()
+	for {
+		select {
+		case <-sshServer.shutdownBroadcast:
+			return
+		case <-ticker.C:
+			var memStats runtime.MemStats
+			runtime.ReadMemStats(&memStats)
+			fields := LogFields{
+				"goroutines":    runtime.NumGoroutine(),
+				"memAlloc":      memStats.Alloc,
+				"memTotalAlloc": memStats.TotalAlloc,
+				"memSysAlloc":   memStats.Sys,
+			}
+			for tunnelProtocol, count := range sshServer.countClients() {
+				fields[tunnelProtocol] = count
+			}
+			log.WithContextFields(fields).Info("load")
+		}
+	}
+}
+
 func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
 
 	sshServer.clientsMutex.Lock()
@@ -199,34 +266,20 @@ func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
 	sshServer.clientsMutex.Unlock()
 
 	if client != nil {
-		sshServer.stopClient(client)
+		client.stop()
 	}
 }
 
-func (sshServer *sshServer) stopClient(client *sshClient) {
+func (sshServer *sshServer) countClients() map[string]int {
 
-	client.sshConn.Close()
-	client.sshConn.Wait()
+	sshServer.clientsMutex.Lock()
+	defer sshServer.clientsMutex.Unlock()
 
-	client.Lock()
-	log.WithContextFields(
-		LogFields{
-			"startTime":                         client.startTime,
-			"duration":                          time.Now().Sub(client.startTime),
-			"psiphonSessionID":                  client.psiphonSessionID,
-			"country":                           client.geoIPData.Country,
-			"city":                              client.geoIPData.City,
-			"ISP":                               client.geoIPData.ISP,
-			"bytesUpTCP":                        client.tcpTrafficState.bytesUp,
-			"bytesDownTCP":                      client.tcpTrafficState.bytesDown,
-			"portForwardCountTCP":               client.tcpTrafficState.portForwardCount,
-			"peakConcurrentPortForwardCountTCP": client.tcpTrafficState.peakConcurrentPortForwardCount,
-			"bytesUpUDP":                        client.udpTrafficState.bytesUp,
-			"bytesDownUDP":                      client.udpTrafficState.bytesDown,
-			"portForwardCountUDP":               client.udpTrafficState.portForwardCount,
-			"peakConcurrentPortForwardCountUDP": client.udpTrafficState.peakConcurrentPortForwardCount,
-		}).Info("tunnel closed")
-	client.Unlock()
+	counts := make(map[string]int)
+	for _, client := range sshServer.clients {
+		counts[client.tunnelProtocol] += 1
+	}
+	return counts
 }
 
 func (sshServer *sshServer) stopClients() {
@@ -237,24 +290,21 @@ func (sshServer *sshServer) stopClients() {
 	sshServer.clientsMutex.Unlock()
 
 	for _, client := range sshServer.clients {
-		sshServer.stopClient(client)
+		client.stop()
 	}
 }
 
-func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
+func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
 
-	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
+	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(clientConn.RemoteAddr()))
 
-	sshClient := &sshClient{
-		sshServer:       sshServer,
-		startTime:       time.Now(),
-		geoIPData:       geoIPData,
-		trafficRules:    sshServer.config.GetTrafficRules(geoIPData.Country),
-		tcpTrafficState: &trafficState{},
-		udpTrafficState: &trafficState{},
-	}
+	sshClient := newSshClient(
+		sshServer,
+		tunnelProtocol,
+		geoIPData,
+		sshServer.config.GetTrafficRules(geoIPData.Country))
 
-	// Wrap the base TCP connection with an IdleTimeoutConn which will terminate
+	// Wrap the base client connection with an IdleTimeoutConn which will terminate
 	// the connection if no data is received before the deadline. This timeout is
 	// in effect for the entire duration of the SSH connection. Clients must actively
 	// use the connection or send SSH keep alive requests to keep the connection
@@ -262,7 +312,7 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
 	var conn net.Conn
 
-	conn = psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
+	conn = psiphon.NewIdleTimeoutConn(clientConn, SSH_CONNECTION_READ_DEADLINE, false)
 
 	// Further wrap the connection in a rate limiting ThrottledConn.
 
@@ -292,29 +342,25 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 		})
 	}
 
-	go func() {
-
-		result := &sshNewServerConnResult{}
-		if sshServer.useObfuscation {
-			result.conn, result.err = psiphon.NewObfuscatedSshConn(
-				psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey)
-		} else {
-			result.conn = conn
+	go func(conn net.Conn) {
+		sshServerConfig := &ssh.ServerConfig{
+			PasswordCallback: sshClient.passwordCallback,
+			AuthLogCallback:  sshClient.authLogCallback,
+			ServerVersion:    sshServer.config.SSHServerVersion,
 		}
-		if result.err == nil {
+		sshServerConfig.AddHostKey(sshServer.sshHostKey)
 
-			sshServerConfig := &ssh.ServerConfig{
-				PasswordCallback: sshClient.passwordCallback,
-				AuthLogCallback:  sshClient.authLogCallback,
-				ServerVersion:    sshServer.config.SSHServerVersion,
-			}
-			sshServerConfig.AddHostKey(sshServer.sshHostKey)
+		sshConn, channels, requests, err :=
+			ssh.NewServerConn(conn, sshServerConfig)
 
-			result.sshConn, result.channels, result.requests, result.err =
-				ssh.NewServerConn(result.conn, sshServerConfig)
+		resultChannel <- &sshNewServerConnResult{
+			conn:     conn,
+			sshConn:  sshConn,
+			channels: channels,
+			requests: requests,
+			err:      err,
 		}
-		resultChannel <- result
-	}()
+	}(conn)
 
 	var result *sshNewServerConnResult
 	select {
@@ -351,15 +397,18 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
 type sshClient struct {
 	sync.Mutex
-	sshServer        *sshServer
-	sshConn          ssh.Conn
-	startTime        time.Time
-	geoIPData        GeoIPData
-	psiphonSessionID string
-	udpChannel       ssh.Channel
-	trafficRules     TrafficRules
-	tcpTrafficState  *trafficState
-	udpTrafficState  *trafficState
+	sshServer               *sshServer
+	tunnelProtocol          string
+	sshConn                 ssh.Conn
+	startTime               time.Time
+	geoIPData               GeoIPData
+	psiphonSessionID        string
+	udpChannel              ssh.Channel
+	trafficRules            TrafficRules
+	tcpTrafficState         *trafficState
+	udpTrafficState         *trafficState
+	channelHandlerWaitGroup *sync.WaitGroup
+	stopBroadcast           chan struct{}
 }
 
 type trafficState struct {
@@ -370,15 +419,31 @@ type trafficState struct {
 	peakConcurrentPortForwardCount int64
 }
 
+func newSshClient(
+	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
+	return &sshClient{
+		sshServer:               sshServer,
+		tunnelProtocol:          tunnelProtocol,
+		startTime:               time.Now(),
+		geoIPData:               geoIPData,
+		trafficRules:            trafficRules,
+		tcpTrafficState:         &trafficState{},
+		udpTrafficState:         &trafficState{},
+		channelHandlerWaitGroup: new(sync.WaitGroup),
+		stopBroadcast:           make(chan struct{}),
+	}
+}
+
 func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
 	for newChannel := range channels {
 
 		if newChannel.ChannelType() != "direct-tcpip" {
 			sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
-			return
+			continue
 		}
 
 		// process each port forward concurrently
+		sshClient.channelHandlerWaitGroup.Add(1)
 		go sshClient.handleNewPortForwardChannel(newChannel)
 	}
 }
@@ -395,6 +460,7 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason s
 }
 
 func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
+	defer sshClient.channelHandlerWaitGroup.Done()
 
 	// http://tools.ietf.org/html/rfc4254#section-7.2
 	var directTcpipExtraData struct {
@@ -460,7 +526,7 @@ func (sshClient *sshClient) isPortForwardLimitExceeded(
 	return limitExceeded
 }
 
-func (sshClient *sshClient) establishedPortForward(
+func (sshClient *sshClient) openedPortForward(
 	state *trafficState) {
 
 	sshClient.Lock()
@@ -497,7 +563,17 @@ func (sshClient *sshClient) handleTCPChannel(
 		return
 	}
 
-	// TODO: close LRU connection (after successful Dial) instead of rejecting new connection?
+	var bytesUp, bytesDown int64
+	sshClient.openedPortForward(sshClient.tcpTrafficState)
+	defer sshClient.closedPortForward(
+		sshClient.tcpTrafficState, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
+
+	// TOCTOU note: important to increment the port forward count (via
+	// openPortForward) _before_ checking isPortForwardLimitExceeded
+	// otherwise, the client could potentially consume excess resources
+	// by initiating many port forwards concurrently.
+	// TODO: close LRU connection (after successful Dial) instead of
+	// rejecting new connection?
 	if sshClient.isPortForwardLimitExceeded(
 		sshClient.tcpTrafficState,
 		sshClient.trafficRules.MaxTCPPortForwardCount) {
@@ -507,18 +583,39 @@ func (sshClient *sshClient) handleTCPChannel(
 		return
 	}
 
-	targetAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
+	remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
 
-	log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
+	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
 
-	// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
-	// TODO: port forward dial timeout
-	// TODO: IPv6 support
-	fwdConn, err := net.Dial("tcp4", targetAddr)
-	if err != nil {
-		sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
+	type dialTcpResult struct {
+		conn net.Conn
+		err  error
+	}
+
+	resultChannel := make(chan *dialTcpResult, 1)
+
+	go func() {
+		// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
+		// TODO: IPv6 support
+		conn, err := net.DialTimeout(
+			"tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
+		resultChannel <- &dialTcpResult{conn, err}
+	}()
+
+	var result *dialTcpResult
+	select {
+	case result = <-resultChannel:
+	case <-sshClient.stopBroadcast:
+		// Note: may leave dial in progress
+		return
+	}
+
+	if result.err != nil {
+		sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error())
 		return
 	}
+
+	fwdConn := result.conn
 	defer fwdConn.Close()
 
 	fwdChannel, requests, err := newChannel.Accept()
@@ -529,9 +626,7 @@ func (sshClient *sshClient) handleTCPChannel(
 	go ssh.DiscardRequests(requests)
 	defer fwdChannel.Close()
 
-	sshClient.establishedPortForward(sshClient.tcpTrafficState)
-
-	log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
+	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
 
 	// When idle port forward traffic rules are in place, wrap fwdConn
 	// in an IdleTimeoutConn configured to reset idle on writes as well
@@ -549,28 +644,36 @@ func (sshClient *sshClient) handleTCPChannel(
 	// TODO: relay errors to fwdChannel.Stderr()?
 	// TODO: use a low-memory io.Copy?
 
-	var bytesUp, bytesDown int64
-
 	relayWaitGroup := new(sync.WaitGroup)
 	relayWaitGroup.Add(1)
 	go func() {
 		defer relayWaitGroup.Done()
-		var err error
-		bytesUp, err = io.Copy(fwdConn, fwdChannel)
-		if err != nil {
-			log.WithContextFields(LogFields{"error": err}).Warning("upstream TCP relay failed")
+		bytes, err := io.Copy(fwdChannel, fwdConn)
+		atomic.AddInt64(&bytesDown, bytes)
+		if err != nil && err != io.EOF {
+			// Debug since errors such as "connection reset by peer" occur during normal operation
+			log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
 		}
 	}()
-	bytesDown, err = io.Copy(fwdChannel, fwdConn)
-	if err != nil {
-		log.WithContextFields(LogFields{"error": err}).Warning("downstream TCP relay failed")
+	bytes, err := io.Copy(fwdConn, fwdChannel)
+	atomic.AddInt64(&bytesUp, bytes)
+	if err != nil && err != io.EOF {
+		log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
 	}
-	fwdChannel.CloseWrite()
-	relayWaitGroup.Wait()
 
-	sshClient.closedPortForward(sshClient.tcpTrafficState, bytesUp, bytesDown)
+	// Shutdown special case: fwdChannel will be closed and return EOF when
+	// the SSH connection is closed, but we need to explicitly close fwdConn
+	// to interrupt the downstream io.Copy, which may be blocked on a
+	// fwdConn.Read().
+	fwdConn.Close()
 
-	log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
+	relayWaitGroup.Wait()
+
+	log.WithContextFields(
+		LogFields{
+			"remoteAddr": remoteAddr,
+			"bytesUp":    atomic.LoadInt64(&bytesUp),
+			"bytesDown":  atomic.LoadInt64(&bytesDown)}).Debug("exiting")
 }
 
 func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
@@ -626,3 +729,32 @@ func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string
 		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
 	}
 }
+
+func (sshClient *sshClient) stop() {
+
+	sshClient.sshConn.Close()
+	sshClient.sshConn.Wait()
+
+	close(sshClient.stopBroadcast)
+	sshClient.channelHandlerWaitGroup.Wait()
+
+	sshClient.Lock()
+	log.WithContextFields(
+		LogFields{
+			"startTime":                         sshClient.startTime,
+			"duration":                          time.Now().Sub(sshClient.startTime),
+			"psiphonSessionID":                  sshClient.psiphonSessionID,
+			"country":                           sshClient.geoIPData.Country,
+			"city":                              sshClient.geoIPData.City,
+			"ISP":                               sshClient.geoIPData.ISP,
+			"bytesUpTCP":                        sshClient.tcpTrafficState.bytesUp,
+			"bytesDownTCP":                      sshClient.tcpTrafficState.bytesDown,
+			"portForwardCountTCP":               sshClient.tcpTrafficState.portForwardCount,
+			"peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
+			"bytesUpUDP":                        sshClient.udpTrafficState.bytesUp,
+			"bytesDownUDP":                      sshClient.udpTrafficState.bytesDown,
+			"portForwardCountUDP":               sshClient.udpTrafficState.portForwardCount,
+			"peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
+		}).Info("tunnel closed")
+	sshClient.Unlock()
+}

+ 177 - 122
psiphon/server/udpChannel.go

@@ -61,15 +61,34 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 	// Accept this channel immediately. This channel will replace any
 	// previously existing UDP channel for this client.
 
-	fwdChannel, requests, err := newChannel.Accept()
+	sshChannel, requests, err := newChannel.Accept()
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
 		return
 	}
 	go ssh.DiscardRequests(requests)
-	defer fwdChannel.Close()
+	defer sshChannel.Close()
 
-	sshClient.setUDPChannel(fwdChannel)
+	sshClient.setUDPChannel(sshChannel)
+
+	multiplexer := &udpPortForwardMultiplexer{
+		sshClient:      sshClient,
+		sshChannel:     sshChannel,
+		portForwards:   make(map[uint16]*udpPortForward),
+		relayWaitGroup: new(sync.WaitGroup),
+	}
+	multiplexer.run()
+}
+
+type udpPortForwardMultiplexer struct {
+	sshClient         *sshClient
+	sshChannel        ssh.Channel
+	portForwardsMutex sync.Mutex
+	portForwards      map[uint16]*udpPortForward
+	relayWaitGroup    *sync.WaitGroup
+}
+
+func (mux *udpPortForwardMultiplexer) run() {
 
 	// In a loop, read udpgw messages from the client to this channel. Each message is
 	// a UDP packet to send upstream either via a new port forward, or on an existing
@@ -81,26 +100,11 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 	// When the client disconnects or the server shuts down, the channel will close and
 	// readUdpgwMessage will exit with EOF.
 
-	type udpPortForward struct {
-		connID       uint16
-		preambleSize int
-		remoteIP     []byte
-		remotePort   uint16
-		conn         *net.UDPConn
-		lastActivity int64
-		bytesUp      int64
-		bytesDown    int64
-	}
-
-	var portForwardsMutex sync.Mutex
-	portForwards := make(map[uint16]*udpPortForward)
-	relayWaitGroup := new(sync.WaitGroup)
 	buffer := make([]byte, udpgwProtocolMaxMessageSize)
-
 	for {
 		// Note: message.packet points to the reusable memory in "buffer".
 		// Each readUdpgwMessage call will overwrite the last message.packet.
-		message, err := readUdpgwMessage(fwdChannel, buffer)
+		message, err := readUdpgwMessage(mux.sshChannel, buffer)
 		if err != nil {
 			if err != io.EOF {
 				log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed")
@@ -108,9 +112,9 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 			break
 		}
 
-		portForwardsMutex.Lock()
-		portForward := portForwards[message.connID]
-		portForwardsMutex.Unlock()
+		mux.portForwardsMutex.Lock()
+		portForward := mux.portForwards[message.connID]
+		mux.portForwardsMutex.Unlock()
 
 		if portForward != nil && message.discardExistingConn {
 			// The port forward's goroutine will complete cleanup, including
@@ -136,55 +140,48 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 
 			// Create a new port forward
 
-			if !sshClient.isPortForwardPermitted(
+			if !mux.sshClient.isPortForwardPermitted(
 				int(message.remotePort),
-				sshClient.trafficRules.AllowUDPPorts,
-				sshClient.trafficRules.DenyUDPPorts) {
+				mux.sshClient.trafficRules.AllowUDPPorts,
+				mux.sshClient.trafficRules.DenyUDPPorts) {
 				// The udpgw protocol has no error response, so
 				// we just discard the message and read another.
 				continue
 			}
 
-			if sshClient.isPortForwardLimitExceeded(
-				sshClient.tcpTrafficState,
-				sshClient.trafficRules.MaxUDPPortForwardCount) {
+			mux.sshClient.openedPortForward(mux.sshClient.udpTrafficState)
+			// Note: can't defer sshClient.closedPortForward() here
+
+			// TOCTOU note: important to increment the port forward count (via
+			// openPortForward) _before_ checking isPortForwardLimitExceeded
+			if mux.sshClient.isPortForwardLimitExceeded(
+				mux.sshClient.tcpTrafficState,
+				mux.sshClient.trafficRules.MaxUDPPortForwardCount) {
 
 				// When the UDP port forward limit is exceeded, we
-				// select the least recently used (red from or written
+				// select the least recently used (read from or written
 				// to) port forward and discard it.
-
-				// TODO: use "container/list" and avoid a linear scan?
-				portForwardsMutex.Lock()
-				oldestActivity := int64(math.MaxInt64)
-				var oldestPortForward *udpPortForward
-				for _, nextPortForward := range portForwards {
-					if nextPortForward.lastActivity < oldestActivity {
-						oldestPortForward = nextPortForward
-					}
-				}
-				if oldestPortForward != nil {
-					// The port forward's goroutine will complete cleanup
-					oldestPortForward.conn.Close()
-				}
-				portForwardsMutex.Unlock()
+				mux.closeLeastRecentlyUsedPortForward()
 			}
 
-			dialIP := message.remoteIP
+			dialIP := net.IP(message.remoteIP)
 			dialPort := int(message.remotePort)
 
 			// Transparent DNS forwarding
-			if message.forwardDNS && sshClient.sshServer.config.DNSServerAddress != "" {
-				// Note: DNSServerAddress is validated in LoadConfig
-				host, portStr, _ := net.SplitHostPort(
-					sshClient.sshServer.config.DNSServerAddress)
-				dialIP = net.ParseIP(host)
-				dialPort, _ = strconv.Atoi(portStr)
+			if message.forwardDNS {
+				dialIP, dialPort = mux.transparentDNSAddress(dialIP, dialPort)
 			}
 
+			log.WithContextFields(
+				LogFields{
+					"remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort),
+					"connID":     message.connID}).Debug("dialing")
+
 			// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
 			updConn, err := net.DialUDP(
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 			if err != nil {
+				mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0)
 				log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
 				continue
 			}
@@ -198,82 +195,24 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 				lastActivity: time.Now().UnixNano(),
 				bytesUp:      0,
 				bytesDown:    0,
+				mux:          mux,
 			}
-			portForwardsMutex.Lock()
-			portForwards[portForward.connID] = portForward
-			portForwardsMutex.Unlock()
+			mux.portForwardsMutex.Lock()
+			mux.portForwards[portForward.connID] = portForward
+			mux.portForwardsMutex.Unlock()
 
 			// TODO: timeout inactive UDP port forwards
 
-			sshClient.establishedPortForward(sshClient.udpTrafficState)
-
-			relayWaitGroup.Add(1)
-			go func(portForward *udpPortForward) {
-				defer relayWaitGroup.Done()
-
-				// Downstream UDP packets are read into the reusable memory
-				// in "buffer" starting at the offset past the udpgw message
-				// header and address, leaving enough space to write the udpgw
-				// values into the same buffer and use for writing to the ssh
-				// channel.
-				//
-				// Note: there is one downstream buffer per UDP port forward,
-				// while for upstream there is one buffer per client.
-				// TODO: is the buffer size larger than necessary?
-				buffer := make([]byte, udpgwProtocolMaxMessageSize)
-				packetBuffer := buffer[portForward.preambleSize:udpgwProtocolMaxMessageSize]
-				for {
-					// TODO: if read buffer is too small, excess bytes are discarded?
-					packetSize, err := portForward.conn.Read(packetBuffer)
-					if packetSize > udpgwProtocolMaxPayloadSize {
-						err = fmt.Errorf("unexpected packet size: %d", packetSize)
-					}
-					if err != nil {
-						if err != io.EOF {
-							log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
-						}
-						break
-					}
-
-					err = writeUdpgwPreamble(
-						portForward.preambleSize,
-						portForward.connID,
-						portForward.remoteIP,
-						portForward.remotePort,
-						uint16(packetSize),
-						buffer)
-					if err == nil {
-						_, err = fwdChannel.Write(buffer[0 : portForward.preambleSize+packetSize])
-					}
-
-					if err != nil {
-						// Close the channel, which will interrupt the main loop.
-						fwdChannel.Close()
-						log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
-						break
-					}
-
-					atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
-					atomic.AddInt64(&portForward.bytesDown, int64(packetSize))
-				}
-
-				portForwardsMutex.Lock()
-				delete(portForwards, portForward.connID)
-				portForwardsMutex.Unlock()
-
-				portForward.conn.Close()
-
-				bytesUp := atomic.LoadInt64(&portForward.bytesUp)
-				bytesDown := atomic.LoadInt64(&portForward.bytesDown)
-				sshClient.closedPortForward(sshClient.udpTrafficState, bytesUp, bytesDown)
-
-			}(portForward)
+			// relayDownstream will call sshClient.closedPortForward()
+			mux.relayWaitGroup.Add(1)
+			go portForward.relayDownstream()
 		}
 
 		// Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP)
 		_, err = portForward.conn.Write(message.packet)
 		if err != nil {
-			log.WithContextFields(LogFields{"error": err}).Warning("upstream UDP relay failed")
+			// Debug since errors such as "write: operation not permitted" occur during normal operation
+			log.WithContextFields(LogFields{"error": err}).Debug("upstream UDP relay failed")
 			// The port forward's goroutine will complete cleanup
 			portForward.conn.Close()
 		}
@@ -283,14 +222,130 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 
 	// Cleanup all UDP port forward workers when exiting
 
-	portForwardsMutex.Lock()
-	for _, portForward := range portForwards {
+	mux.portForwardsMutex.Lock()
+	for _, portForward := range mux.portForwards {
 		// The port forward's goroutine will complete cleanup
 		portForward.conn.Close()
 	}
-	portForwardsMutex.Unlock()
+	mux.portForwardsMutex.Unlock()
+
+	mux.relayWaitGroup.Wait()
+}
+
+func (mux *udpPortForwardMultiplexer) closeLeastRecentlyUsedPortForward() {
+	// TODO: use "container/list" and avoid a linear scan?
+	mux.portForwardsMutex.Lock()
+	oldestActivity := int64(math.MaxInt64)
+	var oldestPortForward *udpPortForward
+	for _, nextPortForward := range mux.portForwards {
+		if nextPortForward.lastActivity < oldestActivity {
+			oldestPortForward = nextPortForward
+		}
+	}
+	if oldestPortForward != nil {
+		// The port forward's goroutine will complete cleanup
+		oldestPortForward.conn.Close()
+	}
+	mux.portForwardsMutex.Unlock()
+}
+
+func (mux *udpPortForwardMultiplexer) transparentDNSAddress(
+	dialIP net.IP, dialPort int) (net.IP, int) {
+
+	if mux.sshClient.sshServer.config.DNSServerAddress != "" {
+		// Note: DNSServerAddress is validated in LoadConfig
+		host, portStr, _ := net.SplitHostPort(
+			mux.sshClient.sshServer.config.DNSServerAddress)
+		dialIP = net.ParseIP(host)
+		dialPort, _ = strconv.Atoi(portStr)
+	}
+	return dialIP, dialPort
+}
+
+func (mux *udpPortForwardMultiplexer) removePortForward(connID uint16) {
+	mux.portForwardsMutex.Lock()
+	delete(mux.portForwards, connID)
+	mux.portForwardsMutex.Unlock()
+}
+
+type udpPortForward struct {
+	connID       uint16
+	preambleSize int
+	remoteIP     []byte
+	remotePort   uint16
+	conn         *net.UDPConn
+	lastActivity int64
+	bytesUp      int64
+	bytesDown    int64
+	mux          *udpPortForwardMultiplexer
+}
+
+func (portForward *udpPortForward) relayDownstream() {
+	defer portForward.mux.relayWaitGroup.Done()
+
+	// Downstream UDP packets are read into the reusable memory
+	// in "buffer" starting at the offset past the udpgw message
+	// header and address, leaving enough space to write the udpgw
+	// values into the same buffer and use for writing to the ssh
+	// channel.
+	//
+	// Note: there is one downstream buffer per UDP port forward,
+	// while for upstream there is one buffer per client.
+	// TODO: is the buffer size larger than necessary?
+	buffer := make([]byte, udpgwProtocolMaxMessageSize)
+	packetBuffer := buffer[portForward.preambleSize:udpgwProtocolMaxMessageSize]
+	for {
+		// TODO: if read buffer is too small, excess bytes are discarded?
+		packetSize, err := portForward.conn.Read(packetBuffer)
+		if packetSize > udpgwProtocolMaxPayloadSize {
+			err = fmt.Errorf("unexpected packet size: %d", packetSize)
+		}
+		if err != nil {
+			if err != io.EOF {
+				// Debug since errors such as "use of closed network connection" occur during normal operation
+				log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
+			}
+			break
+		}
+
+		err = writeUdpgwPreamble(
+			portForward.preambleSize,
+			portForward.connID,
+			portForward.remoteIP,
+			portForward.remotePort,
+			uint16(packetSize),
+			buffer)
+		if err == nil {
+			_, err = portForward.mux.sshChannel.Write(buffer[0 : portForward.preambleSize+packetSize])
+		}
+
+		if err != nil {
+			// Close the channel, which will interrupt the main loop.
+			portForward.mux.sshChannel.Close()
+			log.WithContextFields(LogFields{"error": err}).Debug("downstream UDP relay failed")
+			break
+		}
+
+		atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
+		atomic.AddInt64(&portForward.bytesDown, int64(packetSize))
+	}
+
+	portForward.mux.removePortForward(portForward.connID)
+
+	portForward.conn.Close()
+
+	bytesUp := atomic.LoadInt64(&portForward.bytesUp)
+	bytesDown := atomic.LoadInt64(&portForward.bytesDown)
+	portForward.mux.sshClient.closedPortForward(
+		portForward.mux.sshClient.udpTrafficState, bytesUp, bytesDown)
 
-	relayWaitGroup.Wait()
+	log.WithContextFields(
+		LogFields{
+			"remoteAddr": fmt.Sprintf("%s:%d",
+				net.IP(portForward.remoteIP).String(), portForward.remotePort),
+			"bytesUp":   bytesUp,
+			"bytesDown": bytesDown,
+			"connID":    portForward.connID}).Debug("exiting")
 }
 
 // TODO: express and/or calculate udpgwProtocolMaxPayloadSize as function of MTU?