|
|
@@ -81,14 +81,6 @@ func NewTunnelServer(
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
-// GetLoadStats returns load stats for the tunnel server. The stats are
|
|
|
-// broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
|
|
|
-// include current connected client count, total number of current port
|
|
|
-// forwards.
|
|
|
-func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
|
|
|
- return server.sshServer.getLoadStats()
|
|
|
-}
|
|
|
-
|
|
|
// Run runs the tunnel server; this function blocks while running a selection of
|
|
|
// listeners that handle connection using various obfuscation protocols.
|
|
|
//
|
|
|
@@ -192,17 +184,40 @@ func (server *TunnelServer) Run() error {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
-type sshClientID uint64
|
|
|
+// GetLoadStats returns load stats for the tunnel server. The stats are
|
|
|
+// broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
|
|
|
+// include current connected client count, total number of current port
|
|
|
+// forwards.
|
|
|
+func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
|
|
|
+ return server.sshServer.getLoadStats()
|
|
|
+}
|
|
|
+
|
|
|
+// ResetAllClientTrafficRules resets all established client traffic rules
|
|
|
+// to use the latest server config and client state.
|
|
|
+func (server *TunnelServer) ResetAllClientTrafficRules() {
|
|
|
+ server.sshServer.resetAllClientTrafficRules()
|
|
|
+}
|
|
|
+
|
|
|
+// SetClientHandshakeState sets the handshake state -- that it completed and
|
|
|
+// what paramaters were passed -- in sshClient. This state is used for allowing
|
|
|
+// port forwards and for future traffic rule selection. SetClientHandshakeState
|
|
|
+// also triggers an immediate traffic rule re-selection, as the rules selected
|
|
|
+// upon tunnel establishment may no longer apply now that handshake values are
|
|
|
+// set.
|
|
|
+func (server *TunnelServer) SetClientHandshakeState(
|
|
|
+ sessionID string, state handshakeState) error {
|
|
|
+
|
|
|
+ return server.sshServer.setClientHandshakeState(sessionID, state)
|
|
|
+}
|
|
|
|
|
|
type sshServer struct {
|
|
|
support *SupportServices
|
|
|
shutdownBroadcast <-chan struct{}
|
|
|
sshHostKey ssh.Signer
|
|
|
- nextClientID sshClientID
|
|
|
clientsMutex sync.Mutex
|
|
|
stoppingClients bool
|
|
|
acceptedClientCounts map[string]int64
|
|
|
- clients map[sshClientID]*sshClient
|
|
|
+ clients map[string]*sshClient
|
|
|
}
|
|
|
|
|
|
func newSSHServer(
|
|
|
@@ -224,9 +239,8 @@ func newSSHServer(
|
|
|
support: support,
|
|
|
shutdownBroadcast: shutdownBroadcast,
|
|
|
sshHostKey: signer,
|
|
|
- nextClientID: 1,
|
|
|
acceptedClientCounts: make(map[string]int64),
|
|
|
- clients: make(map[sshClientID]*sshClient),
|
|
|
+ clients: make(map[string]*sshClient),
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
@@ -321,28 +335,38 @@ func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol string) {
|
|
|
// An established client has completed its SSH handshake and has a ssh.Conn. Registration is
|
|
|
// for tracking the number of fully established clients and for maintaining a list of running
|
|
|
// clients (for stopping at shutdown time).
|
|
|
-func (sshServer *sshServer) registerEstablishedClient(client *sshClient) (sshClientID, bool) {
|
|
|
+func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
defer sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
if sshServer.stoppingClients {
|
|
|
- return 0, false
|
|
|
+ return false
|
|
|
}
|
|
|
|
|
|
- clientID := sshServer.nextClientID
|
|
|
- sshServer.nextClientID += 1
|
|
|
+ // In the case of a duplicate client sessionID, the previous client is closed.
|
|
|
+ // - Well-behaved clients generate pick a random sessionID that should be
|
|
|
+ // unique (won't accidentally conflict) and hard to guess (can't be targetted
|
|
|
+ // by a malicious client).
|
|
|
+ // - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected
|
|
|
+ // and resestablished. In this case, when the same server is selected, this logic
|
|
|
+ // will be hit; closing the old, dangling client is desirable.
|
|
|
+ // - Multi-tunnel clients should not normally use one server for multiple tunnels.
|
|
|
+ existingClient := sshServer.clients[client.sessionID]
|
|
|
+ if existingClient != nil {
|
|
|
+ existingClient.stop()
|
|
|
+ }
|
|
|
|
|
|
- sshServer.clients[clientID] = client
|
|
|
+ sshServer.clients[client.sessionID] = client
|
|
|
|
|
|
- return clientID, true
|
|
|
+ return true
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) unregisterEstablishedClient(clientID sshClientID) {
|
|
|
+func (sshServer *sshServer) unregisterEstablishedClient(sessionID string) {
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
- client := sshServer.clients[clientID]
|
|
|
- delete(sshServer.clients, clientID)
|
|
|
+ client := sshServer.clients[sessionID]
|
|
|
+ delete(sshServer.clients, sessionID)
|
|
|
sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
if client != nil {
|
|
|
@@ -400,12 +424,47 @@ func (sshServer *sshServer) getLoadStats() map[string]map[string]int64 {
|
|
|
return loadStats
|
|
|
}
|
|
|
|
|
|
+func (sshServer *sshServer) resetAllClientTrafficRules() {
|
|
|
+
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
+ clients := make(map[string]*sshClient)
|
|
|
+ for sessionID, client := range sshServer.clients {
|
|
|
+ clients[sessionID] = client
|
|
|
+ }
|
|
|
+ sshServer.clientsMutex.Unlock()
|
|
|
+
|
|
|
+ for _, client := range clients {
|
|
|
+ client.setTrafficRules()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (sshServer *sshServer) setClientHandshakeState(
|
|
|
+ sessionID string, state handshakeState) error {
|
|
|
+
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
+ client := sshServer.clients[sessionID]
|
|
|
+ sshServer.clientsMutex.Unlock()
|
|
|
+
|
|
|
+ if client == nil {
|
|
|
+ return common.ContextError(errors.New("unknown session ID"))
|
|
|
+ }
|
|
|
+
|
|
|
+ err := client.setHandshakeState(state)
|
|
|
+ if err != nil {
|
|
|
+ return common.ContextError(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ client.setTrafficRules()
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func (sshServer *sshServer) stopClients() {
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
sshServer.stoppingClients = true
|
|
|
clients := sshServer.clients
|
|
|
- sshServer.clients = make(map[sshClientID]*sshClient)
|
|
|
+ sshServer.clients = make(map[string]*sshClient)
|
|
|
sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
for _, client := range clients {
|
|
|
@@ -421,13 +480,10 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
geoIPData := sshServer.support.GeoIPService.Lookup(
|
|
|
common.IPAddressFromAddr(clientConn.RemoteAddr()))
|
|
|
|
|
|
- // TODO: apply reload of TrafficRulesSet to existing clients
|
|
|
+ sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
|
|
|
|
|
|
- sshClient := newSshClient(
|
|
|
- sshServer,
|
|
|
- tunnelProtocol,
|
|
|
- geoIPData,
|
|
|
- sshServer.support.TrafficRulesSet.GetTrafficRules(geoIPData.Country))
|
|
|
+ // Set initial traffic rules, pre-handshake, based on currently known info.
|
|
|
+ sshClient.setTrafficRules()
|
|
|
|
|
|
// Wrap the base client connection with an ActivityMonitoredConn which will
|
|
|
// terminate the connection if no data is received before the deadline. This
|
|
|
@@ -450,8 +506,8 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
|
|
|
// Further wrap the connection in a rate limiting ThrottledConn.
|
|
|
|
|
|
- clientConn = common.NewThrottledConn(
|
|
|
- clientConn, sshClient.trafficRules.GetRateLimits(tunnelProtocol))
|
|
|
+ throttledConn := common.NewThrottledConn(clientConn, sshClient.rateLimits())
|
|
|
+ clientConn = throttledConn
|
|
|
|
|
|
// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
|
|
|
// respect shutdownBroadcast and implement a specific handshake timeout.
|
|
|
@@ -529,15 +585,15 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
sshClient.Lock()
|
|
|
sshClient.sshConn = result.sshConn
|
|
|
sshClient.activityConn = activityConn
|
|
|
+ sshClient.throttledConn = throttledConn
|
|
|
sshClient.Unlock()
|
|
|
|
|
|
- clientID, ok := sshServer.registerEstablishedClient(sshClient)
|
|
|
- if !ok {
|
|
|
+ if !sshServer.registerEstablishedClient(sshClient) {
|
|
|
clientConn.Close()
|
|
|
log.WithContext().Warning("register failed")
|
|
|
return
|
|
|
}
|
|
|
- defer sshServer.unregisterEstablishedClient(clientID)
|
|
|
+ defer sshServer.unregisterEstablishedClient(sshClient.sessionID)
|
|
|
|
|
|
sshClient.runClient(result.channels, result.requests)
|
|
|
|
|
|
@@ -551,12 +607,14 @@ type sshClient struct {
|
|
|
tunnelProtocol string
|
|
|
sshConn ssh.Conn
|
|
|
activityConn *common.ActivityMonitoredConn
|
|
|
+ throttledConn *common.ThrottledConn
|
|
|
geoIPData GeoIPData
|
|
|
- psiphonSessionID string
|
|
|
+ sessionID string
|
|
|
+ handshakeState handshakeState
|
|
|
udpChannel ssh.Channel
|
|
|
trafficRules TrafficRules
|
|
|
- tcpTrafficState *trafficState
|
|
|
- udpTrafficState *trafficState
|
|
|
+ tcpTrafficState trafficState
|
|
|
+ udpTrafficState trafficState
|
|
|
channelHandlerWaitGroup *sync.WaitGroup
|
|
|
tcpPortForwardLRU *common.LRUConns
|
|
|
stopBroadcast chan struct{}
|
|
|
@@ -573,15 +631,18 @@ type trafficState struct {
|
|
|
totalPortForwardCount int64
|
|
|
}
|
|
|
|
|
|
+type handshakeState struct {
|
|
|
+ completed bool
|
|
|
+ apiProtocol string
|
|
|
+ apiParams requestJSONObject
|
|
|
+}
|
|
|
+
|
|
|
func newSshClient(
|
|
|
- sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
|
|
|
+ sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
|
|
|
return &sshClient{
|
|
|
sshServer: sshServer,
|
|
|
tunnelProtocol: tunnelProtocol,
|
|
|
geoIPData: geoIPData,
|
|
|
- trafficRules: trafficRules,
|
|
|
- tcpTrafficState: &trafficState{},
|
|
|
- udpTrafficState: &trafficState{},
|
|
|
channelHandlerWaitGroup: new(sync.WaitGroup),
|
|
|
tcpPortForwardLRU: common.NewLRUConns(),
|
|
|
stopBroadcast: make(chan struct{}),
|
|
|
@@ -590,6 +651,9 @@ func newSshClient(
|
|
|
|
|
|
func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
|
|
|
+ expectedSessionIDLength := 2 * common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
|
|
|
+ expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
|
|
|
+
|
|
|
var sshPasswordPayload struct {
|
|
|
SessionId string `json:"SessionId"`
|
|
|
SshPassword string `json:"SshPassword"`
|
|
|
@@ -601,15 +665,16 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
|
|
|
// send the hex encoded session ID prepended to the SSH password.
|
|
|
// Note: there's an even older case where clients don't send any session ID,
|
|
|
// but that's no longer supported.
|
|
|
- if len(password) == 2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH+2*SSH_PASSWORD_BYTE_LENGTH {
|
|
|
- sshPasswordPayload.SessionId = string(password[0 : 2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH])
|
|
|
- sshPasswordPayload.SshPassword = string(password[2*common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH : len(password)])
|
|
|
+ if len(password) == expectedSessionIDLength+expectedSSHPasswordLength {
|
|
|
+ sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength])
|
|
|
+ sshPasswordPayload.SshPassword = string(password[expectedSSHPasswordLength:len(password)])
|
|
|
} else {
|
|
|
return nil, common.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) {
|
|
|
+ if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) ||
|
|
|
+ len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
|
|
|
return nil, common.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
|
|
|
}
|
|
|
|
|
|
@@ -623,17 +688,18 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
|
|
|
return nil, common.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
|
|
|
}
|
|
|
|
|
|
- psiphonSessionID := sshPasswordPayload.SessionId
|
|
|
+ sessionID := sshPasswordPayload.SessionId
|
|
|
|
|
|
sshClient.Lock()
|
|
|
- sshClient.psiphonSessionID = psiphonSessionID
|
|
|
+ sshClient.sessionID = sessionID
|
|
|
geoIPData := sshClient.geoIPData
|
|
|
sshClient.Unlock()
|
|
|
|
|
|
// Store the GeoIP data associated with the session ID. This makes the GeoIP data
|
|
|
- // available to the web server for web transport Psiphon API requests.
|
|
|
- sshClient.sshServer.support.GeoIPService.SetSessionCache(
|
|
|
- psiphonSessionID, geoIPData)
|
|
|
+ // available to the web server for web transport Psiphon API requests. To allow for
|
|
|
+ // post-tunnel final status requests, the lifetime of cached GeoIP records exceeds
|
|
|
+ // the lifetime of the sshClient, and that's why this distinct session cache exists.
|
|
|
+ sshClient.sshServer.support.GeoIPService.SetSessionCache(sessionID, geoIPData)
|
|
|
|
|
|
return nil, nil
|
|
|
}
|
|
|
@@ -693,24 +759,30 @@ func (sshClient *sshClient) stop() {
|
|
|
// request with an EOF flag set.)
|
|
|
|
|
|
sshClient.Lock()
|
|
|
- log.WithContextFields(
|
|
|
- LogFields{
|
|
|
- "startTime": sshClient.activityConn.GetStartTime(),
|
|
|
- "duration": sshClient.activityConn.GetActiveDuration(),
|
|
|
- "psiphonSessionID": sshClient.psiphonSessionID,
|
|
|
- "country": sshClient.geoIPData.Country,
|
|
|
- "city": sshClient.geoIPData.City,
|
|
|
- "ISP": sshClient.geoIPData.ISP,
|
|
|
- "bytesUpTCP": sshClient.tcpTrafficState.bytesUp,
|
|
|
- "bytesDownTCP": sshClient.tcpTrafficState.bytesDown,
|
|
|
- "peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
|
|
|
- "totalPortForwardCountTCP": sshClient.tcpTrafficState.totalPortForwardCount,
|
|
|
- "bytesUpUDP": sshClient.udpTrafficState.bytesUp,
|
|
|
- "bytesDownUDP": sshClient.udpTrafficState.bytesDown,
|
|
|
- "peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
|
|
|
- "totalPortForwardCountUDP": sshClient.udpTrafficState.totalPortForwardCount,
|
|
|
- }).Info("tunnel closed")
|
|
|
+
|
|
|
+ logFields := getRequestLogFields(
|
|
|
+ sshClient.sshServer.support,
|
|
|
+ "server_tunnel",
|
|
|
+ sshClient.geoIPData,
|
|
|
+ sshClient.handshakeState.apiParams,
|
|
|
+ baseRequestParams)
|
|
|
+
|
|
|
+ // TODO: match legacy log field naming convention?
|
|
|
+ logFields["HandshakeCompleted"] = sshClient.handshakeState.completed
|
|
|
+ logFields["startTime"] = sshClient.activityConn.GetStartTime()
|
|
|
+ logFields["Duration"] = sshClient.activityConn.GetActiveDuration()
|
|
|
+ logFields["BytesUpTCP"] = sshClient.tcpTrafficState.bytesUp
|
|
|
+ logFields["BytesDownTCP"] = sshClient.tcpTrafficState.bytesDown
|
|
|
+ logFields["PeakConcurrentPortForwardCountTCP"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
|
|
|
+ logFields["TotalPortForwardCountTCP"] = sshClient.tcpTrafficState.totalPortForwardCount
|
|
|
+ logFields["BytesUpUDP"] = sshClient.udpTrafficState.bytesUp
|
|
|
+ logFields["BytesDownUDP"] = sshClient.udpTrafficState.bytesDown
|
|
|
+ logFields["PeakConcurrentPortForwardCountUDP"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
|
|
|
+ logFields["TotalPortForwardCountUDP"] = sshClient.udpTrafficState.totalPortForwardCount
|
|
|
+
|
|
|
sshClient.Unlock()
|
|
|
+
|
|
|
+ log.LogRawFieldsWithTimestamp(logFields)
|
|
|
}
|
|
|
|
|
|
// runClient handles/dispatches new channel and new requests from the client.
|
|
|
@@ -812,13 +884,87 @@ func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChanne
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// setHandshakeState records that a client has completed a handshake API request.
|
|
|
+// Some parameters from the handshake request may be used in future traffic rule
|
|
|
+// selection. Port forwards are disallowed until a handshake is complete. The
|
|
|
+// handshake parameters are included in the session summary log recorded in
|
|
|
+// sshClient.stop().
|
|
|
+func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ // Client must only perform one handshake
|
|
|
+ if sshClient.handshakeState.completed {
|
|
|
+ return common.ContextError(errors.New("handshake already completed"))
|
|
|
+ }
|
|
|
+
|
|
|
+ sshClient.handshakeState = state
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// setTrafficRules resets the client's traffic rules based on the latest server config
|
|
|
+// and client state. As sshClient.trafficRules may be reset by a concurrent goroutine,
|
|
|
+// trafficRules must only be accessed within the sshClient mutex.
|
|
|
+func (sshClient *sshClient) setTrafficRules() {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
|
|
|
+ sshClient.tunnelProtocol, sshClient.geoIPData, sshClient.handshakeState)
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) rateLimits() common.RateLimits {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ return sshClient.trafficRules.RateLimits.CommonRateLimits()
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
|
|
|
+}
|
|
|
+
|
|
|
+const (
|
|
|
+ portForwardTypeTCP = iota
|
|
|
+ portForwardTypeUDP
|
|
|
+)
|
|
|
+
|
|
|
func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
- host string, port int, allowPorts []int, denyPorts []int) bool {
|
|
|
+ portForwardType int, host string, port int) bool {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ if !sshClient.handshakeState.completed {
|
|
|
+ return false
|
|
|
+ }
|
|
|
|
|
|
if common.Contains(SSH_DISALLOWED_PORT_FORWARD_HOSTS, host) {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
+ var allowPorts, denyPorts []int
|
|
|
+ if portForwardType == portForwardTypeTCP {
|
|
|
+ allowPorts = sshClient.trafficRules.AllowTCPPorts
|
|
|
+ denyPorts = sshClient.trafficRules.AllowTCPPorts
|
|
|
+ } else {
|
|
|
+ allowPorts = sshClient.trafficRules.AllowUDPPorts
|
|
|
+ denyPorts = sshClient.trafficRules.AllowUDPPorts
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
// TODO: faster lookup?
|
|
|
if len(allowPorts) > 0 {
|
|
|
for _, allowPort := range allowPorts {
|
|
|
@@ -841,37 +987,63 @@ func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) isPortForwardLimitExceeded(
|
|
|
- state *trafficState, maxPortForwardCount int) bool {
|
|
|
+ portForwardType int) (int, bool) {
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ var maxPortForwardCount int
|
|
|
+ var state *trafficState
|
|
|
+ if portForwardType == portForwardTypeTCP {
|
|
|
+ maxPortForwardCount = *sshClient.trafficRules.MaxTCPPortForwardCount
|
|
|
+ state = &sshClient.tcpTrafficState
|
|
|
+ } else {
|
|
|
+ maxPortForwardCount = *sshClient.trafficRules.MaxUDPPortForwardCount
|
|
|
+ state = &sshClient.udpTrafficState
|
|
|
+ }
|
|
|
|
|
|
- limitExceeded := false
|
|
|
- if maxPortForwardCount > 0 {
|
|
|
- sshClient.Lock()
|
|
|
- limitExceeded = state.concurrentPortForwardCount >= int64(maxPortForwardCount)
|
|
|
- sshClient.Unlock()
|
|
|
+ if maxPortForwardCount > 0 && state.concurrentPortForwardCount >= int64(maxPortForwardCount) {
|
|
|
+ return maxPortForwardCount, true
|
|
|
}
|
|
|
- return limitExceeded
|
|
|
+ return maxPortForwardCount, false
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) openedPortForward(
|
|
|
- state *trafficState) {
|
|
|
+ portForwardType int) {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ var state *trafficState
|
|
|
+ if portForwardType == portForwardTypeTCP {
|
|
|
+ state = &sshClient.tcpTrafficState
|
|
|
+ } else {
|
|
|
+ state = &sshClient.udpTrafficState
|
|
|
+ }
|
|
|
+
|
|
|
state.concurrentPortForwardCount += 1
|
|
|
if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
|
|
|
state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
|
|
|
}
|
|
|
state.totalPortForwardCount += 1
|
|
|
- sshClient.Unlock()
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) closedPortForward(
|
|
|
- state *trafficState, bytesUp, bytesDown int64) {
|
|
|
+ portForwardType int, bytesUp, bytesDown int64) {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ var state *trafficState
|
|
|
+ if portForwardType == portForwardTypeTCP {
|
|
|
+ state = &sshClient.tcpTrafficState
|
|
|
+ } else {
|
|
|
+ state = &sshClient.udpTrafficState
|
|
|
+ }
|
|
|
+
|
|
|
state.concurrentPortForwardCount -= 1
|
|
|
state.bytesUp += bytesUp
|
|
|
state.bytesDown += bytesDown
|
|
|
- sshClient.Unlock()
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) handleTCPChannel(
|
|
|
@@ -879,37 +1051,35 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
portToConnect int,
|
|
|
newChannel ssh.NewChannel) {
|
|
|
|
|
|
- if !sshClient.isPortForwardPermitted(
|
|
|
- hostToConnect,
|
|
|
- portToConnect,
|
|
|
- sshClient.trafficRules.AllowTCPPorts,
|
|
|
- sshClient.trafficRules.DenyTCPPorts) {
|
|
|
+ isWebServerPortForward := false
|
|
|
+ config := sshClient.sshServer.support.Config
|
|
|
+ if config.WebServerPortForwardAddress != "" {
|
|
|
+ destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
|
|
|
+ if destination == config.WebServerPortForwardAddress {
|
|
|
+ isWebServerPortForward = true
|
|
|
+ if config.WebServerPortForwardRedirectAddress != "" {
|
|
|
+ // Note: redirect format is validated when config is loaded
|
|
|
+ host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
|
|
|
+ port, _ := strconv.Atoi(portStr)
|
|
|
+ hostToConnect = host
|
|
|
+ portToConnect = port
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if !isWebServerPortForward && !sshClient.isPortForwardPermitted(
|
|
|
+ portForwardTypeTCP, hostToConnect, portToConnect) {
|
|
|
|
|
|
sshClient.rejectNewChannel(
|
|
|
newChannel, ssh.Prohibited, "port forward not permitted")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // Note: redirects are applied *after* isPortForwardPermitted allows the original destination
|
|
|
- if sshClient.sshServer.support.Config.TCPPortForwardRedirects != nil {
|
|
|
- destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
|
|
|
- if redirect, ok := sshClient.sshServer.support.Config.TCPPortForwardRedirects[destination]; ok {
|
|
|
- // Note: redirect format is validated when config is loaded
|
|
|
- host, portStr, _ := net.SplitHostPort(redirect)
|
|
|
- port, _ := strconv.Atoi(portStr)
|
|
|
- hostToConnect = host
|
|
|
- portToConnect = port
|
|
|
- log.WithContextFields(LogFields{"destination": destination, "redirect": redirect}).Debug("port forward redirect")
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
var bytesUp, bytesDown int64
|
|
|
- sshClient.openedPortForward(sshClient.tcpTrafficState)
|
|
|
+ sshClient.openedPortForward(portForwardTypeTCP)
|
|
|
defer func() {
|
|
|
sshClient.closedPortForward(
|
|
|
- sshClient.tcpTrafficState,
|
|
|
- atomic.LoadInt64(&bytesUp),
|
|
|
- atomic.LoadInt64(&bytesDown))
|
|
|
+ portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
|
|
|
}()
|
|
|
|
|
|
// TOCTOU note: important to increment the port forward count (via
|
|
|
@@ -918,9 +1088,7 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
// by initiating many port forwards concurrently.
|
|
|
// TODO: close LRU connection (after successful Dial) instead of
|
|
|
// rejecting new connection?
|
|
|
- if sshClient.isPortForwardLimitExceeded(
|
|
|
- sshClient.tcpTrafficState,
|
|
|
- sshClient.trafficRules.MaxTCPPortForwardCount) {
|
|
|
+ if maxCount, exceeded := sshClient.isPortForwardLimitExceeded(portForwardTypeTCP); exceeded {
|
|
|
|
|
|
// Close the oldest TCP port forward. CloseOldest() closes
|
|
|
// the conn and the port forward's goroutine will complete
|
|
|
@@ -952,7 +1120,7 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
|
|
|
log.WithContextFields(
|
|
|
LogFields{
|
|
|
- "maxCount": sshClient.trafficRules.MaxTCPPortForwardCount,
|
|
|
+ "maxCount": maxCount,
|
|
|
}).Debug("closed LRU TCP port forward")
|
|
|
}
|
|
|
|
|
|
@@ -1015,7 +1183,7 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
|
|
|
fwdConn, err = common.NewActivityMonitoredConn(
|
|
|
fwdConn,
|
|
|
- time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond,
|
|
|
+ sshClient.idleTCPPortForwardTimeout(),
|
|
|
true,
|
|
|
lruEntry)
|
|
|
if result.err != nil {
|