Sfoglia il codice sorgente

SSH UDP/TCP port forward fixes and cleanup
* fix: synchronize port forwards shutdowns when connection closes; this ensures
that bytes transferred data is counted before the final log message
* fix: close upstream TCP socket when peer SSH channel is closed to ensurely
timely shutdown
* rearranged UDP code into smaller functions
* fix: max port forward TOCTOU flaw
* add timeout to TCP port forward dial

Rod Hynes 9 anni fa
parent
commit
922c7d2fd1
3 ha cambiato i file con 298 aggiunte e 193 eliminazioni
  1. 1 0
      psiphon/server/config.go
  2. 123 72
      psiphon/server/sshService.go
  3. 174 121
      psiphon/server/udpChannel.go

+ 1 - 0
psiphon/server/config.go

@@ -57,6 +57,7 @@ const (
 	DEFAULT_SSH_SERVER_PORT                = 2222
 	DEFAULT_SSH_SERVER_PORT                = 2222
 	SSH_HANDSHAKE_TIMEOUT                  = 30 * time.Second
 	SSH_HANDSHAKE_TIMEOUT                  = 30 * time.Second
 	SSH_CONNECTION_READ_DEADLINE           = 5 * time.Minute
 	SSH_CONNECTION_READ_DEADLINE           = 5 * time.Minute
+	SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT      = 30 * time.Second
 	SSH_OBFUSCATED_KEY_BYTE_LENGTH         = 32
 	SSH_OBFUSCATED_KEY_BYTE_LENGTH         = 32
 	DEFAULT_OBFUSCATED_SSH_SERVER_PORT     = 3333
 	DEFAULT_OBFUSCATED_SSH_SERVER_PORT     = 3333
 	REDIS_POOL_MAX_IDLE                    = 50
 	REDIS_POOL_MAX_IDLE                    = 50

+ 123 - 72
psiphon/server/sshService.go

@@ -27,6 +27,7 @@ import (
 	"io"
 	"io"
 	"net"
 	"net"
 	"sync"
 	"sync"
+	"sync/atomic"
 	"time"
 	"time"
 
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
@@ -199,36 +200,10 @@ func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
 	sshServer.clientsMutex.Unlock()
 	sshServer.clientsMutex.Unlock()
 
 
 	if client != nil {
 	if client != nil {
-		sshServer.stopClient(client)
+		client.stop()
 	}
 	}
 }
 }
 
 
