Parcourir la source

Add configurable SSH handshake timeouts

Rod Hynes il y a 6 ans
Parent
commit
a3e6163785
2 fichiers modifiés avec 28 ajouts et 3 suppressions
  1. 24 0
      psiphon/server/config.go
  2. 4 3
      psiphon/server/tunnelServer.go

+ 24 - 0
psiphon/server/config.go

@@ -32,6 +32,7 @@ import (
 	"net"
 	"strconv"
 	"strings"
+	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/accesscontrol"
@@ -157,6 +158,16 @@ type Config struct {
 	// protocols, run by this server instance, which use SSH.
 	SSHPassword string
 
+	// SSHBeginHandshakeTimeoutMilliseconds specifies the timeout
+	// for clients queueing to begin an SSH handshake. The default
+	// is SSH_BEGIN_HANDSHAKE_TIMEOUT.
+	SSHBeginHandshakeTimeoutMilliseconds *int
+
+	// SSHHandshakeTimeoutMilliseconds specifies the timeout
+	// before which a client must complete its handshake. The default
+	// is SSH_HANDSHAKE_TIMEOUT.
+	SSHHandshakeTimeoutMilliseconds *int
+
 	// ObfuscatedSSHKey is the secret key for use in the Obfuscated
 	// SSH protocol. The same secret key is used for all protocols,
 	// run by this server instance, which use Obfuscated SSH.
@@ -325,6 +336,9 @@ type Config struct {
 	// BlocklistActive indicates whether to actively prevent blocklist hits in
 	// addition to logging events.
 	BlocklistActive bool
+
+	sshBeginHandshakeTimeout time.Duration
+	sshHandshakeTimeout      time.Duration
 }
 
 // RunWebServer indicates whether to run a web server component.
@@ -417,6 +431,16 @@ func LoadConfig(configJSON []byte) (*Config, error) {
 		}
 	}
 
+	config.sshBeginHandshakeTimeout = SSH_BEGIN_HANDSHAKE_TIMEOUT
+	if config.SSHBeginHandshakeTimeoutMilliseconds != nil {
+		config.sshBeginHandshakeTimeout = time.Duration(*config.SSHBeginHandshakeTimeoutMilliseconds) * time.Millisecond
+	}
+
+	config.sshHandshakeTimeout = SSH_HANDSHAKE_TIMEOUT
+	if config.SSHHandshakeTimeoutMilliseconds != nil {
+		config.sshHandshakeTimeout = time.Duration(*config.SSHHandshakeTimeoutMilliseconds) * time.Millisecond
+	}
+
 	if config.ObfuscatedSSHKey != "" {
 		seed, err := protocol.DeriveSSHServerVersionPRNGSeed(config.ObfuscatedSSHKey)
 		if err != nil {

+ 4 - 3
psiphon/server/tunnelServer.go

@@ -889,7 +889,8 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	if sshServer.support.Config.MaxConcurrentSSHHandshakes > 0 {
 
 		ctx, cancelFunc := context.WithTimeout(
-			context.Background(), SSH_BEGIN_HANDSHAKE_TIMEOUT)
+			context.Background(),
+			sshServer.support.Config.sshBeginHandshakeTimeout)
 		defer cancelFunc()
 
 		err := sshServer.concurrentSSHHandshakes.Acquire(ctx, 1)
@@ -1097,8 +1098,8 @@ func (sshClient *sshClient) run(
 	resultChannel := make(chan *sshNewServerConnResult, 2)
 
 	var afterFunc *time.Timer
-	if SSH_HANDSHAKE_TIMEOUT > 0 {
-		afterFunc = time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
+	if sshClient.sshServer.support.Config.sshHandshakeTimeout > 0 {
+		afterFunc = time.AfterFunc(sshClient.sshServer.support.Config.sshHandshakeTimeout, func() {
 			resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
 		})
 	}