Browse Source

Throttling changes in progress
- Support on-the-fly change to sshClient's traffic rules.
- Invoke when traffic rules config hot reloaded.
- Change sshClient lookup ID to be sessionID;
this is a precursor to enabling handshake to
lookup sshClient, store API params, and re-select
traffic rules.

Rod Hynes 9 years ago
parent
commit
7676581338
3 changed files with 196 additions and 89 deletions
  1. 3 0
      psiphon/server/services.go
  2. 185 74
      psiphon/server/tunnelServer.go
  3. 8 15
      psiphon/server/udp.go

+ 3 - 0
psiphon/server/services.go

@@ -130,6 +130,9 @@ loop:
 		select {
 		case <-reloadSupportServicesSignal:
 			supportServices.Reload()
+			// Reselect traffic rules for established clients to reflect reloaded config
+			// TODO: only update when traffic rules config has changed
+			tunnelServer.SelectAllClientTrafficRules()
 		case <-logServerLoadSignal:
 			logServerLoad(tunnelServer)
 		case <-systemStopSignal:

+ 185 - 74
psiphon/server/tunnelServer.go

@@ -81,14 +81,6 @@ func NewTunnelServer(
 	}, nil
 }
 
-// GetLoadStats returns load stats for the tunnel server. The stats are
-// broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
-// include current connected client count, total number of current port
-// forwards.
-func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
-	return server.sshServer.getLoadStats()
-}
-
 // Run runs the tunnel server; this function blocks while running a selection of
 // listeners that handle connection using various obfuscation protocols.
 //
@@ -192,17 +184,34 @@ func (server *TunnelServer) Run() error {
 	return err
 }
 
-type sshClientID uint64
+// GetLoadStats returns load stats for the tunnel server. The stats are
+// broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
+// include current connected client count, total number of current port
+// forwards.
+func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
+	return server.sshServer.getLoadStats()
+}
+
+// SelectAllClientTrafficRules resets all established client traffic rules
+// to use the latest server config and client state.
+func (server *TunnelServer) SelectAllClientTrafficRules() {
+	server.sshServer.selectAllClientTrafficRules()
+}
+
+// SelectClientTrafficRules resets a specified client's traffic rules
+// to use the latest server config and client state.
+func (server *TunnelServer) SelectClientTrafficRules(sessionID string) {
+	server.sshServer.selectClientTrafficRules(sessionID)
+}
 
 type sshServer struct {
 	support              *SupportServices
 	shutdownBroadcast    <-chan struct{}
 	sshHostKey           ssh.Signer
-	nextClientID         sshClientID
 	clientsMutex         sync.Mutex
 	stoppingClients      bool
 	acceptedClientCounts map[string]int64
-	clients              map[sshClientID]*sshClient
+	clients              map[string]*sshClient
 }
 
 func newSSHServer(
@@ -224,9 +233,8 @@ func newSSHServer(
 		support:              support,
 		shutdownBroadcast:    shutdownBroadcast,
 		sshHostKey:           signer,
-		nextClientID:         1,
 		acceptedClientCounts: make(map[string]int64),
-		clients:              make(map[sshClientID]*sshClient),
+		clients:              make(map[string]*sshClient),
 	}, nil
 }
 
