|
|
@@ -26,28 +26,21 @@ import (
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net"
|
|
|
+ "runtime"
|
|
|
"sync"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
)
|
|
|
|
|
|
-// RunSSHServer runs an ssh server with plain SSH protocol.
|
|
|
-func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
- return runSSHServer(config, false, shutdownBroadcast)
|
|
|
-}
|
|
|
-
|
|
|
-// RunSSHServer runs an ssh server with Obfuscated SSH protocol.
|
|
|
-func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
- return runSSHServer(config, true, shutdownBroadcast)
|
|
|
-}
|
|
|
-
|
|
|
-// runSSHServer runs an SSH or Obfuscated SSH server. In the Obfuscated SSH case, an
|
|
|
-// ObfuscatedSSHConn is layered in front of the client TCP connection; otherwise, both
|
|
|
-// modes are identical.
|
|
|
+// RunSSHServer runs an SSH server, the core tunneling component of the Psiphon
|
|
|
+// server. The SSH server runs a selection of listeners that handle connections
|
|
|
+// using various, optional obfuscation protocols layered on top of SSH.
|
|
|
+// (Currently, just Obfuscated SSH).
|
|
|
//
|
|
|
-// runSSHServer listens on the designated port and spawns new goroutines to handle
|
|
|
+// RunSSHServer listens on the designated port(s) and spawns new goroutines to handle
|
|
|
// each client connection. It halts when shutdownBroadcast is signaled. A list of active
|
|
|
// clients is maintained, and when halting all clients are first shutdown.
|
|
|
//
|
|
|
@@ -55,11 +48,12 @@ func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) e
|
|
|
// authentication, and then looping on client new channel requests. At this time, only
|
|
|
// "direct-tcpip" channels, dynamic port fowards, are expected and supported.
|
|
|
//
|
|
|
-// A new goroutine is spawned to handle each port forward. Each port forward tracks its
|
|
|
-// bytes transferred. Overall per-client stats for connection duration, GeoIP, number of
|
|
|
-// port forwards, and bytes transferred are tracked and logged when the client shuts down.
|
|
|
-func runSSHServer(
|
|
|
- config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
|
|
|
+// A new goroutine is spawned to handle each port forward for each client. Each port
|
|
|
+// forward tracks its bytes transferred. Overall per-client stats for connection duration,
|
|
|
+// GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
|
|
|
+// client shuts down.
|
|
|
+func RunSSHServer(
|
|
|
+ config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
|
|
|
privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
|
|
|
if err != nil {
|
|
|
@@ -74,89 +68,94 @@ func runSSHServer(
|
|
|
|
|
|
sshServer := &sshServer{
|
|
|
config: config,
|
|
|
- useObfuscation: useObfuscation,
|
|
|
+ runWaitGroup: new(sync.WaitGroup),
|
|
|
+ listenerError: make(chan error),
|
|
|
shutdownBroadcast: shutdownBroadcast,
|
|
|
sshHostKey: signer,
|
|
|
nextClientID: 1,
|
|
|
clients: make(map[sshClientID]*sshClient),
|
|
|
}
|
|
|
|
|
|
- var serverPort int
|
|
|
- if useObfuscation {
|
|
|
- serverPort = config.ObfuscatedSSHServerPort
|
|
|
- } else {
|
|
|
- serverPort = config.SSHServerPort
|
|
|
+ type sshListener struct {
|
|
|
+ net.Listener
|
|
|
+ localAddress string
|
|
|
+ tunnelProtocol string
|
|
|
}
|
|
|
|
|
|
- listener, err := net.Listen(
|
|
|
- "tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, serverPort))
|
|
|
- if err != nil {
|
|
|
- return psiphon.ContextError(err)
|
|
|
- }
|
|
|
+ var listeners []*sshListener
|
|
|
|
|
|
- log.WithContextFields(
|
|
|
- LogFields{
|
|
|
- "useObfuscation": useObfuscation,
|
|
|
- "port": serverPort,
|
|
|
- }).Info("starting")
|
|
|
-
|
|
|
- err = nil
|
|
|
- errors := make(chan error)
|
|
|
- waitGroup := new(sync.WaitGroup)
|
|
|
+ if config.RunSSHServer() {
|
|
|
+ listeners = append(listeners, &sshListener{
|
|
|
+ localAddress: fmt.Sprintf(
|
|
|
+ "%s:%d", config.ServerIPAddress, config.SSHServerPort),
|
|
|
+ tunnelProtocol: psiphon.TUNNEL_PROTOCOL_SSH,
|
|
|
+ })
|
|
|
+ }
|
|
|
|
|
|
- waitGroup.Add(1)
|
|
|
- go func() {
|
|
|
- defer waitGroup.Done()
|
|
|
+ if config.RunObfuscatedSSHServer() {
|
|
|
+ listeners = append(listeners, &sshListener{
|
|
|
+ localAddress: fmt.Sprintf(
|
|
|
+ "%s:%d", config.ServerIPAddress, config.ObfuscatedSSHServerPort),
|
|
|
+ tunnelProtocol: psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
|
|
|
+ })
|
|
|
+ }
|
|
|
|
|
|
- loop:
|
|
|
- for {
|
|
|
- conn, err := listener.Accept()
|
|
|
+ // TODO: add additional protocol listeners here (e.g, meek)
|
|
|
|
|
|
- select {
|
|
|
- case <-shutdownBroadcast:
|
|
|
- if err == nil {
|
|
|
- conn.Close()
|
|
|
- }
|
|
|
- break loop
|
|
|
- default:
|
|
|
+ for i, listener := range listeners {
|
|
|
+ var err error
|
|
|
+ listener.Listener, err = net.Listen("tcp", listener.localAddress)
|
|
|
+ if err != nil {
|
|
|
+ for j := 0; j < i; j++ {
|
|
|
+ listener.Listener.Close()
|
|
|
}
|
|
|
+ return psiphon.ContextError(err)
|
|
|
+ }
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{
|
|
|
+ "localAddress": listener.localAddress,
|
|
|
+ "tunnelProtocol": listener.tunnelProtocol,
|
|
|
+ }).Info("listening")
|
|
|
+ }
|
|
|
|
|
|
- if err != nil {
|
|
|
- if e, ok := err.(net.Error); ok && e.Temporary() {
|
|
|
- log.WithContextFields(LogFields{"error": err}).Error("accept failed")
|
|
|
- // Temporary error, keep running
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- select {
|
|
|
- case errors <- psiphon.ContextError(err):
|
|
|
- default:
|
|
|
- }
|
|
|
+ for _, listener := range listeners {
|
|
|
+ sshServer.runWaitGroup.Add(1)
|
|
|
+ go func(listener *sshListener) {
|
|
|
+ defer sshServer.runWaitGroup.Done()
|
|
|
|
|
|
- break loop
|
|
|
- }
|
|
|
+ sshServer.runListener(
|
|
|
+ listener.Listener, listener.tunnelProtocol)
|
|
|
|
|
|
- // process each client connection concurrently
|
|
|
- go sshServer.handleClient(conn.(*net.TCPConn))
|
|
|
- }
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{
|
|
|
+ "localAddress": listener.localAddress,
|
|
|
+ "tunnelProtocol": listener.tunnelProtocol,
|
|
|
+ }).Info("stopping")
|
|
|
|
|
|
- sshServer.stopClients()
|
|
|
+ }(listener)
|
|
|
+ }
|
|
|
|
|
|
- log.WithContextFields(
|
|
|
- LogFields{"useObfuscation": useObfuscation}).Info("stopped")
|
|
|
- }()
|
|
|
+ if config.RunLoadMonitor() {
|
|
|
+ sshServer.runWaitGroup.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer sshServer.runWaitGroup.Done()
|
|
|
+ sshServer.runLoadMonitor()
|
|
|
+ }()
|
|
|
+ }
|
|
|
|
|
|
+ err = nil
|
|
|
select {
|
|
|
- case <-shutdownBroadcast:
|
|
|
- case err = <-errors:
|
|
|
+ case <-sshServer.shutdownBroadcast:
|
|
|
+ case err = <-sshServer.listenerError:
|
|
|
}
|
|
|
|
|
|
- listener.Close()
|
|
|
-
|
|
|
- waitGroup.Wait()
|
|
|
+ for _, listener := range listeners {
|
|
|
+ listener.Close()
|
|
|
+ }
|
|
|
+ sshServer.stopClients()
|
|
|
+ sshServer.runWaitGroup.Wait()
|
|
|
|
|
|
- log.WithContextFields(
|
|
|
- LogFields{"useObfuscation": useObfuscation}).Info("exiting")
|
|
|
+ log.WithContext().Info("stopped")
|
|
|
|
|
|
return err
|
|
|
}
|
|
|
@@ -165,7 +164,8 @@ type sshClientID uint64
|
|
|
|
|
|
type sshServer struct {
|
|
|
config *Config
|
|
|
- useObfuscation bool
|
|
|
+ runWaitGroup *sync.WaitGroup
|
|
|
+ listenerError chan error
|
|
|
shutdownBroadcast <-chan struct{}
|
|
|
sshHostKey ssh.Signer
|
|
|
nextClientID sshClientID
|
|
|
@@ -174,6 +174,73 @@ type sshServer struct {
|
|
|
clients map[sshClientID]*sshClient
|
|
|
}
|
|
|
|
|
|
+func (sshServer *sshServer) runListener(
|
|
|
+ listener net.Listener, tunnelProtocol string) {
|
|
|
+
|
|
|
+ for {
|
|
|
+ conn, err := listener.Accept()
|
|
|
+
|
|
|
+ if err == nil && tunnelProtocol == psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
|
|
|
+ conn, err = psiphon.NewObfuscatedSshConn(
|
|
|
+ psiphon.OBFUSCATION_CONN_MODE_SERVER,
|
|
|
+ conn,
|
|
|
+ sshServer.config.ObfuscatedSSHKey)
|
|
|
+ }
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-sshServer.shutdownBroadcast:
|
|
|
+ if err == nil {
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ }
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ if e, ok := err.(net.Error); ok && e.Temporary() {
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Error("accept failed")
|
|
|
+ // Temporary error, keep running
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ select {
|
|
|
+ case sshServer.listenerError <- psiphon.ContextError(err):
|
|
|
+ default:
|
|
|
+ }
|
|
|
+
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // process each client connection concurrently
|
|
|
+ go sshServer.handleClient(tunnelProtocol, conn)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (sshServer *sshServer) runLoadMonitor() {
|
|
|
+ ticker := time.NewTicker(
|
|
|
+ time.Duration(sshServer.config.LoadMonitorPeriodSeconds) * time.Second)
|
|
|
+ defer ticker.Stop()
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-sshServer.shutdownBroadcast:
|
|
|
+ return
|
|
|
+ case <-ticker.C:
|
|
|
+ var memStats runtime.MemStats
|
|
|
+ runtime.ReadMemStats(&memStats)
|
|
|
+ fields := LogFields{
|
|
|
+ "goroutines": runtime.NumGoroutine(),
|
|
|
+ "memAlloc": memStats.Alloc,
|
|
|
+ "memTotalAlloc": memStats.TotalAlloc,
|
|
|
+ "memSysAlloc": memStats.Sys,
|
|
|
+ }
|
|
|
+ for tunnelProtocol, count := range sshServer.countClients() {
|
|
|
+ fields[tunnelProtocol] = count
|
|
|
+ }
|
|
|
+ log.WithContextFields(fields).Info("load")
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
@@ -199,34 +266,20 @@ func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
|
|
|
sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
if client != nil {
|
|
|
- sshServer.stopClient(client)
|
|
|
+ client.stop()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) stopClient(client *sshClient) {
|
|
|
+func (sshServer *sshServer) countClients() map[string]int {
|
|
|
|
|
|
- client.sshConn.Close()
|
|
|
- client.sshConn.Wait()
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
+ defer sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
- client.Lock()
|
|
|
- log.WithContextFields(
|
|
|
- LogFields{
|
|
|
- "startTime": client.startTime,
|
|
|
- "duration": time.Now().Sub(client.startTime),
|
|
|
- "psiphonSessionID": client.psiphonSessionID,
|
|
|
- "country": client.geoIPData.Country,
|
|
|
- "city": client.geoIPData.City,
|
|
|
- "ISP": client.geoIPData.ISP,
|
|
|
- "bytesUpTCP": client.tcpTrafficState.bytesUp,
|
|
|
- "bytesDownTCP": client.tcpTrafficState.bytesDown,
|
|
|
- "portForwardCountTCP": client.tcpTrafficState.portForwardCount,
|
|
|
- "peakConcurrentPortForwardCountTCP": client.tcpTrafficState.peakConcurrentPortForwardCount,
|
|
|
- "bytesUpUDP": client.udpTrafficState.bytesUp,
|
|
|
- "bytesDownUDP": client.udpTrafficState.bytesDown,
|
|
|
- "portForwardCountUDP": client.udpTrafficState.portForwardCount,
|
|
|
- "peakConcurrentPortForwardCountUDP": client.udpTrafficState.peakConcurrentPortForwardCount,
|
|
|
- }).Info("tunnel closed")
|
|
|
- client.Unlock()
|
|
|
+ counts := make(map[string]int)
|
|
|
+ for _, client := range sshServer.clients {
|
|
|
+ counts[client.tunnelProtocol] += 1
|
|
|
+ }
|
|
|
+ return counts
|
|
|
}
|
|
|
|
|
|
func (sshServer *sshServer) stopClients() {
|
|
|
@@ -237,24 +290,21 @@ func (sshServer *sshServer) stopClients() {
|
|
|
sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
for _, client := range sshServer.clients {
|
|
|
- sshServer.stopClient(client)
|
|
|
+ client.stop()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
+func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
|
|
|
|
|
|
- geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
|
|
|
+ geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(clientConn.RemoteAddr()))
|
|
|
|
|
|
- sshClient := &sshClient{
|
|
|
- sshServer: sshServer,
|
|
|
- startTime: time.Now(),
|
|
|
- geoIPData: geoIPData,
|
|
|
- trafficRules: sshServer.config.GetTrafficRules(geoIPData.Country),
|
|
|
- tcpTrafficState: &trafficState{},
|
|
|
- udpTrafficState: &trafficState{},
|
|
|
- }
|
|
|
+ sshClient := newSshClient(
|
|
|
+ sshServer,
|
|
|
+ tunnelProtocol,
|
|
|
+ geoIPData,
|
|
|
+ sshServer.config.GetTrafficRules(geoIPData.Country))
|
|
|
|
|
|
- // Wrap the base TCP connection with an IdleTimeoutConn which will terminate
|
|
|
+ // Wrap the base client connection with an IdleTimeoutConn which will terminate
|
|
|
// the connection if no data is received before the deadline. This timeout is
|
|
|
// in effect for the entire duration of the SSH connection. Clients must actively
|
|
|
// use the connection or send SSH keep alive requests to keep the connection
|
|
|
@@ -262,7 +312,7 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
|
|
|
var conn net.Conn
|
|
|
|
|
|
- conn = psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
|
|
|
+ conn = psiphon.NewIdleTimeoutConn(clientConn, SSH_CONNECTION_READ_DEADLINE, false)
|
|
|
|
|
|
// Further wrap the connection in a rate limiting ThrottledConn.
|
|
|
|
|
|
@@ -292,29 +342,25 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
- go func() {
|
|
|
-
|
|
|
- result := &sshNewServerConnResult{}
|
|
|
- if sshServer.useObfuscation {
|
|
|
- result.conn, result.err = psiphon.NewObfuscatedSshConn(
|
|
|
- psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey)
|
|
|
- } else {
|
|
|
- result.conn = conn
|
|
|
+ go func(conn net.Conn) {
|
|
|
+ sshServerConfig := &ssh.ServerConfig{
|
|
|
+ PasswordCallback: sshClient.passwordCallback,
|
|
|
+ AuthLogCallback: sshClient.authLogCallback,
|
|
|
+ ServerVersion: sshServer.config.SSHServerVersion,
|
|
|
}
|
|
|
- if result.err == nil {
|
|
|
+ sshServerConfig.AddHostKey(sshServer.sshHostKey)
|
|
|
|
|
|
- sshServerConfig := &ssh.ServerConfig{
|
|
|
- PasswordCallback: sshClient.passwordCallback,
|
|
|
- AuthLogCallback: sshClient.authLogCallback,
|
|
|
- ServerVersion: sshServer.config.SSHServerVersion,
|
|
|
- }
|
|
|
- sshServerConfig.AddHostKey(sshServer.sshHostKey)
|
|
|
+ sshConn, channels, requests, err :=
|
|
|
+ ssh.NewServerConn(conn, sshServerConfig)
|
|
|
|
|
|
- result.sshConn, result.channels, result.requests, result.err =
|
|
|
- ssh.NewServerConn(result.conn, sshServerConfig)
|
|
|
+ resultChannel <- &sshNewServerConnResult{
|
|
|
+ conn: conn,
|
|
|
+ sshConn: sshConn,
|
|
|
+ channels: channels,
|
|
|
+ requests: requests,
|
|
|
+ err: err,
|
|
|
}
|
|
|
- resultChannel <- result
|
|
|
- }()
|
|
|
+ }(conn)
|
|
|
|
|
|
var result *sshNewServerConnResult
|
|
|
select {
|
|
|
@@ -351,15 +397,18 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
|
|
|
type sshClient struct {
|
|
|
sync.Mutex
|
|
|
- sshServer *sshServer
|
|
|
- sshConn ssh.Conn
|
|
|
- startTime time.Time
|
|
|
- geoIPData GeoIPData
|
|
|
- psiphonSessionID string
|
|
|
- udpChannel ssh.Channel
|
|
|
- trafficRules TrafficRules
|
|
|
- tcpTrafficState *trafficState
|
|
|
- udpTrafficState *trafficState
|
|
|
+ sshServer *sshServer
|
|
|
+ tunnelProtocol string
|
|
|
+ sshConn ssh.Conn
|
|
|
+ startTime time.Time
|
|
|
+ geoIPData GeoIPData
|
|
|
+ psiphonSessionID string
|
|
|
+ udpChannel ssh.Channel
|
|
|
+ trafficRules TrafficRules
|
|
|
+ tcpTrafficState *trafficState
|
|
|
+ udpTrafficState *trafficState
|
|
|
+ channelHandlerWaitGroup *sync.WaitGroup
|
|
|
+ stopBroadcast chan struct{}
|
|
|
}
|
|
|
|
|
|
type trafficState struct {
|
|
|
@@ -370,15 +419,31 @@ type trafficState struct {
|
|
|
peakConcurrentPortForwardCount int64
|
|
|
}
|
|
|
|
|
|
+func newSshClient(
|
|
|
+ sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
|
|
|
+ return &sshClient{
|
|
|
+ sshServer: sshServer,
|
|
|
+ tunnelProtocol: tunnelProtocol,
|
|
|
+ startTime: time.Now(),
|
|
|
+ geoIPData: geoIPData,
|
|
|
+ trafficRules: trafficRules,
|
|
|
+ tcpTrafficState: &trafficState{},
|
|
|
+ udpTrafficState: &trafficState{},
|
|
|
+ channelHandlerWaitGroup: new(sync.WaitGroup),
|
|
|
+ stopBroadcast: make(chan struct{}),
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
|
|
|
for newChannel := range channels {
|
|
|
|
|
|
if newChannel.ChannelType() != "direct-tcpip" {
|
|
|
sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
|
|
|
- return
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
// process each port forward concurrently
|
|
|
+ sshClient.channelHandlerWaitGroup.Add(1)
|
|
|
go sshClient.handleNewPortForwardChannel(newChannel)
|
|
|
}
|
|
|
}
|
|
|
@@ -395,6 +460,7 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason s
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
|
|
|
+ defer sshClient.channelHandlerWaitGroup.Done()
|
|
|
|
|
|
// http://tools.ietf.org/html/rfc4254#section-7.2
|
|
|
var directTcpipExtraData struct {
|
|
|
@@ -460,7 +526,7 @@ func (sshClient *sshClient) isPortForwardLimitExceeded(
|
|
|
return limitExceeded
|
|
|
}
|
|
|
|
|
|
-func (sshClient *sshClient) establishedPortForward(
|
|
|
+func (sshClient *sshClient) openedPortForward(
|
|
|
state *trafficState) {
|
|
|
|
|
|
sshClient.Lock()
|
|
|
@@ -497,7 +563,17 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // TODO: close LRU connection (after successful Dial) instead of rejecting new connection?
|
|
|
+ var bytesUp, bytesDown int64
|
|
|
+ sshClient.openedPortForward(sshClient.tcpTrafficState)
|
|
|
+ defer sshClient.closedPortForward(
|
|
|
+ sshClient.tcpTrafficState, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
|
|
|
+
|
|
|
+ // TOCTOU note: important to increment the port forward count (via
|
|
|
+ // openPortForward) _before_ checking isPortForwardLimitExceeded
|
|
|
+ // otherwise, the client could potentially consume excess resources
|
|
|
+ // by initiating many port forwards concurrently.
|
|
|
+ // TODO: close LRU connection (after successful Dial) instead of
|
|
|
+ // rejecting new connection?
|
|
|
if sshClient.isPortForwardLimitExceeded(
|
|
|
sshClient.tcpTrafficState,
|
|
|
sshClient.trafficRules.MaxTCPPortForwardCount) {
|
|
|
@@ -507,18 +583,39 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- targetAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
|
|
|
+ remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
|
|
|
|
|
|
- log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
|
|
|
+ log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
|
|
|
|
- // TODO: on EADDRNOTAVAIL, temporarily suspend new clients
|
|
|
- // TODO: port forward dial timeout
|
|
|
- // TODO: IPv6 support
|
|
|
- fwdConn, err := net.Dial("tcp4", targetAddr)
|
|
|
- if err != nil {
|
|
|
- sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
|
|
|
+ type dialTcpResult struct {
|
|
|
+ conn net.Conn
|
|
|
+ err error
|
|
|
+ }
|
|
|
+
|
|
|
+ resultChannel := make(chan *dialTcpResult, 1)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ // TODO: on EADDRNOTAVAIL, temporarily suspend new clients
|
|
|
+ // TODO: IPv6 support
|
|
|
+ conn, err := net.DialTimeout(
|
|
|
+ "tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
|
|
|
+ resultChannel <- &dialTcpResult{conn, err}
|
|
|
+ }()
|
|
|
+
|
|
|
+ var result *dialTcpResult
|
|
|
+ select {
|
|
|
+ case result = <-resultChannel:
|
|
|
+ case <-sshClient.stopBroadcast:
|
|
|
+ // Note: may leave dial in progress
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if result.err != nil {
|
|
|
+ sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error())
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
+ fwdConn := result.conn
|
|
|
defer fwdConn.Close()
|
|
|
|
|
|
fwdChannel, requests, err := newChannel.Accept()
|
|
|
@@ -529,9 +626,7 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
go ssh.DiscardRequests(requests)
|
|
|
defer fwdChannel.Close()
|
|
|
|
|
|
- sshClient.establishedPortForward(sshClient.tcpTrafficState)
|
|
|
-
|
|
|
- log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
|
|
|
+ log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
|
|
|
|
|
|
// When idle port forward traffic rules are in place, wrap fwdConn
|
|
|
// in an IdleTimeoutConn configured to reset idle on writes as well
|
|
|
@@ -549,28 +644,36 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
// TODO: relay errors to fwdChannel.Stderr()?
|
|
|
// TODO: use a low-memory io.Copy?
|
|
|
|
|
|
- var bytesUp, bytesDown int64
|
|
|
-
|
|
|
relayWaitGroup := new(sync.WaitGroup)
|
|
|
relayWaitGroup.Add(1)
|
|
|
go func() {
|
|
|
defer relayWaitGroup.Done()
|
|
|
- var err error
|
|
|
- bytesUp, err = io.Copy(fwdConn, fwdChannel)
|
|
|
- if err != nil {
|
|
|
- log.WithContextFields(LogFields{"error": err}).Warning("upstream TCP relay failed")
|
|
|
+ bytes, err := io.Copy(fwdChannel, fwdConn)
|
|
|
+ atomic.AddInt64(&bytesDown, bytes)
|
|
|
+ if err != nil && err != io.EOF {
|
|
|
+ // Debug since errors such as "connection reset by peer" occur during normal operation
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
|
|
|
}
|
|
|
}()
|
|
|
- bytesDown, err = io.Copy(fwdChannel, fwdConn)
|
|
|
- if err != nil {
|
|
|
- log.WithContextFields(LogFields{"error": err}).Warning("downstream TCP relay failed")
|
|
|
+ bytes, err := io.Copy(fwdConn, fwdChannel)
|
|
|
+ atomic.AddInt64(&bytesUp, bytes)
|
|
|
+ if err != nil && err != io.EOF {
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
|
|
|
}
|
|
|
- fwdChannel.CloseWrite()
|
|
|
- relayWaitGroup.Wait()
|
|
|
|
|
|
- sshClient.closedPortForward(sshClient.tcpTrafficState, bytesUp, bytesDown)
|
|
|
+ // Shutdown special case: fwdChannel will be closed and return EOF when
|
|
|
+ // the SSH connection is closed, but we need to explicitly close fwdConn
|
|
|
+ // to interrupt the downstream io.Copy, which may be blocked on a
|
|
|
+ // fwdConn.Read().
|
|
|
+ fwdConn.Close()
|
|
|
|
|
|
- log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
|
|
|
+ relayWaitGroup.Wait()
|
|
|
+
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{
|
|
|
+ "remoteAddr": remoteAddr,
|
|
|
+ "bytesUp": atomic.LoadInt64(&bytesUp),
|
|
|
+ "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
@@ -626,3 +729,32 @@ func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string
|
|
|
log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func (sshClient *sshClient) stop() {
|
|
|
+
|
|
|
+ sshClient.sshConn.Close()
|
|
|
+ sshClient.sshConn.Wait()
|
|
|
+
|
|
|
+ close(sshClient.stopBroadcast)
|
|
|
+ sshClient.channelHandlerWaitGroup.Wait()
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{
|
|
|
+ "startTime": sshClient.startTime,
|
|
|
+ "duration": time.Now().Sub(sshClient.startTime),
|
|
|
+ "psiphonSessionID": sshClient.psiphonSessionID,
|
|
|
+ "country": sshClient.geoIPData.Country,
|
|
|
+ "city": sshClient.geoIPData.City,
|
|
|
+ "ISP": sshClient.geoIPData.ISP,
|
|
|
+ "bytesUpTCP": sshClient.tcpTrafficState.bytesUp,
|
|
|
+ "bytesDownTCP": sshClient.tcpTrafficState.bytesDown,
|
|
|
+ "portForwardCountTCP": sshClient.tcpTrafficState.portForwardCount,
|
|
|
+ "peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
|
|
|
+ "bytesUpUDP": sshClient.udpTrafficState.bytesUp,
|
|
|
+ "bytesDownUDP": sshClient.udpTrafficState.bytesDown,
|
|
|
+ "portForwardCountUDP": sshClient.udpTrafficState.portForwardCount,
|
|
|
+ "peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
|
|
|
+ }).Info("tunnel closed")
|
|
|
+ sshClient.Unlock()
|
|
|
+}
|