|
@@ -26,7 +26,6 @@ import (
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"io"
|
|
"io"
|
|
|
"net"
|
|
"net"
|
|
|
- "runtime"
|
|
|
|
|
"sync"
|
|
"sync"
|
|
|
"sync/atomic"
|
|
"sync/atomic"
|
|
|
"time"
|
|
"time"
|
|
@@ -35,46 +34,66 @@ import (
|
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-// RunSSHServer runs an SSH server, the core tunneling component of the Psiphon
|
|
|
|
|
-// server. The SSH server runs a selection of listeners that handle connections
|
|
|
|
|
-// using various, optional obfuscation protocols layered on top of SSH.
|
|
|
|
|
-// (Currently, just Obfuscated SSH).
|
|
|
|
|
-//
|
|
|
|
|
-// RunSSHServer listens on the designated port(s) and spawns new goroutines to handle
|
|
|
|
|
-// each client connection. It halts when shutdownBroadcast is signaled. A list of active
|
|
|
|
|
-// clients is maintained, and when halting all clients are first shutdown.
|
|
|
|
|
-//
|
|
|
|
|
-// Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
|
|
|
|
|
-// authentication, and then looping on client new channel requests. At this time, only
|
|
|
|
|
-// "direct-tcpip" channels, dynamic port fowards, are expected and supported.
|
|
|
|
|
-//
|
|
|
|
|
-// A new goroutine is spawned to handle each port forward for each client. Each port
|
|
|
|
|
-// forward tracks its bytes transferred. Overall per-client stats for connection duration,
|
|
|
|
|
-// GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
|
|
|
|
|
-// client shuts down.
|
|
|
|
|
-func RunSSHServer(
|
|
|
|
|
- config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
|
|
|
|
+// TunnelServer is the main server that accepts Psiphon client
|
|
|
|
|
+// connections, via various obfuscation protocols, and provides
|
|
|
|
|
+// port forwarding (TCP and UDP) services to the Psiphon client.
|
|
|
|
|
+// At its core, TunnelServer is an SSH server. SSH is the base
|
|
|
|
|
+// protocol that provides port forward multiplexing, and transport
|
|
|
|
|
+// security. Layered on top of SSH, optionally, is Obfuscated SSH
|
|
|
|
|
+// and meek protocols, which provide further circumvention
|
|
|
|
|
+// capabilities.
|
|
|
|
|
+type TunnelServer struct {
|
|
|
|
|
+ config *Config
|
|
|
|
|
+ runWaitGroup *sync.WaitGroup
|
|
|
|
|
+ listenerError chan error
|
|
|
|
|
+ shutdownBroadcast <-chan struct{}
|
|
|
|
|
+ sshServer *sshServer
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
- privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return psiphon.ContextError(err)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+// NewTunnelServer initializes a new tunnel server.
|
|
|
|
|
+func NewTunnelServer(
|
|
|
|
|
+ config *Config, shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
|
|
|
|
|
|
|
|
- // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
|
|
|
|
|
- signer, err := ssh.NewSignerFromKey(privateKey)
|
|
|
|
|
|
|
+ sshServer, err := newSSHServer(config, shutdownBroadcast)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return psiphon.ContextError(err)
|
|
|
|
|
|
|
+ return nil, psiphon.ContextError(err)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- sshServer := &sshServer{
|
|
|
|
|
|
|
+ return &TunnelServer{
|
|
|
config: config,
|
|
config: config,
|
|
|
runWaitGroup: new(sync.WaitGroup),
|
|
runWaitGroup: new(sync.WaitGroup),
|
|
|
listenerError: make(chan error),
|
|
listenerError: make(chan error),
|
|
|
shutdownBroadcast: shutdownBroadcast,
|
|
shutdownBroadcast: shutdownBroadcast,
|
|
|
- sshHostKey: signer,
|
|
|
|
|
- nextClientID: 1,
|
|
|
|
|
- clients: make(map[sshClientID]*sshClient),
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ sshServer: sshServer,
|
|
|
|
|
+ }, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// GetLoadStats returns load stats for the tunnel server. The stats are
|
|
|
|
|
+// broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
|
|
|
|
|
+// include current connected client count, total number of current port
|
|
|
|
|
+// forwards.
|
|
|
|
|
+func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
|
|
|
|
|
+ return server.sshServer.getLoadStats()
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Run runs the tunnel server; this function blocks while running a selection of
|
|
|
|
|
+// listeners that handle connection using various obfuscation protocols.
|
|
|
|
|
+//
|
|
|
|
|
+// Run listens on each designated tunnel port and spawns new goroutines to handle
|
|
|
|
|
+// each client connection. It halts when shutdownBroadcast is signaled. A list of active
|
|
|
|
|
+// clients is maintained, and when halting all clients are cleanly shutdown.
|
|
|
|
|
+//
|
|
|
|
|
+// Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
|
|
|
|
|
+// authentication, and then looping on client new channel requests. "direct-tcpip"
|
|
|
|
|
+// channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress
|
|
|
|
|
+// config parameter is configured, UDP port forwards over a TCP stream, following
|
|
|
|
|
+// the udpgw protocol, are handled.
|
|
|
|
|
+//
|
|
|
|
|
+// A new goroutine is spawned to handle each port forward for each client. Each port
|
|
|
|
|
+// forward tracks its bytes transferred. Overall per-client stats for connection duration,
|
|
|
|
|
+// GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
|
|
|
|
|
+// client shuts down.
|
|
|
|
|
+func (server *TunnelServer) Run() error {
|
|
|
|
|
|
|
|
type sshListener struct {
|
|
type sshListener struct {
|
|
|
net.Listener
|
|
net.Listener
|
|
@@ -82,78 +101,75 @@ func RunSSHServer(
|
|
|
tunnelProtocol string
|
|
tunnelProtocol string
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // First bind all listeners; once all are successful,
|
|
|
|
|
+ // start accepting connections on each.
|
|
|
|
|
+
|
|
|
var listeners []*sshListener
|
|
var listeners []*sshListener
|
|
|
|
|
|
|
|
- if config.RunSSHServer() {
|
|
|
|
|
- listeners = append(listeners, &sshListener{
|
|
|
|
|
- localAddress: fmt.Sprintf(
|
|
|
|
|
- "%s:%d", config.ServerIPAddress, config.SSHServerPort),
|
|
|
|
|
- tunnelProtocol: psiphon.TUNNEL_PROTOCOL_SSH,
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ for tunnelProtocol, listenPort := range server.config.TunnelProtocolPorts {
|
|
|
|
|
|
|
|
- if config.RunObfuscatedSSHServer() {
|
|
|
|
|
- listeners = append(listeners, &sshListener{
|
|
|
|
|
- localAddress: fmt.Sprintf(
|
|
|
|
|
- "%s:%d", config.ServerIPAddress, config.ObfuscatedSSHServerPort),
|
|
|
|
|
- tunnelProtocol: psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- // TODO: add additional protocol listeners here (e.g, meek)
|
|
|
|
|
|
|
+ localAddress := fmt.Sprintf(
|
|
|
|
|
+ "%s:%d", server.config.ServerIPAddress, listenPort)
|
|
|
|
|
|
|
|
- for i, listener := range listeners {
|
|
|
|
|
- var err error
|
|
|
|
|
- listener.Listener, err = net.Listen("tcp", listener.localAddress)
|
|
|
|
|
|
|
+ listener, err := net.Listen("tcp", localAddress)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- for j := 0; j < i; j++ {
|
|
|
|
|
- listener.Listener.Close()
|
|
|
|
|
|
|
+ for _, existingListener := range listeners {
|
|
|
|
|
+ existingListener.Listener.Close()
|
|
|
}
|
|
}
|
|
|
return psiphon.ContextError(err)
|
|
return psiphon.ContextError(err)
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
log.WithContextFields(
|
|
log.WithContextFields(
|
|
|
LogFields{
|
|
LogFields{
|
|
|
- "localAddress": listener.localAddress,
|
|
|
|
|
- "tunnelProtocol": listener.tunnelProtocol,
|
|
|
|
|
|
|
+ "localAddress": localAddress,
|
|
|
|
|
+ "tunnelProtocol": tunnelProtocol,
|
|
|
}).Info("listening")
|
|
}).Info("listening")
|
|
|
|
|
+
|
|
|
|
|
+ listeners = append(
|
|
|
|
|
+ listeners,
|
|
|
|
|
+ &sshListener{
|
|
|
|
|
+ Listener: listener,
|
|
|
|
|
+ localAddress: localAddress,
|
|
|
|
|
+ tunnelProtocol: tunnelProtocol,
|
|
|
|
|
+ })
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for _, listener := range listeners {
|
|
for _, listener := range listeners {
|
|
|
- sshServer.runWaitGroup.Add(1)
|
|
|
|
|
|
|
+ server.runWaitGroup.Add(1)
|
|
|
go func(listener *sshListener) {
|
|
go func(listener *sshListener) {
|
|
|
- defer sshServer.runWaitGroup.Done()
|
|
|
|
|
|
|
+ defer server.runWaitGroup.Done()
|
|
|
|
|
|
|
|
- sshServer.runListener(
|
|
|
|
|
- listener.Listener, listener.tunnelProtocol)
|
|
|
|
|
|
|
+ log.WithContextFields(
|
|
|
|
|
+ LogFields{
|
|
|
|
|
+ "localAddress": listener.localAddress,
|
|
|
|
|
+ "tunnelProtocol": listener.tunnelProtocol,
|
|
|
|
|
+ }).Info("running")
|
|
|
|
|
+
|
|
|
|
|
+ server.sshServer.runListener(
|
|
|
|
|
+ listener.Listener,
|
|
|
|
|
+ server.listenerError,
|
|
|
|
|
+ listener.tunnelProtocol)
|
|
|
|
|
|
|
|
log.WithContextFields(
|
|
log.WithContextFields(
|
|
|
LogFields{
|
|
LogFields{
|
|
|
"localAddress": listener.localAddress,
|
|
"localAddress": listener.localAddress,
|
|
|
"tunnelProtocol": listener.tunnelProtocol,
|
|
"tunnelProtocol": listener.tunnelProtocol,
|
|
|
- }).Info("stopping")
|
|
|
|
|
|
|
+ }).Info("stopped")
|
|
|
|
|
|
|
|
}(listener)
|
|
}(listener)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if config.RunLoadMonitor() {
|
|
|
|
|
- sshServer.runWaitGroup.Add(1)
|
|
|
|
|
- go func() {
|
|
|
|
|
- defer sshServer.runWaitGroup.Done()
|
|
|
|
|
- sshServer.runLoadMonitor()
|
|
|
|
|
- }()
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- err = nil
|
|
|
|
|
|
|
+ var err error
|
|
|
select {
|
|
select {
|
|
|
- case <-sshServer.shutdownBroadcast:
|
|
|
|
|
- case err = <-sshServer.listenerError:
|
|
|
|
|
|
|
+ case <-server.shutdownBroadcast:
|
|
|
|
|
+ case err = <-server.listenerError:
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for _, listener := range listeners {
|
|
for _, listener := range listeners {
|
|
|
listener.Close()
|
|
listener.Close()
|
|
|
}
|
|
}
|
|
|
- sshServer.stopClients()
|
|
|
|
|
- sshServer.runWaitGroup.Wait()
|
|
|
|
|
|
|
+ server.sshServer.stopClients()
|
|
|
|
|
+ server.runWaitGroup.Wait()
|
|
|
|
|
|
|
|
log.WithContext().Info("stopped")
|
|
log.WithContext().Info("stopped")
|
|
|
|
|
|
|
@@ -164,8 +180,6 @@ type sshClientID uint64
|
|
|
|
|
|
|
|
type sshServer struct {
|
|
type sshServer struct {
|
|
|
config *Config
|
|
config *Config
|
|
|
- runWaitGroup *sync.WaitGroup
|
|
|
|
|
- listenerError chan error
|
|
|
|
|
shutdownBroadcast <-chan struct{}
|
|
shutdownBroadcast <-chan struct{}
|
|
|
sshHostKey ssh.Signer
|
|
sshHostKey ssh.Signer
|
|
|
nextClientID sshClientID
|
|
nextClientID sshClientID
|
|
@@ -174,69 +188,96 @@ type sshServer struct {
|
|
|
clients map[sshClientID]*sshClient
|
|
clients map[sshClientID]*sshClient
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func newSSHServer(
|
|
|
|
|
+ config *Config,
|
|
|
|
|
+ shutdownBroadcast <-chan struct{}) (*sshServer, error) {
|
|
|
|
|
+
|
|
|
|
|
+ privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, psiphon.ContextError(err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
|
|
|
|
|
+ signer, err := ssh.NewSignerFromKey(privateKey)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, psiphon.ContextError(err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return &sshServer{
|
|
|
|
|
+ config: config,
|
|
|
|
|
+ shutdownBroadcast: shutdownBroadcast,
|
|
|
|
|
+ sshHostKey: signer,
|
|
|
|
|
+ nextClientID: 1,
|
|
|
|
|
+ clients: make(map[sshClientID]*sshClient),
|
|
|
|
|
+ }, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// runListener is intended to run an a goroutine; it blocks
|
|
|
|
|
+// running a particular listener. If an unrecoverable error
|
|
|
|
|
+// occurs, it will send the error to the listenerError channel.
|
|
|
func (sshServer *sshServer) runListener(
|
|
func (sshServer *sshServer) runListener(
|
|
|
- listener net.Listener, tunnelProtocol string) {
|
|
|
|
|
|
|
+ listener net.Listener,
|
|
|
|
|
+ listenerError chan<- error,
|
|
|
|
|
+ tunnelProtocol string) {
|
|
|
|
|
|
|
|
- for {
|
|
|
|
|
- conn, err := listener.Accept()
|
|
|
|
|
|
|
+ handleClient := func(clientConn net.Conn) {
|
|
|
|
|
+ // process each client connection concurrently
|
|
|
|
|
+ go sshServer.handleClient(tunnelProtocol, clientConn)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if err == nil && tunnelProtocol == psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
|
|
|
|
|
- conn, err = psiphon.NewObfuscatedSshConn(
|
|
|
|
|
- psiphon.OBFUSCATION_CONN_MODE_SERVER,
|
|
|
|
|
- conn,
|
|
|
|
|
- sshServer.config.ObfuscatedSSHKey)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Note: when exiting due to a unrecoverable error, be sure
|
|
|
|
|
+ // to try to send the error to listenerError so that the outer
|
|
|
|
|
+ // TunnelServer.Run will properly shut down instead of remaining
|
|
|
|
|
+ // running.
|
|
|
|
|
|
|
|
- select {
|
|
|
|
|
- case <-sshServer.shutdownBroadcast:
|
|
|
|
|
- if err == nil {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- }
|
|
|
|
|
- return
|
|
|
|
|
- default:
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if psiphon.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
|
|
|
|
|
+ psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
|
|
|
|
|
|
|
|
|
|
+ meekServer, err := NewMeekServer(
|
|
|
|
|
+ sshServer.config,
|
|
|
|
|
+ listener,
|
|
|
|
|
+ psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol),
|
|
|
|
|
+ handleClient,
|
|
|
|
|
+ sshServer.shutdownBroadcast)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- if e, ok := err.(net.Error); ok && e.Temporary() {
|
|
|
|
|
- log.WithContextFields(LogFields{"error": err}).Error("accept failed")
|
|
|
|
|
- // Temporary error, keep running
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
select {
|
|
select {
|
|
|
- case sshServer.listenerError <- psiphon.ContextError(err):
|
|
|
|
|
|
|
+ case listenerError <- psiphon.ContextError(err):
|
|
|
default:
|
|
default:
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // process each client connection concurrently
|
|
|
|
|
- go sshServer.handleClient(tunnelProtocol, conn)
|
|
|
|
|
- }
|
|
|
|
|
-}
|
|
|
|
|
|
|
+ meekServer.Run()
|
|
|
|
|
|
|
|
-func (sshServer *sshServer) runLoadMonitor() {
|
|
|
|
|
- ticker := time.NewTicker(
|
|
|
|
|
- time.Duration(sshServer.config.LoadMonitorPeriodSeconds) * time.Second)
|
|
|
|
|
- defer ticker.Stop()
|
|
|
|
|
- for {
|
|
|
|
|
- select {
|
|
|
|
|
- case <-sshServer.shutdownBroadcast:
|
|
|
|
|
- return
|
|
|
|
|
- case <-ticker.C:
|
|
|
|
|
- var memStats runtime.MemStats
|
|
|
|
|
- runtime.ReadMemStats(&memStats)
|
|
|
|
|
- fields := LogFields{
|
|
|
|
|
- "goroutines": runtime.NumGoroutine(),
|
|
|
|
|
- "memAlloc": memStats.Alloc,
|
|
|
|
|
- "memTotalAlloc": memStats.TotalAlloc,
|
|
|
|
|
- "memSysAlloc": memStats.Sys,
|
|
|
|
|
|
|
+ } else {
|
|
|
|
|
+
|
|
|
|
|
+ for {
|
|
|
|
|
+ conn, err := listener.Accept()
|
|
|
|
|
+
|
|
|
|
|
+ select {
|
|
|
|
|
+ case <-sshServer.shutdownBroadcast:
|
|
|
|
|
+ if err == nil {
|
|
|
|
|
+ conn.Close()
|
|
|
|
|
+ }
|
|
|
|
|
+ return
|
|
|
|
|
+ default:
|
|
|
}
|
|
}
|
|
|
- for tunnelProtocol, count := range sshServer.countClients() {
|
|
|
|
|
- fields[tunnelProtocol] = count
|
|
|
|
|
|
|
+
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ if e, ok := err.(net.Error); ok && e.Temporary() {
|
|
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Error("accept failed")
|
|
|
|
|
+ // Temporary error, keep running
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ select {
|
|
|
|
|
+ case listenerError <- psiphon.ContextError(err):
|
|
|
|
|
+ default:
|
|
|
|
|
+ }
|
|
|
|
|
+ return
|
|
|
}
|
|
}
|
|
|
- log.WithContextFields(fields).Info("load")
|
|
|
|
|
|
|
+
|
|
|
|
|
+ handleClient(conn)
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -270,16 +311,26 @@ func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (sshServer *sshServer) countClients() map[string]int {
|
|
|
|
|
|
|
+func (sshServer *sshServer) getLoadStats() map[string]map[string]int64 {
|
|
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
sshServer.clientsMutex.Lock()
|
|
|
defer sshServer.clientsMutex.Unlock()
|
|
defer sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
|
|
- counts := make(map[string]int)
|
|
|
|
|
|
|
+ loadStats := make(map[string]map[string]int64)
|
|
|
for _, client := range sshServer.clients {
|
|
for _, client := range sshServer.clients {
|
|
|
- counts[client.tunnelProtocol] += 1
|
|
|
|
|
- }
|
|
|
|
|
- return counts
|
|
|
|
|
|
|
+ if loadStats[client.tunnelProtocol] == nil {
|
|
|
|
|
+ loadStats[client.tunnelProtocol] = make(map[string]int64)
|
|
|
|
|
+ }
|
|
|
|
|
+ // Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
|
|
|
|
|
+ loadStats[client.tunnelProtocol]["CurrentClients"] += 1
|
|
|
|
|
+ client.Lock()
|
|
|
|
|
+ loadStats[client.tunnelProtocol]["CurrentTCPPortForwards"] += client.tcpTrafficState.concurrentPortForwardCount
|
|
|
|
|
+ loadStats[client.tunnelProtocol]["TotalTCPPortForwards"] += client.tcpTrafficState.totalPortForwardCount
|
|
|
|
|
+ loadStats[client.tunnelProtocol]["CurrentUDPPortForwards"] += client.udpTrafficState.concurrentPortForwardCount
|
|
|
|
|
+ loadStats[client.tunnelProtocol]["TotalUDPPortForwards"] += client.udpTrafficState.totalPortForwardCount
|
|
|
|
|
+ client.Unlock()
|
|
|
|
|
+ }
|
|
|
|
|
+ return loadStats
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (sshServer *sshServer) stopClients() {
|
|
func (sshServer *sshServer) stopClients() {
|
|
@@ -304,22 +355,27 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
geoIPData,
|
|
geoIPData,
|
|
|
sshServer.config.GetTrafficRules(geoIPData.Country))
|
|
sshServer.config.GetTrafficRules(geoIPData.Country))
|
|
|
|
|
|
|
|
- // Wrap the base client connection with an IdleTimeoutConn which will terminate
|
|
|
|
|
- // the connection if no data is received before the deadline. This timeout is
|
|
|
|
|
- // in effect for the entire duration of the SSH connection. Clients must actively
|
|
|
|
|
- // use the connection or send SSH keep alive requests to keep the connection
|
|
|
|
|
- // active.
|
|
|
|
|
|
|
+ // Wrap the base client connection with an ActivityMonitoredConn which will
|
|
|
|
|
+ // terminate the connection if no data is received before the deadline. This
|
|
|
|
|
+ // timeout is in effect for the entire duration of the SSH connection. Clients
|
|
|
|
|
+ // must actively use the connection or send SSH keep alive requests to keep
|
|
|
|
|
+ // the connection active.
|
|
|
|
|
|
|
|
- var conn net.Conn
|
|
|
|
|
-
|
|
|
|
|
- conn = psiphon.NewIdleTimeoutConn(clientConn, SSH_CONNECTION_READ_DEADLINE, false)
|
|
|
|
|
|
|
+ clientConn = psiphon.NewActivityMonitoredConn(
|
|
|
|
|
+ clientConn,
|
|
|
|
|
+ SSH_CONNECTION_READ_DEADLINE,
|
|
|
|
|
+ false,
|
|
|
|
|
+ nil)
|
|
|
|
|
|
|
|
// Further wrap the connection in a rate limiting ThrottledConn.
|
|
// Further wrap the connection in a rate limiting ThrottledConn.
|
|
|
|
|
|
|
|
- conn = psiphon.NewThrottledConn(
|
|
|
|
|
- conn,
|
|
|
|
|
- int64(sshClient.trafficRules.LimitDownstreamBytesPerSecond),
|
|
|
|
|
- int64(sshClient.trafficRules.LimitUpstreamBytesPerSecond))
|
|
|
|
|
|
|
+ rateLimits := sshClient.trafficRules.GetRateLimits(tunnelProtocol)
|
|
|
|
|
+ clientConn = psiphon.NewThrottledConn(
|
|
|
|
|
+ clientConn,
|
|
|
|
|
+ rateLimits.DownstreamUnlimitedBytes,
|
|
|
|
|
+ int64(rateLimits.DownstreamBytesPerSecond),
|
|
|
|
|
+ rateLimits.UpstreamUnlimitedBytes,
|
|
|
|
|
+ int64(rateLimits.UpstreamBytesPerSecond))
|
|
|
|
|
|
|
|
// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
|
|
// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
|
|
|
// respect shutdownBroadcast and implement a specific handshake timeout.
|
|
// respect shutdownBroadcast and implement a specific handshake timeout.
|
|
@@ -350,17 +406,30 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
}
|
|
}
|
|
|
sshServerConfig.AddHostKey(sshServer.sshHostKey)
|
|
sshServerConfig.AddHostKey(sshServer.sshHostKey)
|
|
|
|
|
|
|
|
- sshConn, channels, requests, err :=
|
|
|
|
|
- ssh.NewServerConn(conn, sshServerConfig)
|
|
|
|
|
|
|
+ result := &sshNewServerConnResult{}
|
|
|
|
|
+
|
|
|
|
|
+ // Wrap the connection in an SSH deobfuscator when required.
|
|
|
|
|
|
|
|
- resultChannel <- &sshNewServerConnResult{
|
|
|
|
|
- conn: conn,
|
|
|
|
|
- sshConn: sshConn,
|
|
|
|
|
- channels: channels,
|
|
|
|
|
- requests: requests,
|
|
|
|
|
- err: err,
|
|
|
|
|
|
|
+ if psiphon.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
|
|
|
|
|
+ // Note: NewObfuscatedSshConn blocks on network I/O
|
|
|
|
|
+ // TODO: ensure this won't block shutdown
|
|
|
|
|
+ conn, result.err = psiphon.NewObfuscatedSshConn(
|
|
|
|
|
+ psiphon.OBFUSCATION_CONN_MODE_SERVER,
|
|
|
|
|
+ clientConn,
|
|
|
|
|
+ sshServer.config.ObfuscatedSSHKey)
|
|
|
|
|
+ if result.err != nil {
|
|
|
|
|
+ result.err = psiphon.ContextError(result.err)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if result.err == nil {
|
|
|
|
|
+ result.sshConn, result.channels, result.requests, result.err =
|
|
|
|
|
+ ssh.NewServerConn(conn, sshServerConfig)
|
|
|
}
|
|
}
|
|
|
- }(conn)
|
|
|
|
|
|
|
+
|
|
|
|
|
+ resultChannel <- result
|
|
|
|
|
+
|
|
|
|
|
+ }(clientConn)
|
|
|
|
|
|
|
|
var result *sshNewServerConnResult
|
|
var result *sshNewServerConnResult
|
|
|
select {
|
|
select {
|
|
@@ -368,13 +437,16 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
case <-sshServer.shutdownBroadcast:
|
|
case <-sshServer.shutdownBroadcast:
|
|
|
// Close() will interrupt an ongoing handshake
|
|
// Close() will interrupt an ongoing handshake
|
|
|
// TODO: wait for goroutine to exit before returning?
|
|
// TODO: wait for goroutine to exit before returning?
|
|
|
- conn.Close()
|
|
|
|
|
|
|
+ clientConn.Close()
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if result.err != nil {
|
|
if result.err != nil {
|
|
|
- conn.Close()
|
|
|
|
|
- log.WithContextFields(LogFields{"error": result.err}).Warning("handshake failed")
|
|
|
|
|
|
|
+ clientConn.Close()
|
|
|
|
|
+ // This is a Debug log due to noise. The handshake often fails due to I/O
|
|
|
|
|
+ // errors as clients frequently interrupt connections in progress when
|
|
|
|
|
+ // client-side load balancing completes a connection to a different server.
|
|
|
|
|
+ log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed")
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -384,7 +456,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
|
|
|
|
|
clientID, ok := sshServer.registerClient(sshClient)
|
|
clientID, ok := sshServer.registerClient(sshClient)
|
|
|
if !ok {
|
|
if !ok {
|
|
|
- conn.Close()
|
|
|
|
|
|
|
+ clientConn.Close()
|
|
|
log.WithContext().Warning("register failed")
|
|
log.WithContext().Warning("register failed")
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
@@ -393,6 +465,8 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
go ssh.DiscardRequests(result.requests)
|
|
go ssh.DiscardRequests(result.requests)
|
|
|
|
|
|
|
|
sshClient.handleChannels(result.channels)
|
|
sshClient.handleChannels(result.channels)
|
|
|
|
|
+
|
|
|
|
|
+ // TODO: clientConn.Close()?
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
type sshClient struct {
|
|
type sshClient struct {
|
|
@@ -408,15 +482,16 @@ type sshClient struct {
|
|
|
tcpTrafficState *trafficState
|
|
tcpTrafficState *trafficState
|
|
|
udpTrafficState *trafficState
|
|
udpTrafficState *trafficState
|
|
|
channelHandlerWaitGroup *sync.WaitGroup
|
|
channelHandlerWaitGroup *sync.WaitGroup
|
|
|
|
|
+ tcpPortForwardLRU *psiphon.LRUConns
|
|
|
stopBroadcast chan struct{}
|
|
stopBroadcast chan struct{}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
type trafficState struct {
|
|
type trafficState struct {
|
|
|
bytesUp int64
|
|
bytesUp int64
|
|
|
bytesDown int64
|
|
bytesDown int64
|
|
|
- portForwardCount int64
|
|
|
|
|
concurrentPortForwardCount int64
|
|
concurrentPortForwardCount int64
|
|
|
peakConcurrentPortForwardCount int64
|
|
peakConcurrentPortForwardCount int64
|
|
|
|
|
+ totalPortForwardCount int64
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func newSshClient(
|
|
func newSshClient(
|
|
@@ -430,10 +505,94 @@ func newSshClient(
|
|
|
tcpTrafficState: &trafficState{},
|
|
tcpTrafficState: &trafficState{},
|
|
|
udpTrafficState: &trafficState{},
|
|
udpTrafficState: &trafficState{},
|
|
|
channelHandlerWaitGroup: new(sync.WaitGroup),
|
|
channelHandlerWaitGroup: new(sync.WaitGroup),
|
|
|
|
|
+ tcpPortForwardLRU: psiphon.NewLRUConns(),
|
|
|
stopBroadcast: make(chan struct{}),
|
|
stopBroadcast: make(chan struct{}),
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
|
|
+ var sshPasswordPayload struct {
|
|
|
|
|
+ SessionId string `json:"SessionId"`
|
|
|
|
|
+ SshPassword string `json:"SshPassword"`
|
|
|
|
|
+ }
|
|
|
|
|
+ err := json.Unmarshal(password, &sshPasswordPayload)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ userOk := (subtle.ConstantTimeCompare(
|
|
|
|
|
+ []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
|
|
|
|
|
+
|
|
|
|
|
+ passwordOk := (subtle.ConstantTimeCompare(
|
|
|
|
|
+ []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
|
|
|
|
|
+
|
|
|
|
|
+ if !userOk || !passwordOk {
|
|
|
|
|
+ return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ psiphonSessionID := sshPasswordPayload.SessionId
|
|
|
|
|
+
|
|
|
|
|
+ sshClient.Lock()
|
|
|
|
|
+ sshClient.psiphonSessionID = psiphonSessionID
|
|
|
|
|
+ geoIPData := sshClient.geoIPData
|
|
|
|
|
+ sshClient.Unlock()
|
|
|
|
|
+
|
|
|
|
|
+ if sshClient.sshServer.config.UseRedis() {
|
|
|
|
|
+ err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.WithContextFields(LogFields{
|
|
|
|
|
+ "psiphonSessionID": psiphonSessionID,
|
|
|
|
|
+ "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
|
|
|
|
|
+ // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return nil, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ if sshClient.sshServer.config.UseFail2Ban() {
|
|
|
|
|
+ clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
|
|
|
|
|
+ if clientIPAddress != "" {
|
|
|
|
|
+ LogFail2Ban(clientIPAddress)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
|
|
|
|
|
+ } else {
|
|
|
|
|
+ log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (sshClient *sshClient) stop() {
|
|
|
|
|
+
|
|
|
|
|
+ sshClient.sshConn.Close()
|
|
|
|
|
+ sshClient.sshConn.Wait()
|
|
|
|
|
+
|
|
|
|
|
+ close(sshClient.stopBroadcast)
|
|
|
|
|
+ sshClient.channelHandlerWaitGroup.Wait()
|
|
|
|
|
+
|
|
|
|
|
+ sshClient.Lock()
|
|
|
|
|
+ log.WithContextFields(
|
|
|
|
|
+ LogFields{
|
|
|
|
|
+ "startTime": sshClient.startTime,
|
|
|
|
|
+ "duration": time.Now().Sub(sshClient.startTime),
|
|
|
|
|
+ "psiphonSessionID": sshClient.psiphonSessionID,
|
|
|
|
|
+ "country": sshClient.geoIPData.Country,
|
|
|
|
|
+ "city": sshClient.geoIPData.City,
|
|
|
|
|
+ "ISP": sshClient.geoIPData.ISP,
|
|
|
|
|
+ "bytesUpTCP": sshClient.tcpTrafficState.bytesUp,
|
|
|
|
|
+ "bytesDownTCP": sshClient.tcpTrafficState.bytesDown,
|
|
|
|
|
+ "peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
|
|
|
|
|
+ "totalPortForwardCountTCP": sshClient.tcpTrafficState.totalPortForwardCount,
|
|
|
|
|
+ "bytesUpUDP": sshClient.udpTrafficState.bytesUp,
|
|
|
|
|
+ "bytesDownUDP": sshClient.udpTrafficState.bytesDown,
|
|
|
|
|
+ "peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
|
|
|
|
|
+ "totalPortForwardCountUDP": sshClient.udpTrafficState.totalPortForwardCount,
|
|
|
|
|
+ }).Info("tunnel closed")
|
|
|
|
|
+ sshClient.Unlock()
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
|
|
func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
|
|
|
for newChannel := range channels {
|
|
for newChannel := range channels {
|
|
|
|
|
|
|
@@ -478,8 +637,8 @@ func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChanne
|
|
|
|
|
|
|
|
// Intercept TCP port forwards to a specified udpgw server and handle directly.
|
|
// Intercept TCP port forwards to a specified udpgw server and handle directly.
|
|
|
// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
|
|
// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
|
|
|
- isUDPChannel := sshClient.sshServer.config.UdpgwServerAddress != "" &&
|
|
|
|
|
- sshClient.sshServer.config.UdpgwServerAddress ==
|
|
|
|
|
|
|
+ isUDPChannel := sshClient.sshServer.config.UDPInterceptUdpgwServerAddress != "" &&
|
|
|
|
|
+ sshClient.sshServer.config.UDPInterceptUdpgwServerAddress ==
|
|
|
fmt.Sprintf("%s:%d",
|
|
fmt.Sprintf("%s:%d",
|
|
|
directTcpipExtraData.HostToConnect,
|
|
directTcpipExtraData.HostToConnect,
|
|
|
directTcpipExtraData.PortToConnect)
|
|
directTcpipExtraData.PortToConnect)
|
|
@@ -496,7 +655,7 @@ func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
port int, allowPorts []int, denyPorts []int) bool {
|
|
port int, allowPorts []int, denyPorts []int) bool {
|
|
|
|
|
|
|
|
// TODO: faster lookup?
|
|
// TODO: faster lookup?
|
|
|
- if allowPorts != nil {
|
|
|
|
|
|
|
+ if len(allowPorts) > 0 {
|
|
|
for _, allowPort := range allowPorts {
|
|
for _, allowPort := range allowPorts {
|
|
|
if port == allowPort {
|
|
if port == allowPort {
|
|
|
return true
|
|
return true
|
|
@@ -504,7 +663,7 @@ func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
}
|
|
}
|
|
|
return false
|
|
return false
|
|
|
}
|
|
}
|
|
|
- if denyPorts != nil {
|
|
|
|
|
|
|
+ if len(denyPorts) > 0 {
|
|
|
for _, denyPort := range denyPorts {
|
|
for _, denyPort := range denyPorts {
|
|
|
if port == denyPort {
|
|
if port == denyPort {
|
|
|
return false
|
|
return false
|
|
@@ -520,7 +679,7 @@ func (sshClient *sshClient) isPortForwardLimitExceeded(
|
|
|
limitExceeded := false
|
|
limitExceeded := false
|
|
|
if maxPortForwardCount > 0 {
|
|
if maxPortForwardCount > 0 {
|
|
|
sshClient.Lock()
|
|
sshClient.Lock()
|
|
|
- limitExceeded = state.portForwardCount >= int64(maxPortForwardCount)
|
|
|
|
|
|
|
+ limitExceeded = state.concurrentPortForwardCount >= int64(maxPortForwardCount)
|
|
|
sshClient.Unlock()
|
|
sshClient.Unlock()
|
|
|
}
|
|
}
|
|
|
return limitExceeded
|
|
return limitExceeded
|
|
@@ -530,11 +689,11 @@ func (sshClient *sshClient) openedPortForward(
|
|
|
state *trafficState) {
|
|
state *trafficState) {
|
|
|
|
|
|
|
|
sshClient.Lock()
|
|
sshClient.Lock()
|
|
|
- state.portForwardCount += 1
|
|
|
|
|
state.concurrentPortForwardCount += 1
|
|
state.concurrentPortForwardCount += 1
|
|
|
if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
|
|
if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
|
|
|
state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
|
|
state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
|
|
|
}
|
|
}
|
|
|
|
|
+ state.totalPortForwardCount += 1
|
|
|
sshClient.Unlock()
|
|
sshClient.Unlock()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -565,8 +724,12 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
|
|
|
|
|
var bytesUp, bytesDown int64
|
|
var bytesUp, bytesDown int64
|
|
|
sshClient.openedPortForward(sshClient.tcpTrafficState)
|
|
sshClient.openedPortForward(sshClient.tcpTrafficState)
|
|
|
- defer sshClient.closedPortForward(
|
|
|
|
|
- sshClient.tcpTrafficState, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
|
|
|
|
|
|
|
+ defer func() {
|
|
|
|
|
+ sshClient.closedPortForward(
|
|
|
|
|
+ sshClient.tcpTrafficState,
|
|
|
|
|
+ atomic.LoadInt64(&bytesUp),
|
|
|
|
|
+ atomic.LoadInt64(&bytesDown))
|
|
|
|
|
+ }()
|
|
|
|
|
|
|
|
// TOCTOU note: important to increment the port forward count (via
|
|
// TOCTOU note: important to increment the port forward count (via
|
|
|
// openPortForward) _before_ checking isPortForwardLimitExceeded
|
|
// openPortForward) _before_ checking isPortForwardLimitExceeded
|
|
@@ -578,11 +741,43 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
sshClient.tcpTrafficState,
|
|
sshClient.tcpTrafficState,
|
|
|
sshClient.trafficRules.MaxTCPPortForwardCount) {
|
|
sshClient.trafficRules.MaxTCPPortForwardCount) {
|
|
|
|
|
|
|
|
- sshClient.rejectNewChannel(
|
|
|
|
|
- newChannel, ssh.Prohibited, "maximum port forward limit exceeded")
|
|
|
|
|
- return
|
|
|
|
|
|
|
+ // Close the oldest TCP port forward. CloseOldest() closes
|
|
|
|
|
+ // the conn and the port forward's goroutine will complete
|
|
|
|
|
+ // the cleanup asynchronously.
|
|
|
|
|
+ //
|
|
|
|
|
+ // Some known limitations:
|
|
|
|
|
+ //
|
|
|
|
|
+ // - Since CloseOldest() closes the upstream socket but does not
|
|
|
|
|
+ // clean up all resources associated with the port forward. These
|
|
|
|
|
+ // include the goroutine(s) relaying traffic as well as the SSH
|
|
|
|
|
+ // channel. Closing the socket will interrupt the goroutines which
|
|
|
|
|
+ // will then complete the cleanup. But, since the full cleanup is
|
|
|
|
|
+ // asynchronous, there exists a possibility that a client can consume
|
|
|
|
|
+ // more than max port forward resources -- just not upstream sockets.
|
|
|
|
|
+ //
|
|
|
|
|
+ // - An LRU list entry for this port forward is not added until
|
|
|
|
|
+ // after the dial completes, but the port forward is counted
|
|
|
|
|
+ // towards max limits. This means many dials in progress will
|
|
|
|
|
+ // put established connections in jeopardy.
|
|
|
|
|
+ //
|
|
|
|
|
+ // - We're closing the oldest open connection _before_ successfully
|
|
|
|
|
+ // dialing the new port forward. This means we are potentially
|
|
|
|
|
+ // discarding a good connection to make way for a failed connection.
|
|
|
|
|
+ // We cannot simply dial first and still maintain a limit on
|
|
|
|
|
+ // resources used, so to address this we'd need to add some
|
|
|
|
|
+ // accounting for connections still establishing.
|
|
|
|
|
+
|
|
|
|
|
+ sshClient.tcpPortForwardLRU.CloseOldest()
|
|
|
|
|
+
|
|
|
|
|
+ log.WithContextFields(
|
|
|
|
|
+ LogFields{
|
|
|
|
|
+ "maxCount": sshClient.trafficRules.MaxTCPPortForwardCount,
|
|
|
|
|
+ }).Debug("closed LRU TCP port forward")
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // Dial the target remote address. This is done in a goroutine to
|
|
|
|
|
+ // ensure the shutdown signal is handled immediately.
|
|
|
|
|
+
|
|
|
remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
|
|
remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
|
|
|
|
|
|
|
|
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
@@ -615,9 +810,25 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // The upstream TCP port forward connection has been established. Schedule
|
|
|
|
|
+ // some cleanup and notify the SSH client that the channel is accepted.
|
|
|
|
|
+
|
|
|
fwdConn := result.conn
|
|
fwdConn := result.conn
|
|
|
defer fwdConn.Close()
|
|
defer fwdConn.Close()
|
|
|
|
|
|
|
|
|
|
+ lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
|
|
|
|
|
+ defer lruEntry.Remove()
|
|
|
|
|
+
|
|
|
|
|
+ // ActivityMonitoredConn monitors the TCP port forward I/O and updates
|
|
|
|
|
+ // its LRU status. ActivityMonitoredConn also times out read on the port
|
|
|
|
|
+ // forward if both reads and writes have been idle for the specified
|
|
|
|
|
+ // duration.
|
|
|
|
|
+ fwdConn = psiphon.NewActivityMonitoredConn(
|
|
|
|
|
+ fwdConn,
|
|
|
|
|
+ time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond,
|
|
|
|
|
+ true,
|
|
|
|
|
+ lruEntry)
|
|
|
|
|
+
|
|
|
fwdChannel, requests, err := newChannel.Accept()
|
|
fwdChannel, requests, err := newChannel.Accept()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
@@ -628,39 +839,35 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
|
|
|
|
|
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
|
|
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
|
|
|
|
|
|
|
|
- // When idle port forward traffic rules are in place, wrap fwdConn
|
|
|
|
|
- // in an IdleTimeoutConn configured to reset idle on writes as well
|
|
|
|
|
- // as read. This ensures the port forward idle timeout only happens
|
|
|
|
|
- // when both upstream and downstream directions are are idle.
|
|
|
|
|
-
|
|
|
|
|
- if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
|
|
|
|
|
- fwdConn = psiphon.NewIdleTimeoutConn(
|
|
|
|
|
- fwdConn,
|
|
|
|
|
- time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond,
|
|
|
|
|
- true)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Relay channel to forwarded connection.
|
|
|
|
|
|
|
|
- // relay channel to forwarded connection
|
|
|
|
|
// TODO: relay errors to fwdChannel.Stderr()?
|
|
// TODO: relay errors to fwdChannel.Stderr()?
|
|
|
- // TODO: use a low-memory io.Copy?
|
|
|
|
|
-
|
|
|
|
|
relayWaitGroup := new(sync.WaitGroup)
|
|
relayWaitGroup := new(sync.WaitGroup)
|
|
|
relayWaitGroup.Add(1)
|
|
relayWaitGroup.Add(1)
|
|
|
go func() {
|
|
go func() {
|
|
|
defer relayWaitGroup.Done()
|
|
defer relayWaitGroup.Done()
|
|
|
- bytes, err := io.Copy(fwdChannel, fwdConn)
|
|
|
|
|
|
|
+ // io.Copy allocates a 32K temporary buffer, and each port forward relay uses
|
|
|
|
|
+ // two of these buffers; using io.CopyBuffer with a smaller buffer reduces the
|
|
|
|
|
+ // overall memory footprint.
|
|
|
|
|
+ bytes, err := io.CopyBuffer(
|
|
|
|
|
+ fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
|
|
|
atomic.AddInt64(&bytesDown, bytes)
|
|
atomic.AddInt64(&bytesDown, bytes)
|
|
|
if err != nil && err != io.EOF {
|
|
if err != nil && err != io.EOF {
|
|
|
// Debug since errors such as "connection reset by peer" occur during normal operation
|
|
// Debug since errors such as "connection reset by peer" occur during normal operation
|
|
|
log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
|
|
log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
|
|
|
}
|
|
}
|
|
|
|
|
+ // Interrupt upstream io.Copy when downstream is shutting down.
|
|
|
|
|
+ // TODO: this is done to quickly cleanup the port forward when
|
|
|
|
|
+ // fwdConn has a read timeout, but is it clean -- upstream may still
|
|
|
|
|
+ // be flowing?
|
|
|
|
|
+ fwdChannel.Close()
|
|
|
}()
|
|
}()
|
|
|
- bytes, err := io.Copy(fwdConn, fwdChannel)
|
|
|
|
|
|
|
+ bytes, err := io.CopyBuffer(
|
|
|
|
|
+ fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
|
|
|
atomic.AddInt64(&bytesUp, bytes)
|
|
atomic.AddInt64(&bytesUp, bytes)
|
|
|
if err != nil && err != io.EOF {
|
|
if err != nil && err != io.EOF {
|
|
|
log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
|
|
log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
// Shutdown special case: fwdChannel will be closed and return EOF when
|
|
// Shutdown special case: fwdChannel will be closed and return EOF when
|
|
|
// the SSH connection is closed, but we need to explicitly close fwdConn
|
|
// the SSH connection is closed, but we need to explicitly close fwdConn
|
|
|
// to interrupt the downstream io.Copy, which may be blocked on a
|
|
// to interrupt the downstream io.Copy, which may be blocked on a
|
|
@@ -675,86 +882,3 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
"bytesUp": atomic.LoadInt64(&bytesUp),
|
|
"bytesUp": atomic.LoadInt64(&bytesUp),
|
|
|
"bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
|
|
"bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
-func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
|
|
- var sshPasswordPayload struct {
|
|
|
|
|
- SessionId string `json:"SessionId"`
|
|
|
|
|
- SshPassword string `json:"SshPassword"`
|
|
|
|
|
- }
|
|
|
|
|
- err := json.Unmarshal(password, &sshPasswordPayload)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- userOk := (subtle.ConstantTimeCompare(
|
|
|
|
|
- []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
|
|
|
|
|
-
|
|
|
|
|
- passwordOk := (subtle.ConstantTimeCompare(
|
|
|
|
|
- []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
|
|
|
|
|
-
|
|
|
|
|
- if !userOk || !passwordOk {
|
|
|
|
|
- return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- psiphonSessionID := sshPasswordPayload.SessionId
|
|
|
|
|
-
|
|
|
|
|
- sshClient.Lock()
|
|
|
|
|
- sshClient.psiphonSessionID = psiphonSessionID
|
|
|
|
|
- geoIPData := sshClient.geoIPData
|
|
|
|
|
- sshClient.Unlock()
|
|
|
|
|
-
|
|
|
|
|
- if sshClient.sshServer.config.UseRedis() {
|
|
|
|
|
- err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- log.WithContextFields(LogFields{
|
|
|
|
|
- "psiphonSessionID": psiphonSessionID,
|
|
|
|
|
- "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
|
|
|
|
|
- // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- return nil, nil
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- if sshClient.sshServer.config.UseFail2Ban() {
|
|
|
|
|
- clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
|
|
|
|
|
- if clientIPAddress != "" {
|
|
|
|
|
- LogFail2Ban(clientIPAddress)
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
|
|
|
|
|
- } else {
|
|
|
|
|
- log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
|
|
|
|
|
- }
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-func (sshClient *sshClient) stop() {
|
|
|
|
|
-
|
|
|
|
|
- sshClient.sshConn.Close()
|
|
|
|
|
- sshClient.sshConn.Wait()
|
|
|
|
|
-
|
|
|
|
|
- close(sshClient.stopBroadcast)
|
|
|
|
|
- sshClient.channelHandlerWaitGroup.Wait()
|
|
|
|
|
-
|
|
|
|
|
- sshClient.Lock()
|
|
|
|
|
- log.WithContextFields(
|
|
|
|
|
- LogFields{
|
|
|
|
|
- "startTime": sshClient.startTime,
|
|
|
|
|
- "duration": time.Now().Sub(sshClient.startTime),
|
|
|
|
|
- "psiphonSessionID": sshClient.psiphonSessionID,
|
|
|
|
|
- "country": sshClient.geoIPData.Country,
|
|
|
|
|
- "city": sshClient.geoIPData.City,
|
|
|
|
|
- "ISP": sshClient.geoIPData.ISP,
|
|
|
|
|
- "bytesUpTCP": sshClient.tcpTrafficState.bytesUp,
|
|
|
|
|
- "bytesDownTCP": sshClient.tcpTrafficState.bytesDown,
|
|
|
|
|
- "portForwardCountTCP": sshClient.tcpTrafficState.portForwardCount,
|
|
|
|
|
- "peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
|
|
|
|
|
- "bytesUpUDP": sshClient.udpTrafficState.bytesUp,
|
|
|
|
|
- "bytesDownUDP": sshClient.udpTrafficState.bytesDown,
|
|
|
|
|
- "portForwardCountUDP": sshClient.udpTrafficState.portForwardCount,
|
|
|
|
|
- "peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
|
|
|
|
|
- }).Info("tunnel closed")
|
|
|
|
|
- sshClient.Unlock()
|
|
|
|
|
-}
|
|
|