Przeglądaj źródła

Refactoring and load monitor
* Optional load monitor periodically logs process
and connection stats.
* Single ssh server serves all tunnels, with
obfuscation protocols layered in front of
multiple listeners.
* Log relay socket errors at debug level only
to reduce noise from both normal network conditions
and use of closed sockets.

Rod Hynes 9 lat temu
rodzic
commit
80c956a35a

+ 11 - 0
psiphon/server/config.go

@@ -182,6 +182,12 @@ type Config struct {
 	// by this server, which parses the SSH channel using the udpgw
 	// protocol.
 	UdpgwServerAddress string
+
+	// LoadMonitorPeriodSeconds indicates how frequently to log server
+	// load information (number of connected clients per tunnel protocol,
+	// number of running goroutines, amount of memory allocated).
+	// The default, 0, disables load logging.
+	LoadMonitorPeriodSeconds int
 }
 
 // TrafficRules specify the limits placed on client traffic.
@@ -251,6 +257,11 @@ func (config *Config) RunObfuscatedSSHServer() bool {
 	return config.ObfuscatedSSHServerPort > 0
 }
 
+// RunLoadMonitor indicates whether to monitor and log server load.
+func (config *Config) RunLoadMonitor() bool {
+	return config.LoadMonitorPeriodSeconds > 0
+}
+
 // UseRedis indicates whether to store per-session GeoIP information in
 // redis. This is for integration with the legacy psi_web component.
 func (config *Config) UseRedis() bool {

+ 1 - 13
psiphon/server/services.go

@@ -79,7 +79,7 @@ func RunServices(encodedConfigs [][]byte) error {
 		}()
 	}
 
-	if config.RunSSHServer() {
+	if config.RunSSHServer() || config.RunObfuscatedSSHServer() {
 		waitGroup.Add(1)
 		go func() {
 			defer waitGroup.Done()
@@ -91,18 +91,6 @@ func RunServices(encodedConfigs [][]byte) error {
 		}()
 	}
 
-	if config.RunObfuscatedSSHServer() {
-		waitGroup.Add(1)
-		go func() {
-			defer waitGroup.Done()
-			err := RunObfuscatedSSHServer(config, shutdownBroadcast)
-			select {
-			case errors <- err:
-			default:
-			}
-		}()
-	}
-
 	// An OS signal triggers an orderly shutdown
 	systemStopSignal := make(chan os.Signal, 1)
 	signal.Notify(systemStopSignal, os.Interrupt, os.Kill)

+ 188 - 107
psiphon/server/sshService.go

@@ -26,6 +26,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"runtime"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -34,21 +35,12 @@ import (
 	"golang.org/x/crypto/ssh"
 )
 
-// RunSSHServer runs an ssh server with plain SSH protocol.
-func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
-	return runSSHServer(config, false, shutdownBroadcast)
-}
-
-// RunSSHServer runs an ssh server with Obfuscated SSH protocol.
-func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
-	return runSSHServer(config, true, shutdownBroadcast)
-}
-
-// runSSHServer runs an SSH or Obfuscated SSH server. In the Obfuscated SSH case, an
-// ObfuscatedSSHConn is layered in front of the client TCP connection; otherwise, both
-// modes are identical.
+// RunSSHServer runs an SSH server, the core tunneling component of the Psiphon
+// server. The SSH server runs a selection of listeners that handle connections
+// using various, optional obfuscation protocols layered on top of SSH.
+// (Currently, just Obfuscated SSH).
 //
-// runSSHServer listens on the designated port and spawns new goroutines to handle
+// RunSSHServer listens on the designated port(s) and spawns new goroutines to handle
 // each client connection. It halts when shutdownBroadcast is signaled. A list of active
 // clients is maintained, and when halting all clients are first shutdown.
 //
@@ -56,11 +48,12 @@ func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) e
 // authentication, and then looping on client new channel requests. At this time, only
 // "direct-tcpip" channels, dynamic port fowards, are expected and supported.
 //
