Explorar o código

Add LRU logic for TCP; use same code to track LRU for UDP

Rod Hynes %!s(int64=9) %!d(string=hai) anos
pai
achega
2ae1f4ce1c
Modificáronse 4 ficheiros con 328 adicións e 175 borrados
  1. 130 19
      psiphon/net.go
  2. 17 10
      psiphon/server/config.go
  3. 147 105
      psiphon/server/tunnelServer.go
  4. 34 41
      psiphon/server/udp.go

+ 130 - 19
psiphon/net.go

@@ -51,6 +51,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 package psiphon
 
 import (
+	"container/list"
 	"crypto/tls"
 	"crypto/x509"
 	"errors"
@@ -221,6 +222,92 @@ func (conns *Conns) CloseAll() {
 	conns.conns = make(map[net.Conn]bool)
 }
 
+// LRUConns is a concurrency-safe list of net.Conns ordered
+// by recent activity. Its purpose is to facilitate closing
+// the oldest connection in a set of connections.
+//
+// New connections added are referenced by a LRUConnsEntry,
+// which is used to Touch() active connections, which
+// promotes them to the front of the order and to Remove()
+// connections that are no longer LRU candidates.
+//
+// CloseOldest() will remove the oldest connection from the
+// list and call net.Conn.Close() on the connection.
+//
+// After an entry has been removed, LRUConnsEntry Touch()
+// and Remove() will have no effect.
+type LRUConns struct {
+	mutex sync.Mutex
+	list  *list.List
+}
+
+// NewLRUConns initializes a new LRUConns.
+func NewLRUConns() *LRUConns {
+	return &LRUConns{list: list.New()}
+}
+
+// Add inserts a net.Conn as the freshest connection
+// in a LRUConns and returns an LRUConnsEntry to be
+// used to freshen the connection or remove the connection
+// from the LRU list.
+func (conns *LRUConns) Add(conn net.Conn) *LRUConnsEntry {
+	conns.mutex.Lock()
+	defer conns.mutex.Unlock()
+	return &LRUConnsEntry{
+		lruConns: conns,
+		element:  conns.list.PushFront(conn),
+	}
+}
+
+// CloseOldest closes the oldest connection in a
+// LRUConns. It calls net.Conn.Close() on the
+// connection.
+func (conns *LRUConns) CloseOldest() {
+	conns.mutex.Lock()
+	oldest := conns.list.Back()
+	conn, ok := oldest.Value.(net.Conn)
+	if oldest != nil {
+		conns.list.Remove(oldest)
+	}
+	// Release mutex before closing conn
+	conns.mutex.Unlock()
+	if ok {
+		conn.Close()
+	}
+}
+
+// LRUConnsEntry is an entry in a LRUConns list.
+type LRUConnsEntry struct {
+	lruConns *LRUConns
+	element  *list.Element
+}
+
+// Remove deletes the connection referenced by the
+// LRUConnsEntry from the associated LRUConns.
+// Has no effect if the entry was not initialized
+// or previously removed.
+func (entry *LRUConnsEntry) Remove() {
+	if entry.lruConns == nil || entry.element == nil {
+		return
+	}
+	entry.lruConns.mutex.Lock()
+	defer entry.lruConns.mutex.Unlock()
+	entry.lruConns.list.Remove(entry.element)
+}
+
+// Touch promotes the connection referenced by the
+// LRUConnsEntry to the front of the associated LRUConns.
+// Has no effect if the entry was not initialized
+// or previously removed.
+func (entry *LRUConnsEntry) Touch() {
+	if entry.lruConns == nil || entry.element == nil {
+		return
+	}
+	entry.lruConns.mutex.Lock()
+	defer entry.lruConns.mutex.Unlock()
+	entry.lruConns.list.MoveToFront(entry.element)
+}
+
 // LocalProxyRelay sends to remoteConn bytes received from localConn,
 // and sends to localConn bytes received from remoteConn.
 func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
@@ -647,39 +734,63 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
 	return tc, nil
 }
 
