Procházet zdrojové kódy

Log tunnel stats
* GeoIP lookup of client IP; record country, city, ISP
* logged at end of each tunnel (client ssh session)
* start time/duration
* number of port forwards and concurrent high water mark
* total bytes transferred
* include Psiphon session ID

Rod Hynes před 10 roky
rodič
revize
10ea3c505f
1 změnil soubory, kde provedl 136 přidání a 73 odebrání
  1. 136 73
      psiphon/server/sshService.go

+ 136 - 73
psiphon/server/sshService.go

@@ -21,7 +21,6 @@ package server
 
 import (
 	"crypto/subtle"
-	"encoding/hex"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -47,9 +46,22 @@ type sshServer struct {
 	useObfuscation    bool
 	shutdownBroadcast <-chan struct{}
 	sshConfig         *ssh.ServerConfig
-	clientMutex       sync.Mutex
+	clientsMutex      sync.Mutex
 	stoppingClients   bool
-	clients           map[string]ssh.Conn
+	clients           map[string]*sshClient
+}
+
+type sshClient struct {
+	sync.Mutex
+	sshConn                       ssh.Conn
+	startTime                     time.Time
+	geoIPData                     GeoIPData
+	psiphonSessionID              string
+	bytesUp                       int64
+	bytesDown                     int64
+	portForwardCount              int64
+	concurrentPortForwardCount    int64
+	maxConcurrentPortForwardCount int64
 }
 
 func runSSHServer(
@@ -59,7 +71,7 @@ func runSSHServer(
 		config:            config,
 		useObfuscation:    useObfuscation,
 		shutdownBroadcast: shutdownBroadcast,
-		clients:           make(map[string]ssh.Conn),
+		clients:           make(map[string]*sshClient),
 	}
 
 	sshServer.sshConfig = &ssh.ServerConfig{
@@ -161,86 +173,83 @@ func runSSHServer(
 	return err
 }
 
-func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
-	var sshPasswordPayload struct {
-		SessionId   string `json:"SessionId"`
-		SshPassword string `json:"SshPassword"`
-	}
-	err := json.Unmarshal(password, &sshPasswordPayload)
-	if err != nil {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
+func (sshServer *sshServer) registerClient(client *sshClient) (string, bool) {
+	sshServer.clientsMutex.Lock()
+	defer sshServer.clientsMutex.Unlock()
+	if sshServer.stoppingClients {
+		return "", false
 	}
-
-	userOk := (subtle.ConstantTimeCompare(
-		[]byte(conn.User()), []byte(sshServer.config.SSHUserName)) == 1)
-
-	passwordOk := (subtle.ConstantTimeCompare(
-		[]byte(sshPasswordPayload.SshPassword), []byte(sshServer.config.SSHPassword)) == 1)
-
-	if !userOk || !passwordOk {
-		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
+	key := string(client.sshConn.SessionID())
+	existingClient := sshServer.clients[key]
+	if existingClient != nil {
+		log.WithContext().Warning("unexpected existing connection")
+		client.sshConn.Close()
+		client.sshConn.Wait()
 	}
-
-	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(conn.RemoteAddr()))
-
-	log.WithContextFields(
-		LogFields{
-			"sshSessionID":     hex.EncodeToString(conn.SessionID()),
-			"psiphonSessionID": sshPasswordPayload.SessionId,
-			"country":          geoIPData.Country,
-			"city":             geoIPData.City,
-			"ISP":              geoIPData.ISP,
-		}).Info("tunnel started")
-
-	return nil, nil
+	sshServer.clients[key] = client
+	return key, true
 }
 
-func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
-	if err != nil {
-		log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
-	} else {
-		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
+func (sshServer *sshServer) updateClient(
+	clientKey string, updater func(*sshClient)) {
+
+	sshServer.clientsMutex.Lock()
+	sshClient, ok := sshServer.clients[clientKey]
+	sshServer.clientsMutex.Unlock()
+	if ok {
+		sshClient.Lock()
+		updater(sshClient)
+		sshClient.Unlock()
 	}
 }
 
-func (sshServer *sshServer) registerClient(sshConn ssh.Conn) bool {
-	sshServer.clientMutex.Lock()
-	defer sshServer.clientMutex.Unlock()
+func (sshServer *sshServer) unregisterClient(clientKey string) {
+	sshServer.clientsMutex.Lock()
 	if sshServer.stoppingClients {
-		return false
-	}
-	existingSshConn := sshServer.clients[string(sshConn.SessionID())]
-	if existingSshConn != nil {
-		log.WithContext().Warning("unexpected existing connection")
-		existingSshConn.Close()
-		existingSshConn.Wait()
+		return
 	}
-	sshServer.clients[string(sshConn.SessionID())] = sshConn
-	return true
-}
+	client := sshServer.clients[clientKey]
+	delete(sshServer.clients, clientKey)
+	sshServer.clientsMutex.Unlock()
 
-func (sshServer *sshServer) unregisterClient(sshConn ssh.Conn) {
-	sshServer.clientMutex.Lock()
-	if sshServer.stoppingClients {
+	if client == nil {
 		return
 	}
-	delete(sshServer.clients, string(sshConn.SessionID()))
-	sshServer.clientMutex.Unlock()
-	sshConn.Close()
+
+	client.sshConn.Close()
+	client.Lock()
+	log.WithContextFields(
+		LogFields{
+			"startTime":                     client.startTime,
+			"duration":                      time.Now().Sub(client.startTime),
+			"psiphonSessionID":              client.psiphonSessionID,
+			"country":                       client.geoIPData.Country,
+			"city":                          client.geoIPData.City,
+			"ISP":                           client.geoIPData.ISP,
+			"bytesUp":                       client.bytesUp,
+			"bytesDown":                     client.bytesDown,
+			"portForwardCount":              client.portForwardCount,
+			"maxConcurrentPortForwardCount": client.maxConcurrentPortForwardCount,
+		}).Info("tunnel closed")
+	client.Unlock()
 }
 
 func (sshServer *sshServer) stopClients() {
-	sshServer.clientMutex.Lock()
+	sshServer.clientsMutex.Lock()
 	sshServer.stoppingClients = true
-	sshServer.clientMutex.Unlock()
-	for _, sshConn := range sshServer.clients {
-		sshConn.Close()
-		sshConn.Wait()
+	sshServer.clientsMutex.Unlock()
+	for _, client := range sshServer.clients {
+		client.sshConn.Close()
+		client.sshConn.Wait()
 	}
+	sshServer.clients = make(map[string]*sshClient)
 }
 
 func (sshServer *sshServer) handleClient(conn net.Conn) {
 
+	startTime := time.Now()
+	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(conn.RemoteAddr()))
+
 	// Run the initial [obfuscated] SSH handshake in a goroutine
 	// so we can both respect shutdownBroadcast and implement a
 	// handshake timeout. The timeout is to reclaim network
@@ -293,16 +302,18 @@ func (sshServer *sshServer) handleClient(conn net.Conn) {
 		return
 	}
 
-	if !sshServer.registerClient(result.sshConn) {
+	clientKey, ok := sshServer.registerClient(
+		&sshClient{
+			sshConn:   result.sshConn,
+			startTime: startTime,
+			geoIPData: geoIPData,
+		})
+	if !ok {
 		result.sshConn.Close()
 		log.WithContext().Warning("register failed")
 		return
 	}
-	defer sshServer.unregisterClient(result.sshConn)
-
-	// TODO: don't record IP; do GeoIP
-	log.WithContextFields(
-		LogFields{"remoteAddr": result.sshConn.RemoteAddr()}).Warning("connection accepted")
+	defer sshServer.unregisterClient(clientKey)
 
 	go ssh.DiscardRequests(result.requests)
 
@@ -314,7 +325,42 @@ func (sshServer *sshServer) handleClient(conn net.Conn) {
 		}
 
 		// process each port forward concurrently
-		go sshServer.handleNewDirectTcpipChannel(newChannel)
+		go sshServer.handleNewDirectTcpipChannel(clientKey, newChannel)
+	}
+}
+
+func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
+	var sshPasswordPayload struct {
+		SessionId   string `json:"SessionId"`
+		SshPassword string `json:"SshPassword"`
+	}
+	err := json.Unmarshal(password, &sshPasswordPayload)
+	if err != nil {
+		return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
+	}
+
+	userOk := (subtle.ConstantTimeCompare(
+		[]byte(conn.User()), []byte(sshServer.config.SSHUserName)) == 1)
+
+	passwordOk := (subtle.ConstantTimeCompare(
+		[]byte(sshPasswordPayload.SshPassword), []byte(sshServer.config.SSHPassword)) == 1)
+
+	if !userOk || !passwordOk {
+		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
+	}
+
+	clientKey := string(conn.SessionID())
+	sshServer.updateClient(clientKey, func(client *sshClient) {
+		client.psiphonSessionID = sshPasswordPayload.SessionId
+	})
+	return nil, nil
+}
+
+func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
+	if err != nil {
+		log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
+	} else {
+		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
 	}
 }
 
@@ -329,7 +375,7 @@ func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason s
 	newChannel.Reject(reason, message)
 }
 
-func (sshServer *sshServer) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
+func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newChannel ssh.NewChannel) {
 
 	// http://tools.ietf.org/html/rfc4254#section-7.2
 	var directTcpipExtraData struct {
@@ -366,6 +412,14 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 		return
 	}
 
+	sshServer.updateClient(clientKey, func(client *sshClient) {
+		client.portForwardCount += 1
+		client.concurrentPortForwardCount += 1
+		if client.concurrentPortForwardCount > client.maxConcurrentPortForwardCount {
+			client.maxConcurrentPortForwardCount = client.concurrentPortForwardCount
+		}
+	})
+
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
 
 	go ssh.DiscardRequests(requests)
@@ -377,21 +431,30 @@ func (sshServer *sshServer) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
 	// TODO: use a low-memory io.Copy?
 	// TODO: relay errors to fwdChannel.Stderr()?
 
+	var bytesUp, bytesDown int64
+
 	relayWaitGroup := new(sync.WaitGroup)
 	relayWaitGroup.Add(1)
 	go func() {
 		defer relayWaitGroup.Done()
-		_, err := io.Copy(fwdConn, fwdChannel)
+		var err error
+		bytesUp, err = io.Copy(fwdConn, fwdChannel)
 		if err != nil {
 			log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
 		}
 	}()
-	_, err = io.Copy(fwdChannel, fwdConn)
+	bytesDown, err = io.Copy(fwdChannel, fwdConn)
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
 	}
 	fwdChannel.CloseWrite()
 	relayWaitGroup.Wait()
 
+	sshServer.updateClient(clientKey, func(client *sshClient) {
+		client.concurrentPortForwardCount -= 1
+		client.bytesUp += bytesUp
+		client.bytesDown += bytesDown
+	})
+
 	log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
 }