|
|
@@ -35,20 +35,18 @@ import (
|
|
|
"github.com/Psiphon-Inc/goarista/monotime"
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
|
|
|
+ "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
|
|
|
+ "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
- SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
|
|
|
- SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
|
|
|
- SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT = 30 * time.Second
|
|
|
- SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
|
|
|
+ SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
|
|
|
+ SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
|
|
|
+ SSH_TCP_PORT_FORWARD_IP_LOOKUP_TIMEOUT = 30 * time.Second
|
|
|
+ SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT = 30 * time.Second
|
|
|
+ SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
|
|
|
)
|
|
|
|
|
|
-// Disallowed port forward hosts is a failsafe. The server should
|
|
|
-// be run on a host with correctly configured firewall rules, or
|
|
|
-// containerization, or both.
|
|
|
-var SSH_DISALLOWED_PORT_FORWARD_HOSTS = []string{"localhost", "127.0.0.1"}
|
|
|
-
|
|
|
// TunnelServer is the main server that accepts Psiphon client
|
|
|
// connections, via various obfuscation protocols, and provides
|
|
|
// port forwarding (TCP and UDP) services to the Psiphon client.
|
|
|
@@ -194,11 +192,19 @@ func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
|
|
|
}
|
|
|
|
|
|
// ResetAllClientTrafficRules resets all established client traffic rules
|
|
|
-// to use the latest server config and client state.
|
|
|
+// to use the latest config and client properties. Any existing traffic
|
|
|
+// rule state is lost, including throttling state.
|
|
|
func (server *TunnelServer) ResetAllClientTrafficRules() {
|
|
|
server.sshServer.resetAllClientTrafficRules()
|
|
|
}
|
|
|
|
|
|
+// ResetAllClientOSLConfigs resets all established client OSL state to use
|
|
|
+// the latest OSL config. Any existing OSL state is lost, including partial
|
|
|
+// progress towards SLOKs.
|
|
|
+func (server *TunnelServer) ResetAllClientOSLConfigs() {
|
|
|
+ server.sshServer.resetAllClientOSLConfigs()
|
|
|
+}
|
|
|
+
|
|
|
// SetClientHandshakeState sets the handshake state -- that it completed and
|
|
|
// what paramaters were passed -- in sshClient. This state is used for allowing
|
|
|
// port forwards and for future traffic rule selection. SetClientHandshakeState
|
|
|
@@ -211,6 +217,14 @@ func (server *TunnelServer) SetClientHandshakeState(
|
|
|
return server.sshServer.setClientHandshakeState(sessionID, state)
|
|
|
}
|
|
|
|
|
|
+// GetClientSeedPayload gets the current OSL seed payload for the specified
|
|
|
+// client session. Any seeded SLOKs are issued and included in the payload.
|
|
|
+func (server *TunnelServer) GetClientSeedPayload(
|
|
|
+ sessionID string) (*osl.SeedPayload, error) {
|
|
|
+
|
|
|
+ return server.sshServer.getClientSeedPayload(sessionID)
|
|
|
+}
|
|
|
+
|
|
|
// SetEstablishTunnels sets whether new tunnels may be established or not.
|
|
|
// When not establishing, incoming connections are immediately closed.
|
|
|
func (server *TunnelServer) SetEstablishTunnels(establish bool) {
|
|
|
@@ -310,13 +324,13 @@ func (sshServer *sshServer) runListener(
|
|
|
// TunnelServer.Run will properly shut down instead of remaining
|
|
|
// running.
|
|
|
|
|
|
- if common.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
|
|
|
- common.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
|
|
|
+ if protocol.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
|
|
|
+ protocol.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
|
|
|
|
|
|
meekServer, err := NewMeekServer(
|
|
|
sshServer.support,
|
|
|
listener,
|
|
|
- common.TunnelProtocolUsesMeekHTTPS(tunnelProtocol),
|
|
|
+ protocol.TunnelProtocolUsesMeekHTTPS(tunnelProtocol),
|
|
|
handleClient,
|
|
|
sshServer.shutdownBroadcast)
|
|
|
if err != nil {
|
|
|
@@ -511,6 +525,20 @@ func (sshServer *sshServer) resetAllClientTrafficRules() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func (sshServer *sshServer) resetAllClientOSLConfigs() {
|
|
|
+
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
+ clients := make(map[string]*sshClient)
|
|
|
+ for sessionID, client := range sshServer.clients {
|
|
|
+ clients[sessionID] = client
|
|
|
+ }
|
|
|
+ sshServer.clientsMutex.Unlock()
|
|
|
+
|
|
|
+ for _, client := range clients {
|
|
|
+ client.setOSLConfig()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (sshServer *sshServer) setClientHandshakeState(
|
|
|
sessionID string, state handshakeState) error {
|
|
|
|
|
|
@@ -527,11 +555,23 @@ func (sshServer *sshServer) setClientHandshakeState(
|
|
|
return common.ContextError(err)
|
|
|
}
|
|
|
|
|
|
- client.setTrafficRules()
|
|
|
-
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+func (sshServer *sshServer) getClientSeedPayload(
|
|
|
+ sessionID string) (*osl.SeedPayload, error) {
|
|
|
+
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
+ client := sshServer.clients[sessionID]
|
|
|
+ sshServer.clientsMutex.Unlock()
|
|
|
+
|
|
|
+ if client == nil {
|
|
|
+ return nil, common.ContextError(errors.New("unknown session ID"))
|
|
|
+ }
|
|
|
+
|
|
|
+ return client.getClientSeedPayload(), nil
|
|
|
+}
|
|
|
+
|
|
|
func (sshServer *sshServer) stopClients() {
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
@@ -555,6 +595,70 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
|
|
|
sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
|
|
|
|
|
|
+ sshClient.run(clientConn)
|
|
|
+}
|
|
|
+
|
|
|
+type sshClient struct {
|
|
|
+ sync.Mutex
|
|
|
+ sshServer *sshServer
|
|
|
+ tunnelProtocol string
|
|
|
+ sshConn ssh.Conn
|
|
|
+ activityConn *common.ActivityMonitoredConn
|
|
|
+ throttledConn *common.ThrottledConn
|
|
|
+ geoIPData GeoIPData
|
|
|
+ sessionID string
|
|
|
+ handshakeState handshakeState
|
|
|
+ udpChannel ssh.Channel
|
|
|
+ trafficRules TrafficRules
|
|
|
+ tcpTrafficState trafficState
|
|
|
+ udpTrafficState trafficState
|
|
|
+ qualityMetrics qualityMetrics
|
|
|
+ channelHandlerWaitGroup *sync.WaitGroup
|
|
|
+ tcpPortForwardLRU *common.LRUConns
|
|
|
+ oslClientSeedState *osl.ClientSeedState
|
|
|
+ stopBroadcast chan struct{}
|
|
|
+}
|
|
|
+
|
|
|
+type trafficState struct {
|
|
|
+ bytesUp int64
|
|
|
+ bytesDown int64
|
|
|
+ concurrentPortForwardCount int64
|
|
|
+ peakConcurrentPortForwardCount int64
|
|
|
+ totalPortForwardCount int64
|
|
|
+}
|
|
|
+
|
|
|
+// qualityMetrics records upstream TCP dial attempts and
|
|
|
+// elapsed time. Elapsed time includes the full TCP handshake
|
|
|
+// and, in aggregate, is a measure of the quality of the
|
|
|
+// upstream link. These stats are recorded by each sshClient
|
|
|
+// and then reported and reset in sshServer.getLoadStats().
|
|
|
+type qualityMetrics struct {
|
|
|
+ tcpPortForwardDialedCount int64
|
|
|
+ tcpPortForwardDialedDuration time.Duration
|
|
|
+ tcpPortForwardFailedCount int64
|
|
|
+ tcpPortForwardFailedDuration time.Duration
|
|
|
+}
|
|
|
+
|
|
|
+type handshakeState struct {
|
|
|
+ completed bool
|
|
|
+ apiProtocol string
|
|
|
+ apiParams requestJSONObject
|
|
|
+}
|
|
|
+
|
|
|
+func newSshClient(
|
|
|
+ sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
|
|
|
+ return &sshClient{
|
|
|
+ sshServer: sshServer,
|
|
|
+ tunnelProtocol: tunnelProtocol,
|
|
|
+ geoIPData: geoIPData,
|
|
|
+ channelHandlerWaitGroup: new(sync.WaitGroup),
|
|
|
+ tcpPortForwardLRU: common.NewLRUConns(),
|
|
|
+ stopBroadcast: make(chan struct{}),
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) run(clientConn net.Conn) {
|
|
|
+
|
|
|
// Set initial traffic rules, pre-handshake, based on currently known info.
|
|
|
sshClient.setTrafficRules()
|
|
|
|
|
|
@@ -569,6 +673,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
clientConn,
|
|
|
SSH_CONNECTION_READ_DEADLINE,
|
|
|
false,
|
|
|
+ nil,
|
|
|
nil)
|
|
|
if err != nil {
|
|
|
clientConn.Close()
|
|
|
@@ -607,21 +712,21 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
sshServerConfig := &ssh.ServerConfig{
|
|
|
PasswordCallback: sshClient.passwordCallback,
|
|
|
AuthLogCallback: sshClient.authLogCallback,
|
|
|
- ServerVersion: sshServer.support.Config.SSHServerVersion,
|
|
|
+ ServerVersion: sshClient.sshServer.support.Config.SSHServerVersion,
|
|
|
}
|
|
|
- sshServerConfig.AddHostKey(sshServer.sshHostKey)
|
|
|
+ sshServerConfig.AddHostKey(sshClient.sshServer.sshHostKey)
|
|
|
|
|
|
result := &sshNewServerConnResult{}
|
|
|
|
|
|
// Wrap the connection in an SSH deobfuscator when required.
|
|
|
|
|
|
- if common.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
|
|
|
+ if protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
|
|
|
// Note: NewObfuscatedSshConn blocks on network I/O
|
|
|
// TODO: ensure this won't block shutdown
|
|
|
conn, result.err = psiphon.NewObfuscatedSshConn(
|
|
|
psiphon.OBFUSCATION_CONN_MODE_SERVER,
|
|
|
conn,
|
|
|
- sshServer.support.Config.ObfuscatedSSHKey)
|
|
|
+ sshClient.sshServer.support.Config.ObfuscatedSSHKey)
|
|
|
if result.err != nil {
|
|
|
result.err = common.ContextError(result.err)
|
|
|
}
|
|
|
@@ -639,7 +744,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
var result *sshNewServerConnResult
|
|
|
select {
|
|
|
case result = <-resultChannel:
|
|
|
- case <-sshServer.shutdownBroadcast:
|
|
|
+ case <-sshClient.sshServer.shutdownBroadcast:
|
|
|
// Close() will interrupt an ongoing handshake
|
|
|
// TODO: wait for goroutine to exit before returning?
|
|
|
clientConn.Close()
|
|
|
@@ -661,80 +766,22 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
sshClient.throttledConn = throttledConn
|
|
|
sshClient.Unlock()
|
|
|
|
|
|
- if !sshServer.registerEstablishedClient(sshClient) {
|
|
|
+ if !sshClient.sshServer.registerEstablishedClient(sshClient) {
|
|
|
clientConn.Close()
|
|
|
log.WithContext().Warning("register failed")
|
|
|
return
|
|
|
}
|
|
|
- defer sshServer.unregisterEstablishedClient(sshClient.sessionID)
|
|
|
+ defer sshClient.sshServer.unregisterEstablishedClient(sshClient.sessionID)
|
|
|
|
|
|
- sshClient.runClient(result.channels, result.requests)
|
|
|
+ sshClient.runTunnel(result.channels, result.requests)
|
|
|
|
|
|
- // Note: sshServer.unregisterClient calls sshClient.Close(),
|
|
|
+ // Note: sshServer.unregisterEstablishedClient calls sshClient.Close(),
|
|
|
// which also closes underlying transport Conn.
|
|
|
}
|
|
|
|
|
|
-type sshClient struct {
|
|
|
- sync.Mutex
|
|
|
- sshServer *sshServer
|
|
|
- tunnelProtocol string
|
|
|
- sshConn ssh.Conn
|
|
|
- activityConn *common.ActivityMonitoredConn
|
|
|
- throttledConn *common.ThrottledConn
|
|
|
- geoIPData GeoIPData
|
|
|
- sessionID string
|
|
|
- handshakeState handshakeState
|
|
|
- udpChannel ssh.Channel
|
|
|
- trafficRules TrafficRules
|
|
|
- tcpTrafficState trafficState
|
|
|
- udpTrafficState trafficState
|
|
|
- qualityMetrics qualityMetrics
|
|
|
- channelHandlerWaitGroup *sync.WaitGroup
|
|
|
- tcpPortForwardLRU *common.LRUConns
|
|
|
- stopBroadcast chan struct{}
|
|
|
-}
|
|
|
-
|
|
|
-type trafficState struct {
|
|
|
- bytesUp int64
|
|
|
- bytesDown int64
|
|
|
- concurrentPortForwardCount int64
|
|
|
- peakConcurrentPortForwardCount int64
|
|
|
- totalPortForwardCount int64
|
|
|
-}
|
|
|
-
|
|
|
-// qualityMetrics records upstream TCP dial attempts and
|
|
|
-// elapsed time. Elapsed time includes the full TCP handshake
|
|
|
-// and, in aggregate, is a measure of the quality of the
|
|
|
-// upstream link. These stats are recorded by each sshClient
|
|
|
-// and then reported and reset in sshServer.getLoadStats().
|
|
|
-type qualityMetrics struct {
|
|
|
- tcpPortForwardDialedCount int64
|
|
|
- tcpPortForwardDialedDuration time.Duration
|
|
|
- tcpPortForwardFailedCount int64
|
|
|
- tcpPortForwardFailedDuration time.Duration
|
|
|
-}
|
|
|
-
|
|
|
-type handshakeState struct {
|
|
|
- completed bool
|
|
|
- apiProtocol string
|
|
|
- apiParams requestJSONObject
|
|
|
-}
|
|
|
-
|
|
|
-func newSshClient(
|
|
|
- sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
|
|
|
- return &sshClient{
|
|
|
- sshServer: sshServer,
|
|
|
- tunnelProtocol: tunnelProtocol,
|
|
|
- geoIPData: geoIPData,
|
|
|
- channelHandlerWaitGroup: new(sync.WaitGroup),
|
|
|
- tcpPortForwardLRU: common.NewLRUConns(),
|
|
|
- stopBroadcast: make(chan struct{}),
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
|
|
|
- expectedSessionIDLength := 2 * common.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
|
|
|
+ expectedSessionIDLength := 2 * protocol.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
|
|
|
expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
|
|
|
|
|
|
var sshPasswordPayload struct {
|
|
|
@@ -867,10 +914,10 @@ func (sshClient *sshClient) stop() {
|
|
|
log.LogRawFieldsWithTimestamp(logFields)
|
|
|
}
|
|
|
|
|
|
-// runClient handles/dispatches new channel and new requests from the client.
|
|
|
+// runTunnel handles/dispatches new channel and new requests from the client.
|
|
|
// When the SSH client connection closes, both the channels and requests channels
|
|
|
// will close and runClient will exit.
|
|
|
-func (sshClient *sshClient) runClient(
|
|
|
+func (sshClient *sshClient) runTunnel(
|
|
|
channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
|
|
|
|
|
|
requestsWaitGroup := new(sync.WaitGroup)
|
|
|
@@ -975,22 +1022,28 @@ func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChanne
|
|
|
// handshake parameters are included in the session summary log recorded in
|
|
|
// sshClient.stop().
|
|
|
func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
|
|
|
+
|
|
|
sshClient.Lock()
|
|
|
- defer sshClient.Unlock()
|
|
|
+ completed := sshClient.handshakeState.completed
|
|
|
+ if !completed {
|
|
|
+ sshClient.handshakeState = state
|
|
|
+ }
|
|
|
+ sshClient.Unlock()
|
|
|
|
|
|
// Client must only perform one handshake
|
|
|
- if sshClient.handshakeState.completed {
|
|
|
+ if completed {
|
|
|
return common.ContextError(errors.New("handshake already completed"))
|
|
|
}
|
|
|
|
|
|
- sshClient.handshakeState = state
|
|
|
+ sshClient.setTrafficRules()
|
|
|
+ sshClient.setOSLConfig()
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// setTrafficRules resets the client's traffic rules based on the latest server config
|
|
|
-// and client state. As sshClient.trafficRules may be reset by a concurrent goroutine,
|
|
|
-// trafficRules must only be accessed within the sshClient mutex.
|
|
|
+// and client properties. As sshClient.trafficRules may be reset by a concurrent
|
|
|
+// goroutine, trafficRules must only be accessed within the sshClient mutex.
|
|
|
func (sshClient *sshClient) setTrafficRules() {
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
@@ -999,11 +1052,68 @@ func (sshClient *sshClient) setTrafficRules() {
|
|
|
sshClient.tunnelProtocol, sshClient.geoIPData, sshClient.handshakeState)
|
|
|
|
|
|
if sshClient.throttledConn != nil {
|
|
|
+ // Any existing throttling state is reset.
|
|
|
sshClient.throttledConn.SetLimits(
|
|
|
sshClient.trafficRules.RateLimits.CommonRateLimits())
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// setOSLConfig resets the client's OSL seed state based on the latest OSL config
|
|
|
+// As sshClient.oslClientSeedState may be reset by a concurrent goroutine,
|
|
|
+// oslClientSeedState must only be accessed within the sshClient mutex.
|
|
|
+func (sshClient *sshClient) setOSLConfig() {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ propagationChannelID, err := getStringRequestParam(
|
|
|
+ sshClient.handshakeState.apiParams, "propagation_channel_id")
|
|
|
+ if err != nil {
|
|
|
+ // This should not fail as long as client has sent valid handshake
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Two limitations when setOSLConfig() is invoked due to an
|
|
|
+ // OSL config hot reload:
|
|
|
+ //
|
|
|
+ // 1. any partial progress towards SLOKs is lost.
|
|
|
+ //
|
|
|
+ // 2. all existing osl.ClientSeedPortForwards for existing
|
|
|
+ // port forwards will not send progress to the new client
|
|
|
+ // seed state.
|
|
|
+
|
|
|
+ sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
|
|
|
+ sshClient.geoIPData.Country,
|
|
|
+ propagationChannelID)
|
|
|
+}
|
|
|
+
|
|
|
+// newClientSeedPortForward will return nil when no seeding is
|
|
|
+// associated with the specified ipAddress.
|
|
|
+func (sshClient *sshClient) newClientSeedPortForward(ipAddress net.IP) *osl.ClientSeedPortForward {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ // Will not be initialized before handshake.
|
|
|
+ if sshClient.oslClientSeedState == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return sshClient.oslClientSeedState.NewClientSeedPortForward(ipAddress)
|
|
|
+}
|
|
|
+
|
|
|
+// getClientSeedPayload returns a payload containing all seeded SLOKs for
|
|
|
+// this client's session.
|
|
|
+func (sshClient *sshClient) getClientSeedPayload() *osl.SeedPayload {
|
|
|
+ sshClient.Lock()
|
|
|
+ defer sshClient.Unlock()
|
|
|
+
|
|
|
+ // Will not be initialized before handshake.
|
|
|
+ if sshClient.oslClientSeedState == nil {
|
|
|
+ return &osl.SeedPayload{SLOKs: make([]*osl.SLOK, 0)}
|
|
|
+ }
|
|
|
+
|
|
|
+ return sshClient.oslClientSeedState.GetSeedPayload()
|
|
|
+}
|
|
|
+
|
|
|
func (sshClient *sshClient) rateLimits() common.RateLimits {
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
@@ -1032,7 +1142,7 @@ const (
|
|
|
)
|
|
|
|
|
|
func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
- portForwardType int, host string, port int) bool {
|
|
|
+ portForwardType int, remoteIP net.IP, port int) bool {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
defer sshClient.Unlock()
|
|
|
@@ -1041,7 +1151,9 @@ func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
- if common.Contains(SSH_DISALLOWED_PORT_FORWARD_HOSTS, host) {
|
|
|
+ // Disallow connection to loopback. This is a failsafe. The server
|
|
|
+ // should be run on a host with correctly configured firewall rules.
|
|
|
+ if remoteIP.IsLoopback() {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
@@ -1065,17 +1177,11 @@ func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // TODO: AllowSubnets won't match when host is a domain.
|
|
|
- // Callers should resolve domain host before checking
|
|
|
- // isPortForwardPermitted.
|
|
|
-
|
|
|
- if ip := net.ParseIP(host); ip != nil {
|
|
|
- for _, subnet := range sshClient.trafficRules.AllowSubnets {
|
|
|
- // Note: ignoring error as config has been validated
|
|
|
- _, network, _ := net.ParseCIDR(subnet)
|
|
|
- if network.Contains(ip) {
|
|
|
- return true
|
|
|
- }
|
|
|
+ for _, subnet := range sshClient.trafficRules.AllowSubnets {
|
|
|
+ // Note: ignoring error as config has been validated
|
|
|
+ _, network, _ := net.ParseCIDR(subnet)
|
|
|
+ if network.Contains(remoteIP) {
|
|
|
+ return true
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1179,8 +1285,48 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if !isWebServerPortForward && !sshClient.isPortForwardPermitted(
|
|
|
- portForwardTypeTCP, hostToConnect, portToConnect) {
|
|
|
+ type lookupIPResult struct {
|
|
|
+ IP net.IP
|
|
|
+ err error
|
|
|
+ }
|
|
|
+ lookupResultChannel := make(chan *lookupIPResult, 1)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ // TODO: explicit timeout for DNS resolution?
|
|
|
+ IPs, err := net.LookupIP(hostToConnect)
|
|
|
+ // TODO: shuffle list to try other IPs
|
|
|
+ // TODO: IPv6 support
|
|
|
+ var IP net.IP
|
|
|
+ for _, ip := range IPs {
|
|
|
+ if ip.To4() != nil {
|
|
|
+ IP = ip
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if err == nil && IP == nil {
|
|
|
+ err = errors.New("no IP address")
|
|
|
+ }
|
|
|
+ lookupResultChannel <- &lookupIPResult{IP, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ var lookupResult *lookupIPResult
|
|
|
+ select {
|
|
|
+ case lookupResult = <-lookupResultChannel:
|
|
|
+ case <-sshClient.stopBroadcast:
|
|
|
+ // Note: may leave LookupIP in progress
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if lookupResult.err != nil {
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
+ newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", lookupResult.err))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if !isWebServerPortForward &&
|
|
|
+ !sshClient.isPortForwardPermitted(
|
|
|
+ portForwardTypeTCP,
|
|
|
+ lookupResult.IP,
|
|
|
+ portToConnect) {
|
|
|
|
|
|
sshClient.rejectNewChannel(
|
|
|
newChannel, ssh.Prohibited, "port forward not permitted")
|
|
|
@@ -1239,46 +1385,47 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
// Dial the target remote address. This is done in a goroutine to
|
|
|
// ensure the shutdown signal is handled immediately.
|
|
|
|
|
|
- remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
|
|
|
+ remoteAddr := net.JoinHostPort(lookupResult.IP.String(), strconv.Itoa(portToConnect))
|
|
|
|
|
|
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
|
|
|
|
- type dialTcpResult struct {
|
|
|
+ type dialTCPResult struct {
|
|
|
conn net.Conn
|
|
|
err error
|
|
|
}
|
|
|
+ dialResultChannel := make(chan *dialTCPResult, 1)
|
|
|
|
|
|
- resultChannel := make(chan *dialTcpResult, 1)
|
|
|
dialStartTime := monotime.Now()
|
|
|
|
|
|
go func() {
|
|
|
// TODO: on EADDRNOTAVAIL, temporarily suspend new clients
|
|
|
- // TODO: IPv6 support
|
|
|
conn, err := net.DialTimeout(
|
|
|
- "tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
|
|
|
- resultChannel <- &dialTcpResult{conn, err}
|
|
|
+ "tcp", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
|
|
|
+ dialResultChannel <- &dialTCPResult{conn, err}
|
|
|
}()
|
|
|
|
|
|
- var result *dialTcpResult
|
|
|
+ var dialResult *dialTCPResult
|
|
|
select {
|
|
|
- case result = <-resultChannel:
|
|
|
+ case dialResult = <-dialResultChannel:
|
|
|
case <-sshClient.stopBroadcast:
|
|
|
- // Note: may leave dial in progress (TODO: use DialContext to cancel)
|
|
|
+ // Note: may leave Dial in progress
|
|
|
+ // TODO: use net.Dialer.DialContext to be able to cancel
|
|
|
return
|
|
|
}
|
|
|
|
|
|
sshClient.updateQualityMetrics(
|
|
|
- result.err == nil, monotime.Since(dialStartTime))
|
|
|
+ dialResult.err == nil, monotime.Since(dialStartTime))
|
|
|
|
|
|
- if result.err != nil {
|
|
|
- sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error())
|
|
|
+ if dialResult.err != nil {
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
+ newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", dialResult.err))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// The upstream TCP port forward connection has been established. Schedule
|
|
|
// some cleanup and notify the SSH client that the channel is accepted.
|
|
|
|
|
|
- fwdConn := result.conn
|
|
|
+ fwdConn := dialResult.conn
|
|
|
defer fwdConn.Close()
|
|
|
|
|
|
fwdChannel, requests, err := newChannel.Accept()
|
|
|
@@ -1297,12 +1444,20 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
|
|
|
defer lruEntry.Remove()
|
|
|
|
|
|
+ // Ensure nil interface if newClientSeedPortForward returns nil
|
|
|
+ var updater common.ActivityUpdater
|
|
|
+ seedUpdater := sshClient.newClientSeedPortForward(lookupResult.IP)
|
|
|
+ if seedUpdater != nil {
|
|
|
+ updater = seedUpdater
|
|
|
+ }
|
|
|
+
|
|
|
fwdConn, err = common.NewActivityMonitoredConn(
|
|
|
fwdConn,
|
|
|
sshClient.idleTCPPortForwardTimeout(),
|
|
|
true,
|
|
|
+ updater,
|
|
|
lruEntry)
|
|
|
- if result.err != nil {
|
|
|
+ if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
|
|
|
return
|
|
|
}
|