/*
* Copyright (c) 2016, Psiphon Inc.
* All rights reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*
*/
package server
import (
"context"
"crypto/subtle"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/Psiphon-Inc/goarista/monotime"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
"github.com/marusama/semaphore"
cache "github.com/patrickmn/go-cache"
)
const (
SSH_AUTH_LOG_PERIOD = 30 * time.Minute
SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
SSH_BEGIN_HANDSHAKE_TIMEOUT = 1 * time.Second
SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
SSH_TCP_PORT_FORWARD_QUEUE_SIZE = 1024
SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second
SSH_SEND_OSL_RETRY_FACTOR = 2
OSL_SESSION_CACHE_TTL = 5 * time.Minute
)
// TunnelServer is the main server that accepts Psiphon client
// connections, via various obfuscation protocols, and provides
// port forwarding (TCP and UDP) services to the Psiphon client.
// At its core, TunnelServer is an SSH server. SSH is the base
// protocol that provides port forward multiplexing, and transport
// security. Layered on top of SSH, optionally, is Obfuscated SSH
// and meek protocols, which provide further circumvention
// capabilities.
type TunnelServer struct {
runWaitGroup *sync.WaitGroup
listenerError chan error
shutdownBroadcast <-chan struct{}
sshServer *sshServer
}
// NewTunnelServer initializes a new tunnel server.
func NewTunnelServer(
support *SupportServices,
shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
sshServer, err := newSSHServer(support, shutdownBroadcast)
if err != nil {
return nil, common.ContextError(err)
}
return &TunnelServer{
runWaitGroup: new(sync.WaitGroup),
listenerError: make(chan error),
shutdownBroadcast: shutdownBroadcast,
sshServer: sshServer,
}, nil
}
// Run runs the tunnel server; this function blocks while running a selection of
// listeners that handle connection using various obfuscation protocols.
//
// Run listens on each designated tunnel port and spawns new goroutines to handle
// each client connection. It halts when shutdownBroadcast is signaled. A list of active
// clients is maintained, and when halting all clients are cleanly shutdown.
//
// Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
// authentication, and then looping on client new channel requests. "direct-tcpip"
// channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress
// config parameter is configured, UDP port forwards over a TCP stream, following
// the udpgw protocol, are handled.
//
// A new goroutine is spawned to handle each port forward for each client. Each port
// forward tracks its bytes transferred. Overall per-client stats for connection duration,
// GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
// client shuts down.
//
// Note: client handler goroutines may still be shutting down after Run() returns. See
// comment in sshClient.stop(). TODO: fully synchronized shutdown.
func (server *TunnelServer) Run() error {
type sshListener struct {
net.Listener
localAddress string
tunnelProtocol string
}
// TODO: should TunnelServer hold its own support pointer?
support := server.sshServer.support
// First bind all listeners; once all are successful,
// start accepting connections on each.
var listeners []*sshListener
for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts {
localAddress := fmt.Sprintf(
"%s:%d", support.Config.ServerIPAddress, listenPort)
listener, err := net.Listen("tcp", localAddress)
if err != nil {
for _, existingListener := range listeners {
existingListener.Listener.Close()
}
return common.ContextError(err)
}
log.WithContextFields(
LogFields{
"localAddress": localAddress,
"tunnelProtocol": tunnelProtocol,
}).Info("listening")
listeners = append(
listeners,
&sshListener{
Listener: listener,
localAddress: localAddress,
tunnelProtocol: tunnelProtocol,
})
}
for _, listener := range listeners {
server.runWaitGroup.Add(1)
go func(listener *sshListener) {
defer server.runWaitGroup.Done()
log.WithContextFields(
LogFields{
"localAddress": listener.localAddress,
"tunnelProtocol": listener.tunnelProtocol,
}).Info("running")
server.sshServer.runListener(
listener.Listener,
server.listenerError,
listener.tunnelProtocol)
log.WithContextFields(
LogFields{
"localAddress": listener.localAddress,
"tunnelProtocol": listener.tunnelProtocol,
}).Info("stopped")
}(listener)
}
var err error
select {
case <-server.shutdownBroadcast:
case err = <-server.listenerError:
}
for _, listener := range listeners {
listener.Close()
}
server.sshServer.stopClients()
server.runWaitGroup.Wait()
log.WithContext().Info("stopped")
return err
}
// GetLoadStats returns load stats for the tunnel server. The stats are
// broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
// include current connected client count, total number of current port
// forwards.
func (server *TunnelServer) GetLoadStats() (ProtocolStats, RegionStats) {
return server.sshServer.getLoadStats()
}
// ResetAllClientTrafficRules resets all established client traffic rules
// to use the latest config and client properties. Any existing traffic
// rule state is lost, including throttling state.
func (server *TunnelServer) ResetAllClientTrafficRules() {
server.sshServer.resetAllClientTrafficRules()
}
// ResetAllClientOSLConfigs resets all established client OSL state to use
// the latest OSL config. Any existing OSL state is lost, including partial
// progress towards SLOKs.
func (server *TunnelServer) ResetAllClientOSLConfigs() {
server.sshServer.resetAllClientOSLConfigs()
}
// SetClientHandshakeState sets the handshake state -- that it completed and
// what parameters were passed -- in sshClient. This state is used for allowing
// port forwards and for future traffic rule selection. SetClientHandshakeState
// also triggers an immediate traffic rule re-selection, as the rules selected
// upon tunnel establishment may no longer apply now that handshake values are
// set.
func (server *TunnelServer) SetClientHandshakeState(
sessionID string, state handshakeState) error {
return server.sshServer.setClientHandshakeState(sessionID, state)
}
// GetClientHandshaked indicates whether the client has completed a handshake
// and whether its traffic rules are immediately exhausted.
func (server *TunnelServer) GetClientHandshaked(
sessionID string) (bool, bool, error) {
return server.sshServer.getClientHandshaked(sessionID)
}
// SetEstablishTunnels sets whether new tunnels may be established or not.
// When not establishing, incoming connections are immediately closed.
func (server *TunnelServer) SetEstablishTunnels(establish bool) {
server.sshServer.setEstablishTunnels(establish)
}
// GetEstablishTunnels returns whether new tunnels may be established or not.
func (server *TunnelServer) GetEstablishTunnels() bool {
return server.sshServer.getEstablishTunnels()
}
type sshServer struct {
// Note: 64-bit ints used with atomic operations are placed
// at the start of struct to ensure 64-bit alignment.
// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
lastAuthLog int64
authFailedCount int64
support *SupportServices
establishTunnels int32
concurrentSSHHandshakes semaphore.Semaphore
shutdownBroadcast <-chan struct{}
sshHostKey ssh.Signer
clientsMutex sync.Mutex
stoppingClients bool
acceptedClientCounts map[string]map[string]int64
clients map[string]*sshClient
oslSessionCacheMutex sync.Mutex
oslSessionCache *cache.Cache
}
func newSSHServer(
support *SupportServices,
shutdownBroadcast <-chan struct{}) (*sshServer, error) {
privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey))
if err != nil {
return nil, common.ContextError(err)
}
// TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
signer, err := ssh.NewSignerFromKey(privateKey)
if err != nil {
return nil, common.ContextError(err)
}
var concurrentSSHHandshakes semaphore.Semaphore
if support.Config.MaxConcurrentSSHHandshakes > 0 {
concurrentSSHHandshakes = semaphore.New(support.Config.MaxConcurrentSSHHandshakes)
}
// The OSL session cache temporarily retains OSL seed state
// progress for disconnected clients. This enables clients
// that disconnect and immediately reconnect to the same
// server to resume their OSL progress. Cached progress
// is referenced by session ID and is retained for
// OSL_SESSION_CACHE_TTL after disconnect.
//
// Note: session IDs are assumed to be unpredictable. If a
// rogue client could guess the session ID of another client,
// it could resume its OSL progress and, if the OSL config
// were known, infer some activity.
oslSessionCache := cache.New(OSL_SESSION_CACHE_TTL, 1*time.Minute)
return &sshServer{
support: support,
establishTunnels: 1,
concurrentSSHHandshakes: concurrentSSHHandshakes,
shutdownBroadcast: shutdownBroadcast,
sshHostKey: signer,
acceptedClientCounts: make(map[string]map[string]int64),
clients: make(map[string]*sshClient),
oslSessionCache: oslSessionCache,
}, nil
}
func (sshServer *sshServer) setEstablishTunnels(establish bool) {
// Do nothing when the setting is already correct. This avoids
// spurious log messages when setEstablishTunnels is called
// periodically with the same setting.
if establish == sshServer.getEstablishTunnels() {
return
}
establishFlag := int32(1)
if !establish {
establishFlag = 0
}
atomic.StoreInt32(&sshServer.establishTunnels, establishFlag)
log.WithContextFields(
LogFields{"establish": establish}).Info("establishing tunnels")
}
func (sshServer *sshServer) getEstablishTunnels() bool {
return atomic.LoadInt32(&sshServer.establishTunnels) == 1
}
// runListener is intended to run an a goroutine; it blocks
// running a particular listener. If an unrecoverable error
// occurs, it will send the error to the listenerError channel.
func (sshServer *sshServer) runListener(
listener net.Listener,
listenerError chan<- error,
listenerTunnelProtocol string) {
runningProtocols := make([]string, 0)
for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
runningProtocols = append(runningProtocols, tunnelProtocol)
}
handleClient := func(clientTunnelProtocol string, clientConn net.Conn) {
// Note: establish tunnel limiter cannot simply stop TCP
// listeners in all cases (e.g., meek) since SSH tunnel can
// span multiple TCP connections.
if !sshServer.getEstablishTunnels() {
log.WithContext().Debug("not establishing tunnels")
clientConn.Close()
return
}
// The tunnelProtocol passed to handleClient is used for stats,
// throttling, etc. When the tunnel protocol can be determined
// unambiguously from the listening port, use that protocol and
// don't use any client-declared value. Only use the client's
// value, if present, in special cases where the listenting port
// cannot distinguish the protocol.
tunnelProtocol := listenerTunnelProtocol
if clientTunnelProtocol != "" &&
protocol.UseClientTunnelProtocol(
clientTunnelProtocol, runningProtocols) {
tunnelProtocol = clientTunnelProtocol
}
// process each client connection concurrently
go sshServer.handleClient(tunnelProtocol, clientConn)
}
// Note: when exiting due to a unrecoverable error, be sure
// to try to send the error to listenerError so that the outer
// TunnelServer.Run will properly shut down instead of remaining
// running.
if protocol.TunnelProtocolUsesMeekHTTP(listenerTunnelProtocol) ||
protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol) {
meekServer, err := NewMeekServer(
sshServer.support,
listener,
protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol),
protocol.TunnelProtocolUsesObfuscatedSessionTickets(listenerTunnelProtocol),
handleClient,
sshServer.shutdownBroadcast)
if err == nil {
err = meekServer.Run()
}
if err != nil {
select {
case listenerError <- common.ContextError(err):
default:
}
return
}
} else {
for {
conn, err := listener.Accept()
select {
case <-sshServer.shutdownBroadcast:
if err == nil {
conn.Close()
}
return
default:
}
if err != nil {
if e, ok := err.(net.Error); ok && e.Temporary() {
log.WithContextFields(LogFields{"error": err}).Error("accept failed")
// Temporary error, keep running
continue
}
select {
case listenerError <- common.ContextError(err):
default:
}
return
}
handleClient("", conn)
}
}
}
// An accepted client has completed a direct TCP or meek connection and has a net.Conn. Registration
// is for tracking the number of connections.
func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol, region string) {
sshServer.clientsMutex.Lock()
defer sshServer.clientsMutex.Unlock()
if sshServer.acceptedClientCounts[tunnelProtocol] == nil {
sshServer.acceptedClientCounts[tunnelProtocol] = make(map[string]int64)
}
sshServer.acceptedClientCounts[tunnelProtocol][region] += 1
}
func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol, region string) {
sshServer.clientsMutex.Lock()
defer sshServer.clientsMutex.Unlock()
sshServer.acceptedClientCounts[tunnelProtocol][region] -= 1
}
// An established client has completed its SSH handshake and has a ssh.Conn. Registration is
// for tracking the number of fully established clients and for maintaining a list of running
// clients (for stopping at shutdown time).
func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
sshServer.clientsMutex.Lock()
if sshServer.stoppingClients {
sshServer.clientsMutex.Unlock()
return false
}
// In the case of a duplicate client sessionID, the previous client is closed.
// - Well-behaved clients generate pick a random sessionID that should be
// unique (won't accidentally conflict) and hard to guess (can't be targeted
// by a malicious client).
// - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected
// and resestablished. In this case, when the same server is selected, this logic
// will be hit; closing the old, dangling client is desirable.
// - Multi-tunnel clients should not normally use one server for multiple tunnels.
existingClient := sshServer.clients[client.sessionID]
sshServer.clients[client.sessionID] = client
sshServer.clientsMutex.Unlock()
// Call stop() outside the mutex to avoid deadlock.
if existingClient != nil {
existingClient.stop()
log.WithContext().Debug(
"stopped existing client with duplicate session ID")
}
return true
}
func (sshServer *sshServer) unregisterEstablishedClient(client *sshClient) {
sshServer.clientsMutex.Lock()
registeredClient := sshServer.clients[client.sessionID]
// registeredClient will differ from client when client
// is the existingClient terminated in registerEstablishedClient.
// In that case, registeredClient remains connected, and
// the sshServer.clients entry should be retained.
if registeredClient == client {
delete(sshServer.clients, client.sessionID)
}
sshServer.clientsMutex.Unlock()
// Call stop() outside the mutex to avoid deadlock.
client.stop()
}
type ProtocolStats map[string]map[string]int64
type RegionStats map[string]map[string]map[string]int64
func (sshServer *sshServer) getLoadStats() (ProtocolStats, RegionStats) {
sshServer.clientsMutex.Lock()
defer sshServer.clientsMutex.Unlock()
// Explicitly populate with zeros to ensure 0 counts in log messages
zeroStats := func() map[string]int64 {
stats := make(map[string]int64)
stats["accepted_clients"] = 0
stats["established_clients"] = 0
stats["dialing_tcp_port_forwards"] = 0
stats["tcp_port_forwards"] = 0
stats["total_tcp_port_forwards"] = 0
stats["udp_port_forwards"] = 0
stats["total_udp_port_forwards"] = 0
stats["tcp_port_forward_dialed_count"] = 0
stats["tcp_port_forward_dialed_duration"] = 0
stats["tcp_port_forward_failed_count"] = 0
stats["tcp_port_forward_failed_duration"] = 0
stats["tcp_port_forward_rejected_dialing_limit_count"] = 0
return stats
}
zeroProtocolStats := func() map[string]map[string]int64 {
stats := make(map[string]map[string]int64)
stats["ALL"] = zeroStats()
for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
stats[tunnelProtocol] = zeroStats()
}
return stats
}
// [][] -> count
protocolStats := zeroProtocolStats()
// [][] -> count
regionStats := make(RegionStats)
// Note: as currently tracked/counted, each established client is also an accepted client
for tunnelProtocol, regionAcceptedClientCounts := range sshServer.acceptedClientCounts {
for region, acceptedClientCount := range regionAcceptedClientCounts {
if acceptedClientCount > 0 {
if regionStats[region] == nil {
regionStats[region] = zeroProtocolStats()
}
protocolStats["ALL"]["accepted_clients"] += acceptedClientCount
protocolStats[tunnelProtocol]["accepted_clients"] += acceptedClientCount
regionStats[region]["ALL"]["accepted_clients"] += acceptedClientCount
regionStats[region][tunnelProtocol]["accepted_clients"] += acceptedClientCount
}
}
}
for _, client := range sshServer.clients {
client.Lock()
tunnelProtocol := client.tunnelProtocol
region := client.geoIPData.Country
if regionStats[region] == nil {
regionStats[region] = zeroProtocolStats()
}
stats := []map[string]int64{
protocolStats["ALL"],
protocolStats[tunnelProtocol],
regionStats[region]["ALL"],
regionStats[region][tunnelProtocol]}
for _, stat := range stats {
stat["established_clients"] += 1
// Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
stat["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
stat["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
stat["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
// client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
stat["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
stat["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
stat["tcp_port_forward_dialed_count"] += client.qualityMetrics.tcpPortForwardDialedCount
stat["tcp_port_forward_dialed_duration"] +=
int64(client.qualityMetrics.tcpPortForwardDialedDuration / time.Millisecond)
stat["tcp_port_forward_failed_count"] += client.qualityMetrics.tcpPortForwardFailedCount
stat["tcp_port_forward_failed_duration"] +=
int64(client.qualityMetrics.tcpPortForwardFailedDuration / time.Millisecond)
stat["tcp_port_forward_rejected_dialing_limit_count"] +=
client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount
}
client.qualityMetrics.tcpPortForwardDialedCount = 0
client.qualityMetrics.tcpPortForwardDialedDuration = 0
client.qualityMetrics.tcpPortForwardFailedCount = 0
client.qualityMetrics.tcpPortForwardFailedDuration = 0
client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount = 0
client.Unlock()
}
return protocolStats, regionStats
}
func (sshServer *sshServer) resetAllClientTrafficRules() {
sshServer.clientsMutex.Lock()
clients := make(map[string]*sshClient)
for sessionID, client := range sshServer.clients {
clients[sessionID] = client
}
sshServer.clientsMutex.Unlock()
for _, client := range clients {
client.setTrafficRules()
}
}
func (sshServer *sshServer) resetAllClientOSLConfigs() {
// Flush cached seed state. This has the same effect
// and same limitations as calling setOSLConfig for
// currently connected clients -- all progress is lost.
sshServer.oslSessionCacheMutex.Lock()
sshServer.oslSessionCache.Flush()
sshServer.oslSessionCacheMutex.Unlock()
sshServer.clientsMutex.Lock()
clients := make(map[string]*sshClient)
for sessionID, client := range sshServer.clients {
clients[sessionID] = client
}
sshServer.clientsMutex.Unlock()
for _, client := range clients {
client.setOSLConfig()
}
}
func (sshServer *sshServer) setClientHandshakeState(
sessionID string, state handshakeState) error {
sshServer.clientsMutex.Lock()
client := sshServer.clients[sessionID]
sshServer.clientsMutex.Unlock()
if client == nil {
return common.ContextError(errors.New("unknown session ID"))
}
err := client.setHandshakeState(state)
if err != nil {
return common.ContextError(err)
}
return nil
}
func (sshServer *sshServer) getClientHandshaked(
sessionID string) (bool, bool, error) {
sshServer.clientsMutex.Lock()
client := sshServer.clients[sessionID]
sshServer.clientsMutex.Unlock()
if client == nil {
return false, false, common.ContextError(errors.New("unknown session ID"))
}
completed, exhausted := client.getHandshaked()
return completed, exhausted, nil
}
func (sshServer *sshServer) stopClients() {
sshServer.clientsMutex.Lock()
sshServer.stoppingClients = true
clients := sshServer.clients
sshServer.clients = make(map[string]*sshClient)
sshServer.clientsMutex.Unlock()
for _, client := range clients {
client.stop()
}
}
func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
geoIPData := sshServer.support.GeoIPService.Lookup(
common.IPAddressFromAddr(clientConn.RemoteAddr()))
sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
// When configured, enforce a cap on the number of concurrent SSH
// handshakes. This limits load spikes on busy servers when many clients
// attempt to connect at once. Wait a short time, SSH_BEGIN_HANDSHAKE_TIMEOUT,
// to acquire; waiting will avoid immediately creating more load on another
// server in the network when the client tries a new candidate. Disconnect the
// client when that wait time is exceeded.
//
// This mechanism limits memory allocations and CPU usage associated with the
// SSH handshake. At this point, new direct TCP connections or new meek
// connections, with associated resource usage, are already established. Those
// connections are expected to be rate or load limited using other mechanisms.
//
// TODO:
//
// - deduct time spent acquiring the semaphore from SSH_HANDSHAKE_TIMEOUT in
// sshClient.run, since the client is also applying an SSH handshake timeout
// and won't exclude time spent waiting.
// - each call to sshServer.handleClient (in sshServer.runListener) is invoked
// in its own goroutine, but shutdown doesn't synchronously await these
// goroutnes. Once this is synchronizes, the following context.WithTimeout
// should use an sshServer parent context to ensure blocking acquires
// interrupt immediately upon shutdown.
var onSSHHandshakeFinished func()
if sshServer.support.Config.MaxConcurrentSSHHandshakes > 0 {
ctx, cancelFunc := context.WithTimeout(
context.Background(), SSH_BEGIN_HANDSHAKE_TIMEOUT)
defer cancelFunc()
err := sshServer.concurrentSSHHandshakes.Acquire(ctx, 1)
if err != nil {
clientConn.Close()
// This is a debug log as the only possible error is context timeout.
log.WithContextFields(LogFields{"error": err}).Debug(
"acquire SSH handshake semaphore failed")
return
}
onSSHHandshakeFinished = func() {
sshServer.concurrentSSHHandshakes.Release(1)
}
}
sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
// sshClient.run _must_ call onSSHHandshakeFinished to release the semaphore:
// in any error case; or, as soon as the SSH handshake phase has successfully
// completed.
sshClient.run(clientConn, onSSHHandshakeFinished)
}
func (sshServer *sshServer) monitorPortForwardDialError(err error) {
// "err" is the error returned from a failed TCP or UDP port
// forward dial. Certain system error codes indicate low resource
// conditions: insufficient file descriptors, ephemeral ports, or
// memory. For these cases, log an alert.
// TODO: also temporarily suspend new clients
// Note: don't log net.OpError.Error() as the full error string
// may contain client destination addresses.
opErr, ok := err.(*net.OpError)
if ok {
if opErr.Err == syscall.EADDRNOTAVAIL ||
opErr.Err == syscall.EAGAIN ||
opErr.Err == syscall.ENOMEM ||
opErr.Err == syscall.EMFILE ||
opErr.Err == syscall.ENFILE {
log.WithContextFields(
LogFields{"error": opErr.Err}).Error(
"port forward dial failed due to unavailable resource")
}
}
}
type sshClient struct {
sync.Mutex
sshServer *sshServer
tunnelProtocol string
sshConn ssh.Conn
activityConn *common.ActivityMonitoredConn
throttledConn *common.ThrottledConn
geoIPData GeoIPData
sessionID string
supportsServerRequests bool
handshakeState handshakeState
udpChannel ssh.Channel
packetTunnelChannel ssh.Channel
trafficRules TrafficRules
tcpTrafficState trafficState
udpTrafficState trafficState
qualityMetrics qualityMetrics
tcpPortForwardLRU *common.LRUConns
oslClientSeedState *osl.ClientSeedState
signalIssueSLOKs chan struct{}
runCtx context.Context
stopRunning context.CancelFunc
tcpPortForwardDialingAvailableSignal context.CancelFunc
}
type trafficState struct {
bytesUp int64
bytesDown int64
concurrentDialingPortForwardCount int64
peakConcurrentDialingPortForwardCount int64
concurrentPortForwardCount int64
peakConcurrentPortForwardCount int64
totalPortForwardCount int64
availablePortForwardCond *sync.Cond
}
// qualityMetrics records upstream TCP dial attempts and
// elapsed time. Elapsed time includes the full TCP handshake
// and, in aggregate, is a measure of the quality of the
// upstream link. These stats are recorded by each sshClient
// and then reported and reset in sshServer.getLoadStats().
type qualityMetrics struct {
tcpPortForwardDialedCount int64
tcpPortForwardDialedDuration time.Duration
tcpPortForwardFailedCount int64
tcpPortForwardFailedDuration time.Duration
tcpPortForwardRejectedDialingLimitCount int64
}
type handshakeState struct {
completed bool
apiProtocol string
apiParams requestJSONObject
}
func newSshClient(
sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
runCtx, stopRunning := context.WithCancel(context.Background())
client := &sshClient{
sshServer: sshServer,
tunnelProtocol: tunnelProtocol,
geoIPData: geoIPData,
tcpPortForwardLRU: common.NewLRUConns(),
signalIssueSLOKs: make(chan struct{}, 1),
runCtx: runCtx,
stopRunning: stopRunning,
}
client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
client.udpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
return client
}
func (sshClient *sshClient) run(
clientConn net.Conn, onSSHHandshakeFinished func()) {
// onSSHHandshakeFinished must be called even if the SSH handshake is aborted.
defer func() {
if onSSHHandshakeFinished != nil {
onSSHHandshakeFinished()
}
}()
// Some conns report additional metrics
metricsSource, isMetricsSource := clientConn.(MetricsSource)
// Set initial traffic rules, pre-handshake, based on currently known info.
sshClient.setTrafficRules()
// Wrap the base client connection with an ActivityMonitoredConn which will
// terminate the connection if no data is received before the deadline. This
// timeout is in effect for the entire duration of the SSH connection. Clients
// must actively use the connection or send SSH keep alive requests to keep
// the connection active. Writes are not considered reliable activity indicators
// due to buffering.
activityConn, err := common.NewActivityMonitoredConn(
clientConn,
SSH_CONNECTION_READ_DEADLINE,
false,
nil,
nil)
if err != nil {
clientConn.Close()
log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
return
}
clientConn = activityConn
// Further wrap the connection in a rate limiting ThrottledConn.
throttledConn := common.NewThrottledConn(clientConn, sshClient.rateLimits())
clientConn = throttledConn
// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
// respect shutdownBroadcast and implement a specific handshake timeout.
// The timeout is to reclaim network resources in case the handshake takes
// too long.
type sshNewServerConnResult struct {
conn net.Conn
sshConn *ssh.ServerConn
channels <-chan ssh.NewChannel
requests <-chan *ssh.Request
err error
}
resultChannel := make(chan *sshNewServerConnResult, 2)
var afterFunc *time.Timer
if SSH_HANDSHAKE_TIMEOUT > 0 {
afterFunc = time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
})
}
go func(conn net.Conn) {
sshServerConfig := &ssh.ServerConfig{
PasswordCallback: sshClient.passwordCallback,
AuthLogCallback: sshClient.authLogCallback,
ServerVersion: sshClient.sshServer.support.Config.SSHServerVersion,
}
sshServerConfig.AddHostKey(sshClient.sshServer.sshHostKey)
result := &sshNewServerConnResult{}
// Wrap the connection in an SSH deobfuscator when required.
if protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
// Note: NewObfuscatedSshConn blocks on network I/O
// TODO: ensure this won't block shutdown
conn, result.err = common.NewObfuscatedSshConn(
common.OBFUSCATION_CONN_MODE_SERVER,
conn,
sshClient.sshServer.support.Config.ObfuscatedSSHKey)
if result.err != nil {
result.err = common.ContextError(result.err)
}
}
if result.err == nil {
result.sshConn, result.channels, result.requests, result.err =
ssh.NewServerConn(conn, sshServerConfig)
}
resultChannel <- result
}(clientConn)
var result *sshNewServerConnResult
select {
case result = <-resultChannel:
case <-sshClient.sshServer.shutdownBroadcast:
// Close() will interrupt an ongoing handshake
// TODO: wait for SSH handshake goroutines to exit before returning?
clientConn.Close()
return
}
if afterFunc != nil {
afterFunc.Stop()
}
if result.err != nil {
clientConn.Close()
// This is a Debug log due to noise. The handshake often fails due to I/O
// errors as clients frequently interrupt connections in progress when
// client-side load balancing completes a connection to a different server.
log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed")
return
}
// The SSH handshake has finished successfully; notify now to allow other
// blocked SSH handshakes to proceed.
if onSSHHandshakeFinished != nil {
onSSHHandshakeFinished()
}
onSSHHandshakeFinished = nil
sshClient.Lock()
sshClient.sshConn = result.sshConn
sshClient.activityConn = activityConn
sshClient.throttledConn = throttledConn
sshClient.Unlock()
if !sshClient.sshServer.registerEstablishedClient(sshClient) {
clientConn.Close()
log.WithContext().Warning("register failed")
return
}
sshClient.runTunnel(result.channels, result.requests)
// Note: sshServer.unregisterEstablishedClient calls sshClient.stop(),
// which also closes underlying transport Conn.
sshClient.sshServer.unregisterEstablishedClient(sshClient)
var additionalMetrics LogFields
if isMetricsSource {
additionalMetrics = metricsSource.GetMetrics()
}
sshClient.logTunnel(additionalMetrics)
// Transfer OSL seed state -- the OSL progress -- from the closing
// client to the session cache so the client can resume its progress
// if it reconnects to this same server.
// Note: following setOSLConfig order of locking.
sshClient.Lock()
if sshClient.oslClientSeedState != nil {
sshClient.sshServer.oslSessionCacheMutex.Lock()
sshClient.oslClientSeedState.Hibernate()
sshClient.sshServer.oslSessionCache.Set(
sshClient.sessionID, sshClient.oslClientSeedState, cache.DefaultExpiration)
sshClient.sshServer.oslSessionCacheMutex.Unlock()
sshClient.oslClientSeedState = nil
}
sshClient.Unlock()
// Initiate cleanup of the GeoIP session cache. To allow for post-tunnel
// final status requests, the lifetime of cached GeoIP records exceeds the
// lifetime of the sshClient.
sshClient.sshServer.support.GeoIPService.MarkSessionCacheToExpire(sshClient.sessionID)
}
func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
expectedSessionIDLength := 2 * protocol.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
var sshPasswordPayload protocol.SSHPasswordPayload
err := json.Unmarshal(password, &sshPasswordPayload)
if err != nil {
// Backwards compatibility case: instead of a JSON payload, older clients
// send the hex encoded session ID prepended to the SSH password.
// Note: there's an even older case where clients don't send any session ID,
// but that's no longer supported.
if len(password) == expectedSessionIDLength+expectedSSHPasswordLength {
sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength])
sshPasswordPayload.SshPassword = string(password[expectedSSHPasswordLength:])
} else {
return nil, common.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
}
}
if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) ||
len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
return nil, common.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
}
userOk := (subtle.ConstantTimeCompare(
[]byte(conn.User()), []byte(sshClient.sshServer.support.Config.SSHUserName)) == 1)
passwordOk := (subtle.ConstantTimeCompare(
[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1)
if !userOk || !passwordOk {
return nil, common.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
}
sessionID := sshPasswordPayload.SessionId
supportsServerRequests := common.Contains(
sshPasswordPayload.ClientCapabilities, protocol.CLIENT_CAPABILITY_SERVER_REQUESTS)
sshClient.Lock()
sshClient.sessionID = sessionID
sshClient.supportsServerRequests = supportsServerRequests
geoIPData := sshClient.geoIPData
sshClient.Unlock()
// Store the GeoIP data associated with the session ID. This makes
// the GeoIP data available to the web server for web API requests.
// A cache that's distinct from the sshClient record is used to allow
// for or post-tunnel final status requests.
// If the client is reconnecting with the same session ID, this call
// will undo the expiry set by MarkSessionCacheToExpire.
sshClient.sshServer.support.GeoIPService.SetSessionCache(sessionID, geoIPData)
return nil, nil
}
func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
if err != nil {
if method == "none" && err.Error() == "no auth passed yet" {
// In this case, the callback invocation is noise from auth negotiation
return
}
// Note: here we previously logged messages for fail2ban to act on. This is no longer
// done as the complexity outweighs the benefits.
//
// - The SSH credential is not secret -- it's in the server entry. Attackers targeting
// the server likely already have the credential. On the other hand, random scanning and
// brute forcing is mitigated with high entropy random passwords, rate limiting
// (implemented on the host via iptables), and limited capabilities (the SSH session can
// only port forward).
//
// - fail2ban coverage was inconsistent; in the case of an unfronted meek protocol through
// an upstream proxy, the remote address is the upstream proxy, which should not be blocked.
// The X-Forwarded-For header cant be used instead as it may be forged and used to get IPs
// deliberately blocked; and in any case fail2ban adds iptables rules which can only block
// by direct remote IP, not by original client IP. Fronted meek has the same iptables issue.
//
// Random scanning and brute forcing of port 22 will result in log noise. To mitigate this,
// not every authentication failure is logged. A summary log is emitted periodically to
// retain some record of this activity in case this is relevant to, e.g., a performance
// investigation.
atomic.AddInt64(&sshClient.sshServer.authFailedCount, 1)
lastAuthLog := monotime.Time(atomic.LoadInt64(&sshClient.sshServer.lastAuthLog))
if monotime.Since(lastAuthLog) > SSH_AUTH_LOG_PERIOD {
now := int64(monotime.Now())
if atomic.CompareAndSwapInt64(&sshClient.sshServer.lastAuthLog, int64(lastAuthLog), now) {
count := atomic.SwapInt64(&sshClient.sshServer.authFailedCount, 0)
log.WithContextFields(
LogFields{"lastError": err, "failedCount": count}).Warning("authentication failures")
}
}
log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
} else {
log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
}
}
// stop signals the ssh connection to shutdown. After sshConn() returns,
// the connection has terminated but sshClient.run() may still be
// running and in the process of exiting.
func (sshClient *sshClient) stop() {
sshClient.sshConn.Close()
sshClient.sshConn.Wait()
}
// runTunnel handles/dispatches new channels and new requests from the client.
// When the SSH client connection closes, both the channels and requests channels
// will close and runTunnel will exit.
func (sshClient *sshClient) runTunnel(
channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
waitGroup := new(sync.WaitGroup)
// Start client SSH API request handler
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
for request := range requests {
// Requests are processed serially; API responses must be sent in request order.
var responsePayload []byte
var err error
if request.Type == "keepalive@openssh.com" {
// Keepalive requests have an empty response.
} else {
// All other requests are assumed to be API requests.
responsePayload, err = sshAPIRequestHandler(
sshClient.sshServer.support,
sshClient.geoIPData,
request.Type,
request.Payload)
}
if err == nil {
err = request.Reply(true, responsePayload)
} else {
log.WithContextFields(LogFields{"error": err}).Warning("request failed")
err = request.Reply(false, nil)
}
if err != nil {
log.WithContextFields(LogFields{"error": err}).Warning("response failed")
}
}
}()
// Start OSL sender
if sshClient.supportsServerRequests {
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
sshClient.runOSLSender()
}()
}
// Lifecycle of a TCP port forward:
//
// 1. A "direct-tcpip" SSH request is received from the client.
//
// A new TCP port forward request is enqueued. The queue delivers TCP port
// forward requests to the TCP port forward manager, which enforces the TCP
// port forward dial limit.
//
// Enqueuing new requests allows for reading further SSH requests from the
// client without blocking when the dial limit is hit; this is to permit new
// UDP/udpgw port forwards to be restablished without delay. The maximum size
// of the queue enforces a hard cap on resources consumed by a client in the
// pre-dial phase. When the queue is full, new TCP port forwards are
// immediately rejected.
//
// 2. The TCP port forward manager dequeues the request.
//
// The manager calls dialingTCPPortForward(), which increments
// concurrentDialingPortForwardCount, and calls
// isTCPDialingPortForwardLimitExceeded() to check the concurrent dialing
// count.
//
// The manager enforces the concurrent TCP dial limit: when at the limit, the
// manager blocks waiting for the number of dials to drop below the limit before
// dispatching the request to handleTCPPortForward(), which will run in its own
// goroutine and will dial and relay the port forward.
//
// The block delays the current request and also halts dequeuing of subsequent
// requests and could ultimately cause requests to be immediately rejected if
// the queue fills. These actions are intended to apply back pressure when
// upstream network resources are impaired.
//
// The time spent in the queue is deducted from the port forward's dial timeout.
// The time spent blocking while at the dial limit is similarly deducted from
// the dial timeout. If the dial timeout has expired before the dial begins, the
// port forward is rejected and a stat is recorded.
//
// 3. handleTCPPortForward() performs the port forward dial and relaying.
//
// a. Dial the target, using the dial timeout remaining after queue and blocking
// time is deducted.
//
// b. If the dial fails, call abortedTCPPortForward() to decrement
// concurrentDialingPortForwardCount, freeing up a dial slot.
//
// c. If the dial succeeds, call establishedPortForward(), which decrements
// concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
// the "established" port forward count.
//
// d. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
// concurrentPortForwardCount, the number of _established_ TCP port forwards.
// If the limit is exceeded, the LRU established TCP port forward is closed and
// the newly established TCP port forward proceeds. This LRU logic allows some
// dangling resource consumption (e.g., TIME_WAIT) while providing a better
// experience for clients.
//
// e. Relay data.
//
// f. Call closedPortForward() which decrements concurrentPortForwardCount and
// records bytes transferred.
// Start the TCP port forward manager
type newTCPPortForward struct {
enqueueTime monotime.Time
hostToConnect string
portToConnect int
newChannel ssh.NewChannel
}
// The queue size is set to the traffic rules (MaxTCPPortForwardCount +
// MaxTCPDialingPortForwardCount), which is a reasonable indication of resource
// limits per client; when that value is not set, a default is used.
// A limitation: this queue size is set once and doesn't change, for this client,
// when traffic rules are reloaded.
queueSize := sshClient.getTCPPortForwardQueueSize()
if queueSize == 0 {
queueSize = SSH_TCP_PORT_FORWARD_QUEUE_SIZE
}
newTCPPortForwards := make(chan *newTCPPortForward, queueSize)
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
for newPortForward := range newTCPPortForwards {
remainingDialTimeout :=
time.Duration(sshClient.getDialTCPPortForwardTimeoutMilliseconds())*time.Millisecond -
monotime.Since(newPortForward.enqueueTime)
if remainingDialTimeout <= 0 {
sshClient.updateQualityMetricsWithRejectedDialingLimit()
sshClient.rejectNewChannel(
newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out in queue")
continue
}
// Reserve a TCP dialing slot.
//
// TOCTOU note: important to increment counts _before_ checking limits; otherwise,
// the client could potentially consume excess resources by initiating many port
// forwards concurrently.
sshClient.dialingTCPPortForward()
// When max dials are in progress, wait up to remainingDialTimeout for dialing
// to become available. This blocks all dequeing.
if sshClient.isTCPDialingPortForwardLimitExceeded() {
blockStartTime := monotime.Now()
ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
sshClient.setTCPPortForwardDialingAvailableSignal(cancelCtx)
<-ctx.Done()
sshClient.setTCPPortForwardDialingAvailableSignal(nil)
cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
remainingDialTimeout -= monotime.Since(blockStartTime)
}
if remainingDialTimeout <= 0 {
// Release the dialing slot here since handleTCPChannel() won't be called.
sshClient.abortedTCPPortForward()
sshClient.updateQualityMetricsWithRejectedDialingLimit()
sshClient.rejectNewChannel(
newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out before dialing")
continue
}
// Dial and relay the TCP port forward. handleTCPChannel is run in its own worker goroutine.
// handleTCPChannel will release the dialing slot reserved by dialingTCPPortForward(); and
// will deal with remainingDialTimeout <= 0.
waitGroup.Add(1)
go func(remainingDialTimeout time.Duration, newPortForward *newTCPPortForward) {
defer waitGroup.Done()
sshClient.handleTCPChannel(
remainingDialTimeout,
newPortForward.hostToConnect,
newPortForward.portToConnect,
newPortForward.newChannel)
}(remainingDialTimeout, newPortForward)
}
}()
// Handle new channel (port forward) requests from the client.
//
// packet tunnel channels are handled by the packet tunnel server
// component. Each client may have at most one packet tunnel channel.
//
// udpgw client connections are dispatched immediately (clients use this for
// DNS, so it's essential to not block; and only one udpgw connection is
// retained at a time).
//
// All other TCP port forwards are dispatched via the TCP port forward
// manager queue.
for newChannel := range channels {
if newChannel.ChannelType() == protocol.PACKET_TUNNEL_CHANNEL_TYPE {
if !sshClient.sshServer.support.Config.RunPacketTunnel {
sshClient.rejectNewChannel(
newChannel, ssh.Prohibited, "unsupported packet tunnel channel type")
continue
}
// Accept this channel immediately. This channel will replace any
// previously existing packet tunnel channel for this client.
packetTunnelChannel, requests, err := newChannel.Accept()
if err != nil {
log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
continue
}
go ssh.DiscardRequests(requests)
sshClient.setPacketTunnelChannel(packetTunnelChannel)
// PacketTunnelServer will run the client's packet tunnel. If neessary, ClientConnected
// will stop packet tunnel workers for any previous packet tunnel channel.
checkAllowedTCPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
return sshClient.isPortForwardPermitted(portForwardTypeTCP, false, upstreamIPAddress, port)
}
checkAllowedUDPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
return sshClient.isPortForwardPermitted(portForwardTypeUDP, false, upstreamIPAddress, port)
}
flowActivityUpdaterMaker := func(
upstreamHostname string, upstreamIPAddress net.IP) []tun.FlowActivityUpdater {
var updaters []tun.FlowActivityUpdater
oslUpdater := sshClient.newClientSeedPortForward(upstreamIPAddress)
if oslUpdater != nil {
updaters = append(updaters, oslUpdater)
}
return updaters
}
err = sshClient.sshServer.support.PacketTunnelServer.ClientConnected(
sshClient.sessionID,
packetTunnelChannel,
checkAllowedTCPPortFunc,
checkAllowedUDPPortFunc,
flowActivityUpdaterMaker)
if err != nil {
log.WithContextFields(LogFields{"error": err}).Warning("start packet tunnel client failed")
sshClient.setPacketTunnelChannel(nil)
}
continue
}
if newChannel.ChannelType() != "direct-tcpip" {
sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
continue
}
// http://tools.ietf.org/html/rfc4254#section-7.2
var directTcpipExtraData struct {
HostToConnect string
PortToConnect uint32
OriginatorIPAddress string
OriginatorPort uint32
}
err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
if err != nil {
sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
continue
}
// Intercept TCP port forwards to a specified udpgw server and handle directly.
// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
if isUDPChannel {
// Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
// own worker goroutine.
waitGroup.Add(1)
go func(channel ssh.NewChannel) {
defer waitGroup.Done()
sshClient.handleUDPChannel(channel)
}(newChannel)
} else {
// Dispatch via TCP port forward manager. When the queue is full, the channel
// is immediately rejected.
tcpPortForward := &newTCPPortForward{
enqueueTime: monotime.Now(),
hostToConnect: directTcpipExtraData.HostToConnect,
portToConnect: int(directTcpipExtraData.PortToConnect),
newChannel: newChannel,
}
select {
case newTCPPortForwards <- tcpPortForward:
default:
sshClient.updateQualityMetricsWithRejectedDialingLimit()
sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "TCP port forward dial queue full")
}
}
}
// The channel loop is interrupted by a client
// disconnect or by calling sshClient.stop().
// Stop the TCP port forward manager
close(newTCPPortForwards)
// Stop all other worker goroutines
sshClient.stopRunning()
if sshClient.sshServer.support.Config.RunPacketTunnel {
// PacketTunnelServer.ClientDisconnected stops packet tunnel workers.
sshClient.sshServer.support.PacketTunnelServer.ClientDisconnected(
sshClient.sessionID)
}
waitGroup.Wait()
}
// setPacketTunnelChannel sets the single packet tunnel channel
// for this sshClient. Any existing packet tunnel channel is
// closed.
func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
sshClient.Lock()
if sshClient.packetTunnelChannel != nil {
sshClient.packetTunnelChannel.Close()
}
sshClient.packetTunnelChannel = channel
sshClient.Unlock()
}
// setUDPChannel sets the single UDP channel for this sshClient.
// Each sshClient may have only one concurrent UDP channel. Each
// UDP channel multiplexes many UDP port forwards via the udpgw
// protocol. Any existing UDP channel is closed.
func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
sshClient.Lock()
if sshClient.udpChannel != nil {
sshClient.udpChannel.Close()
}
sshClient.udpChannel = channel
sshClient.Unlock()
}
func (sshClient *sshClient) logTunnel(additionalMetrics LogFields) {
// Note: reporting duration based on last confirmed data transfer, which
// is reads for sshClient.activityConn.GetActiveDuration(), and not
// connection closing is important for protocols such as meek. For
// meek, the connection remains open until the HTTP session expires,
// which may be some time after the tunnel has closed. (The meek
// protocol has no allowance for signalling payload EOF, and even if
// it did the client may not have the opportunity to send a final
// request with an EOF flag set.)
sshClient.Lock()
logFields := getRequestLogFields(
sshClient.sshServer.support,
"server_tunnel",
sshClient.geoIPData,
sshClient.handshakeState.apiParams,
baseRequestParams)
logFields["handshake_completed"] = sshClient.handshakeState.completed
logFields["start_time"] = sshClient.activityConn.GetStartTime()
logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
// Merge in additional metrics from the optional metrics source
if additionalMetrics != nil {
for name, value := range additionalMetrics {
// Don't overwrite any basic fields
if logFields[name] == nil {
logFields[name] = value
}
}
}
sshClient.Unlock()
log.LogRawFieldsWithTimestamp(logFields)
}
func (sshClient *sshClient) runOSLSender() {
for {
// Await a signal that there are SLOKs to send
// TODO: use reflect.SelectCase, and optionally await timer here?
select {
case <-sshClient.signalIssueSLOKs:
case <-sshClient.runCtx.Done():
return
}
retryDelay := SSH_SEND_OSL_INITIAL_RETRY_DELAY
for {
err := sshClient.sendOSLRequest()
if err == nil {
break
}
log.WithContextFields(LogFields{"error": err}).Warning("sendOSLRequest failed")
// If the request failed, retry after a delay (with exponential backoff)
// or when signaled that there are additional SLOKs to send
retryTimer := time.NewTimer(retryDelay)
select {
case <-retryTimer.C:
case <-sshClient.signalIssueSLOKs:
case <-sshClient.runCtx.Done():
retryTimer.Stop()
return
}
retryTimer.Stop()
retryDelay *= SSH_SEND_OSL_RETRY_FACTOR
}
}
}
// sendOSLRequest will invoke osl.GetSeedPayload to issue SLOKs and
// generate a payload, and send an OSL request to the client when
// there are new SLOKs in the payload.
func (sshClient *sshClient) sendOSLRequest() error {
seedPayload := sshClient.getOSLSeedPayload()
// Don't send when no SLOKs. This will happen when signalIssueSLOKs
// is received but no new SLOKs are issued.
if len(seedPayload.SLOKs) == 0 {
return nil
}
oslRequest := protocol.OSLRequest{
SeedPayload: seedPayload,
}
requestPayload, err := json.Marshal(oslRequest)
if err != nil {
return common.ContextError(err)
}
ok, _, err := sshClient.sshConn.SendRequest(
protocol.PSIPHON_API_OSL_REQUEST_NAME,
true,
requestPayload)
if err != nil {
return common.ContextError(err)
}
if !ok {
return common.ContextError(errors.New("client rejected request"))
}
sshClient.clearOSLSeedPayload()
return nil
}
func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, logMessage string) {
// Note: Debug level, as logMessage may contain user traffic destination address information
log.WithContextFields(
LogFields{
"channelType": newChannel.ChannelType(),
"logMessage": logMessage,
"rejectReason": reason.String(),
}).Debug("reject new channel")
// Note: logMessage is internal, for logging only; just the RejectionReason is sent to the client
newChannel.Reject(reason, reason.String())
}
// setHandshakeState records that a client has completed a handshake API request.
// Some parameters from the handshake request may be used in future traffic rule
// selection. Port forwards are disallowed until a handshake is complete. The
// handshake parameters are included in the session summary log recorded in
// sshClient.stop().
func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
sshClient.Lock()
completed := sshClient.handshakeState.completed
if !completed {
sshClient.handshakeState = state
}
sshClient.Unlock()
// Client must only perform one handshake
if completed {
return common.ContextError(errors.New("handshake already completed"))
}
sshClient.setTrafficRules()
sshClient.setOSLConfig()
return nil
}
// getHandshaked returns whether the client has completed a handshake API
// request and whether the traffic rules that were selected after the
// handshake immediately exhaust the client.
//
// When the client is immediately exhausted it will be closed; but this
// takes effect asynchronously. The "exhausted" return value is used to
// prevent API requests by clients that will close.
func (sshClient *sshClient) getHandshaked() (bool, bool) {
sshClient.Lock()
defer sshClient.Unlock()
completed := sshClient.handshakeState.completed
exhausted := false
// Notes:
// - "Immediately exhausted" is when CloseAfterExhausted is set and
// either ReadUnthrottledBytes or WriteUnthrottledBytes starts from
// 0, so no bytes would be read or written. This check does not
// examine whether 0 bytes _remain_ in the ThrottledConn.
// - This check is made against the current traffic rules, which
// could have changed in a hot reload since the handshake.
if completed &&
*sshClient.trafficRules.RateLimits.CloseAfterExhausted == true &&
(*sshClient.trafficRules.RateLimits.ReadUnthrottledBytes == 0 ||
*sshClient.trafficRules.RateLimits.WriteUnthrottledBytes == 0) {
exhausted = true
}
return completed, exhausted
}
// setTrafficRules resets the client's traffic rules based on the latest server config
// and client properties. As sshClient.trafficRules may be reset by a concurrent
// goroutine, trafficRules must only be accessed within the sshClient mutex.
func (sshClient *sshClient) setTrafficRules() {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
sshClient.tunnelProtocol, sshClient.geoIPData, sshClient.handshakeState)
if sshClient.throttledConn != nil {
// Any existing throttling state is reset.
sshClient.throttledConn.SetLimits(
sshClient.trafficRules.RateLimits.CommonRateLimits())
}
}
// setOSLConfig resets the client's OSL seed state based on the latest OSL config
// As sshClient.oslClientSeedState may be reset by a concurrent goroutine,
// oslClientSeedState must only be accessed within the sshClient mutex.
func (sshClient *sshClient) setOSLConfig() {
sshClient.Lock()
defer sshClient.Unlock()
propagationChannelID, err := getStringRequestParam(
sshClient.handshakeState.apiParams, "propagation_channel_id")
if err != nil {
// This should not fail as long as client has sent valid handshake
return
}
// Use a cached seed state if one is found for the client's
// session ID. This enables resuming progress made in a previous
// tunnel.
// Note: go-cache is already concurency safe; the additional mutex
// is necessary to guarantee that Get/Delete is atomic; although in
// practice no two concurrent clients should ever supply the same
// session ID.
sshClient.sshServer.oslSessionCacheMutex.Lock()
oslClientSeedState, found := sshClient.sshServer.oslSessionCache.Get(sshClient.sessionID)
if found {
sshClient.sshServer.oslSessionCache.Delete(sshClient.sessionID)
sshClient.sshServer.oslSessionCacheMutex.Unlock()
sshClient.oslClientSeedState = oslClientSeedState.(*osl.ClientSeedState)
sshClient.oslClientSeedState.Resume(sshClient.signalIssueSLOKs)
return
}
sshClient.sshServer.oslSessionCacheMutex.Unlock()
// Two limitations when setOSLConfig() is invoked due to an
// OSL config hot reload:
//
// 1. any partial progress towards SLOKs is lost.
//
// 2. all existing osl.ClientSeedPortForwards for existing
// port forwards will not send progress to the new client
// seed state.
sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
sshClient.geoIPData.Country,
propagationChannelID,
sshClient.signalIssueSLOKs)
}
// newClientSeedPortForward will return nil when no seeding is
// associated with the specified ipAddress.
func (sshClient *sshClient) newClientSeedPortForward(ipAddress net.IP) *osl.ClientSeedPortForward {
sshClient.Lock()
defer sshClient.Unlock()
// Will not be initialized before handshake.
if sshClient.oslClientSeedState == nil {
return nil
}
return sshClient.oslClientSeedState.NewClientSeedPortForward(ipAddress)
}
// getOSLSeedPayload returns a payload containing all seeded SLOKs for
// this client's session.
func (sshClient *sshClient) getOSLSeedPayload() *osl.SeedPayload {
sshClient.Lock()
defer sshClient.Unlock()
// Will not be initialized before handshake.
if sshClient.oslClientSeedState == nil {
return &osl.SeedPayload{SLOKs: make([]*osl.SLOK, 0)}
}
return sshClient.oslClientSeedState.GetSeedPayload()
}
func (sshClient *sshClient) clearOSLSeedPayload() {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.oslClientSeedState.ClearSeedPayload()
}
func (sshClient *sshClient) rateLimits() common.RateLimits {
sshClient.Lock()
defer sshClient.Unlock()
return sshClient.trafficRules.RateLimits.CommonRateLimits()
}
func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
sshClient.Lock()
defer sshClient.Unlock()
return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
}
func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
sshClient.Lock()
defer sshClient.Unlock()
return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
}
func (sshClient *sshClient) setTCPPortForwardDialingAvailableSignal(signal context.CancelFunc) {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.tcpPortForwardDialingAvailableSignal = signal
}
const (
portForwardTypeTCP = iota
portForwardTypeUDP
portForwardTypeTransparentDNS
)
func (sshClient *sshClient) isPortForwardPermitted(
portForwardType int,
isTransparentDNSForwarding bool,
remoteIP net.IP,
port int) bool {
sshClient.Lock()
defer sshClient.Unlock()
if !sshClient.handshakeState.completed {
return false
}
// Disallow connection to loopback. This is a failsafe. The server
// should be run on a host with correctly configured firewall rules.
// An exception is made in the case of tranparent DNS forwarding,
// where the remoteIP has been rewritten.
if !isTransparentDNSForwarding && remoteIP.IsLoopback() {
return false
}
var allowPorts []int
if portForwardType == portForwardTypeTCP {
allowPorts = sshClient.trafficRules.AllowTCPPorts
} else {
allowPorts = sshClient.trafficRules.AllowUDPPorts
}
if len(allowPorts) == 0 {
return true
}
// TODO: faster lookup?
if len(allowPorts) > 0 {
for _, allowPort := range allowPorts {
if port == allowPort {
return true
}
}
}
for _, subnet := range sshClient.trafficRules.AllowSubnets {
// Note: ignoring error as config has been validated
_, network, _ := net.ParseCIDR(subnet)
if network.Contains(remoteIP) {
return true
}
}
return false
}
func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
sshClient.Lock()
defer sshClient.Unlock()
state := &sshClient.tcpTrafficState
max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
return true
}
return false
}
func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
sshClient.Lock()
defer sshClient.Unlock()
return *sshClient.trafficRules.MaxTCPPortForwardCount +
*sshClient.trafficRules.MaxTCPDialingPortForwardCount
}
func (sshClient *sshClient) getDialTCPPortForwardTimeoutMilliseconds() int {
sshClient.Lock()
defer sshClient.Unlock()
return *sshClient.trafficRules.DialTCPPortForwardTimeoutMilliseconds
}
func (sshClient *sshClient) dialingTCPPortForward() {
sshClient.Lock()
defer sshClient.Unlock()
state := &sshClient.tcpTrafficState
state.concurrentDialingPortForwardCount += 1
if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
}
}
func (sshClient *sshClient) abortedTCPPortForward() {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
}
func (sshClient *sshClient) allocatePortForward(portForwardType int) bool {
sshClient.Lock()
defer sshClient.Unlock()
// Check if at port forward limit. The subsequent counter
// changes must be atomic with the limit check to ensure
// the counter never exceeds the limit in the case of
// concurrent allocations.
var max int
var state *trafficState
if portForwardType == portForwardTypeTCP {
max = *sshClient.trafficRules.MaxTCPPortForwardCount
state = &sshClient.tcpTrafficState
} else {
max = *sshClient.trafficRules.MaxUDPPortForwardCount
state = &sshClient.udpTrafficState
}
if max > 0 && state.concurrentPortForwardCount >= int64(max) {
return false
}
// Update port forward counters.
if portForwardType == portForwardTypeTCP {
// Assumes TCP port forwards called dialingTCPPortForward
state.concurrentDialingPortForwardCount -= 1
if sshClient.tcpPortForwardDialingAvailableSignal != nil {
max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
sshClient.tcpPortForwardDialingAvailableSignal()
}
}
}
state.concurrentPortForwardCount += 1
if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
}
state.totalPortForwardCount += 1
return true
}
// establishedPortForward increments the concurrent port
// forward counter. closedPortForward decrements it, so it
// must always be called for each establishedPortForward
// call.
//
// When at the limit of established port forwards, the LRU
// existing port forward is closed to make way for the newly
// established one. There can be a minor delay as, in addition
// to calling Close() on the port forward net.Conn,
// establishedPortForward waits for the LRU's closedPortForward()
// call which will decrement the concurrent counter. This
// ensures all resources associated with the LRU (socket,
// goroutine) are released or will very soon be released before
// proceeding.
func (sshClient *sshClient) establishedPortForward(
portForwardType int, portForwardLRU *common.LRUConns) {
// Do not lock sshClient here.
var state *trafficState
if portForwardType == portForwardTypeTCP {
state = &sshClient.tcpTrafficState
} else {
state = &sshClient.udpTrafficState
}
// When the maximum number of port forwards is already
// established, close the LRU. CloseOldest will call
// Close on the port forward net.Conn. Both TCP and
// UDP port forwards have handler goroutines that may
// be blocked calling Read on the net.Conn. Close will
// eventually interrupt the Read and cause the handlers
// to exit, but not immediately. So the following logic
// waits for a LRU handler to be interrupted and signal
// availability.
//
// Notes:
//
// - the port forward limit can change via a traffic
// rules hot reload; the condition variable handles
// this case whereas a channel-based semaphore would
// not.
//
// - if a number of goroutines exceeding the total limit
// arrive here all concurrently, some CloseOldest() calls
// will have no effect as there can be less existing port
// forwards than new ones. In this case, the new port
// forward will be delayed. This is highly unlikely in
// practise since UDP calls to establishedPortForward are
// serialized and TCP calls are limited by the dial
// queue/count.
if !sshClient.allocatePortForward(portForwardType) {
portForwardLRU.CloseOldest()
log.WithContext().Debug("closed LRU port forward")
state.availablePortForwardCond.L.Lock()
for !sshClient.allocatePortForward(portForwardType) {
state.availablePortForwardCond.Wait()
}
state.availablePortForwardCond.L.Unlock()
}
}
func (sshClient *sshClient) closedPortForward(
portForwardType int, bytesUp, bytesDown int64) {
sshClient.Lock()
var state *trafficState
if portForwardType == portForwardTypeTCP {
state = &sshClient.tcpTrafficState
} else {
state = &sshClient.udpTrafficState
}
state.concurrentPortForwardCount -= 1
state.bytesUp += bytesUp
state.bytesDown += bytesDown
sshClient.Unlock()
// Signal any goroutine waiting in establishedPortForward
// that an established port forward slot is available.
state.availablePortForwardCond.Signal()
}
func (sshClient *sshClient) updateQualityMetricsWithDialResult(
tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
sshClient.Lock()
defer sshClient.Unlock()
if tcpPortForwardDialSuccess {
sshClient.qualityMetrics.tcpPortForwardDialedCount += 1
sshClient.qualityMetrics.tcpPortForwardDialedDuration += dialDuration
} else {
sshClient.qualityMetrics.tcpPortForwardFailedCount += 1
sshClient.qualityMetrics.tcpPortForwardFailedDuration += dialDuration
}
}
func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
sshClient.Lock()
defer sshClient.Unlock()
sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
}
func (sshClient *sshClient) handleTCPChannel(
remainingDialTimeout time.Duration,
hostToConnect string,
portToConnect int,
newChannel ssh.NewChannel) {
// Assumptions:
// - sshClient.dialingTCPPortForward() has been called
// - remainingDialTimeout > 0
established := false
defer func() {
if !established {
sshClient.abortedTCPPortForward()
}
}()
// Transparently redirect web API request connections.
isWebServerPortForward := false
config := sshClient.sshServer.support.Config
if config.WebServerPortForwardAddress != "" {
destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
if destination == config.WebServerPortForwardAddress {
isWebServerPortForward = true
if config.WebServerPortForwardRedirectAddress != "" {
// Note: redirect format is validated when config is loaded
host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
port, _ := strconv.Atoi(portStr)
hostToConnect = host
portToConnect = port
}
}
}
// Dial the remote address.
//
// Hostname resolution is performed explicitly, as a separate step, as the target IP
// address is used for traffic rules (AllowSubnets) and OSL seed progress.
//
// Contexts are used for cancellation (via sshClient.runCtx, which is cancelled
// when the client is stopping) and timeouts.
dialStartTime := monotime.Now()
log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
// TODO: shuffle list to try other IPs?
// TODO: IPv6 support
var IP net.IP
for _, ip := range IPs {
if ip.IP.To4() != nil {
IP = ip.IP
break
}
}
if err == nil && IP == nil {
err = errors.New("no IP address")
}
resolveElapsedTime := monotime.Since(dialStartTime)
if err != nil {
// Record a port forward failure
sshClient.updateQualityMetricsWithDialResult(true, resolveElapsedTime)
sshClient.rejectNewChannel(
newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
return
}
remainingDialTimeout -= resolveElapsedTime
if remainingDialTimeout <= 0 {
sshClient.rejectNewChannel(
newChannel, ssh.Prohibited, "TCP port forward timed out resolving")
return
}
// Enforce traffic rules, using the resolved IP address.
if !isWebServerPortForward &&
!sshClient.isPortForwardPermitted(
portForwardTypeTCP,
false,
IP,
portToConnect) {
// Note: not recording a port forward failure in this case
sshClient.rejectNewChannel(
newChannel, ssh.Prohibited, "port forward not permitted")
return
}
// TCP dial.
remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
ctx, cancelCtx = context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
// Record port forward success or failure
sshClient.updateQualityMetricsWithDialResult(err == nil, monotime.Since(dialStartTime))
if err != nil {
// Monitor for low resource error conditions
sshClient.sshServer.monitorPortForwardDialError(err)
sshClient.rejectNewChannel(
newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", err))
return
}
// The upstream TCP port forward connection has been established. Schedule
// some cleanup and notify the SSH client that the channel is accepted.
defer fwdConn.Close()
fwdChannel, requests, err := newChannel.Accept()
if err != nil {
log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
return
}
go ssh.DiscardRequests(requests)
defer fwdChannel.Close()
// Release the dialing slot and acquire an established slot.
//
// establishedPortForward increments the concurrent TCP port
// forward counter and closes the LRU existing TCP port forward
// when already at the limit.
//
// Known limitations:
//
// - Closed LRU TCP sockets will enter the TIME_WAIT state,
// continuing to consume some resources.
sshClient.establishedPortForward(portForwardTypeTCP, sshClient.tcpPortForwardLRU)
// "established = true" cancels the deferred abortedTCPPortForward()
established = true
// TODO: 64-bit alignment? https://golang.org/pkg/sync/atomic/#pkg-note-BUG
var bytesUp, bytesDown int64
defer func() {
sshClient.closedPortForward(
portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
}()
lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
defer lruEntry.Remove()
// ActivityMonitoredConn monitors the TCP port forward I/O and updates
// its LRU status. ActivityMonitoredConn also times out I/O on the port
// forward if both reads and writes have been idle for the specified
// duration.
// Ensure nil interface if newClientSeedPortForward returns nil
var updater common.ActivityUpdater
seedUpdater := sshClient.newClientSeedPortForward(IP)
if seedUpdater != nil {
updater = seedUpdater
}
fwdConn, err = common.NewActivityMonitoredConn(
fwdConn,
sshClient.idleTCPPortForwardTimeout(),
true,
updater,
lruEntry)
if err != nil {
log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
return
}
// Relay channel to forwarded connection.
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
// TODO: relay errors to fwdChannel.Stderr()?
relayWaitGroup := new(sync.WaitGroup)
relayWaitGroup.Add(1)
go func() {
defer relayWaitGroup.Done()
// io.Copy allocates a 32K temporary buffer, and each port forward relay uses
// two of these buffers; using io.CopyBuffer with a smaller buffer reduces the
// overall memory footprint.
bytes, err := io.CopyBuffer(
fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
atomic.AddInt64(&bytesDown, bytes)
if err != nil && err != io.EOF {
// Debug since errors such as "connection reset by peer" occur during normal operation
log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
}
// Interrupt upstream io.Copy when downstream is shutting down.
// TODO: this is done to quickly cleanup the port forward when
// fwdConn has a read timeout, but is it clean -- upstream may still
// be flowing?
fwdChannel.Close()
}()
bytes, err := io.CopyBuffer(
fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
atomic.AddInt64(&bytesUp, bytes)
if err != nil && err != io.EOF {
log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
}
// Shutdown special case: fwdChannel will be closed and return EOF when
// the SSH connection is closed, but we need to explicitly close fwdConn
// to interrupt the downstream io.Copy, which may be blocked on a
// fwdConn.Read().
fwdConn.Close()
relayWaitGroup.Wait()
log.WithContextFields(
LogFields{
"remoteAddr": remoteAddr,
"bytesUp": atomic.LoadInt64(&bytesUp),
"bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
}