/*
* Copyright (c) 2016, Psiphon Inc.
* All rights reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*
*/
package server
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/json"
std_errors "errors"
"fmt"
"io"
"io/ioutil"
"net"
"strconv"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/Psiphon-Labs/goarista/monotime"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/accesscontrol"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/marionette"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tapdance"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
"github.com/marusama/semaphore"
cache "github.com/patrickmn/go-cache"
)
const (
SSH_AUTH_LOG_PERIOD = 30 * time.Minute
SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
SSH_BEGIN_HANDSHAKE_TIMEOUT = 1 * time.Second
SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
SSH_TCP_PORT_FORWARD_QUEUE_SIZE = 1024
SSH_KEEP_ALIVE_PAYLOAD_MIN_BYTES = 0
SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES = 256
SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second
SSH_SEND_OSL_RETRY_FACTOR = 2
OSL_SESSION_CACHE_TTL = 5 * time.Minute
MAX_AUTHORIZATIONS = 16
PRE_HANDSHAKE_RANDOM_STREAM_MAX_COUNT = 1
RANDOM_STREAM_MAX_BYTES = 10485760
ALERT_REQUEST_QUEUE_BUFFER_SIZE = 16
)
// TunnelServer is the main server that accepts Psiphon client
// connections, via various obfuscation protocols, and provides
// port forwarding (TCP and UDP) services to the Psiphon client.
// At its core, TunnelServer is an SSH server. SSH is the base
// protocol that provides port forward multiplexing, and transport
// security. Layered on top of SSH, optionally, is Obfuscated SSH
// and meek protocols, which provide further circumvention
// capabilities.
type TunnelServer struct {
runWaitGroup *sync.WaitGroup
listenerError chan error
shutdownBroadcast <-chan struct{}
sshServer *sshServer
}
type sshListener struct {
net.Listener
localAddress string
tunnelProtocol string
port int
BPFProgramName string
}
// NewTunnelServer initializes a new tunnel server.
func NewTunnelServer(
support *SupportServices,
shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
sshServer, err := newSSHServer(support, shutdownBroadcast)
if err != nil {
return nil, errors.Trace(err)
}
return &TunnelServer{
runWaitGroup: new(sync.WaitGroup),
listenerError: make(chan error),
shutdownBroadcast: shutdownBroadcast,
sshServer: sshServer,
}, nil
}
// Run runs the tunnel server; this function blocks while running a selection of
// listeners that handle connection using various obfuscation protocols.
//
// Run listens on each designated tunnel port and spawns new goroutines to handle
// each client connection. It halts when shutdownBroadcast is signaled. A list of active
// clients is maintained, and when halting all clients are cleanly shutdown.
//
// Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
// authentication, and then looping on client new channel requests. "direct-tcpip"
// channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress
// config parameter is configured, UDP port forwards over a TCP stream, following
// the udpgw protocol, are handled.
//
// A new goroutine is spawned to handle each port forward for each client. Each port
// forward tracks its bytes transferred. Overall per-client stats for connection duration,
// GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
// client shuts down.
//
// Note: client handler goroutines may still be shutting down after Run() returns. See
// comment in sshClient.stop(). TODO: fully synchronized shutdown.
func (server *TunnelServer) Run() error {
// TODO: should TunnelServer hold its own support pointer?
support := server.sshServer.support
// First bind all listeners; once all are successful,
// start accepting connections on each.
var listeners []*sshListener
for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts {
localAddress := fmt.Sprintf(
"%s:%d", support.Config.ServerIPAddress, listenPort)
var listener net.Listener
var BPFProgramName string
var err error
if protocol.TunnelProtocolUsesFrontedMeekQUIC(tunnelProtocol) {
// For FRONTED-MEEK-QUIC-OSSH, no listener implemented. The edge-to-server
// hop uses HTTPS and the client tunnel protocol is distinguished using
// protocol.MeekCookieData.ClientTunnelProtocol.
continue
} else if protocol.TunnelProtocolUsesQUIC(tunnelProtocol) {
listener, err = quic.Listen(
CommonLogger(log),
localAddress,
support.Config.ObfuscatedSSHKey)
} else if protocol.TunnelProtocolUsesMarionette(tunnelProtocol) {
listener, err = marionette.Listen(
support.Config.ServerIPAddress,
support.Config.MarionetteFormat)
} else {
listener, BPFProgramName, err = newTCPListenerWithBPF(support, localAddress)
if protocol.TunnelProtocolUsesTapdance(tunnelProtocol) {
listener, err = tapdance.Listen(listener)
}
}
if err != nil {
for _, existingListener := range listeners {
existingListener.Listener.Close()
}
return errors.Trace(err)
}
tacticsListener := tactics.NewListener(
listener,
support.TacticsServer,
tunnelProtocol,
func(IPAddress string) common.GeoIPData {
return common.GeoIPData(support.GeoIPService.Lookup(IPAddress))
})
log.WithTraceFields(
LogFields{
"localAddress": localAddress,
"tunnelProtocol": tunnelProtocol,
"BPFProgramName": BPFProgramName,
}).Info("listening")
listeners = append(
listeners,
&sshListener{
Listener: tacticsListener,
localAddress: localAddress,
port: listenPort,
tunnelProtocol: tunnelProtocol,
BPFProgramName: BPFProgramName,
})
}
for _, listener := range listeners {
server.runWaitGroup.Add(1)
go func(listener *sshListener) {
defer server.runWaitGroup.Done()
log.WithTraceFields(
LogFields{
"localAddress": listener.localAddress,
"tunnelProtocol": listener.tunnelProtocol,
}).Info("running")
server.sshServer.runListener(
listener,
server.listenerError)
log.WithTraceFields(
LogFields{
"localAddress": listener.localAddress,
"tunnelProtocol": listener.tunnelProtocol,
}).Info("stopped")
}(listener)
}
var err error
select {
case <-server.shutdownBroadcast:
case err = <-server.listenerError:
}
for _, listener := range listeners {
listener.Close()
}
server.sshServer.stopClients()
server.runWaitGroup.Wait()
log.WithTrace().Info("stopped")
return err
}
// GetLoadStats returns load stats for the tunnel server. The stats are
// broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
// include current connected client count, total number of current port
// forwards.
func (server *TunnelServer) GetLoadStats() (ProtocolStats, RegionStats) {
return server.sshServer.getLoadStats()
}
// GetEstablishedClientCount returns the number of currently established
// clients.
func (server *TunnelServer) GetEstablishedClientCount() int {
return server.sshServer.getEstablishedClientCount()
}
// ResetAllClientTrafficRules resets all established client traffic rules
// to use the latest config and client properties. Any existing traffic
// rule state is lost, including throttling state.
func (server *TunnelServer) ResetAllClientTrafficRules() {
server.sshServer.resetAllClientTrafficRules()
}
// ResetAllClientOSLConfigs resets all established client OSL state to use
// the latest OSL config. Any existing OSL state is lost, including partial
// progress towards SLOKs.
func (server *TunnelServer) ResetAllClientOSLConfigs() {
server.sshServer.resetAllClientOSLConfigs()
}
// SetClientHandshakeState sets the handshake state -- that it completed and
// what parameters were passed -- in sshClient. This state is used for allowing
// port forwards and for future traffic rule selection. SetClientHandshakeState
// also triggers an immediate traffic rule re-selection, as the rules selected
// upon tunnel establishment may no longer apply now that handshake values are
// set.
//
// The authorizations received from the client handshake are verified and the
// resulting list of authorized access types are applied to the client's tunnel
// and traffic rules.
//
// A list of active authorization IDs, authorized access types, and traffic
// rate limits are returned for responding to the client and logging.
func (server *TunnelServer) SetClientHandshakeState(
sessionID string,
state handshakeState,
authorizations []string) (*handshakeStateInfo, error) {
return server.sshServer.setClientHandshakeState(sessionID, state, authorizations)
}
// GetClientHandshaked indicates whether the client has completed a handshake
// and whether its traffic rules are immediately exhausted.
func (server *TunnelServer) GetClientHandshaked(
sessionID string) (bool, bool, error) {
return server.sshServer.getClientHandshaked(sessionID)
}
// UpdateClientAPIParameters updates the recorded handshake API parameters for
// the client corresponding to sessionID.
func (server *TunnelServer) UpdateClientAPIParameters(
sessionID string,
apiParams common.APIParameters) error {
return server.sshServer.updateClientAPIParameters(sessionID, apiParams)
}
// ExpectClientDomainBytes indicates whether the client was configured to report
// domain bytes in its handshake response.
func (server *TunnelServer) ExpectClientDomainBytes(
sessionID string) (bool, error) {
return server.sshServer.expectClientDomainBytes(sessionID)
}
// SetEstablishTunnels sets whether new tunnels may be established or not.
// When not establishing, incoming connections are immediately closed.
func (server *TunnelServer) SetEstablishTunnels(establish bool) {
server.sshServer.setEstablishTunnels(establish)
}
// CheckEstablishTunnels returns whether new tunnels may be established or
// not, and increments a metrics counter when establishment is disallowed.
func (server *TunnelServer) CheckEstablishTunnels() bool {
return server.sshServer.checkEstablishTunnels()
}
// GetEstablishTunnelsMetrics returns whether tunnel establishment is
// currently allowed and the number of tunnels rejected since due to not
// establishing since the last GetEstablishTunnelsMetrics call.
func (server *TunnelServer) GetEstablishTunnelsMetrics() (bool, int64) {
return server.sshServer.getEstablishTunnelsMetrics()
}
type sshServer struct {
// Note: 64-bit ints used with atomic operations are placed
// at the start of struct to ensure 64-bit alignment.
// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
lastAuthLog int64
authFailedCount int64
establishLimitedCount int64
support *SupportServices
establishTunnels int32
concurrentSSHHandshakes semaphore.Semaphore
shutdownBroadcast <-chan struct{}
sshHostKey ssh.Signer
clientsMutex sync.Mutex
stoppingClients bool
acceptedClientCounts map[string]map[string]int64
clients map[string]*sshClient
oslSessionCacheMutex sync.Mutex
oslSessionCache *cache.Cache
authorizationSessionIDsMutex sync.Mutex
authorizationSessionIDs map[string]string
obfuscatorSeedHistory *obfuscator.SeedHistory
}
func newSSHServer(
support *SupportServices,
shutdownBroadcast <-chan struct{}) (*sshServer, error) {
privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey))
if err != nil {
return nil, errors.Trace(err)
}
// TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
signer, err := ssh.NewSignerFromKey(privateKey)
if err != nil {
return nil, errors.Trace(err)
}
var concurrentSSHHandshakes semaphore.Semaphore
if support.Config.MaxConcurrentSSHHandshakes > 0 {
concurrentSSHHandshakes = semaphore.New(support.Config.MaxConcurrentSSHHandshakes)
}
// The OSL session cache temporarily retains OSL seed state
// progress for disconnected clients. This enables clients
// that disconnect and immediately reconnect to the same
// server to resume their OSL progress. Cached progress
// is referenced by session ID and is retained for
// OSL_SESSION_CACHE_TTL after disconnect.
//
// Note: session IDs are assumed to be unpredictable. If a
// rogue client could guess the session ID of another client,
// it could resume its OSL progress and, if the OSL config
// were known, infer some activity.
oslSessionCache := cache.New(OSL_SESSION_CACHE_TTL, 1*time.Minute)
return &sshServer{
support: support,
establishTunnels: 1,
concurrentSSHHandshakes: concurrentSSHHandshakes,
shutdownBroadcast: shutdownBroadcast,
sshHostKey: signer,
acceptedClientCounts: make(map[string]map[string]int64),
clients: make(map[string]*sshClient),
oslSessionCache: oslSessionCache,
authorizationSessionIDs: make(map[string]string),
obfuscatorSeedHistory: obfuscator.NewSeedHistory(nil),
}, nil
}
func (sshServer *sshServer) setEstablishTunnels(establish bool) {
// Do nothing when the setting is already correct. This avoids
// spurious log messages when setEstablishTunnels is called
// periodically with the same setting.
if establish == (atomic.LoadInt32(&sshServer.establishTunnels) == 1) {
return
}
establishFlag := int32(1)
if !establish {
establishFlag = 0
}
atomic.StoreInt32(&sshServer.establishTunnels, establishFlag)
log.WithTraceFields(
LogFields{"establish": establish}).Info("establishing tunnels")
}
func (sshServer *sshServer) checkEstablishTunnels() bool {
establishTunnels := atomic.LoadInt32(&sshServer.establishTunnels) == 1
if !establishTunnels {
atomic.AddInt64(&sshServer.establishLimitedCount, 1)
}
return establishTunnels
}
func (sshServer *sshServer) getEstablishTunnelsMetrics() (bool, int64) {
return atomic.LoadInt32(&sshServer.establishTunnels) == 1,
atomic.SwapInt64(&sshServer.establishLimitedCount, 0)
}
// runListener is intended to run an a goroutine; it blocks
// running a particular listener. If an unrecoverable error
// occurs, it will send the error to the listenerError channel.
func (sshServer *sshServer) runListener(sshListener *sshListener, listenerError chan<- error) {
runningProtocols := make([]string, 0)
for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
runningProtocols = append(runningProtocols, tunnelProtocol)
}
handleClient := func(clientTunnelProtocol string, clientConn net.Conn) {
// Note: establish tunnel limiter cannot simply stop TCP
// listeners in all cases (e.g., meek) since SSH tunnels can
// span multiple TCP connections.
if !sshServer.checkEstablishTunnels() {
log.WithTrace().Debug("not establishing tunnels")
clientConn.Close()
return
}
// The tunnelProtocol passed to handleClient is used for stats,
// throttling, etc. When the tunnel protocol can be determined
// unambiguously from the listening port, use that protocol and
// don't use any client-declared value. Only use the client's
// value, if present, in special cases where the listening port
// cannot distinguish the protocol.
tunnelProtocol := sshListener.tunnelProtocol
if clientTunnelProtocol != "" {
if !common.Contains(runningProtocols, clientTunnelProtocol) {
log.WithTraceFields(
LogFields{
"clientTunnelProtocol": clientTunnelProtocol}).
Warning("invalid client tunnel protocol")
clientConn.Close()
return
}
if protocol.UseClientTunnelProtocol(
clientTunnelProtocol, runningProtocols) {
tunnelProtocol = clientTunnelProtocol
}
}
// sshListener.tunnelProtocol indictes the tunnel protocol run by the
// listener. For direct protocols, this is also the client tunnel protocol.
// For fronted protocols, the client may use a different protocol to connect
// to the front and then only the front-to-Psiphon server will use the
// listener protocol.
//
// A fronted meek client, for example, reports its first hop protocol in
// protocol.MeekCookieData.ClientTunnelProtocol. Most metrics record this
// value as relay_protocol, since the first hop is the one subject to
// adversarial conditions. In some cases, such as irregular tunnels, there
// is no ClientTunnelProtocol value available and the listener tunnel
// protocol will be logged.
//
// Similarly, listenerPort indicates the listening port, which is the dialed
// port number for direct protocols; while, for fronted protocols, the
// client may dial a different port for its first hop.
// Process each client connection concurrently.
go sshServer.handleClient(sshListener, tunnelProtocol, clientConn)
}
// Note: when exiting due to a unrecoverable error, be sure
// to try to send the error to listenerError so that the outer
// TunnelServer.Run will properly shut down instead of remaining
// running.
if protocol.TunnelProtocolUsesMeekHTTP(sshListener.tunnelProtocol) ||
protocol.TunnelProtocolUsesMeekHTTPS(sshListener.tunnelProtocol) {
meekServer, err := NewMeekServer(
sshServer.support,
sshListener.Listener,
sshListener.tunnelProtocol,
sshListener.port,
protocol.TunnelProtocolUsesMeekHTTPS(sshListener.tunnelProtocol),
protocol.TunnelProtocolUsesFrontedMeek(sshListener.tunnelProtocol),
protocol.TunnelProtocolUsesObfuscatedSessionTickets(sshListener.tunnelProtocol),
handleClient,
sshServer.shutdownBroadcast)
if err == nil {
err = meekServer.Run()
}
if err != nil {
select {
case listenerError <- errors.Trace(err):
default:
}
return
}
} else {
for {
conn, err := sshListener.Listener.Accept()
select {
case <-sshServer.shutdownBroadcast:
if err == nil {
conn.Close()
}
return
default:
}
if err != nil {
if e, ok := err.(net.Error); ok && e.Temporary() {
log.WithTraceFields(LogFields{"error": err}).Error("accept failed")
// Temporary error, keep running
continue
}
select {
case listenerError <- errors.Trace(err):
default:
}
return
}
handleClient("", conn)
}
}
}
// An accepted client has completed a direct TCP or meek connection and has a net.Conn. Registration
// is for tracking the number of connections.
func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol, region string) {
sshServer.clientsMutex.Lock()
defer sshServer.clientsMutex.Unlock()
if sshServer.acceptedClientCounts[tunnelProtocol] == nil {
sshServer.acceptedClientCounts[tunnelProtocol] = make(map[string]int64)
}
sshServer.acceptedClientCounts[tunnelProtocol][region] += 1
}
func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol, region string) {
sshServer.clientsMutex.Lock()
defer sshServer.clientsMutex.Unlock()
sshServer.acceptedClientCounts[tunnelProtocol][region] -= 1
}
// An established client has completed its SSH handshake and has a ssh.Conn. Registration is
// for tracking the number of fully established clients and for maintaining a list of running
// clients (for stopping at shutdown time).
func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
sshServer.clientsMutex.Lock()
if sshServer.stoppingClients {
sshServer.clientsMutex.Unlock()
return false
}
// In the case of a duplicate client sessionID, the previous client is closed.
// - Well-behaved clients generate a random sessionID that should be unique (won't
// accidentally conflict) and hard to guess (can't be targeted by a malicious
// client).
// - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected
// and reestablished. In this case, when the same server is selected, this logic
// will be hit; closing the old, dangling client is desirable.
// - Multi-tunnel clients should not normally use one server for multiple tunnels.
existingClient := sshServer.clients[client.sessionID]
sshServer.clientsMutex.Unlock()
if existingClient != nil {
// This case is expected to be common, and so logged at the lowest severity
// level.
log.WithTrace().Debug(
"stopping existing client with duplicate session ID")
existingClient.stop()
// Block until the existingClient is fully terminated. This is necessary to
// avoid this scenario:
// - existingClient is invoking handshakeAPIRequestHandler
// - sshServer.clients[client.sessionID] is updated to point to new client
// - existingClient's handshakeAPIRequestHandler invokes
// SetClientHandshakeState but sets the handshake parameters for new
// client
// - as a result, the new client handshake will fail (only a single handshake
// is permitted) and the new client server_tunnel log will contain an
// invalid mix of existing/new client fields
//
// Once existingClient.awaitStopped returns, all existingClient port
// forwards and request handlers have terminated, so no API handler, either
// tunneled web API or SSH API, will remain and it is safe to point
// sshServer.clients[client.sessionID] to the new client.
// Limitation: this scenario remains possible with _untunneled_ web API
// requests.
//
// Blocking also ensures existingClient.releaseAuthorizations is invoked before
// the new client attempts to submit the same authorizations.
//
// Perform blocking awaitStopped operation outside the
// sshServer.clientsMutex mutex to avoid blocking all other clients for the
// duration. We still expect and require that the stop process completes
// rapidly, e.g., does not block on network I/O, allowing the new client
// connection to proceed without delay.
//
// In addition, operations triggered by stop, and which must complete before
// awaitStopped returns, will attempt to lock sshServer.clientsMutex,
// including unregisterEstablishedClient.
existingClient.awaitStopped()
}
sshServer.clientsMutex.Lock()
defer sshServer.clientsMutex.Unlock()
// existingClient's stop will have removed it from sshServer.clients via
// unregisterEstablishedClient, so sshServer.clients[client.sessionID] should
// be nil -- unless yet another client instance using the same sessionID has
// connected in the meantime while awaiting existingClient stop. In this
// case, it's not clear which is the most recent connection from the client,
// so instead of this connection terminating more peers, it aborts.
if sshServer.clients[client.sessionID] != nil {
// As this is expected to be rare case, it's logged at a higher severity
// level.
log.WithTrace().Warning(
"aborting new client with duplicate session ID")
return false
}
sshServer.clients[client.sessionID] = client
return true
}
func (sshServer *sshServer) unregisterEstablishedClient(client *sshClient) {
sshServer.clientsMutex.Lock()
registeredClient := sshServer.clients[client.sessionID]
// registeredClient will differ from client when client is the existingClient
// terminated in registerEstablishedClient. In that case, registeredClient
// remains connected, and the sshServer.clients entry should be retained.
if registeredClient == client {
delete(sshServer.clients, client.sessionID)
}
sshServer.clientsMutex.Unlock()
client.stop()
}
type ProtocolStats map[string]map[string]int64
type RegionStats map[string]map[string]map[string]int64
func (sshServer *sshServer) getLoadStats() (ProtocolStats, RegionStats) {
sshServer.clientsMutex.Lock()
defer sshServer.clientsMutex.Unlock()
// Explicitly populate with zeros to ensure 0 counts in log messages
zeroStats := func() map[string]int64 {
stats := make(map[string]int64)
stats["accepted_clients"] = 0
stats["established_clients"] = 0
stats["dialing_tcp_port_forwards"] = 0
stats["tcp_port_forwards"] = 0
stats["total_tcp_port_forwards"] = 0
stats["udp_port_forwards"] = 0
stats["total_udp_port_forwards"] = 0
stats["tcp_port_forward_dialed_count"] = 0
stats["tcp_port_forward_dialed_duration"] = 0
stats["tcp_port_forward_failed_count"] = 0
stats["tcp_port_forward_failed_duration"] = 0
stats["tcp_port_forward_rejected_dialing_limit_count"] = 0
stats["tcp_port_forward_rejected_disallowed_count"] = 0
stats["udp_port_forward_rejected_disallowed_count"] = 0
stats["tcp_ipv4_port_forward_dialed_count"] = 0
stats["tcp_ipv4_port_forward_dialed_duration"] = 0
stats["tcp_ipv4_port_forward_failed_count"] = 0
stats["tcp_ipv4_port_forward_failed_duration"] = 0
stats["tcp_ipv6_port_forward_dialed_count"] = 0
stats["tcp_ipv6_port_forward_dialed_duration"] = 0
stats["tcp_ipv6_port_forward_failed_count"] = 0
stats["tcp_ipv6_port_forward_failed_duration"] = 0
return stats
}
zeroProtocolStats := func() map[string]map[string]int64 {
stats := make(map[string]map[string]int64)
stats["ALL"] = zeroStats()
for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
stats[tunnelProtocol] = zeroStats()
}
return stats
}
// [][] -> count
protocolStats := zeroProtocolStats()
// [][] -> count
regionStats := make(RegionStats)
// Note: as currently tracked/counted, each established client is also an accepted client
for tunnelProtocol, regionAcceptedClientCounts := range sshServer.acceptedClientCounts {
for region, acceptedClientCount := range regionAcceptedClientCounts {
if acceptedClientCount > 0 {
if regionStats[region] == nil {
regionStats[region] = zeroProtocolStats()
}
protocolStats["ALL"]["accepted_clients"] += acceptedClientCount
protocolStats[tunnelProtocol]["accepted_clients"] += acceptedClientCount
regionStats[region]["ALL"]["accepted_clients"] += acceptedClientCount
regionStats[region][tunnelProtocol]["accepted_clients"] += acceptedClientCount
}
}
}
for _, client := range sshServer.clients {
client.Lock()
tunnelProtocol := client.tunnelProtocol
region := client.geoIPData.Country
if regionStats[region] == nil {
regionStats[region] = zeroProtocolStats()
}
stats := []map[string]int64{
protocolStats["ALL"],
protocolStats[tunnelProtocol],
regionStats[region]["ALL"],
regionStats[region][tunnelProtocol]}
for _, stat := range stats {
stat["established_clients"] += 1
// Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
stat["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
stat["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
stat["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
// client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
stat["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
stat["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
stat["tcp_port_forward_dialed_count"] += client.qualityMetrics.TCPPortForwardDialedCount
stat["tcp_port_forward_dialed_duration"] +=
int64(client.qualityMetrics.TCPPortForwardDialedDuration / time.Millisecond)
stat["tcp_port_forward_failed_count"] += client.qualityMetrics.TCPPortForwardFailedCount
stat["tcp_port_forward_failed_duration"] +=
int64(client.qualityMetrics.TCPPortForwardFailedDuration / time.Millisecond)
stat["tcp_port_forward_rejected_dialing_limit_count"] +=
client.qualityMetrics.TCPPortForwardRejectedDialingLimitCount
stat["tcp_port_forward_rejected_disallowed_count"] +=
client.qualityMetrics.TCPPortForwardRejectedDisallowedCount
stat["udp_port_forward_rejected_disallowed_count"] +=
client.qualityMetrics.UDPPortForwardRejectedDisallowedCount
stat["tcp_ipv4_port_forward_dialed_count"] += client.qualityMetrics.TCPIPv4PortForwardDialedCount
stat["tcp_ipv4_port_forward_dialed_duration"] +=
int64(client.qualityMetrics.TCPIPv4PortForwardDialedDuration / time.Millisecond)
stat["tcp_ipv4_port_forward_failed_count"] += client.qualityMetrics.TCPIPv4PortForwardFailedCount
stat["tcp_ipv4_port_forward_failed_duration"] +=
int64(client.qualityMetrics.TCPIPv4PortForwardFailedDuration / time.Millisecond)
stat["tcp_ipv6_port_forward_dialed_count"] += client.qualityMetrics.TCPIPv6PortForwardDialedCount
stat["tcp_ipv6_port_forward_dialed_duration"] +=
int64(client.qualityMetrics.TCPIPv6PortForwardDialedDuration / time.Millisecond)
stat["tcp_ipv6_port_forward_failed_count"] += client.qualityMetrics.TCPIPv6PortForwardFailedCount
stat["tcp_ipv6_port_forward_failed_duration"] +=
int64(client.qualityMetrics.TCPIPv6PortForwardFailedDuration / time.Millisecond)
}
client.qualityMetrics.TCPPortForwardDialedCount = 0
client.qualityMetrics.TCPPortForwardDialedDuration = 0
client.qualityMetrics.TCPPortForwardFailedCount = 0
client.qualityMetrics.TCPPortForwardFailedDuration = 0
client.qualityMetrics.TCPPortForwardRejectedDialingLimitCount = 0
client.qualityMetrics.TCPPortForwardRejectedDisallowedCount = 0
client.qualityMetrics.UDPPortForwardRejectedDisallowedCount = 0
client.qualityMetrics.TCPIPv4PortForwardDialedCount = 0
client.qualityMetrics.TCPIPv4PortForwardDialedDuration = 0
client.qualityMetrics.TCPIPv4PortForwardFailedCount = 0
client.qualityMetrics.TCPIPv4PortForwardFailedDuration = 0
client.qualityMetrics.TCPIPv6PortForwardDialedCount = 0
client.qualityMetrics.TCPIPv6PortForwardDialedDuration = 0
client.qualityMetrics.TCPIPv6PortForwardFailedCount = 0
client.qualityMetrics.TCPIPv6PortForwardFailedDuration = 0
client.Unlock()
}
return protocolStats, regionStats
}
func (sshServer *sshServer) getEstablishedClientCount() int {
sshServer.clientsMutex.Lock()
defer sshServer.clientsMutex.Unlock()
establishedClients := len(sshServer.clients)
return establishedClients
}
func (sshServer *sshServer) resetAllClientTrafficRules() {
sshServer.clientsMutex.Lock()
clients := make(map[string]*sshClient)
for sessionID, client := range sshServer.clients {
clients[sessionID] = client
}
sshServer.clientsMutex.Unlock()
for _, client := range clients {
client.setTrafficRules()
}
}
func (sshServer *sshServer) resetAllClientOSLConfigs() {
// Flush cached seed state. This has the same effect
// and same limitations as calling setOSLConfig for
// currently connected clients -- all progress is lost.
sshServer.oslSessionCacheMutex.Lock()
sshServer.oslSessionCache.Flush()
sshServer.oslSessionCacheMutex.Unlock()
sshServer.clientsMutex.Lock()
clients := make(map[string]*sshClient)
for sessionID, client := range sshServer.clients {
clients[sessionID] = client
}
sshServer.clientsMutex.Unlock()
for _, client := range clients {
client.setOSLConfig()
}
}
func (sshServer *sshServer) setClientHandshakeState(
sessionID string,
state handshakeState,
authorizations []string) (*handshakeStateInfo, error) {
sshServer.clientsMutex.Lock()
client := sshServer.clients[sessionID]
sshServer.clientsMutex.Unlock()
if client == nil {
return nil, errors.TraceNew("unknown session ID")
}
handshakeStateInfo, err := client.setHandshakeState(
state, authorizations)
if err != nil {
return nil, errors.Trace(err)
}
return handshakeStateInfo, nil
}
func (sshServer *sshServer) getClientHandshaked(
sessionID string) (bool, bool, error) {
sshServer.clientsMutex.Lock()
client := sshServer.clients[sessionID]
sshServer.clientsMutex.Unlock()
if client == nil {
return false, false, errors.TraceNew("unknown session ID")
}
completed, exhausted := client.getHandshaked()
return completed, exhausted, nil
}
func (sshServer *sshServer) updateClientAPIParameters(
sessionID string,
apiParams common.APIParameters) error {
sshServer.clientsMutex.Lock()
client := sshServer.clients[sessionID]
sshServer.clientsMutex.Unlock()
if client == nil {
return errors.TraceNew("unknown session ID")
}
client.updateAPIParameters(apiParams)
return nil
}
func (sshServer *sshServer) revokeClientAuthorizations(sessionID string) {
sshServer.clientsMutex.Lock()
client := sshServer.clients[sessionID]
sshServer.clientsMutex.Unlock()
if client == nil {
return
}
// sshClient.handshakeState.authorizedAccessTypes is not cleared. Clearing
// authorizedAccessTypes may cause sshClient.logTunnel to fail to log
// access types. As the revocation may be due to legitimate use of an
// authorization in multiple sessions by a single client, useful metrics
// would be lost.
client.Lock()
client.handshakeState.authorizationsRevoked = true
client.Unlock()
// Select and apply new traffic rules, as filtered by the client's new
// authorization state.
client.setTrafficRules()
}
func (sshServer *sshServer) expectClientDomainBytes(
sessionID string) (bool, error) {
sshServer.clientsMutex.Lock()
client := sshServer.clients[sessionID]
sshServer.clientsMutex.Unlock()
if client == nil {
return false, errors.TraceNew("unknown session ID")
}
return client.expectDomainBytes(), nil
}
func (sshServer *sshServer) stopClients() {
sshServer.clientsMutex.Lock()
sshServer.stoppingClients = true
clients := sshServer.clients
sshServer.clients = make(map[string]*sshClient)
sshServer.clientsMutex.Unlock()
for _, client := range clients {
client.stop()
}
}
func (sshServer *sshServer) handleClient(
sshListener *sshListener, tunnelProtocol string, clientConn net.Conn) {
// Calling clientConn.RemoteAddr at this point, before any Read calls,
// satisfies the constraint documented in tapdance.Listen.
clientAddr := clientConn.RemoteAddr()
// Check if there were irregularities during the network connection
// establishment. When present, log and then behave as Obfuscated SSH does
// when the client fails to provide a valid seed message.
//
// One concrete irregular case is failure to send a PROXY protocol header for
// TAPDANCE-OSSH.
if indicator, ok := clientConn.(common.IrregularIndicator); ok {
tunnelErr := indicator.IrregularTunnelError()
if tunnelErr != nil {
logIrregularTunnel(
sshServer.support,
sshListener.tunnelProtocol,
sshListener.port,
common.IPAddressFromAddr(clientAddr),
errors.Trace(tunnelErr),
nil)
var afterFunc *time.Timer
if sshServer.support.Config.sshHandshakeTimeout > 0 {
afterFunc = time.AfterFunc(sshServer.support.Config.sshHandshakeTimeout, func() {
clientConn.Close()
})
}
io.Copy(ioutil.Discard, clientConn)
clientConn.Close()
afterFunc.Stop()
return
}
}
serverPacketManipulation := ""
if sshServer.support.Config.RunPacketManipulator &&
protocol.TunnelProtocolMayUseServerPacketManipulation(tunnelProtocol) {
// A meekConn has synthetic address values, including the original client
// address in cases where the client uses an upstream proxy to connect to
// Psiphon. For meekConn, and any other conn implementing
// UnderlyingTCPAddrSource, get the underlying TCP connection addresses.
//
// Limitation: a meek tunnel may consist of several TCP connections. The
// server_packet_manipulation metric will reflect the packet manipulation
// applied to the _first_ TCP connection only.
var localAddr, remoteAddr *net.TCPAddr
var ok bool
underlying, ok := clientConn.(UnderlyingTCPAddrSource)
if ok {
localAddr, remoteAddr, ok = underlying.GetUnderlyingTCPAddrs()
} else {
localAddr, ok = clientConn.LocalAddr().(*net.TCPAddr)
if ok {
remoteAddr, ok = clientConn.RemoteAddr().(*net.TCPAddr)
}
}
if ok {
specName, err := sshServer.support.PacketManipulator.
GetAppliedSpecName(localAddr, remoteAddr)
if err == nil {
serverPacketManipulation = specName
}
}
}
geoIPData := sshServer.support.GeoIPService.Lookup(
common.IPAddressFromAddr(clientAddr))
sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
// When configured, enforce a cap on the number of concurrent SSH
// handshakes. This limits load spikes on busy servers when many clients
// attempt to connect at once. Wait a short time, SSH_BEGIN_HANDSHAKE_TIMEOUT,
// to acquire; waiting will avoid immediately creating more load on another
// server in the network when the client tries a new candidate. Disconnect the
// client when that wait time is exceeded.
//
// This mechanism limits memory allocations and CPU usage associated with the
// SSH handshake. At this point, new direct TCP connections or new meek
// connections, with associated resource usage, are already established. Those
// connections are expected to be rate or load limited using other mechanisms.
//
// TODO:
//
// - deduct time spent acquiring the semaphore from SSH_HANDSHAKE_TIMEOUT in
// sshClient.run, since the client is also applying an SSH handshake timeout
// and won't exclude time spent waiting.
// - each call to sshServer.handleClient (in sshServer.runListener) is invoked
// in its own goroutine, but shutdown doesn't synchronously await these
// goroutnes. Once this is synchronizes, the following context.WithTimeout
// should use an sshServer parent context to ensure blocking acquires
// interrupt immediately upon shutdown.
var onSSHHandshakeFinished func()
if sshServer.support.Config.MaxConcurrentSSHHandshakes > 0 {
ctx, cancelFunc := context.WithTimeout(
context.Background(),
sshServer.support.Config.sshBeginHandshakeTimeout)
defer cancelFunc()
err := sshServer.concurrentSSHHandshakes.Acquire(ctx, 1)
if err != nil {
clientConn.Close()
// This is a debug log as the only possible error is context timeout.
log.WithTraceFields(LogFields{"error": err}).Debug(
"acquire SSH handshake semaphore failed")
return
}
onSSHHandshakeFinished = func() {
sshServer.concurrentSSHHandshakes.Release(1)
}
}
sshClient := newSshClient(
sshServer,
sshListener,
tunnelProtocol,
serverPacketManipulation,
geoIPData)
// sshClient.run _must_ call onSSHHandshakeFinished to release the semaphore:
// in any error case; or, as soon as the SSH handshake phase has successfully
// completed.
sshClient.run(clientConn, onSSHHandshakeFinished)
}
func (sshServer *sshServer) monitorPortForwardDialError(err error) {
// "err" is the error returned from a failed TCP or UDP port
// forward dial. Certain system error codes indicate low resource
// conditions: insufficient file descriptors, ephemeral ports, or
// memory. For these cases, log an alert.
// TODO: also temporarily suspend new clients
// Note: don't log net.OpError.Error() as the full error string
// may contain client destination addresses.
opErr, ok := err.(*net.OpError)
if ok {
if opErr.Err == syscall.EADDRNOTAVAIL ||
opErr.Err == syscall.EAGAIN ||
opErr.Err == syscall.ENOMEM ||
opErr.Err == syscall.EMFILE ||
opErr.Err == syscall.ENFILE {
log.WithTraceFields(
LogFields{"error": opErr.Err}).Error(
"port forward dial failed due to unavailable resource")
}
}
}
type sshClient struct {
sync.Mutex
sshServer *sshServer
sshListener *sshListener
tunnelProtocol string
sshConn ssh.Conn
activityConn *common.ActivityMonitoredConn
throttledConn *common.ThrottledConn
serverPacketManipulation string
geoIPData GeoIPData
sessionID string
isFirstTunnelInSession bool
supportsServerRequests bool
handshakeState handshakeState
udpChannel ssh.Channel
packetTunnelChannel ssh.Channel
trafficRules TrafficRules
tcpTrafficState trafficState
udpTrafficState trafficState
qualityMetrics qualityMetrics
tcpPortForwardLRU *common.LRUConns
oslClientSeedState *osl.ClientSeedState
signalIssueSLOKs chan struct{}
runCtx context.Context
stopRunning context.CancelFunc
stopped chan struct{}
tcpPortForwardDialingAvailableSignal context.CancelFunc
releaseAuthorizations func()
stopTimer *time.Timer
preHandshakeRandomStreamMetrics randomStreamMetrics
postHandshakeRandomStreamMetrics randomStreamMetrics
sendAlertRequests chan protocol.AlertRequest
sentAlertRequests map[protocol.AlertRequest]bool
}
type trafficState struct {
bytesUp int64
bytesDown int64
concurrentDialingPortForwardCount int64
peakConcurrentDialingPortForwardCount int64
concurrentPortForwardCount int64
peakConcurrentPortForwardCount int64
totalPortForwardCount int64
availablePortForwardCond *sync.Cond
}
type randomStreamMetrics struct {
count int
upstreamBytes int
receivedUpstreamBytes int
downstreamBytes int
sentDownstreamBytes int
}
// qualityMetrics records upstream TCP dial attempts and
// elapsed time. Elapsed time includes the full TCP handshake
// and, in aggregate, is a measure of the quality of the
// upstream link. These stats are recorded by each sshClient
// and then reported and reset in sshServer.getLoadStats().
type qualityMetrics struct {
TCPPortForwardDialedCount int64
TCPPortForwardDialedDuration time.Duration
TCPPortForwardFailedCount int64
TCPPortForwardFailedDuration time.Duration
TCPPortForwardRejectedDialingLimitCount int64
TCPPortForwardRejectedDisallowedCount int64
UDPPortForwardRejectedDisallowedCount int64
TCPIPv4PortForwardDialedCount int64
TCPIPv4PortForwardDialedDuration time.Duration
TCPIPv4PortForwardFailedCount int64
TCPIPv4PortForwardFailedDuration time.Duration
TCPIPv6PortForwardDialedCount int64
TCPIPv6PortForwardDialedDuration time.Duration
TCPIPv6PortForwardFailedCount int64
TCPIPv6PortForwardFailedDuration time.Duration
}
type handshakeState struct {
completed bool
apiProtocol string
apiParams common.APIParameters
activeAuthorizationIDs []string
authorizedAccessTypes []string
authorizationsRevoked bool
expectDomainBytes bool
establishedTunnelsCount int
}
type handshakeStateInfo struct {
activeAuthorizationIDs []string
authorizedAccessTypes []string
upstreamBytesPerSecond int64
downstreamBytesPerSecond int64
}
func newSshClient(
sshServer *sshServer,
sshListener *sshListener,
tunnelProtocol string,
serverPacketManipulation string,
geoIPData GeoIPData) *sshClient {
runCtx, stopRunning := context.WithCancel(context.Background())
// isFirstTunnelInSession is defaulted to true so that the pre-handshake
// traffic rules won't apply UnthrottleFirstTunnelOnly and negate any
// unthrottled bytes during the initial protocol negotiation.
client := &sshClient{
sshServer: sshServer,
sshListener: sshListener,
tunnelProtocol: tunnelProtocol,
serverPacketManipulation: serverPacketManipulation,
geoIPData: geoIPData,
isFirstTunnelInSession: true,
tcpPortForwardLRU: common.NewLRUConns(),
signalIssueSLOKs: make(chan struct{}, 1),
runCtx: runCtx,
stopRunning: stopRunning,
stopped: make(chan struct{}),
sendAlertRequests: make(chan protocol.AlertRequest, ALERT_REQUEST_QUEUE_BUFFER_SIZE),
sentAlertRequests: make(map[protocol.AlertRequest]bool),
}
client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
client.udpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
return client
}
func (sshClient *sshClient) run(
baseConn net.Conn, onSSHHandshakeFinished func()) {
// When run returns, the client has fully stopped, with all SSH state torn
// down and no port forwards or API requests in progress.
defer close(sshClient.stopped)
// onSSHHandshakeFinished must be called even if the SSH handshake is aborted.
defer func() {
if onSSHHandshakeFinished != nil {
onSSHHandshakeFinished()
}
}()
// Set initial traffic rules, pre-handshake, based on currently known info.
sshClient.setTrafficRules()
conn := baseConn
// Wrap the base client connection with an ActivityMonitoredConn which will
// terminate the connection if no data is received before the deadline. This
// timeout is in effect for the entire duration of the SSH connection. Clients
// must actively use the connection or send SSH keep alive requests to keep
// the connection active. Writes are not considered reliable activity indicators
// due to buffering.
activityConn, err := common.NewActivityMonitoredConn(
conn,
SSH_CONNECTION_READ_DEADLINE,
false,
nil,
nil)
if err != nil {
conn.Close()
if !isExpectedTunnelIOError(err) {
log.WithTraceFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
}
return
}
conn = activityConn
// Further wrap the connection in a rate limiting ThrottledConn.
throttledConn := common.NewThrottledConn(conn, sshClient.rateLimits())
conn = throttledConn
// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
// respect shutdownBroadcast and implement a specific handshake timeout.
// The timeout is to reclaim network resources in case the handshake takes
// too long.
type sshNewServerConnResult struct {
obfuscatedSSHConn *obfuscator.ObfuscatedSSHConn
sshConn *ssh.ServerConn
channels <-chan ssh.NewChannel
requests <-chan *ssh.Request
err error
}
resultChannel := make(chan *sshNewServerConnResult, 2)
var afterFunc *time.Timer
if sshClient.sshServer.support.Config.sshHandshakeTimeout > 0 {
afterFunc = time.AfterFunc(sshClient.sshServer.support.Config.sshHandshakeTimeout, func() {
resultChannel <- &sshNewServerConnResult{err: std_errors.New("ssh handshake timeout")}
})
}
go func(baseConn, conn net.Conn) {
sshServerConfig := &ssh.ServerConfig{
PasswordCallback: sshClient.passwordCallback,
AuthLogCallback: sshClient.authLogCallback,
ServerVersion: sshClient.sshServer.support.Config.SSHServerVersion,
}
sshServerConfig.AddHostKey(sshClient.sshServer.sshHostKey)
var err error
if protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
// With Encrypt-then-MAC hash algorithms, packet length is
// transmitted in plaintext, which aids in traffic analysis;
// clients may still send Encrypt-then-MAC algorithms in their
// KEX_INIT message, but do not select these algorithms.
//
// The exception is TUNNEL_PROTOCOL_SSH, which is intended to appear
// like SSH on the wire.
sshServerConfig.NoEncryptThenMACHash = true
} else {
// For TUNNEL_PROTOCOL_SSH only, randomize KEX.
if sshClient.sshServer.support.Config.ObfuscatedSSHKey != "" {
sshServerConfig.KEXPRNGSeed, err = protocol.DeriveSSHServerKEXPRNGSeed(
sshClient.sshServer.support.Config.ObfuscatedSSHKey)
if err != nil {
err = errors.Trace(err)
}
}
}
result := &sshNewServerConnResult{}
// Wrap the connection in an SSH deobfuscator when required.
if err == nil && protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
// Note: NewServerObfuscatedSSHConn blocks on network I/O
// TODO: ensure this won't block shutdown
result.obfuscatedSSHConn, err = obfuscator.NewServerObfuscatedSSHConn(
conn,
sshClient.sshServer.support.Config.ObfuscatedSSHKey,
sshClient.sshServer.obfuscatorSeedHistory,
func(clientIP string, err error, logFields common.LogFields) {
logIrregularTunnel(
sshClient.sshServer.support,
sshClient.sshListener.tunnelProtocol,
sshClient.sshListener.port,
clientIP,
errors.Trace(err),
LogFields(logFields))
})
if err != nil {
err = errors.Trace(err)
} else {
conn = result.obfuscatedSSHConn
}
// Now seed fragmentor, when present, with seed derived from
// initial obfuscator message. See tactics.Listener.Accept.
// This must preceed ssh.NewServerConn to ensure fragmentor
// is seeded before downstream bytes are written.
if err == nil && sshClient.tunnelProtocol == protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
if fragmentorConn, ok := baseConn.(*fragmentor.Conn); ok {
var fragmentorPRNG *prng.PRNG
fragmentorPRNG, err = result.obfuscatedSSHConn.GetDerivedPRNG("server-side-fragmentor")
if err != nil {
err = errors.Trace(err)
} else {
fragmentorConn.SetPRNG(fragmentorPRNG)
}
}
}
}
if err == nil {
result.sshConn, result.channels, result.requests, err =
ssh.NewServerConn(conn, sshServerConfig)
if err != nil {
err = errors.Trace(err)
}
}
result.err = err
resultChannel <- result
}(baseConn, conn)
var result *sshNewServerConnResult
select {
case result = <-resultChannel:
case <-sshClient.sshServer.shutdownBroadcast:
// Close() will interrupt an ongoing handshake
// TODO: wait for SSH handshake goroutines to exit before returning?
conn.Close()
return
}
if afterFunc != nil {
afterFunc.Stop()
}
if result.err != nil {
conn.Close()
// This is a Debug log due to noise. The handshake often fails due to I/O
// errors as clients frequently interrupt connections in progress when
// client-side load balancing completes a connection to a different server.
log.WithTraceFields(LogFields{"error": result.err}).Debug("handshake failed")
return
}
// The SSH handshake has finished successfully; notify now to allow other
// blocked SSH handshakes to proceed.
if onSSHHandshakeFinished != nil {
onSSHHandshakeFinished()
}
onSSHHandshakeFinished = nil
sshClient.Lock()
sshClient.sshConn = result.sshConn
sshClient.activityConn = activityConn
sshClient.throttledConn = throttledConn
sshClient.Unlock()
if !sshClient.sshServer.registerEstablishedClient(sshClient) {
conn.Close()
log.WithTrace().Warning("register failed")
return
}
sshClient.runTunnel(result.channels, result.requests)
// Note: sshServer.unregisterEstablishedClient calls sshClient.stop(),
// which also closes underlying transport Conn.
sshClient.sshServer.unregisterEstablishedClient(sshClient)
// Some conns report additional metrics. Meek conns report resiliency
// metrics and fragmentor.Conns report fragmentor configs.
//
// Limitation: for meek, GetMetrics from underlying fragmentor.Conn(s)
// should be called in order to log fragmentor metrics for meek sessions.
var additionalMetrics []LogFields
if metricsSource, ok := baseConn.(common.MetricsSource); ok {
additionalMetrics = append(
additionalMetrics, LogFields(metricsSource.GetMetrics()))
}
if result.obfuscatedSSHConn != nil {
additionalMetrics = append(
additionalMetrics, LogFields(result.obfuscatedSSHConn.GetMetrics()))
}
sshClient.logTunnel(additionalMetrics)
// Transfer OSL seed state -- the OSL progress -- from the closing
// client to the session cache so the client can resume its progress
// if it reconnects to this same server.
// Note: following setOSLConfig order of locking.
sshClient.Lock()
if sshClient.oslClientSeedState != nil {
sshClient.sshServer.oslSessionCacheMutex.Lock()
sshClient.oslClientSeedState.Hibernate()
sshClient.sshServer.oslSessionCache.Set(
sshClient.sessionID, sshClient.oslClientSeedState, cache.DefaultExpiration)
sshClient.sshServer.oslSessionCacheMutex.Unlock()
sshClient.oslClientSeedState = nil
}
sshClient.Unlock()
// Initiate cleanup of the GeoIP session cache. To allow for post-tunnel
// final status requests, the lifetime of cached GeoIP records exceeds the
// lifetime of the sshClient.
sshClient.sshServer.support.GeoIPService.MarkSessionCacheToExpire(sshClient.sessionID)
}
func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
expectedSessionIDLength := 2 * protocol.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
var sshPasswordPayload protocol.SSHPasswordPayload
err := json.Unmarshal(password, &sshPasswordPayload)
if err != nil {
// Backwards compatibility case: instead of a JSON payload, older clients
// send the hex encoded session ID prepended to the SSH password.
// Note: there's an even older case where clients don't send any session ID,
// but that's no longer supported.
if len(password) == expectedSessionIDLength+expectedSSHPasswordLength {
sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength])
sshPasswordPayload.SshPassword = string(password[expectedSessionIDLength:])
} else {
return nil, errors.Tracef("invalid password payload for %q", conn.User())
}
}
if !isHexDigits(sshClient.sshServer.support.Config, sshPasswordPayload.SessionId) ||
len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
return nil, errors.Tracef("invalid session ID for %q", conn.User())
}
userOk := (subtle.ConstantTimeCompare(
[]byte(conn.User()), []byte(sshClient.sshServer.support.Config.SSHUserName)) == 1)
passwordOk := (subtle.ConstantTimeCompare(
[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1)
if !userOk || !passwordOk {
return nil, errors.Tracef("invalid password for %q", conn.User())
}
sessionID := sshPasswordPayload.SessionId
// The GeoIP session cache will be populated if there was a previous tunnel
// with this session ID. This will be true up to GEOIP_SESSION_CACHE_TTL, which
// is currently much longer than the OSL session cache, another option to use if
// the GeoIP session cache is retired (the GeoIP session cache currently only
// supports legacy use cases).
isFirstTunnelInSession := !sshClient.sshServer.support.GeoIPService.InSessionCache(sessionID)
supportsServerRequests := common.Contains(
sshPasswordPayload.ClientCapabilities, protocol.CLIENT_CAPABILITY_SERVER_REQUESTS)
sshClient.Lock()
// After this point, these values are read-only as they are read
// without obtaining sshClient.Lock.
sshClient.sessionID = sessionID
sshClient.isFirstTunnelInSession = isFirstTunnelInSession
sshClient.supportsServerRequests = supportsServerRequests
geoIPData := sshClient.geoIPData
sshClient.Unlock()
// Store the GeoIP data associated with the session ID. This makes
// the GeoIP data available to the web server for web API requests.
// A cache that's distinct from the sshClient record is used to allow
// for or post-tunnel final status requests.
// If the client is reconnecting with the same session ID, this call
// will undo the expiry set by MarkSessionCacheToExpire.
sshClient.sshServer.support.GeoIPService.SetSessionCache(sessionID, geoIPData)
return nil, nil
}
func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
if err != nil {
if method == "none" && err.Error() == "ssh: no auth passed yet" {
// In this case, the callback invocation is noise from auth negotiation
return
}
// Note: here we previously logged messages for fail2ban to act on. This is no longer
// done as the complexity outweighs the benefits.
//
// - The SSH credential is not secret -- it's in the server entry. Attackers targeting
// the server likely already have the credential. On the other hand, random scanning and
// brute forcing is mitigated with high entropy random passwords, rate limiting
// (implemented on the host via iptables), and limited capabilities (the SSH session can
// only port forward).
//
// - fail2ban coverage was inconsistent; in the case of an unfronted meek protocol through
// an upstream proxy, the remote address is the upstream proxy, which should not be blocked.
// The X-Forwarded-For header cant be used instead as it may be forged and used to get IPs
// deliberately blocked; and in any case fail2ban adds iptables rules which can only block
// by direct remote IP, not by original client IP. Fronted meek has the same iptables issue.
//
// Random scanning and brute forcing of port 22 will result in log noise. To mitigate this,
// not every authentication failure is logged. A summary log is emitted periodically to
// retain some record of this activity in case this is relevant to, e.g., a performance
// investigation.
atomic.AddInt64(&sshClient.sshServer.authFailedCount, 1)
lastAuthLog := monotime.Time(atomic.LoadInt64(&sshClient.sshServer.lastAuthLog))
if monotime.Since(lastAuthLog) > SSH_AUTH_LOG_PERIOD {
now := int64(monotime.Now())
if atomic.CompareAndSwapInt64(&sshClient.sshServer.lastAuthLog, int64(lastAuthLog), now) {
count := atomic.SwapInt64(&sshClient.sshServer.authFailedCount, 0)
log.WithTraceFields(
LogFields{"lastError": err, "failedCount": count}).Warning("authentication failures")
}
}
log.WithTraceFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
} else {
log.WithTraceFields(LogFields{"error": err, "method": method}).Debug("authentication success")
}
}
// stop signals the ssh connection to shutdown. After sshConn.Wait returns,
// the SSH connection has terminated but sshClient.run may still be running and
// in the process of exiting.
//
// The shutdown process must complete rapidly and not, e.g., block on network
// I/O, as newly connecting clients need to await stop completion of any
// existing connection that shares the same session ID.
func (sshClient *sshClient) stop() {
sshClient.sshConn.Close()
sshClient.sshConn.Wait()
}
// awaitStopped will block until sshClient.run has exited, at which point all
// worker goroutines associated with the sshClient, including any in-flight
// API handlers, will have exited.
func (sshClient *sshClient) awaitStopped() {
<-sshClient.stopped
}
// runTunnel handles/dispatches new channels and new requests from the client.
// When the SSH client connection closes, both the channels and requests channels
// will close and runTunnel will exit.
func (sshClient *sshClient) runTunnel(
channels <-chan ssh.NewChannel,
requests <-chan *ssh.Request) {
waitGroup := new(sync.WaitGroup)
// Start client SSH API request handler
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
sshClient.handleSSHRequests(requests)
}()
// Start request senders
if sshClient.supportsServerRequests {
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
sshClient.runOSLSender()
}()
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
sshClient.runAlertSender()
}()
}
// Start the TCP port forward manager
// The queue size is set to the traffic rules (MaxTCPPortForwardCount +
// MaxTCPDialingPortForwardCount), which is a reasonable indication of resource
// limits per client; when that value is not set, a default is used.
// A limitation: this queue size is set once and doesn't change, for this client,
// when traffic rules are reloaded.
queueSize := sshClient.getTCPPortForwardQueueSize()
if queueSize == 0 {
queueSize = SSH_TCP_PORT_FORWARD_QUEUE_SIZE
}
newTCPPortForwards := make(chan *newTCPPortForward, queueSize)
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
sshClient.handleTCPPortForwards(waitGroup, newTCPPortForwards)
}()
// Handle new channel (port forward) requests from the client.
for newChannel := range channels {
switch newChannel.ChannelType() {
case protocol.RANDOM_STREAM_CHANNEL_TYPE:
sshClient.handleNewRandomStreamChannel(waitGroup, newChannel)
case protocol.PACKET_TUNNEL_CHANNEL_TYPE:
sshClient.handleNewPacketTunnelChannel(waitGroup, newChannel)
case "direct-tcpip":
sshClient.handleNewTCPPortForwardChannel(waitGroup, newChannel, newTCPPortForwards)
default:
sshClient.rejectNewChannel(newChannel,
fmt.Sprintf("unknown or unsupported channel type: %s", newChannel.ChannelType()))
}
}
// The channel loop is interrupted by a client
// disconnect or by calling sshClient.stop().
// Stop the TCP port forward manager
close(newTCPPortForwards)
// Stop all other worker goroutines
sshClient.stopRunning()
if sshClient.sshServer.support.Config.RunPacketTunnel {
// PacketTunnelServer.ClientDisconnected stops packet tunnel workers.
sshClient.sshServer.support.PacketTunnelServer.ClientDisconnected(
sshClient.sessionID)
}
waitGroup.Wait()
sshClient.cleanupAuthorizations()
}
func (sshClient *sshClient) handleSSHRequests(requests <-chan *ssh.Request) {
for request := range requests {
// Requests are processed serially; API responses must be sent in request order.
var responsePayload []byte
var err error
if request.Type == "keepalive@openssh.com" {
// SSH keep alive round trips are used as speed test samples.
responsePayload, err = tactics.MakeSpeedTestResponse(
SSH_KEEP_ALIVE_PAYLOAD_MIN_BYTES, SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES)
} else {
// All other requests are assumed to be API requests.
sshClient.Lock()
authorizedAccessTypes := sshClient.handshakeState.authorizedAccessTypes
sshClient.Unlock()
// Note: unlock before use is only safe as long as referenced sshClient data,
// such as slices in handshakeState, is read-only after initially set.
responsePayload, err = sshAPIRequestHandler(
sshClient.sshServer.support,
sshClient.geoIPData,
authorizedAccessTypes,
request.Type,
request.Payload)
}
if err == nil {
err = request.Reply(true, responsePayload)
} else {
log.WithTraceFields(LogFields{"error": err}).Warning("request failed")
err = request.Reply(false, nil)
}
if err != nil {
if !isExpectedTunnelIOError(err) {
log.WithTraceFields(LogFields{"error": err}).Warning("response failed")
}
}
}
}
type newTCPPortForward struct {
enqueueTime time.Time
hostToConnect string
portToConnect int
newChannel ssh.NewChannel
}
func (sshClient *sshClient) handleTCPPortForwards(
waitGroup *sync.WaitGroup,
newTCPPortForwards chan *newTCPPortForward) {
// Lifecycle of a TCP port forward:
//
// 1. A "direct-tcpip" SSH request is received from the client.
//
// A new TCP port forward request is enqueued. The queue delivers TCP port
// forward requests to the TCP port forward manager, which enforces the TCP
// port forward dial limit.
//
// Enqueuing new requests allows for reading further SSH requests from the
// client without blocking when the dial limit is hit; this is to permit new
// UDP/udpgw port forwards to be restablished without delay. The maximum size
// of the queue enforces a hard cap on resources consumed by a client in the
// pre-dial phase. When the queue is full, new TCP port forwards are
// immediately rejected.
//
// 2. The TCP port forward manager dequeues the request.
//
// The manager calls dialingTCPPortForward(), which increments
// concurrentDialingPortForwardCount, and calls
// isTCPDialingPortForwardLimitExceeded() to check the concurrent dialing
// count.
//
// The manager enforces the concurrent TCP dial limit: when at the limit, the
// manager blocks waiting for the number of dials to drop below the limit before
// dispatching the request to handleTCPPortForward(), which will run in its own
// goroutine and will dial and relay the port forward.
//
// The block delays the current request and also halts dequeuing of subsequent
// requests and could ultimately cause requests to be immediately rejected if
// the queue fills. These actions are intended to apply back pressure when
// upstream network resources are impaired.
//
// The time spent in the queue is deducted from the port forward's dial timeout.
// The time spent blocking while at the dial limit is similarly deducted from
// the dial timeout. If the dial timeout has expired before the dial begins, the
// port forward is rejected and a stat is recorded.
//
// 3. handleTCPPortForward() performs the port forward dial and relaying.
//
// a. Dial the target, using the dial timeout remaining after queue and blocking
// time is deducted.
//
// b. If the dial fails, call abortedTCPPortForward() to decrement
// concurrentDialingPortForwardCount, freeing up a dial slot.
//
// c. If the dial succeeds, call establishedPortForward(), which decrements
// concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
// the "established" port forward count.
//
// d. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
// concurrentPortForwardCount, the number of _established_ TCP port forwards.
// If the limit is exceeded, the LRU established TCP port forward is closed and
// the newly established TCP port forward proceeds. This LRU logic allows some
// dangling resource consumption (e.g., TIME_WAIT) while providing a better
// experience for clients.
//
// e. Relay data.
//
// f. Call closedPortForward() which decrements concurrentPortForwardCount and
// records bytes transferred.
for newPortForward := range newTCPPortForwards {
remainingDialTimeout :=
time.Duration(sshClient.getDialTCPPortForwardTimeoutMilliseconds())*time.Millisecond -
time.Since(newPortForward.enqueueTime)
if remainingDialTimeout <= 0 {
sshClient.updateQualityMetricsWithRejectedDialingLimit()
sshClient.rejectNewChannel(
newPortForward.newChannel, "TCP port forward timed out in queue")
continue
}
// Reserve a TCP dialing slot.
//
// TOCTOU note: important to increment counts _before_ checking limits; otherwise,
// the client could potentially consume excess resources by initiating many port
// forwards concurrently.
sshClient.dialingTCPPortForward()
// When max dials are in progress, wait up to remainingDialTimeout for dialing
// to become available. This blocks all dequeing.
if sshClient.isTCPDialingPortForwardLimitExceeded() {
blockStartTime := time.Now()
ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
sshClient.setTCPPortForwardDialingAvailableSignal(cancelCtx)
<-ctx.Done()
sshClient.setTCPPortForwardDialingAvailableSignal(nil)
cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
remainingDialTimeout -= time.Since(blockStartTime)
}
if remainingDialTimeout <= 0 {
// Release the dialing slot here since handleTCPChannel() won't be called.
sshClient.abortedTCPPortForward()
sshClient.updateQualityMetricsWithRejectedDialingLimit()
sshClient.rejectNewChannel(
newPortForward.newChannel, "TCP port forward timed out before dialing")
continue
}
// Dial and relay the TCP port forward. handleTCPChannel is run in its own worker goroutine.
// handleTCPChannel will release the dialing slot reserved by dialingTCPPortForward(); and
// will deal with remainingDialTimeout <= 0.
waitGroup.Add(1)
go func(remainingDialTimeout time.Duration, newPortForward *newTCPPortForward) {
defer waitGroup.Done()
sshClient.handleTCPChannel(
remainingDialTimeout,
newPortForward.hostToConnect,
newPortForward.portToConnect,
newPortForward.newChannel)
}(remainingDialTimeout, newPortForward)
}
}
func (sshClient *sshClient) handleNewRandomStreamChannel(
waitGroup *sync.WaitGroup, newChannel ssh.NewChannel) {
// A random stream channel returns the requested number of bytes -- random
// bytes -- to the client while also consuming and discarding bytes sent
// by the client.
//
// One use case for the random stream channel is a liveness test that the
// client performs to confirm that the tunnel is live. As the liveness
// test is performed in the concurrent establishment phase, before
// selecting a single candidate for handshake, the random stream channel
// is available pre-handshake, albeit with additional restrictions.
//
// The random stream is subject to throttling in traffic rules; for
// unthrottled liveness tests, set initial Read/WriteUnthrottledBytes as
// required. The random stream maximum count and response size cap
// mitigate clients abusing the facility to waste server resources.
//
// Like all other channels, this channel type is handled asynchronously,
// so it's possible to run at any point in the tunnel lifecycle.
//
// Up/downstream byte counts don't include SSH packet and request
// marshalling overhead.
var request protocol.RandomStreamRequest
err := json.Unmarshal(newChannel.ExtraData(), &request)
if err != nil {
sshClient.rejectNewChannel(newChannel, fmt.Sprintf("invalid request: %s", err))
return
}
if request.UpstreamBytes > RANDOM_STREAM_MAX_BYTES {
sshClient.rejectNewChannel(newChannel,
fmt.Sprintf("invalid upstream bytes: %d", request.UpstreamBytes))
return
}
if request.DownstreamBytes > RANDOM_STREAM_MAX_BYTES {
sshClient.rejectNewChannel(newChannel,
fmt.Sprintf("invalid downstream bytes: %d", request.DownstreamBytes))
return
}
var metrics *randomStreamMetrics
sshClient.Lock()
if !sshClient.handshakeState.completed {
metrics = &sshClient.preHandshakeRandomStreamMetrics
} else {
metrics = &sshClient.postHandshakeRandomStreamMetrics
}
countOk := true
if !sshClient.handshakeState.completed &&
metrics.count >= PRE_HANDSHAKE_RANDOM_STREAM_MAX_COUNT {
countOk = false
} else {
metrics.count++
}
sshClient.Unlock()
if !countOk {
sshClient.rejectNewChannel(newChannel, "max count exceeded")
return
}
channel, requests, err := newChannel.Accept()
if err != nil {
if !isExpectedTunnelIOError(err) {
log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed")
}
return
}
go ssh.DiscardRequests(requests)
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
received := 0
sent := 0
if request.UpstreamBytes > 0 {
n, err := io.CopyN(ioutil.Discard, channel, int64(request.UpstreamBytes))
received = int(n)
if err != nil {
if !isExpectedTunnelIOError(err) {
log.WithTraceFields(LogFields{"error": err}).Warning("receive failed")
}
// Fall through and record any bytes received...
}
}
if request.DownstreamBytes > 0 {
n, err := io.CopyN(channel, rand.Reader, int64(request.DownstreamBytes))
sent = int(n)
if err != nil {
if !isExpectedTunnelIOError(err) {
log.WithTraceFields(LogFields{"error": err}).Warning("send failed")
}
}
}
sshClient.Lock()
metrics.upstreamBytes += request.UpstreamBytes
metrics.receivedUpstreamBytes += received
metrics.downstreamBytes += request.DownstreamBytes
metrics.sentDownstreamBytes += sent
sshClient.Unlock()
channel.Close()
}()
}
func (sshClient *sshClient) handleNewPacketTunnelChannel(
waitGroup *sync.WaitGroup, newChannel ssh.NewChannel) {
// packet tunnel channels are handled by the packet tunnel server
// component. Each client may have at most one packet tunnel channel.
if !sshClient.sshServer.support.Config.RunPacketTunnel {
sshClient.rejectNewChannel(newChannel, "unsupported packet tunnel channel type")
return
}
// Accept this channel immediately. This channel will replace any
// previously existing packet tunnel channel for this client.
packetTunnelChannel, requests, err := newChannel.Accept()
if err != nil {
if !isExpectedTunnelIOError(err) {
log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed")
}
return
}
go ssh.DiscardRequests(requests)
sshClient.setPacketTunnelChannel(packetTunnelChannel)
// PacketTunnelServer will run the client's packet tunnel. If necessary, ClientConnected
// will stop packet tunnel workers for any previous packet tunnel channel.
checkAllowedTCPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
return sshClient.isPortForwardPermitted(portForwardTypeTCP, upstreamIPAddress, port)
}
checkAllowedUDPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
return sshClient.isPortForwardPermitted(portForwardTypeUDP, upstreamIPAddress, port)
}
checkAllowedDomainFunc := func(domain string) bool {
ok, _ := sshClient.isDomainPermitted(domain)
return ok
}
flowActivityUpdaterMaker := func(
upstreamHostname string, upstreamIPAddress net.IP) []tun.FlowActivityUpdater {
var updaters []tun.FlowActivityUpdater
oslUpdater := sshClient.newClientSeedPortForward(upstreamIPAddress)
if oslUpdater != nil {
updaters = append(updaters, oslUpdater)
}
return updaters
}
metricUpdater := func(
TCPApplicationBytesDown, TCPApplicationBytesUp,
UDPApplicationBytesDown, UDPApplicationBytesUp int64) {
sshClient.Lock()
sshClient.tcpTrafficState.bytesDown += TCPApplicationBytesDown
sshClient.tcpTrafficState.bytesUp += TCPApplicationBytesUp
sshClient.udpTrafficState.bytesDown += UDPApplicationBytesDown
sshClient.udpTrafficState.bytesUp += UDPApplicationBytesUp
sshClient.Unlock()
}
err = sshClient.sshServer.support.PacketTunnelServer.ClientConnected(
sshClient.sessionID,
packetTunnelChannel,
checkAllowedTCPPortFunc,
checkAllowedUDPPortFunc,
checkAllowedDomainFunc,
flowActivityUpdaterMaker,
metricUpdater)
if err != nil {
log.WithTraceFields(LogFields{"error": err}).Warning("start packet tunnel client failed")
sshClient.setPacketTunnelChannel(nil)
}
}
func (sshClient *sshClient) handleNewTCPPortForwardChannel(
waitGroup *sync.WaitGroup, newChannel ssh.NewChannel,
newTCPPortForwards chan *newTCPPortForward) {
// udpgw client connections are dispatched immediately (clients use this for
// DNS, so it's essential to not block; and only one udpgw connection is
// retained at a time).
//
// All other TCP port forwards are dispatched via the TCP port forward
// manager queue.
// http://tools.ietf.org/html/rfc4254#section-7.2
var directTcpipExtraData struct {
HostToConnect string
PortToConnect uint32
OriginatorIPAddress string
OriginatorPort uint32
}
err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
if err != nil {
sshClient.rejectNewChannel(newChannel, "invalid extra data")
return
}
// Intercept TCP port forwards to a specified udpgw server and handle directly.
// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
if isUDPChannel {
// Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
// own worker goroutine.
waitGroup.Add(1)
go func(channel ssh.NewChannel) {
defer waitGroup.Done()
sshClient.handleUDPChannel(channel)
}(newChannel)
} else {
// Dispatch via TCP port forward manager. When the queue is full, the channel
// is immediately rejected.
tcpPortForward := &newTCPPortForward{
enqueueTime: time.Now(),
hostToConnect: directTcpipExtraData.HostToConnect,
portToConnect: int(directTcpipExtraData.PortToConnect),
newChannel: newChannel,
}
select {
case newTCPPortForwards <- tcpPortForward:
default:
sshClient.updateQualityMetricsWithRejectedDialingLimit()
sshClient.rejectNewChannel(newChannel, "TCP port forward dial queue full")
}
}
}
func (sshClient *sshClient) cleanupAuthorizations() {
sshClient.Lock()
if sshClient.releaseAuthorizations != nil {
sshClient.releaseAuthorizations()
}
if sshClient.stopTimer != nil {
sshClient.stopTimer.Stop()
}
sshClient.Unlock()
}
// setPacketTunnelChannel sets the single packet tunnel channel
// for this sshClient. Any existing packet tunnel channel is
// closed.
func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
sshClient.Lock()
if sshClient.packetTunnelChannel != nil {
sshClient.packetTunnelChannel.Close()
}
sshClient.packetTunnelChannel = channel
sshClient.Unlock()
}
// setUDPChannel sets the single UDP channel for this sshClient.
// Each sshClient may have only one concurrent UDP channel. Each
// UDP channel multiplexes many UDP port forwards via the udpgw
// protocol. Any existing UDP channel is closed.
func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
sshClient.Lock()
if sshClient.udpChannel != nil {
sshClient.udpChannel.Close()
}
sshClient.udpChannel = channel
sshClient.Unlock()
}
var serverTunnelStatParams = append(
[]requestParamSpec{
{"last_connected", isLastConnected, requestParamOptional},
{"establishment_duration", isIntString, requestParamOptional}},
baseSessionAndDialParams...)
func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
// Note: reporting duration based on last confirmed data transfer, which
// is reads for sshClient.activityConn.GetActiveDuration(), and not
// connection closing is important for protocols such as meek. For
// meek, the connection remains open until the HTTP session expires,
// which may be some time after the tunnel has closed. (The meek
// protocol has no allowance for signalling payload EOF, and even if
// it did the client may not have the opportunity to send a final
// request with an EOF flag set.)
sshClient.Lock()
logFields := getRequestLogFields(
"server_tunnel",
sshClient.geoIPData,
sshClient.handshakeState.authorizedAccessTypes,
sshClient.handshakeState.apiParams,
serverTunnelStatParams)
// "relay_protocol" is sent with handshake API parameters. In pre-
// handshake logTunnel cases, this value is not yet known. As
// sshClient.tunnelProtocol is authoritative, set this value
// unconditionally, overwriting any value from handshake.
logFields["relay_protocol"] = sshClient.tunnelProtocol
if sshClient.serverPacketManipulation != "" {
logFields["server_packet_manipulation"] = sshClient.serverPacketManipulation
}
if sshClient.sshListener.BPFProgramName != "" {
logFields["server_bpf"] = sshClient.sshListener.BPFProgramName
}
logFields["session_id"] = sshClient.sessionID
logFields["handshake_completed"] = sshClient.handshakeState.completed
logFields["start_time"] = sshClient.activityConn.GetStartTime()
logFields["duration"] = int64(sshClient.activityConn.GetActiveDuration() / time.Millisecond)
logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
logFields["pre_handshake_random_stream_count"] = sshClient.preHandshakeRandomStreamMetrics.count
logFields["pre_handshake_random_stream_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.upstreamBytes
logFields["pre_handshake_random_stream_received_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.receivedUpstreamBytes
logFields["pre_handshake_random_stream_downstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.downstreamBytes
logFields["pre_handshake_random_stream_sent_downstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.sentDownstreamBytes
logFields["random_stream_count"] = sshClient.postHandshakeRandomStreamMetrics.count
logFields["random_stream_upstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.upstreamBytes
logFields["random_stream_received_upstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.receivedUpstreamBytes
logFields["random_stream_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.downstreamBytes
logFields["random_stream_sent_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.sentDownstreamBytes
// Pre-calculate a total-tunneled-bytes field. This total is used
// extensively in analytics and is more performant when pre-calculated.
logFields["bytes"] = sshClient.tcpTrafficState.bytesUp +
sshClient.tcpTrafficState.bytesDown +
sshClient.udpTrafficState.bytesUp +
sshClient.udpTrafficState.bytesDown
// Merge in additional metrics from the optional metrics source
for _, metrics := range additionalMetrics {
for name, value := range metrics {
// Don't overwrite any basic fields
if logFields[name] == nil {
logFields[name] = value
}
}
}
sshClient.Unlock()
// Note: unlock before use is only safe as long as referenced sshClient data,
// such as slices in handshakeState, is read-only after initially set.
log.LogRawFieldsWithTimestamp(logFields)
}
var blocklistHitsStatParams = []requestParamSpec{
{"propagation_channel_id", isHexDigits, 0},
{"sponsor_id", isHexDigits, 0},
{"client_version", isIntString, requestParamLogStringAsInt},
{"client_platform", isClientPlatform, 0},
{"client_build_rev", isHexDigits, requestParamOptional},
{"tunnel_whole_device", isBooleanFlag, requestParamOptional | requestParamLogFlagAsBool},
{"device_region", isAnyString, requestParamOptional},
{"egress_region", isRegionCode, requestParamOptional},
{"session_id", isHexDigits, 0},
{"last_connected", isLastConnected, requestParamOptional},
}
func (sshClient *sshClient) logBlocklistHits(IP net.IP, domain string, tags []BlocklistTag) {
sshClient.Lock()
logFields := getRequestLogFields(
"server_blocklist_hit",
sshClient.geoIPData,
sshClient.handshakeState.authorizedAccessTypes,
sshClient.handshakeState.apiParams,
blocklistHitsStatParams)
logFields["session_id"] = sshClient.sessionID
// Note: see comment in logTunnel regarding unlock and concurrent access.
sshClient.Unlock()
for _, tag := range tags {
if IP != nil {
logFields["blocklist_ip_address"] = IP.String()
}
if domain != "" {
logFields["blocklist_domain"] = domain
}
logFields["blocklist_source"] = tag.Source
logFields["blocklist_subject"] = tag.Subject
log.LogRawFieldsWithTimestamp(logFields)
}
}
func (sshClient *sshClient) runOSLSender() {
for {
// Await a signal that there are SLOKs to send
// TODO: use reflect.SelectCase, and optionally await timer here?
select {
case <-sshClient.signalIssueSLOKs:
case <-sshClient.runCtx.Done():
return
}
retryDelay := SSH_SEND_OSL_INITIAL_RETRY_DELAY
for {
err := sshClient.sendOSLRequest()
if err == nil {
break
}
if !isExpectedTunnelIOError(err) {
log.WithTraceFields(LogFields{"error": err}).Warning("sendOSLRequest failed")
}
// If the request failed, retry after a delay (with exponential backoff)
// or when signaled that there are additional SLOKs to send
retryTimer := time.NewTimer(retryDelay)
select {
case <-retryTimer.C:
case <-sshClient.signalIssueSLOKs:
case <-sshClient.runCtx.Done():
retryTimer.Stop()
return
}
retryTimer.Stop()
retryDelay *= SSH_SEND_OSL_RETRY_FACTOR
}
}
}
// sendOSLRequest will invoke osl.GetSeedPayload to issue SLOKs and
// generate a payload, and send an OSL request to the client when
// there are new SLOKs in the payload.
func (sshClient *sshClient) sendOSLRequest() error {
seedPayload := sshClient.getOSLSeedPayload()
// Don't send when no SLOKs. This will happen when signalIssueSLOKs
// is received but no new SLOKs are issued.
if len(seedPayload.SLOKs) == 0 {
return nil
}
oslRequest := protocol.OSLRequest{
SeedPayload: seedPayload,
}
requestPayload, err := json.Marshal(oslRequest)
if err != nil {
return errors.Trace(err)
}
ok, _, err := sshClient.sshConn.SendRequest(
protocol.PSIPHON_API_OSL_REQUEST_NAME,
true,
requestPayload)
if err != nil {
return errors.Trace(err)
}
if !ok {
return errors.TraceNew("client rejected request")
}
sshClient.clearOSLSeedPayload()
return nil
}
// runAlertSender dequeues and sends alert requests to the client. As these
// alerts are informational, there is no retry logic and no SSH client
// acknowledgement (wantReply) is requested. This worker scheme allows
// nonconcurrent components including udpgw and packet tunnel to enqueue
// alerts without blocking their traffic processing.
func (sshClient *sshClient) runAlertSender() {
for {
select {
case <-sshClient.runCtx.Done():
return
case request := <-sshClient.sendAlertRequests:
payload, err := json.Marshal(request)
if err != nil {
log.WithTraceFields(LogFields{"error": err}).Warning("Marshal failed")
break
}
_, _, err = sshClient.sshConn.SendRequest(
protocol.PSIPHON_API_ALERT_REQUEST_NAME,
false,
payload)
if err != nil && !isExpectedTunnelIOError(err) {
log.WithTraceFields(LogFields{"error": err}).Warning("SendRequest failed")
break
}
sshClient.Lock()
sshClient.sentAlertRequests[request] = true
sshClient.Unlock()
}
}
}
// enqueueAlertRequest enqueues an alert request to be sent to the client.
// Only one request is sent per tunnel per protocol.AlertRequest value;
// subsequent alerts with the same value are dropped. enqueueAlertRequest will
// not block until the queue exceeds ALERT_REQUEST_QUEUE_BUFFER_SIZE.
func (sshClient *sshClient) enqueueAlertRequest(request protocol.AlertRequest) {
sshClient.Lock()
if sshClient.sentAlertRequests[request] {
sshClient.Unlock()
return
}
sshClient.Unlock()
select {
case <-sshClient.runCtx.Done():
case sshClient.sendAlertRequests <- request:
}
}
func (sshClient *sshClient) enqueueDisallowedTrafficAlertRequest() {
sshClient.enqueueAlertRequest(protocol.AlertRequest{
Reason: protocol.PSIPHON_API_ALERT_DISALLOWED_TRAFFIC,
})
}
func (sshClient *sshClient) enqueueUnsafeTrafficAlertRequest(tags []BlocklistTag) {
for _, tag := range tags {
sshClient.enqueueAlertRequest(protocol.AlertRequest{
Reason: protocol.PSIPHON_API_ALERT_UNSAFE_TRAFFIC,
Subject: tag.Subject,
})
}
}
func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, logMessage string) {
// We always return the reject reason "Prohibited":
// - Traffic rules and connection limits may prohibit the connection.
// - External firewall rules may prohibit the connection, and this is not currently
// distinguishable from other failure modes.
// - We limit the failure information revealed to the client.
reason := ssh.Prohibited
// Note: Debug level, as logMessage may contain user traffic destination address information
log.WithTraceFields(
LogFields{
"channelType": newChannel.ChannelType(),
"logMessage": logMessage,
"rejectReason": reason.String(),
}).Debug("reject new channel")
// Note: logMessage is internal, for logging only; just the reject reason is sent to the client.
newChannel.Reject(reason, reason.String())
}
// setHandshakeState records that a client has completed a handshake API request.
// Some parameters from the handshake request may be used in future traffic rule
// selection. Port forwards are disallowed until a handshake is complete. The
// handshake parameters are included in the session summary log recorded in
// sshClient.stop().
func (sshClient *sshClient) setHandshakeState(
state handshakeState,
authorizations []string) (*handshakeStateInfo, error) {
sshClient.Lock()
completed := sshClient.handshakeState.completed
if !completed {
sshClient.handshakeState = state
}
sshClient.Unlock()
// Client must only perform one handshake
if completed {
return nil, errors.TraceNew("handshake already completed")
}
// Verify the authorizations submitted by the client. Verified, active
// (non-expired) access types will be available for traffic rules
// filtering.
//
// When an authorization is active but expires while the client is
// connected, the client is disconnected to ensure the access is reset.
// This is implemented by setting a timer to perform the disconnect at the
// expiry time of the soonest expiring authorization.
//
// sshServer.authorizationSessionIDs tracks the unique mapping of active
// authorization IDs to client session IDs and is used to detect and
// prevent multiple malicious clients from reusing a single authorization
// (within the scope of this server).
// authorizationIDs and authorizedAccessTypes are returned to the client
// and logged, respectively; initialize to empty lists so the
// protocol/logs don't need to handle 'null' values.
authorizationIDs := make([]string, 0)
authorizedAccessTypes := make([]string, 0)
var stopTime time.Time
for i, authorization := range authorizations {
// This sanity check mitigates malicious clients causing excess CPU use.
if i >= MAX_AUTHORIZATIONS {
log.WithTrace().Warning("too many authorizations")
break
}
verifiedAuthorization, err := accesscontrol.VerifyAuthorization(
&sshClient.sshServer.support.Config.AccessControlVerificationKeyRing,
authorization)
if err != nil {
log.WithTraceFields(
LogFields{"error": err}).Warning("verify authorization failed")
continue
}
authorizationID := base64.StdEncoding.EncodeToString(verifiedAuthorization.ID)
if common.Contains(authorizedAccessTypes, verifiedAuthorization.AccessType) {
log.WithTraceFields(
LogFields{"accessType": verifiedAuthorization.AccessType}).Warning("duplicate authorization access type")
continue
}
authorizationIDs = append(authorizationIDs, authorizationID)
authorizedAccessTypes = append(authorizedAccessTypes, verifiedAuthorization.AccessType)
if stopTime.IsZero() || stopTime.After(verifiedAuthorization.Expires) {
stopTime = verifiedAuthorization.Expires
}
}
// Associate all verified authorizationIDs with this client's session ID.
// Handle cases where previous associations exist:
//
// - Multiple malicious clients reusing a single authorization. In this
// case, authorizations are revoked from the previous client.
//
// - The client reconnected with a new session ID due to user toggling.
// This case is expected due to server affinity. This cannot be
// distinguished from the previous case and the same action is taken;
// this will have no impact on a legitimate client as the previous
// session is dangling.
//
// - The client automatically reconnected with the same session ID. This
// case is not expected as sshServer.registerEstablishedClient
// synchronously calls sshClient.releaseAuthorizations; as a safe guard,
// this case is distinguished and no revocation action is taken.
sshClient.sshServer.authorizationSessionIDsMutex.Lock()
for _, authorizationID := range authorizationIDs {
sessionID, ok := sshClient.sshServer.authorizationSessionIDs[authorizationID]
if ok && sessionID != sshClient.sessionID {
logFields := LogFields{
"event_name": "irregular_tunnel",
"tunnel_error": "duplicate active authorization",
"duplicate_authorization_id": authorizationID,
}
sshClient.geoIPData.SetLogFields(logFields)
duplicateGeoIPData := sshClient.sshServer.support.GeoIPService.GetSessionCache(sessionID)
if duplicateGeoIPData != sshClient.geoIPData {
duplicateGeoIPData.SetLogFieldsWithPrefix("duplicate_authorization_", logFields)
}
log.LogRawFieldsWithTimestamp(logFields)
// Invoke asynchronously to avoid deadlocks.
// TODO: invoke only once for each distinct sessionID?
go sshClient.sshServer.revokeClientAuthorizations(sessionID)
}
sshClient.sshServer.authorizationSessionIDs[authorizationID] = sshClient.sessionID
}
sshClient.sshServer.authorizationSessionIDsMutex.Unlock()
if len(authorizationIDs) > 0 {
sshClient.Lock()
// Make the authorizedAccessTypes available for traffic rules filtering.
sshClient.handshakeState.activeAuthorizationIDs = authorizationIDs
sshClient.handshakeState.authorizedAccessTypes = authorizedAccessTypes
// On exit, sshClient.runTunnel will call releaseAuthorizations, which
// will release the authorization IDs so the client can reconnect and
// present the same authorizations again. sshClient.runTunnel will
// also cancel the stopTimer in case it has not yet fired.
// Note: termination of the stopTimer goroutine is not synchronized.
sshClient.releaseAuthorizations = func() {
sshClient.sshServer.authorizationSessionIDsMutex.Lock()
for _, authorizationID := range authorizationIDs {
sessionID, ok := sshClient.sshServer.authorizationSessionIDs[authorizationID]
if ok && sessionID == sshClient.sessionID {
delete(sshClient.sshServer.authorizationSessionIDs, authorizationID)
}
}
sshClient.sshServer.authorizationSessionIDsMutex.Unlock()
}
sshClient.stopTimer = time.AfterFunc(
time.Until(stopTime),
func() {
sshClient.stop()
})
sshClient.Unlock()
}
upstreamBytesPerSecond, downstreamBytesPerSecond := sshClient.setTrafficRules()
sshClient.setOSLConfig()
return &handshakeStateInfo{
activeAuthorizationIDs: authorizationIDs,
authorizedAccessTypes: authorizedAccessTypes,
upstreamBytesPerSecond: upstreamBytesPerSecond,
downstreamBytesPerSecond: downstreamBytesPerSecond,
}, nil
}
// getHandshaked returns whether the client has completed a handshake API
// request and whether the traffic rules that were selected after the
// handshake immediately exhaust the client.
//
// When the client is immediately exhausted it will be closed; but this
// takes effect asynchronously. The "exhausted" return value is used to
// prevent API requests by clients that will close.
func (sshClient *sshClient) getHandshaked() (bool, bool) {
sshClient.Lock()
defer sshClient.Unlock()
completed := sshClient.handshakeState.completed
exhausted := false
// Notes:
// - "Immediately exhausted" is when CloseAfterExhausted is set and
// either ReadUnthrottledBytes or WriteUnthrottledBytes starts from
// 0, so no bytes would be read or written. This check does not
// examine whether 0 bytes _remain_ in the ThrottledConn.
// - This check is made against the current traffic rules, which
// could have changed in a hot reload since the handshake.
if completed &&
*sshClient.trafficRules.RateLimits.CloseAfterExhausted &&
(*sshClient.trafficRules.RateLimits.ReadUnthrottledBytes == 0 ||
*sshClient.trafficRules.RateLimits.WriteUnthrottledBytes == 0) {
exhausted = true
}
return completed, exhausted
}
func (sshClient *sshClient) updateAPIParameters(
apiParams common.APIParameters) {
sshClient.Lock()
defer sshClient.Unlock()
// Only update after handshake has initialized API params.
if !sshClient.handshakeState.completed {
return
}
for name, value := range apiParams {
sshClient.handshakeState.apiParams[name] = value
}
}
func (sshClient *sshClient) expectDomainBytes() bool {
sshClient.Lock()
defer sshClient.Unlock()
return sshClient.handshakeState.expectDomainBytes
}
// setTrafficRules resets the client's traffic rules based on the latest server config
// and client properties. As sshClient.trafficRules may be reset by a concurrent
// goroutine, trafficRules must only be accessed within the sshClient mutex.
func (sshClient *sshClient) setTrafficRules() (int64, int64) {
sshClient.Lock()
defer sshClient.Unlock()
isFirstTunnelInSession := sshClient.isFirstTunnelInSession &&
sshClient.handshakeState.establishedTunnelsCount == 0
sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
isFirstTunnelInSession,
sshClient.tunnelProtocol,
sshClient.geoIPData,
sshClient.handshakeState)
if sshClient.throttledConn != nil {
// Any existing throttling state is reset.
sshClient.throttledConn.SetLimits(
sshClient.trafficRules.RateLimits.CommonRateLimits())
}
return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
*sshClient.trafficRules.RateLimits.WriteBytesPerSecond
}
// setOSLConfig resets the client's OSL seed state based on the latest OSL config
// As sshClient.oslClientSeedState may be reset by a concurrent goroutine,
// oslClientSeedState must only be accessed within the sshClient mutex.
func (sshClient *sshClient) setOSLConfig() {
sshClient.Lock()
defer sshClient.Unlock()
propagationChannelID, err := getStringRequestParam(
sshClient.handshakeState.apiParams, "propagation_channel_id")
if err != nil {
// This should not fail as long as client has sent valid handshake
return
}
// Use a cached seed state if one is found for the client's
// session ID. This enables resuming progress made in a previous
// tunnel.
// Note: go-cache is already concurency safe; the additional mutex
// is necessary to guarantee that Get/Delete is atomic; although in
// practice no two concurrent clients should ever supply the same
// session ID.
sshClient.sshServer.oslSessionCacheMutex.Lock()
oslClientSeedState, found := sshClient.sshServer.oslSessionCache.Get(sshClient.sessionID)
if found {
sshClient.sshServer.oslSessionCache.Delete(sshClient.sessionID)
sshClient.sshServer.oslSessionCacheMutex.Unlock()
sshClient.oslClientSeedState = oslClientSeedState.(*osl.ClientSeedState)
sshClient.oslClientSeedState.Resume(sshClient.signalIssueSLOKs)
return
}
sshClient.sshServer.oslSessionCacheMutex.Unlock()
// Two limitations when setOSLConfig() is invoked due to an
// OSL config hot reload:
//
// 1. any partial progress towards SLOKs is lost.
//
// 2. all existing osl.ClientSeedPortForwards for existing
// port forwards will not send progress to the new client
// seed state.
sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
sshClient.geoIPData.Country,
propagationChannelID,
sshClient.signalIssueSLOKs)
}
// newClientSeedPortForward will return nil when no seeding is
// associated with the specified ipAddress.
func (sshClient *sshClient) newClientSeedPortForward(ipAddress net.IP) *osl.ClientSeedPortForward {
sshClient.Lock()
defer sshClient.Unlock()
// Will not be initialized before handshake.
if sshClient.oslClientSeedState == nil {
return nil
}
return sshClient.oslClientSeedState.NewClientSeedPortForward(ipAddress)
}
// getOSLSeedPayload returns a payload containing all seeded SLOKs for
// this client's session.
func (sshClient *sshClient) getOSLSeedPayload() *osl.SeedPayload {
sshClient.Lock()
defer sshClient.Unlock()
// Will not be initialized before handshake.
if sshClient.oslClientSeedState == nil {
return &osl.SeedPayload{SLOKs: make([]*osl.SLOK, 0)}
}
return sshClient.oslClientSeedState.GetSeedPayload()
}
func (sshClient *sshClient) clearOSLSeedPayload() {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.oslClientSeedState.ClearSeedPayload()
}
func (sshClient *sshClient) rateLimits() common.RateLimits {
sshClient.Lock()
defer sshClient.Unlock()
return sshClient.trafficRules.RateLimits.CommonRateLimits()
}
func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
sshClient.Lock()
defer sshClient.Unlock()
return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
}
func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
sshClient.Lock()
defer sshClient.Unlock()
return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
}
func (sshClient *sshClient) setTCPPortForwardDialingAvailableSignal(signal context.CancelFunc) {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.tcpPortForwardDialingAvailableSignal = signal
}
const (
portForwardTypeTCP = iota
portForwardTypeUDP
)
func (sshClient *sshClient) isPortForwardPermitted(
portForwardType int,
remoteIP net.IP,
port int) bool {
// Disallow connection to bogons.
//
// As a security measure, this is a failsafe. The server should be run on a
// host with correctly configured firewall rules.
//
// This check also avoids spurious disallowed traffic alerts for destinations
// that are impossible to reach.
if !sshClient.sshServer.support.Config.AllowBogons && common.IsBogon(remoteIP) {
return false
}
// Blocklist check.
//
// Limitation: isPortForwardPermitted is not called in transparent DNS
// forwarding cases. As the destination IP address is rewritten in these
// cases, a blocklist entry won't be dialed in any case. However, no logs
// will be recorded.
tags := sshClient.sshServer.support.Blocklist.LookupIP(remoteIP)
if len(tags) > 0 {
sshClient.logBlocklistHits(remoteIP, "", tags)
if sshClient.sshServer.support.Config.BlocklistActive {
// Actively alert and block
sshClient.enqueueUnsafeTrafficAlertRequest(tags)
return false
}
}
// Don't lock before calling logBlocklistHits.
// Unlock before calling enqueueDisallowedTrafficAlertRequest/log.
sshClient.Lock()
allowed := true
// Client must complete handshake before port forwards are permitted.
if !sshClient.handshakeState.completed {
allowed = false
}
if allowed {
// Traffic rules checks.
switch portForwardType {
case portForwardTypeTCP:
if !sshClient.trafficRules.AllowTCPPort(remoteIP, port) {
allowed = false
}
case portForwardTypeUDP:
if !sshClient.trafficRules.AllowUDPPort(remoteIP, port) {
allowed = false
}
}
}
sshClient.Unlock()
if allowed {
return true
}
switch portForwardType {
case portForwardTypeTCP:
sshClient.updateQualityMetricsWithTCPRejectedDisallowed()
case portForwardTypeUDP:
sshClient.updateQualityMetricsWithUDPRejectedDisallowed()
}
sshClient.enqueueDisallowedTrafficAlertRequest()
log.WithTraceFields(
LogFields{
"type": portForwardType,
"port": port,
}).Debug("port forward denied by traffic rules")
return false
}
// isDomainPermitted returns true when the specified domain may be resolved
// and returns false and a reject reason otherwise.
func (sshClient *sshClient) isDomainPermitted(domain string) (bool, string) {
// We're not doing comprehensive validation, to avoid overhead per port
// forward. This is a simple sanity check to ensure we don't process
// blantantly invalid input.
//
// TODO: validate with dns.IsDomainName?
if len(domain) > 255 {
return false, "invalid domain name"
}
tags := sshClient.sshServer.support.Blocklist.LookupDomain(domain)
if len(tags) > 0 {
sshClient.logBlocklistHits(nil, domain, tags)
if sshClient.sshServer.support.Config.BlocklistActive {
// Actively alert and block
sshClient.enqueueUnsafeTrafficAlertRequest(tags)
return false, "port forward not permitted"
}
}
return true, ""
}
func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
sshClient.Lock()
defer sshClient.Unlock()
state := &sshClient.tcpTrafficState
max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
return true
}
return false
}
func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
sshClient.Lock()
defer sshClient.Unlock()
return *sshClient.trafficRules.MaxTCPPortForwardCount +
*sshClient.trafficRules.MaxTCPDialingPortForwardCount
}
func (sshClient *sshClient) getDialTCPPortForwardTimeoutMilliseconds() int {
sshClient.Lock()
defer sshClient.Unlock()
return *sshClient.trafficRules.DialTCPPortForwardTimeoutMilliseconds
}
func (sshClient *sshClient) dialingTCPPortForward() {
sshClient.Lock()
defer sshClient.Unlock()
state := &sshClient.tcpTrafficState
state.concurrentDialingPortForwardCount += 1
if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
}
}
func (sshClient *sshClient) abortedTCPPortForward() {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
}
func (sshClient *sshClient) allocatePortForward(portForwardType int) bool {
sshClient.Lock()
defer sshClient.Unlock()
// Check if at port forward limit. The subsequent counter
// changes must be atomic with the limit check to ensure
// the counter never exceeds the limit in the case of
// concurrent allocations.
var max int
var state *trafficState
if portForwardType == portForwardTypeTCP {
max = *sshClient.trafficRules.MaxTCPPortForwardCount
state = &sshClient.tcpTrafficState
} else {
max = *sshClient.trafficRules.MaxUDPPortForwardCount
state = &sshClient.udpTrafficState
}
if max > 0 && state.concurrentPortForwardCount >= int64(max) {
return false
}
// Update port forward counters.
if portForwardType == portForwardTypeTCP {
// Assumes TCP port forwards called dialingTCPPortForward
state.concurrentDialingPortForwardCount -= 1
if sshClient.tcpPortForwardDialingAvailableSignal != nil {
max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
sshClient.tcpPortForwardDialingAvailableSignal()
}
}
}
state.concurrentPortForwardCount += 1
if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
}
state.totalPortForwardCount += 1
return true
}
// establishedPortForward increments the concurrent port
// forward counter. closedPortForward decrements it, so it
// must always be called for each establishedPortForward
// call.
//
// When at the limit of established port forwards, the LRU
// existing port forward is closed to make way for the newly
// established one. There can be a minor delay as, in addition
// to calling Close() on the port forward net.Conn,
// establishedPortForward waits for the LRU's closedPortForward()
// call which will decrement the concurrent counter. This
// ensures all resources associated with the LRU (socket,
// goroutine) are released or will very soon be released before
// proceeding.
func (sshClient *sshClient) establishedPortForward(
portForwardType int, portForwardLRU *common.LRUConns) {
// Do not lock sshClient here.
var state *trafficState
if portForwardType == portForwardTypeTCP {
state = &sshClient.tcpTrafficState
} else {
state = &sshClient.udpTrafficState
}
// When the maximum number of port forwards is already
// established, close the LRU. CloseOldest will call
// Close on the port forward net.Conn. Both TCP and
// UDP port forwards have handler goroutines that may
// be blocked calling Read on the net.Conn. Close will
// eventually interrupt the Read and cause the handlers
// to exit, but not immediately. So the following logic
// waits for a LRU handler to be interrupted and signal
// availability.
//
// Notes:
//
// - the port forward limit can change via a traffic
// rules hot reload; the condition variable handles
// this case whereas a channel-based semaphore would
// not.
//
// - if a number of goroutines exceeding the total limit
// arrive here all concurrently, some CloseOldest() calls
// will have no effect as there can be less existing port
// forwards than new ones. In this case, the new port
// forward will be delayed. This is highly unlikely in
// practise since UDP calls to establishedPortForward are
// serialized and TCP calls are limited by the dial
// queue/count.
if !sshClient.allocatePortForward(portForwardType) {
portForwardLRU.CloseOldest()
log.WithTrace().Debug("closed LRU port forward")
state.availablePortForwardCond.L.Lock()
for !sshClient.allocatePortForward(portForwardType) {
state.availablePortForwardCond.Wait()
}
state.availablePortForwardCond.L.Unlock()
}
}
func (sshClient *sshClient) closedPortForward(
portForwardType int, bytesUp, bytesDown int64) {
sshClient.Lock()
var state *trafficState
if portForwardType == portForwardTypeTCP {
state = &sshClient.tcpTrafficState
} else {
state = &sshClient.udpTrafficState
}
state.concurrentPortForwardCount -= 1
state.bytesUp += bytesUp
state.bytesDown += bytesDown
sshClient.Unlock()
// Signal any goroutine waiting in establishedPortForward
// that an established port forward slot is available.
state.availablePortForwardCond.Signal()
}
func (sshClient *sshClient) updateQualityMetricsWithDialResult(
tcpPortForwardDialSuccess bool, dialDuration time.Duration, IP net.IP) {
sshClient.Lock()
defer sshClient.Unlock()
if tcpPortForwardDialSuccess {
sshClient.qualityMetrics.TCPPortForwardDialedCount += 1
sshClient.qualityMetrics.TCPPortForwardDialedDuration += dialDuration
if IP.To4() != nil {
sshClient.qualityMetrics.TCPIPv4PortForwardDialedCount += 1
sshClient.qualityMetrics.TCPIPv4PortForwardDialedDuration += dialDuration
} else if IP != nil {
sshClient.qualityMetrics.TCPIPv6PortForwardDialedCount += 1
sshClient.qualityMetrics.TCPIPv6PortForwardDialedDuration += dialDuration
}
} else {
sshClient.qualityMetrics.TCPPortForwardFailedCount += 1
sshClient.qualityMetrics.TCPPortForwardFailedDuration += dialDuration
if IP.To4() != nil {
sshClient.qualityMetrics.TCPIPv4PortForwardFailedCount += 1
sshClient.qualityMetrics.TCPIPv4PortForwardFailedDuration += dialDuration
} else if IP != nil {
sshClient.qualityMetrics.TCPIPv6PortForwardFailedCount += 1
sshClient.qualityMetrics.TCPIPv6PortForwardFailedDuration += dialDuration
}
}
}
func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.qualityMetrics.TCPPortForwardRejectedDialingLimitCount += 1
}
func (sshClient *sshClient) updateQualityMetricsWithTCPRejectedDisallowed() {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.qualityMetrics.TCPPortForwardRejectedDisallowedCount += 1
}
func (sshClient *sshClient) updateQualityMetricsWithUDPRejectedDisallowed() {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.qualityMetrics.UDPPortForwardRejectedDisallowedCount += 1
}
func (sshClient *sshClient) handleTCPChannel(
remainingDialTimeout time.Duration,
hostToConnect string,
portToConnect int,
newChannel ssh.NewChannel) {
// Assumptions:
// - sshClient.dialingTCPPortForward() has been called
// - remainingDialTimeout > 0
established := false
defer func() {
if !established {
sshClient.abortedTCPPortForward()
}
}()
// Transparently redirect web API request connections.
isWebServerPortForward := false
config := sshClient.sshServer.support.Config
if config.WebServerPortForwardAddress != "" {
destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
if destination == config.WebServerPortForwardAddress {
isWebServerPortForward = true
if config.WebServerPortForwardRedirectAddress != "" {
// Note: redirect format is validated when config is loaded
host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
port, _ := strconv.Atoi(portStr)
hostToConnect = host
portToConnect = port
}
}
}
// Validate the domain name and check the domain blocklist before dialing.
//
// The IP blocklist is checked in isPortForwardPermitted, which also provides
// IP blocklist checking for the packet tunnel code path. When hostToConnect
// is an IP address, the following hostname resolution step effectively
// performs no actions and next immediate step is the isPortForwardPermitted
// check.
//
// Limitation: this case handles port forwards where the client sends the
// destination domain in the SSH port forward request but does not currently
// handle DNS-over-TCP; in the DNS-over-TCP case, a client may bypass the
// block list check.
if !isWebServerPortForward &&
net.ParseIP(hostToConnect) == nil {
ok, rejectMessage := sshClient.isDomainPermitted(hostToConnect)
if !ok {
// Note: not recording a port forward failure in this case
sshClient.rejectNewChannel(newChannel, rejectMessage)
return
}
}
// Dial the remote address.
//
// Hostname resolution is performed explicitly, as a separate step, as the
// target IP address is used for traffic rules (AllowSubnets), OSL seed
// progress, and IP address blocklists.
//
// Contexts are used for cancellation (via sshClient.runCtx, which is
// cancelled when the client is stopping) and timeouts.
dialStartTime := time.Now()
log.WithTraceFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
// IPv4 is preferred in case the host has limited IPv6 routing. IPv6 is
// selected and attempted only when there's no IPv4 option.
// TODO: shuffle list to try other IPs?
var IP net.IP
for _, ip := range IPs {
if ip.IP.To4() != nil {
IP = ip.IP
break
}
}
if IP == nil && len(IPs) > 0 {
// If there are no IPv4 IPs, the first IP is IPv6.
IP = IPs[0].IP
}
if err == nil && IP == nil {
err = std_errors.New("no IP address")
}
resolveElapsedTime := time.Since(dialStartTime)
if err != nil {
// Record a port forward failure
sshClient.updateQualityMetricsWithDialResult(false, resolveElapsedTime, IP)
sshClient.rejectNewChannel(newChannel, fmt.Sprintf("LookupIP failed: %s", err))
return
}
remainingDialTimeout -= resolveElapsedTime
if remainingDialTimeout <= 0 {
sshClient.rejectNewChannel(newChannel, "TCP port forward timed out resolving")
return
}
// Enforce traffic rules, using the resolved IP address.
if !isWebServerPortForward &&
!sshClient.isPortForwardPermitted(
portForwardTypeTCP,
IP,
portToConnect) {
// Note: not recording a port forward failure in this case
sshClient.rejectNewChannel(newChannel, "port forward not permitted")
return
}
// TCP dial.
remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
log.WithTraceFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
ctx, cancelCtx = context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
// Record port forward success or failure
sshClient.updateQualityMetricsWithDialResult(err == nil, time.Since(dialStartTime), IP)
if err != nil {
// Monitor for low resource error conditions
sshClient.sshServer.monitorPortForwardDialError(err)
sshClient.rejectNewChannel(newChannel, fmt.Sprintf("DialTimeout failed: %s", err))
return
}
// The upstream TCP port forward connection has been established. Schedule
// some cleanup and notify the SSH client that the channel is accepted.
defer fwdConn.Close()
fwdChannel, requests, err := newChannel.Accept()
if err != nil {
if !isExpectedTunnelIOError(err) {
log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed")
}
return
}
go ssh.DiscardRequests(requests)
defer fwdChannel.Close()
// Release the dialing slot and acquire an established slot.
//
// establishedPortForward increments the concurrent TCP port
// forward counter and closes the LRU existing TCP port forward
// when already at the limit.
//
// Known limitations:
//
// - Closed LRU TCP sockets will enter the TIME_WAIT state,
// continuing to consume some resources.
sshClient.establishedPortForward(portForwardTypeTCP, sshClient.tcpPortForwardLRU)
// "established = true" cancels the deferred abortedTCPPortForward()
established = true
// TODO: 64-bit alignment? https://golang.org/pkg/sync/atomic/#pkg-note-BUG
var bytesUp, bytesDown int64
defer func() {
sshClient.closedPortForward(
portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
}()
lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
defer lruEntry.Remove()
// ActivityMonitoredConn monitors the TCP port forward I/O and updates
// its LRU status. ActivityMonitoredConn also times out I/O on the port
// forward if both reads and writes have been idle for the specified
// duration.
// Ensure nil interface if newClientSeedPortForward returns nil
var updater common.ActivityUpdater
seedUpdater := sshClient.newClientSeedPortForward(IP)
if seedUpdater != nil {
updater = seedUpdater
}
fwdConn, err = common.NewActivityMonitoredConn(
fwdConn,
sshClient.idleTCPPortForwardTimeout(),
true,
updater,
lruEntry)
if err != nil {
log.WithTraceFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
return
}
// Relay channel to forwarded connection.
log.WithTraceFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
// TODO: relay errors to fwdChannel.Stderr()?
relayWaitGroup := new(sync.WaitGroup)
relayWaitGroup.Add(1)
go func() {
defer relayWaitGroup.Done()
// io.Copy allocates a 32K temporary buffer, and each port forward relay
// uses two of these buffers; using common.CopyBuffer with a smaller buffer
// reduces the overall memory footprint.
bytes, err := common.CopyBuffer(
fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
atomic.AddInt64(&bytesDown, bytes)
if err != nil && err != io.EOF {
// Debug since errors such as "connection reset by peer" occur during normal operation
log.WithTraceFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
}
// Interrupt upstream io.Copy when downstream is shutting down.
// TODO: this is done to quickly cleanup the port forward when
// fwdConn has a read timeout, but is it clean -- upstream may still
// be flowing?
fwdChannel.Close()
}()
bytes, err := common.CopyBuffer(
fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
atomic.AddInt64(&bytesUp, bytes)
if err != nil && err != io.EOF {
log.WithTraceFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
}
// Shutdown special case: fwdChannel will be closed and return EOF when
// the SSH connection is closed, but we need to explicitly close fwdConn
// to interrupt the downstream io.Copy, which may be blocked on a
// fwdConn.Read().
fwdConn.Close()
relayWaitGroup.Wait()
log.WithTraceFields(
LogFields{
"remoteAddr": remoteAddr,
"bytesUp": atomic.LoadInt64(&bytesUp),
"bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
}