|
|
@@ -284,6 +284,13 @@ func (server *TunnelServer) ResetAllClientOSLConfigs() {
|
|
|
server.sshServer.resetAllClientOSLConfigs()
|
|
|
}
|
|
|
|
|
|
+type ClientHandshakeStateInfo struct {
|
|
|
+ ActiveAuthorizationIDs []string
|
|
|
+ AuthorizedAccessTypes []string
|
|
|
+ UpstreamBytesPerSecond int64
|
|
|
+ DownstreamBytesPerSecond int64
|
|
|
+}
|
|
|
+
|
|
|
// SetClientHandshakeState sets the handshake state -- that it completed and
|
|
|
// what parameters were passed -- in sshClient. This state is used for allowing
|
|
|
// port forwards and for future traffic rule selection. SetClientHandshakeState
|
|
|
@@ -293,12 +300,14 @@ func (server *TunnelServer) ResetAllClientOSLConfigs() {
|
|
|
//
|
|
|
// The authorizations received from the client handshake are verified and the
|
|
|
// resulting list of authorized access types are applied to the client's tunnel
|
|
|
-// and traffic rules. A list of active authorization IDs and authorized access
|
|
|
-// types is returned for responding to the client and logging.
|
|
|
+// and traffic rules.
|
|
|
+//
|
|
|
+// A list of active authorization IDs, authorized access types, and traffic
|
|
|
+// rate limits are returned for responding to the client and logging.
|
|
|
func (server *TunnelServer) SetClientHandshakeState(
|
|
|
sessionID string,
|
|
|
state handshakeState,
|
|
|
- authorizations []string) ([]string, []string, error) {
|
|
|
+ authorizations []string) (*ClientHandshakeStateInfo, error) {
|
|
|
|
|
|
return server.sshServer.setClientHandshakeState(sessionID, state, authorizations)
|
|
|
}
|
|
|
@@ -908,23 +917,23 @@ func (sshServer *sshServer) resetAllClientOSLConfigs() {
|
|
|
func (sshServer *sshServer) setClientHandshakeState(
|
|
|
sessionID string,
|
|
|
state handshakeState,
|
|
|
- authorizations []string) ([]string, []string, error) {
|
|
|
+ authorizations []string) (*ClientHandshakeStateInfo, error) {
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
client := sshServer.clients[sessionID]
|
|
|
sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
if client == nil {
|
|
|
- return nil, nil, errors.TraceNew("unknown session ID")
|
|
|
+ return nil, errors.TraceNew("unknown session ID")
|
|
|
}
|
|
|
|
|
|
- activeAuthorizationIDs, authorizedAccessTypes, err := client.setHandshakeState(
|
|
|
+ clientHandshakeStateInfo, err := client.setHandshakeState(
|
|
|
state, authorizations)
|
|
|
if err != nil {
|
|
|
- return nil, nil, errors.Trace(err)
|
|
|
+ return nil, errors.Trace(err)
|
|
|
}
|
|
|
|
|
|
- return activeAuthorizationIDs, authorizedAccessTypes, nil
|
|
|
+ return clientHandshakeStateInfo, nil
|
|
|
}
|
|
|
|
|
|
func (sshServer *sshServer) getClientHandshaked(
|
|
|
@@ -2508,7 +2517,7 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, logMessa
|
|
|
// sshClient.stop().
|
|
|
func (sshClient *sshClient) setHandshakeState(
|
|
|
state handshakeState,
|
|
|
- authorizations []string) ([]string, []string, error) {
|
|
|
+ authorizations []string) (*ClientHandshakeStateInfo, error) {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
completed := sshClient.handshakeState.completed
|
|
|
@@ -2519,7 +2528,7 @@ func (sshClient *sshClient) setHandshakeState(
|
|
|
|
|
|
// Client must only perform one handshake
|
|
|
if completed {
|
|
|
- return nil, nil, errors.TraceNew("handshake already completed")
|
|
|
+ return nil, errors.TraceNew("handshake already completed")
|
|
|
}
|
|
|
|
|
|
// Verify the authorizations submitted by the client. Verified, active
|
|
|
@@ -2653,10 +2662,16 @@ func (sshClient *sshClient) setHandshakeState(
|
|
|
sshClient.Unlock()
|
|
|
}
|
|
|
|
|
|
- sshClient.setTrafficRules()
|
|
|
+ upstreamBytesPerSecond, downstreamBytesPerSecond := sshClient.setTrafficRules()
|
|
|
+
|
|
|
sshClient.setOSLConfig()
|
|
|
|
|
|
- return authorizationIDs, authorizedAccessTypes, nil
|
|
|
+ return &ClientHandshakeStateInfo{
|
|
|
+ ActiveAuthorizationIDs: authorizationIDs,
|
|
|
+ AuthorizedAccessTypes: authorizedAccessTypes,
|
|
|
+ UpstreamBytesPerSecond: upstreamBytesPerSecond,
|
|
|
+ DownstreamBytesPerSecond: downstreamBytesPerSecond,
|
|
|
+ }, nil
|
|
|
}
|
|
|
|
|
|
// getHandshaked returns whether the client has completed a handshake API
|
|
|
@@ -2719,7 +2734,7 @@ func (sshClient *sshClient) expectDomainBytes() bool {
|
|
|
// setTrafficRules resets the client's traffic rules based on the latest server config
|
|
|
// and client properties. As sshClient.trafficRules may be reset by a concurrent
|
|
|
// goroutine, trafficRules must only be accessed within the sshClient mutex.
|
|
|
-func (sshClient *sshClient) setTrafficRules() {
|
|
|
+func (sshClient *sshClient) setTrafficRules() (int64, int64) {
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
|
|
|
@@ -2734,6 +2749,9 @@ func (sshClient *sshClient) setTrafficRules() {
|
|
|
sshClient.throttledConn.SetLimits(
|
|
|
sshClient.trafficRules.RateLimits.CommonRateLimits())
|
|
|
}
|
|
|
+
|
|
|
+ return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
|
|
|
+ *sshClient.trafficRules.RateLimits.WriteBytesPerSecond
|
|
|
}
|
|
|
|
|
|
// setOSLConfig resets the client's OSL seed state based on the latest OSL config
|