Ver Fonte

Rearranged ssh server code
* Fix: populate sshClient.geoIPData before it is
used for UpdateRedisForLegacyPsiWeb in the
passwordCallback
* Associate more methods to sshClient type (this
was necessary in the case of the passwordCallback
as this was invoked before any reference to the
client GeoIPData could be associated with data
passed to the callback; but now the callback
method implicitly receives the sshClient itself)
* sshServer type still keeps a list of clients,
but has less client logic

Rod Hynes há 10 anos atrás
pai
commit
b095940913
1 ficheiros alterados com 144 adições e 156 exclusões
  1. 144 156
      psiphon/server/sshService.go

+ 144 - 156
psiphon/server/sshService.go

@@ -41,45 +41,9 @@ func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) e
 	return runSSHServer(config, true, shutdownBroadcast)
 	return runSSHServer(config, true, shutdownBroadcast)
 }
 }
 
 
-type sshServer struct {
-	config            *Config
-	useObfuscation    bool
-	shutdownBroadcast <-chan struct{}
-	sshConfig         *ssh.ServerConfig
-	clientsMutex      sync.Mutex
-	stoppingClients   bool
-	clients           map[string]*sshClient
-}
-
-type sshClient struct {
-	sync.Mutex
-	sshConn                       ssh.Conn
-	startTime                     time.Time
-	geoIPData                     GeoIPData
-	psiphonSessionID              string
-	bytesUp                       int64
-	bytesDown                     int64
-	portForwardCount              int64
-	concurrentPortForwardCount    int64
-	maxConcurrentPortForwardCount int64
-}
-
 func runSSHServer(
 func runSSHServer(
 	config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
 	config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
 
 
-	sshServer := &sshServer{
-		config:            config,
-		useObfuscation:    useObfuscation,
-		shutdownBroadcast: shutdownBroadcast,
-		clients:           make(map[string]*sshClient),
-	}
-
-	sshServer.sshConfig = &ssh.ServerConfig{
-		PasswordCallback: sshServer.passwordCallback,
-		AuthLogCallback:  sshServer.authLogCallback,
-		ServerVersion:    config.SSHServerVersion,
-	}
-
 	privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
 	privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
 	if err != nil {
 	if err != nil {
 		return psiphon.ContextError(err)
 		return psiphon.ContextError(err)
@@ -91,7 +55,14 @@ func runSSHServer(
 		return psiphon.ContextError(err)
 		return psiphon.ContextError(err)
 	}
 	}
 
 
-	sshServer.sshConfig.AddHostKey(signer)
+	sshServer := &sshServer{
+		config:            config,
+		useObfuscation:    useObfuscation,
+		shutdownBroadcast: shutdownBroadcast,
+		sshHostKey:        signer,
+		nextClientID:      1,
+		clients:           make(map[sshClientID]*sshClient),
+	}
 
 
 	var serverPort int
 	var serverPort int
 	if useObfuscation {
 	if useObfuscation {
@@ -173,65 +144,53 @@ func runSSHServer(
 	return err
 	return err
 }
 }
 
 
-func (sshServer *sshServer) registerClient(client *sshClient) (string, bool) {
+type sshClientID uint64
+
+type sshServer struct {
+	config            *Config
+	useObfuscation    bool
+	shutdownBroadcast <-chan struct{}
+	sshHostKey        ssh.Signer
+	nextClientID      sshClientID
+	clientsMutex      sync.Mutex
+	stoppingClients   bool
+	clients           map[sshClientID]*sshClient
+}
+
+func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
+
 	sshServer.clientsMutex.Lock()
 	sshServer.clientsMutex.Lock()
 	defer sshServer.clientsMutex.Unlock()
 	defer sshServer.clientsMutex.Unlock()
+
 	if sshServer.stoppingClients {
 	if sshServer.stoppingClients {
-		return "", false
+		return 0, false
 	}
 	}
-	key := string(client.sshConn.SessionID())
-	existingClient := sshServer.clients[key]
-	if existingClient != nil {
-		log.WithContext().Warning("unexpected existing connection")
-		existingClient.sshConn.Close()
-		existingClient.sshConn.Wait()
-	}
-	sshServer.clients[key] = client
-	return key, true
-}
-
-func (sshServer *sshServer) getClientGeoIPData(clientKey string) GeoIPData {
-	sshServer.clientsMutex.Lock()
-	client, ok := sshServer.clients[clientKey]
-	sshServer.clientsMutex.Unlock()
 
 
-	geoIPData := NewGeoIPData()
+	clientID := sshServer.nextClientID
+	sshServer.nextClientID += 1
 
 
-	if ok {
-		client.Lock()
-		geoIPData = client.geoIPData
-		client.Unlock()
-	}
+	sshServer.clients[clientID] = client
 
 
-	return geoIPData
+	return clientID, true
 }
 }
 
 
-func (sshServer *sshServer) updateClient(
-	clientKey string, updater func(*sshClient)) {
+func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
 
 
 	sshServer.clientsMutex.Lock()
 	sshServer.clientsMutex.Lock()
-	client, ok := sshServer.clients[clientKey]
+	client := sshServer.clients[clientID]
+	delete(sshServer.clients, clientID)
 	sshServer.clientsMutex.Unlock()
 	sshServer.clientsMutex.Unlock()
-	if ok {
-		client.Lock()
-		updater(client)
-		client.Unlock()
-	}
-}
 
 
-func (sshServer *sshServer) unregisterClient(clientKey string) {
-	sshServer.clientsMutex.Lock()
-	client := sshServer.clients[clientKey]
-	delete(sshServer.clients, clientKey)
-	sshServer.clientsMutex.Unlock()
 	if client != nil {
 	if client != nil {
 		sshServer.stopClient(client)
 		sshServer.stopClient(client)
 	}
 	}
 }
 }
 
 
 func (sshServer *sshServer) stopClient(client *sshClient) {
 func (sshServer *sshServer) stopClient(client *sshClient) {
+
 	client.sshConn.Close()
 	client.sshConn.Close()
 	client.sshConn.Wait()
 	client.sshConn.Wait()
+
 	client.Lock()
 	client.Lock()
 	log.WithContextFields(
 	log.WithContextFields(
 		LogFields{
 		LogFields{
@@ -250,10 +209,12 @@ func (sshServer *sshServer) stopClient(client *sshClient) {
 }
 }
 
 
 func (sshServer *sshServer) stopClients() {
 func (sshServer *sshServer) stopClients() {
+
 	sshServer.clientsMutex.Lock()
 	sshServer.clientsMutex.Lock()
 	sshServer.stoppingClients = true
 	sshServer.stoppingClients = true
-	sshServer.clients = make(map[string]*sshClient)
+	sshServer.clients = make(map[sshClientID]*sshClient)
 	sshServer.clientsMutex.Unlock()
 	sshServer.clientsMutex.Unlock()
+
 	for _, client := range sshServer.clients {
 	for _, client := range sshServer.clients {
 		sshServer.stopClient(client)
 		sshServer.stopClient(client)
 	}
 	}
@@ -261,8 +222,11 @@ func (sshServer *sshServer) stopClients() {
 
 
 func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
 
-	startTime := time.Now()
-	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
+	sshClient := &sshClient{
+		sshServer: sshServer,
+		startTime: time.Now(),
+		geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())),
+	}
 
 
 	// Wrap the base TCP connection in a TimeoutTCPConn which will terminate
 	// Wrap the base TCP connection in a TimeoutTCPConn which will terminate
 	// the connection if it's idle for too long. This timeout is in effect for
 	// the connection if it's idle for too long. This timeout is in effect for
@@ -303,8 +267,16 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 			result.conn = conn
 			result.conn = conn
 		}
 		}
 		if result.err == nil {
 		if result.err == nil {
-			result.sshConn, result.channels,
-				result.requests, result.err = ssh.NewServerConn(result.conn, sshServer.sshConfig)
+
+			sshServerConfig := &ssh.ServerConfig{
+				PasswordCallback: sshClient.passwordCallback,
+				AuthLogCallback:  sshClient.authLogCallback,
+				ServerVersion:    sshServer.config.SSHServerVersion,
+			}
+			sshServerConfig.AddHostKey(sshServer.sshHostKey)
+
+			result.sshConn, result.channels, result.requests, result.err =
+				ssh.NewServerConn(result.conn, sshServerConfig)
 		}
 		}
 		resultChannel <- result
 		resultChannel <- result
 	}()
 	}()
@@ -325,83 +297,51 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 		return
 		return
 	}
 	}
 
 
-	clientKey, ok := sshServer.registerClient(
-		&sshClient{
-			sshConn:   result.sshConn,
-			startTime: startTime,
-			geoIPData: geoIPData,
-		})
+	sshClient.Lock()
+	sshClient.sshConn = result.sshConn
+	sshClient.Unlock()
+
+	clientID, ok := sshServer.registerClient(sshClient)
 	if !ok {
 	if !ok {
-		result.sshConn.Close()
+		tcpConn.Close()
 		log.WithContext().Warning("register failed")
 		log.WithContext().Warning("register failed")
 		return
 		return
 	}
 	}
-	defer sshServer.unregisterClient(clientKey)
+	defer sshServer.unregisterClient(clientID)
 
 
 	go ssh.DiscardRequests(result.requests)
 	go ssh.DiscardRequests(result.requests)
 
 
-	for newChannel := range result.channels {
-
-		if newChannel.ChannelType() != "direct-tcpip" {
-			sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
-			return
-		}
-
-		// process each port forward concurrently
-		go sshServer.handleNewDirectTcpipChannel(clientKey, newChannel)
-	}
+	sshClient.handleChannels(result.channels)
 }
 }
 
 