-// A new goroutine is spawned to handle each port forward. Each port forward tracks its
-// bytes transferred. Overall per-client stats for connection duration, GeoIP, number of
-// port forwards, and bytes transferred are tracked and logged when the client shuts down.
-func runSSHServer(
-	config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
+// A new goroutine is spawned to handle each port forward for each client. Each port
+// forward tracks its bytes transferred. Overall per-client stats for connection duration,
+// GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
+// client shuts down.
+func RunSSHServer(
+	config *Config, shutdownBroadcast <-chan struct{}) error {
 
 	privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
 	if err != nil {
@@ -75,89 +68,94 @@ func runSSHServer(
 
 	sshServer := &sshServer{
 		config:            config,
-		useObfuscation:    useObfuscation,
+		runWaitGroup:      new(sync.WaitGroup),
+		listenerError:     make(chan error),
 		shutdownBroadcast: shutdownBroadcast,
 		sshHostKey:        signer,
 		nextClientID:      1,
 		clients:           make(map[sshClientID]*sshClient),
 	}
 
-	var serverPort int
-	if useObfuscation {
-		serverPort = config.ObfuscatedSSHServerPort
-	} else {
-		serverPort = config.SSHServerPort
-	}
-
-	listener, err := net.Listen(
-		"tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, serverPort))
-	if err != nil {
-		return psiphon.ContextError(err)
+	type sshListener struct {
+		net.Listener
+		localAddress   string
+		tunnelProtocol string
 	}
 
-	log.WithContextFields(
-		LogFields{
-			"useObfuscation": useObfuscation,
-			"port":           serverPort,
-		}).Info("starting")
+	var listeners []*sshListener
 
-	err = nil
-	errors := make(chan error)
-	waitGroup := new(sync.WaitGroup)
+	if config.RunSSHServer() {
+		listeners = append(listeners, &sshListener{
+			localAddress: fmt.Sprintf(
+				"%s:%d", config.ServerIPAddress, config.SSHServerPort),
+			tunnelProtocol: psiphon.TUNNEL_PROTOCOL_SSH,
+		})
+	}
 
-	waitGroup.Add(1)
-	go func() {
-		defer waitGroup.Done()
+	if config.RunObfuscatedSSHServer() {
+		listeners = append(listeners, &sshListener{
+			localAddress: fmt.Sprintf(
+				"%s:%d", config.ServerIPAddress, config.ObfuscatedSSHServerPort),
+			tunnelProtocol: psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
+		})
+	}
 
-	loop:
-		for {
-			conn, err := listener.Accept()
+	// TODO: add additional protocol listeners here (e.g, meek)
 
-			select {
-			case <-shutdownBroadcast:
-				if err == nil {
-					conn.Close()
-				}
-				break loop
-			default:
+	for i, listener := range listeners {
+		var err error
+		listener.Listener, err = net.Listen("tcp", listener.localAddress)
+		if err != nil {
+			for j := 0; j < i; j++ {
+				listener.Listener.Close()
 			}
+			return psiphon.ContextError(err)
+		}
+		log.WithContextFields(
+			LogFields{
+				"localAddress":   listener.localAddress,
+				"tunnelProtocol": listener.tunnelProtocol,
+			}).Info("listening")
+	}
 
-			if err != nil {
-				if e, ok := err.(net.Error); ok && e.Temporary() {
-					log.WithContextFields(LogFields{"error": err}).Error("accept failed")
-					// Temporary error, keep running
-					continue
-				}
-
-				select {
-				case errors <- psiphon.ContextError(err):
-				default:
-				}
+	for _, listener := range listeners {
+		sshServer.runWaitGroup.Add(1)
+		go func(listener *sshListener) {
+			defer sshServer.runWaitGroup.Done()
 
-				break loop
-			}
+			sshServer.runListener(
+				listener.Listener, listener.tunnelProtocol)
 
-			// process each client connection concurrently
-			go sshServer.handleClient(conn.(*net.TCPConn))
-		}
+			log.WithContextFields(
+				LogFields{
+					"localAddress":   listener.localAddress,
+					"tunnelProtocol": listener.tunnelProtocol,
+				}).Info("stopping")
 
-		sshServer.stopClients()
+		}(listener)
+	}
 
-		log.WithContextFields(
-			LogFields{"useObfuscation": useObfuscation}).Info("stopped")
-	}()
+	if config.RunLoadMonitor() {
+		sshServer.runWaitGroup.Add(1)
+		go func() {
+			defer sshServer.runWaitGroup.Done()
+			sshServer.runLoadMonitor()
+		}()
+	}
 
+	err = nil
 	select {
-	case <-shutdownBroadcast:
-	case err = <-errors:
+	case <-sshServer.shutdownBroadcast:
+	case err = <-sshServer.listenerError:
 	}
 
-	listener.Close()
-
-	waitGroup.Wait()
+	for _, listener := range listeners {
+		listener.Close()
+	}
+	sshServer.stopClients()
+	sshServer.runWaitGroup.Wait()
 
-	log.WithContextFields(
-		LogFields{"useObfuscation": useObfuscation}).Info("exiting")
+	log.WithContext().Info("stopped")
 
 	return err
 }
@@ -166,7 +164,8 @@ type sshClientID uint64
 
 type sshServer struct {
 	config            *Config
-	useObfuscation    bool
+	runWaitGroup      *sync.WaitGroup
+	listenerError     chan error
 	shutdownBroadcast <-chan struct{}
 	sshHostKey        ssh.Signer
 	nextClientID      sshClientID
@@ -175,6 +174,73 @@ type sshServer struct {
 	clients           map[sshClientID]*sshClient
 }
 
+func (sshServer *sshServer) runListener(
+	listener net.Listener, tunnelProtocol string) {
+
+	for {
+		conn, err := listener.Accept()
+
+		if err == nil && tunnelProtocol == psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
+			conn, err = psiphon.NewObfuscatedSshConn(
+				psiphon.OBFUSCATION_CONN_MODE_SERVER,
+				conn,
+				sshServer.config.ObfuscatedSSHKey)
+		}
+
+		select {
+		case <-sshServer.shutdownBroadcast:
+			if err == nil {
+				conn.Close()
+			}
+			return
+		default:
+		}
+
+		if err != nil {
+			if e, ok := err.(net.Error); ok && e.Temporary() {
+				log.WithContextFields(LogFields{"error": err}).Error("accept failed")
+				// Temporary error, keep running
+				continue
+			}
+
+			select {
+			case sshServer.listenerError <- psiphon.ContextError(err):
+			default:
+			}
+
+			return
+		}
+
+		// process each client connection concurrently
+		go sshServer.handleClient(tunnelProtocol, conn)
+	}
+}
+
+func (sshServer *sshServer) runLoadMonitor() {
+	ticker := time.NewTicker(
+		time.Duration(sshServer.config.LoadMonitorPeriodSeconds) * time.Second)
+	defer ticker.Stop()
+	for {
+		select {
+		case <-sshServer.shutdownBroadcast:
+			return
+		case <-ticker.C:
+			var memStats runtime.MemStats
+			runtime.ReadMemStats(&memStats)
+			fields := LogFields{
+				"goroutines":    runtime.NumGoroutine(),
+				"memAlloc":      memStats.Alloc,
+				"memTotalAlloc": memStats.TotalAlloc,
+				"memSysAlloc":   memStats.Sys,
+			}
+			for tunnelProtocol, count := range sshServer.countClients() {
+				fields[tunnelProtocol] = count
+			}
+			log.WithContextFields(fields).Info("load")
+		}
+	}
+}
+
 func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
 
 	sshServer.clientsMutex.Lock()
@@ -204,6 +270,18 @@ func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
 	}
 }
 
+func (sshServer *sshServer) countClients() map[string]int {
+
+	sshServer.clientsMutex.Lock()
+	defer sshServer.clientsMutex.Unlock()
+
+	counts := make(map[string]int)
+	for _, client := range sshServer.clients {
+		counts[client.tunnelProtocol] += 1
+	}
+	return counts
+}
+
 func (sshServer *sshServer) stopClients() {
 
 	sshServer.clientsMutex.Lock()
@@ -216,14 +294,17 @@ func (sshServer *sshServer) stopClients() {
 	}
 }
 
-func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
+func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
 
-	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
+	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(clientConn.RemoteAddr()))
 
 	sshClient := newSshClient(
-		sshServer, geoIPData, sshServer.config.GetTrafficRules(geoIPData.Country))
+		sshServer,
+		tunnelProtocol,
+		geoIPData,
+		sshServer.config.GetTrafficRules(geoIPData.Country))
 
-	// Wrap the base TCP connection with an IdleTimeoutConn which will terminate
+	// Wrap the base client connection with an IdleTimeoutConn which will terminate
 	// the connection if no data is received before the deadline. This timeout is
 	// in effect for the entire duration of the SSH connection. Clients must actively
 	// use the connection or send SSH keep alive requests to keep the connection
@@ -231,7 +312,7 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 
 	var conn net.Conn
 
-	conn = psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
+	conn = psiphon.NewIdleTimeoutConn(clientConn, SSH_CONNECTION_READ_DEADLINE, false)
 
 	// Further wrap the connection in a rate limiting ThrottledConn.
 
@@ -261,29 +342,25 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 		})
 	}
 
