|
|
@@ -21,28 +21,43 @@ package server
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net"
|
|
|
"sync"
|
|
|
+ "time"
|
|
|
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
)
|
|
|
|
|
|
+func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
+ return runSSHServer(config, false, shutdownBroadcast)
|
|
|
+}
|
|
|
+
|
|
|
+func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
+ return runSSHServer(config, true, shutdownBroadcast)
|
|
|
+}
|
|
|
+
|
|
|
type sshServer struct {
|
|
|
- config *Config
|
|
|
- sshConfig *ssh.ServerConfig
|
|
|
- clientMutex sync.Mutex
|
|
|
- stoppingClients bool
|
|
|
- clients map[string]ssh.Conn
|
|
|
+ config *Config
|
|
|
+ useObfuscation bool
|
|
|
+ shutdownBroadcast <-chan struct{}
|
|
|
+ sshConfig *ssh.ServerConfig
|
|
|
+ clientMutex sync.Mutex
|
|
|
+ stoppingClients bool
|
|
|
+ clients map[string]ssh.Conn
|
|
|
}
|
|
|
|
|
|
-func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
+func runSSHServer(
|
|
|
+ config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
|
|
|
|
|
|
sshServer := &sshServer{
|
|
|
- config: config,
|
|
|
- clients: make(map[string]ssh.Conn),
|
|
|
+ config: config,
|
|
|
+ useObfuscation: useObfuscation,
|
|
|
+ shutdownBroadcast: shutdownBroadcast,
|
|
|
+ clients: make(map[string]ssh.Conn),
|
|
|
}
|
|
|
|
|
|
sshServer.sshConfig = &ssh.ServerConfig{
|
|
|
@@ -64,13 +79,21 @@ func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
|
|
|
sshServer.sshConfig.AddHostKey(signer)
|
|
|
|
|
|
+ var serverPort int
|
|
|
+ if useObfuscation {
|
|
|
+ serverPort = config.ObfuscatedSSHServerPort
|
|
|
+ } else {
|
|
|
+ serverPort = config.SSHServerPort
|
|
|
+ }
|
|
|
+
|
|
|
listener, err := net.Listen(
|
|
|
- "tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, config.SSHPort))
|
|
|
+ "tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, serverPort))
|
|
|
if err != nil {
|
|
|
return psiphon.ContextError(err)
|
|
|
}
|
|
|
|
|
|
- log.WithContext().Info("starting")
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{"useObfuscation": useObfuscation}).Info("starting")
|
|
|
|
|
|
err = nil
|
|
|
errors := make(chan error)
|
|
|
@@ -86,6 +109,9 @@ func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
|
|
|
select {
|
|
|
case <-shutdownBroadcast:
|
|
|
+ if err == nil {
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
break loop
|
|
|
default:
|
|
|
}
|
|
|
@@ -111,7 +137,8 @@ func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
|
|
|
sshServer.stopClients()
|
|
|
|
|
|
- log.WithContext().Info("stopped")
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{"useObfuscation": useObfuscation}).Info("stopped")
|
|
|
}()
|
|
|
|
|
|
select {
|
|
|
@@ -123,7 +150,8 @@ func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
|
|
|
|
|
|
waitGroup.Wait()
|
|
|
|
|
|
- log.WithContext().Info("exiting")
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{"useObfuscation": useObfuscation}).Info("exiting")
|
|
|
|
|
|
return err
|
|
|
}
|
|
|
@@ -192,27 +220,72 @@ func (sshServer *sshServer) stopClients() {
|
|
|
|
|
|
func (sshServer *sshServer) handleClient(conn net.Conn) {
|
|
|
|
|
|
- // TODO: does this block on SSH handshake (so should be in goroutine)?
|
|
|
- sshConn, channels, requests, err := ssh.NewServerConn(conn, sshServer.sshConfig)
|
|
|
- if err != nil {
|
|
|
+ // Run the initial [obfuscated] SSH handshake in a goroutine
|
|
|
+ // so we can both respect shutdownBroadcast and implement a
|
|
|
+ // handshake timeout. The timeout is to reclaim network
|
|
|
+ // resources in case the handshake takes too long.
|
|
|
+
|
|
|
+ type sshNewServerConnResult struct {
|
|
|
+ conn net.Conn
|
|
|
+ sshConn *ssh.ServerConn
|
|
|
+ channels <-chan ssh.NewChannel
|
|
|
+ requests <-chan *ssh.Request
|
|
|
+ err error
|
|
|
+ }
|
|
|
+
|
|
|
+ resultChannel := make(chan *sshNewServerConnResult, 2)
|
|
|
+
|
|
|
+ if SSH_HANDSHAKE_TIMEOUT > 0 {
|
|
|
+ time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
|
|
|
+ resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ result := &sshNewServerConnResult{}
|
|
|
+ if sshServer.useObfuscation {
|
|
|
+ result.conn, result.err = psiphon.NewObfuscatedSshConn(
|
|
|
+ psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey)
|
|
|
+ } else {
|
|
|
+ result.conn = conn
|
|
|
+ }
|
|
|
+ if result.err == nil {
|
|
|
+ result.sshConn, result.channels,
|
|
|
+ result.requests, result.err = ssh.NewServerConn(conn, sshServer.sshConfig)
|
|
|
+ }
|
|
|
+ resultChannel <- result
|
|
|
+ }()
|
|
|
+
|
|
|
+ var result *sshNewServerConnResult
|
|
|
+ select {
|
|
|
+ case result = <-resultChannel:
|
|
|
+ case <-sshServer.shutdownBroadcast:
|
|
|
+ // Close() will interrupt an ongoing handshake
|
|
|
+ // TODO: wait for goroutine to exit before returning?
|
|
|
conn.Close()
|
|
|
- log.WithContextFields(LogFields{"error": err}).Warning("establish failed")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if !sshServer.registerClient(sshConn) {
|
|
|
- sshConn.Close()
|
|
|
+ if result.err != nil {
|
|
|
+ conn.Close()
|
|
|
+ log.WithContextFields(LogFields{"error": result.err}).Warning("handshake failed")
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if !sshServer.registerClient(result.sshConn) {
|
|
|
+ result.sshConn.Close()
|
|
|
log.WithContext().Warning("register failed")
|
|
|
return
|
|
|
}
|
|
|
- defer sshServer.unregisterClient(sshConn)
|
|
|
+ defer sshServer.unregisterClient(result.sshConn)
|
|
|
|
|
|
// TODO: don't record IP; do GeoIP
|
|
|
- log.WithContextFields(LogFields{"remoteAddr": sshConn.RemoteAddr()}).Warning("connection accepted")
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{"remoteAddr": result.sshConn.RemoteAddr()}).Warning("connection accepted")
|
|
|
|
|
|
- go ssh.DiscardRequests(requests)
|
|
|
+ go ssh.DiscardRequests(result.requests)
|
|
|
|
|
|
- for newChannel := range channels {
|
|
|
+ for newChannel := range result.channels {
|
|
|
|
|
|
if newChannel.ChannelType() != "direct-tcpip" {
|
|
|
sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
|