-func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
-	var sshPasswordPayload struct {
-		SessionId   string `json:"SessionId"`
-		SshPassword string `json:"SshPassword"`
-	}
-	err := json.Unmarshal(password, &sshPasswordPayload)
-	if err != nil {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
-	}
-
-	userOk := (subtle.ConstantTimeCompare(
-		[]byte(conn.User()), []byte(sshServer.config.SSHUserName)) == 1)
-
-	passwordOk := (subtle.ConstantTimeCompare(
-		[]byte(sshPasswordPayload.SshPassword), []byte(sshServer.config.SSHPassword)) == 1)
-
-	if !userOk || !passwordOk {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
-	}
-
-	clientKey := string(conn.SessionID())
-	psiphonSessionID := sshPasswordPayload.SessionId
+type sshClient struct {
+	sync.Mutex
+	sshServer                     *sshServer
+	sshConn                       ssh.Conn
+	startTime                     time.Time
+	geoIPData                     GeoIPData
+	psiphonSessionID              string
+	bytesUp                       int64
+	bytesDown                     int64
+	portForwardCount              int64
+	concurrentPortForwardCount    int64
+	maxConcurrentPortForwardCount int64
+}
 
 
-	sshServer.updateClient(clientKey, func(client *sshClient) {
-		client.psiphonSessionID = psiphonSessionID
-	})
+func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
+	for newChannel := range channels {
 
 
-	if sshServer.config.UseRedis() {
-		err = UpdateRedisForLegacyPsiWeb(
-			psiphonSessionID, sshServer.getClientGeoIPData(clientKey))
-		if err != nil {
-			log.WithContextFields(LogFields{
-				"psiphonSessionID": psiphonSessionID,
-				"error":            err}).Warning("UpdateRedisForLegacyPsiWeb failed")
-			// Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
+		if newChannel.ChannelType() != "direct-tcpip" {
+			sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
+			return
 		}
 		}
-	}
 
 
-	return nil, nil
-}
-
-func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
-	if err != nil {
-		log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
-	} else {
-		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
+		// process each port forward concurrently
+		go sshClient.handleNewDirectTcpipChannel(newChannel)
 	}
 	}
 }
 }
 
 
