|
|
@@ -41,45 +41,9 @@ func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) e
|
|
|
return runSSHServer(config, true, shutdownBroadcast)
|
|
|
}
|
|
|
|
|
|
-type sshServer struct {
|
|
|
- config *Config
|
|
|
- useObfuscation bool
|
|
|
- shutdownBroadcast <-chan struct{}
|
|
|
- sshConfig *ssh.ServerConfig
|
|
|
- clientsMutex sync.Mutex
|
|
|
- stoppingClients bool
|
|
|
- clients map[string]*sshClient
|
|
|
-}
|
|
|
-
|
|
|
-type sshClient struct {
|
|
|
- sync.Mutex
|
|
|
- sshConn ssh.Conn
|
|
|
- startTime time.Time
|
|
|
- geoIPData GeoIPData
|
|
|
- psiphonSessionID string
|
|
|
- bytesUp int64
|
|
|
- bytesDown int64
|
|
|
- portForwardCount int64
|
|
|
- concurrentPortForwardCount int64
|
|
|
- maxConcurrentPortForwardCount int64
|
|
|
-}
|
|
|
-
|
|
|
func runSSHServer(
|
|
|
config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
|
|
|
|
|
|
- sshServer := &sshServer{
|
|
|
- config: config,
|
|
|
- useObfuscation: useObfuscation,
|
|
|
- shutdownBroadcast: shutdownBroadcast,
|
|
|
- clients: make(map[string]*sshClient),
|
|
|
- }
|
|
|
-
|
|
|
- sshServer.sshConfig = &ssh.ServerConfig{
|
|
|
- PasswordCallback: sshServer.passwordCallback,
|
|
|
- AuthLogCallback: sshServer.authLogCallback,
|
|
|
- ServerVersion: config.SSHServerVersion,
|
|
|
- }
|
|
|
-
|
|
|
privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
|
|
|
if err != nil {
|
|
|
return psiphon.ContextError(err)
|
|
|
@@ -91,7 +55,14 @@ func runSSHServer(
|
|
|
return psiphon.ContextError(err)
|
|
|
}
|
|
|
|
|
|
- sshServer.sshConfig.AddHostKey(signer)
|
|
|
+ sshServer := &sshServer{
|
|
|
+ config: config,
|
|
|
+ useObfuscation: useObfuscation,
|
|
|
+ shutdownBroadcast: shutdownBroadcast,
|
|
|
+ sshHostKey: signer,
|
|
|
+ nextClientID: 1,
|
|
|
+ clients: make(map[sshClientID]*sshClient),
|
|
|
+ }
|
|
|
|
|
|
var serverPort int
|
|
|
if useObfuscation {
|
|
|
@@ -173,65 +144,53 @@ func runSSHServer(
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) registerClient(client *sshClient) (string, bool) {
|
|
|
+type sshClientID uint64
|
|
|
+
|
|
|
+type sshServer struct {
|
|
|
+ config *Config
|
|
|
+ useObfuscation bool
|
|
|
+ shutdownBroadcast <-chan struct{}
|
|
|
+ sshHostKey ssh.Signer
|
|
|
+ nextClientID sshClientID
|
|
|
+ clientsMutex sync.Mutex
|
|
|
+ stoppingClients bool
|
|
|
+ clients map[sshClientID]*sshClient
|
|
|
+}
|
|
|
+
|
|
|
+func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
|
|
|
+
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
defer sshServer.clientsMutex.Unlock()
|
|
|
+
|
|
|
if sshServer.stoppingClients {
|
|
|
- return "", false
|
|
|
+ return 0, false
|
|
|
}
|
|
|
- key := string(client.sshConn.SessionID())
|
|
|
- existingClient := sshServer.clients[key]
|
|
|
- if existingClient != nil {
|
|
|
- log.WithContext().Warning("unexpected existing connection")
|
|
|
- existingClient.sshConn.Close()
|
|
|
- existingClient.sshConn.Wait()
|
|
|
- }
|
|
|
- sshServer.clients[key] = client
|
|
|
- return key, true
|
|
|
-}
|
|
|
-
|
|
|
-func (sshServer *sshServer) getClientGeoIPData(clientKey string) GeoIPData {
|
|
|
- sshServer.clientsMutex.Lock()
|
|
|
- client, ok := sshServer.clients[clientKey]
|
|
|
- sshServer.clientsMutex.Unlock()
|
|
|
|
|
|
- geoIPData := NewGeoIPData()
|
|
|
+ clientID := sshServer.nextClientID
|
|
|
+ sshServer.nextClientID += 1
|
|
|
|
|
|
- if ok {
|
|
|
- client.Lock()
|
|
|
- geoIPData = client.geoIPData
|
|
|
- client.Unlock()
|
|
|
- }
|
|
|
+ sshServer.clients[clientID] = client
|
|
|
|
|
|
- return geoIPData
|
|
|
+ return clientID, true
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) updateClient(
|
|
|
- clientKey string, updater func(*sshClient)) {
|
|
|
+func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
|
|
|
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
- client, ok := sshServer.clients[clientKey]
|
|
|
+ client := sshServer.clients[clientID]
|
|
|
+ delete(sshServer.clients, clientID)
|
|
|
sshServer.clientsMutex.Unlock()
|
|
|
- if ok {
|
|
|
- client.Lock()
|
|
|
- updater(client)
|
|
|
- client.Unlock()
|
|
|
- }
|
|
|
-}
|
|
|
|
|
|
-func (sshServer *sshServer) unregisterClient(clientKey string) {
|
|
|
- sshServer.clientsMutex.Lock()
|
|
|
- client := sshServer.clients[clientKey]
|
|
|
- delete(sshServer.clients, clientKey)
|
|
|
- sshServer.clientsMutex.Unlock()
|
|
|
if client != nil {
|
|
|
sshServer.stopClient(client)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (sshServer *sshServer) stopClient(client *sshClient) {
|
|
|
+
|
|
|
client.sshConn.Close()
|
|
|
client.sshConn.Wait()
|
|
|
+
|
|
|
client.Lock()
|
|
|
log.WithContextFields(
|
|
|
LogFields{
|
|
|
@@ -250,10 +209,12 @@ func (sshServer *sshServer) stopClient(client *sshClient) {
|
|
|
}
|
|
|
|
|
|
func (sshServer *sshServer) stopClients() {
|
|
|
+
|
|
|
sshServer.clientsMutex.Lock()
|
|
|
sshServer.stoppingClients = true
|
|
|
- sshServer.clients = make(map[string]*sshClient)
|
|
|
+ sshServer.clients = make(map[sshClientID]*sshClient)
|
|
|
sshServer.clientsMutex.Unlock()
|
|
|
+
|
|
|
for _, client := range sshServer.clients {
|
|
|
sshServer.stopClient(client)
|
|
|
}
|
|
|
@@ -261,8 +222,11 @@ func (sshServer *sshServer) stopClients() {
|
|
|
|
|
|
func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
|
|
|
- startTime := time.Now()
|
|
|
- geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
|
|
|
+ sshClient := &sshClient{
|
|
|
+ sshServer: sshServer,
|
|
|
+ startTime: time.Now(),
|
|
|
+ geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())),
|
|
|
+ }
|
|
|
|
|
|
// Wrap the base TCP connection in a TimeoutTCPConn which will terminate
|
|
|
// the connection if it's idle for too long. This timeout is in effect for
|
|
|
@@ -303,8 +267,16 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
result.conn = conn
|
|
|
}
|
|
|
if result.err == nil {
|
|
|
- result.sshConn, result.channels,
|
|
|
- result.requests, result.err = ssh.NewServerConn(result.conn, sshServer.sshConfig)
|
|
|
+
|
|
|
+ sshServerConfig := &ssh.ServerConfig{
|
|
|
+ PasswordCallback: sshClient.passwordCallback,
|
|
|
+ AuthLogCallback: sshClient.authLogCallback,
|
|
|
+ ServerVersion: sshServer.config.SSHServerVersion,
|
|
|
+ }
|
|
|
+ sshServerConfig.AddHostKey(sshServer.sshHostKey)
|
|
|
+
|
|
|
+ result.sshConn, result.channels, result.requests, result.err =
|
|
|
+ ssh.NewServerConn(result.conn, sshServerConfig)
|
|
|
}
|
|
|
resultChannel <- result
|
|
|
}()
|
|
|
@@ -325,83 +297,51 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- clientKey, ok := sshServer.registerClient(
|
|
|
- &sshClient{
|
|
|
- sshConn: result.sshConn,
|
|
|
- startTime: startTime,
|
|
|
- geoIPData: geoIPData,
|
|
|
- })
|
|
|
+ sshClient.Lock()
|
|
|
+ sshClient.sshConn = result.sshConn
|
|
|
+ sshClient.Unlock()
|
|
|
+
|
|
|
+ clientID, ok := sshServer.registerClient(sshClient)
|
|
|
if !ok {
|
|
|
- result.sshConn.Close()
|
|
|
+ tcpConn.Close()
|
|
|
log.WithContext().Warning("register failed")
|
|
|
return
|
|
|
}
|
|
|
- defer sshServer.unregisterClient(clientKey)
|
|
|
+ defer sshServer.unregisterClient(clientID)
|
|
|
|
|
|
go ssh.DiscardRequests(result.requests)
|
|
|
|
|
|
- for newChannel := range result.channels {
|
|
|
-
|
|
|
- if newChannel.ChannelType() != "direct-tcpip" {
|
|
|
- sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- // process each port forward concurrently
|
|
|
- go sshServer.handleNewDirectTcpipChannel(clientKey, newChannel)
|
|
|
- }
|
|
|
+ sshClient.handleChannels(result.channels)
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
- var sshPasswordPayload struct {
|
|
|
- SessionId string `json:"SessionId"`
|
|
|
- SshPassword string `json:"SshPassword"`
|
|
|
- }
|
|
|
- err := json.Unmarshal(password, &sshPasswordPayload)
|
|
|
- if err != nil {
|
|
|
- return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
|
|
|
- }
|
|
|
-
|
|
|
- userOk := (subtle.ConstantTimeCompare(
|
|
|
- []byte(conn.User()), []byte(sshServer.config.SSHUserName)) == 1)
|
|
|
-
|
|
|
- passwordOk := (subtle.ConstantTimeCompare(
|
|
|
- []byte(sshPasswordPayload.SshPassword), []byte(sshServer.config.SSHPassword)) == 1)
|
|
|
-
|
|
|
- if !userOk || !passwordOk {
|
|
|
- return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
|
|
|
- }
|
|
|
-
|
|
|
- clientKey := string(conn.SessionID())
|
|
|
- psiphonSessionID := sshPasswordPayload.SessionId
|
|
|
+type sshClient struct {
|
|
|
+ sync.Mutex
|
|
|
+ sshServer *sshServer
|
|
|
+ sshConn ssh.Conn
|
|
|
+ startTime time.Time
|
|
|
+ geoIPData GeoIPData
|
|
|
+ psiphonSessionID string
|
|
|
+ bytesUp int64
|
|
|
+ bytesDown int64
|
|
|
+ portForwardCount int64
|
|
|
+ concurrentPortForwardCount int64
|
|
|
+ maxConcurrentPortForwardCount int64
|
|
|
+}
|
|
|
|
|
|
- sshServer.updateClient(clientKey, func(client *sshClient) {
|
|
|
- client.psiphonSessionID = psiphonSessionID
|
|
|
- })
|
|
|
+func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
|
|
|
+ for newChannel := range channels {
|
|
|
|
|
|
- if sshServer.config.UseRedis() {
|
|
|
- err = UpdateRedisForLegacyPsiWeb(
|
|
|
- psiphonSessionID, sshServer.getClientGeoIPData(clientKey))
|
|
|
- if err != nil {
|
|
|
- log.WithContextFields(LogFields{
|
|
|
- "psiphonSessionID": psiphonSessionID,
|
|
|
- "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
|
|
|
- // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
|
|
|
+ if newChannel.ChannelType() != "direct-tcpip" {
|
|
|
+ sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
|
|
|
+ return
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- return nil, nil
|
|
|
-}
|
|
|
-
|
|
|
-func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
|
|
- if err != nil {
|
|
|
- log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
|
|
|
- } else {
|
|
|
- log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
|
|
|
+ // process each port forward concurrently
|
|
|
+ go sshClient.handleNewDirectTcpipChannel(newChannel)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
|
|
|
+func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
|
|
|
// TODO: log more details?
|
|
|
log.WithContextFields(
|
|
|
LogFields{
|
|
|
@@ -412,7 +352,7 @@ func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason s
|
|
|
newChannel.Reject(reason, message)
|
|
|
}
|
|
|
|
|
|
-func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newChannel ssh.NewChannel) {
|
|
|
+func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
|
|
|
|
|
|
// http://tools.ietf.org/html/rfc4254#section-7.2
|
|
|
var directTcpipExtraData struct {
|
|
|
@@ -424,7 +364,7 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
|
|
|
|
|
|
err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
|
|
|
if err != nil {
|
|
|
- sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
|
|
|
+ sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
@@ -439,7 +379,7 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
|
|
|
// TODO: IPv6 support
|
|
|
fwdConn, err := net.Dial("tcp4", targetAddr)
|
|
|
if err != nil {
|
|
|
- sshServer.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
|
|
|
+ sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
|
|
|
return
|
|
|
}
|
|
|
defer fwdConn.Close()
|
|
|
@@ -450,13 +390,13 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- sshServer.updateClient(clientKey, func(client *sshClient) {
|
|
|
- client.portForwardCount += 1
|
|
|
- client.concurrentPortForwardCount += 1
|
|
|
- if client.concurrentPortForwardCount > client.maxConcurrentPortForwardCount {
|
|
|
- client.maxConcurrentPortForwardCount = client.concurrentPortForwardCount
|
|
|
- }
|
|
|
- })
|
|
|
+ sshClient.Lock()
|
|
|
+ sshClient.portForwardCount += 1
|
|
|
+ sshClient.concurrentPortForwardCount += 1
|
|
|
+ if sshClient.concurrentPortForwardCount > sshClient.maxConcurrentPortForwardCount {
|
|
|
+ sshClient.maxConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
|
|
|
+ }
|
|
|
+ sshClient.Unlock()
|
|
|
|
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
|
|
|
|
|
|
@@ -488,11 +428,59 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
|
|
|
fwdChannel.CloseWrite()
|
|
|
relayWaitGroup.Wait()
|
|
|
|
|
|
- sshServer.updateClient(clientKey, func(client *sshClient) {
|
|
|
- client.concurrentPortForwardCount -= 1
|
|
|
- client.bytesUp += bytesUp
|
|
|
- client.bytesDown += bytesDown
|
|
|
- })
|
|
|
+ sshClient.Lock()
|
|
|
+ sshClient.concurrentPortForwardCount -= 1
|
|
|
+ sshClient.bytesUp += bytesUp
|
|
|
+ sshClient.bytesDown += bytesDown
|
|
|
+ sshClient.Unlock()
|
|
|
|
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
|
|
|
}
|
|
|
+
|
|
|
+func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
+ var sshPasswordPayload struct {
|
|
|
+ SessionId string `json:"SessionId"`
|
|
|
+ SshPassword string `json:"SshPassword"`
|
|
|
+ }
|
|
|
+ err := json.Unmarshal(password, &sshPasswordPayload)
|
|
|
+ if err != nil {
|
|
|
+ return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
|
|
|
+ }
|
|
|
+
|
|
|
+ userOk := (subtle.ConstantTimeCompare(
|
|
|
+ []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
|
|
|
+
|
|
|
+ passwordOk := (subtle.ConstantTimeCompare(
|
|
|
+ []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
|
|
|
+
|
|
|
+ if !userOk || !passwordOk {
|
|
|
+ return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
|
|
|
+ }
|
|
|
+
|
|
|
+ psiphonSessionID := sshPasswordPayload.SessionId
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ sshClient.psiphonSessionID = psiphonSessionID
|
|
|
+ geoIPData := sshClient.geoIPData
|
|
|
+ sshClient.Unlock()
|
|
|
+
|
|
|
+ if sshClient.sshServer.config.UseRedis() {
|
|
|
+ err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
|
|
|
+ if err != nil {
|
|
|
+ log.WithContextFields(LogFields{
|
|
|
+ "psiphonSessionID": psiphonSessionID,
|
|
|
+ "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
|
|
|
+ // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
|
|
+ if err != nil {
|
|
|
+ log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
|
|
|
+ } else {
|
|
|
+ log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
|
|
|
+ }
|
|
|
+}
|