-func (sshServer *sshServer) stopClient(client *sshClient) {
-
-	client.sshConn.Close()
-	client.sshConn.Wait()
-
-	client.Lock()
-	log.WithContextFields(
-		LogFields{
-			"startTime":                         client.startTime,
-			"duration":                          time.Now().Sub(client.startTime),
-			"psiphonSessionID":                  client.psiphonSessionID,
-			"country":                           client.geoIPData.Country,
-			"city":                              client.geoIPData.City,
-			"ISP":                               client.geoIPData.ISP,
-			"bytesUpTCP":                        client.tcpTrafficState.bytesUp,
-			"bytesDownTCP":                      client.tcpTrafficState.bytesDown,
-			"portForwardCountTCP":               client.tcpTrafficState.portForwardCount,
-			"peakConcurrentPortForwardCountTCP": client.tcpTrafficState.peakConcurrentPortForwardCount,
-			"bytesUpUDP":                        client.udpTrafficState.bytesUp,
-			"bytesDownUDP":                      client.udpTrafficState.bytesDown,
-			"portForwardCountUDP":               client.udpTrafficState.portForwardCount,
-			"peakConcurrentPortForwardCountUDP": client.udpTrafficState.peakConcurrentPortForwardCount,
-		}).Info("tunnel closed")
-	client.Unlock()
-}
-
 func (sshServer *sshServer) stopClients() {
 func (sshServer *sshServer) stopClients() {
 
 
 	sshServer.clientsMutex.Lock()
 	sshServer.clientsMutex.Lock()
@@ -237,7 +212,7 @@ func (sshServer *sshServer) stopClients() {
 	sshServer.clientsMutex.Unlock()
 	sshServer.clientsMutex.Unlock()
 
 
 	for _, client := range sshServer.clients {
 	for _, client := range sshServer.clients {
-		sshServer.stopClient(client)
+		client.stop()
 	}
 	}
 }
 }
 
 
@@ -245,14 +220,8 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
 
 	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
 	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
 
 
-	sshClient := &sshClient{
-		sshServer:       sshServer,
-		startTime:       time.Now(),
-		geoIPData:       geoIPData,
-		trafficRules:    sshServer.config.GetTrafficRules(geoIPData.Country),
-		tcpTrafficState: &trafficState{},
-		udpTrafficState: &trafficState{},
-	}
+	sshClient := newSshClient(
+		sshServer, geoIPData, sshServer.config.GetTrafficRules(geoIPData.Country))
 
 
 	// Wrap the base TCP connection with an IdleTimeoutConn which will terminate
 	// Wrap the base TCP connection with an IdleTimeoutConn which will terminate
 	// the connection if no data is received before the deadline. This timeout is
 	// the connection if no data is received before the deadline. This timeout is
@@ -351,15 +320,17 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
 
 type sshClient struct {
 type sshClient struct {
 	sync.Mutex
 	sync.Mutex
-	sshServer        *sshServer
-	sshConn          ssh.Conn
-	startTime        time.Time
-	geoIPData        GeoIPData
-	psiphonSessionID string
-	udpChannel       ssh.Channel
-	trafficRules     TrafficRules
-	tcpTrafficState  *trafficState
-	udpTrafficState  *trafficState
+	sshServer               *sshServer
+	sshConn                 ssh.Conn
+	startTime               time.Time
+	geoIPData               GeoIPData
+	psiphonSessionID        string
+	udpChannel              ssh.Channel
+	trafficRules            TrafficRules
+	tcpTrafficState         *trafficState
+	udpTrafficState         *trafficState
+	channelHandlerWaitGroup *sync.WaitGroup
+	stopBroadcast           chan struct{}
 }
 }
 
 
 type trafficState struct {
 type trafficState struct {
@@ -370,15 +341,29 @@ type trafficState struct {
 	peakConcurrentPortForwardCount int64
 	peakConcurrentPortForwardCount int64
 }
 }
 
 
+func newSshClient(sshServer *sshServer, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
+	return &sshClient{
+		sshServer:               sshServer,
+		startTime:               time.Now(),
+		geoIPData:               geoIPData,
+		trafficRules:            trafficRules,
+		tcpTrafficState:         &trafficState{},
+		udpTrafficState:         &trafficState{},
+		channelHandlerWaitGroup: new(sync.WaitGroup),
+		stopBroadcast:           make(chan struct{}),
+	}
+}
+
 func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
 func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
 	for newChannel := range channels {
 	for newChannel := range channels {
 
 
 		if newChannel.ChannelType() != "direct-tcpip" {
 		if newChannel.ChannelType() != "direct-tcpip" {
 			sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
 			sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
-			return
+			continue
 		}
 		}
 
 
 		// process each port forward concurrently
 		// process each port forward concurrently
+		sshClient.channelHandlerWaitGroup.Add(1)
 		go sshClient.handleNewPortForwardChannel(newChannel)
 		go sshClient.handleNewPortForwardChannel(newChannel)
 	}
 	}
 }
 }
@@ -395,6 +380,7 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason s
 }
 }
 
 
 func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
 func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