-func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
+func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
 	// TODO: log more details?
 	// TODO: log more details?
 	log.WithContextFields(
 	log.WithContextFields(
 		LogFields{
 		LogFields{
@@ -412,7 +352,7 @@ func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason s
 	newChannel.Reject(reason, message)
 	newChannel.Reject(reason, message)
 }
 }
 
 
-func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newChannel ssh.NewChannel) {
+func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
 
 
 	// http://tools.ietf.org/html/rfc4254#section-7.2
 	// http://tools.ietf.org/html/rfc4254#section-7.2
 	var directTcpipExtraData struct {
 	var directTcpipExtraData struct {
@@ -424,7 +364,7 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
 
 
 	err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
 	err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
 	if err != nil {
 	if err != nil {
-		sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
+		sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
 		return
 		return
 	}
 	}
 
 
@@ -439,7 +379,7 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
 	// TODO: IPv6 support
 	// TODO: IPv6 support
 	fwdConn, err := net.Dial("tcp4", targetAddr)
 	fwdConn, err := net.Dial("tcp4", targetAddr)
 	if err != nil {
 	if err != nil {
-		sshServer.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
+		sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
 		return
 		return
 	}
 	}
 	defer fwdConn.Close()
 	defer fwdConn.Close()
@@ -450,13 +390,13 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
 		return
 		return
 	}
 	}
 
 
-	sshServer.updateClient(clientKey, func(client *sshClient) {
-		client.portForwardCount += 1
-		client.concurrentPortForwardCount += 1
-		if client.concurrentPortForwardCount > client.maxConcurrentPortForwardCount {
-			client.maxConcurrentPortForwardCount = client.concurrentPortForwardCount
-		}
-	})
+	sshClient.Lock()
+	sshClient.portForwardCount += 1
+	sshClient.concurrentPortForwardCount += 1
+	if sshClient.concurrentPortForwardCount > sshClient.maxConcurrentPortForwardCount {
+		sshClient.maxConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
+	}
+	sshClient.Unlock()
 
 
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
 
 
@@ -488,11 +428,59 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
 	fwdChannel.CloseWrite()
 	fwdChannel.CloseWrite()
 	relayWaitGroup.Wait()
 	relayWaitGroup.Wait()
 
 
-	sshServer.updateClient(clientKey, func(client *sshClient) {
-		client.concurrentPortForwardCount -= 1
-		client.bytesUp += bytesUp
-		client.bytesDown += bytesDown
-	})
+	sshClient.Lock()
+	sshClient.concurrentPortForwardCount -= 1
+	sshClient.bytesUp += bytesUp
+	sshClient.bytesDown += bytesDown
+	sshClient.Unlock()
 
 
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
 }
 }
+
+func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
+	var sshPasswordPayload struct {
+		SessionId   string `json:"SessionId"`
+		SshPassword string `json:"SshPassword"`
+	}
+	err := json.Unmarshal(password, &sshPasswordPayload)
+	if err != nil {
+		return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
+	}
+
+	userOk := (subtle.ConstantTimeCompare(
+		[]byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
+
+	passwordOk := (subtle.ConstantTimeCompare(
+		[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
+
+	if !userOk || !passwordOk {
+		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
+	}
+
+	psiphonSessionID := sshPasswordPayload.SessionId
+
+	sshClient.Lock()
+	sshClient.psiphonSessionID = psiphonSessionID
+	geoIPData := sshClient.geoIPData
+	sshClient.Unlock()
+
+	if sshClient.sshServer.config.UseRedis() {
+		err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
+		if err != nil {
+			log.WithContextFields(LogFields{
+				"psiphonSessionID": psiphonSessionID,
+				"error":            err}).Warning("UpdateRedisForLegacyPsiWeb failed")
+			// Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
+		}
+	}
+
+	return nil, nil
+}
+
+func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
+	if err != nil {
+		log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
+	} else {
+		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
+	}
+}