|
@@ -21,7 +21,6 @@ package server
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
"crypto/subtle"
|
|
"crypto/subtle"
|
|
|
- "encoding/hex"
|
|
|
|
|
"encoding/json"
|
|
"encoding/json"
|
|
|
"errors"
|
|
"errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
@@ -47,9 +46,22 @@ type sshServer struct {
|
|
|
useObfuscation bool
|
|
useObfuscation bool
|
|
|
shutdownBroadcast <-chan struct{}
|
|
shutdownBroadcast <-chan struct{}
|
|
|
sshConfig *ssh.ServerConfig
|
|
sshConfig *ssh.ServerConfig
|
|
|
- clientMutex sync.Mutex
|
|
|
|
|
|
|
+ clientsMutex sync.Mutex
|
|
|
stoppingClients bool
|
|
stoppingClients bool
|
|
|
- clients map[string]ssh.Conn
|
|
|
|
|
|
|
+ clients map[string]*sshClient
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+type sshClient struct {
|
|
|
|
|
+ sync.Mutex
|
|
|
|
|
+ sshConn ssh.Conn
|
|
|
|
|
+ startTime time.Time
|
|
|
|
|
+ geoIPData GeoIPData
|
|
|
|
|
+ psiphonSessionID string
|
|
|
|
|
+ bytesUp int64
|
|
|
|
|
+ bytesDown int64
|
|
|
|
|
+ portForwardCount int64
|
|
|
|
|
+ concurrentPortForwardCount int64
|
|
|
|
|
+ maxConcurrentPortForwardCount int64
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func runSSHServer(
|
|
func runSSHServer(
|
|
@@ -59,7 +71,7 @@ func runSSHServer(
|
|
|
config: config,
|
|
config: config,
|
|
|
useObfuscation: useObfuscation,
|
|
useObfuscation: useObfuscation,
|
|
|
shutdownBroadcast: shutdownBroadcast,
|
|
shutdownBroadcast: shutdownBroadcast,
|
|
|
- clients: make(map[string]ssh.Conn),
|
|
|
|
|
|
|
+ clients: make(map[string]*sshClient),
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
sshServer.sshConfig = &ssh.ServerConfig{
|
|
sshServer.sshConfig = &ssh.ServerConfig{
|
|
@@ -161,86 +173,83 @@ func runSSHServer(
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
|
|
- var sshPasswordPayload struct {
|
|
|
|
|
- SessionId string `json:"SessionId"`
|
|
|
|
|
- SshPassword string `json:"SshPassword"`
|
|
|
|
|
- }
|
|
|
|
|
- err := json.Unmarshal(password, &sshPasswordPayload)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
|
|
|
|
|
|
|
+func (sshServer *sshServer) registerClient(client *sshClient) (string, bool) {
|
|
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
|
|
+ defer sshServer.clientsMutex.Unlock()
|
|
|
|
|
+ if sshServer.stoppingClients {
|
|
|
|
|
+ return "", false
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- userOk := (subtle.ConstantTimeCompare(
|
|
|
|
|
- []byte(conn.User()), []byte(sshServer.config.SSHUserName)) == 1)
|
|
|
|
|
-
|
|
|
|
|
- passwordOk := (subtle.ConstantTimeCompare(
|
|
|
|
|
- []byte(sshPasswordPayload.SshPassword), []byte(sshServer.config.SSHPassword)) == 1)
|
|
|
|
|
-
|
|
|
|
|
- if !userOk || !passwordOk {
|
|
|
|
|
- return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
|
|
|
|
|
|
|
+ key := string(client.sshConn.SessionID())
|
|
|
|
|
+ existingClient := sshServer.clients[key]
|
|
|
|
|
+ if existingClient != nil {
|
|
|
|
|
+ log.WithContext().Warning("unexpected existing connection")
|
|
|
|
|
+ client.sshConn.Close()
|
|
|
|
|
+ client.sshConn.Wait()
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(conn.RemoteAddr()))
|
|
|
|
|
-
|
|
|
|
|
- log.WithContextFields(
|
|
|
|
|
- LogFields{
|
|
|
|
|
- "sshSessionID": hex.EncodeToString(conn.SessionID()),
|
|
|
|
|
- "psiphonSessionID": sshPasswordPayload.SessionId,
|
|
|
|
|
- "country": geoIPData.Country,
|
|
|
|
|
- "city": geoIPData.City,
|
|
|
|
|
- "ISP": geoIPData.ISP,
|
|
|
|
|
- }).Info("tunnel started")
|
|
|
|
|
-
|
|
|
|
|
- return nil, nil
|
|
|
|
|
|
|
+ sshServer.clients[key] = client
|
|
|
|
|
+ return key, true
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
|
|
|
|
|
- } else {
|
|
|
|
|
- log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
|
|
|
|
|
|
|
+func (sshServer *sshServer) updateClient(
|
|
|
|
|
+ clientKey string, updater func(*sshClient)) {
|
|
|
|
|
+
|
|
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
|
|
+ sshClient, ok := sshServer.clients[clientKey]
|
|
|
|
|
+ sshServer.clientsMutex.Unlock()
|
|
|
|
|
+ if ok {
|
|
|
|
|
+ sshClient.Lock()
|
|
|
|
|
+ updater(sshClient)
|
|
|
|
|
+ sshClient.Unlock()
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (sshServer *sshServer) registerClient(sshConn ssh.Conn) bool {
|
|
|
|
|
- sshServer.clientMutex.Lock()
|
|
|
|
|
- defer sshServer.clientMutex.Unlock()
|
|
|
|
|
|
|
+func (sshServer *sshServer) unregisterClient(clientKey string) {
|
|
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
if sshServer.stoppingClients {
|
|
if sshServer.stoppingClients {
|
|
|
- return false
|
|
|
|
|
- }
|
|
|
|
|
- existingSshConn := sshServer.clients[string(sshConn.SessionID())]
|
|
|
|
|
- if existingSshConn != nil {
|
|
|
|
|
- log.WithContext().Warning("unexpected existing connection")
|
|
|
|
|
- existingSshConn.Close()
|
|
|
|
|
- existingSshConn.Wait()
|
|
|
|
|
|
|
+ return
|
|
|
}
|
|
}
|
|
|
- sshServer.clients[string(sshConn.SessionID())] = sshConn
|
|
|
|
|
- return true
|
|
|
|
|
-}
|
|
|
|
|
|
|
+ client := sshServer.clients[clientKey]
|
|
|
|
|
+ delete(sshServer.clients, clientKey)
|
|
|
|
|
+ sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
|
|
-func (sshServer *sshServer) unregisterClient(sshConn ssh.Conn) {
|
|
|
|
|
- sshServer.clientMutex.Lock()
|
|
|
|
|
- if sshServer.stoppingClients {
|
|
|
|
|
|
|
+ if client == nil {
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
- delete(sshServer.clients, string(sshConn.SessionID()))
|
|
|
|
|
- sshServer.clientMutex.Unlock()
|
|
|
|
|
- sshConn.Close()
|
|
|
|
|
|
|
+
|
|
|
|
|
+ client.sshConn.Close()
|
|
|
|
|
+ client.Lock()
|
|
|
|
|
+ log.WithContextFields(
|
|
|
|
|
+ LogFields{
|
|
|
|
|
+ "startTime": client.startTime,
|
|
|
|
|
+ "duration": time.Now().Sub(client.startTime),
|
|
|
|
|
+ "psiphonSessionID": client.psiphonSessionID,
|
|
|
|
|
+ "country": client.geoIPData.Country,
|
|
|
|
|
+ "city": client.geoIPData.City,
|
|
|
|
|
+ "ISP": client.geoIPData.ISP,
|
|
|
|
|
+ "bytesUp": client.bytesUp,
|
|
|
|
|
+ "bytesDown": client.bytesDown,
|
|
|
|
|
+ "portForwardCount": client.portForwardCount,
|
|
|
|
|
+ "maxConcurrentPortForwardCount": client.maxConcurrentPortForwardCount,
|
|
|
|
|
+ }).Info("tunnel closed")
|
|
|
|
|
+ client.Unlock()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (sshServer *sshServer) stopClients() {
|
|
func (sshServer *sshServer) stopClients() {
|
|
|
- sshServer.clientMutex.Lock()
|
|
|
|
|
|
|
+ sshServer.clientsMutex.Lock()
|
|
|
sshServer.stoppingClients = true
|
|
sshServer.stoppingClients = true
|
|
|
- sshServer.clientMutex.Unlock()
|
|
|
|
|
- for _, sshConn := range sshServer.clients {
|
|
|
|
|
- sshConn.Close()
|
|
|
|
|
- sshConn.Wait()
|
|
|
|
|
|
|
+ sshServer.clientsMutex.Unlock()
|
|
|
|
|
+ for _, client := range sshServer.clients {
|
|
|
|
|
+ client.sshConn.Close()
|
|
|
|
|
+ client.sshConn.Wait()
|
|
|
}
|
|
}
|
|
|
|
|
+ sshServer.clients = make(map[string]*sshClient)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (sshServer *sshServer) handleClient(conn net.Conn) {
|
|
func (sshServer *sshServer) handleClient(conn net.Conn) {
|
|
|
|
|
|
|
|
|
|
+ startTime := time.Now()
|
|
|
|
|
+ geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(conn.RemoteAddr()))
|
|
|
|
|
+
|
|
|
// Run the initial [obfuscated] SSH handshake in a goroutine
|
|
// Run the initial [obfuscated] SSH handshake in a goroutine
|
|
|
// so we can both respect shutdownBroadcast and implement a
|
|
// so we can both respect shutdownBroadcast and implement a
|
|
|
// handshake timeout. The timeout is to reclaim network
|
|
// handshake timeout. The timeout is to reclaim network
|
|
@@ -293,16 +302,18 @@ func (sshServer *sshServer) handleClient(conn net.Conn) {
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if !sshServer.registerClient(result.sshConn) {
|
|
|
|
|
|
|
+ clientKey, ok := sshServer.registerClient(
|
|
|
|
|
+ &sshClient{
|
|
|
|
|
+ sshConn: result.sshConn,
|
|
|
|
|
+ startTime: startTime,
|
|
|
|
|
+ geoIPData: geoIPData,
|
|
|
|
|
+ })
|
|
|
|
|
+ if !ok {
|
|
|
result.sshConn.Close()
|
|
result.sshConn.Close()
|
|
|
log.WithContext().Warning("register failed")
|
|
log.WithContext().Warning("register failed")
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
- defer sshServer.unregisterClient(result.sshConn)
|
|
|
|
|
-
|
|
|
|
|
- // TODO: don't record IP; do GeoIP
|
|
|
|
|
- log.WithContextFields(
|
|
|
|
|
- LogFields{"remoteAddr": result.sshConn.RemoteAddr()}).Warning("connection accepted")
|
|
|
|
|
|
|
+ defer sshServer.unregisterClient(clientKey)
|
|
|
|
|
|
|
|
go ssh.DiscardRequests(result.requests)
|
|
go ssh.DiscardRequests(result.requests)
|
|
|
|
|
|
|
@@ -314,7 +325,42 @@ func (sshServer *sshServer) handleClient(conn net.Conn) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// process each port forward concurrently
|
|
// process each port forward concurrently
|
|
|
- go sshServer.handleNewDirectTcpipChannel(newChannel)
|
|
|
|
|
|
|
+ go sshServer.handleNewDirectTcpipChannel(clientKey, newChannel)
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
|
|
+ var sshPasswordPayload struct {
|
|
|
|
|
+ SessionId string `json:"SessionId"`
|
|
|
|
|
+ SshPassword string `json:"SshPassword"`
|
|
|
|
|
+ }
|
|
|
|
|
+ err := json.Unmarshal(password, &sshPasswordPayload)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ userOk := (subtle.ConstantTimeCompare(
|
|
|
|
|
+ []byte(conn.User()), []byte(sshServer.config.SSHUserName)) == 1)
|
|
|
|
|
+
|
|
|
|
|
+ passwordOk := (subtle.ConstantTimeCompare(
|
|
|
|
|
+ []byte(sshPasswordPayload.SshPassword), []byte(sshServer.config.SSHPassword)) == 1)
|
|
|
|
|
+
|
|
|
|
|
+ if !userOk || !passwordOk {
|
|
|
|
|
+ return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ clientKey := string(conn.SessionID())
|
|
|
|
|
+ sshServer.updateClient(clientKey, func(client *sshClient) {
|
|
|
|
|
+ client.psiphonSessionID = sshPasswordPayload.SessionId
|
|
|
|
|
+ })
|
|
|
|
|
+ return nil, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
|
|
|
|
|
+ } else {
|
|
|
|
|
+ log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -329,7 +375,7 @@ func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason s
|
|
|
newChannel.Reject(reason, message)
|
|
newChannel.Reject(reason, message)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (sshServer *sshServer) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
|
|
|
|
|
|
|
+func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newChannel ssh.NewChannel) {
|
|
|
|
|
|
|
|
// http://tools.ietf.org/html/rfc4254#section-7.2
|
|
// http://tools.ietf.org/html/rfc4254#section-7.2
|
|
|
var directTcpipExtraData struct {
|
|
var directTcpipExtraData struct {
|
|
@@ -366,6 +412,14 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ sshServer.updateClient(clientKey, func(client *sshClient) {
|
|
|
|
|
+ client.portForwardCount += 1
|
|
|
|
|
+ client.concurrentPortForwardCount += 1
|
|
|
|
|
+ if client.concurrentPortForwardCount > client.maxConcurrentPortForwardCount {
|
|
|
|
|
+ client.maxConcurrentPortForwardCount = client.concurrentPortForwardCount
|
|
|
|
|
+ }
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
|
|
|
|
|
|
|
|
go ssh.DiscardRequests(requests)
|
|
go ssh.DiscardRequests(requests)
|
|
@@ -377,21 +431,30 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
// TODO: use a low-memory io.Copy?
|
|
// TODO: use a low-memory io.Copy?
|
|
|
// TODO: relay errors to fwdChannel.Stderr()?
|
|
// TODO: relay errors to fwdChannel.Stderr()?
|
|
|
|
|
|
|
|
|
|
+ var bytesUp, bytesDown int64
|
|
|
|
|
+
|
|
|
relayWaitGroup := new(sync.WaitGroup)
|
|
relayWaitGroup := new(sync.WaitGroup)
|
|
|
relayWaitGroup.Add(1)
|
|
relayWaitGroup.Add(1)
|
|
|
go func() {
|
|
go func() {
|
|
|
defer relayWaitGroup.Done()
|
|
defer relayWaitGroup.Done()
|
|
|
- _, err := io.Copy(fwdConn, fwdChannel)
|
|
|
|
|
|
|
+ var err error
|
|
|
|
|
+ bytesUp, err = io.Copy(fwdConn, fwdChannel)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
|
|
log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
|
|
|
}
|
|
}
|
|
|
}()
|
|
}()
|
|
|
- _, err = io.Copy(fwdChannel, fwdConn)
|
|
|
|
|
|
|
+ bytesDown, err = io.Copy(fwdChannel, fwdConn)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
|
|
log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
|
|
|
}
|
|
}
|
|
|
fwdChannel.CloseWrite()
|
|
fwdChannel.CloseWrite()
|
|
|
relayWaitGroup.Wait()
|
|
relayWaitGroup.Wait()
|
|
|
|
|
|
|
|
|
|
+ sshServer.updateClient(clientKey, func(client *sshClient) {
|
|
|
|
|
+ client.concurrentPortForwardCount -= 1
|
|
|
|
|
+ client.bytesUp += bytesUp
|
|
|
|
|
+ client.bytesDown += bytesDown
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
|
|
|
}
|
|
}
|