+	defer sshClient.channelHandlerWaitGroup.Done()
 
 
 	// http://tools.ietf.org/html/rfc4254#section-7.2
 	// http://tools.ietf.org/html/rfc4254#section-7.2
 	var directTcpipExtraData struct {
 	var directTcpipExtraData struct {
@@ -460,7 +446,7 @@ func (sshClient *sshClient) isPortForwardLimitExceeded(
 	return limitExceeded
 	return limitExceeded
 }
 }
 
 
-func (sshClient *sshClient) establishedPortForward(
+func (sshClient *sshClient) openedPortForward(
 	state *trafficState) {
 	state *trafficState) {
 
 
 	sshClient.Lock()
 	sshClient.Lock()
@@ -497,7 +483,17 @@ func (sshClient *sshClient) handleTCPChannel(
 		return
 		return
 	}
 	}
 
 
-	// TODO: close LRU connection (after successful Dial) instead of rejecting new connection?
+	var bytesUp, bytesDown int64
+	sshClient.openedPortForward(sshClient.tcpTrafficState)
+	defer sshClient.closedPortForward(
+		sshClient.tcpTrafficState, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
+
+	// TOCTOU note: important to increment the port forward count (via
+	// openPortForward) _before_ checking isPortForwardLimitExceeded
+	// otherwise, the client could potentially consume excess resources
+	// by initiating many port forwards concurrently.
+	// TODO: close LRU connection (after successful Dial) instead of
+	// rejecting new connection?
 	if sshClient.isPortForwardLimitExceeded(
 	if sshClient.isPortForwardLimitExceeded(
 		sshClient.tcpTrafficState,
 		sshClient.tcpTrafficState,
 		sshClient.trafficRules.MaxTCPPortForwardCount) {
 		sshClient.trafficRules.MaxTCPPortForwardCount) {
@@ -507,18 +503,39 @@ func (sshClient *sshClient) handleTCPChannel(
 		return
 		return
 	}
 	}
 
 
-	targetAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
+	remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
 
 
-	log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
+	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
 
 
-	// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
-	// TODO: port forward dial timeout
-	// TODO: IPv6 support
-	fwdConn, err := net.Dial("tcp4", targetAddr)
-	if err != nil {
-		sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
+	type dialTcpResult struct {
+		conn net.Conn
+		err  error
+	}
+
+	resultChannel := make(chan *dialTcpResult, 1)
+
+	go func() {
+		// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
+		// TODO: IPv6 support
+		conn, err := net.DialTimeout(
+			"tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
+		resultChannel <- &dialTcpResult{conn, err}
+	}()
+
+	var result *dialTcpResult
+	select {
+	case result = <-resultChannel:
+	case <-sshClient.stopBroadcast:
+		// Note: may leave dial in progress
+		return
+	}
+
+	if result.err != nil {
+		sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error())
 		return
 		return
 	}
 	}
+
+	fwdConn := result.conn
 	defer fwdConn.Close()
 	defer fwdConn.Close()
 
 
 	fwdChannel, requests, err := newChannel.Accept()
 	fwdChannel, requests, err := newChannel.Accept()
@@ -529,9 +546,7 @@ func (sshClient *sshClient) handleTCPChannel(
 	go ssh.DiscardRequests(requests)
 	go ssh.DiscardRequests(requests)
 	defer fwdChannel.Close()
 	defer fwdChannel.Close()
 
 
-	sshClient.establishedPortForward(sshClient.tcpTrafficState)
-
-	log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
+	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
 
 
 	// When idle port forward traffic rules are in place, wrap fwdConn
 	// When idle port forward traffic rules are in place, wrap fwdConn
 	// in an IdleTimeoutConn configured to reset idle on writes as well
 	// in an IdleTimeoutConn configured to reset idle on writes as well
@@ -549,28 +564,35 @@ func (sshClient *sshClient) handleTCPChannel(
 	// TODO: relay errors to fwdChannel.Stderr()?
 	// TODO: relay errors to fwdChannel.Stderr()?
 	// TODO: use a low-memory io.Copy?
 	// TODO: use a low-memory io.Copy?
 
 
-	var bytesUp, bytesDown int64
-
 	relayWaitGroup := new(sync.WaitGroup)
 	relayWaitGroup := new(sync.WaitGroup)
 	relayWaitGroup.Add(1)
 	relayWaitGroup.Add(1)
 	go func() {
 	go func() {
 		defer relayWaitGroup.Done()
 		defer relayWaitGroup.Done()
-		var err error
-		bytesUp, err = io.Copy(fwdConn, fwdChannel)
-		if err != nil {
-			log.WithContextFields(LogFields{"error": err}).Warning("upstream TCP relay failed")
+		bytes, err := io.Copy(fwdChannel, fwdConn)
+		atomic.AddInt64(&bytesDown, bytes)
+		if err != nil && err != io.EOF {
+			log.WithContextFields(LogFields{"error": err}).Warning("downstream TCP relay failed")
 		}
 		}
 	}()
 	}()
-	bytesDown, err = io.Copy(fwdChannel, fwdConn)
-	if err != nil {
-		log.WithContextFields(LogFields{"error": err}).Warning("downstream TCP relay failed")
+	bytes, err := io.Copy(fwdConn, fwdChannel)
+	atomic.AddInt64(&bytesUp, bytes)
+	if err != nil && err != io.EOF {
+		log.WithContextFields(LogFields{"error": err}).Warning("upstream TCP relay failed")
 	}
 	}
-	fwdChannel.CloseWrite()
-	relayWaitGroup.Wait()
 
 
-	sshClient.closedPortForward(sshClient.tcpTrafficState, bytesUp, bytesDown)
+	// Shutdown special case: fwdChannel will be closed and return EOF when
+	// the SSH connection is closed, but we need to explicitly close fwdConn
+	// to interrupt the downstream io.Copy, which may be blocked on a
+	// fwdConn.Read().
+	fwdConn.Close()
 
 
-	log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
+	relayWaitGroup.Wait()
+
+	log.WithContextFields(
+		LogFields{
+			"remoteAddr": remoteAddr,
+			"bytesUp":    atomic.LoadInt64(&bytesUp),
+			"bytesDown":  atomic.LoadInt64(&bytesDown)}).Debug("exiting")
 }
 }
 
 
 func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
 func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
@@ -626,3 +648,32 @@ func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string
 		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
 		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
 	}
 	}
 }
 }
