Explorar el Código

Rearranged ssh server code
* Fix: populate sshClient.geoIPData before it is
used for UpdateRedisForLegacyPsiWeb in the
passwordCallback
* Associate more methods to sshClient type (this
was necessary in the case of the passwordCallback
as this was invoked before any reference to the
client GeoIPData could be associated with data
passed to the callback; but now the callback
method implicitly receives the sshClient itself)
* sshServer type still keeps a list of clients,
but has less client logic

Rod Hynes hace 10 años
padre
commit
b095940913
Se han modificado 1 ficheros con 144 adiciones y 156 borrados
  1. 144 156
      psiphon/server/sshService.go

+ 144 - 156
psiphon/server/sshService.go

@@ -41,45 +41,9 @@ func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) e
 	return runSSHServer(config, true, shutdownBroadcast)
 }
 
-type sshServer struct {
-	config            *Config
-	useObfuscation    bool
-	shutdownBroadcast <-chan struct{}
-	sshConfig         *ssh.ServerConfig
-	clientsMutex      sync.Mutex
-	stoppingClients   bool
-	clients           map[string]*sshClient
-}
-
-type sshClient struct {
-	sync.Mutex
-	sshConn                       ssh.Conn
-	startTime                     time.Time
-	geoIPData                     GeoIPData
-	psiphonSessionID              string
-	bytesUp                       int64
-	bytesDown                     int64
-	portForwardCount              int64
-	concurrentPortForwardCount    int64
-	maxConcurrentPortForwardCount int64
-}
-
 func runSSHServer(
 	config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
 
-	sshServer := &sshServer{
-		config:            config,
-		useObfuscation:    useObfuscation,
-		shutdownBroadcast: shutdownBroadcast,
-		clients:           make(map[string]*sshClient),
-	}
-
-	sshServer.sshConfig = &ssh.ServerConfig{
-		PasswordCallback: sshServer.passwordCallback,
-		AuthLogCallback:  sshServer.authLogCallback,
-		ServerVersion:    config.SSHServerVersion,
-	}
-
 	privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
 	if err != nil {
 		return psiphon.ContextError(err)
@@ -91,7 +55,14 @@ func runSSHServer(
 		return psiphon.ContextError(err)
 	}
 
-	sshServer.sshConfig.AddHostKey(signer)
+	sshServer := &sshServer{
+		config:            config,
+		useObfuscation:    useObfuscation,
+		shutdownBroadcast: shutdownBroadcast,
+		sshHostKey:        signer,
+		nextClientID:      1,
+		clients:           make(map[sshClientID]*sshClient),
+	}
 
 	var serverPort int
 	if useObfuscation {
@@ -173,65 +144,53 @@ func runSSHServer(
 	return err
 }
 
-func (sshServer *sshServer) registerClient(client *sshClient) (string, bool) {
+type sshClientID uint64
+
+type sshServer struct {
+	config            *Config
+	useObfuscation    bool
+	shutdownBroadcast <-chan struct{}
+	sshHostKey        ssh.Signer
+	nextClientID      sshClientID
+	clientsMutex      sync.Mutex
+	stoppingClients   bool
+	clients           map[sshClientID]*sshClient
+}
+
+func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
+
 	sshServer.clientsMutex.Lock()
 	defer sshServer.clientsMutex.Unlock()
+
 	if sshServer.stoppingClients {
-		return "", false
+		return 0, false
 	}
-	key := string(client.sshConn.SessionID())
-	existingClient := sshServer.clients[key]
-	if existingClient != nil {
-		log.WithContext().Warning("unexpected existing connection")
-		existingClient.sshConn.Close()
-		existingClient.sshConn.Wait()
-	}
-	sshServer.clients[key] = client
-	return key, true
-}
-
-func (sshServer *sshServer) getClientGeoIPData(clientKey string) GeoIPData {
-	sshServer.clientsMutex.Lock()
-	client, ok := sshServer.clients[clientKey]
-	sshServer.clientsMutex.Unlock()
 
-	geoIPData := NewGeoIPData()
+	clientID := sshServer.nextClientID
+	sshServer.nextClientID += 1
 
-	if ok {
-		client.Lock()
-		geoIPData = client.geoIPData
-		client.Unlock()
-	}
+	sshServer.clients[clientID] = client
 
-	return geoIPData
+	return clientID, true
 }
 
-func (sshServer *sshServer) updateClient(
-	clientKey string, updater func(*sshClient)) {
+func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
 
 	sshServer.clientsMutex.Lock()
-	client, ok := sshServer.clients[clientKey]
+	client := sshServer.clients[clientID]
+	delete(sshServer.clients, clientID)
 	sshServer.clientsMutex.Unlock()
-	if ok {
-		client.Lock()
-		updater(client)
-		client.Unlock()
-	}
-}
 
-func (sshServer *sshServer) unregisterClient(clientKey string) {
-	sshServer.clientsMutex.Lock()
-	client := sshServer.clients[clientKey]
-	delete(sshServer.clients, clientKey)
-	sshServer.clientsMutex.Unlock()
 	if client != nil {
 		sshServer.stopClient(client)
 	}
 }
 
 func (sshServer *sshServer) stopClient(client *sshClient) {
+
 	client.sshConn.Close()
 	client.sshConn.Wait()
+
 	client.Lock()
 	log.WithContextFields(
 		LogFields{
@@ -250,10 +209,12 @@ func (sshServer *sshServer) stopClient(client *sshClient) {
 }
 
 func (sshServer *sshServer) stopClients() {
+
 	sshServer.clientsMutex.Lock()
 	sshServer.stoppingClients = true
-	sshServer.clients = make(map[string]*sshClient)
+	sshServer.clients = make(map[sshClientID]*sshClient)
 	sshServer.clientsMutex.Unlock()
+
 	for _, client := range sshServer.clients {
 		sshServer.stopClient(client)
 	}
@@ -261,8 +222,11 @@ func (sshServer *sshServer) stopClients() {
 
 func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
-	startTime := time.Now()
-	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
+	sshClient := &sshClient{
+		sshServer: sshServer,
+		startTime: time.Now(),
+		geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())),
+	}
 
 	// Wrap the base TCP connection in a TimeoutTCPConn which will terminate
 	// the connection if it's idle for too long. This timeout is in effect for
@@ -303,8 +267,16 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 			result.conn = conn
 		}
 		if result.err == nil {
-			result.sshConn, result.channels,
-				result.requests, result.err = ssh.NewServerConn(result.conn, sshServer.sshConfig)
+
+			sshServerConfig := &ssh.ServerConfig{
+				PasswordCallback: sshClient.passwordCallback,
+				AuthLogCallback:  sshClient.authLogCallback,
+				ServerVersion:    sshServer.config.SSHServerVersion,
+			}
+			sshServerConfig.AddHostKey(sshServer.sshHostKey)
+
+			result.sshConn, result.channels, result.requests, result.err =
+				ssh.NewServerConn(result.conn, sshServerConfig)
 		}
 		resultChannel <- result
 	}()
@@ -325,83 +297,51 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 		return
 	}
 
-	clientKey, ok := sshServer.registerClient(
-		&sshClient{
-			sshConn:   result.sshConn,
-			startTime: startTime,
-			geoIPData: geoIPData,
-		})
+	sshClient.Lock()
+	sshClient.sshConn = result.sshConn
+	sshClient.Unlock()
+
+	clientID, ok := sshServer.registerClient(sshClient)
 	if !ok {
-		result.sshConn.Close()
+		tcpConn.Close()
 		log.WithContext().Warning("register failed")
 		return
 	}
-	defer sshServer.unregisterClient(clientKey)
+	defer sshServer.unregisterClient(clientID)
 
 	go ssh.DiscardRequests(result.requests)
 
-	for newChannel := range result.channels {
-
-		if newChannel.ChannelType() != "direct-tcpip" {
-			sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
-			return
-		}
-
-		// process each port forward concurrently
-		go sshServer.handleNewDirectTcpipChannel(clientKey, newChannel)
-	}
+	sshClient.handleChannels(result.channels)
 }
 
-func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
-	var sshPasswordPayload struct {
-		SessionId   string `json:"SessionId"`
-		SshPassword string `json:"SshPassword"`
-	}
-	err := json.Unmarshal(password, &sshPasswordPayload)
-	if err != nil {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
-	}
-
-	userOk := (subtle.ConstantTimeCompare(
-		[]byte(conn.User()), []byte(sshServer.config.SSHUserName)) == 1)
-
-	passwordOk := (subtle.ConstantTimeCompare(
-		[]byte(sshPasswordPayload.SshPassword), []byte(sshServer.config.SSHPassword)) == 1)
-
-	if !userOk || !passwordOk {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
-	}
-
-	clientKey := string(conn.SessionID())
-	psiphonSessionID := sshPasswordPayload.SessionId
+type sshClient struct {
+	sync.Mutex
+	sshServer                     *sshServer
+	sshConn                       ssh.Conn
+	startTime                     time.Time
+	geoIPData                     GeoIPData
+	psiphonSessionID              string
+	bytesUp                       int64
+	bytesDown                     int64
+	portForwardCount              int64
+	concurrentPortForwardCount    int64
+	maxConcurrentPortForwardCount int64
+}
 
-	sshServer.updateClient(clientKey, func(client *sshClient) {
-		client.psiphonSessionID = psiphonSessionID
-	})
+func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
+	for newChannel := range channels {
 
-	if sshServer.config.UseRedis() {
-		err = UpdateRedisForLegacyPsiWeb(
-			psiphonSessionID, sshServer.getClientGeoIPData(clientKey))
-		if err != nil {
-			log.WithContextFields(LogFields{
-				"psiphonSessionID": psiphonSessionID,
-				"error":            err}).Warning("UpdateRedisForLegacyPsiWeb failed")
-			// Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
+		if newChannel.ChannelType() != "direct-tcpip" {
+			sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
+			return
 		}
-	}
 
-	return nil, nil
-}
-
-func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
-	if err != nil {
-		log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
-	} else {
-		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
+		// process each port forward concurrently
+		go sshClient.handleNewDirectTcpipChannel(newChannel)
 	}
 }
 
-func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
+func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
 	// TODO: log more details?
 	log.WithContextFields(
 		LogFields{
@@ -412,7 +352,7 @@ func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason s
 	newChannel.Reject(reason, message)
 }
 
-func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newChannel ssh.NewChannel) {
+func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
 
 	// http://tools.ietf.org/html/rfc4254#section-7.2
 	var directTcpipExtraData struct {
@@ -424,7 +364,7 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
 
 	err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
 	if err != nil {
-		sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
+		sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
 		return
 	}
 
@@ -439,7 +379,7 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
 	// TODO: IPv6 support
 	fwdConn, err := net.Dial("tcp4", targetAddr)
 	if err != nil {
-		sshServer.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
+		sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
 		return
 	}
 	defer fwdConn.Close()
@@ -450,13 +390,13 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
 		return
 	}
 
-	sshServer.updateClient(clientKey, func(client *sshClient) {
-		client.portForwardCount += 1
-		client.concurrentPortForwardCount += 1
-		if client.concurrentPortForwardCount > client.maxConcurrentPortForwardCount {
-			client.maxConcurrentPortForwardCount = client.concurrentPortForwardCount
-		}
-	})
+	sshClient.Lock()
+	sshClient.portForwardCount += 1
+	sshClient.concurrentPortForwardCount += 1
+	if sshClient.concurrentPortForwardCount > sshClient.maxConcurrentPortForwardCount {
+		sshClient.maxConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
+	}
+	sshClient.Unlock()
 
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
 
@@ -488,11 +428,59 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newCha
 	fwdChannel.CloseWrite()
 	relayWaitGroup.Wait()
 
-	sshServer.updateClient(clientKey, func(client *sshClient) {
-		client.concurrentPortForwardCount -= 1
-		client.bytesUp += bytesUp
-		client.bytesDown += bytesDown
-	})
+	sshClient.Lock()
+	sshClient.concurrentPortForwardCount -= 1
+	sshClient.bytesUp += bytesUp
+	sshClient.bytesDown += bytesDown
+	sshClient.Unlock()
 
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
 }
+
+func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
+	var sshPasswordPayload struct {
+		SessionId   string `json:"SessionId"`
+		SshPassword string `json:"SshPassword"`
+	}
+	err := json.Unmarshal(password, &sshPasswordPayload)
+	if err != nil {
+		return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
+	}
+
+	userOk := (subtle.ConstantTimeCompare(
+		[]byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
+
+	passwordOk := (subtle.ConstantTimeCompare(
+		[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
+
+	if !userOk || !passwordOk {
+		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
+	}
+
+	psiphonSessionID := sshPasswordPayload.SessionId
+
+	sshClient.Lock()
+	sshClient.psiphonSessionID = psiphonSessionID
+	geoIPData := sshClient.geoIPData
+	sshClient.Unlock()
+
+	if sshClient.sshServer.config.UseRedis() {
+		err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
+		if err != nil {
+			log.WithContextFields(LogFields{
+				"psiphonSessionID": psiphonSessionID,
+				"error":            err}).Warning("UpdateRedisForLegacyPsiWeb failed")
+			// Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
+		}
+	}
+
+	return nil, nil
+}
+
+func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
+	if err != nil {
+		log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
+	} else {
+		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
+	}
+}