-	go func() {
-
-		result := &sshNewServerConnResult{}
-		if sshServer.useObfuscation {
-			result.conn, result.err = psiphon.NewObfuscatedSshConn(
-				psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey)
-		} else {
-			result.conn = conn
+	go func(conn net.Conn) {
+		sshServerConfig := &ssh.ServerConfig{
+			PasswordCallback: sshClient.passwordCallback,
+			AuthLogCallback:  sshClient.authLogCallback,
+			ServerVersion:    sshServer.config.SSHServerVersion,
 		}
-		if result.err == nil {
+		sshServerConfig.AddHostKey(sshServer.sshHostKey)
 
-			sshServerConfig := &ssh.ServerConfig{
-				PasswordCallback: sshClient.passwordCallback,
-				AuthLogCallback:  sshClient.authLogCallback,
-				ServerVersion:    sshServer.config.SSHServerVersion,
-			}
-			sshServerConfig.AddHostKey(sshServer.sshHostKey)
+		sshConn, channels, requests, err :=
+			ssh.NewServerConn(conn, sshServerConfig)
 
-			result.sshConn, result.channels, result.requests, result.err =
-				ssh.NewServerConn(result.conn, sshServerConfig)
+		resultChannel <- &sshNewServerConnResult{
+			conn:     conn,
+			sshConn:  sshConn,
+			channels: channels,
+			requests: requests,
+			err:      err,
 		}
-		resultChannel <- result
-	}()
+	}(conn)
 
 	var result *sshNewServerConnResult
 	select {
@@ -321,6 +398,7 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
 type sshClient struct {
 	sync.Mutex
 	sshServer               *sshServer
+	tunnelProtocol          string
 	sshConn                 ssh.Conn
 	startTime               time.Time
 	geoIPData               GeoIPData
@@ -341,9 +419,11 @@ type trafficState struct {
 	peakConcurrentPortForwardCount int64
 }
 
-func newSshClient(sshServer *sshServer, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
+func newSshClient(
+	sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
 	return &sshClient{
 		sshServer:               sshServer,
+		tunnelProtocol:          tunnelProtocol,
 		startTime:               time.Now(),
 		geoIPData:               geoIPData,
 		trafficRules:            trafficRules,
@@ -571,13 +651,14 @@ func (sshClient *sshClient) handleTCPChannel(
 		bytes, err := io.Copy(fwdChannel, fwdConn)
 		atomic.AddInt64(&bytesDown, bytes)
 		if err != nil && err != io.EOF {
-			log.WithContextFields(LogFields{"error": err}).Warning("downstream TCP relay failed")
+			// Debug since errors such as "connection reset by peer" occur during normal operation
+			log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
 		}
 	}()
 	bytes, err := io.Copy(fwdConn, fwdChannel)
 	atomic.AddInt64(&bytesUp, bytes)
 	if err != nil && err != io.EOF {
-		log.WithContextFields(LogFields{"error": err}).Warning("upstream TCP relay failed")
+		log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
 	}
 
 	// Shutdown special case: fwdChannel will be closed and return EOF when

+ 4 - 2
psiphon/server/udpChannel.go

@@ -211,7 +211,8 @@ func (mux *udpPortForwardMultiplexer) run() {
 		// Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP)
 		_, err = portForward.conn.Write(message.packet)
 		if err != nil {
-			log.WithContextFields(LogFields{"error": err}).Warning("upstream UDP relay failed")
+			// Debug since errors such as "write: operation not permitted" occur during normal operation
+			log.WithContextFields(LogFields{"error": err}).Debug("upstream UDP relay failed")
 			// The port forward's goroutine will complete cleanup
 			portForward.conn.Close()
 		}
@@ -301,6 +302,7 @@ func (portForward *udpPortForward) relayDownstream() {
 		}
 		if err != nil {
 			if err != io.EOF {
+				// Debug since errors such as "use of closed network connection" occur during normal operation
 				log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
 			}
 			break
@@ -320,7 +322,7 @@ func (portForward *udpPortForward) relayDownstream() {
 		if err != nil {
 			// Close the channel, which will interrupt the main loop.
 			portForward.mux.sshChannel.Close()
-			log.WithContextFields(LogFields{"error": err}).Warning("downstream UDP relay failed")
+			log.WithContextFields(LogFields{"error": err}).Debug("downstream UDP relay failed")
 			break
 		}