+
+func (sshClient *sshClient) stop() {
+
+	sshClient.sshConn.Close()
+	sshClient.sshConn.Wait()
+
+	close(sshClient.stopBroadcast)
+	sshClient.channelHandlerWaitGroup.Wait()
+
+	sshClient.Lock()
+	log.WithContextFields(
+		LogFields{
+			"startTime":                         sshClient.startTime,
+			"duration":                          time.Now().Sub(sshClient.startTime),
+			"psiphonSessionID":                  sshClient.psiphonSessionID,
+			"country":                           sshClient.geoIPData.Country,
+			"city":                              sshClient.geoIPData.City,
+			"ISP":                               sshClient.geoIPData.ISP,
+			"bytesUpTCP":                        sshClient.tcpTrafficState.bytesUp,
+			"bytesDownTCP":                      sshClient.tcpTrafficState.bytesDown,
+			"portForwardCountTCP":               sshClient.tcpTrafficState.portForwardCount,
+			"peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
+			"bytesUpUDP":                        sshClient.udpTrafficState.bytesUp,
+			"bytesDownUDP":                      sshClient.udpTrafficState.bytesDown,
+			"portForwardCountUDP":               sshClient.udpTrafficState.portForwardCount,
+			"peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
+		}).Info("tunnel closed")
+	sshClient.Unlock()
+}

+ 174 - 121
psiphon/server/udpChannel.go

@@ -61,15 +61,34 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 	// Accept this channel immediately. This channel will replace any
 	// Accept this channel immediately. This channel will replace any
 	// previously existing UDP channel for this client.
 	// previously existing UDP channel for this client.
 
 