@@ -321,28 +329,38 @@ func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol string) {
 // An established client has completed its SSH handshake and has a ssh.Conn. Registration is
 // for tracking the number of fully established clients and for maintaining a list of running
 // clients (for stopping at shutdown time).
-func (sshServer *sshServer) registerEstablishedClient(client *sshClient) (sshClientID, bool) {
+func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
 
 	sshServer.clientsMutex.Lock()
 	defer sshServer.clientsMutex.Unlock()
 
 	if sshServer.stoppingClients {
-		return 0, false
+		return false
 	}
 
-	clientID := sshServer.nextClientID
-	sshServer.nextClientID += 1
+	// In the case of a duplicate client sessionID, the previous client is closed.
+	// - Well-behaved clients generate pick a random sessionID that should be
+	//   unique (won't accidentally conflict) and hard to guess (can't be targetted
+	//   by a malicious client).
+	// - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected
+	//   and resestablished. In this case, when the same server is selected, this logic
+	//   will be hit; closing the old, dangling client is desirable.
+	// - Multi-tunnel clients should not normally use one server for multiple tunnels.
+	existingClient := sshServer.clients[client.sessionID]
+	if existingClient != nil {
+		existingClient.stop()
+	}
 
-	sshServer.clients[clientID] = client
+	sshServer.clients[client.sessionID] = client
 
-	return clientID, true
+	return true
 }
 
-func (sshServer *sshServer) unregisterEstablishedClient(clientID sshClientID) {
+func (sshServer *sshServer) unregisterEstablishedClient(sessionID string) {
 
 	sshServer.clientsMutex.Lock()
-	client := sshServer.clients[clientID]
-	delete(sshServer.clients, clientID)
+	client := sshServer.clients[sessionID]
+	delete(sshServer.clients, sessionID)
 	sshServer.clientsMutex.Unlock()
 
 	if client != nil {
@@ -400,12 +418,37 @@ func (sshServer *sshServer) getLoadStats() map[string]map[string]int64 {
 	return loadStats
 }
 
+func (sshServer *sshServer) selectAllClientTrafficRules() {
+
+	sshServer.clientsMutex.Lock()
+	clients := make(map[string]*sshClient)
+	for sessionID, client := range sshServer.clients {
+		clients[sessionID] = client
+	}
+	sshServer.clientsMutex.Unlock()
+
+	for _, client := range clients {
+		client.selectTrafficRules()
+	}
+}
+
+func (sshServer *sshServer) selectClientTrafficRules(sessionID string) {
+
+	sshServer.clientsMutex.Lock()
+	client := sshServer.clients[sessionID]
+	sshServer.clientsMutex.Unlock()
+
+	if client != nil {
+		client.selectTrafficRules()
+	}
+}
+
 func (sshServer *sshServer) stopClients() {
 
 	sshServer.clientsMutex.Lock()
 	sshServer.stoppingClients = true
 	clients := sshServer.clients
-	sshServer.clients = make(map[sshClientID]*sshClient)
+	sshServer.clients = make(map[string]*sshClient)
 	sshServer.clientsMutex.Unlock()
 
 	for _, client := range clients {
@@ -421,13 +464,9 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	geoIPData := sshServer.support.GeoIPService.Lookup(
 		common.IPAddressFromAddr(clientConn.RemoteAddr()))
 
-	// TODO: apply reload of TrafficRulesSet to existing clients
+	sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
 
-	sshClient := newSshClient(
-		sshServer,
-		tunnelProtocol,
-		geoIPData,
-		sshServer.support.TrafficRulesSet.GetTrafficRules(geoIPData.Country))
+	sshClient.selectTrafficRules()
 
 	// Wrap the base client connection with an ActivityMonitoredConn which will
 	// terminate the connection if no data is received before the deadline. This
@@ -450,8 +489,8 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 
 	// Further wrap the connection in a rate limiting ThrottledConn.
 
-	clientConn = common.NewThrottledConn(
-		clientConn, sshClient.trafficRules.GetRateLimits(tunnelProtocol))
+	throttledConn := common.NewThrottledConn(clientConn, sshClient.rateLimits())
+	clientConn = throttledConn
 
 	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
 	// respect shutdownBroadcast and implement a specific handshake timeout.
@@ -529,15 +568,15 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	sshClient.Lock()
 	sshClient.sshConn = result.sshConn
 	sshClient.activityConn = activityConn
+	sshClient.throttledConn = throttledConn
 	sshClient.Unlock()
 
-	clientID, ok := sshServer.registerEstablishedClient(sshClient)
-	if !ok {
+	if !sshServer.registerEstablishedClient(sshClient) {
 		clientConn.Close()
 		log.WithContext().Warning("register failed")
 		return
 	}
-	defer sshServer.unregisterEstablishedClient(clientID)
+	defer sshServer.unregisterEstablishedClient(sshClient.sessionID)
 
 	sshClient.runClient(result.channels, result.requests)
 
@@ -551,12 +590,13 @@ type sshClient struct {
 	tunnelProtocol          string
 	sshConn                 ssh.Conn
 	activityConn            *common.ActivityMonitoredConn
+	throttledConn           *common.ThrottledConn
 	geoIPData               GeoIPData
-	psiphonSessionID        string
+	sessionID               string
 	udpChannel              ssh.Channel
 	trafficRules            TrafficRules
-	tcpTrafficState         *trafficState
-	udpTrafficState         *trafficState
+	tcpTrafficState         trafficState
+	udpTrafficState         trafficState
 	channelHandlerWaitGroup *sync.WaitGroup
 	tcpPortForwardLRU       *common.LRUConns
 	stopBroadcast           chan struct{}
@@ -574,14 +614,11 @@ type trafficState struct {
 }
 
 func newSshClient(
-	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
+	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
 	return &sshClient{
 		sshServer:               sshServer,
 		tunnelProtocol:          tunnelProtocol,
 		geoIPData:               geoIPData,
-		trafficRules:            trafficRules,
-		tcpTrafficState:         &trafficState{},
-		udpTrafficState:         &trafficState{},
 		channelHandlerWaitGroup: new(sync.WaitGroup),
 		tcpPortForwardLRU:       common.NewLRUConns(),
 		stopBroadcast:           make(chan struct{}),
@@ -590,6 +627,9 @@ func newSshClient(
 
 func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
 
+	expectedSessionIDLength := 2 * common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
+	expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
+
 	var sshPasswordPayload struct {
 		SessionId   string `json:"SessionId"`
 		SshPassword string `json:"SshPassword"`
@@ -601,15 +641,16 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 		// send the hex encoded session ID prepended to the SSH password.
 		// Note: there's an even older case where clients don't send any session ID,
 		// but that's no longer supported.
-		if len(password) == 2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH+2*SSH_PASSWORD_BYTE_LENGTH {
-			sshPasswordPayload.SessionId = string(password[0 : 2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH])
-			sshPasswordPayload.SshPassword = string(password[2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH : len(password)])
+		if len(password) == expectedSessionIDLength+expectedSSHPasswordLength {
+			sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength])
+			sshPasswordPayload.SshPassword = string(password[expectedSSHPasswordLength:len(password)])
 		} else {
 			return nil, common.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
 		}
 	}
 
-	if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) {
+	if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) ||
+		len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
 		return nil, common.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
 	}
 
@@ -623,17 +664,17 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 		return nil, common.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
 	}
 
-	psiphonSessionID := sshPasswordPayload.SessionId
+	sessionID := sshPasswordPayload.SessionId
 
 	sshClient.Lock()
-	sshClient.psiphonSessionID = psiphonSessionID
+	sshClient.sessionID = sessionID
 	geoIPData := sshClient.geoIPData
 	sshClient.Unlock()
 
 	// Store the GeoIP data associated with the session ID. This makes the GeoIP data
 	// available to the web server for web transport Psiphon API requests.
 	sshClient.sshServer.support.GeoIPService.SetSessionCache(
-		psiphonSessionID, geoIPData)
+		sessionID, geoIPData)
 
 	return nil, nil
 }
@@ -697,7 +738,7 @@ func (sshClient *sshClient) stop() {
 		LogFields{
 			"startTime":                         sshClient.activityConn.GetStartTime(),
 			"duration":                          sshClient.activityConn.GetActiveDuration(),
-			"psiphonSessionID":                  sshClient.psiphonSessionID,
+			"sessionID":                         sshClient.sessionID,
 			"country":                           sshClient.geoIPData.Country,
 			"city":                              sshClient.geoIPData.City,
 			"ISP":                               sshClient.geoIPData.ISP,
@@ -812,13 +853,64 @@ func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChanne
 	}
 }
 
+// selectTrafficRules resets the client's traffic rules based on the latest server config
+// and client state. As sshClient.trafficRules may be reset by a concurrent goroutine,
+// trafficRules must only be accessed within the sshClient mutex.
+func (sshClient *sshClient) selectTrafficRules() {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(sshClient.geoIPData.Country)
+}
+
+func (sshClient *sshClient) rateLimits() common.RateLimits {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	return sshClient.trafficRules.GetRateLimits(sshClient.tunnelProtocol)
+}
+
+func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	return time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
+}
+
+func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	return time.Duration(sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
+}
+
+const (
+	portForwardTypeTCP = iota
+	portForwardTypeUDP
+)
+
 func (sshClient *sshClient) isPortForwardPermitted(
-	host string, port int, allowPorts []int, denyPorts []int) bool {
+	portForwardType int, host string, port int) bool {
+
+	// Mutex required for accessing sshClient.trafficRules
+	sshClient.Lock()
+	defer sshClient.Unlock()
 
 	if common.Contains(SSH_DISALLOWED_PORT_FORWARD_HOSTS, host) {
 		return false
 	}
 
+	var allowPorts, denyPorts []int
+	if portForwardType == portForwardTypeTCP {
+		allowPorts = sshClient.trafficRules.AllowTCPPorts
+		denyPorts = sshClient.trafficRules.AllowTCPPorts
+	} else {
+		allowPorts = sshClient.trafficRules.AllowUDPPorts
+		denyPorts = sshClient.trafficRules.AllowUDPPorts
+
+	}
+
 	// TODO: faster lookup?
 	if len(allowPorts) > 0 {
 		for _, allowPort := range allowPorts {
@@ -841,37 +933,63 @@ func (sshClient *sshClient) isPortForwardPermitted(
 }
 
 func (sshClient *sshClient) isPortForwardLimitExceeded(
-	state *trafficState, maxPortForwardCount int) bool {
+	portForwardType int) (int, bool) {
+
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	var maxPortForwardCount int
+	var state *trafficState
+	if portForwardType == portForwardTypeTCP {
+		maxPortForwardCount = sshClient.trafficRules.MaxTCPPortForwardCount
+		state = &sshClient.tcpTrafficState
+	} else {
+		maxPortForwardCount = sshClient.trafficRules.MaxUDPPortForwardCount
+		state = &sshClient.udpTrafficState
+	}
 
-	limitExceeded := false
-	if maxPortForwardCount > 0 {
-		sshClient.Lock()
-		limitExceeded = state.concurrentPortForwardCount >= int64(maxPortForwardCount)
-		sshClient.Unlock()
+	if maxPortForwardCount > 0 && state.concurrentPortForwardCount >= int64(maxPortForwardCount) {
+		return maxPortForwardCount, true
 	}
-	return limitExceeded
+	return maxPortForwardCount, false
 }
 
 func (sshClient *sshClient) openedPortForward(
-	state *trafficState) {
+	portForwardType int) {
 
 	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	var state *trafficState
+	if portForwardType == portForwardTypeTCP {
+		state = &sshClient.tcpTrafficState
+	} else {
+		state = &sshClient.udpTrafficState
+	}
+
 	state.concurrentPortForwardCount += 1
 	if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
 		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
 	}
 	state.totalPortForwardCount += 1
-	sshClient.Unlock()
 }
 
 func (sshClient *sshClient) closedPortForward(
-	state *trafficState, bytesUp, bytesDown int64) {
+	portForwardType int, bytesUp, bytesDown int64) {
 
 	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	var state *trafficState
+	if portForwardType == portForwardTypeTCP {
+		state = &sshClient.tcpTrafficState
+	} else {
+		state = &sshClient.udpTrafficState
+	}
+
 	state.concurrentPortForwardCount -= 1
 	state.bytesUp += bytesUp
 	state.bytesDown += bytesDown
-	sshClient.Unlock()
 }
 
 func (sshClient *sshClient) handleTCPChannel(
@@ -880,10 +998,7 @@ func (sshClient *sshClient) handleTCPChannel(
 	newChannel ssh.NewChannel) {
 
 	if !sshClient.isPortForwardPermitted(
-		hostToConnect,
-		portToConnect,
-		sshClient.trafficRules.AllowTCPPorts,
-		sshClient.trafficRules.DenyTCPPorts) {
+		portForwardTypeTCP, hostToConnect, portToConnect) {
 
 		sshClient.rejectNewChannel(
 			newChannel, ssh.Prohibited, "port forward not permitted")
@@ -904,12 +1019,10 @@ func (sshClient *sshClient) handleTCPChannel(
 	}
 
 	var bytesUp, bytesDown int64
-	sshClient.openedPortForward(sshClient.tcpTrafficState)
+	sshClient.openedPortForward(portForwardTypeTCP)
 	defer func() {
 		sshClient.closedPortForward(
-			sshClient.tcpTrafficState,
-			atomic.LoadInt64(&bytesUp),
-			atomic.LoadInt64(&bytesDown))
+			portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
 	}()
 
 	// TOCTOU note: important to increment the port forward count (via
@@ -918,9 +1031,7 @@ func (sshClient *sshClient) handleTCPChannel(
 	// by initiating many port forwards concurrently.
 	// TODO: close LRU connection (after successful Dial) instead of
 	// rejecting new connection?
-	if sshClient.isPortForwardLimitExceeded(
-		sshClient.tcpTrafficState,
-		sshClient.trafficRules.MaxTCPPortForwardCount) {
+	if maxCount, exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
 
 		// Close the oldest TCP port forward. CloseOldest() closes
 		// the conn and the port forward's goroutine will complete
@@ -952,7 +1063,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 		log.WithContextFields(
 			LogFields{
-				"maxCount": sshClient.trafficRules.MaxTCPPortForwardCount,
+				"maxCount": maxCount,
 			}).Debug("closed LRU TCP port forward")
 	}
 
@@ -1015,7 +1126,7 @@ func (sshClient *sshClient) handleTCPChannel(
 
 	fwdConn, err = common.NewActivityMonitoredConn(
 		fwdConn,
-		time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond,
+		sshClient.idleTCPPortForwardTimeout(),
 		true,
 		lruEntry)
 	if result.err != nil {

+ 8 - 15
psiphon/server/udp.go

@@ -28,7 +28,6 @@ import (
 	"runtime/debug"
 	"sync"
 	"sync/atomic"
-	"time"
 
 	"github.com/Psiphon-Inc/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
@@ -162,23 +161,18 @@ func (mux *udpPortForwardMultiplexer) run() {
 			}
 
 			if !mux.sshClient.isPortForwardPermitted(
-				dialIP.String(),
-				int(message.remotePort),
-				mux.sshClient.trafficRules.AllowUDPPorts,
-				mux.sshClient.trafficRules.DenyUDPPorts) {
+				portForwardTypeUDP, dialIP.String(), int(message.remotePort)) {
 				// The udpgw protocol has no error response, so
 				// we just discard the message and read another.
 				continue
 			}
 
-			mux.sshClient.openedPortForward(mux.sshClient.udpTrafficState)
+			mux.sshClient.openedPortForward(portForwardTypeUDP)
 			// Note: can't defer sshClient.closedPortForward() here
 
 			// TOCTOU note: important to increment the port forward count (via
 			// openPortForward) _before_ checking isPortForwardLimitExceeded
-			if mux.sshClient.isPortForwardLimitExceeded(
-				mux.sshClient.tcpTrafficState,
-				mux.sshClient.trafficRules.MaxUDPPortForwardCount) {
+			if maxCount, exceeded := mux.sshClient.isPortForwardLimitExceeded(portForwardTypeUDP); exceeded {
 
 				// Close the oldest UDP port forward. CloseOldest() closes
 				// the conn and the port forward's goroutine will complete
@@ -190,7 +184,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 
 				log.WithContextFields(
 					LogFields{
-						"maxCount": mux.sshClient.trafficRules.MaxUDPPortForwardCount,
+						"maxCount": maxCount,
 					}).Debug("closed LRU UDP port forward")
 			}
 
@@ -203,7 +197,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 			udpConn, err := net.DialUDP(
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
 			if err != nil {
-				mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0)
+				mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0)
 				log.WithContextFields(LogFields{"error": err}).Warning("DialUDP failed")
 				continue
 			}
@@ -217,12 +211,12 @@ func (mux *udpPortForwardMultiplexer) run() {
 
 			conn, err := common.NewActivityMonitoredConn(
 				udpConn,
-				time.Duration(mux.sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds)*time.Millisecond,
+				mux.sshClient.idleUDPPortForwardTimeout(),
 				true,
 				lruEntry)
 			if err != nil {
 				lruEntry.Remove()
-				mux.sshClient.closedPortForward(mux.sshClient.udpTrafficState, 0, 0)
+				mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0)
 				log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
 				continue
 			}
@@ -354,8 +348,7 @@ func (portForward *udpPortForward) relayDownstream() {
 
 	bytesUp := atomic.LoadInt64(&portForward.bytesUp)
 	bytesDown := atomic.LoadInt64(&portForward.bytesDown)
-	portForward.mux.sshClient.closedPortForward(
-		portForward.mux.sshClient.udpTrafficState, bytesUp, bytesDown)
+	portForward.mux.sshClient.closedPortForward(portForwardTypeUDP, bytesUp, bytesDown)
 
 	log.WithContextFields(
 		LogFields{