|
|
@@ -198,10 +198,16 @@ func (server *TunnelServer) SelectAllClientTrafficRules() {
|
|
|
server.sshServer.selectAllClientTrafficRules()
|
|
|
}
|
|
|
|
|
|
-// SelectClientTrafficRules resets a specified client's traffic rules
|
|
|
-// to use the latest server config and client state.
|
|
|
-func (server *TunnelServer) SelectClientTrafficRules(sessionID string) {
|
|
|
- server.sshServer.selectClientTrafficRules(sessionID)
|
|
|
+// SetClientHandshakeState sets the handshake state -- that it completed and
|
|
|
+// what paramaters were passed -- in sshClient. This state is used for allowing
|
|
|
+// port forwards and for future traffic rule selection. SetClientHandshakeState
|
|
|
+// also triggers an immediate traffic rule re-selection, as the rules selected
|
|
|
+// upon tunnel establishment may no longer apply now that handshake values are
|
|
|
+// set.
|
|
|
+func (server *TunnelServer) SetClientHandshakeState(
|
|
|
+ sessionID string, state handshakeState) error {
|
|
|
+
|
|
|
+ return server.sshServer.setClientHandshakeState(sessionID, state)
|
|
|
}
|
|
|
|
|
|
type sshServer struct {
|
|
|
@@ -432,15 +438,25 @@ func (sshServer *sshServer) selectAllClientTrafficRules() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) selectClientTrafficRules(sessionID string) {
|
|
|
+func (sshServer *sshServer) setClientHandshakeState(
|
|
|
+ sessionID string, state handshakeState) error {
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
client := sshServer.clients[sessionID]
|
|
|
sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
- if client != nil {
|
|
|
- client.selectTrafficRules()
|
|
|
+ if client == nil {
|
|
|
+ return common.ContextError(errors.New("unknown session ID"))
|
|
|
+ }
|
|
|
+
|
|
|
+ err := client.setHandshakeState(state)
|
|
|
+ if err != nil {
|
|
|
+ return common.ContextError(err)
|
|
|
}
|
|
|
+
|
|
|
+ client.selectTrafficRules()
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (sshServer *sshServer) stopClients() {
|
|
|
@@ -593,6 +609,7 @@ type sshClient struct {
|
|
|
throttledConn *common.ThrottledConn
|
|
|
geoIPData GeoIPData
|
|
|
sessionID string
|
|
|
+ handshakeState handshakeState
|
|
|
udpChannel ssh.Channel
|
|
|
trafficRules TrafficRules
|
|
|
tcpTrafficState trafficState
|
|
|
@@ -613,6 +630,14 @@ type trafficState struct {
|
|
|
totalPortForwardCount int64
|
|
|
}
|
|
|
|
|
|
+type handshakeState struct {
|
|
|
+ completed bool
|
|
|
+ propagationChannelID string
|
|
|
+ sponsorID string
|
|
|
+ clientVersion int
|
|
|
+ clientPlatform string
|
|
|
+}
|
|
|
+
|
|
|
func newSshClient(
|
|
|
sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
|
|
|
return &sshClient{
|
|
|
@@ -734,9 +759,17 @@ func (sshClient *sshClient) stop() {
|
|
|
// it did the client may not have the opportunity to send a final
|
|
|
// request with an EOF flag set.)
|
|
|
|
|
|
+ // TODO: match legacy log field naming convention?
|
|
|
+ // TODO: log all handshake common inputs?
|
|
|
+
|
|
|
sshClient.Lock()
|
|
|
log.WithContextFields(
|
|
|
LogFields{
|
|
|
+ "handshakeCompleted": sshClient.handshakeState.completed,
|
|
|
+ "propagationChannelID": sshClient.handshakeState.propagationChannelID,
|
|
|
+ "sponsorID": sshClient.handshakeState.sponsorID,
|
|
|
+ "clientVersion": sshClient.handshakeState.clientVersion,
|
|
|
+ "clientPlatform": sshClient.handshakeState.clientPlatform,
|
|
|
"startTime": sshClient.activityConn.GetStartTime(),
|
|
|
"duration": sshClient.activityConn.GetActiveDuration(),
|
|
|
"sessionID": sshClient.sessionID,
|
|
|
@@ -854,6 +887,25 @@ func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChanne
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// setHandshakeState records that a client has completed a handshake API request.
|
|
|
+// Some parameters from the handshake request may be used in future traffic rule
|
|
|
+// selection. Port forwards are disallowed until a handshake is complete. The
|
|
|
+// handshake parameters are included in the session summary log recorded in
|
|
|
+// sshClient.stop().
|
|
|
+func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ // Client must only perform one handshake
|
|
|
+ if sshClient.handshakeState.completed {
|
|
|
+ return common.ContextError(errors.New("handshake already completed"))
|
|
|
+ }
|
|
|
+
|
|
|
+ sshClient.handshakeState = state
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
// selectTrafficRules resets the client's traffic rules based on the latest server config
|
|
|
// and client state. As sshClient.trafficRules may be reset by a concurrent goroutine,
|
|
|
// trafficRules must only be accessed within the sshClient mutex.
|
|
|
@@ -861,7 +913,9 @@ func (sshClient *sshClient) selectTrafficRules() {
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
|
|
|
- sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(sshClient.geoIPData.Country)
|
|
|
+ sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
|
|
|
+ // TODO: sshClient.geoIPData, sshClient.handshakeState)
|
|
|
+ sshClient.geoIPData.Country)
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) rateLimits() common.RateLimits {
|
|
|
@@ -894,10 +948,13 @@ const (
|
|
|
func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
portForwardType int, host string, port int) bool {
|
|
|
|
|
|
- // Mutex required for accessing sshClient.trafficRules
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
|
|
|
+ if !sshClient.handshakeState.completed {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
if common.Contains(SSH_DISALLOWED_PORT_FORWARD_HOSTS, host) {
|
|
|
return false
|
|
|
}
|
|
|
@@ -998,7 +1055,23 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
portToConnect int,
|
|
|
newChannel ssh.NewChannel) {
|
|
|
|
|
|
- if !sshClient.isPortForwardPermitted(
|
|
|
+ isWebServerPortForward := false
|
|
|
+ config := sshClient.sshServer.support.Config
|
|
|
+ if config.WebServerPortForwardAddress != "" {
|
|
|
+ destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
|
|
|
+ if destination == config.WebServerPortForwardAddress {
|
|
|
+ isWebServerPortForward = true
|
|
|
+ if config.WebServerPortForwardRedirectAddress != "" {
|
|
|
+ // Note: redirect format is validated when config is loaded
|
|
|
+ host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
|
|
|
+ port, _ := strconv.Atoi(portStr)
|
|
|
+ hostToConnect = host
|
|
|
+ portToConnect = port
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if !isWebServerPortForward && !sshClient.isPortForwardPermitted(
|
|
|
portForwardTypeTCP, hostToConnect, portToConnect) {
|
|
|
|
|
|
sshClient.rejectNewChannel(
|
|
|
@@ -1006,19 +1079,6 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // Note: redirects are applied *after* isPortForwardPermitted allows the original destination
|
|
|
- if sshClient.sshServer.support.Config.TCPPortForwardRedirects != nil {
|
|
|
- destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
|
|
|
- if redirect, ok := sshClient.sshServer.support.Config.TCPPortForwardRedirects[destination]; ok {
|
|
|
- // Note: redirect format is validated when config is loaded
|
|
|
- host, portStr, _ := net.SplitHostPort(redirect)
|
|
|
- port, _ := strconv.Atoi(portStr)
|
|
|
- hostToConnect = host
|
|
|
- portToConnect = port
|
|
|
- log.WithContextFields(LogFields{"destination": destination, "redirect": redirect}).Debug("port forward redirect")
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
var bytesUp, bytesDown int64
|
|
|
sshClient.openedPortForward(portForwardTypeTCP)
|
|
|
defer func() {
|