-	fwdChannel, requests, err := newChannel.Accept()
+	sshChannel, requests, err := newChannel.Accept()
 	if err != nil {
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
 		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
 		return
 		return
 	}
 	}
 	go ssh.DiscardRequests(requests)
 	go ssh.DiscardRequests(requests)
-	defer fwdChannel.Close()
+	defer sshChannel.Close()
 
 
-	sshClient.setUDPChannel(fwdChannel)
+	sshClient.setUDPChannel(sshChannel)
+
+	multiplexer := &udpPortForwardMultiplexer{
+		sshClient:      sshClient,
+		sshChannel:     sshChannel,
+		portForwards:   make(map[uint16]*udpPortForward),
+		relayWaitGroup: new(sync.WaitGroup),
+	}
+	multiplexer.run()
+}
+
+type udpPortForwardMultiplexer struct {
+	sshClient         *sshClient
+	sshChannel        ssh.Channel
+	portForwardsMutex sync.Mutex
+	portForwards      map[uint16]*udpPortForward
+	relayWaitGroup    *sync.WaitGroup
+}
+
+func (mux *udpPortForwardMultiplexer) run() {
 
 
 	// In a loop, read udpgw messages from the client to this channel. Each message is
 	// In a loop, read udpgw messages from the client to this channel. Each message is
 	// a UDP packet to send upstream either via a new port forward, or on an existing
 	// a UDP packet to send upstream either via a new port forward, or on an existing
@@ -81,26 +100,11 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 	// When the client disconnects or the server shuts down, the channel will close and
 	// When the client disconnects or the server shuts down, the channel will close and
 	// readUdpgwMessage will exit with EOF.
 	// readUdpgwMessage will exit with EOF.
 
 
-	type udpPortForward struct {
-		connID       uint16
-		preambleSize int
-		remoteIP     []byte
-		remotePort   uint16
-		conn         *net.UDPConn
-		lastActivity int64
-		bytesUp      int64
-		bytesDown    int64
-	}
-
-	var portForwardsMutex sync.Mutex
-	portForwards := make(map[uint16]*udpPortForward)
-	relayWaitGroup := new(sync.WaitGroup)
 	buffer := make([]byte, udpgwProtocolMaxMessageSize)
 	buffer := make([]byte, udpgwProtocolMaxMessageSize)
-
 	for {
 	for {
 		// Note: message.packet points to the reusable memory in "buffer".
 		// Note: message.packet points to the reusable memory in "buffer".
 		// Each readUdpgwMessage call will overwrite the last message.packet.
 		// Each readUdpgwMessage call will overwrite the last message.packet.
-		message, err := readUdpgwMessage(fwdChannel, buffer)
+		message, err := readUdpgwMessage(mux.sshChannel, buffer)
 		if err != nil {
 		if err != nil {
 			if err != io.EOF {
 			if err != io.EOF {
 				log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed")
 				log.WithContextFields(LogFields{"error": err}).Warning("readUpdgwMessage failed")
@@ -108,9 +112,9 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 			break
 			break
 		}
 		}
 
 
-		portForwardsMutex.Lock()
-		portForward := portForwards[message.connID]
-		portForwardsMutex.Unlock()
+		mux.portForwardsMutex.Lock()
+		portForward := mux.portForwards[message.connID]
+		mux.portForwardsMutex.Unlock()
 
 
 		if portForward != nil && message.discardExistingConn {
 		if portForward != nil && message.discardExistingConn {
 			// The port forward's goroutine will complete cleanup, including
 			// The port forward's goroutine will complete cleanup, including
@@ -136,55 +140,48 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 
 
 			// Create a new port forward
 			// Create a new port forward
 
 
-			if !sshClient.isPortForwardPermitted(
+			if !mux.sshClient.isPortForwardPermitted(
 				int(message.remotePort),
 				int(message.remotePort),
-				sshClient.trafficRules.AllowUDPPorts,
-				sshClient.trafficRules.DenyUDPPorts) {
+				mux.sshClient.trafficRules.AllowUDPPorts,
+				mux.sshClient.trafficRules.DenyUDPPorts) {
 				// The udpgw protocol has no error response, so
 				// The udpgw protocol has no error response, so
 				// we just discard the message and read another.
 				// we just discard the message and read another.
 				continue
 				continue
 			}
 			}
 
 
-			if sshClient.isPortForwardLimitExceeded(
-				sshClient.tcpTrafficState,
-				sshClient.trafficRules.MaxUDPPortForwardCount) {
+			mux.sshClient.openedPortForward(mux.sshClient.udpTrafficState)
+			// Note: can't defer sshClient.closedPortForward() here
+
+			// TOCTOU note: important to increment the port forward count (via
+			// openPortForward) _before_ checking isPortForwardLimitExceeded
+			if mux.sshClient.isPortForwardLimitExceeded(
+				mux.sshClient.tcpTrafficState,
+				mux.sshClient.trafficRules.MaxUDPPortForwardCount) {
 
 
 				// When the UDP port forward limit is exceeded, we
 				// When the UDP port forward limit is exceeded, we
-				// select the least recently used (red from or written
+				// select the least recently used (read from or written
 				// to) port forward and discard it.
 				// to) port forward and discard it.
-
-				// TODO: use "container/list" and avoid a linear scan?
-				portForwardsMutex.Lock()
-				oldestActivity := int64(math.MaxInt64)
-				var oldestPortForward *udpPortForward
-				for _, nextPortForward := range portForwards {
-					if nextPortForward.lastActivity < oldestActivity {
-						oldestPortForward = nextPortForward
-					}
-				}
-				if oldestPortForward != nil {
-					// The port forward's goroutine will complete cleanup
-					oldestPortForward.conn.Close()
-				}
-				portForwardsMutex.Unlock()
+				mux.closeLeastRecentlyUsedPortForward()
 			}
 			}
 
 
-			dialIP := message.remoteIP
+			dialIP := net.IP(message.remoteIP)
 			dialPort := int(message.remotePort)
 			dialPort := int(message.remotePort)
 
 
 			// Transparent DNS forwarding
 			// Transparent DNS forwarding
-			if message.forwardDNS && sshClient.sshServer.config.DNSServerAddress != "" {
-				// Note: DNSServerAddress is validated in LoadConfig
-				host, portStr, _ := net.SplitHostPort(
-					sshClient.sshServer.config.DNSServerAddress)
-				dialIP = net.ParseIP(host)
-				dialPort, _ = strconv.Atoi(portStr)
+			if message.forwardDNS {
+				dialIP, dialPort = mux.transparentDNSAddress(dialIP, dialPort)
 			}
 			}
 
 
+			log.WithContextFields(
+				LogFields{
+					"remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort),
+					"connID":     message.connID}).Debug("dialing")
+
 			// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
 			// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
 			updConn, err := net.DialUDP(
 			updConn, err := net.DialUDP(
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 			if err != nil {
 			if err != nil {
+				mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0)
 				log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
 				log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
 				continue
 				continue
 			}
 			}
@@ -198,76 +195,17 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 				lastActivity: time.Now().UnixNano(),
 				lastActivity: time.Now().UnixNano(),
 				bytesUp:      0,
 				bytesUp:      0,
 				bytesDown:    0,
 				bytesDown:    0,
+				mux:          mux,
 			}
 			}
-			portForwardsMutex.Lock()
-			portForwards[portForward.connID] = portForward
-			portForwardsMutex.Unlock()
+			mux.portForwardsMutex.Lock()
+			mux.portForwards[portForward.connID] = portForward
+			mux.portForwardsMutex.Unlock()
 
 
 			// TODO: timeout inactive UDP port forwards
 			// TODO: timeout inactive UDP port forwards
 
 
-			sshClient.establishedPortForward(sshClient.udpTrafficState)
-
-			relayWaitGroup.Add(1)
-			go func(portForward *udpPortForward) {
-				defer relayWaitGroup.Done()
-
-				// Downstream UDP packets are read into the reusable memory
-				// in "buffer" starting at the offset past the udpgw message
-				// header and address, leaving enough space to write the udpgw
-				// values into the same buffer and use for writing to the ssh
-				// channel.
-				//
-				// Note: there is one downstream buffer per UDP port forward,
-				// while for upstream there is one buffer per client.
-				// TODO: is the buffer size larger than necessary?
-				buffer := make([]byte, udpgwProtocolMaxMessageSize)
-				packetBuffer := buffer[portForward.preambleSize:udpgwProtocolMaxMessageSize]
-				for {
-					// TODO: if read buffer is too small, excess bytes are discarded?
-					packetSize, err := portForward.conn.Read(packetBuffer)
-					if packetSize > udpgwProtocolMaxPayloadSize {
-						err = fmt.Errorf("unexpected packet size: %d", packetSize)
-					}
-					if err != nil {
-						if err != io.EOF {
-							log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
-						}
-						break
-					}
-
-					err = writeUdpgwPreamble(
-						portForward.preambleSize,
-						portForward.connID,
-						portForward.remoteIP,
-						portForward.remotePort,
-						uint16(packetSize),
-						buffer)
-					if err == nil {
-						_, err = fwdChannel.Write(buffer[0 : portForward.preambleSize+packetSize])
-					}
-
-					if err != nil {
-						// Close the channel, which will interrupt the main loop.
-						fwdChannel.Close()
-						log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
-						break
-					}
-
-					atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
-					atomic.AddInt64(&portForward.bytesDown, int64(packetSize))
-				}
-
-				portForwardsMutex.Lock()
-				delete(portForwards, portForward.connID)
-				portForwardsMutex.Unlock()
-
-				portForward.conn.Close()
-
-				bytesUp := atomic.LoadInt64(&portForward.bytesUp)
-				bytesDown := atomic.LoadInt64(&portForward.bytesDown)
-				sshClient.closedPortForward(sshClient.udpTrafficState, bytesUp, bytesDown)
-
-			}(portForward)
+			// relayDownstream will call sshClient.closedPortForward()
+			mux.relayWaitGroup.Add(1)
+			go portForward.relayDownstream()
 		}
 		}
 
 
 		// Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP)
 		// Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP)
@@ -283,14 +221,129 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 
 
 	// Cleanup all UDP port forward workers when exiting
 	// Cleanup all UDP port forward workers when exiting
 
 
-	portForwardsMutex.Lock()
-	for _, portForward := range portForwards {
+	mux.portForwardsMutex.Lock()
+	for _, portForward := range mux.portForwards {
 		// The port forward's goroutine will complete cleanup
 		// The port forward's goroutine will complete cleanup
 		portForward.conn.Close()
 		portForward.conn.Close()
 	}
 	}
-	portForwardsMutex.Unlock()
+	mux.portForwardsMutex.Unlock()
+
+	mux.relayWaitGroup.Wait()
+}
+
+func (mux *udpPortForwardMultiplexer) closeLeastRecentlyUsedPortForward() {
+	// TODO: use "container/list" and avoid a linear scan?
+	mux.portForwardsMutex.Lock()
+	oldestActivity := int64(math.MaxInt64)
+	var oldestPortForward *udpPortForward
+	for _, nextPortForward := range mux.portForwards {
+		if nextPortForward.lastActivity < oldestActivity {
+			oldestPortForward = nextPortForward
+		}
+	}
+	if oldestPortForward != nil {
+		// The port forward's goroutine will complete cleanup
+		oldestPortForward.conn.Close()
+	}
+	mux.portForwardsMutex.Unlock()
+}
+
+func (mux *udpPortForwardMultiplexer) transparentDNSAddress(
+	dialIP net.IP, dialPort int) (net.IP, int) {
+
+	if mux.sshClient.sshServer.config.DNSServerAddress != "" {
+		// Note: DNSServerAddress is validated in LoadConfig
+		host, portStr, _ := net.SplitHostPort(
+			mux.sshClient.sshServer.config.DNSServerAddress)
+		dialIP = net.ParseIP(host)
+		dialPort, _ = strconv.Atoi(portStr)
+	}
+	return dialIP, dialPort
+}
+
+func (mux *udpPortForwardMultiplexer) removePortForward(connID uint16) {
+	mux.portForwardsMutex.Lock()
+	delete(mux.portForwards, connID)
+	mux.portForwardsMutex.Unlock()
+}
+
+type udpPortForward struct {
+	connID       uint16
+	preambleSize int
+	remoteIP     []byte
+	remotePort   uint16
+	conn         *net.UDPConn
+	lastActivity int64
+	bytesUp      int64
+	bytesDown    int64
+	mux          *udpPortForwardMultiplexer
+}
+
+func (portForward *udpPortForward) relayDownstream() {
+	defer portForward.mux.relayWaitGroup.Done()
+
+	// Downstream UDP packets are read into the reusable memory
+	// in "buffer" starting at the offset past the udpgw message
+	// header and address, leaving enough space to write the udpgw
+	// values into the same buffer and use for writing to the ssh
+	// channel.
+	//
+	// Note: there is one downstream buffer per UDP port forward,
+	// while for upstream there is one buffer per client.
+	// TODO: is the buffer size larger than necessary?
+	buffer := make([]byte, udpgwProtocolMaxMessageSize)
+	packetBuffer := buffer[portForward.preambleSize:udpgwProtocolMaxMessageSize]
+	for {
+		// TODO: if read buffer is too small, excess bytes are discarded?
+		packetSize, err := portForward.conn.Read(packetBuffer)
+		if packetSize > udpgwProtocolMaxPayloadSize {
+			err = fmt.Errorf("unexpected packet size: %d", packetSize)
+		}
+		if err != nil {
+			if err != io.EOF {
+				log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
+			}
+			break
+		}
+
+		err = writeUdpgwPreamble(
+			portForward.preambleSize,
+			portForward.connID,
+			portForward.remoteIP,
+			portForward.remotePort,
+			uint16(packetSize),
+			buffer)
+		if err == nil {
+			_, err = portForward.mux.sshChannel.Write(buffer[0 : portForward.preambleSize+packetSize])
+		}
+
+		if err != nil {
+			// Close the channel, which will interrupt the main loop.
+			portForward.mux.sshChannel.Close()
+			log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
+			break
+		}
+
+		atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
+		atomic.AddInt64(&portForward.bytesDown, int64(packetSize))
+	}
+
+	portForward.mux.removePortForward(portForward.connID)
+
+	portForward.conn.Close()
+
+	bytesUp := atomic.LoadInt64(&portForward.bytesUp)
+	bytesDown := atomic.LoadInt64(&portForward.bytesDown)
+	portForward.mux.sshClient.closedPortForward(
+		portForward.mux.sshClient.udpTrafficState, bytesUp, bytesDown)
 
 
-	relayWaitGroup.Wait()
+	log.WithContextFields(
+		LogFields{
+			"remoteAddr": fmt.Sprintf("%s:%d",
+				net.IP(portForward.remoteIP).String(), portForward.remotePort),
+			"bytesUp":   bytesUp,
+			"bytesDown": bytesDown,
+			"connID":    portForward.connID}).Debug("exiting")
 }
 }
 
 
 // TODO: express and/or calculate udpgwProtocolMaxPayloadSize as function of MTU?
 // TODO: express and/or calculate udpgwProtocolMaxPayloadSize as function of MTU?