-// IdleTimeoutConn wraps a net.Conn and sets an initial ReadDeadline. The
-// deadline is extended whenever data is received from the connection.
-// Optionally, IdleTimeoutConn will also extend the deadline when data is
-// written to the connection.
-type IdleTimeoutConn struct {
+// ActivityMonitoredConn wraps a net.Conn, adding logic to deal with
+// events triggered by I/O activity.
+//
+// When an inactivity timeout is specified, the net.Conn Read() will
+// timeout after the specified period of read inactivity. Optionally,
+// ActivityMonitoredConn will also consider the connection active when
+// data is written to it.
+//
+// When a LRUConnsEntry is specified, then the LRU entry is promoted on
+// either a successful read or write.
+//
+type ActivityMonitoredConn struct {
 	net.Conn
-	deadline     time.Duration
-	resetOnWrite bool
+	inactivityTimeout time.Duration
+	activeOnWrite     bool
+	lruEntry          *LRUConnsEntry
 }
 
-func NewIdleTimeoutConn(
-	conn net.Conn, deadline time.Duration, resetOnWrite bool) *IdleTimeoutConn {
+func NewActivityMonitoredConn(
+	conn net.Conn,
+	inactivityTimeout time.Duration,
+	activeOnWrite bool,
+	lruEntry *LRUConnsEntry) *ActivityMonitoredConn {
 
-	conn.SetReadDeadline(time.Now().Add(deadline))
-	return &IdleTimeoutConn{
-		Conn:         conn,
-		deadline:     deadline,
-		resetOnWrite: resetOnWrite,
+	if inactivityTimeout > 0 {
+		conn.SetReadDeadline(time.Now().Add(inactivityTimeout))
+	}
+	return &ActivityMonitoredConn{
+		Conn:              conn,
+		inactivityTimeout: inactivityTimeout,
+		activeOnWrite:     activeOnWrite,
+		lruEntry:          lruEntry,
 	}
 }
 
-func (conn *IdleTimeoutConn) Read(buffer []byte) (int, error) {
+func (conn *ActivityMonitoredConn) Read(buffer []byte) (int, error) {
 	n, err := conn.Conn.Read(buffer)
 	if err == nil {
-		conn.Conn.SetReadDeadline(time.Now().Add(conn.deadline))
+		if conn.inactivityTimeout > 0 {
+			conn.Conn.SetReadDeadline(time.Now().Add(conn.inactivityTimeout))
+		}
+		if conn.lruEntry != nil {
+			conn.lruEntry.Touch()
+		}
 	}
 	return n, err
 }
 
-func (conn *IdleTimeoutConn) Write(buffer []byte) (int, error) {
+func (conn *ActivityMonitoredConn) Write(buffer []byte) (int, error) {
 	n, err := conn.Conn.Write(buffer)
-	if err == nil && conn.resetOnWrite {
-		conn.Conn.SetReadDeadline(time.Now().Add(conn.deadline))
+	if err == nil {
+		if conn.inactivityTimeout > 0 && conn.activeOnWrite {
+			conn.Conn.SetReadDeadline(time.Now().Add(conn.inactivityTimeout))
+		}
+		if conn.lruEntry != nil {
+			conn.lruEntry.Touch()
+		}
 	}
 	return n, err
 }

+ 17 - 10
psiphon/server/config.go

@@ -254,11 +254,17 @@ type TrafficRules struct {
 	// TunnelProtocolPorts.
 	ProtocolRateLimits map[string]RateLimits
 
-	// IdlePortForwardTimeoutMilliseconds is the timeout period
+	// IdleTCPPortForwardTimeoutMilliseconds is the timeout period
 	// after which idle (no bytes flowing in either direction)
-	// SSH client port forwards are preemptively closed.
+	// client TCP port forwards are preemptively closed.
 	// The default, 0, is no idle timeout.
-	IdlePortForwardTimeoutMilliseconds int
+	IdleTCPPortForwardTimeoutMilliseconds int
+
+	// IdleUDPPortForwardTimeoutMilliseconds is the timeout period
+	// after which idle (no bytes flowing in either direction)
+	// client UDP port forwards are preemptively closed.
+	// The default, 0, is no idle timeout.
+	IdleUDPPortForwardTimeoutMilliseconds int
 
 	// MaxTCPPortForwardCount is the maximum number of TCP port
 	// forwards each client may have open concurrently.
@@ -567,13 +573,14 @@ func GenerateConfig(serverIPaddress string) ([]byte, []byte, error) {
 				UpstreamUnlimitedBytes:   0,
 				UpstreamBytesPerSecond:   0,
 			},
-			IdlePortForwardTimeoutMilliseconds: 30000,
-			MaxTCPPortForwardCount:             1024,
-			MaxUDPPortForwardCount:             32,
-			AllowTCPPorts:                      nil,
-			AllowUDPPorts:                      nil,
-			DenyTCPPorts:                       nil,
-			DenyUDPPorts:                       nil,
+			IdleTCPPortForwardTimeoutMilliseconds: 30000,
+			IdleUDPPortForwardTimeoutMilliseconds: 30000,
+			MaxTCPPortForwardCount:                1024,
+			MaxUDPPortForwardCount:                32,
+			AllowTCPPorts:                         nil,
+			AllowUDPPorts:                         nil,
+			DenyTCPPorts:                          nil,
+			DenyUDPPorts:                          nil,
 		},
 		LoadMonitorPeriodSeconds: 300,
 	}

+ 147 - 105
psiphon/server/tunnelServer.go

@@ -355,13 +355,17 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 		geoIPData,
 		sshServer.config.GetTrafficRules(geoIPData.Country))
 
-	// Wrap the base client connection with an IdleTimeoutConn which will terminate
-	// the connection if no data is received before the deadline. This timeout is
-	// in effect for the entire duration of the SSH connection. Clients must actively
-	// use the connection or send SSH keep alive requests to keep the connection
-	// active.
+	// Wrap the base client connection with an ActivityMonitoredConn which will
+	// terminate the connection if no data is received before the deadline. This
+	// timeout is in effect for the entire duration of the SSH connection. Clients
+	// must actively use the connection or send SSH keep alive requests to keep
+	// the connection active.
 
-	clientConn = psiphon.NewIdleTimeoutConn(clientConn, SSH_CONNECTION_READ_DEADLINE, false)
+	clientConn = psiphon.NewActivityMonitoredConn(
+		clientConn,
+		SSH_CONNECTION_READ_DEADLINE,
+		false,
+		nil)
 
 	// Further wrap the connection in a rate limiting ThrottledConn.
 
@@ -478,6 +482,7 @@ type sshClient struct {
 	tcpTrafficState         *trafficState
 	udpTrafficState         *trafficState
 	channelHandlerWaitGroup *sync.WaitGroup
+	tcpPortForwardLRU       *psiphon.LRUConns
 	stopBroadcast           chan struct{}
 }
 
@@ -500,10 +505,94 @@ func newSshClient(
 		tcpTrafficState:         &trafficState{},
 		udpTrafficState:         &trafficState{},
 		channelHandlerWaitGroup: new(sync.WaitGroup),
+		tcpPortForwardLRU:       psiphon.NewLRUConns(),
 		stopBroadcast:           make(chan struct{}),
 	}
 }
 
+func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
+	var sshPasswordPayload struct {
+		SessionId   string `json:"SessionId"`
+		SshPassword string `json:"SshPassword"`
+	}
+	err := json.Unmarshal(password, &sshPasswordPayload)
+	if err != nil {
+		return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
+	}
+
+	userOk := (subtle.ConstantTimeCompare(
+		[]byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
+
+	passwordOk := (subtle.ConstantTimeCompare(
+		[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
+
+	if !userOk || !passwordOk {
+		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
+	}
+
+	psiphonSessionID := sshPasswordPayload.SessionId
+
+	sshClient.Lock()
+	sshClient.psiphonSessionID = psiphonSessionID
+	geoIPData := sshClient.geoIPData
+	sshClient.Unlock()
+
+	if sshClient.sshServer.config.UseRedis() {
+		err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
+		if err != nil {
+			log.WithContextFields(LogFields{
+				"psiphonSessionID": psiphonSessionID,
+				"error":            err}).Warning("UpdateRedisForLegacyPsiWeb failed")
+			// Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
+		}
+	}
+
+	return nil, nil
+}
+
+func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
+	if err != nil {
+		if sshClient.sshServer.config.UseFail2Ban() {
+			clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
+			if clientIPAddress != "" {
+				LogFail2Ban(clientIPAddress)
+			}
+		}
+		log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
+	} else {
+		log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
+	}
+}
+
+func (sshClient *sshClient) stop() {
+
+	sshClient.sshConn.Close()
+	sshClient.sshConn.Wait()
+
+	close(sshClient.stopBroadcast)
+	sshClient.channelHandlerWaitGroup.Wait()
+
+	sshClient.Lock()
+	log.WithContextFields(
+		LogFields{
+			"startTime":                         sshClient.startTime,
+			"duration":                          time.Now().Sub(sshClient.startTime),
+			"psiphonSessionID":                  sshClient.psiphonSessionID,
+			"country":                           sshClient.geoIPData.Country,
+			"city":                              sshClient.geoIPData.City,
+			"ISP":                               sshClient.geoIPData.ISP,
+			"bytesUpTCP":                        sshClient.tcpTrafficState.bytesUp,
+			"bytesDownTCP":                      sshClient.tcpTrafficState.bytesDown,
+			"peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
+			"totalPortForwardCountTCP":          sshClient.tcpTrafficState.totalPortForwardCount,
+			"bytesUpUDP":                        sshClient.udpTrafficState.bytesUp,
+			"bytesDownUDP":                      sshClient.udpTrafficState.bytesDown,
+			"peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
+			"totalPortForwardCountUDP":          sshClient.udpTrafficState.totalPortForwardCount,
+		}).Info("tunnel closed")
+	sshClient.Unlock()
+}
+
 func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
 	for newChannel := range channels {
 
@@ -652,11 +741,43 @@ func (sshClient *sshClient) handleTCPChannel(
 		sshClient.tcpTrafficState,
 		sshClient.trafficRules.MaxTCPPortForwardCount) {
 
-		sshClient.rejectNewChannel(
-			newChannel, ssh.Prohibited, "maximum port forward limit exceeded")
-		return
+		// Close the oldest TCP port forward. CloseOldest() closes
+		// the conn and the port forward's goroutine will complete
+		// the cleanup asynchronously.
+		//
+		// Some known limitations:
+		//
+		// - Since CloseOldest() closes the upstream socket but does not
+		//   clean up all resources associated with the port forward. These
+		//   include the goroutine(s) relaying traffic as well as the SSH
+		//   channel. Closing the socket will interrupt the goroutines which
+		//   will then complete the cleanup. But, since the full cleanup is
+		//   asynchronous, there exists a possibility that a client can consume
+		//   more than max port forward resources -- just not upstream sockets.
+		//
+		// - An LRU list entry for this port forward is not added until
+		//   after the dial completes, but the port forward is counted
+		//   towards max limits. This means many dials in progress will
+		//   put established connections in jeopardy.
+		//
+		// - We're closing the oldest open connection _before_ successfully
+		//   dialing the new port forward. This means we are potentially
+		//   discarding a good connection to make way for a failed connection.
+		//   We cannot simply dial first and still maintain a limit on
+		//   resources used, so to address this we'd need to add some
+		//   accounting for connections still establishing.
+
+		sshClient.tcpPortForwardLRU.CloseOldest()
+
+		log.WithContextFields(
+			LogFields{
+				"maxCount": sshClient.trafficRules.MaxTCPPortForwardCount,
+			}).Debug("closed LRU TCP port forward")
 	}
 
+	// Dial the target remote address. This is done in a goroutine to
+	// ensure the shutdown signal is handled immediately.
+
 	remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
 
 	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
@@ -689,9 +810,25 @@ func (sshClient *sshClient) handleTCPChannel(
 		return
 	}
 
+	// The upstream TCP port forward connection has been established. Schedule
+	// some cleanup and notify the SSH client that the channel is accepted.
+
 	fwdConn := result.conn
 	defer fwdConn.Close()
 
+	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
+	defer lruEntry.Remove()
+
+	// ActivityMonitoredConn monitors the TCP port forward I/O and updates
+	// its LRU status. ActivityMonitoredConn also times out read on the port
+	// forward if both reads and writes have been idle for the specified
+	// duration.
+	fwdConn = psiphon.NewActivityMonitoredConn(
+		fwdConn,
+		time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond,
+		true,
+		lruEntry)
+
 	fwdChannel, requests, err := newChannel.Accept()
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
@@ -702,21 +839,9 @@ func (sshClient *sshClient) handleTCPChannel(
 
 	log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
 
-	// When idle port forward traffic rules are in place, wrap fwdConn
-	// in an IdleTimeoutConn configured to reset idle on writes as well
-	// as read. This ensures the port forward idle timeout only happens
-	// when both upstream and downstream directions are are idle.
-
-	if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
-		fwdConn = psiphon.NewIdleTimeoutConn(
-			fwdConn,
-			time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond,
-			true)
-	}
-
 	// Relay channel to forwarded connection.
-	// TODO: relay errors to fwdChannel.Stderr()?
 
+	// TODO: relay errors to fwdChannel.Stderr()?
 	relayWaitGroup := new(sync.WaitGroup)
 	relayWaitGroup.Add(1)
 	go func() {
@@ -757,86 +882,3 @@ func (sshClient *sshClient) handleTCPChannel(
 			"bytesUp":    atomic.LoadInt64(&bytesUp),
 			"bytesDown":  atomic.LoadInt64(&bytesDown)}).Debug("exiting")
 }
-
-func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
-	var sshPasswordPayload struct {
-		SessionId   string `json:"SessionId"`
-		SshPassword string `json:"SshPassword"`
-	}
-	err := json.Unmarshal(password, &sshPasswordPayload)
-	if err != nil {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
-	}
-
-	userOk := (subtle.ConstantTimeCompare(
-		[]byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
-
-	passwordOk := (subtle.ConstantTimeCompare(
-		[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
-
-	if !userOk || !passwordOk {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
-	}
-
-	psiphonSessionID := sshPasswordPayload.SessionId
-
-	sshClient.Lock()
-	sshClient.psiphonSessionID = psiphonSessionID
-	geoIPData := sshClient.geoIPData
-	sshClient.Unlock()
-
-	if sshClient.sshServer.config.UseRedis() {
-		err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
-		if err != nil {
-			log.WithContextFields(LogFields{
-				"psiphonSessionID": psiphonSessionID,
-				"error":            err}).Warning("UpdateRedisForLegacyPsiWeb failed")
-			// Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
-		}
-	}
-
-	return nil, nil
-}
-
-func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
-	if err != nil {
-		if sshClient.sshServer.config.UseFail2Ban() {
-			clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
-			if clientIPAddress != "" {
-				LogFail2Ban(clientIPAddress)
-			}
-		}
-		log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
-	} else {
-		log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
-	}
-}
-
-func (sshClient *sshClient) stop() {
-
-	sshClient.sshConn.Close()
-	sshClient.sshConn.Wait()
-
-	close(sshClient.stopBroadcast)
-	sshClient.channelHandlerWaitGroup.Wait()
-
-	sshClient.Lock()
-	log.WithContextFields(
-		LogFields{
-			"startTime":                         sshClient.startTime,
-			"duration":                          time.Now().Sub(sshClient.startTime),
-			"psiphonSessionID":                  sshClient.psiphonSessionID,
-			"country":                           sshClient.geoIPData.Country,
-			"city":                              sshClient.geoIPData.City,
-			"ISP":                               sshClient.geoIPData.ISP,
-			"bytesUpTCP":                        sshClient.tcpTrafficState.bytesUp,
-			"bytesDownTCP":                      sshClient.tcpTrafficState.bytesDown,
-			"peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
-			"totalPortForwardCountTCP":          sshClient.tcpTrafficState.totalPortForwardCount,
-			"bytesUpUDP":                        sshClient.udpTrafficState.bytesUp,
-			"bytesDownUDP":                      sshClient.udpTrafficState.bytesDown,
-			"peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
-			"totalPortForwardCountUDP":          sshClient.udpTrafficState.totalPortForwardCount,
-		}).Info("tunnel closed")
-	sshClient.Unlock()
-}

+ 34 - 41
psiphon/server/udp.go

@@ -24,7 +24,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"math"
 	"net"
 	"strconv"
 	"sync"
@@ -75,6 +74,7 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 		sshClient:      sshClient,
 		sshChannel:     sshChannel,
 		portForwards:   make(map[uint16]*udpPortForward),
+		portForwardLRU: psiphon.NewLRUConns(),
 		relayWaitGroup: new(sync.WaitGroup),
 	}
 	multiplexer.run()
@@ -86,6 +86,7 @@ type udpPortForwardMultiplexer struct {
 	portForwardsMutex sync.Mutex
 	portForwards      map[uint16]*udpPortForward
 	relayWaitGroup    *sync.WaitGroup
+	portForwardLRU    *psiphon.LRUConns
 }
 
 func (mux *udpPortForwardMultiplexer) run() {
@@ -158,10 +159,18 @@ func (mux *udpPortForwardMultiplexer) run() {
 				mux.sshClient.tcpTrafficState,
 				mux.sshClient.trafficRules.MaxUDPPortForwardCount) {
 
-				// When the UDP port forward limit is exceeded, we
-				// select the least recently used (read from or written
-				// to) port forward and discard it.
-				mux.closeLeastRecentlyUsedPortForward()
+				// Close the oldest UDP port forward. CloseOldest() closes
+				// the conn and the port forward's goroutine will complete
+				// the cleanup asynchronously.
+				//
+				// See LRU comment in handleTCPChannel() for a known
+				// limitations regarding CloseOldest().
+				mux.portForwardLRU.CloseOldest()
+
+				log.WithContextFields(
+					LogFields{
+						"maxCount": mux.sshClient.trafficRules.MaxUDPPortForwardCount,
+					}).Debug("closed LRU UDP port forward")
 			}
 
 			dialIP := net.IP(message.remoteIP)
@@ -186,20 +195,17 @@ func (mux *udpPortForwardMultiplexer) run() {
 				continue
 			}
 
-			// When idle port forward traffic rules are in place, wrap updConn
-			// in an IdleTimeoutConn configured to reset idle on writes as well
-			// as reads. This ensures the port forward idle timeout only happens
-			// when both upstream and downstream directions are are idle.
-
-			var conn net.Conn
-			if mux.sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
-				conn = psiphon.NewIdleTimeoutConn(
-					udpConn,
-					time.Duration(mux.sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond,
-					true)
-			} else {
-				conn = udpConn
-			}
+			lruEntry := mux.portForwardLRU.Add(udpConn)
+
+			// ActivityMonitoredConn monitors the TCP port forward I/O and updates
+			// its LRU status. ActivityMonitoredConn also times out read on the port
+			// forward if both reads and writes have been idle for the specified
+			// duration.
+			conn := psiphon.NewActivityMonitoredConn(
+				udpConn,
+				time.Duration(mux.sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds)*time.Millisecond,
+				true,
+				lruEntry)
 
 			portForward = &udpPortForward{
 				connID:       message.connID,
@@ -207,7 +213,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 				remoteIP:     message.remoteIP,
 				remotePort:   message.remotePort,
 				conn:         conn,
-				lastActivity: time.Now().UnixNano(),
+				lruEntry:     lruEntry,
 				bytesUp:      0,
 				bytesDown:    0,
 				mux:          mux,
@@ -229,7 +235,9 @@ func (mux *udpPortForwardMultiplexer) run() {
 			// The port forward's goroutine will complete cleanup
 			portForward.conn.Close()
 		}
-		atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
+
+		portForward.lruEntry.Touch()
+
 		atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
 	}
 
@@ -245,24 +253,6 @@ func (mux *udpPortForwardMultiplexer) run() {
 	mux.relayWaitGroup.Wait()
 }
 
-func (mux *udpPortForwardMultiplexer) closeLeastRecentlyUsedPortForward() {
-	// TODO: use "container/list" and avoid a linear scan? However,
-	// move-to-front on each read/write would require mutex lock?
-	mux.portForwardsMutex.Lock()
-	oldestActivity := int64(math.MaxInt64)
-	var oldestPortForward *udpPortForward
-	for _, nextPortForward := range mux.portForwards {
-		if nextPortForward.lastActivity < oldestActivity {
-			oldestPortForward = nextPortForward
-		}
-	}
-	if oldestPortForward != nil {
-		// The port forward's goroutine will complete cleanup
-		oldestPortForward.conn.Close()
-	}
-	mux.portForwardsMutex.Unlock()
-}
-
 func (mux *udpPortForwardMultiplexer) transparentDNSAddress(
 	dialIP net.IP, dialPort int) (net.IP, int) {
 
@@ -288,7 +278,7 @@ type udpPortForward struct {
 	remoteIP     []byte
 	remotePort   uint16
 	conn         net.Conn
-	lastActivity int64
+	lruEntry     *psiphon.LRUConnsEntry
 	bytesUp      int64
 	bytesDown    int64
 	mux          *udpPortForwardMultiplexer
@@ -340,12 +330,15 @@ func (portForward *udpPortForward) relayDownstream() {
 			break
 		}
 
-		atomic.StoreInt64(&portForward.lastActivity, time.Now().UnixNano())
+		portForward.lruEntry.Touch()
+
 		atomic.AddInt64(&portForward.bytesDown, int64(packetSize))
 	}
 
 	portForward.mux.removePortForward(portForward.connID)
 
+	portForward.lruEntry.Remove()
+
 	portForward.conn.Close()
 
 	bytesUp := atomic.LoadInt64(&portForward.bytesUp)