Kaynağa Gözat

Upgrade x/crypto/ssh subtree to v0.17.0

- "Merge commit 'af6ffb3a97a16d7e752bb76be4755ef4535e50db'"

- `git subtree pull --prefix psiphon/common/crypto https://github.com/golang/crypto.git v0.17.0 --squash`

- Adds additional KEX algorithms

- Includes fix for CVE-2023-48795, although it's not always enabled
  (as documented in a comment, Psiphon's usage of SSH should not be
  vulnerable to downgrade attacks related to stripping SSH_MSG_EXT_INFO)

- Re-deleted all non-ssh packages; retained older crypto/internal
  package dependencies
Rod Hynes 2 yıl önce
ebeveyn
işleme
435a6a3f21
57 değiştirilmiş dosya ile 3766 ekleme ve 883 silme
  1. 0 3
      psiphon/common/crypto/AUTHORS
  2. 0 3
      psiphon/common/crypto/CONTRIBUTORS
  3. 120 0
      psiphon/common/crypto/internal/testenv/exec.go
  4. 15 0
      psiphon/common/crypto/internal/testenv/testenv_notunix.go
  5. 15 0
      psiphon/common/crypto/internal/testenv/testenv_unix.go
  6. 59 18
      psiphon/common/crypto/ssh/agent/client.go
  7. 21 36
      psiphon/common/crypto/ssh/agent/client_test.go
  8. 3 3
      psiphon/common/crypto/ssh/agent/keyring.go
  9. 2 2
      psiphon/common/crypto/ssh/agent/server.go
  10. 11 4
      psiphon/common/crypto/ssh/agent/server_test.go
  11. 4 3
      psiphon/common/crypto/ssh/benchmark_test.go
  12. 86 41
      psiphon/common/crypto/ssh/certs.go
  13. 108 19
      psiphon/common/crypto/ssh/certs_test.go
  14. 45 32
      psiphon/common/crypto/ssh/channel.go
  15. 8 8
      psiphon/common/crypto/ssh/cipher.go
  16. 1 1
      psiphon/common/crypto/ssh/cipher_test.go
  17. 9 18
      psiphon/common/crypto/ssh/client.go
  18. 176 38
      psiphon/common/crypto/ssh/client_auth.go
  19. 388 2
      psiphon/common/crypto/ssh/client_auth_test.go
  20. 116 5
      psiphon/common/crypto/ssh/client_test.go
  21. 116 52
      psiphon/common/crypto/ssh/common.go
  22. 5 5
      psiphon/common/crypto/ssh/common_test.go
  23. 2 2
      psiphon/common/crypto/ssh/connection.go
  24. 4 2
      psiphon/common/crypto/ssh/doc.go
  25. 86 7
      psiphon/common/crypto/ssh/example_test.go
  26. 207 53
      psiphon/common/crypto/ssh/handshake.go
  27. 476 17
      psiphon/common/crypto/ssh/handshake_test.go
  28. 96 117
      psiphon/common/crypto/ssh/kex.go
  29. 42 1
      psiphon/common/crypto/ssh/kex_test.go
  30. 396 142
      psiphon/common/crypto/ssh/keys.go
  31. 115 5
      psiphon/common/crypto/ssh/keys_test.go
  32. 1 1
      psiphon/common/crypto/ssh/knownhosts/knownhosts.go
  33. 7 0
      psiphon/common/crypto/ssh/mac.go
  34. 17 3
      psiphon/common/crypto/ssh/mempipe_test.go
  35. 31 6
      psiphon/common/crypto/ssh/messages.go
  36. 6 0
      psiphon/common/crypto/ssh/mux.go
  37. 201 78
      psiphon/common/crypto/ssh/mux_test.go
  38. 1 1
      psiphon/common/crypto/ssh/randomized_kex_test.go
  39. 102 20
      psiphon/common/crypto/ssh/server.go
  40. 140 0
      psiphon/common/crypto/ssh/server_test.go
  41. 4 4
      psiphon/common/crypto/ssh/session.go
  42. 134 18
      psiphon/common/crypto/ssh/session_test.go
  43. 35 0
      psiphon/common/crypto/ssh/tcpip.go
  44. 33 0
      psiphon/common/crypto/ssh/tcpip_test.go
  45. 1 3
      psiphon/common/crypto/ssh/test/agent_unix_test.go
  46. 1 3
      psiphon/common/crypto/ssh/test/banner_test.go
  47. 1 3
      psiphon/common/crypto/ssh/test/cert_test.go
  48. 8 6
      psiphon/common/crypto/ssh/test/dial_unix_test.go
  49. 7 16
      psiphon/common/crypto/ssh/test/forward_unix_test.go
  50. 5 7
      psiphon/common/crypto/ssh/test/multi_auth_test.go
  51. 98 0
      psiphon/common/crypto/ssh/test/server_test.go
  52. 26 21
      psiphon/common/crypto/ssh/test/session_test.go
  53. 100 0
      psiphon/common/crypto/ssh/test/sshcli_test.go
  54. 1 1
      psiphon/common/crypto/ssh/test/sshd_test_pw.c
  55. 31 43
      psiphon/common/crypto/ssh/test/test_unix_test.go
  56. 7 1
      psiphon/common/crypto/ssh/testdata/keys.go
  57. 36 9
      psiphon/common/crypto/ssh/transport.go

+ 0 - 3
psiphon/common/crypto/AUTHORS

@@ -1,3 +0,0 @@
-# This source code refers to The Go Authors for copyright purposes.
-# The master list of authors is in the main Go distribution,
-# visible at https://tip.golang.org/AUTHORS.

+ 0 - 3
psiphon/common/crypto/CONTRIBUTORS

@@ -1,3 +0,0 @@
-# This source code was written by the Go contributors.
-# The master list of contributors is in the main Go distribution,
-# visible at https://tip.golang.org/CONTRIBUTORS.

+ 120 - 0
psiphon/common/crypto/internal/testenv/exec.go

@@ -0,0 +1,120 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package testenv
+
+import (
+	"context"
+	"os"
+	"os/exec"
+	"reflect"
+	"strconv"
+	"testing"
+	"time"
+)
+
+// CommandContext is like exec.CommandContext, but:
+//   - skips t if the platform does not support os/exec,
+//   - sends SIGQUIT (if supported by the platform) instead of SIGKILL
+//     in its Cancel function
+//   - if the test has a deadline, adds a Context timeout and WaitDelay
+//     for an arbitrary grace period before the test's deadline expires,
+//   - fails the test if the command does not complete before the test's deadline, and
+//   - sets a Cleanup function that verifies that the test did not leak a subprocess.
+func CommandContext(t testing.TB, ctx context.Context, name string, args ...string) *exec.Cmd {
+	t.Helper()
+
+	var (
+		cancelCtx   context.CancelFunc
+		gracePeriod time.Duration // unlimited unless the test has a deadline (to allow for interactive debugging)
+	)
+
+	if t, ok := t.(interface {
+		testing.TB
+		Deadline() (time.Time, bool)
+	}); ok {
+		if td, ok := t.Deadline(); ok {
+			// Start with a minimum grace period, just long enough to consume the
+			// output of a reasonable program after it terminates.
+			gracePeriod = 100 * time.Millisecond
+			if s := os.Getenv("GO_TEST_TIMEOUT_SCALE"); s != "" {
+				scale, err := strconv.Atoi(s)
+				if err != nil {
+					t.Fatalf("invalid GO_TEST_TIMEOUT_SCALE: %v", err)
+				}
+				gracePeriod *= time.Duration(scale)
+			}
+
+			// If time allows, increase the termination grace period to 5% of the
+			// test's remaining time.
+			testTimeout := time.Until(td)
+			if gp := testTimeout / 20; gp > gracePeriod {
+				gracePeriod = gp
+			}
+
+			// When we run commands that execute subprocesses, we want to reserve two
+			// grace periods to clean up: one for the delay between the first
+			// termination signal being sent (via the Cancel callback when the Context
+			// expires) and the process being forcibly terminated (via the WaitDelay
+			// field), and a second one for the delay becween the process being
+			// terminated and and the test logging its output for debugging.
+			//
+			// (We want to ensure that the test process itself has enough time to
+			// log the output before it is also terminated.)
+			cmdTimeout := testTimeout - 2*gracePeriod
+
+			if cd, ok := ctx.Deadline(); !ok || time.Until(cd) > cmdTimeout {
+				// Either ctx doesn't have a deadline, or its deadline would expire
+				// after (or too close before) the test has already timed out.
+				// Add a shorter timeout so that the test will produce useful output.
+				ctx, cancelCtx = context.WithTimeout(ctx, cmdTimeout)
+			}
+		}
+	}
+
+	cmd := exec.CommandContext(ctx, name, args...)
+	// Set the Cancel and WaitDelay fields only if present (go 1.20 and later).
+	// TODO: When Go 1.19 is no longer supported, remove this use of reflection
+	// and instead set the fields directly.
+	if cmdCancel := reflect.ValueOf(cmd).Elem().FieldByName("Cancel"); cmdCancel.IsValid() {
+		cmdCancel.Set(reflect.ValueOf(func() error {
+			if cancelCtx != nil && ctx.Err() == context.DeadlineExceeded {
+				// The command timed out due to running too close to the test's deadline.
+				// There is no way the test did that intentionally — it's too close to the
+				// wire! — so mark it as a test failure. That way, if the test expects the
+				// command to fail for some other reason, it doesn't have to distinguish
+				// between that reason and a timeout.
+				t.Errorf("test timed out while running command: %v", cmd)
+			} else {
+				// The command is being terminated due to ctx being canceled, but
+				// apparently not due to an explicit test deadline that we added.
+				// Log that information in case it is useful for diagnosing a failure,
+				// but don't actually fail the test because of it.
+				t.Logf("%v: terminating command: %v", ctx.Err(), cmd)
+			}
+			return cmd.Process.Signal(Sigquit)
+		}))
+	}
+	if cmdWaitDelay := reflect.ValueOf(cmd).Elem().FieldByName("WaitDelay"); cmdWaitDelay.IsValid() {
+		cmdWaitDelay.Set(reflect.ValueOf(gracePeriod))
+	}
+
+	t.Cleanup(func() {
+		if cancelCtx != nil {
+			cancelCtx()
+		}
+		if cmd.Process != nil && cmd.ProcessState == nil {
+			t.Errorf("command was started, but test did not wait for it to complete: %v", cmd)
+		}
+	})
+
+	return cmd
+}
+
+// Command is like exec.Command, but applies the same changes as
+// testenv.CommandContext (with a default Context).
+func Command(t testing.TB, name string, args ...string) *exec.Cmd {
+	t.Helper()
+	return CommandContext(t, context.Background(), name, args...)
+}

+ 15 - 0
psiphon/common/crypto/internal/testenv/testenv_notunix.go

@@ -0,0 +1,15 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows || plan9 || (js && wasm) || wasip1
+
+package testenv
+
+import (
+	"os"
+)
+
+// Sigquit is the signal to send to kill a hanging subprocess.
+// On Unix we send SIGQUIT, but on non-Unix we only have os.Kill.
+var Sigquit = os.Kill

+ 15 - 0
psiphon/common/crypto/internal/testenv/testenv_unix.go

@@ -0,0 +1,15 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package testenv
+
+import (
+	"syscall"
+)
+
+// Sigquit is the signal to send to kill a hanging subprocess.
+// Send SIGQUIT to get a stack trace.
+var Sigquit = syscall.SIGQUIT

+ 59 - 18
psiphon/common/crypto/ssh/agent/client.go

@@ -8,13 +8,15 @@
 // ssh-agent process using the sample server.
 //
 // References:
-//  [PROTOCOL.agent]: https://tools.ietf.org/html/draft-miller-ssh-agent-00
+//
+//	[PROTOCOL.agent]: https://tools.ietf.org/html/draft-miller-ssh-agent-00
 package agent // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/agent"
 
 import (
 	"bytes"
 	"crypto/dsa"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/elliptic"
 	"crypto/rsa"
 	"encoding/base64"
@@ -25,9 +27,7 @@ import (
 	"math/big"
 	"sync"
 
-	"crypto"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
-	"golang.org/x/crypto/ed25519"
 )
 
 // SignatureFlags represent additional flags that can be passed to the signature
@@ -93,7 +93,7 @@ type ExtendedAgent interface {
 type ConstraintExtension struct {
 	// ExtensionName consist of a UTF-8 string suffixed by the
 	// implementation domain following the naming scheme defined
-	// in Section 4.2 of [RFC4251], e.g.  "foo@example.com".
+	// in Section 4.2 of RFC 4251, e.g.  "foo@example.com".
 	ExtensionName string
 	// ExtensionDetails contains the actual content of the extended
 	// constraint.
@@ -141,9 +141,14 @@ const (
 	agentAddSmartcardKeyConstrained = 26
 
 	// 3.7 Key constraint identifiers
-	agentConstrainLifetime  = 1
-	agentConstrainConfirm   = 2
-	agentConstrainExtension = 3
+	agentConstrainLifetime = 1
+	agentConstrainConfirm  = 2
+	// Constraint extension identifier up to version 2 of the protocol. A
+	// backward incompatible change will be required if we want to add support
+	// for SSH_AGENT_CONSTRAIN_MAXSIGN which uses the same ID.
+	agentConstrainExtensionV00 = 3
+	// Constraint extension identifier in version 3 and later of the protocol.
+	agentConstrainExtension = 255
 )
 
 // maxAgentResponseBytes is the maximum agent reply size that is accepted. This
@@ -205,7 +210,7 @@ type constrainLifetimeAgentMsg struct {
 }
 
 type constrainExtensionAgentMsg struct {
-	ExtensionName    string `sshtype:"3"`
+	ExtensionName    string `sshtype:"255|3"`
 	ExtensionDetails []byte
 
 	// Rest is a field used for parsing, not part of message
@@ -226,7 +231,9 @@ var ErrExtensionUnsupported = errors.New("agent: extension unsupported")
 
 type extensionAgentMsg struct {
 	ExtensionType string `sshtype:"27"`
-	Contents      []byte
+	// NOTE: this matches OpenSSH's PROTOCOL.agent, not the IETF draft [PROTOCOL.agent],
+	// so that it matches what OpenSSH actually implements in the wild.
+	Contents []byte `ssh:"rest"`
 }
 
 // Key represents a protocol 2 public key as defined in
@@ -729,7 +736,7 @@ func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string
 	if err != nil {
 		return err
 	}
-	if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
+	if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
 		return errors.New("agent: signer and cert have different public key")
 	}
 
@@ -771,19 +778,53 @@ func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature,
 	return s.agent.Sign(s.pub, data)
 }
 
-func (s *agentKeyringSigner) SignWithOpts(rand io.Reader, data []byte, opts crypto.SignerOpts) (*ssh.Signature, error) {
+func (s *agentKeyringSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
+	if algorithm == "" || algorithm == underlyingAlgo(s.pub.Type()) {
+		return s.Sign(rand, data)
+	}
+
 	var flags SignatureFlags
-	if opts != nil {
-		switch opts.HashFunc() {
-		case crypto.SHA256:
-			flags = SignatureFlagRsaSha256
-		case crypto.SHA512:
-			flags = SignatureFlagRsaSha512
-		}
+	switch algorithm {
+	case ssh.KeyAlgoRSASHA256:
+		flags = SignatureFlagRsaSha256
+	case ssh.KeyAlgoRSASHA512:
+		flags = SignatureFlagRsaSha512
+	default:
+		return nil, fmt.Errorf("agent: unsupported algorithm %q", algorithm)
 	}
+
 	return s.agent.SignWithFlags(s.pub, data, flags)
 }
 
+var _ ssh.AlgorithmSigner = &agentKeyringSigner{}
+
+// certKeyAlgoNames is a mapping from known certificate algorithm names to the
+// corresponding public key signature algorithm.
+//
+// This map must be kept in sync with the one in certs.go.
+var certKeyAlgoNames = map[string]string{
+	ssh.CertAlgoRSAv01:        ssh.KeyAlgoRSA,
+	ssh.CertAlgoRSASHA256v01:  ssh.KeyAlgoRSASHA256,
+	ssh.CertAlgoRSASHA512v01:  ssh.KeyAlgoRSASHA512,
+	ssh.CertAlgoDSAv01:        ssh.KeyAlgoDSA,
+	ssh.CertAlgoECDSA256v01:   ssh.KeyAlgoECDSA256,
+	ssh.CertAlgoECDSA384v01:   ssh.KeyAlgoECDSA384,
+	ssh.CertAlgoECDSA521v01:   ssh.KeyAlgoECDSA521,
+	ssh.CertAlgoSKECDSA256v01: ssh.KeyAlgoSKECDSA256,
+	ssh.CertAlgoED25519v01:    ssh.KeyAlgoED25519,
+	ssh.CertAlgoSKED25519v01:  ssh.KeyAlgoSKED25519,
+}
+
+// underlyingAlgo returns the signature algorithm associated with algo (which is
+// an advertised or negotiated public key or host key algorithm). These are
+// usually the same, except for certificate algorithms.
+func underlyingAlgo(algo string) string {
+	if a, ok := certKeyAlgoNames[algo]; ok {
+		return a
+	}
+	return algo
+}
+
 // Calls an extension method. It is up to the agent implementation as to whether or not
 // any particular extension is supported and may always return an error. Because the
 // type of the response is up to the implementation, this returns the bytes of the

+ 21 - 36
psiphon/common/crypto/ssh/agent/client_test.go

@@ -16,7 +16,6 @@ import (
 	"runtime"
 	"strconv"
 	"strings"
-	"sync"
 	"testing"
 	"time"
 
@@ -30,6 +29,9 @@ func startOpenSSHAgent(t *testing.T) (client ExtendedAgent, socket string, clean
 		// types supported vary by platform.
 		t.Skip("skipping test due to -short")
 	}
+	if runtime.GOOS == "windows" {
+		t.Skip("skipping on windows, we don't support connecting to the ssh-agent via a named pipe")
+	}
 
 	bin, err := exec.LookPath("ssh-agent")
 	if err != nil {
@@ -183,9 +185,9 @@ func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert
 				t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
 			}
 		}
-		sshFlagTest(0, ssh.SigAlgoRSA)
-		sshFlagTest(SignatureFlagRsaSha256, ssh.SigAlgoRSASHA2256)
-		sshFlagTest(SignatureFlagRsaSha512, ssh.SigAlgoRSASHA2512)
+		sshFlagTest(0, ssh.KeyAlgoRSA)
+		sshFlagTest(SignatureFlagRsaSha256, ssh.KeyAlgoRSASHA256)
+		sshFlagTest(SignatureFlagRsaSha512, ssh.KeyAlgoRSASHA512)
 	}
 
 	// If the key has a lifetime, is it removed when it should be?
@@ -204,44 +206,26 @@ func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert
 
 func TestMalformedRequests(t *testing.T) {
 	keyringAgent := NewKeyring()
-	listener, err := netListener()
-	if err != nil {
-		t.Fatalf("netListener: %v", err)
-	}
-	defer listener.Close()
 
 	testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) {
-		var wg sync.WaitGroup
-		wg.Add(1)
+		c, s := net.Pipe()
+		defer c.Close()
+		defer s.Close()
 		go func() {
-			defer wg.Done()
-			c, err := listener.Accept()
+			_, err := c.Write(requestBytes)
 			if err != nil {
-				t.Errorf("listener.Accept: %v", err)
-				return
-			}
-			defer c.Close()
-
-			err = ServeAgent(keyringAgent, c)
-			if err == nil {
-				t.Error("ServeAgent should have returned an error to malformed input")
-			} else {
-				if (err != io.EOF) != wantServerErr {
-					t.Errorf("ServeAgent returned expected error: %v", err)
-				}
+				t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
 			}
+			c.Close()
 		}()
-
-		c, err := net.Dial("tcp", listener.Addr().String())
-		if err != nil {
-			t.Fatalf("net.Dial: %v", err)
-		}
-		_, err = c.Write(requestBytes)
-		if err != nil {
-			t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
+		err := ServeAgent(keyringAgent, s)
+		if err == nil {
+			t.Error("ServeAgent should have returned an error to malformed input")
+		} else {
+			if (err != io.EOF) != wantServerErr {
+				t.Errorf("ServeAgent returned expected error: %v", err)
+			}
 		}
-		c.Close()
-		wg.Wait()
 	}
 
 	var testCases = []struct {
@@ -385,7 +369,8 @@ func TestAuth(t *testing.T) {
 	go func() {
 		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
 		if err != nil {
-			t.Fatalf("Server: %v", err)
+			t.Errorf("NewServerConn error: %v", err)
+			return
 		}
 		conn.Close()
 	}()

+ 3 - 3
psiphon/common/crypto/ssh/agent/keyring.go

@@ -113,7 +113,7 @@ func (r *keyring) Unlock(passphrase []byte) error {
 
 // expireKeysLocked removes expired keys from the keyring. If a key was added
 // with a lifetimesecs contraint and seconds >= lifetimesecs seconds have
-// ellapsed, it is removed. The caller *must* be holding the keyring mutex.
+// elapsed, it is removed. The caller *must* be holding the keyring mutex.
 func (r *keyring) expireKeysLocked() {
 	for _, k := range r.keys {
 		if k.expire != nil && time.Now().After(*k.expire) {
@@ -205,9 +205,9 @@ func (r *keyring) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureF
 					var algorithm string
 					switch flags {
 					case SignatureFlagRsaSha256:
-						algorithm = ssh.SigAlgoRSASHA2256
+						algorithm = ssh.KeyAlgoRSASHA256
 					case SignatureFlagRsaSha512:
-						algorithm = ssh.SigAlgoRSASHA2512
+						algorithm = ssh.KeyAlgoRSASHA512
 					default:
 						return nil, fmt.Errorf("agent: unsupported signature flags: %d", flags)
 					}

+ 2 - 2
psiphon/common/crypto/ssh/agent/server.go

@@ -20,7 +20,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 )
 
-// Server wraps an Agent and uses it to implement the agent side of
+// server wraps an Agent and uses it to implement the agent side of
 // the SSH-agent, wire protocol.
 type server struct {
 	agent Agent
@@ -208,7 +208,7 @@ func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse
 		case agentConstrainConfirm:
 			confirmBeforeUse = true
 			constraints = constraints[1:]
-		case agentConstrainExtension:
+		case agentConstrainExtension, agentConstrainExtensionV00:
 			var msg constrainExtensionAgentMsg
 			if err = ssh.Unmarshal(constraints, &msg); err != nil {
 				return 0, false, nil, err

+ 11 - 4
psiphon/common/crypto/ssh/agent/server_test.go

@@ -53,10 +53,11 @@ func TestSetupForwardAgent(t *testing.T) {
 	incoming := make(chan *ssh.ServerConn, 1)
 	go func() {
 		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
+		incoming <- conn
 		if err != nil {
-			t.Fatalf("Server: %v", err)
+			t.Errorf("NewServerConn error: %v", err)
+			return
 		}
-		incoming <- conn
 	}()
 
 	conf := ssh.ClientConfig{
@@ -71,8 +72,10 @@ func TestSetupForwardAgent(t *testing.T) {
 	if err := ForwardToRemote(client, socket); err != nil {
 		t.Fatalf("SetupForwardAgent: %v", err)
 	}
-
 	server := <-incoming
+	if server == nil {
+		t.Fatal("Unable to get server")
+	}
 	ch, reqs, err := server.OpenChannel(channelType, nil)
 	if err != nil {
 		t.Fatalf("OpenChannel(%q): %v", channelType, err)
@@ -240,7 +243,11 @@ func TestParseConstraints(t *testing.T) {
 			ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
 		}
 		expect = append(expect, ext)
-		data = append(data, agentConstrainExtension)
+		if i%2 == 0 {
+			data = append(data, agentConstrainExtension)
+		} else {
+			data = append(data, agentConstrainExtensionV00)
+		}
 		data = append(data, ssh.Marshal(ext)...)
 	}
 	_, _, extensions, err := parseConstraints(data)

+ 4 - 3
psiphon/common/crypto/ssh/benchmark_test.go

@@ -6,6 +6,7 @@ package ssh
 
 import (
 	"errors"
+	"fmt"
 	"io"
 	"net"
 	"testing"
@@ -90,16 +91,16 @@ func BenchmarkEndToEnd(b *testing.B) {
 	go func() {
 		newCh, err := server.Accept()
 		if err != nil {
-			b.Fatalf("Client: %v", err)
+			panic(fmt.Sprintf("Client: %v", err))
 		}
 		ch, incoming, err := newCh.Accept()
 		if err != nil {
-			b.Fatalf("Accept: %v", err)
+			panic(fmt.Sprintf("Accept: %v", err))
 		}
 		go DiscardRequests(incoming)
 		for i := 0; i < b.N; i++ {
 			if _, err := io.ReadFull(ch, output); err != nil {
-				b.Fatalf("ReadFull: %v", err)
+				panic(fmt.Sprintf("ReadFull: %v", err))
 			}
 		}
 		ch.Close()

+ 86 - 41
psiphon/common/crypto/ssh/certs.go

@@ -14,8 +14,11 @@ import (
 	"time"
 )
 
-// These constants from [PROTOCOL.certkeys] represent the key algorithm names
-// for certificate types supported by this package.
+// Certificate algorithm names from [PROTOCOL.certkeys]. These values can appear
+// in Certificate.Type, PublicKey.Type, and ClientConfig.HostKeyAlgorithms.
+// Unlike key algorithm names, these are not passed to AlgorithmSigner nor
+// returned by MultiAlgorithmSigner and don't appear in the Signature.Format
+// field.
 const (
 	CertAlgoRSAv01        = "ssh-rsa-cert-v01@openssh.com"
 	CertAlgoDSAv01        = "ssh-dss-cert-v01@openssh.com"
@@ -25,14 +28,21 @@ const (
 	CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com"
 	CertAlgoED25519v01    = "ssh-ed25519-cert-v01@openssh.com"
 	CertAlgoSKED25519v01  = "sk-ssh-ed25519-cert-v01@openssh.com"
+
+	// CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a
+	// Certificate.Type (or PublicKey.Type), but only in
+	// ClientConfig.HostKeyAlgorithms.
+	CertAlgoRSASHA256v01 = "rsa-sha2-256-cert-v01@openssh.com"
+	CertAlgoRSASHA512v01 = "rsa-sha2-512-cert-v01@openssh.com"
 )
 
-// These constants from [PROTOCOL.certkeys] represent additional signature
-// algorithm names for certificate types supported by this package.
 const (
-	CertSigAlgoRSAv01        = "ssh-rsa-cert-v01@openssh.com"
-	CertSigAlgoRSASHA2256v01 = "rsa-sha2-256-cert-v01@openssh.com"
-	CertSigAlgoRSASHA2512v01 = "rsa-sha2-512-cert-v01@openssh.com"
+	// Deprecated: use CertAlgoRSAv01.
+	CertSigAlgoRSAv01 = CertAlgoRSAv01
+	// Deprecated: use CertAlgoRSASHA256v01.
+	CertSigAlgoRSASHA2256v01 = CertAlgoRSASHA256v01
+	// Deprecated: use CertAlgoRSASHA512v01.
+	CertSigAlgoRSASHA2512v01 = CertAlgoRSASHA512v01
 )
 
 // Certificate types distinguish between host and user
@@ -242,14 +252,21 @@ type algorithmOpenSSHCertSigner struct {
 // private key is held by signer. It returns an error if the public key in cert
 // doesn't match the key used by signer.
 func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
-	if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
+	if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
 		return nil, errors.New("ssh: signer and cert have different public key")
 	}
 
-	if algorithmSigner, ok := signer.(AlgorithmSigner); ok {
+	switch s := signer.(type) {
+	case MultiAlgorithmSigner:
+		return &multiAlgorithmSigner{
+			AlgorithmSigner: &algorithmOpenSSHCertSigner{
+				&openSSHCertSigner{cert, signer}, s},
+			supportedAlgorithms: s.Algorithms(),
+		}, nil
+	case AlgorithmSigner:
 		return &algorithmOpenSSHCertSigner{
-			&openSSHCertSigner{cert, signer}, algorithmSigner}, nil
-	} else {
+			&openSSHCertSigner{cert, signer}, s}, nil
+	default:
 		return &openSSHCertSigner{cert, signer}, nil
 	}
 }
@@ -423,7 +440,9 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
 }
 
 // SignCert signs the certificate with an authority, setting the Nonce,
-// SignatureKey, and Signature fields.
+// SignatureKey, and Signature fields. If the authority implements the
+// MultiAlgorithmSigner interface the first algorithm in the list is used. This
+// is useful if you want to sign with a specific algorithm.
 func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
 	c.Nonce = make([]byte, 32)
 	if _, err := io.ReadFull(rand, c.Nonce); err != nil {
@@ -431,10 +450,26 @@ func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
 	}
 	c.SignatureKey = authority.PublicKey()
 
-	if v, ok := authority.(AlgorithmSigner); ok {
-		if v.PublicKey().Type() == KeyAlgoRSA {
-			authority = &rsaSigner{v, SigAlgoRSASHA2512}
+	if v, ok := authority.(MultiAlgorithmSigner); ok {
+		if len(v.Algorithms()) == 0 {
+			return errors.New("the provided authority has no signature algorithm")
+		}
+		// Use the first algorithm in the list.
+		sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), v.Algorithms()[0])
+		if err != nil {
+			return err
 		}
+		c.Signature = sig
+		return nil
+	} else if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA {
+		// Default to KeyAlgoRSASHA512 for ssh-rsa signers.
+		// TODO: consider using KeyAlgoRSASHA256 as default.
+		sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), KeyAlgoRSASHA512)
+		if err != nil {
+			return err
+		}
+		c.Signature = sig
+		return nil
 	}
 
 	sig, err := authority.Sign(rand, c.bytesForSigning())
@@ -445,32 +480,42 @@ func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
 	return nil
 }
 
-// certAlgoNames includes a mapping from signature algorithms to the
-// corresponding certificate signature algorithm. When a key type (such
-// as ED25516) is associated with only one algorithm, the KeyAlgo
-// constant is used instead of the SigAlgo.
-var certAlgoNames = map[string]string{
-	SigAlgoRSA:        CertSigAlgoRSAv01,
-	SigAlgoRSASHA2256: CertSigAlgoRSASHA2256v01,
-	SigAlgoRSASHA2512: CertSigAlgoRSASHA2512v01,
-	KeyAlgoDSA:        CertAlgoDSAv01,
-	KeyAlgoECDSA256:   CertAlgoECDSA256v01,
-	KeyAlgoECDSA384:   CertAlgoECDSA384v01,
-	KeyAlgoECDSA521:   CertAlgoECDSA521v01,
-	KeyAlgoSKECDSA256: CertAlgoSKECDSA256v01,
-	KeyAlgoED25519:    CertAlgoED25519v01,
-	KeyAlgoSKED25519:  CertAlgoSKED25519v01,
+// certKeyAlgoNames is a mapping from known certificate algorithm names to the
+// corresponding public key signature algorithm.
+//
+// This map must be kept in sync with the one in agent/client.go.
+var certKeyAlgoNames = map[string]string{
+	CertAlgoRSAv01:        KeyAlgoRSA,
+	CertAlgoRSASHA256v01:  KeyAlgoRSASHA256,
+	CertAlgoRSASHA512v01:  KeyAlgoRSASHA512,
+	CertAlgoDSAv01:        KeyAlgoDSA,
+	CertAlgoECDSA256v01:   KeyAlgoECDSA256,
+	CertAlgoECDSA384v01:   KeyAlgoECDSA384,
+	CertAlgoECDSA521v01:   KeyAlgoECDSA521,
+	CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256,
+	CertAlgoED25519v01:    KeyAlgoED25519,
+	CertAlgoSKED25519v01:  KeyAlgoSKED25519,
+}
+
+// underlyingAlgo returns the signature algorithm associated with algo (which is
+// an advertised or negotiated public key or host key algorithm). These are
+// usually the same, except for certificate algorithms.
+func underlyingAlgo(algo string) string {
+	if a, ok := certKeyAlgoNames[algo]; ok {
+		return a
+	}
+	return algo
 }
 
-// certToPrivAlgo returns the underlying algorithm for a certificate algorithm.
-// Panics if a non-certificate algorithm is passed.
-func certToPrivAlgo(algo string) string {
-	for privAlgo, pubAlgo := range certAlgoNames {
-		if pubAlgo == algo {
-			return privAlgo
+// certificateAlgo returns the certificate algorithms that uses the provided
+// underlying signature algorithm.
+func certificateAlgo(algo string) (certAlgo string, ok bool) {
+	for certName, algoName := range certKeyAlgoNames {
+		if algoName == algo {
+			return certName, true
 		}
 	}
-	panic("unknown cert algorithm")
+	return "", false
 }
 
 func (cert *Certificate) bytesForSigning() []byte {
@@ -514,13 +559,13 @@ func (c *Certificate) Marshal() []byte {
 	return result
 }
 
-// Type returns the key name. It is part of the PublicKey interface.
+// Type returns the certificate algorithm name. It is part of the PublicKey interface.
 func (c *Certificate) Type() string {
-	algo, ok := certAlgoNames[c.Key.Type()]
+	certName, ok := certificateAlgo(c.Key.Type())
 	if !ok {
-		panic("unknown cert key type " + c.Key.Type())
+		panic("unknown certificate type for key type " + c.Key.Type())
 	}
-	return algo
+	return certName
 }
 
 // Verify verifies a signature against the certificate's public

+ 108 - 19
psiphon/common/crypto/ssh/certs_test.go

@@ -49,14 +49,17 @@ func TestParseCert(t *testing.T) {
 // % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub
 // user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN
 // Critical Options:
-//         force-command /bin/sleep
-//         source-address 192.168.1.0/24
+//
+//	force-command /bin/sleep
+//	source-address 192.168.1.0/24
+//
 // Extensions:
-//         permit-X11-forwarding
-//         permit-agent-forwarding
-//         permit-port-forwarding
-//         permit-pty
-//         permit-user-rc
+//
+//	permit-X11-forwarding
+//	permit-agent-forwarding
+//	permit-port-forwarding
+//	permit-pty
+//	permit-user-rc
 const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ`
 
 func TestParseCertWithOptions(t *testing.T) {
@@ -184,10 +187,30 @@ func TestHostKeyCert(t *testing.T) {
 	}
 
 	for _, test := range []struct {
-		addr    string
-		succeed bool
+		addr                    string
+		succeed                 bool
+		certSignerAlgorithms    []string // Empty means no algorithm restrictions.
+		clientHostKeyAlgorithms []string
 	}{
 		{addr: "hostname:22", succeed: true},
+		{
+			addr:                    "hostname:22",
+			succeed:                 true,
+			certSignerAlgorithms:    []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
+			clientHostKeyAlgorithms: []string{CertAlgoRSASHA512v01},
+		},
+		{
+			addr:                    "hostname:22",
+			succeed:                 false,
+			certSignerAlgorithms:    []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
+			clientHostKeyAlgorithms: []string{CertAlgoRSAv01},
+		},
+		{
+			addr:                    "hostname:22",
+			succeed:                 false,
+			certSignerAlgorithms:    []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
+			clientHostKeyAlgorithms: []string{KeyAlgoRSASHA512}, // Not a certificate algorithm.
+		},
 		{addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22'
 		{addr: "lasthost:22", succeed: false},
 	} {
@@ -204,24 +227,34 @@ func TestHostKeyCert(t *testing.T) {
 			conf := ServerConfig{
 				NoClientAuth: true,
 			}
-			conf.AddHostKey(certSigner)
+			if len(test.certSignerAlgorithms) > 0 {
+				mas, err := NewSignerWithAlgorithms(certSigner.(AlgorithmSigner), test.certSignerAlgorithms)
+				if err != nil {
+					errc <- err
+					return
+				}
+				conf.AddHostKey(mas)
+			} else {
+				conf.AddHostKey(certSigner)
+			}
 			_, _, _, err := NewServerConn(c1, &conf)
 			errc <- err
 		}()
 
 		config := &ClientConfig{
-			User:            "user",
-			HostKeyCallback: checker.CheckHostKey,
+			User:              "user",
+			HostKeyCallback:   checker.CheckHostKey,
+			HostKeyAlgorithms: test.clientHostKeyAlgorithms,
 		}
 		_, _, _, err = NewClientConn(c2, test.addr, config)
 
 		if (err == nil) != test.succeed {
-			t.Fatalf("NewClientConn(%q): %v", test.addr, err)
+			t.Errorf("NewClientConn(%q): %v", test.addr, err)
 		}
 
 		err = <-errc
 		if (err == nil) != test.succeed {
-			t.Fatalf("NewServerConn(%q): %v", test.addr, err)
+			t.Errorf("NewServerConn(%q): %v", test.addr, err)
 		}
 	}
 }
@@ -235,10 +268,24 @@ func (s *legacyRSASigner) Sign(rand io.Reader, data []byte) (*Signature, error)
 	if !ok {
 		return nil, fmt.Errorf("invalid signer")
 	}
-	return v.SignWithAlgorithm(rand, data, SigAlgoRSA)
+	return v.SignWithAlgorithm(rand, data, KeyAlgoRSA)
 }
 
 func TestCertTypes(t *testing.T) {
+	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	multiAlgoSignerSHA256, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer SHA256: %v", err)
+	}
+	// Algorithms are in order of preference, we expect rsa-sha2-512 to be used.
+	multiAlgoSignerSHA512, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer SHA512: %v", err)
+	}
+
 	var testVars = []struct {
 		name   string
 		signer Signer
@@ -248,10 +295,10 @@ func TestCertTypes(t *testing.T) {
 		{CertAlgoECDSA384v01, testSigners["ecdsap384"], ""},
 		{CertAlgoECDSA521v01, testSigners["ecdsap521"], ""},
 		{CertAlgoED25519v01, testSigners["ed25519"], ""},
-		{CertAlgoRSAv01, testSigners["rsa"], SigAlgoRSASHA2512},
-		{CertAlgoRSAv01, &legacyRSASigner{testSigners["rsa"]}, SigAlgoRSA},
-		{CertAlgoRSAv01, testSigners["rsa-sha2-256"], SigAlgoRSASHA2512},
-		{CertAlgoRSAv01, testSigners["rsa-sha2-512"], SigAlgoRSASHA2512},
+		{CertAlgoRSAv01, testSigners["rsa"], KeyAlgoRSASHA256},
+		{"legacyRSASigner", &legacyRSASigner{testSigners["rsa"]}, KeyAlgoRSA},
+		{"multiAlgoRSASignerSHA256", multiAlgoSignerSHA256, KeyAlgoRSASHA256},
+		{"multiAlgoRSASignerSHA512", multiAlgoSignerSHA512, KeyAlgoRSASHA512},
 		{CertAlgoDSAv01, testSigners["dsa"], ""},
 	}
 
@@ -317,3 +364,45 @@ func TestCertTypes(t *testing.T) {
 		})
 	}
 }
+
+func TestCertSignWithMultiAlgorithmSigner(t *testing.T) {
+	type testcase struct {
+		sigAlgo   string
+		algoritms []string
+	}
+	cases := []testcase{
+		{
+			sigAlgo:   KeyAlgoRSA,
+			algoritms: []string{KeyAlgoRSA, KeyAlgoRSASHA512},
+		},
+		{
+			sigAlgo:   KeyAlgoRSASHA256,
+			algoritms: []string{KeyAlgoRSASHA256, KeyAlgoRSA, KeyAlgoRSASHA512},
+		},
+		{
+			sigAlgo:   KeyAlgoRSASHA512,
+			algoritms: []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256},
+		},
+	}
+
+	cert := &Certificate{
+		Key:         testPublicKeys["rsa"],
+		ValidBefore: CertTimeInfinity,
+		CertType:    UserCert,
+	}
+
+	for _, c := range cases {
+		t.Run(c.sigAlgo, func(t *testing.T) {
+			signer, err := NewSignerWithAlgorithms(testSigners["rsa"].(AlgorithmSigner), c.algoritms)
+			if err != nil {
+				t.Fatalf("NewSignerWithAlgorithms error: %v", err)
+			}
+			if err := cert.SignCert(rand.Reader, signer); err != nil {
+				t.Fatalf("SignCert error: %v", err)
+			}
+			if cert.Signature.Format != c.sigAlgo {
+				t.Fatalf("got signature format %q, want %q", cert.Signature.Format, c.sigAlgo)
+			}
+		})
+	}
+}

+ 45 - 32
psiphon/common/crypto/ssh/channel.go

@@ -30,30 +30,30 @@ const (
 
 // [Psiphon]
 //
-// - Use a smaller initial/max channel window size.
-// - Testing with the full Psiphon stack shows that
-//   this smaller channel window size is more performant
-//   for low bandwidth connections while still adequate for
-//   higher bandwidth connections.
-// - In Psiphon, a single SSH connection is used for all
-//   client port forwards. Bulk data transfers with large
-//   channel windows can immediately backlog the connection
-//   with many large SSH packets, introducing large latency
-//   for opening new channels. For Psiphon, we don't wish to
-//   optimize for a single bulk transfer throughput.
-// - TODO: can we implement some sort of adaptive max
-//   channel window size, starting with this small initial
-//   value and only growing based on connection properties?
-// - channelWindowSize directly defines the local channel
-//   window initial and max size. We also cap remote channel
-//   window sizes via an extra customization in the
-//   channelOpenConfirmMsg handler. Both upstream and
-//   downstream bulk data transfers have the same latency
-//   issue.
-// - For packet tunnel, use a larger channel window size,
-//   since all tunneled traffic flows through a single
-//   channel; we still select a size smaller than the stock
-//   channelWindowSize due to client memory constraints.
+//   - Use a smaller initial/max channel window size.
+//   - Testing with the full Psiphon stack shows that
+//     this smaller channel window size is more performant
+//     for low bandwidth connections while still adequate for
+//     higher bandwidth connections.
+//   - In Psiphon, a single SSH connection is used for all
+//     client port forwards. Bulk data transfers with large
+//     channel windows can immediately backlog the connection
+//     with many large SSH packets, introducing large latency
+//     for opening new channels. For Psiphon, we don't wish to
+//     optimize for a single bulk transfer throughput.
+//   - TODO: can we implement some sort of adaptive max
+//     channel window size, starting with this small initial
+//     value and only growing based on connection properties?
+//   - channelWindowSize directly defines the local channel
+//     window initial and max size. We also cap remote channel
+//     window sizes via an extra customization in the
+//     channelOpenConfirmMsg handler. Both upstream and
+//     downstream bulk data transfers have the same latency
+//     issue.
+//   - For packet tunnel, use a larger channel window size,
+//     since all tunneled traffic flows through a single
+//     channel; we still select a size smaller than the stock
+//     channelWindowSize due to client memory constraints.
 func getChannelWindowSize(chanType string) int {
 
 	// From "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol".
@@ -230,9 +230,11 @@ type channel struct {
 	pending    *buffer
 	extPending *buffer
 
-	// windowMu protects myWindow, the flow-control window.
-	windowMu sync.Mutex
-	myWindow uint32
+	// windowMu protects myWindow, the flow-control window, and myConsumed,
+	// the number of bytes consumed since we last increased myWindow
+	windowMu   sync.Mutex
+	myWindow   uint32
+	myConsumed uint32
 
 	// writeMu serializes calls to mux.conn.writePacket() and
 	// protects sentClose and packetPool. This mutex must be
@@ -375,14 +377,25 @@ func (ch *channel) handleData(packet []byte) error {
 	return nil
 }
 
-func (c *channel) adjustWindow(n uint32) error {
+func (c *channel) adjustWindow(adj uint32) error {
 	c.windowMu.Lock()
-	// Since myWindow is managed on our side, and can never exceed
-	// the initial window setting, we don't worry about overflow.
-	c.myWindow += uint32(n)
+	// Since myConsumed and myWindow are managed on our side, and can never
+	// exceed the initial window setting, we don't worry about overflow.
+	c.myConsumed += adj
+	var sendAdj uint32
+	channelWindowSize := uint32(getChannelWindowSize(c.chanType))
+	if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) ||
+		(c.myWindow < channelWindowSize/2) {
+		sendAdj = c.myConsumed
+		c.myConsumed = 0
+		c.myWindow += sendAdj
+	}
 	c.windowMu.Unlock()
+	if sendAdj == 0 {
+		return nil
+	}
 	return c.sendMessage(windowAdjustMsg{
-		AdditionalBytes: uint32(n),
+		AdditionalBytes: sendAdj,
 	})
 }
 

+ 8 - 8
psiphon/common/crypto/ssh/cipher.go

@@ -15,7 +15,6 @@ import (
 	"fmt"
 	"hash"
 	"io"
-	"io/ioutil"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/internal/poly1305"
 	"golang.org/x/crypto/chacha20"
@@ -97,13 +96,13 @@ func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream,
 // are not supported and will not be negotiated, even if explicitly requested in
 // ClientConfig.Crypto.Ciphers.
 var cipherModes = map[string]*cipherMode{
-	// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
+	// Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms
 	// are defined in the order specified in the RFC.
 	"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
 	"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
 	"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
 
-	// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
+	// Ciphers from RFC 4345, which introduces security-improved arcfour ciphers.
 	// They are defined in the order specified in the RFC.
 	"arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
 	"arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
@@ -111,11 +110,12 @@ var cipherModes = map[string]*cipherMode{
 	// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
 	// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
 	// RC4) has problems with weak keys, and should be used with caution."
-	// RFC4345 introduces improved versions of Arcfour.
+	// RFC 4345 introduces improved versions of Arcfour.
 	"arcfour": {16, 0, streamCipherMode(0, newRC4)},
 
 	// AEAD ciphers
-	gcmCipherID:        {16, 12, newGCMCipher},
+	gcm128CipherID:     {16, 12, newGCMCipher},
+	gcm256CipherID:     {32, 12, newGCMCipher},
 	chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
 
 	// CBC mode is insecure and so is not included in the default config.
@@ -497,7 +497,7 @@ func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error)
 			// data, to make distinguishing between
 			// failing MAC and failing length check more
 			// difficult.
-			io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage))
+			io.CopyN(io.Discard, r, int64(c.oracleCamouflage))
 		}
 	}
 	return p, err
@@ -640,9 +640,9 @@ const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
 // chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
 // AEAD, which is described here:
 //
-//   https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
+//	https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
 //
-// the methods here also implement padding, which RFC4253 Section 6
+// the methods here also implement padding, which RFC 4253 Section 6
 // also requires of stream ciphers.
 type chacha20Poly1305Cipher struct {
 	lengthKey  [32]byte

+ 1 - 1
psiphon/common/crypto/ssh/cipher_test.go

@@ -141,7 +141,7 @@ func TestCVE202143565(t *testing.T) {
 		constructPacket func(packetCipher) io.Reader
 	}{
 		{
-			cipher: gcmCipherID,
+			cipher: gcm128CipherID,
 			constructPacket: func(client packetCipher) io.Reader {
 				internalCipher := client.(*gcmCipher)
 				b := &bytes.Buffer{}

+ 9 - 18
psiphon/common/crypto/ssh/client.go

@@ -82,7 +82,7 @@ func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan
 
 	if err := conn.clientHandshake(addr, &fullConf); err != nil {
 		c.Close()
-		return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err)
+		return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err)
 	}
 	conn.mux = newMux(conn.transport)
 	return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
@@ -113,25 +113,16 @@ func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) e
 	return c.clientAuthenticate(config)
 }
 
-// verifyHostKeySignature verifies the host key obtained in the key
-// exchange.
+// verifyHostKeySignature verifies the host key obtained in the key exchange.
+// algo is the negotiated algorithm, and may be a certificate type.
 func verifyHostKeySignature(hostKey PublicKey, algo string, result *kexResult) error {
 	sig, rest, ok := parseSignatureBody(result.Signature)
 	if len(rest) > 0 || !ok {
 		return errors.New("ssh: signature parse error")
 	}
 
-	// For keys, underlyingAlgo is exactly algo. For certificates,
-	// we have to look up the underlying key algorithm that SSH
-	// uses to evaluate signatures.
-	underlyingAlgo := algo
-	for sigAlgo, certAlgo := range certAlgoNames {
-		if certAlgo == algo {
-			underlyingAlgo = sigAlgo
-		}
-	}
-	if sig.Format != underlyingAlgo {
-		return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, underlyingAlgo)
+	if a := underlyingAlgo(algo); sig.Format != a {
+		return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, a)
 	}
 
 	return hostKey.Verify(result.H, sig)
@@ -237,11 +228,11 @@ type ClientConfig struct {
 	// be used for the connection. If empty, a reasonable default is used.
 	ClientVersion string
 
-	// HostKeyAlgorithms lists the key types that the client will
-	// accept from the server as host key, in order of
+	// HostKeyAlgorithms lists the public key algorithms that the client will
+	// accept from the server for host key authentication, in order of
 	// preference. If empty, a reasonable default is used. Any
-	// string returned from PublicKey.Type method may be used, or
-	// any of the CertAlgoXxxx and KeyAlgoXxxx constants.
+	// string returned from a PublicKey.Type method may be used, or
+	// any of the CertAlgo and KeyAlgo constants.
 	HostKeyAlgorithms []string
 
 	// Timeout is the maximum amount of time for the TCP connection to establish.

+ 176 - 38
psiphon/common/crypto/ssh/client_auth.go

@@ -9,6 +9,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"strings"
 )
 
 type authResult int
@@ -29,6 +30,33 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
 	if err != nil {
 		return err
 	}
+	// The server may choose to send a SSH_MSG_EXT_INFO at this point (if we
+	// advertised willingness to receive one, which we always do) or not. See
+	// RFC 8308, Section 2.4.
+	extensions := make(map[string][]byte)
+	if len(packet) > 0 && packet[0] == msgExtInfo {
+		var extInfo extInfoMsg
+		if err := Unmarshal(packet, &extInfo); err != nil {
+			return err
+		}
+		payload := extInfo.Payload
+		for i := uint32(0); i < extInfo.NumExtensions; i++ {
+			name, rest, ok := parseString(payload)
+			if !ok {
+				return parseError(msgExtInfo)
+			}
+			value, rest, ok := parseString(rest)
+			if !ok {
+				return parseError(msgExtInfo)
+			}
+			extensions[string(name)] = value
+			payload = rest
+		}
+		packet, err = c.transport.readPacket()
+		if err != nil {
+			return err
+		}
+	}
 	var serviceAccept serviceAcceptMsg
 	if err := Unmarshal(packet, &serviceAccept); err != nil {
 		return err
@@ -41,9 +69,11 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
 
 	sessionID := c.transport.getSessionID()
 	for auth := AuthMethod(new(noneAuth)); auth != nil; {
-		ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand)
+		ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
 		if err != nil {
-			return err
+			// We return the error later if there is no other method left to
+			// try.
+			ok = authFailure
 		}
 		if ok == authSuccess {
 			// success
@@ -73,6 +103,12 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
 				}
 			}
 		}
+
+		if auth == nil && err != nil {
+			// We have an error and there are no other authentication methods to
+			// try, so we return it.
+			return err
+		}
 	}
 	return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried)
 }
@@ -93,7 +129,7 @@ type AuthMethod interface {
 	// If authentication is not successful, a []string of alternative
 	// method names is returned. If the slice is nil, it will be ignored
 	// and the previous set of possible methods will be reused.
-	auth(session []byte, user string, p packetConn, rand io.Reader) (authResult, []string, error)
+	auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error)
 
 	// method returns the RFC 4252 method name.
 	method() string
@@ -102,7 +138,7 @@ type AuthMethod interface {
 // "none" authentication, RFC 4252 section 5.2.
 type noneAuth int
 
-func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
 	if err := c.writePacket(Marshal(&userAuthRequestMsg{
 		User:    user,
 		Service: serviceSSH,
@@ -122,7 +158,7 @@ func (n *noneAuth) method() string {
 // a function call, e.g. by prompting the user.
 type passwordCallback func() (password string, err error)
 
-func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
 	type passwordAuthMsg struct {
 		User     string `sshtype:"50"`
 		Service  string
@@ -189,7 +225,77 @@ func (cb publicKeyCallback) method() string {
 	return "publickey"
 }
 
-func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) {
+	var as MultiAlgorithmSigner
+	keyFormat := signer.PublicKey().Type()
+
+	// If the signer implements MultiAlgorithmSigner we use the algorithms it
+	// support, if it implements AlgorithmSigner we assume it supports all
+	// algorithms, otherwise only the key format one.
+	switch s := signer.(type) {
+	case MultiAlgorithmSigner:
+		as = s
+	case AlgorithmSigner:
+		as = &multiAlgorithmSigner{
+			AlgorithmSigner:     s,
+			supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)),
+		}
+	default:
+		as = &multiAlgorithmSigner{
+			AlgorithmSigner:     algorithmSignerWrapper{signer},
+			supportedAlgorithms: []string{underlyingAlgo(keyFormat)},
+		}
+	}
+
+	getFallbackAlgo := func() (string, error) {
+		// Fallback to use if there is no "server-sig-algs" extension or a
+		// common algorithm cannot be found. We use the public key format if the
+		// MultiAlgorithmSigner supports it, otherwise we return an error.
+		if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
+			return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v",
+				underlyingAlgo(keyFormat), keyFormat, as.Algorithms())
+		}
+		return keyFormat, nil
+	}
+
+	extPayload, ok := extensions["server-sig-algs"]
+	if !ok {
+		// If there is no "server-sig-algs" extension use the fallback
+		// algorithm.
+		algo, err := getFallbackAlgo()
+		return as, algo, err
+	}
+
+	// The server-sig-algs extension only carries underlying signature
+	// algorithm, but we are trying to select a protocol-level public key
+	// algorithm, which might be a certificate type. Extend the list of server
+	// supported algorithms to include the corresponding certificate algorithms.
+	serverAlgos := strings.Split(string(extPayload), ",")
+	for _, algo := range serverAlgos {
+		if certAlgo, ok := certificateAlgo(algo); ok {
+			serverAlgos = append(serverAlgos, certAlgo)
+		}
+	}
+
+	// Filter algorithms based on those supported by MultiAlgorithmSigner.
+	var keyAlgos []string
+	for _, algo := range algorithmsForKeyFormat(keyFormat) {
+		if contains(as.Algorithms(), underlyingAlgo(algo)) {
+			keyAlgos = append(keyAlgos, algo)
+		}
+	}
+
+	algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
+	if err != nil {
+		// If there is no overlap, return the fallback algorithm to support
+		// servers that fail to list all supported algorithms.
+		algo, err := getFallbackAlgo()
+		return as, algo, err
+	}
+	return as, algo, nil
+}
+
+func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
 	// Authentication is performed by sending an enquiry to test if a key is
 	// acceptable to the remote. If the key is acceptable, the client will
 	// attempt to authenticate with the valid key.  If not the client will repeat
@@ -200,22 +306,50 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
 		return authFailure, nil, err
 	}
 	var methods []string
-	for _, signer := range signers {
-		ok, err := validateKey(signer.PublicKey(), user, c)
+	var errSigAlgo error
+
+	origSignersLen := len(signers)
+	for idx := 0; idx < len(signers); idx++ {
+		signer := signers[idx]
+		pub := signer.PublicKey()
+		as, algo, err := pickSignatureAlgorithm(signer, extensions)
+		if err != nil && errSigAlgo == nil {
+			// If we cannot negotiate a signature algorithm store the first
+			// error so we can return it to provide a more meaningful message if
+			// no other signers work.
+			errSigAlgo = err
+			continue
+		}
+		ok, err := validateKey(pub, algo, user, c)
 		if err != nil {
 			return authFailure, nil, err
 		}
+		// OpenSSH 7.2-7.7 advertises support for rsa-sha2-256 and rsa-sha2-512
+		// in the "server-sig-algs" extension but doesn't support these
+		// algorithms for certificate authentication, so if the server rejects
+		// the key try to use the obtained algorithm as if "server-sig-algs" had
+		// not been implemented if supported from the algorithm signer.
+		if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 {
+			if contains(as.Algorithms(), KeyAlgoRSA) {
+				// We retry using the compat algorithm after all signers have
+				// been tried normally.
+				signers = append(signers, &multiAlgorithmSigner{
+					AlgorithmSigner:     as,
+					supportedAlgorithms: []string{KeyAlgoRSA},
+				})
+			}
+		}
 		if !ok {
 			continue
 		}
 
-		pub := signer.PublicKey()
 		pubKey := pub.Marshal()
-		sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
+		data := buildDataSignedForAuth(session, userAuthRequestMsg{
 			User:    user,
 			Service: serviceSSH,
 			Method:  cb.method(),
-		}, []byte(pub.Type()), pubKey))
+		}, algo, pubKey)
+		sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
 		if err != nil {
 			return authFailure, nil, err
 		}
@@ -229,7 +363,7 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
 			Service:  serviceSSH,
 			Method:   cb.method(),
 			HasSig:   true,
-			Algoname: pub.Type(),
+			Algoname: algo,
 			PubKey:   pubKey,
 			Sig:      sig,
 		}
@@ -247,45 +381,34 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
 		// contain the "publickey" method, do not attempt to authenticate with any
 		// other keys.  According to RFC 4252 Section 7, the latter can occur when
 		// additional authentication methods are required.
-		if success == authSuccess || !containsMethod(methods, cb.method()) {
+		if success == authSuccess || !contains(methods, cb.method()) {
 			return success, methods, err
 		}
 	}
 
-	return authFailure, methods, nil
-}
-
-func containsMethod(methods []string, method string) bool {
-	for _, m := range methods {
-		if m == method {
-			return true
-		}
-	}
-
-	return false
+	return authFailure, methods, errSigAlgo
 }
 
 // validateKey validates the key provided is acceptable to the server.
-func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
+func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) {
 	pubKey := key.Marshal()
 	msg := publickeyAuthMsg{
 		User:     user,
 		Service:  serviceSSH,
 		Method:   "publickey",
 		HasSig:   false,
-		Algoname: key.Type(),
+		Algoname: algo,
 		PubKey:   pubKey,
 	}
 	if err := c.writePacket(Marshal(&msg)); err != nil {
 		return false, err
 	}
 
-	return confirmKeyAck(key, c)
+	return confirmKeyAck(key, algo, c)
 }
 
-func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
+func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
 	pubKey := key.Marshal()
-	algoname := key.Type()
 
 	for {
 		packet, err := c.readPacket()
@@ -302,14 +425,14 @@ func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
 			if err := Unmarshal(packet, &msg); err != nil {
 				return false, err
 			}
-			if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) {
+			if msg.Algo != algo || !bytes.Equal(msg.PubKey, pubKey) {
 				return false, nil
 			}
 			return true, nil
 		case msgUserAuthFailure:
 			return false, nil
 		default:
-			return false, unexpectedMessageError(msgUserAuthSuccess, packet[0])
+			return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0])
 		}
 	}
 }
@@ -330,6 +453,7 @@ func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMet
 // along with a list of remaining authentication methods to try next and
 // an error if an unexpected response was received.
 func handleAuthResponse(c packetConn) (authResult, []string, error) {
+	gotMsgExtInfo := false
 	for {
 		packet, err := c.readPacket()
 		if err != nil {
@@ -341,6 +465,12 @@ func handleAuthResponse(c packetConn) (authResult, []string, error) {
 			if err := handleBannerResponse(c, packet); err != nil {
 				return authFailure, nil, err
 			}
+		case msgExtInfo:
+			// Ignore post-authentication RFC 8308 extensions, once.
+			if gotMsgExtInfo {
+				return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
+			}
+			gotMsgExtInfo = true
 		case msgUserAuthFailure:
 			var msg userAuthFailureMsg
 			if err := Unmarshal(packet, &msg); err != nil {
@@ -380,10 +510,10 @@ func handleBannerResponse(c packetConn, packet []byte) error {
 // disabling echoing (e.g. for passwords), and return all the answers.
 // Challenge may be called multiple times in a single session. After
 // successful authentication, the server may send a challenge with no
-// questions, for which the user and instruction messages should be
+// questions, for which the name and instruction messages should be
 // printed.  RFC 4256 section 3.3 details how the UI should behave for
 // both CLI and GUI environments.
-type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error)
+type KeyboardInteractiveChallenge func(name, instruction string, questions []string, echos []bool) (answers []string, err error)
 
 // KeyboardInteractive returns an AuthMethod using a prompt/response
 // sequence controlled by the server.
@@ -395,7 +525,7 @@ func (cb KeyboardInteractiveChallenge) method() string {
 	return "keyboard-interactive"
 }
 
-func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
 	type initiateMsg struct {
 		User       string `sshtype:"50"`
 		Service    string
@@ -412,6 +542,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 		return authFailure, nil, err
 	}
 
+	gotMsgExtInfo := false
 	for {
 		packet, err := c.readPacket()
 		if err != nil {
@@ -425,6 +556,13 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 				return authFailure, nil, err
 			}
 			continue
+		case msgExtInfo:
+			// Ignore post-authentication RFC 8308 extensions, once.
+			if gotMsgExtInfo {
+				return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
+			}
+			gotMsgExtInfo = true
+			continue
 		case msgUserAuthInfoRequest:
 			// OK
 		case msgUserAuthFailure:
@@ -465,7 +603,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 			return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
 		}
 
-		answers, err := cb(msg.User, msg.Instruction, prompts, echos)
+		answers, err := cb(msg.Name, msg.Instruction, prompts, echos)
 		if err != nil {
 			return authFailure, nil, err
 		}
@@ -497,9 +635,9 @@ type retryableAuthMethod struct {
 	maxTries   int
 }
 
-func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok authResult, methods []string, err error) {
+func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) {
 	for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
-		ok, methods, err = r.authMethod.auth(session, user, c, rand)
+		ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions)
 		if ok != authFailure || err != nil { // either success, partial success or error terminate
 			return ok, methods, err
 		}
@@ -542,7 +680,7 @@ type gssAPIWithMICCallback struct {
 	target       string
 }
 
-func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
+func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
 	m := &userAuthRequestMsg{
 		User:    user,
 		Service: serviceSSH,

+ 388 - 2
psiphon/common/crypto/ssh/client_auth_test.go

@@ -13,6 +13,7 @@ import (
 	"log"
 	"net"
 	"os"
+	"runtime"
 	"strings"
 	"testing"
 )
@@ -104,11 +105,61 @@ func tryAuthBothSides(t *testing.T, config *ClientConfig, gssAPIWithMICConfig *G
 	return err, serverAuthErrors
 }
 
+type loggingAlgorithmSigner struct {
+	used []string
+	AlgorithmSigner
+}
+
+func (l *loggingAlgorithmSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
+	l.used = append(l.used, "[Sign]")
+	return l.AlgorithmSigner.Sign(rand, data)
+}
+
+func (l *loggingAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
+	l.used = append(l.used, algorithm)
+	return l.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
+}
+
 func TestClientAuthPublicKey(t *testing.T) {
+	signer := &loggingAlgorithmSigner{AlgorithmSigner: testSigners["rsa"].(AlgorithmSigner)}
 	config := &ClientConfig{
 		User: "testuser",
 		Auth: []AuthMethod{
-			PublicKeys(testSigners["rsa"]),
+			PublicKeys(signer),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	if err := tryAuth(t, config); err != nil {
+		t.Fatalf("unable to dial remote side: %s", err)
+	}
+	if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSASHA256 {
+		t.Errorf("unexpected Sign/SignWithAlgorithm calls: %q", signer.used)
+	}
+}
+
+// TestClientAuthNoSHA2 tests a ssh-rsa Signer that doesn't implement AlgorithmSigner.
+func TestClientAuthNoSHA2(t *testing.T) {
+	config := &ClientConfig{
+		User: "testuser",
+		Auth: []AuthMethod{
+			PublicKeys(&legacyRSASigner{testSigners["rsa"]}),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	if err := tryAuth(t, config); err != nil {
+		t.Fatalf("unable to dial remote side: %s", err)
+	}
+}
+
+// TestClientAuthThirdKey checks that the third configured can succeed. If we
+// were to do three attempts for each key (rsa-sha2-256, rsa-sha2-512, ssh-rsa),
+// we'd hit the six maximum attempts before reaching it.
+func TestClientAuthThirdKey(t *testing.T) {
+	config := &ClientConfig{
+		User: "testuser",
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa-openssh-format"],
+				testSigners["rsa-openssh-format"], testSigners["rsa"]),
 		},
 		HostKeyCallback: InsecureIgnoreHostKey(),
 	}
@@ -639,7 +690,15 @@ func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) {
 	if err := tryAuth(t, invalidConfig); err == nil {
 		t.Fatalf("client: got no error, want %s", expectedErr)
 	} else if err.Error() != expectedErr.Error() {
-		t.Fatalf("client: got %s, want %s", err, expectedErr)
+		// On Windows we can see a WSAECONNABORTED error
+		// if the client writes another authentication request
+		// before the client goroutine reads the disconnection
+		// message.  See issue 50805.
+		if runtime.GOOS == "windows" && strings.Contains(err.Error(), "wsarecv: An established connection was aborted") {
+			// OK.
+		} else {
+			t.Fatalf("client: got %s, want %s", err, expectedErr)
+		}
 	}
 }
 
@@ -896,3 +955,330 @@ func TestAuthMethodGSSAPIWithMIC(t *testing.T) {
 		}
 	}
 }
+
+func TestCompatibleAlgoAndSignatures(t *testing.T) {
+	type testcase struct {
+		algo       string
+		sigFormat  string
+		compatible bool
+	}
+	testcases := []*testcase{
+		{
+			KeyAlgoRSA,
+			KeyAlgoRSA,
+			true,
+		},
+		{
+			KeyAlgoRSA,
+			KeyAlgoRSASHA256,
+			true,
+		},
+		{
+			KeyAlgoRSA,
+			KeyAlgoRSASHA512,
+			true,
+		},
+		{
+			KeyAlgoRSASHA256,
+			KeyAlgoRSA,
+			true,
+		},
+		{
+			KeyAlgoRSASHA512,
+			KeyAlgoRSA,
+			true,
+		},
+		{
+			KeyAlgoRSASHA512,
+			KeyAlgoRSASHA256,
+			true,
+		},
+		{
+			KeyAlgoRSASHA256,
+			KeyAlgoRSASHA512,
+			true,
+		},
+		{
+			KeyAlgoRSASHA512,
+			KeyAlgoRSASHA512,
+			true,
+		},
+		{
+			CertAlgoRSAv01,
+			KeyAlgoRSA,
+			true,
+		},
+		{
+			CertAlgoRSAv01,
+			KeyAlgoRSASHA256,
+			true,
+		},
+		{
+			CertAlgoRSAv01,
+			KeyAlgoRSASHA512,
+			true,
+		},
+		{
+			CertAlgoRSASHA256v01,
+			KeyAlgoRSASHA512,
+			true,
+		},
+		{
+			CertAlgoRSASHA512v01,
+			KeyAlgoRSASHA512,
+			true,
+		},
+		{
+			CertAlgoRSASHA512v01,
+			KeyAlgoRSASHA256,
+			true,
+		},
+		{
+			CertAlgoRSASHA256v01,
+			CertAlgoRSAv01,
+			true,
+		},
+		{
+			CertAlgoRSAv01,
+			CertAlgoRSASHA512v01,
+			true,
+		},
+		{
+			KeyAlgoECDSA256,
+			KeyAlgoRSA,
+			false,
+		},
+		{
+			KeyAlgoECDSA256,
+			KeyAlgoECDSA521,
+			false,
+		},
+		{
+			KeyAlgoECDSA256,
+			KeyAlgoECDSA256,
+			true,
+		},
+		{
+			KeyAlgoECDSA256,
+			KeyAlgoED25519,
+			false,
+		},
+		{
+			KeyAlgoED25519,
+			KeyAlgoED25519,
+			true,
+		},
+	}
+
+	for _, c := range testcases {
+		if isAlgoCompatible(c.algo, c.sigFormat) != c.compatible {
+			t.Errorf("algorithm %q, signature format %q, expected compatible to be %t", c.algo, c.sigFormat, c.compatible)
+		}
+	}
+}
+
+func TestPickSignatureAlgorithm(t *testing.T) {
+	type testcase struct {
+		name       string
+		extensions map[string][]byte
+	}
+	cases := []testcase{
+		{
+			name: "server with empty server-sig-algs",
+			extensions: map[string][]byte{
+				"server-sig-algs": []byte(``),
+			},
+		},
+		{
+			name:       "server with no server-sig-algs",
+			extensions: nil,
+		},
+	}
+	for _, c := range cases {
+		t.Run(c.name, func(t *testing.T) {
+			signer, ok := testSigners["rsa"].(MultiAlgorithmSigner)
+			if !ok {
+				t.Fatalf("rsa test signer does not implement the MultiAlgorithmSigner interface")
+			}
+			// The signer supports the public key algorithm which is then returned.
+			_, algo, err := pickSignatureAlgorithm(signer, c.extensions)
+			if err != nil {
+				t.Fatalf("got %v, want no error", err)
+			}
+			if algo != signer.PublicKey().Type() {
+				t.Fatalf("got algo %q, want %q", algo, signer.PublicKey().Type())
+			}
+			// Test a signer that uses a certificate algorithm as the public key
+			// type.
+			cert := &Certificate{
+				CertType: UserCert,
+				Key:      signer.PublicKey(),
+			}
+			cert.SignCert(rand.Reader, signer)
+
+			certSigner, err := NewCertSigner(cert, signer)
+			if err != nil {
+				t.Fatalf("error generating cert signer: %v", err)
+			}
+			// The signer supports the public key algorithm and the
+			// public key format is a certificate type so the cerificate
+			// algorithm matching the key format must be returned
+			_, algo, err = pickSignatureAlgorithm(certSigner, c.extensions)
+			if err != nil {
+				t.Fatalf("got %v, want no error", err)
+			}
+			if algo != certSigner.PublicKey().Type() {
+				t.Fatalf("got algo %q, want %q", algo, certSigner.PublicKey().Type())
+			}
+			signer, err = NewSignerWithAlgorithms(signer.(AlgorithmSigner), []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256})
+			if err != nil {
+				t.Fatalf("unable to create signer with algorithms: %v", err)
+			}
+			// The signer does not support the public key algorithm so an error
+			// is returned.
+			_, _, err = pickSignatureAlgorithm(signer, c.extensions)
+			if err == nil {
+				t.Fatal("got no error, no common public key signature algorithm error expected")
+			}
+		})
+	}
+}
+
+// configurablePublicKeyCallback is a public key callback that allows to
+// configure the signature algorithm and format. This way we can emulate the
+// behavior of buggy clients.
+type configurablePublicKeyCallback struct {
+	signer          AlgorithmSigner
+	signatureAlgo   string
+	signatureFormat string
+}
+
+func (cb configurablePublicKeyCallback) method() string {
+	return "publickey"
+}
+
+func (cb configurablePublicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
+	pub := cb.signer.PublicKey()
+
+	ok, err := validateKey(pub, cb.signatureAlgo, user, c)
+	if err != nil {
+		return authFailure, nil, err
+	}
+	if !ok {
+		return authFailure, nil, fmt.Errorf("invalid public key")
+	}
+
+	pubKey := pub.Marshal()
+	data := buildDataSignedForAuth(session, userAuthRequestMsg{
+		User:    user,
+		Service: serviceSSH,
+		Method:  cb.method(),
+	}, cb.signatureAlgo, pubKey)
+	sign, err := cb.signer.SignWithAlgorithm(rand, data, underlyingAlgo(cb.signatureFormat))
+	if err != nil {
+		return authFailure, nil, err
+	}
+
+	s := Marshal(sign)
+	sig := make([]byte, stringLength(len(s)))
+	marshalString(sig, s)
+	msg := publickeyAuthMsg{
+		User:     user,
+		Service:  serviceSSH,
+		Method:   cb.method(),
+		HasSig:   true,
+		Algoname: cb.signatureAlgo,
+		PubKey:   pubKey,
+		Sig:      sig,
+	}
+	p := Marshal(&msg)
+	if err := c.writePacket(p); err != nil {
+		return authFailure, nil, err
+	}
+	var success authResult
+	success, methods, err := handleAuthResponse(c)
+	if err != nil {
+		return authFailure, nil, err
+	}
+	if success == authSuccess || !contains(methods, cb.method()) {
+		return success, methods, err
+	}
+
+	return authFailure, methods, nil
+}
+
+func TestPublicKeyAndAlgoCompatibility(t *testing.T) {
+	cert := &Certificate{
+		Key:         testPublicKeys["rsa"],
+		ValidBefore: CertTimeInfinity,
+		CertType:    UserCert,
+	}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	certSigner, err := NewCertSigner(cert, testSigners["rsa"])
+	if err != nil {
+		t.Fatalf("NewCertSigner: %v", err)
+	}
+
+	clientConfig := &ClientConfig{
+		User:            "user",
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		Auth: []AuthMethod{
+			configurablePublicKeyCallback{
+				signer:          certSigner.(AlgorithmSigner),
+				signatureAlgo:   KeyAlgoRSASHA256,
+				signatureFormat: KeyAlgoRSASHA256,
+			},
+		},
+	}
+	if err := tryAuth(t, clientConfig); err == nil {
+		t.Error("cert login passed with incompatible public key type and algorithm")
+	}
+}
+
+func TestClientAuthGPGAgentCompat(t *testing.T) {
+	clientConfig := &ClientConfig{
+		User:            "testuser",
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		Auth: []AuthMethod{
+			// algorithm rsa-sha2-512 and signature format ssh-rsa.
+			configurablePublicKeyCallback{
+				signer:          testSigners["rsa"].(AlgorithmSigner),
+				signatureAlgo:   KeyAlgoRSASHA512,
+				signatureFormat: KeyAlgoRSA,
+			},
+		},
+	}
+	if err := tryAuth(t, clientConfig); err != nil {
+		t.Fatalf("unable to dial remote side: %s", err)
+	}
+}
+
+func TestCertAuthOpenSSHCompat(t *testing.T) {
+	cert := &Certificate{
+		Key:         testPublicKeys["rsa"],
+		ValidBefore: CertTimeInfinity,
+		CertType:    UserCert,
+	}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	certSigner, err := NewCertSigner(cert, testSigners["rsa"])
+	if err != nil {
+		t.Fatalf("NewCertSigner: %v", err)
+	}
+
+	clientConfig := &ClientConfig{
+		User:            "user",
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		Auth: []AuthMethod{
+			// algorithm ssh-rsa-cert-v01@openssh.com and signature format
+			// rsa-sha2-256.
+			configurablePublicKeyCallback{
+				signer:          certSigner.(AlgorithmSigner),
+				signatureAlgo:   CertAlgoRSAv01,
+				signatureFormat: KeyAlgoRSASHA256,
+			},
+		},
+	}
+	if err := tryAuth(t, clientConfig); err != nil {
+		t.Fatalf("unable to dial remote side: %s", err)
+	}
+}

+ 116 - 5
psiphon/common/crypto/ssh/client_test.go

@@ -7,6 +7,9 @@ package ssh
 import (
 	"bytes"
 	"crypto/rand"
+	"errors"
+	"fmt"
+	"net"
 	"strings"
 	"testing"
 )
@@ -125,9 +128,9 @@ func TestVerifyHostKeySignature(t *testing.T) {
 		verifyAlgo string
 		wantError  string
 	}{
-		{"rsa", SigAlgoRSA, SigAlgoRSA, ""},
-		{"rsa", SigAlgoRSASHA2256, SigAlgoRSASHA2256, ""},
-		{"rsa", SigAlgoRSA, SigAlgoRSASHA2512, `ssh: invalid signature algorithm "ssh-rsa", expected "rsa-sha2-512"`},
+		{"rsa", KeyAlgoRSA, KeyAlgoRSA, ""},
+		{"rsa", KeyAlgoRSASHA256, KeyAlgoRSASHA256, ""},
+		{"rsa", KeyAlgoRSA, KeyAlgoRSASHA512, `ssh: invalid signature algorithm "ssh-rsa", expected "rsa-sha2-512"`},
 		{"ed25519", KeyAlgoED25519, KeyAlgoED25519, ""},
 	} {
 		key := testSigners[tt.key].PublicKey()
@@ -207,9 +210,12 @@ func TestBannerCallback(t *testing.T) {
 }
 
 func TestNewClientConn(t *testing.T) {
+	errHostKeyMismatch := errors.New("host key mismatch")
+
 	for _, tt := range []struct {
-		name string
-		user string
+		name                    string
+		user                    string
+		simulateHostKeyMismatch HostKeyCallback
 	}{
 		{
 			name: "good user field for ConnMetadata",
@@ -219,6 +225,13 @@ func TestNewClientConn(t *testing.T) {
 			name: "empty user field for ConnMetadata",
 			user: "",
 		},
+		{
+			name: "host key mismatch",
+			user: "testuser",
+			simulateHostKeyMismatch: func(hostname string, remote net.Addr, key PublicKey) error {
+				return fmt.Errorf("%w: %s", errHostKeyMismatch, bytes.TrimSpace(MarshalAuthorizedKey(key)))
+			},
+		},
 	} {
 		t.Run(tt.name, func(t *testing.T) {
 			c1, c2, err := netPipe()
@@ -243,8 +256,16 @@ func TestNewClientConn(t *testing.T) {
 				},
 				HostKeyCallback: InsecureIgnoreHostKey(),
 			}
+
+			if tt.simulateHostKeyMismatch != nil {
+				clientConf.HostKeyCallback = tt.simulateHostKeyMismatch
+			}
+
 			clientConn, _, _, err := NewClientConn(c2, "", clientConf)
 			if err != nil {
+				if tt.simulateHostKeyMismatch != nil && errors.Is(err, errHostKeyMismatch) {
+					return
+				}
 				t.Fatal(err)
 			}
 
@@ -254,3 +275,93 @@ func TestNewClientConn(t *testing.T) {
 		})
 	}
 }
+
+func TestUnsupportedAlgorithm(t *testing.T) {
+	for _, tt := range []struct {
+		name      string
+		config    Config
+		wantError string
+	}{
+		{
+			"unsupported KEX",
+			Config{
+				KeyExchanges: []string{"unsupported"},
+			},
+			"no common algorithm",
+		},
+		{
+			"unsupported and supported KEXs",
+			Config{
+				KeyExchanges: []string{"unsupported", kexAlgoCurve25519SHA256},
+			},
+			"",
+		},
+		{
+			"unsupported cipher",
+			Config{
+				Ciphers: []string{"unsupported"},
+			},
+			"no common algorithm",
+		},
+		{
+			"unsupported and supported ciphers",
+			Config{
+				Ciphers: []string{"unsupported", chacha20Poly1305ID},
+			},
+			"",
+		},
+		{
+			"unsupported MAC",
+			Config{
+				MACs: []string{"unsupported"},
+				// MAC is used for non AAED ciphers.
+				Ciphers: []string{"aes256-ctr"},
+			},
+			"no common algorithm",
+		},
+		{
+			"unsupported and supported MACs",
+			Config{
+				MACs: []string{"unsupported", "hmac-sha2-256-etm@openssh.com"},
+				// MAC is used for non AAED ciphers.
+				Ciphers: []string{"aes256-ctr"},
+			},
+			"",
+		},
+	} {
+		t.Run(tt.name, func(t *testing.T) {
+			c1, c2, err := netPipe()
+			if err != nil {
+				t.Fatalf("netPipe: %v", err)
+			}
+			defer c1.Close()
+			defer c2.Close()
+
+			serverConf := &ServerConfig{
+				Config: tt.config,
+				PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+					return &Permissions{}, nil
+				},
+			}
+			serverConf.AddHostKey(testSigners["rsa"])
+			go NewServerConn(c1, serverConf)
+
+			clientConf := &ClientConfig{
+				User:   "testuser",
+				Config: tt.config,
+				Auth: []AuthMethod{
+					Password("testpw"),
+				},
+				HostKeyCallback: InsecureIgnoreHostKey(),
+			}
+			_, _, _, err = NewClientConn(c2, "", clientConf)
+			if err != nil {
+				if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) {
+					t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError)
+				}
+			} else if tt.wantError != "" {
+				t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError)
+			}
+		})
+	}
+}

+ 116 - 52
psiphon/common/crypto/ssh/common.go

@@ -30,7 +30,7 @@ const (
 // supportedCiphers lists ciphers we support but might not recommend.
 var supportedCiphers = []string{
 	"aes128-ctr", "aes192-ctr", "aes256-ctr",
-	"aes128-gcm@openssh.com",
+	"aes128-gcm@openssh.com", gcm256CipherID,
 	chacha20Poly1305ID,
 	"arcfour256", "arcfour128", "arcfour",
 	aes128cbcID,
@@ -39,7 +39,7 @@ var supportedCiphers = []string{
 
 // preferredCiphers specifies the default preference for ciphers.
 var preferredCiphers = []string{
-	"aes128-gcm@openssh.com",
+	"aes128-gcm@openssh.com", gcm256CipherID,
 	chacha20Poly1305ID,
 	"aes128-ctr", "aes192-ctr", "aes256-ctr",
 }
@@ -47,14 +47,12 @@ var preferredCiphers = []string{
 // supportedKexAlgos specifies the supported key-exchange algorithms in
 // preference order.
 var supportedKexAlgos = []string{
-	kexAlgoCurve25519SHA256,
+	kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
 	// P384 and P521 are not constant-time yet, but since we don't
 	// reuse ephemeral keys, using them for ECDH should be OK.
 	kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
-
-	// [Psiphon]
-	// Remove kexAlgoDH1SHA1 and add kexAlgoDH14SHA256
-	kexAlgoDH14SHA256, kexAlgoDH14SHA1,
+	kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1,
+	kexAlgoDH1SHA1,
 }
 
 // serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden
@@ -64,27 +62,29 @@ var serverForbiddenKexAlgos = map[string]struct{}{
 	kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests
 }
 
-// preferredKexAlgos specifies the default preference for key-exchange algorithms
-// in preference order.
+// preferredKexAlgos specifies the default preference for key-exchange
+// algorithms in preference order. The diffie-hellman-group16-sha512 algorithm
+// is disabled by default because it is a bit slower than the others.
 var preferredKexAlgos = []string{
-	kexAlgoCurve25519SHA256,
+	kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
 	kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
+	kexAlgoDH14SHA256, kexAlgoDH14SHA1,
 
 	// [Psiphon]
-	// Add kexAlgoDH14SHA256
-	kexAlgoDH14SHA256, kexAlgoDH14SHA1,
+	// Enable kexAlgoDH16SHA512
+	kexAlgoDH16SHA512,
 }
 
 // supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
 // of authenticating servers) in preference order.
 var supportedHostKeyAlgos = []string{
-	CertSigAlgoRSASHA2512v01, CertSigAlgoRSASHA2256v01,
-	CertSigAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
+	CertAlgoRSASHA256v01, CertAlgoRSASHA512v01,
+	CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
 	CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01,
 
 	KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
-	SigAlgoRSASHA2512, SigAlgoRSASHA2256,
-	SigAlgoRSA, KeyAlgoDSA,
+	KeyAlgoRSASHA256, KeyAlgoRSASHA512,
+	KeyAlgoRSA, KeyAlgoDSA,
 
 	KeyAlgoED25519,
 }
@@ -93,28 +93,65 @@ var supportedHostKeyAlgos = []string{
 // This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
 // because they have reached the end of their useful life.
 var supportedMACs = []string{
-	"hmac-sha2-256-etm@openssh.com", "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96",
+	"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96",
 }
 
 var supportedCompressions = []string{compressionNone}
 
-// hashFuncs keeps the mapping of supported algorithms to their respective
-// hashes needed for signature verification.
+// hashFuncs keeps the mapping of supported signature algorithms to their
+// respective hashes needed for signing and verification.
 var hashFuncs = map[string]crypto.Hash{
-	SigAlgoRSA:               crypto.SHA1,
-	SigAlgoRSASHA2256:        crypto.SHA256,
-	SigAlgoRSASHA2512:        crypto.SHA512,
-	KeyAlgoDSA:               crypto.SHA1,
-	KeyAlgoECDSA256:          crypto.SHA256,
-	KeyAlgoECDSA384:          crypto.SHA384,
-	KeyAlgoECDSA521:          crypto.SHA512,
-	CertSigAlgoRSAv01:        crypto.SHA1,
-	CertSigAlgoRSASHA2256v01: crypto.SHA256,
-	CertSigAlgoRSASHA2512v01: crypto.SHA512,
-	CertAlgoDSAv01:           crypto.SHA1,
-	CertAlgoECDSA256v01:      crypto.SHA256,
-	CertAlgoECDSA384v01:      crypto.SHA384,
-	CertAlgoECDSA521v01:      crypto.SHA512,
+	KeyAlgoRSA:       crypto.SHA1,
+	KeyAlgoRSASHA256: crypto.SHA256,
+	KeyAlgoRSASHA512: crypto.SHA512,
+	KeyAlgoDSA:       crypto.SHA1,
+	KeyAlgoECDSA256:  crypto.SHA256,
+	KeyAlgoECDSA384:  crypto.SHA384,
+	KeyAlgoECDSA521:  crypto.SHA512,
+	// KeyAlgoED25519 doesn't pre-hash.
+	KeyAlgoSKECDSA256: crypto.SHA256,
+	KeyAlgoSKED25519:  crypto.SHA256,
+}
+
+// algorithmsForKeyFormat returns the supported signature algorithms for a given
+// public key format (PublicKey.Type), in order of preference. See RFC 8332,
+// Section 2. See also the note in sendKexInit on backwards compatibility.
+func algorithmsForKeyFormat(keyFormat string) []string {
+	switch keyFormat {
+	case KeyAlgoRSA:
+		return []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA}
+	case CertAlgoRSAv01:
+		return []string{CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, CertAlgoRSAv01}
+	default:
+		return []string{keyFormat}
+	}
+}
+
+// isRSA returns whether algo is a supported RSA algorithm, including certificate
+// algorithms.
+func isRSA(algo string) bool {
+	algos := algorithmsForKeyFormat(KeyAlgoRSA)
+	return contains(algos, underlyingAlgo(algo))
+}
+
+func isRSACert(algo string) bool {
+	_, ok := certKeyAlgoNames[algo]
+	if !ok {
+		return false
+	}
+	return isRSA(algo)
+}
+
+// supportedPubKeyAuthAlgos specifies the supported client public key
+// authentication algorithms. Note that this doesn't include certificate types
+// since those use the underlying algorithm. This list is sent to the client if
+// it supports the server-sig-algs extension. Order is irrelevant.
+var supportedPubKeyAuthAlgos = []string{
+	KeyAlgoED25519,
+	KeyAlgoSKED25519, KeyAlgoSKECDSA256,
+	KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
+	KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA,
+	KeyAlgoDSA,
 }
 
 // unexpectedMessageError results when the SSH message that we received didn't
@@ -148,19 +185,25 @@ type directionAlgorithms struct {
 
 // rekeyBytes returns a rekeying intervals in bytes.
 func (a *directionAlgorithms) rekeyBytes() int64 {
-	// According to RFC4344 block ciphers should rekey after
+	// According to RFC 4344 block ciphers should rekey after
 	// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
 	// 128.
 	switch a.Cipher {
-	case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID:
+	case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID:
 		return 16 * (1 << 32)
 
 	}
 
-	// For others, stick with RFC4253 recommendation to rekey after 1 Gb of data.
+	// For others, stick with RFC 4253 recommendation to rekey after 1 Gb of data.
 	return 1 << 30
 }
 
+var aeadCiphers = map[string]bool{
+	gcm128CipherID:     true,
+	gcm256CipherID:     true,
+	chacha20Poly1305ID: true,
+}
+
 type algorithms struct {
 	kex     string
 	hostKey string
@@ -196,14 +239,18 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs
 		return
 	}
 
-	ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
-	if err != nil {
-		return
+	if !aeadCiphers[ctos.Cipher] {
+		ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
+		if err != nil {
+			return
+		}
 	}
 
-	stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
-	if err != nil {
-		return
+	if !aeadCiphers[stoc.Cipher] {
+		stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
+		if err != nil {
+			return
+		}
 	}
 
 	ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
@@ -236,16 +283,16 @@ type Config struct {
 	// unspecified, a size suitable for the chosen cipher is used.
 	RekeyThreshold uint64
 
-	// The allowed key exchanges algorithms. If unspecified then a
-	// default set of algorithms is used.
+	// The allowed key exchanges algorithms. If unspecified then a default set
+	// of algorithms is used. Unsupported values are silently ignored.
 	KeyExchanges []string
 
-	// The allowed cipher algorithms. If unspecified then a sensible
-	// default is used.
+	// The allowed cipher algorithms. If unspecified then a sensible default is
+	// used. Unsupported values are silently ignored.
 	Ciphers []string
 
-	// The allowed MAC algorithms. If unspecified then a sensible default
-	// is used.
+	// The allowed MAC algorithms. If unspecified then a sensible default is
+	// used. Unsupported values are silently ignored.
 	MACs []string
 
 	// [Psiphon]
@@ -275,7 +322,7 @@ func (c *Config) SetDefaults() {
 	var ciphers []string
 	for _, c := range c.Ciphers {
 		if cipherModes[c] != nil {
-			// reject the cipher if we have no cipherModes definition
+			// Ignore the cipher if we have no cipherModes definition.
 			ciphers = append(ciphers, c)
 		}
 	}
@@ -284,10 +331,26 @@ func (c *Config) SetDefaults() {
 	if c.KeyExchanges == nil {
 		c.KeyExchanges = preferredKexAlgos
 	}
+	var kexs []string
+	for _, k := range c.KeyExchanges {
+		if kexAlgoMap[k] != nil {
+			// Ignore the KEX if we have no kexAlgoMap definition.
+			kexs = append(kexs, k)
+		}
+	}
+	c.KeyExchanges = kexs
 
 	if c.MACs == nil {
 		c.MACs = supportedMACs
 	}
+	var macs []string
+	for _, m := range c.MACs {
+		if macModes[m] != nil {
+			// Ignore the MAC if we have no macModes definition.
+			macs = append(macs, m)
+		}
+	}
+	c.MACs = macs
 
 	if c.RekeyThreshold == 0 {
 		// cipher specific default
@@ -300,8 +363,9 @@ func (c *Config) SetDefaults() {
 }
 
 // buildDataSignedForAuth returns the data that is signed in order to prove
-// possession of a private key. See RFC 4252, section 7.
-func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
+// possession of a private key. See RFC 4252, section 7. algo is the advertised
+// algorithm, and may be a certificate type.
+func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo string, pubKey []byte) []byte {
 	data := struct {
 		Session []byte
 		Type    byte
@@ -309,7 +373,7 @@ func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubK
 		Service string
 		Method  string
 		Sign    bool
-		Algo    []byte
+		Algo    string
 		PubKey  []byte
 	}{
 		sessionID,

+ 5 - 5
psiphon/common/crypto/ssh/common_test.go

@@ -82,11 +82,11 @@ func TestFindAgreedAlgorithms(t *testing.T) {
 	}
 
 	cases := []testcase{
-		testcase{
+		{
 			name: "standard",
 		},
 
-		testcase{
+		{
 			name: "no common hostkey",
 			serverIn: kexInitMsg{
 				ServerHostKeyAlgos: []string{"hostkey2"},
@@ -94,7 +94,7 @@ func TestFindAgreedAlgorithms(t *testing.T) {
 			wantErr: true,
 		},
 
-		testcase{
+		{
 			name: "no common kex",
 			serverIn: kexInitMsg{
 				KexAlgos: []string{"kex2"},
@@ -102,7 +102,7 @@ func TestFindAgreedAlgorithms(t *testing.T) {
 			wantErr: true,
 		},
 
-		testcase{
+		{
 			name: "no common cipher",
 			serverIn: kexInitMsg{
 				CiphersClientServer: []string{"cipher2"},
@@ -110,7 +110,7 @@ func TestFindAgreedAlgorithms(t *testing.T) {
 			wantErr: true,
 		},
 
-		testcase{
+		{
 			name: "client decides cipher",
 			serverIn: kexInitMsg{
 				CiphersClientServer: []string{"cipher1", "cipher2"},

+ 2 - 2
psiphon/common/crypto/ssh/connection.go

@@ -52,7 +52,7 @@ type Conn interface {
 
 	// SendRequest sends a global request, and returns the
 	// reply. If wantReply is true, it returns the response status
-	// and payload. See also RFC4254, section 4.
+	// and payload. See also RFC 4254, section 4.
 	SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
 
 	// OpenChannel tries to open an channel. If the request is
@@ -97,7 +97,7 @@ func (c *connection) Close() error {
 	return c.sshConn.conn.Close()
 }
 
-// sshconn provides net.Conn metadata, but disallows direct reads and
+// sshConn provides net.Conn metadata, but disallows direct reads and
 // writes.
 type sshConn struct {
 	conn net.Conn

+ 4 - 2
psiphon/common/crypto/ssh/doc.go

@@ -12,8 +12,10 @@ the multiplexed nature of SSH is exposed to users that wish to support
 others.
 
 References:
-  [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
-  [SSH-PARAMETERS]:    http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
+
+	[PROTOCOL]: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL?rev=HEAD
+	[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
+	[SSH-PARAMETERS]:    http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
 
 This package does not fall under the stability promise of the Go language itself,
 so its API may be changed when pressing needs arise.

+ 86 - 7
psiphon/common/crypto/ssh/example_test.go

@@ -7,14 +7,16 @@ package ssh_test
 import (
 	"bufio"
 	"bytes"
+	"crypto/rand"
+	"crypto/rsa"
 	"fmt"
-	"io/ioutil"
 	"log"
 	"net"
 	"net/http"
 	"os"
 	"path/filepath"
 	"strings"
+	"sync"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/terminal"
@@ -24,7 +26,7 @@ func ExampleNewServerConn() {
 	// Public key authentication is done by comparing
 	// the public key of a received connection
 	// with the entries in the authorized_keys file.
-	authorizedKeysBytes, err := ioutil.ReadFile("authorized_keys")
+	authorizedKeysBytes, err := os.ReadFile("authorized_keys")
 	if err != nil {
 		log.Fatalf("Failed to load authorized_keys, err: %v", err)
 	}
@@ -67,7 +69,7 @@ func ExampleNewServerConn() {
 		},
 	}
 
-	privateBytes, err := ioutil.ReadFile("id_rsa")
+	privateBytes, err := os.ReadFile("id_rsa")
 	if err != nil {
 		log.Fatal("Failed to load private key: ", err)
 	}
@@ -76,7 +78,6 @@ func ExampleNewServerConn() {
 	if err != nil {
 		log.Fatal("Failed to parse private key: ", err)
 	}
-
 	config.AddHostKey(private)
 
 	// Once a ServerConfig has been configured, connections can be
@@ -98,8 +99,15 @@ func ExampleNewServerConn() {
 	}
 	log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])
 
+	var wg sync.WaitGroup
+	defer wg.Wait()
+
 	// The incoming Request channel must be serviced.
-	go ssh.DiscardRequests(reqs)
+	wg.Add(1)
+	go func() {
+		ssh.DiscardRequests(reqs)
+		wg.Done()
+	}()
 
 	// Service the incoming Channel channel.
 	for newChannel := range chans {
@@ -119,16 +127,22 @@ func ExampleNewServerConn() {
 		// Sessions have out-of-band requests such as "shell",
 		// "pty-req" and "env".  Here we handle only the
 		// "shell" request.
+		wg.Add(1)
 		go func(in <-chan *ssh.Request) {
 			for req := range in {
 				req.Reply(req.Type == "shell", nil)
 			}
+			wg.Done()
 		}(requests)
 
 		term := terminal.NewTerminal(channel, "> ")
 
+		wg.Add(1)
 		go func() {
-			defer channel.Close()
+			defer func() {
+				channel.Close()
+				wg.Done()
+			}()
 			for {
 				line, err := term.ReadLine()
 				if err != nil {
@@ -140,6 +154,36 @@ func ExampleNewServerConn() {
 	}
 }
 
+func ExampleServerConfig_AddHostKey() {
+	// Minimal ServerConfig supporting only password authentication.
+	config := &ssh.ServerConfig{
+		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
+			// Should use constant-time compare (or better, salt+hash) in
+			// a production setting.
+			if c.User() == "testuser" && string(pass) == "tiger" {
+				return nil, nil
+			}
+			return nil, fmt.Errorf("password rejected for %q", c.User())
+		},
+	}
+
+	privateBytes, err := os.ReadFile("id_rsa")
+	if err != nil {
+		log.Fatal("Failed to load private key: ", err)
+	}
+
+	private, err := ssh.ParsePrivateKey(privateBytes)
+	if err != nil {
+		log.Fatal("Failed to parse private key: ", err)
+	}
+	// Restrict host key algorithms to disable ssh-rsa.
+	signer, err := ssh.NewSignerWithAlgorithms(private.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512})
+	if err != nil {
+		log.Fatal("Failed to create private key with restricted algorithms: ", err)
+	}
+	config.AddHostKey(signer)
+}
+
 func ExampleClientConfig_HostKeyCallback() {
 	// Every client must provide a host key check.  Here is a
 	// simple-minded parse of OpenSSH's known_hosts file
@@ -225,7 +269,7 @@ func ExamplePublicKeys() {
 	//
 	// If you have an encrypted private key, the crypto/x509 package
 	// can be used to decrypt it.
-	key, err := ioutil.ReadFile("/home/user/.ssh/id_rsa")
+	key, err := os.ReadFile("/home/user/.ssh/id_rsa")
 	if err != nil {
 		log.Fatalf("unable to read private key: %v", err)
 	}
@@ -319,3 +363,38 @@ func ExampleSession_RequestPty() {
 		log.Fatal("failed to start shell: ", err)
 	}
 }
+
+func ExampleCertificate_SignCert() {
+	// Sign a certificate with a specific algorithm.
+	privateKey, err := rsa.GenerateKey(rand.Reader, 3072)
+	if err != nil {
+		log.Fatal("unable to generate RSA key: ", err)
+	}
+	publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+	if err != nil {
+		log.Fatal("unable to get RSA public key: ", err)
+	}
+	caKey, err := rsa.GenerateKey(rand.Reader, 3072)
+	if err != nil {
+		log.Fatal("unable to generate CA key: ", err)
+	}
+	signer, err := ssh.NewSignerFromKey(caKey)
+	if err != nil {
+		log.Fatal("unable to generate signer from key: ", err)
+	}
+	mas, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256})
+	if err != nil {
+		log.Fatal("unable to create signer with algoritms: ", err)
+	}
+	certificate := ssh.Certificate{
+		Key:      publicKey,
+		CertType: ssh.UserCert,
+	}
+	if err := certificate.SignCert(rand.Reader, mas); err != nil {
+		log.Fatal("unable to sign certificate: ", err)
+	}
+	// Save the public key to a file and check that rsa-sha-256 is used for
+	// signing:
+	// ssh-keygen -L -f <path to the file>
+	fmt.Println(string(ssh.MarshalAuthorizedKey(&certificate)))
+}

+ 207 - 53
psiphon/common/crypto/ssh/handshake.go

@@ -11,6 +11,7 @@ import (
 	"io"
 	"log"
 	"net"
+	"strings"
 	"sync"
 
 	// [Psiphon]
@@ -37,6 +38,16 @@ type keyingTransport interface {
 	// direction will be effected if a msgNewKeys message is sent
 	// or received.
 	prepareKeyChange(*algorithms, *kexResult) error
+
+	// setStrictMode sets the strict KEX mode, notably triggering
+	// sequence number resets on sending or receiving msgNewKeys.
+	// If the sequence number is already > 1 when setStrictMode
+	// is called, an error is returned.
+	setStrictMode() error
+
+	// setInitialKEXDone indicates to the transport that the initial key exchange
+	// was completed
+	setInitialKEXDone()
 }
 
 // handshakeTransport implements rekeying on top of a keyingTransport
@@ -53,6 +64,10 @@ type handshakeTransport struct {
 	// connection.
 	hostKeys []Signer
 
+	// publicKeyAuthAlgorithms is non-empty if we are the server. In that case,
+	// it contains the supported client public key authentication algorithms.
+	publicKeyAuthAlgorithms []string
+
 	// hostKeyAlgorithms is non-empty if we are the client. In that case,
 	// we accept these key types from the server as host key.
 	hostKeyAlgorithms []string
@@ -61,11 +76,13 @@ type handshakeTransport struct {
 	incoming  chan []byte
 	readError error
 
-	mu             sync.Mutex
-	writeError     error
-	sentInitPacket []byte
-	sentInitMsg    *kexInitMsg
-	pendingPackets [][]byte // Used when a key exchange is in progress.
+	mu               sync.Mutex
+	writeError       error
+	sentInitPacket   []byte
+	sentInitMsg      *kexInitMsg
+	pendingPackets   [][]byte // Used when a key exchange is in progress.
+	writePacketsLeft uint32
+	writeBytesLeft   int64
 
 	// If the read loop wants to schedule a kex, it pings this
 	// channel, and the write loop will send out a kex
@@ -74,7 +91,8 @@ type handshakeTransport struct {
 
 	// If the other side requests or confirms a kex, its kexInit
 	// packet is sent here for the write loop to find it.
-	startKex chan *pendingKex
+	startKex    chan *pendingKex
+	kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits
 
 	// data for host key checking
 	hostKeyCallback HostKeyCallback
@@ -89,14 +107,16 @@ type handshakeTransport struct {
 	// Algorithms agreed in the last key exchange.
 	algorithms *algorithms
 
+	// Counters exclusively owned by readLoop.
 	readPacketsLeft uint32
 	readBytesLeft   int64
 
-	writePacketsLeft uint32
-	writeBytesLeft   int64
-
 	// The session ID or nil if first kex did not complete yet.
 	sessionID []byte
+
+	// strictMode indicates if the other side of the handshake indicated
+	// that we should be following the strict KEX protocol restrictions.
+	strictMode bool
 }
 
 type pendingKex struct {
@@ -111,7 +131,8 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
 		clientVersion: clientVersion,
 		incoming:      make(chan []byte, chanSize),
 		requestKex:    make(chan struct{}, 1),
-		startKex:      make(chan *pendingKex, 1),
+		startKex:      make(chan *pendingKex),
+		kexLoopDone:   make(chan struct{}),
 
 		config: config,
 	}
@@ -142,6 +163,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
 func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
 	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
 	t.hostKeys = config.hostKeys
+	t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms
 	go t.readLoop()
 	go t.kexLoop()
 	return t
@@ -204,7 +226,10 @@ func (t *handshakeTransport) readLoop() {
 			close(t.incoming)
 			break
 		}
-		if p[0] == msgIgnore || p[0] == msgDebug {
+		// If this is the first kex, and strict KEX mode is enabled,
+		// we don't ignore any messages, as they may be used to manipulate
+		// the packet sequence numbers.
+		if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
 			continue
 		}
 		t.incoming <- p
@@ -343,16 +368,17 @@ write:
 		t.mu.Unlock()
 	}
 
+	// Unblock reader.
+	t.conn.Close()
+
 	// drain startKex channel. We don't service t.requestKex
 	// because nobody does blocking sends there.
-	go func() {
-		for init := range t.startKex {
-			init.done <- t.writeError
-		}
-	}()
+	for request := range t.startKex {
+		request.done <- t.getWriteError()
+	}
 
-	// Unblock reader.
-	t.conn.Close()
+	// Mark that the loop is done so that Close can return.
+	close(t.kexLoopDone)
 }
 
 // The protocol uses uint32 for packet counters, so we can't let them
@@ -435,6 +461,11 @@ func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
 	return successPacket, nil
 }
 
+const (
+	kexStrictClient = "kex-strict-c-v00@openssh.com"
+	kexStrictServer = "kex-strict-s-v00@openssh.com"
+)
+
 // sendKexInit sends a key change message.
 func (t *handshakeTransport) sendKexInit() error {
 	t.mu.Lock()
@@ -448,7 +479,6 @@ func (t *handshakeTransport) sendKexInit() error {
 	}
 
 	msg := &kexInitMsg{
-		KexAlgos:                t.config.KeyExchanges,
 		CiphersClientServer:     t.config.Ciphers,
 		CiphersServerClient:     t.config.Ciphers,
 		MACsClientServer:        t.config.MACs,
@@ -458,20 +488,55 @@ func (t *handshakeTransport) sendKexInit() error {
 	}
 	io.ReadFull(rand.Reader, msg.Cookie[:])
 
-	if len(t.hostKeys) > 0 {
+	// We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm,
+	// and possibly to add the ext-info extension algorithm. Since the slice may be the
+	// user owned KeyExchanges, we create our own slice in order to avoid using user
+	// owned memory by mistake.
+	msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info
+	msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
+
+	isServer := len(t.hostKeys) > 0
+	if isServer {
 		for _, k := range t.hostKeys {
-			algo := k.PublicKey().Type()
-			switch algo {
-			case KeyAlgoRSA:
-				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, []string{SigAlgoRSASHA2512, SigAlgoRSASHA2256, SigAlgoRSA}...)
-			case CertAlgoRSAv01:
-				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, []string{CertSigAlgoRSASHA2512v01, CertSigAlgoRSASHA2256v01, CertSigAlgoRSAv01}...)
+			// If k is a MultiAlgorithmSigner, we restrict the signature
+			// algorithms. If k is a AlgorithmSigner, presume it supports all
+			// signature algorithms associated with the key format. If k is not
+			// an AlgorithmSigner, we can only assume it only supports the
+			// algorithms that matches the key format. (This means that Sign
+			// can't pick a different default).
+			keyFormat := k.PublicKey().Type()
+
+			switch s := k.(type) {
+			case MultiAlgorithmSigner:
+				for _, algo := range algorithmsForKeyFormat(keyFormat) {
+					if contains(s.Algorithms(), underlyingAlgo(algo)) {
+						msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
+					}
+				}
+			case AlgorithmSigner:
+				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...)
 			default:
-				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
+				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
 			}
 		}
+
+		if t.sessionID == nil {
+			msg.KexAlgos = append(msg.KexAlgos, kexStrictServer)
+		}
 	} else {
 		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
+
+		// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
+		// algorithms the server supports for public key authentication. See RFC
+		// 8308, Section 2.1.
+		//
+		// We also send the strict KEX mode extension algorithm, in order to opt
+		// into the strict KEX mode.
+		if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
+			msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
+			msg.KexAlgos = append(msg.KexAlgos, kexStrictClient)
+		}
+
 	}
 
 	// [Psiphon]
@@ -486,6 +551,17 @@ func (t *handshakeTransport) sendKexInit() error {
 	//
 	// When NoEncryptThenMACHash is specified, do not use Encrypt-then-MAC has
 	// algorithms.
+	//
+	// Limitations:
+	//
+	// - "ext-info-c" and "kex-strict-c/s-v00@openssh.com" extensions included
+	//    in KexAlgos may be truncated; Psiphon's usage of SSH does not
+	//    request SSH_MSG_EXT_INFO for client authentication and should not
+	//    be vulnerable to downgrade attacks related to stripping
+	//    SSH_MSG_EXT_INFO.
+	//
+	// - KEX algorithms are not synchronized with the version identification
+	//   string.
 
 	equal := func(list1, list2 []string) bool {
 		if len(list1) != len(list2) {
@@ -501,7 +577,7 @@ func (t *handshakeTransport) sendKexInit() error {
 
 	// Psiphon transforms assume that default algorithms are configured.
 	if (t.config.NoEncryptThenMACHash || t.config.KEXPRNGSeed != nil) &&
-		(!equal(t.config.KeyExchanges, supportedKexAlgos) ||
+		(!equal(t.config.KeyExchanges, preferredKexAlgos) ||
 			!equal(t.config.Ciphers, preferredCiphers) ||
 			!equal(t.config.MACs, supportedMACs)) {
 		return errors.New("ssh: custom algorithm preferences not supported")
@@ -684,7 +760,16 @@ func (t *handshakeTransport) writePacket(p []byte) error {
 }
 
 func (t *handshakeTransport) Close() error {
-	return t.conn.Close()
+	// Close the connection. This should cause the readLoop goroutine to wake up
+	// and close t.startKex, which will shut down kexLoop if running.
+	err := t.conn.Close()
+
+	// Wait for the kexLoop goroutine to complete.
+	// At that point we know that the readLoop goroutine is complete too,
+	// because kexLoop itself waits for readLoop to close the startKex channel.
+	<-t.kexLoopDone
+
+	return err
 }
 
 func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
@@ -720,6 +805,19 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 		return err
 	}
 
+	if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) &&
+
+		// [Psiphon]
+		// When KEX randomization omits "kex-strict-c/s-v00@openssh.com"
+		// (see comment in sendKexInit), do not enable strict mode.
+		((isClient && contains(t.sentInitMsg.KexAlgos, kexStrictClient)) || (!isClient && contains(t.sentInitMsg.KexAlgos, kexStrictServer))) {
+
+		t.strictMode = true
+		if err := t.conn.setStrictMode(); err != nil {
+			return err
+		}
+	}
+
 	// We don't send FirstKexFollows, but we handle receiving it.
 	//
 	// RFC 4253 section 7 defines the kex and the agreement method for
@@ -745,16 +843,17 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 
 	var result *kexResult
 	if len(t.hostKeys) > 0 {
-		result, err = t.server(kex, t.algorithms, &magics)
+		result, err = t.server(kex, &magics)
 	} else {
-		result, err = t.client(kex, t.algorithms, &magics)
+		result, err = t.client(kex, &magics)
 	}
 
 	if err != nil {
 		return err
 	}
 
-	if t.sessionID == nil {
+	firstKeyExchange := t.sessionID == nil
+	if firstKeyExchange {
 		t.sessionID = result.H
 	}
 	result.SessionID = t.sessionID
@@ -765,42 +864,97 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
 		return err
 	}
+
+	// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
+	// message with the server-sig-algs extension if the client supports it. See
+	// RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9.
+	if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
+		supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",")
+		extInfo := &extInfoMsg{
+			NumExtensions: 2,
+			Payload:       make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1),
+		}
+		extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
+		extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
+		extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
+		extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
+		extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com"))
+		extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...)
+		extInfo.Payload = appendInt(extInfo.Payload, 1)
+		extInfo.Payload = append(extInfo.Payload, "0"...)
+		if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
+			return err
+		}
+	}
+
 	if packet, err := t.conn.readPacket(); err != nil {
 		return err
 	} else if packet[0] != msgNewKeys {
 		return unexpectedMessageError(msgNewKeys, packet[0])
 	}
 
+	if firstKeyExchange {
+		// Indicates to the transport that the first key exchange is completed
+		// after receiving SSH_MSG_NEWKEYS.
+		t.conn.setInitialKEXDone()
+	}
+
 	return nil
 }
 
-func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
-	var hostKey Signer
-	for _, k := range t.hostKeys {
-		kt := k.PublicKey().Type()
-		if kt == algs.hostKey {
-			hostKey = k
-		} else if signer, ok := k.(AlgorithmSigner); ok {
-			// Some signature algorithms don't show up as key types
-			// so we have to manually check for a compatible host key.
-			switch kt {
-			case KeyAlgoRSA:
-				if algs.hostKey == SigAlgoRSASHA2256 || algs.hostKey == SigAlgoRSASHA2512 {
-					hostKey = &rsaSigner{signer, algs.hostKey}
-				}
-			case CertAlgoRSAv01:
-				if algs.hostKey == CertSigAlgoRSASHA2256v01 || algs.hostKey == CertSigAlgoRSASHA2512v01 {
-					hostKey = &rsaSigner{signer, certToPrivAlgo(algs.hostKey)}
-				}
+// algorithmSignerWrapper is an AlgorithmSigner that only supports the default
+// key format algorithm.
+//
+// This is technically a violation of the AlgorithmSigner interface, but it
+// should be unreachable given where we use this. Anyway, at least it returns an
+// error instead of panicing or producing an incorrect signature.
+type algorithmSignerWrapper struct {
+	Signer
+}
+
+func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
+	if algorithm != underlyingAlgo(a.PublicKey().Type()) {
+		return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm")
+	}
+	return a.Sign(rand, data)
+}
+
+func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
+	for _, k := range hostKeys {
+		if s, ok := k.(MultiAlgorithmSigner); ok {
+			if !contains(s.Algorithms(), underlyingAlgo(algo)) {
+				continue
 			}
 		}
+
+		if algo == k.PublicKey().Type() {
+			return algorithmSignerWrapper{k}
+		}
+
+		k, ok := k.(AlgorithmSigner)
+		if !ok {
+			continue
+		}
+		for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) {
+			if algo == a {
+				return k
+			}
+		}
+	}
+	return nil
+}
+
+func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
+	hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey)
+	if hostKey == nil {
+		return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
 	}
 
-	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
+	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey)
 	return r, err
 }
 
-func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
+func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
 	result, err := kex.Client(t.conn, t.config.Rand, magics)
 	if err != nil {
 		return nil, err
@@ -811,7 +965,7 @@ func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *
 		return nil, err
 	}
 
-	if err := verifyHostKeySignature(hostKey, algs.hostKey, result); err != nil {
+	if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil {
 		return nil, err
 	}
 

+ 476 - 17
psiphon/common/crypto/ssh/handshake_test.go

@@ -148,6 +148,7 @@ func TestHandshakeBasic(t *testing.T) {
 	clientDone := make(chan int, 0)
 	gotHalf := make(chan int, 0)
 	const N = 20
+	errorCh := make(chan error, 1)
 
 	go func() {
 		defer close(clientDone)
@@ -158,7 +159,9 @@ func TestHandshakeBasic(t *testing.T) {
 		for i := 0; i < N; i++ {
 			p := []byte{msgRequestSuccess, byte(i)}
 			if err := trC.writePacket(p); err != nil {
-				t.Fatalf("sendPacket: %v", err)
+				errorCh <- err
+				trC.Close()
+				return
 			}
 			if (i % 10) == 5 {
 				<-gotHalf
@@ -177,16 +180,15 @@ func TestHandshakeBasic(t *testing.T) {
 				checker.waitCall <- 1
 			}
 		}
+		errorCh <- nil
 	}()
 
 	// Server checks that client messages come in cleanly
 	i := 0
-	err = nil
 	for ; i < N; i++ {
-		var p []byte
-		p, err = trS.readPacket()
-		if err != nil {
-			break
+		p, err := trS.readPacket()
+		if err != nil && err != io.EOF {
+			t.Fatalf("server error: %v", err)
 		}
 		if (i % 10) == 5 {
 			gotHalf <- 1
@@ -198,8 +200,8 @@ func TestHandshakeBasic(t *testing.T) {
 		}
 	}
 	<-clientDone
-	if err != nil && err != io.EOF {
-		t.Fatalf("server error: %v", err)
+	if err := <-errorCh; err != nil {
+		t.Fatalf("sendPacket: %v", err)
 	}
 	if i != N {
 		t.Errorf("received %d messages, want 10.", i)
@@ -345,16 +347,16 @@ func TestHandshakeAutoRekeyRead(t *testing.T) {
 
 	// While we read out the packet, a key change will be
 	// initiated.
-	done := make(chan int, 1)
+	errorCh := make(chan error, 1)
 	go func() {
-		defer close(done)
-		if _, err := trC.readPacket(); err != nil {
-			t.Fatalf("readPacket(client): %v", err)
-		}
-
+		_, err := trC.readPacket()
+		errorCh <- err
 	}()
 
-	<-done
+	if err := <-errorCh; err != nil {
+		t.Fatalf("readPacket(client): %v", err)
+	}
+
 	<-sync.called
 }
 
@@ -393,6 +395,10 @@ func (n *errorKeyingTransport) readPacket() ([]byte, error) {
 	return n.packetConn.readPacket()
 }
 
+func (n *errorKeyingTransport) setStrictMode() error { return nil }
+
+func (n *errorKeyingTransport) setInitialKEXDone() {}
+
 func TestHandshakeErrorHandlingRead(t *testing.T) {
 	for i := 0; i < 20; i++ {
 		testHandshakeErrorHandlingN(t, i, -1, false)
@@ -421,8 +427,8 @@ func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
 // handshakeTransport deadlocks, the go runtime will detect it and
 // panic.
 func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
-	if runtime.GOOS == "js" && runtime.GOARCH == "wasm" {
-		t.Skip("skipping on js/wasm; see golang.org/issue/32840")
+	if (runtime.GOOS == "js" || runtime.GOOS == "wasip1") && runtime.GOARCH == "wasm" {
+		t.Skipf("skipping on %s/wasm; see golang.org/issue/32840", runtime.GOOS)
 	}
 	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
 
@@ -560,3 +566,456 @@ func TestHandshakeRekeyDefault(t *testing.T) {
 		t.Errorf("got rekey after %dG write, want 64G", wgb)
 	}
 }
+
+func TestHandshakeAEADCipherNoMAC(t *testing.T) {
+	for _, cipher := range []string{chacha20Poly1305ID, gcm128CipherID} {
+		checker := &syncChecker{
+			called: make(chan int, 1),
+		}
+		clientConf := &ClientConfig{
+			Config: Config{
+				Ciphers: []string{cipher},
+				MACs:    []string{},
+			},
+			HostKeyCallback: checker.Check,
+		}
+		trC, trS, err := handshakePair(clientConf, "addr", false)
+		if err != nil {
+			t.Fatalf("handshakePair: %v", err)
+		}
+		defer trC.Close()
+		defer trS.Close()
+
+		<-checker.called
+	}
+}
+
+// TestNoSHA2Support tests a host key Signer that is not an AlgorithmSigner and
+// therefore can't do SHA-2 signatures. Ensures the server does not advertise
+// support for them in this case.
+func TestNoSHA2Support(t *testing.T) {
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	serverConf := &ServerConfig{
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			return &Permissions{}, nil
+		},
+	}
+	serverConf.AddHostKey(&legacyRSASigner{testSigners["rsa"]})
+	go func() {
+		_, _, _, err := NewServerConn(c1, serverConf)
+		if err != nil {
+			t.Error(err)
+		}
+	}()
+
+	clientConf := &ClientConfig{
+		User:            "test",
+		Auth:            []AuthMethod{Password("testpw")},
+		HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()),
+	}
+
+	if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestMultiAlgoSignerHandshake(t *testing.T) {
+	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer: %v", err)
+	}
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	serverConf := &ServerConfig{
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			return &Permissions{}, nil
+		},
+	}
+	serverConf.AddHostKey(multiAlgoSigner)
+	go NewServerConn(c1, serverConf)
+
+	clientConf := &ClientConfig{
+		User:              "test",
+		Auth:              []AuthMethod{Password("testpw")},
+		HostKeyCallback:   FixedHostKey(testSigners["rsa"].PublicKey()),
+		HostKeyAlgorithms: []string{KeyAlgoRSASHA512},
+	}
+
+	if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestMultiAlgoSignerNoCommonHostKeyAlgo(t *testing.T) {
+	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer: %v", err)
+	}
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	// ssh-rsa is disabled server side
+	serverConf := &ServerConfig{
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			return &Permissions{}, nil
+		},
+	}
+	serverConf.AddHostKey(multiAlgoSigner)
+	go NewServerConn(c1, serverConf)
+
+	// the client only supports ssh-rsa
+	clientConf := &ClientConfig{
+		User:              "test",
+		Auth:              []AuthMethod{Password("testpw")},
+		HostKeyCallback:   FixedHostKey(testSigners["rsa"].PublicKey()),
+		HostKeyAlgorithms: []string{KeyAlgoRSA},
+	}
+
+	_, _, _, err = NewClientConn(c2, "", clientConf)
+	if err == nil {
+		t.Fatal("succeeded connecting with no common hostkey algorithm")
+	}
+}
+
+func TestPickIncompatibleHostKeyAlgo(t *testing.T) {
+	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
+	if err != nil {
+		t.Fatalf("unable to create multi algorithm signer: %v", err)
+	}
+	signer := pickHostKey([]Signer{multiAlgoSigner}, KeyAlgoRSA)
+	if signer != nil {
+		t.Fatal("incompatible signer returned")
+	}
+}
+
+func TestStrictKEXResetSeqFirstKEX(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+
+	checker := &syncChecker{
+		waitCall: make(chan int, 10),
+		called:   make(chan int, 10),
+	}
+
+	checker.waitCall <- 1
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+	<-checker.called
+
+	t.Cleanup(func() {
+		trC.Close()
+		trS.Close()
+	})
+
+	// Throw away the msgExtInfo packet sent during the handshake by the server
+	_, err = trC.readPacket()
+	if err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	}
+
+	// close the handshake transports before checking the sequence number to
+	// avoid races.
+	trC.Close()
+	trS.Close()
+
+	// check that the sequence number counters. We reset after msgNewKeys, but
+	// then the server immediately writes msgExtInfo, and we close the
+	// transports so we expect read 2, write 0 on the client and read 1, write 1
+	// on the server.
+	if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
+		trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
+		t.Errorf(
+			"unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
+			trC.conn.(*transport).reader.seqNum,
+			trC.conn.(*transport).writer.seqNum,
+			trS.conn.(*transport).reader.seqNum,
+			trS.conn.(*transport).writer.seqNum,
+		)
+	}
+}
+
+func TestStrictKEXResetSeqSuccessiveKEX(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+
+	checker := &syncChecker{
+		waitCall: make(chan int, 10),
+		called:   make(chan int, 10),
+	}
+
+	checker.waitCall <- 1
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+	<-checker.called
+
+	t.Cleanup(func() {
+		trC.Close()
+		trS.Close()
+	})
+
+	// Throw away the msgExtInfo packet sent during the handshake by the server
+	_, err = trC.readPacket()
+	if err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	}
+
+	// write and read five packets on either side to bump the sequence numbers
+	for i := 0; i < 5; i++ {
+		if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
+			t.Fatalf("writePacket failed: %s", err)
+		}
+		if _, err := trS.readPacket(); err != nil {
+			t.Fatalf("readPacket failed: %s", err)
+		}
+		if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
+			t.Fatalf("writePacket failed: %s", err)
+		}
+		if _, err := trC.readPacket(); err != nil {
+			t.Fatalf("readPacket failed: %s", err)
+		}
+	}
+
+	// Request a key exchange, which should cause the sequence numbers to reset
+	checker.waitCall <- 1
+	trC.requestKeyExchange()
+	<-checker.called
+
+	// write a packet on the client, and then read it, to verify the key change has actually happened, since
+	// the HostKeyCallback is called _during_ the handshake, so isn't actually indicative of the handshake
+	// finishing.
+	dummyPacket := []byte{99}
+	if err := trS.writePacket(dummyPacket); err != nil {
+		t.Fatalf("writePacket failed: %s", err)
+	}
+	if p, err := trC.readPacket(); err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	} else if !bytes.Equal(p, dummyPacket) {
+		t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
+	}
+
+	// close the handshake transports before checking the sequence number to
+	// avoid races.
+	trC.Close()
+	trS.Close()
+
+	if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
+		trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
+		t.Errorf(
+			"unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
+			trC.conn.(*transport).reader.seqNum,
+			trC.conn.(*transport).writer.seqNum,
+			trS.conn.(*transport).reader.seqNum,
+			trS.conn.(*transport).writer.seqNum,
+		)
+	}
+}
+
+func TestSeqNumIncrease(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+
+	checker := &syncChecker{
+		waitCall: make(chan int, 10),
+		called:   make(chan int, 10),
+	}
+
+	checker.waitCall <- 1
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+	<-checker.called
+
+	t.Cleanup(func() {
+		trC.Close()
+		trS.Close()
+	})
+
+	// Throw away the msgExtInfo packet sent during the handshake by the server
+	_, err = trC.readPacket()
+	if err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	}
+
+	// write and read five packets on either side to bump the sequence numbers
+	for i := 0; i < 5; i++ {
+		if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
+			t.Fatalf("writePacket failed: %s", err)
+		}
+		if _, err := trS.readPacket(); err != nil {
+			t.Fatalf("readPacket failed: %s", err)
+		}
+		if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
+			t.Fatalf("writePacket failed: %s", err)
+		}
+		if _, err := trC.readPacket(); err != nil {
+			t.Fatalf("readPacket failed: %s", err)
+		}
+	}
+
+	// close the handshake transports before checking the sequence number to
+	// avoid races.
+	trC.Close()
+	trS.Close()
+
+	if trC.conn.(*transport).reader.seqNum != 7 || trC.conn.(*transport).writer.seqNum != 5 ||
+		trS.conn.(*transport).reader.seqNum != 6 || trS.conn.(*transport).writer.seqNum != 6 {
+		t.Errorf(
+			"unexpected sequence counters:\nclient: reader %d (expected 7), writer %d (expected 5)\nserver: reader %d (expected 6), writer %d (expected 6)",
+			trC.conn.(*transport).reader.seqNum,
+			trC.conn.(*transport).writer.seqNum,
+			trS.conn.(*transport).reader.seqNum,
+			trS.conn.(*transport).writer.seqNum,
+		)
+	}
+}
+
+func TestStrictKEXUnexpectedMsg(t *testing.T) {
+	if runtime.GOOS == "plan9" {
+		t.Skip("see golang.org/issue/7237")
+	}
+
+	// Check that unexpected messages during the handshake cause failure
+	_, _, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", true)
+	if err == nil {
+		t.Fatal("handshake should fail when there are unexpected messages during the handshake")
+	}
+
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", false)
+	if err != nil {
+		t.Fatalf("handshake failed: %s", err)
+	}
+
+	// Check that ignore/debug pacekts are still ignored outside of the handshake
+	if err := trC.writePacket([]byte{msgIgnore}); err != nil {
+		t.Fatalf("writePacket failed: %s", err)
+	}
+	if err := trC.writePacket([]byte{msgDebug}); err != nil {
+		t.Fatalf("writePacket failed: %s", err)
+	}
+	dummyPacket := []byte{99}
+	if err := trC.writePacket(dummyPacket); err != nil {
+		t.Fatalf("writePacket failed: %s", err)
+	}
+
+	if p, err := trS.readPacket(); err != nil {
+		t.Fatalf("readPacket failed: %s", err)
+	} else if !bytes.Equal(p, dummyPacket) {
+		t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
+	}
+}
+
+func TestStrictKEXMixed(t *testing.T) {
+	// Test that we still support a mixed connection, where one side sends kex-strict but the other
+	// side doesn't.
+
+	a, b, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe failed: %s", err)
+	}
+
+	var trC, trS keyingTransport
+
+	trC = newTransport(a, rand.Reader, true)
+	trS = newTransport(b, rand.Reader, false)
+	trS = addNoiseTransport(trS)
+
+	clientConf := &ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+
+	transport := newHandshakeTransport(trS, &serverConf.Config, []byte("version"), []byte("version"))
+	transport.hostKeys = serverConf.hostKeys
+	transport.publicKeyAuthAlgorithms = serverConf.PublicKeyAuthAlgorithms
+
+	readOneFailure := make(chan error, 1)
+	go func() {
+		if _, err := transport.readOnePacket(true); err != nil {
+			readOneFailure <- err
+		}
+	}()
+
+	// Basically sendKexInit, but without the kex-strict extension algorithm
+	msg := &kexInitMsg{
+		KexAlgos:                transport.config.KeyExchanges,
+		CiphersClientServer:     transport.config.Ciphers,
+		CiphersServerClient:     transport.config.Ciphers,
+		MACsClientServer:        transport.config.MACs,
+		MACsServerClient:        transport.config.MACs,
+		CompressionClientServer: supportedCompressions,
+		CompressionServerClient: supportedCompressions,
+		ServerHostKeyAlgos:      []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA},
+	}
+	packet := Marshal(msg)
+	// writePacket destroys the contents, so save a copy.
+	packetCopy := make([]byte, len(packet))
+	copy(packetCopy, packet)
+	if err := transport.pushPacket(packetCopy); err != nil {
+		t.Fatalf("pushPacket: %s", err)
+	}
+	transport.sentInitMsg = msg
+	transport.sentInitPacket = packet
+
+	if err := transport.getWriteError(); err != nil {
+		t.Fatalf("getWriteError failed: %s", err)
+	}
+	var request *pendingKex
+	select {
+	case err = <-readOneFailure:
+		t.Fatalf("server readOnePacket failed: %s", err)
+	case request = <-transport.startKex:
+		break
+	}
+
+	// We expect the following calls to fail if the side which does not support
+	// kex-strict sends unexpected/ignored packets during the handshake, even if
+	// the other side does support kex-strict.
+
+	if err := transport.enterKeyExchange(request.otherInit); err != nil {
+		t.Fatalf("enterKeyExchange failed: %s", err)
+	}
+	if err := client.waitSession(); err != nil {
+		t.Fatalf("client.waitSession: %v", err)
+	}
+}

+ 96 - 117
psiphon/common/crypto/ssh/kex.go

@@ -20,13 +20,15 @@ import (
 )
 
 const (
-	kexAlgoDH1SHA1          = "diffie-hellman-group1-sha1"
-	kexAlgoDH14SHA1         = "diffie-hellman-group14-sha1"
-	kexAlgoDH14SHA256       = "diffie-hellman-group14-sha256"
-	kexAlgoECDH256          = "ecdh-sha2-nistp256"
-	kexAlgoECDH384          = "ecdh-sha2-nistp384"
-	kexAlgoECDH521          = "ecdh-sha2-nistp521"
-	kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org"
+	kexAlgoDH1SHA1                = "diffie-hellman-group1-sha1"
+	kexAlgoDH14SHA1               = "diffie-hellman-group14-sha1"
+	kexAlgoDH14SHA256             = "diffie-hellman-group14-sha256"
+	kexAlgoDH16SHA512             = "diffie-hellman-group16-sha512"
+	kexAlgoECDH256                = "ecdh-sha2-nistp256"
+	kexAlgoECDH384                = "ecdh-sha2-nistp384"
+	kexAlgoECDH521                = "ecdh-sha2-nistp521"
+	kexAlgoCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org"
+	kexAlgoCurve25519SHA256       = "curve25519-sha256"
 
 	// For the following kex only the client half contains a production
 	// ready implementation. The server half only consists of a minimal
@@ -76,8 +78,9 @@ func (m *handshakeMagics) write(w io.Writer) {
 // kexAlgorithm abstracts different key exchange algorithms.
 type kexAlgorithm interface {
 	// Server runs server-side key agreement, signing the result
-	// with a hostkey.
-	Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error)
+	// with a hostkey. algo is the negotiated algorithm, and may
+	// be a certificate type.
+	Server(p packetConn, rand io.Reader, magics *handshakeMagics, s AlgorithmSigner, algo string) (*kexResult, error)
 
 	// Client runs the client-side key agreement. Caller is
 	// responsible for verifying the host key signature.
@@ -87,9 +90,7 @@ type kexAlgorithm interface {
 // dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
 type dhGroup struct {
 	g, p, pMinus1 *big.Int
-
-	// [Psiphon]
-	hashFunc crypto.Hash
+	hashFunc      crypto.Hash
 }
 
 func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
@@ -100,10 +101,6 @@ func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int,
 }
 
 func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
-
-	// [Psiphon]
-	hashFunc := group.hashFunc
-
 	var x *big.Int
 	for {
 		var err error
@@ -138,7 +135,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
 		return nil, err
 	}
 
-	h := hashFunc.New()
+	h := group.hashFunc.New()
 	magics.write(h)
 	writeString(h, kexDHReply.HostKey)
 	writeInt(h, X)
@@ -152,15 +149,11 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
 		K:         K,
 		HostKey:   kexDHReply.HostKey,
 		Signature: kexDHReply.Signature,
-		Hash:      hashFunc,
+		Hash:      group.hashFunc,
 	}, nil
 }
 
-func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
-
-	// [Psiphon]
-	hashFunc := group.hashFunc
-
+func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
 	packet, err := c.readPacket()
 	if err != nil {
 		return
@@ -188,7 +181,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
 
 	hostKeyBytes := priv.PublicKey().Marshal()
 
-	h := hashFunc.New()
+	h := group.hashFunc.New()
 	magics.write(h)
 	writeString(h, hostKeyBytes)
 	writeInt(h, kexDHInit.X)
@@ -202,7 +195,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
 
 	// H is already a hash, but the hostkey signing will apply its
 	// own key-specific hash algorithm.
-	sig, err := signAndMarshal(priv, randSource, H)
+	sig, err := signAndMarshal(priv, randSource, H, algo)
 	if err != nil {
 		return nil, err
 	}
@@ -220,7 +213,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
 		K:         K,
 		HostKey:   hostKeyBytes,
 		Signature: sig,
-		Hash:      hashFunc,
+		Hash:      group.hashFunc,
 	}, err
 }
 
@@ -323,7 +316,7 @@ func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
 	return true
 }
 
-func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
 	packet, err := c.readPacket()
 	if err != nil {
 		return nil, err
@@ -368,7 +361,7 @@ func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, p
 
 	// H is already a hash, but the hostkey signing will apply its
 	// own key-specific hash algorithm.
-	sig, err := signAndMarshal(priv, rand, H)
+	sig, err := signAndMarshal(priv, rand, H, algo)
 	if err != nil {
 		return nil, err
 	}
@@ -393,25 +386,37 @@ func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, p
 	}, nil
 }
 
+// ecHash returns the hash to match the given elliptic curve, see RFC
+// 5656, section 6.2.1
+func ecHash(curve elliptic.Curve) crypto.Hash {
+	bitSize := curve.Params().BitSize
+	switch {
+	case bitSize <= 256:
+		return crypto.SHA256
+	case bitSize <= 384:
+		return crypto.SHA384
+	}
+	return crypto.SHA512
+}
+
 var kexAlgoMap = map[string]kexAlgorithm{}
 
 func init() {
-	// This is the group called diffie-hellman-group1-sha1 in RFC
-	// 4253 and Oakley Group 2 in RFC 2409.
+	// This is the group called diffie-hellman-group1-sha1 in
+	// RFC 4253 and Oakley Group 2 in RFC 2409.
 	p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
 	kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
-		g:       new(big.Int).SetInt64(2),
-		p:       p,
-		pMinus1: new(big.Int).Sub(p, bigOne),
-
+		g:        new(big.Int).SetInt64(2),
+		p:        p,
+		pMinus1:  new(big.Int).Sub(p, bigOne),
 		hashFunc: crypto.SHA1,
 	}
 
-	// This is the group called diffie-hellman-group14-sha1 in RFC
-	// 4253 and Oakley Group 14 in RFC 3526.
+	// This are the groups called diffie-hellman-group14-sha1 and
+	// diffie-hellman-group14-sha256 in RFC 4253 and RFC 8268,
+	// and Oakley Group 14 in RFC 3526.
 	p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
-
-	kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
+	group14 := &dhGroup{
 		g:       new(big.Int).SetInt64(2),
 		p:       p,
 		pMinus1: new(big.Int).Sub(p, bigOne),
@@ -419,31 +424,37 @@ func init() {
 		hashFunc: crypto.SHA1,
 	}
 
-	// [Psiphon]
-	// RFC 8268:
-	// > The method of key exchange used for the name "diffie-hellman-
-	// > group14-sha256" is the same as that for "diffie-hellman-group14-sha1"
-	// > except that the SHA256 hash algorithm is used.
-
+	kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
+		g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
+		hashFunc: crypto.SHA1,
+	}
 	kexAlgoMap[kexAlgoDH14SHA256] = &dhGroup{
-		g:       new(big.Int).SetInt64(2),
-		p:       p,
-		pMinus1: new(big.Int).Sub(p, bigOne),
-
+		g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
 		hashFunc: crypto.SHA256,
 	}
 
+	// This is the group called diffie-hellman-group16-sha512 in RFC
+	// 8268 and Oakley Group 16 in RFC 3526.
+	p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16)
+
+	kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{
+		g:        new(big.Int).SetInt64(2),
+		p:        p,
+		pMinus1:  new(big.Int).Sub(p, bigOne),
+		hashFunc: crypto.SHA512,
+	}
+
 	kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
 	kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
 	kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
 	kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{}
+	kexAlgoMap[kexAlgoCurve25519SHA256LibSSH] = &curve25519sha256{}
 	kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1}
 	kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256}
 }
 
-// curve25519sha256 implements the curve25519-sha256@libssh.org key
-// agreement protocol, as described in
-// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
+// curve25519sha256 implements the curve25519-sha256 (formerly known as
+// curve25519-sha256@libssh.org) key exchange method, as described in RFC 8731.
 type curve25519sha256 struct{}
 
 type curve25519KeyPair struct {
@@ -513,7 +524,7 @@ func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handsh
 	}, nil
 }
 
-func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
 	packet, err := c.readPacket()
 	if err != nil {
 		return
@@ -554,7 +565,7 @@ func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handsh
 
 	H := h.Sum(nil)
 
-	sig, err := signAndMarshal(priv, rand, H)
+	sig, err := signAndMarshal(priv, rand, H, algo)
 	if err != nil {
 		return nil, err
 	}
@@ -580,7 +591,6 @@ func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handsh
 // diffie-hellman-group-exchange-sha256 key agreement protocols,
 // as described in RFC 4419
 type dhGEXSHA struct {
-	g, p     *big.Int
 	hashFunc crypto.Hash
 }
 
@@ -590,14 +600,7 @@ const (
 	dhGroupExchangeMaximumBits   = 8192
 )
 
-func (gex *dhGEXSHA) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
-	if theirPublic.Sign() <= 0 || theirPublic.Cmp(gex.p) >= 0 {
-		return nil, fmt.Errorf("ssh: DH parameter out of bounds")
-	}
-	return new(big.Int).Exp(theirPublic, myPrivate, gex.p), nil
-}
-
-func (gex dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
+func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
 	// Send GexRequest
 	kexDHGexRequest := kexDHGexRequestMsg{
 		MinBits:      dhGroupExchangeMinimumBits,
@@ -614,35 +617,29 @@ func (gex dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshake
 		return nil, err
 	}
 
-	var kexDHGexGroup kexDHGexGroupMsg
-	if err = Unmarshal(packet, &kexDHGexGroup); err != nil {
+	var msg kexDHGexGroupMsg
+	if err = Unmarshal(packet, &msg); err != nil {
 		return nil, err
 	}
 
 	// reject if p's bit length < dhGroupExchangeMinimumBits or > dhGroupExchangeMaximumBits
-	if kexDHGexGroup.P.BitLen() < dhGroupExchangeMinimumBits || kexDHGexGroup.P.BitLen() > dhGroupExchangeMaximumBits {
-		return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", kexDHGexGroup.P.BitLen())
+	if msg.P.BitLen() < dhGroupExchangeMinimumBits || msg.P.BitLen() > dhGroupExchangeMaximumBits {
+		return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", msg.P.BitLen())
 	}
 
-	gex.p = kexDHGexGroup.P
-	gex.g = kexDHGexGroup.G
-
-	// Check if g is safe by verifing that g > 1 and g < p - 1
-	one := big.NewInt(1)
-	var pMinusOne = &big.Int{}
-	pMinusOne.Sub(gex.p, one)
-	if gex.g.Cmp(one) != 1 && gex.g.Cmp(pMinusOne) != -1 {
+	// Check if g is safe by verifying that 1 < g < p-1
+	pMinusOne := new(big.Int).Sub(msg.P, bigOne)
+	if msg.G.Cmp(bigOne) <= 0 || msg.G.Cmp(pMinusOne) >= 0 {
 		return nil, fmt.Errorf("ssh: server provided gex g is not safe")
 	}
 
 	// Send GexInit
-	var pHalf = &big.Int{}
-	pHalf.Rsh(gex.p, 1)
+	pHalf := new(big.Int).Rsh(msg.P, 1)
 	x, err := rand.Int(randSource, pHalf)
 	if err != nil {
 		return nil, err
 	}
-	X := new(big.Int).Exp(gex.g, x, gex.p)
+	X := new(big.Int).Exp(msg.G, x, msg.P)
 	kexDHGexInit := kexDHGexInitMsg{
 		X: X,
 	}
@@ -661,13 +658,13 @@ func (gex dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshake
 		return nil, err
 	}
 
-	kInt, err := gex.diffieHellman(kexDHGexReply.Y, x)
-	if err != nil {
-		return nil, err
+	if kexDHGexReply.Y.Cmp(bigOne) <= 0 || kexDHGexReply.Y.Cmp(pMinusOne) >= 0 {
+		return nil, errors.New("ssh: DH parameter out of bounds")
 	}
+	kInt := new(big.Int).Exp(kexDHGexReply.Y, x, msg.P)
 
-	// Check if k is safe by verifing that k > 1 and k < p - 1
-	if kInt.Cmp(one) != 1 && kInt.Cmp(pMinusOne) != -1 {
+	// Check if k is safe by verifying that k > 1 and k < p - 1
+	if kInt.Cmp(bigOne) <= 0 || kInt.Cmp(pMinusOne) >= 0 {
 		return nil, fmt.Errorf("ssh: derived k is not safe")
 	}
 
@@ -677,8 +674,8 @@ func (gex dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshake
 	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
 	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
 	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
-	writeInt(h, gex.p)
-	writeInt(h, gex.g)
+	writeInt(h, msg.P)
+	writeInt(h, msg.G)
 	writeInt(h, X)
 	writeInt(h, kexDHGexReply.Y)
 	K := make([]byte, intLength(kInt))
@@ -697,7 +694,7 @@ func (gex dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshake
 // Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256.
 //
 // This is a minimal implementation to satisfy the automated tests.
-func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
 	// Receive GexRequest
 	packet, err := c.readPacket()
 	if err != nil {
@@ -708,35 +705,17 @@ func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshake
 		return
 	}
 
-	// smoosh the user's preferred size into our own limits
-	if kexDHGexRequest.PreferedBits > dhGroupExchangeMaximumBits {
-		kexDHGexRequest.PreferedBits = dhGroupExchangeMaximumBits
-	}
-	if kexDHGexRequest.PreferedBits < dhGroupExchangeMinimumBits {
-		kexDHGexRequest.PreferedBits = dhGroupExchangeMinimumBits
-	}
-	// fix min/max if they're inconsistent.  technically, we could just pout
-	// and hang up, but there's no harm in giving them the benefit of the
-	// doubt and just picking a bitsize for them.
-	if kexDHGexRequest.MinBits > kexDHGexRequest.PreferedBits {
-		kexDHGexRequest.MinBits = kexDHGexRequest.PreferedBits
-	}
-	if kexDHGexRequest.MaxBits < kexDHGexRequest.PreferedBits {
-		kexDHGexRequest.MaxBits = kexDHGexRequest.PreferedBits
-	}
-
 	// Send GexGroup
 	// This is the group called diffie-hellman-group14-sha1 in RFC
 	// 4253 and Oakley Group 14 in RFC 3526.
 	p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
-	gex.p = p
-	gex.g = big.NewInt(2)
+	g := big.NewInt(2)
 
-	kexDHGexGroup := kexDHGexGroupMsg{
-		P: gex.p,
-		G: gex.g,
+	msg := &kexDHGexGroupMsg{
+		P: p,
+		G: g,
 	}
-	if err := c.writePacket(Marshal(&kexDHGexGroup)); err != nil {
+	if err := c.writePacket(Marshal(msg)); err != nil {
 		return nil, err
 	}
 
@@ -750,19 +729,19 @@ func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshake
 		return
 	}
 
-	var pHalf = &big.Int{}
-	pHalf.Rsh(gex.p, 1)
+	pHalf := new(big.Int).Rsh(p, 1)
 
 	y, err := rand.Int(randSource, pHalf)
 	if err != nil {
 		return
 	}
+	Y := new(big.Int).Exp(g, y, p)
 
-	Y := new(big.Int).Exp(gex.g, y, gex.p)
-	kInt, err := gex.diffieHellman(kexDHGexInit.X, y)
-	if err != nil {
-		return nil, err
+	pMinusOne := new(big.Int).Sub(p, bigOne)
+	if kexDHGexInit.X.Cmp(bigOne) <= 0 || kexDHGexInit.X.Cmp(pMinusOne) >= 0 {
+		return nil, errors.New("ssh: DH parameter out of bounds")
 	}
+	kInt := new(big.Int).Exp(kexDHGexInit.X, y, p)
 
 	hostKeyBytes := priv.PublicKey().Marshal()
 
@@ -772,8 +751,8 @@ func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshake
 	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
 	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
 	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
-	writeInt(h, gex.p)
-	writeInt(h, gex.g)
+	writeInt(h, p)
+	writeInt(h, g)
 	writeInt(h, kexDHGexInit.X)
 	writeInt(h, Y)
 
@@ -785,7 +764,7 @@ func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshake
 
 	// H is already a hash, but the hostkey signing will apply its
 	// own key-specific hash algorithm.
-	sig, err := signAndMarshal(priv, randSource, H)
+	sig, err := signAndMarshal(priv, randSource, H, algo)
 	if err != nil {
 		return nil, err
 	}

+ 42 - 1
psiphon/common/crypto/ssh/kex_test.go

@@ -8,6 +8,7 @@ package ssh
 
 import (
 	"crypto/rand"
+	"fmt"
 	"reflect"
 	"sync"
 	"testing"
@@ -41,7 +42,7 @@ func TestKexes(t *testing.T) {
 						c <- kexResultErr{r, e}
 					}()
 					go func() {
-						r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
+						r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type())
 						b.Close()
 						s <- kexResultErr{r, e}
 					}()
@@ -63,3 +64,43 @@ func TestKexes(t *testing.T) {
 		})
 	}
 }
+
+func BenchmarkKexes(b *testing.B) {
+	type kexResultErr struct {
+		result *kexResult
+		err    error
+	}
+
+	for name, kex := range kexAlgoMap {
+		b.Run(name, func(b *testing.B) {
+			for i := 0; i < b.N; i++ {
+				t1, t2 := memPipe()
+
+				s := make(chan kexResultErr, 1)
+				c := make(chan kexResultErr, 1)
+				var magics handshakeMagics
+
+				go func() {
+					r, e := kex.Client(t1, rand.Reader, &magics)
+					t1.Close()
+					c <- kexResultErr{r, e}
+				}()
+				go func() {
+					r, e := kex.Server(t2, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type())
+					t2.Close()
+					s <- kexResultErr{r, e}
+				}()
+
+				clientRes := <-c
+				serverRes := <-s
+
+				if clientRes.err != nil {
+					panic(fmt.Sprintf("client: %v", clientRes.err))
+				}
+				if serverRes.err != nil {
+					panic(fmt.Sprintf("server: %v", serverRes.err))
+				}
+			}
+		})
+	}
+}

+ 396 - 142
psiphon/common/crypto/ssh/keys.go

@@ -11,13 +11,16 @@ import (
 	"crypto/cipher"
 	"crypto/dsa"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/elliptic"
 	"crypto/md5"
+	"crypto/rand"
 	"crypto/rsa"
 	"crypto/sha256"
 	"crypto/x509"
 	"encoding/asn1"
 	"encoding/base64"
+	"encoding/binary"
 	"encoding/hex"
 	"encoding/pem"
 	"errors"
@@ -27,11 +30,11 @@ import (
 	"strings"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/internal/bcrypt_pbkdf"
-	"golang.org/x/crypto/ed25519"
 )
 
-// These constants represent the algorithm names for key types supported by this
-// package.
+// Public key algorithms names. These values can appear in PublicKey.Type,
+// ClientConfig.HostKeyAlgorithms, Signature.Format, or as AlgorithmSigner
+// arguments.
 const (
 	KeyAlgoRSA        = "ssh-rsa"
 	KeyAlgoDSA        = "ssh-dss"
@@ -41,16 +44,21 @@ const (
 	KeyAlgoECDSA521   = "ecdsa-sha2-nistp521"
 	KeyAlgoED25519    = "ssh-ed25519"
 	KeyAlgoSKED25519  = "sk-ssh-ed25519@openssh.com"
+
+	// KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, not
+	// public key formats, so they can't appear as a PublicKey.Type. The
+	// corresponding PublicKey.Type is KeyAlgoRSA. See RFC 8332, Section 2.
+	KeyAlgoRSASHA256 = "rsa-sha2-256"
+	KeyAlgoRSASHA512 = "rsa-sha2-512"
 )
 
-// These constants represent non-default signature algorithms that are supported
-// as algorithm parameters to AlgorithmSigner.SignWithAlgorithm methods. See
-// [PROTOCOL.agent] section 4.5.1 and
-// https://tools.ietf.org/html/draft-ietf-curdle-rsa-sha2-10
 const (
-	SigAlgoRSA        = "ssh-rsa"
-	SigAlgoRSASHA2256 = "rsa-sha2-256"
-	SigAlgoRSASHA2512 = "rsa-sha2-512"
+	// Deprecated: use KeyAlgoRSA.
+	SigAlgoRSA = KeyAlgoRSA
+	// Deprecated: use KeyAlgoRSASHA256.
+	SigAlgoRSASHA2256 = KeyAlgoRSASHA256
+	// Deprecated: use KeyAlgoRSASHA512.
+	SigAlgoRSASHA2512 = KeyAlgoRSASHA512
 )
 
 // parsePubKey parses a public key of the given algorithm.
@@ -70,7 +78,7 @@ func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err err
 	case KeyAlgoSKED25519:
 		return parseSKEd25519(in)
 	case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
-		cert, err := parseCert(in, certToPrivAlgo(algo))
+		cert, err := parseCert(in, certKeyAlgoNames[algo])
 		if err != nil {
 			return nil, nil, err
 		}
@@ -178,7 +186,7 @@ func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey
 	return "", nil, nil, "", nil, io.EOF
 }
 
-// ParseAuthorizedKeys parses a public key from an authorized_keys
+// ParseAuthorizedKey parses a public key from an authorized_keys
 // file used in OpenSSH according to the sshd(8) manual page.
 func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
 	for len(in) > 0 {
@@ -289,18 +297,33 @@ func MarshalAuthorizedKey(key PublicKey) []byte {
 	return b.Bytes()
 }
 
-// PublicKey is an abstraction of different types of public keys.
+// MarshalPrivateKey returns a PEM block with the private key serialized in the
+// OpenSSH format.
+func MarshalPrivateKey(key crypto.PrivateKey, comment string) (*pem.Block, error) {
+	return marshalOpenSSHPrivateKey(key, comment, unencryptedOpenSSHMarshaler)
+}
+
+// MarshalPrivateKeyWithPassphrase returns a PEM block holding the encrypted
+// private key serialized in the OpenSSH format.
+func MarshalPrivateKeyWithPassphrase(key crypto.PrivateKey, comment string, passphrase []byte) (*pem.Block, error) {
+	return marshalOpenSSHPrivateKey(key, comment, passphraseProtectedOpenSSHMarshaler(passphrase))
+}
+
+// PublicKey represents a public key using an unspecified algorithm.
+//
+// Some PublicKeys provided by this package also implement CryptoPublicKey.
 type PublicKey interface {
-	// Type returns the key's type, e.g. "ssh-rsa".
+	// Type returns the key format name, e.g. "ssh-rsa".
 	Type() string
 
-	// Marshal returns the serialized key data in SSH wire format,
-	// with the name prefix. To unmarshal the returned data, use
-	// the ParsePublicKey function.
+	// Marshal returns the serialized key data in SSH wire format, with the name
+	// prefix. To unmarshal the returned data, use the ParsePublicKey function.
 	Marshal() []byte
 
-	// Verify that sig is a signature on the given data using this
-	// key. This function will hash the data appropriately first.
+	// Verify that sig is a signature on the given data using this key. This
+	// method will hash the data appropriately first. sig.Format is allowed to
+	// be any signature algorithm compatible with the key type, the caller
+	// should check if it has more stringent requirements.
 	Verify(data []byte, sig *Signature) error
 }
 
@@ -311,28 +334,104 @@ type CryptoPublicKey interface {
 }
 
 // A Signer can create signatures that verify against a public key.
+//
+// Some Signers provided by this package also implement MultiAlgorithmSigner.
 type Signer interface {
-	// PublicKey returns an associated PublicKey instance.
+	// PublicKey returns the associated PublicKey.
 	PublicKey() PublicKey
 
-	// Sign returns raw signature for the given data. This method
-	// will apply the hash specified for the keytype to the data.
+	// Sign returns a signature for the given data. This method will hash the
+	// data appropriately first. The signature algorithm is expected to match
+	// the key format returned by the PublicKey.Type method (and not to be any
+	// alternative algorithm supported by the key format).
 	Sign(rand io.Reader, data []byte) (*Signature, error)
 }
 
-// A AlgorithmSigner is a Signer that also supports specifying a specific
-// algorithm to use for signing.
+// An AlgorithmSigner is a Signer that also supports specifying an algorithm to
+// use for signing.
+//
+// An AlgorithmSigner can't advertise the algorithms it supports, unless it also
+// implements MultiAlgorithmSigner, so it should be prepared to be invoked with
+// every algorithm supported by the public key format.
 type AlgorithmSigner interface {
 	Signer
 
-	// SignWithAlgorithm is like Signer.Sign, but allows specification of a
-	// non-default signing algorithm. See the SigAlgo* constants in this
-	// package for signature algorithms supported by this package. Callers may
-	// pass an empty string for the algorithm in which case the AlgorithmSigner
-	// will use its default algorithm.
+	// SignWithAlgorithm is like Signer.Sign, but allows specifying a desired
+	// signing algorithm. Callers may pass an empty string for the algorithm in
+	// which case the AlgorithmSigner will use a default algorithm. This default
+	// doesn't currently control any behavior in this package.
 	SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error)
 }
 
+// MultiAlgorithmSigner is an AlgorithmSigner that also reports the algorithms
+// supported by that signer.
+type MultiAlgorithmSigner interface {
+	AlgorithmSigner
+
+	// Algorithms returns the available algorithms in preference order. The list
+	// must not be empty, and it must not include certificate types.
+	Algorithms() []string
+}
+
+// NewSignerWithAlgorithms returns a signer restricted to the specified
+// algorithms. The algorithms must be set in preference order. The list must not
+// be empty, and it must not include certificate types. An error is returned if
+// the specified algorithms are incompatible with the public key type.
+func NewSignerWithAlgorithms(signer AlgorithmSigner, algorithms []string) (MultiAlgorithmSigner, error) {
+	if len(algorithms) == 0 {
+		return nil, errors.New("ssh: please specify at least one valid signing algorithm")
+	}
+	var signerAlgos []string
+	supportedAlgos := algorithmsForKeyFormat(underlyingAlgo(signer.PublicKey().Type()))
+	if s, ok := signer.(*multiAlgorithmSigner); ok {
+		signerAlgos = s.Algorithms()
+	} else {
+		signerAlgos = supportedAlgos
+	}
+
+	for _, algo := range algorithms {
+		if !contains(supportedAlgos, algo) {
+			return nil, fmt.Errorf("ssh: algorithm %q is not supported for key type %q",
+				algo, signer.PublicKey().Type())
+		}
+		if !contains(signerAlgos, algo) {
+			return nil, fmt.Errorf("ssh: algorithm %q is restricted for the provided signer", algo)
+		}
+	}
+	return &multiAlgorithmSigner{
+		AlgorithmSigner:     signer,
+		supportedAlgorithms: algorithms,
+	}, nil
+}
+
+type multiAlgorithmSigner struct {
+	AlgorithmSigner
+	supportedAlgorithms []string
+}
+
+func (s *multiAlgorithmSigner) Algorithms() []string {
+	return s.supportedAlgorithms
+}
+
+func (s *multiAlgorithmSigner) isAlgorithmSupported(algorithm string) bool {
+	if algorithm == "" {
+		algorithm = underlyingAlgo(s.PublicKey().Type())
+	}
+	for _, algo := range s.supportedAlgorithms {
+		if algorithm == algo {
+			return true
+		}
+	}
+	return false
+}
+
+func (s *multiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
+	if !s.isAlgorithmSupported(algorithm) {
+		return nil, fmt.Errorf("ssh: algorithm %q is not supported: %v", algorithm, s.supportedAlgorithms)
+	}
+	return s.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
+}
+
 type rsaPublicKey rsa.PublicKey
 
 func (r *rsaPublicKey) Type() string {
@@ -381,17 +480,11 @@ func (r *rsaPublicKey) Marshal() []byte {
 }
 
 func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
-	var hash crypto.Hash
-	switch sig.Format {
-	case SigAlgoRSA:
-		hash = crypto.SHA1
-	case SigAlgoRSASHA2256:
-		hash = crypto.SHA256
-	case SigAlgoRSASHA2512:
-		hash = crypto.SHA512
-	default:
+	supportedAlgos := algorithmsForKeyFormat(r.Type())
+	if !contains(supportedAlgos, sig.Format) {
 		return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
 	}
+	hash := hashFuncs[sig.Format]
 	h := hash.New()
 	h.Write(data)
 	digest := h.Sum(nil)
@@ -466,7 +559,7 @@ func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error {
 	if sig.Format != k.Type() {
 		return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
 	}
-	h := crypto.SHA1.New()
+	h := hashFuncs[sig.Format].New()
 	h.Write(data)
 	digest := h.Sum(nil)
 
@@ -499,7 +592,11 @@ func (k *dsaPrivateKey) PublicKey() PublicKey {
 }
 
 func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
-	return k.SignWithAlgorithm(rand, data, "")
+	return k.SignWithAlgorithm(rand, data, k.PublicKey().Type())
+}
+
+func (k *dsaPrivateKey) Algorithms() []string {
+	return []string{k.PublicKey().Type()}
 }
 
 func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
@@ -507,7 +604,7 @@ func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm
 		return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
 	}
 
-	h := crypto.SHA1.New()
+	h := hashFuncs[k.PublicKey().Type()].New()
 	h.Write(data)
 	digest := h.Sum(nil)
 	r, s, err := dsa.Sign(rand, k.PrivateKey, digest)
@@ -603,19 +700,6 @@ func supportedEllipticCurve(curve elliptic.Curve) bool {
 	return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521()
 }
 
-// ecHash returns the hash to match the given elliptic curve, see RFC
-// 5656, section 6.2.1
-func ecHash(curve elliptic.Curve) crypto.Hash {
-	bitSize := curve.Params().BitSize
-	switch {
-	case bitSize <= 256:
-		return crypto.SHA256
-	case bitSize <= 384:
-		return crypto.SHA384
-	}
-	return crypto.SHA512
-}
-
 // parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
 func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
 	var w struct {
@@ -671,7 +755,7 @@ func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
 		return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
 	}
 
-	h := ecHash(k.Curve).New()
+	h := hashFuncs[sig.Format].New()
 	h.Write(data)
 	digest := h.Sum(nil)
 
@@ -775,7 +859,7 @@ func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error {
 		return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
 	}
 
-	h := ecHash(k.Curve).New()
+	h := hashFuncs[sig.Format].New()
 	h.Write([]byte(k.application))
 	appDigest := h.Sum(nil)
 
@@ -874,7 +958,7 @@ func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error {
 		return fmt.Errorf("invalid size %d for Ed25519 public key", l)
 	}
 
-	h := sha256.New()
+	h := hashFuncs[sig.Format].New()
 	h.Write([]byte(k.application))
 	appDigest := h.Sum(nil)
 
@@ -970,44 +1054,23 @@ func (s *wrappedSigner) PublicKey() PublicKey {
 }
 
 func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
-	return s.SignWithAlgorithm(rand, data, "")
+	return s.SignWithAlgorithm(rand, data, s.pubKey.Type())
+}
+
+func (s *wrappedSigner) Algorithms() []string {
+	return algorithmsForKeyFormat(s.pubKey.Type())
 }
 
 func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
-	var hashFunc crypto.Hash
-
-	if _, ok := s.pubKey.(*rsaPublicKey); ok {
-		// RSA keys support a few hash functions determined by the requested signature algorithm
-		switch algorithm {
-		case "", SigAlgoRSA:
-			algorithm = SigAlgoRSA
-			hashFunc = crypto.SHA1
-		case SigAlgoRSASHA2256:
-			hashFunc = crypto.SHA256
-		case SigAlgoRSASHA2512:
-			hashFunc = crypto.SHA512
-		default:
-			return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
-		}
-	} else {
-		// The only supported algorithm for all other key types is the same as the type of the key
-		if algorithm == "" {
-			algorithm = s.pubKey.Type()
-		} else if algorithm != s.pubKey.Type() {
-			return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
-		}
+	if algorithm == "" {
+		algorithm = s.pubKey.Type()
+	}
 
-		switch key := s.pubKey.(type) {
-		case *dsaPublicKey:
-			hashFunc = crypto.SHA1
-		case *ecdsaPublicKey:
-			hashFunc = ecHash(key.Curve)
-		case ed25519PublicKey:
-		default:
-			return nil, fmt.Errorf("ssh: unsupported key type %T", key)
-		}
+	if !contains(s.Algorithms(), algorithm) {
+		return nil, fmt.Errorf("ssh: unsupported signature algorithm %q for key format %q", algorithm, s.pubKey.Type())
 	}
 
+	hashFunc := hashFuncs[algorithm]
 	var digest []byte
 	if hashFunc != 0 {
 		h := hashFunc.New()
@@ -1123,9 +1186,9 @@ func (*PassphraseMissingError) Error() string {
 	return "ssh: this private key is passphrase protected"
 }
 
-// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
-// supports RSA (PKCS#1), PKCS#8, DSA (OpenSSL), and ECDSA private keys. If the
-// private key is encrypted, it will return a PassphraseMissingError.
+// ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports
+// RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH
+// formats. If the private key is encrypted, it will return a PassphraseMissingError.
 func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
 	block, _ := pem.Decode(pemBytes)
 	if block == nil {
@@ -1178,16 +1241,27 @@ func ParseRawPrivateKeyWithPassphrase(pemBytes, passphrase []byte) (interface{},
 		return nil, fmt.Errorf("ssh: cannot decode encrypted private keys: %v", err)
 	}
 
+	var result interface{}
+
 	switch block.Type {
 	case "RSA PRIVATE KEY":
-		return x509.ParsePKCS1PrivateKey(buf)
+		result, err = x509.ParsePKCS1PrivateKey(buf)
 	case "EC PRIVATE KEY":
-		return x509.ParseECPrivateKey(buf)
+		result, err = x509.ParseECPrivateKey(buf)
 	case "DSA PRIVATE KEY":
-		return ParseDSAPrivateKey(buf)
+		result, err = ParseDSAPrivateKey(buf)
 	default:
-		return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
+		err = fmt.Errorf("ssh: unsupported key type %q", block.Type)
+	}
+	// Because of deficiencies in the format, DecryptPEMBlock does not always
+	// detect an incorrect password. In these cases decrypted DER bytes is
+	// random noise. If the parsing of the key returns an asn1.StructuralError
+	// we return x509.IncorrectPasswordError.
+	if _, ok := err.(asn1.StructuralError); ok {
+		return nil, x509.IncorrectPasswordError
 	}
+
+	return result, err
 }
 
 // ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
@@ -1277,28 +1351,106 @@ func passphraseProtectedOpenSSHKey(passphrase []byte) openSSHDecryptFunc {
 	}
 }
 
+func unencryptedOpenSSHMarshaler(privKeyBlock []byte) ([]byte, string, string, string, error) {
+	key := generateOpenSSHPadding(privKeyBlock, 8)
+	return key, "none", "none", "", nil
+}
+
+func passphraseProtectedOpenSSHMarshaler(passphrase []byte) openSSHEncryptFunc {
+	return func(privKeyBlock []byte) ([]byte, string, string, string, error) {
+		salt := make([]byte, 16)
+		if _, err := rand.Read(salt); err != nil {
+			return nil, "", "", "", err
+		}
+
+		opts := struct {
+			Salt   []byte
+			Rounds uint32
+		}{salt, 16}
+
+		// Derive key to encrypt the private key block.
+		k, err := bcrypt_pbkdf.Key(passphrase, salt, int(opts.Rounds), 32+aes.BlockSize)
+		if err != nil {
+			return nil, "", "", "", err
+		}
+
+		// Add padding matching the block size of AES.
+		keyBlock := generateOpenSSHPadding(privKeyBlock, aes.BlockSize)
+
+		// Encrypt the private key using the derived secret.
+
+		dst := make([]byte, len(keyBlock))
+		key, iv := k[:32], k[32:]
+		block, err := aes.NewCipher(key)
+		if err != nil {
+			return nil, "", "", "", err
+		}
+
+		stream := cipher.NewCTR(block, iv)
+		stream.XORKeyStream(dst, keyBlock)
+
+		return dst, "aes256-ctr", "bcrypt", string(Marshal(opts)), nil
+	}
+}
+
+const privateKeyAuthMagic = "openssh-key-v1\x00"
+
 type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error)
+type openSSHEncryptFunc func(PrivKeyBlock []byte) (ProtectedKeyBlock []byte, cipherName, kdfName, kdfOptions string, err error)
+
+type openSSHEncryptedPrivateKey struct {
+	CipherName   string
+	KdfName      string
+	KdfOpts      string
+	NumKeys      uint32
+	PubKey       []byte
+	PrivKeyBlock []byte
+}
+
+type openSSHPrivateKey struct {
+	Check1  uint32
+	Check2  uint32
+	Keytype string
+	Rest    []byte `ssh:"rest"`
+}
+
+type openSSHRSAPrivateKey struct {
+	N       *big.Int
+	E       *big.Int
+	D       *big.Int
+	Iqmp    *big.Int
+	P       *big.Int
+	Q       *big.Int
+	Comment string
+	Pad     []byte `ssh:"rest"`
+}
+
+type openSSHEd25519PrivateKey struct {
+	Pub     []byte
+	Priv    []byte
+	Comment string
+	Pad     []byte `ssh:"rest"`
+}
+
+type openSSHECDSAPrivateKey struct {
+	Curve   string
+	Pub     []byte
+	D       *big.Int
+	Comment string
+	Pad     []byte `ssh:"rest"`
+}
 
 // parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt
 // function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used
 // as the decrypt function to parse an unencrypted private key. See
 // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key.
 func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) {
-	const magic = "openssh-key-v1\x00"
-	if len(key) < len(magic) || string(key[:len(magic)]) != magic {
+	if len(key) < len(privateKeyAuthMagic) || string(key[:len(privateKeyAuthMagic)]) != privateKeyAuthMagic {
 		return nil, errors.New("ssh: invalid openssh private key format")
 	}
-	remaining := key[len(magic):]
-
-	var w struct {
-		CipherName   string
-		KdfName      string
-		KdfOpts      string
-		NumKeys      uint32
-		PubKey       []byte
-		PrivKeyBlock []byte
-	}
+	remaining := key[len(privateKeyAuthMagic):]
 
+	var w openSSHEncryptedPrivateKey
 	if err := Unmarshal(remaining, &w); err != nil {
 		return nil, err
 	}
@@ -1320,13 +1472,7 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
 		return nil, err
 	}
 
-	pk1 := struct {
-		Check1  uint32
-		Check2  uint32
-		Keytype string
-		Rest    []byte `ssh:"rest"`
-	}{}
-
+	var pk1 openSSHPrivateKey
 	if err := Unmarshal(privKeyBlock, &pk1); err != nil || pk1.Check1 != pk1.Check2 {
 		if w.CipherName != "none" {
 			return nil, x509.IncorrectPasswordError
@@ -1336,18 +1482,7 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
 
 	switch pk1.Keytype {
 	case KeyAlgoRSA:
-		// https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773
-		key := struct {
-			N       *big.Int
-			E       *big.Int
-			D       *big.Int
-			Iqmp    *big.Int
-			P       *big.Int
-			Q       *big.Int
-			Comment string
-			Pad     []byte `ssh:"rest"`
-		}{}
-
+		var key openSSHRSAPrivateKey
 		if err := Unmarshal(pk1.Rest, &key); err != nil {
 			return nil, err
 		}
@@ -1373,13 +1508,7 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
 
 		return pk, nil
 	case KeyAlgoED25519:
-		key := struct {
-			Pub     []byte
-			Priv    []byte
-			Comment string
-			Pad     []byte `ssh:"rest"`
-		}{}
-
+		var key openSSHEd25519PrivateKey
 		if err := Unmarshal(pk1.Rest, &key); err != nil {
 			return nil, err
 		}
@@ -1396,14 +1525,7 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
 		copy(pk, key.Priv)
 		return &pk, nil
 	case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
-		key := struct {
-			Curve   string
-			Pub     []byte
-			D       *big.Int
-			Comment string
-			Pad     []byte `ssh:"rest"`
-		}{}
-
+		var key openSSHECDSAPrivateKey
 		if err := Unmarshal(pk1.Rest, &key); err != nil {
 			return nil, err
 		}
@@ -1451,6 +1573,131 @@ func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.Priv
 	}
 }
 
+func marshalOpenSSHPrivateKey(key crypto.PrivateKey, comment string, encrypt openSSHEncryptFunc) (*pem.Block, error) {
+	var w openSSHEncryptedPrivateKey
+	var pk1 openSSHPrivateKey
+
+	// Random check bytes.
+	var check uint32
+	if err := binary.Read(rand.Reader, binary.BigEndian, &check); err != nil {
+		return nil, err
+	}
+
+	pk1.Check1 = check
+	pk1.Check2 = check
+	w.NumKeys = 1
+
+	// Use a []byte directly on ed25519 keys.
+	if k, ok := key.(*ed25519.PrivateKey); ok {
+		key = *k
+	}
+
+	switch k := key.(type) {
+	case *rsa.PrivateKey:
+		E := new(big.Int).SetInt64(int64(k.PublicKey.E))
+		// Marshal public key:
+		// E and N are in reversed order in the public and private key.
+		pubKey := struct {
+			KeyType string
+			E       *big.Int
+			N       *big.Int
+		}{
+			KeyAlgoRSA,
+			E, k.PublicKey.N,
+		}
+		w.PubKey = Marshal(pubKey)
+
+		// Marshal private key.
+		key := openSSHRSAPrivateKey{
+			N:       k.PublicKey.N,
+			E:       E,
+			D:       k.D,
+			Iqmp:    k.Precomputed.Qinv,
+			P:       k.Primes[0],
+			Q:       k.Primes[1],
+			Comment: comment,
+		}
+		pk1.Keytype = KeyAlgoRSA
+		pk1.Rest = Marshal(key)
+	case ed25519.PrivateKey:
+		pub := make([]byte, ed25519.PublicKeySize)
+		priv := make([]byte, ed25519.PrivateKeySize)
+		copy(pub, k[32:])
+		copy(priv, k)
+
+		// Marshal public key.
+		pubKey := struct {
+			KeyType string
+			Pub     []byte
+		}{
+			KeyAlgoED25519, pub,
+		}
+		w.PubKey = Marshal(pubKey)
+
+		// Marshal private key.
+		key := openSSHEd25519PrivateKey{
+			Pub:     pub,
+			Priv:    priv,
+			Comment: comment,
+		}
+		pk1.Keytype = KeyAlgoED25519
+		pk1.Rest = Marshal(key)
+	case *ecdsa.PrivateKey:
+		var curve, keyType string
+		switch name := k.Curve.Params().Name; name {
+		case "P-256":
+			curve = "nistp256"
+			keyType = KeyAlgoECDSA256
+		case "P-384":
+			curve = "nistp384"
+			keyType = KeyAlgoECDSA384
+		case "P-521":
+			curve = "nistp521"
+			keyType = KeyAlgoECDSA521
+		default:
+			return nil, errors.New("ssh: unhandled elliptic curve " + name)
+		}
+
+		pub := elliptic.Marshal(k.Curve, k.PublicKey.X, k.PublicKey.Y)
+
+		// Marshal public key.
+		pubKey := struct {
+			KeyType string
+			Curve   string
+			Pub     []byte
+		}{
+			keyType, curve, pub,
+		}
+		w.PubKey = Marshal(pubKey)
+
+		// Marshal private key.
+		key := openSSHECDSAPrivateKey{
+			Curve:   curve,
+			Pub:     pub,
+			D:       k.D,
+			Comment: comment,
+		}
+		pk1.Keytype = keyType
+		pk1.Rest = Marshal(key)
+	default:
+		return nil, fmt.Errorf("ssh: unsupported key type %T", k)
+	}
+
+	var err error
+	// Add padding and encrypt the key if necessary.
+	w.PrivKeyBlock, w.CipherName, w.KdfName, w.KdfOpts, err = encrypt(Marshal(pk1))
+	if err != nil {
+		return nil, err
+	}
+
+	b := Marshal(w)
+	block := &pem.Block{
+		Type:  "OPENSSH PRIVATE KEY",
+		Bytes: append([]byte(privateKeyAuthMagic), b...),
+	}
+	return block, nil
+}
+
 func checkOpenSSHKeyPadding(pad []byte) error {
 	for i, b := range pad {
 		if int(b) != i+1 {
@@ -1460,6 +1707,13 @@ func checkOpenSSHKeyPadding(pad []byte) error {
 	return nil
 }
 
+func generateOpenSSHPadding(block []byte, blockSize int) []byte {
+	for i, l := 0, len(block); (l+i)%blockSize != 0; i++ {
+		block = append(block, byte(i+1))
+	}
+	return block
+}
+
 // FingerprintLegacyMD5 returns the user presentation of the key's
 // fingerprint as described by RFC 4716 section 4.
 func FingerprintLegacyMD5(pubKey PublicKey) string {

+ 115 - 5
psiphon/common/crypto/ssh/keys_test.go

@@ -8,6 +8,7 @@ import (
 	"bytes"
 	"crypto/dsa"
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/elliptic"
 	"crypto/rand"
 	"crypto/rsa"
@@ -15,6 +16,7 @@ import (
 	"encoding/base64"
 	"encoding/hex"
 	"encoding/pem"
+	"errors"
 	"fmt"
 	"io"
 	"reflect"
@@ -22,7 +24,6 @@ import (
 	"testing"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/testdata"
-	"golang.org/x/crypto/ed25519"
 )
 
 func rawKey(pub PublicKey) interface{} {
@@ -111,9 +112,9 @@ func TestKeySignVerify(t *testing.T) {
 }
 
 func TestKeySignWithAlgorithmVerify(t *testing.T) {
-	for _, priv := range testSigners {
-		if algorithmSigner, ok := priv.(AlgorithmSigner); !ok {
-			t.Errorf("Signers constructed by ssh package should always implement the AlgorithmSigner interface: %T", priv)
+	for k, priv := range testSigners {
+		if algorithmSigner, ok := priv.(MultiAlgorithmSigner); !ok {
+			t.Errorf("Signers %q constructed by ssh package should always implement the MultiAlgorithmSigner interface: %T", k, priv)
 		} else {
 			pub := priv.PublicKey()
 			data := []byte("sign me")
@@ -145,7 +146,7 @@ func TestKeySignWithAlgorithmVerify(t *testing.T) {
 
 			// RSA keys are the only ones which currently support more than one signing algorithm
 			if pub.Type() == KeyAlgoRSA {
-				for _, algorithm := range []string{SigAlgoRSA, SigAlgoRSASHA2256, SigAlgoRSASHA2512} {
+				for _, algorithm := range []string{KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512} {
 					signWithAlgTestCase(algorithm, algorithm)
 				}
 			}
@@ -221,6 +222,16 @@ func TestParseEncryptedPrivateKeysWithPassphrase(t *testing.T) {
 	}
 }
 
+func TestParseEncryptedPrivateKeysWithIncorrectPassphrase(t *testing.T) {
+	pem := testdata.PEMEncryptedKeys[0].PEMBytes
+	for i := 0; i < 4096; i++ {
+		_, err := ParseRawPrivateKeyWithPassphrase(pem, []byte(fmt.Sprintf("%d", i)))
+		if !errors.Is(err, x509.IncorrectPasswordError) {
+			t.Fatalf("expected error: %v, got: %v", x509.IncorrectPasswordError, err)
+		}
+	}
+}
+
 func TestParseDSA(t *testing.T) {
 	// We actually exercise the ParsePrivateKey codepath here, as opposed to
 	// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
@@ -281,6 +292,74 @@ func TestMarshalParsePublicKey(t *testing.T) {
 	}
 }
 
+func TestMarshalPrivateKey(t *testing.T) {
+	tests := []struct {
+		name string
+	}{
+		{"rsa-openssh-format"},
+		{"ed25519"},
+		{"p256-openssh-format"},
+		{"p384-openssh-format"},
+		{"p521-openssh-format"},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			expected, ok := testPrivateKeys[tt.name]
+			if !ok {
+				t.Fatalf("cannot find key %s", tt.name)
+			}
+
+			block, err := MarshalPrivateKey(expected, "test@golang.org")
+			if err != nil {
+				t.Fatalf("cannot marshal %s: %v", tt.name, err)
+			}
+
+			key, err := ParseRawPrivateKey(pem.EncodeToMemory(block))
+			if err != nil {
+				t.Fatalf("cannot parse %s: %v", tt.name, err)
+			}
+
+			if !reflect.DeepEqual(expected, key) {
+				t.Errorf("unexpected marshaled key %s", tt.name)
+			}
+		})
+	}
+}
+
+func TestMarshalPrivateKeyWithPassphrase(t *testing.T) {
+	tests := []struct {
+		name string
+	}{
+		{"rsa-openssh-format"},
+		{"ed25519"},
+		{"p256-openssh-format"},
+		{"p384-openssh-format"},
+		{"p521-openssh-format"},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			expected, ok := testPrivateKeys[tt.name]
+			if !ok {
+				t.Fatalf("cannot find key %s", tt.name)
+			}
+
+			block, err := MarshalPrivateKeyWithPassphrase(expected, "test@golang.org", []byte("test-passphrase"))
+			if err != nil {
+				t.Fatalf("cannot marshal %s: %v", tt.name, err)
+			}
+
+			key, err := ParseRawPrivateKeyWithPassphrase(pem.EncodeToMemory(block), []byte("test-passphrase"))
+			if err != nil {
+				t.Fatalf("cannot parse %s: %v", tt.name, err)
+			}
+
+			if !reflect.DeepEqual(expected, key) {
+				t.Errorf("unexpected marshaled key %s", tt.name)
+			}
+		})
+	}
+}
+
 type testAuthResult struct {
 	pubKey   PublicKey
 	options  []string
@@ -616,3 +695,34 @@ func TestSKKeys(t *testing.T) {
 		}
 	}
 }
+
+func TestNewSignerWithAlgos(t *testing.T) {
+	algorithSigner, ok := testSigners["rsa"].(AlgorithmSigner)
+	if !ok {
+		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
+	}
+	_, err := NewSignerWithAlgorithms(algorithSigner, nil)
+	if err == nil {
+		t.Error("signer with algos created with no algorithms")
+	}
+
+	_, err = NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoED25519})
+	if err == nil {
+		t.Error("signer with algos created with invalid algorithms")
+	}
+
+	_, err = NewSignerWithAlgorithms(algorithSigner, []string{CertAlgoRSASHA256v01})
+	if err == nil {
+		t.Error("signer with algos created with certificate algorithms")
+	}
+
+	mas, err := NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
+	if err != nil {
+		t.Errorf("unable to create signer with valid algorithms: %v", err)
+	}
+
+	_, err = NewSignerWithAlgorithms(mas, []string{KeyAlgoRSA})
+	if err == nil {
+		t.Error("signer with algos created with restricted algorithms")
+	}
+}

+ 1 - 1
psiphon/common/crypto/ssh/knownhosts/knownhosts.go

@@ -142,7 +142,7 @@ func keyEq(a, b ssh.PublicKey) bool {
 	return bytes.Equal(a.Marshal(), b.Marshal())
 }
 
-// IsAuthorityForHost can be used as a callback in ssh.CertChecker
+// IsHostAuthority can be used as a callback in ssh.CertChecker
 func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
 	h, p, err := net.SplitHostPort(address)
 	if err != nil {

+ 7 - 0
psiphon/common/crypto/ssh/mac.go

@@ -10,6 +10,7 @@ import (
 	"crypto/hmac"
 	"crypto/sha1"
 	"crypto/sha256"
+	"crypto/sha512"
 	"hash"
 )
 
@@ -46,9 +47,15 @@ func (t truncatingMAC) Size() int {
 func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
 
 var macModes = map[string]*macMode{
+	"hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash {
+		return hmac.New(sha512.New, key)
+	}},
 	"hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
 		return hmac.New(sha256.New, key)
 	}},
+	"hmac-sha2-512": {64, false, func(key []byte) hash.Hash {
+		return hmac.New(sha512.New, key)
+	}},
 	"hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
 		return hmac.New(sha256.New, key)
 	}},

+ 17 - 3
psiphon/common/crypto/ssh/mempipe_test.go

@@ -13,9 +13,10 @@ import (
 // An in-memory packetConn. It is safe to call Close and writePacket
 // from different goroutines.
 type memTransport struct {
-	eof     bool
-	pending [][]byte
-	write   *memTransport
+	eof        bool
+	pending    [][]byte
+	write      *memTransport
+	writeCount uint64
 	sync.Mutex
 	*sync.Cond
 }
@@ -63,9 +64,16 @@ func (t *memTransport) writePacket(p []byte) error {
 	copy(c, p)
 	t.write.pending = append(t.write.pending, c)
 	t.write.Cond.Signal()
+	t.writeCount++
 	return nil
 }
 
+func (t *memTransport) getWriteCount() uint64 {
+	t.write.Lock()
+	defer t.write.Unlock()
+	return t.writeCount
+}
+
 func memPipe() (a, b packetConn) {
 	t1 := memTransport{}
 	t2 := memTransport{}
@@ -81,6 +89,9 @@ func TestMemPipe(t *testing.T) {
 	if err := a.writePacket([]byte{42}); err != nil {
 		t.Fatalf("writePacket: %v", err)
 	}
+	if wc := a.(*memTransport).getWriteCount(); wc != 1 {
+		t.Fatalf("got %v, want 1", wc)
+	}
 	if err := a.Close(); err != nil {
 		t.Fatal("Close: ", err)
 	}
@@ -95,6 +106,9 @@ func TestMemPipe(t *testing.T) {
 	if err != io.EOF {
 		t.Fatalf("got %v, %v, want EOF", p, err)
 	}
+	if wc := b.(*memTransport).getWriteCount(); wc != 0 {
+		t.Fatalf("got %v, want 0", wc)
+	}
 }
 
 func TestDoubleClose(t *testing.T) {

+ 31 - 6
psiphon/common/crypto/ssh/messages.go

@@ -68,7 +68,7 @@ type kexInitMsg struct {
 
 // See RFC 4253, section 8.
 
-// Diffie-Helman
+// Diffie-Hellman
 const msgKexDHInit = 30
 
 type kexDHInitMsg struct {
@@ -141,6 +141,14 @@ type serviceAcceptMsg struct {
 	Service string `sshtype:"6"`
 }
 
+// See RFC 8308, section 2.3
+const msgExtInfo = 7
+
+type extInfoMsg struct {
+	NumExtensions uint32 `sshtype:"7"`
+	Payload       []byte `ssh:"rest"`
+}
+
 // See RFC 4252, section 5.
 const msgUserAuthRequest = 50
 
@@ -180,11 +188,11 @@ const msgUserAuthInfoRequest = 60
 const msgUserAuthInfoResponse = 61
 
 type userAuthInfoRequestMsg struct {
-	User               string `sshtype:"60"`
-	Instruction        string
-	DeprecatedLanguage string
-	NumPrompts         uint32
-	Prompts            []byte `ssh:"rest"`
+	Name        string `sshtype:"60"`
+	Instruction string
+	Language    string
+	NumPrompts  uint32
+	Prompts     []byte `ssh:"rest"`
 }
 
 // See RFC 4254, section 5.1.
@@ -341,6 +349,20 @@ type userAuthGSSAPIError struct {
 	LanguageTag string
 }
 
+// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
+const msgPing = 192
+
+type pingMsg struct {
+	Data string `sshtype:"192"`
+}
+
+// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
+const msgPong = 193
+
+type pongMsg struct {
+	Data string `sshtype:"193"`
+}
+
 // typeTags returns the possible type bytes for the given reflect.Type, which
 // should be a struct. The possible values are separated by a '|' character.
 func typeTags(structType reflect.Type) (tags []byte) {
@@ -782,6 +804,8 @@ func decode(packet []byte) (interface{}, error) {
 		msg = new(serviceRequestMsg)
 	case msgServiceAccept:
 		msg = new(serviceAcceptMsg)
+	case msgExtInfo:
+		msg = new(extInfoMsg)
 	case msgKexInit:
 		msg = new(kexInitMsg)
 	case msgKexDHInit:
@@ -843,6 +867,7 @@ var packetTypeNames = map[byte]string{
 	msgDisconnect:          "disconnectMsg",
 	msgServiceRequest:      "serviceRequestMsg",
 	msgServiceAccept:       "serviceAcceptMsg",
+	msgExtInfo:             "extInfoMsg",
 	msgKexInit:             "kexInitMsg",
 	msgKexDHInit:           "kexDHInitMsg",
 	msgKexDHReply:          "kexDHReplyMsg",

+ 6 - 0
psiphon/common/crypto/ssh/mux.go

@@ -231,6 +231,12 @@ func (m *mux) onePacket() error {
 		return m.handleChannelOpen(packet)
 	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
 		return m.handleGlobalPacket(packet)
+	case msgPing:
+		var msg pingMsg
+		if err := Unmarshal(packet, &msg); err != nil {
+			return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
+		}
+		return m.sendMessage(pongMsg(msg))
 	}
 
 	// assume a channel packet.

+ 201 - 78
psiphon/common/crypto/ssh/mux_test.go

@@ -5,15 +5,14 @@
 package ssh
 
 import (
+	"errors"
+	"fmt"
 	"io"
-	"io/ioutil"
 	"sync"
 	"testing"
-	"time"
 )
 
-// PSIPHON
-// =======
+// [Psiphon]
 // See comment in channel.go
 var testChannelWindowSize = getChannelWindowSize("")
 
@@ -35,14 +34,21 @@ func channelPair(t *testing.T) (*channel, *channel, *mux) {
 	go func() {
 		newCh, ok := <-s.incomingChannels
 		if !ok {
-			t.Fatalf("No incoming channel")
+			t.Error("no incoming channel")
+			close(res)
+			return
 		}
 		if newCh.ChannelType() != "chan" {
-			t.Fatalf("got type %q want chan", newCh.ChannelType())
+			t.Errorf("got type %q want chan", newCh.ChannelType())
+			newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType()))
+			close(res)
+			return
 		}
 		ch, _, err := newCh.Accept()
 		if err != nil {
-			t.Fatalf("Accept %v", err)
+			t.Errorf("accept: %v", err)
+			close(res)
+			return
 		}
 		res <- ch.(*channel)
 	}()
@@ -51,8 +57,12 @@ func channelPair(t *testing.T) (*channel, *channel, *mux) {
 	if err != nil {
 		t.Fatalf("OpenChannel: %v", err)
 	}
+	w := <-res
+	if w == nil {
+		t.Fatal("unable to get write channel")
+	}
 
-	return <-res, ch, c
+	return w, ch, c
 }
 
 // Test that stderr and stdout can be addressed from different
@@ -78,16 +88,16 @@ func TestMuxChannelExtendedThreadSafety(t *testing.T) {
 
 	rd.Add(2)
 	go func() {
-		c, err := ioutil.ReadAll(reader)
+		c, err := io.ReadAll(reader)
 		if string(c) != magic {
-			t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
+			t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err)
 		}
 		rd.Done()
 	}()
 	go func() {
-		c, err := ioutil.ReadAll(reader.Stderr())
+		c, err := io.ReadAll(reader.Stderr())
 		if string(c) != magic {
-			t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
+			t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err)
 		}
 		rd.Done()
 	}()
@@ -105,14 +115,20 @@ func TestMuxReadWrite(t *testing.T) {
 
 	magic := "hello world"
 	magicExt := "hello stderr"
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		_, err := s.Write([]byte(magic))
 		if err != nil {
-			t.Fatalf("Write: %v", err)
+			t.Errorf("Write: %v", err)
+			return
 		}
 		_, err = s.Extended(1).Write([]byte(magicExt))
 		if err != nil {
-			t.Fatalf("Write: %v", err)
+			t.Errorf("Write: %v", err)
+			return
 		}
 	}()
 
@@ -143,13 +159,15 @@ func TestMuxChannelOverflow(t *testing.T) {
 	defer writer.Close()
 	defer mux.Close()
 
-	wDone := make(chan int, 1)
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
 			t.Errorf("could not fill window: %v", err)
 		}
 		writer.Write(make([]byte, 1))
-		wDone <- 1
 	}()
 	writer.remoteWin.waitWriterBlocked()
 
@@ -166,7 +184,40 @@ func TestMuxChannelOverflow(t *testing.T) {
 	if _, err := reader.SendRequest("hello", true, nil); err == nil {
 		t.Errorf("SendRequest succeeded.")
 	}
-	<-wDone
+}
+
+func TestMuxChannelReadUnblock(t *testing.T) {
+	reader, writer, mux := channelPair(t)
+	defer reader.Close()
+	defer writer.Close()
+	defer mux.Close()
+
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
+			t.Errorf("could not fill window: %v", err)
+		}
+		if _, err := writer.Write(make([]byte, 1)); err != nil {
+			t.Errorf("Write: %v", err)
+		}
+		writer.Close()
+	}()
+
+	writer.remoteWin.waitWriterBlocked()
+
+	buf := make([]byte, 32768)
+	for {
+		_, err := reader.Read(buf)
+		if err == io.EOF {
+			break
+		}
+		if err != nil {
+			t.Fatalf("Read: %v", err)
+		}
+	}
 }
 
 func TestMuxChannelCloseWriteUnblock(t *testing.T) {
@@ -175,20 +226,21 @@ func TestMuxChannelCloseWriteUnblock(t *testing.T) {
 	defer writer.Close()
 	defer mux.Close()
 
-	wDone := make(chan int, 1)
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
 			t.Errorf("could not fill window: %v", err)
 		}
 		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
 			t.Errorf("got %v, want EOF for unblock write", err)
 		}
-		wDone <- 1
 	}()
 
 	writer.remoteWin.waitWriterBlocked()
 	reader.Close()
-	<-wDone
 }
 
 func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
@@ -197,20 +249,21 @@ func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
 	defer writer.Close()
 	defer mux.Close()
 
-	wDone := make(chan int, 1)
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil {
 			t.Errorf("could not fill window: %v", err)
 		}
 		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
 			t.Errorf("got %v, want EOF for unblock write", err)
 		}
-		wDone <- 1
 	}()
 
 	writer.remoteWin.waitWriterBlocked()
 	mux.Close()
-	<-wDone
 }
 
 func TestMuxReject(t *testing.T) {
@@ -218,13 +271,21 @@ func TestMuxReject(t *testing.T) {
 	defer server.Close()
 	defer client.Close()
 
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
+
 		ch, ok := <-server.incomingChannels
 		if !ok {
-			t.Fatalf("Accept")
+			t.Error("cannot accept channel")
+			return
 		}
 		if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
-			t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
+			t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
+			ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String())
+			return
 		}
 		ch.Reject(RejectionReason(42), "message")
 	}()
@@ -255,6 +316,7 @@ func TestMuxChannelRequest(t *testing.T) {
 
 	var received int
 	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
 	wg.Add(1)
 	go func() {
 		for r := range server.incomingRequests {
@@ -283,7 +345,6 @@ func TestMuxChannelRequest(t *testing.T) {
 	}
 	if ok {
 		t.Errorf("SendRequest(no): %v", ok)
-
 	}
 
 	client.Close()
@@ -300,7 +361,7 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
 	defer serverPipe.Close()
 	defer client.Close()
 
-	kDone := make(chan struct{})
+	kDone := make(chan error, 1)
 	go func() {
 		// Ignore unknown channel messages that don't want a reply.
 		err := serverPipe.writePacket(Marshal(channelRequestMsg{
@@ -310,7 +371,8 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
 			RequestSpecificData: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		// Send a keepalive, which should get a channel failure message
@@ -322,44 +384,53 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
 			RequestSpecificData: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		packet, err := serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		decoded, err := decode(packet)
 		if err != nil {
-			t.Fatalf("decode failed: %v", err)
+			kDone <- fmt.Errorf("decode failed: %w", err)
+			return
 		}
 
 		switch msg := decoded.(type) {
 		case *channelRequestFailureMsg:
 			if msg.PeersID != 2 {
-				t.Fatalf("received response to wrong message: %v", msg)
+				kDone <- fmt.Errorf("received response to wrong message: %v", msg)
+				return
+
 			}
 		default:
-			t.Fatalf("unexpected channel message: %v", msg)
+			kDone <- fmt.Errorf("unexpected channel message: %v", msg)
+			return
 		}
 
-		kDone <- struct{}{}
+		kDone <- nil
 
 		// Receive and respond to the keepalive to confirm the mux is
 		// still processing requests.
 		packet, err = serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgGlobalRequest {
-			t.Fatalf("expected global request")
+			kDone <- errors.New("expected global request")
+			return
 		}
 
 		err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
 			Data: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("failed to send failure msg: %v", err)
+			kDone <- fmt.Errorf("failed to send failure msg: %w", err)
+			return
 		}
 
 		close(kDone)
@@ -367,10 +438,8 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
 
 	// Wait for the server to send the keepalive message and receive back a
 	// response.
-	select {
-	case <-kDone:
-	case <-time.After(10 * time.Second):
-		t.Fatalf("server never received ack")
+	if err := <-kDone; err != nil {
+		t.Fatal(err)
 	}
 
 	// Confirm client hasn't closed.
@@ -378,10 +447,9 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
 		t.Fatalf("failed to send keepalive: %v", err)
 	}
 
-	select {
-	case <-kDone:
-	case <-time.After(10 * time.Second):
-		t.Fatalf("server never shut down")
+	// Wait for the server to shut down.
+	if err := <-kDone; err != nil {
+		t.Fatal(err)
 	}
 }
 
@@ -391,20 +459,23 @@ func TestMuxClosedChannel(t *testing.T) {
 	defer serverPipe.Close()
 	defer client.Close()
 
-	kDone := make(chan struct{})
+	kDone := make(chan error, 1)
 	go func() {
 		// Open the channel.
 		packet, err := serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgChannelOpen {
-			t.Fatalf("expected chan open")
+			kDone <- errors.New("expected chan open")
+			return
 		}
 
 		var openMsg channelOpenMsg
 		if err := Unmarshal(packet, &openMsg); err != nil {
-			t.Fatalf("unmarshal: %v", err)
+			kDone <- fmt.Errorf("unmarshal: %w", err)
+			return
 		}
 
 		// Send back the opened channel confirmation.
@@ -415,7 +486,8 @@ func TestMuxClosedChannel(t *testing.T) {
 			MaxPacketSize: channelMaxPacket,
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		// Close the channel.
@@ -423,7 +495,8 @@ func TestMuxClosedChannel(t *testing.T) {
 			PeersID: openMsg.PeersID,
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		// Send a keepalive message on the channel we just closed.
@@ -434,43 +507,51 @@ func TestMuxClosedChannel(t *testing.T) {
 			RequestSpecificData: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		// Receive the channel closed response.
 		packet, err = serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgChannelClose {
-			t.Fatalf("expected channel close")
+			kDone <- errors.New("expected channel close")
+			return
 		}
 
 		// Receive the keepalive response failure.
 		packet, err = serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgChannelFailure {
-			t.Fatalf("expected channel close")
+			kDone <- errors.New("expected channel failure")
+			return
 		}
-		kDone <- struct{}{}
+		kDone <- nil
 
 		// Receive and respond to the keepalive to confirm the mux is
 		// still processing requests.
 		packet, err = serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgGlobalRequest {
-			t.Fatalf("expected global request")
+			kDone <- errors.New("expected global request")
+			return
 		}
 
 		err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
 			Data: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("failed to send failure msg: %v", err)
+			kDone <- fmt.Errorf("failed to send failure msg: %w", err)
+			return
 		}
 
 		close(kDone)
@@ -484,11 +565,7 @@ func TestMuxClosedChannel(t *testing.T) {
 	defer ch.Close()
 
 	// Wait for the server to close the channel and send the keepalive.
-	select {
-	case <-kDone:
-	case <-time.After(10 * time.Second):
-		t.Fatalf("server never received ack")
-	}
+	<-kDone
 
 	// Make sure the channel closed.
 	if _, ok := <-ch.incomingRequests; ok {
@@ -500,22 +577,29 @@ func TestMuxClosedChannel(t *testing.T) {
 		t.Fatalf("failed to send keepalive: %v", err)
 	}
 
-	select {
-	case <-kDone:
-	case <-time.After(10 * time.Second):
-		t.Fatalf("server never shut down")
-	}
+	// Wait for the server to shut down.
+	<-kDone
 }
 
 func TestMuxGlobalRequest(t *testing.T) {
+	var sawPeek bool
+	var wg sync.WaitGroup
+	defer func() {
+		wg.Wait()
+		if !sawPeek {
+			t.Errorf("never saw 'peek' request")
+		}
+	}()
+
 	clientMux, serverMux := muxPair()
 	defer serverMux.Close()
 	defer clientMux.Close()
 
-	var seen bool
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		for r := range serverMux.incomingRequests {
-			seen = seen || r.Type == "peek"
+			sawPeek = sawPeek || r.Type == "peek"
 			if r.WantReply {
 				err := r.Reply(r.Type == "yes",
 					append([]byte(r.Type), r.Payload...))
@@ -545,10 +629,6 @@ func TestMuxGlobalRequest(t *testing.T) {
 		t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
 			ok, data, err)
 	}
-
-	if !seen {
-		t.Errorf("never saw 'peek' request")
-	}
 }
 
 func TestMuxGlobalRequestUnblock(t *testing.T) {
@@ -675,7 +755,7 @@ func TestZeroWindowAdjust(t *testing.T) {
 	}()
 
 	want := "helloworld"
-	c, _ := ioutil.ReadAll(b)
+	c, _ := io.ReadAll(b)
 	if string(c) != want {
 		t.Errorf("got %q want %q", c, want)
 	}
@@ -698,7 +778,13 @@ func TestMuxMaxPacketSize(t *testing.T) {
 		t.Errorf("could not send packet")
 	}
 
-	go a.SendRequest("hello", false, nil)
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
+	go func() {
+		a.SendRequest("hello", false, nil)
+		wg.Done()
+	}()
 
 	_, ok := <-b.incomingRequests
 	if ok {
@@ -706,6 +792,43 @@ func TestMuxMaxPacketSize(t *testing.T) {
 	}
 }
 
+func TestMuxChannelWindowDeferredUpdates(t *testing.T) {
+	s, c, mux := channelPair(t)
+	cTransport := mux.conn.(*memTransport)
+	defer s.Close()
+	defer c.Close()
+	defer mux.Close()
+
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+
+	data := make([]byte, 1024)
+
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		_, err := s.Write(data)
+		if err != nil {
+			t.Errorf("Write: %v", err)
+			return
+		}
+	}()
+	cWritesInit := cTransport.getWriteCount()
+	buf := make([]byte, 1)
+	for i := 0; i < len(data); i++ {
+		n, err := c.Read(buf)
+		if n != len(buf) || err != nil {
+			t.Fatalf("Read: %v, %v", n, err)
+		}
+	}
+	cWrites := cTransport.getWriteCount() - cWritesInit
+	// reading 1 KiB should not cause any window updates to be sent, but allow
+	// for some unexpected writes
+	if cWrites > 30 {
+		t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites)
+	}
+}
+
 // Don't ship code with debug=true.
 func TestDebug(t *testing.T) {
 	if debugMux {

+ 1 - 1
psiphon/common/crypto/ssh/randomized_kex_test.go

@@ -27,8 +27,8 @@ import (
 	"net"
 	"testing"
 
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"golang.org/x/sync/errgroup"
 )
 

+ 102 - 20
psiphon/common/crypto/ssh/server.go

@@ -64,12 +64,27 @@ type ServerConfig struct {
 	// Config contains configuration shared between client and server.
 	Config
 
+	// PublicKeyAuthAlgorithms specifies the supported client public key
+	// authentication algorithms. Note that this should not include certificate
+	// types since those use the underlying algorithm. This list is sent to the
+	// client if it supports the server-sig-algs extension. Order is irrelevant.
+	// If unspecified then a default set of algorithms is used.
+	PublicKeyAuthAlgorithms []string
+
 	hostKeys []Signer
 
 	// NoClientAuth is true if clients are allowed to connect without
 	// authenticating.
+	// To determine NoClientAuth at runtime, set NoClientAuth to true
+	// and the optional NoClientAuthCallback to a non-nil value.
 	NoClientAuth bool
 
+	// NoClientAuthCallback, if non-nil, is called when a user
+	// attempts to authenticate with auth method "none".
+	// NoClientAuth must also be set to true for this be used, or
+	// this func is unused.
+	NoClientAuthCallback func(ConnMetadata) (*Permissions, error)
+
 	// MaxAuthTries specifies the maximum number of authentication attempts
 	// permitted per connection. If set to a negative number, the number of
 	// attempts are unlimited. If set to zero, the number of attempts are limited
@@ -120,7 +135,7 @@ type ServerConfig struct {
 }
 
 // AddHostKey adds a private key as a host key. If an existing host
-// key exists with the same algorithm, it is overwritten. Each server
+// key exists with the same public key format, it is replaced. Each server
 // config must have at least one host key.
 func (s *ServerConfig) AddHostKey(key Signer) {
 	for i, k := range s.hostKeys {
@@ -193,9 +208,20 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha
 	if fullConf.MaxAuthTries == 0 {
 		fullConf.MaxAuthTries = 6
 	}
+	if len(fullConf.PublicKeyAuthAlgorithms) == 0 {
+		fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos
+	} else {
+		for _, algo := range fullConf.PublicKeyAuthAlgorithms {
+			if !contains(supportedPubKeyAuthAlgos, algo) {
+				c.Close()
+				return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo)
+			}
+		}
+	}
 	// Check if the config contains any unsupported key exchanges
 	for _, kex := range fullConf.KeyExchanges {
 		if _, ok := serverForbiddenKexAlgos[kex]; ok {
+			c.Close()
 			return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex)
 		}
 	}
@@ -212,9 +238,10 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha
 }
 
 // signAndMarshal signs the data with the appropriate algorithm,
-// and serializes the result in SSH wire format.
-func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
-	sig, err := k.Sign(rand, data)
+// and serializes the result in SSH wire format. algo is the negotiate
+// algorithm and may be a certificate type.
+func signAndMarshal(k AlgorithmSigner, rand io.Reader, data []byte, algo string) ([]byte, error) {
+	sig, err := k.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
 	if err != nil {
 		return nil, err
 	}
@@ -282,15 +309,6 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
 	return perms, err
 }
 
-func isAcceptableAlgo(algo string) bool {
-	switch algo {
-	case SigAlgoRSA, SigAlgoRSASHA2256, SigAlgoRSASHA2512, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoSKECDSA256, KeyAlgoED25519, KeyAlgoSKED25519,
-		CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
-		return true
-	}
-	return false
-}
-
 func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
 	if addr == nil {
 		return errors.New("ssh: no address known for client, but source-address match required")
@@ -321,7 +339,7 @@ func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
 	return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
 }
 
-func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *connection,
+func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection,
 	sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) {
 	gssAPIServer := gssapiConfig.Server
 	defer gssAPIServer.DeleteSecContext()
@@ -331,7 +349,7 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c
 			outToken     []byte
 			needContinue bool
 		)
-		outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(firstToken)
+		outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token)
 		if err != nil {
 			return err, nil, nil
 		}
@@ -353,6 +371,7 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c
 		if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
 			return nil, nil, err
 		}
+		token = userAuthGSSAPITokenReq.Token
 	}
 	packet, err := s.transport.readPacket()
 	if err != nil {
@@ -370,6 +389,25 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c
 	return authErr, perms, nil
 }
 
+// isAlgoCompatible checks if the signature format is compatible with the
+// selected algorithm taking into account edge cases that occur with old
+// clients.
+func isAlgoCompatible(algo, sigFormat string) bool {
+	// Compatibility for old clients.
+	//
+	// For certificate authentication with OpenSSH 7.2-7.7 signature format can
+	// be rsa-sha2-256 or rsa-sha2-512 for the algorithm
+	// ssh-rsa-cert-v01@openssh.com.
+	//
+	// With gpg-agent < 2.2.6 the algorithm can be rsa-sha2-256 or rsa-sha2-512
+	// for signature format ssh-rsa.
+	if isRSA(algo) && isRSA(sigFormat) {
+		return true
+	}
+	// Standard case: the underlying algorithm must match the signature format.
+	return underlyingAlgo(algo) == sigFormat
+}
+
 // ServerAuthError represents server authentication errors and is
 // sometimes returned by NewServerConn. It appends any authentication
 // errors that may occur, and is returned if all of the authentication
@@ -454,7 +492,11 @@ userAuthLoop:
 		switch userAuthReq.Method {
 		case "none":
 			if config.NoClientAuth {
-				authErr = nil
+				if config.NoClientAuthCallback != nil {
+					perms, authErr = config.NoClientAuthCallback(s)
+				} else {
+					authErr = nil
+				}
 			}
 
 			// allow initial attempt of 'none' without penalty
@@ -501,7 +543,7 @@ userAuthLoop:
 				return nil, parseError(msgUserAuthRequest)
 			}
 			algo := string(algoBytes)
-			if !isAcceptableAlgo(algo) {
+			if !contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) {
 				authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
 				break
 			}
@@ -553,16 +595,31 @@ userAuthLoop:
 				if !ok || len(payload) > 0 {
 					return nil, parseError(msgUserAuthRequest)
 				}
+				// Ensure the declared public key algo is compatible with the
+				// decoded one. This check will ensure we don't accept e.g.
+				// ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public
+				// key type. The algorithm and public key type must be
+				// consistent: both must be certificate algorithms, or neither.
+				if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) {
+					authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q",
+						pubKey.Type(), algo)
+					break
+				}
 				// Ensure the public key algo and signature algo
 				// are supported.  Compare the private key
 				// algorithm name that corresponds to algo with
 				// sig.Format.  This is usually the same, but
 				// for certs, the names differ.
-				if !isAcceptableAlgo(sig.Format) {
+				if !contains(config.PublicKeyAuthAlgorithms, sig.Format) {
 					authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
 					break
 				}
-				signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)
+				if !isAlgoCompatible(algo, sig.Format) {
+					authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo)
+					break
+				}
+
+				signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData)
 
 				if err := pubKey.Verify(signedData, sig); err != nil {
 					return nil, err
@@ -633,6 +690,30 @@ userAuthLoop:
 		}
 
 		authFailures++
+		if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
+			// If we have hit the max attempts, don't bother sending the
+			// final SSH_MSG_USERAUTH_FAILURE message, since there are
+			// no more authentication methods which can be attempted,
+			// and this message may cause the client to re-attempt
+			// authentication while we send the disconnect message.
+			// Continue, and trigger the disconnect at the start of
+			// the loop.
+			//
+			// The SSH specification is somewhat confusing about this,
+			// RFC 4252 Section 5.1 requires each authentication failure
+			// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
+			// message, but Section 4 says the server should disconnect
+			// after some number of attempts, but it isn't explicit which
+			// message should take precedence (i.e. should there be a failure
+			// message than a disconnect message, or if we are going to
+			// disconnect, should we only send that message.)
+			//
+			// Either way, OpenSSH disconnects immediately after the last
+			// failed authnetication attempt, and given they are typically
+			// considered the golden implementation it seems reasonable
+			// to match that behavior.
+			continue
+		}
 
 		var failureMsg userAuthFailureMsg
 		if config.PasswordCallback != nil {
@@ -670,7 +751,7 @@ type sshClientKeyboardInteractive struct {
 	*connection
 }
 
-func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
+func (c *sshClientKeyboardInteractive) Challenge(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
 	if len(questions) != len(echos) {
 		return nil, errors.New("ssh: echos and questions must have equal length")
 	}
@@ -682,6 +763,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
 	}
 
 	if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{
+		Name:        name,
 		Instruction: instruction,
 		NumPrompts:  uint32(len(questions)),
 		Prompts:     prompts,

+ 140 - 0
psiphon/common/crypto/ssh/server_test.go

@@ -0,0 +1,140 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"io"
+	"net"
+	"sync/atomic"
+	"testing"
+	"time"
+)
+
+func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) {
+	for _, tt := range []struct {
+		name      string
+		key       Signer
+		wantError bool
+	}{
+		{"rsa", testSigners["rsa"], false},
+		{"dsa", testSigners["dsa"], true},
+		{"ed25519", testSigners["ed25519"], true},
+	} {
+		c1, c2, err := netPipe()
+		if err != nil {
+			t.Fatalf("netPipe: %v", err)
+		}
+		defer c1.Close()
+		defer c2.Close()
+		serverConf := &ServerConfig{
+			PublicKeyAuthAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
+			PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+				return nil, nil
+			},
+		}
+		serverConf.AddHostKey(testSigners["ecdsap256"])
+
+		done := make(chan struct{})
+		go func() {
+			defer close(done)
+			NewServerConn(c1, serverConf)
+		}()
+
+		clientConf := ClientConfig{
+			User: "user",
+			Auth: []AuthMethod{
+				PublicKeys(tt.key),
+			},
+			HostKeyCallback: InsecureIgnoreHostKey(),
+		}
+
+		_, _, _, err = NewClientConn(c2, "", &clientConf)
+		if err != nil {
+			if !tt.wantError {
+				t.Errorf("%s: got unexpected error %q", tt.name, err.Error())
+			}
+		} else if tt.wantError {
+			t.Errorf("%s: succeeded, but want error", tt.name)
+		}
+		<-done
+	}
+}
+
+func TestNewServerConnValidationErrors(t *testing.T) {
+	serverConf := &ServerConfig{
+		PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01},
+	}
+	c := &markerConn{}
+	_, _, _, err := NewServerConn(c, serverConf)
+	if err == nil {
+		t.Fatal("NewServerConn with invalid public key auth algorithms succeeded")
+	}
+	if !c.isClosed() {
+		t.Fatal("NewServerConn with invalid public key auth algorithms left connection open")
+	}
+	if c.isUsed() {
+		t.Fatal("NewServerConn with invalid public key auth algorithms used connection")
+	}
+
+	serverConf = &ServerConfig{
+		Config: Config{
+			KeyExchanges: []string{kexAlgoDHGEXSHA256},
+		},
+	}
+	c = &markerConn{}
+	_, _, _, err = NewServerConn(c, serverConf)
+	if err == nil {
+		t.Fatal("NewServerConn with unsupported key exchange succeeded")
+	}
+	if !c.isClosed() {
+		t.Fatal("NewServerConn with unsupported key exchange left connection open")
+	}
+	if c.isUsed() {
+		t.Fatal("NewServerConn with unsupported key exchange used connection")
+	}
+}
+
+type markerConn struct {
+	closed uint32
+	used   uint32
+}
+
+func (c *markerConn) isClosed() bool {
+	return atomic.LoadUint32(&c.closed) != 0
+}
+
+func (c *markerConn) isUsed() bool {
+	return atomic.LoadUint32(&c.used) != 0
+}
+
+func (c *markerConn) Close() error {
+	atomic.StoreUint32(&c.closed, 1)
+	return nil
+}
+
+func (c *markerConn) Read(b []byte) (n int, err error) {
+	atomic.StoreUint32(&c.used, 1)
+	if atomic.LoadUint32(&c.closed) != 0 {
+		return 0, net.ErrClosed
+	} else {
+		return 0, io.EOF
+	}
+}
+
+func (c *markerConn) Write(b []byte) (n int, err error) {
+	atomic.StoreUint32(&c.used, 1)
+	if atomic.LoadUint32(&c.closed) != 0 {
+		return 0, net.ErrClosed
+	} else {
+		return 0, io.ErrClosedPipe
+	}
+}
+
+func (*markerConn) LocalAddr() net.Addr  { return nil }
+func (*markerConn) RemoteAddr() net.Addr { return nil }
+
+func (*markerConn) SetDeadline(t time.Time) error      { return nil }
+func (*markerConn) SetReadDeadline(t time.Time) error  { return nil }
+func (*markerConn) SetWriteDeadline(t time.Time) error { return nil }

+ 4 - 4
psiphon/common/crypto/ssh/session.go

@@ -13,7 +13,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"sync"
 )
 
@@ -85,6 +84,7 @@ const (
 	IXANY         = 39
 	IXOFF         = 40
 	IMAXBEL       = 41
+	IUTF8         = 42 // RFC 8160
 	ISIG          = 50
 	ICANON        = 51
 	XCASE         = 52
@@ -123,7 +123,7 @@ type Session struct {
 	// output and error.
 	//
 	// If either is nil, Run connects the corresponding file
-	// descriptor to an instance of ioutil.Discard. There is a
+	// descriptor to an instance of io.Discard. There is a
 	// fixed amount of buffering that is shared for the two streams.
 	// If either blocks it may eventually cause the remote
 	// command to block.
@@ -505,7 +505,7 @@ func (s *Session) stdout() {
 		return
 	}
 	if s.Stdout == nil {
-		s.Stdout = ioutil.Discard
+		s.Stdout = io.Discard
 	}
 	s.copyFuncs = append(s.copyFuncs, func() error {
 		_, err := io.Copy(s.Stdout, s.ch)
@@ -518,7 +518,7 @@ func (s *Session) stderr() {
 		return
 	}
 	if s.Stderr == nil {
-		s.Stderr = ioutil.Discard
+		s.Stderr = io.Discard
 	}
 	s.copyFuncs = append(s.copyFuncs, func() error {
 		_, err := io.Copy(s.Stderr, s.ch.Stderr())

+ 134 - 18
psiphon/common/crypto/ssh/session_test.go

@@ -11,9 +11,9 @@ import (
 	crypto_rand "crypto/rand"
 	"errors"
 	"io"
-	"io/ioutil"
 	"math/rand"
 	"net"
+	"sync"
 	"testing"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/terminal"
@@ -28,8 +28,14 @@ func dial(handler serverType, t *testing.T) *Client {
 		t.Fatalf("netPipe: %v", err)
 	}
 
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
-		defer c1.Close()
+		defer func() {
+			c1.Close()
+			wg.Done()
+		}()
 		conf := ServerConfig{
 			NoClientAuth: true,
 		}
@@ -37,9 +43,14 @@ func dial(handler serverType, t *testing.T) *Client {
 
 		conn, chans, reqs, err := NewServerConn(c1, &conf)
 		if err != nil {
-			t.Fatalf("Unable to handshake: %v", err)
+			t.Errorf("Unable to handshake: %v", err)
+			return
 		}
-		go DiscardRequests(reqs)
+		wg.Add(1)
+		go func() {
+			DiscardRequests(reqs)
+			wg.Done()
+		}()
 
 		for newCh := range chans {
 			if newCh.ChannelType() != "session" {
@@ -52,8 +63,10 @@ func dial(handler serverType, t *testing.T) *Client {
 				t.Errorf("Accept: %v", err)
 				continue
 			}
+			wg.Add(1)
 			go func() {
 				handler(ch, inReqs, t)
+				wg.Done()
 			}()
 		}
 		if err := conn.Wait(); err != io.EOF {
@@ -338,8 +351,13 @@ func TestServerWindow(t *testing.T) {
 		t.Fatal(err)
 	}
 	defer session.Close()
-	result := make(chan []byte)
 
+	serverStdin, err := session.StdinPipe()
+	if err != nil {
+		t.Fatalf("StdinPipe failed: %v", err)
+	}
+
+	result := make(chan []byte)
 	go func() {
 		defer close(result)
 		echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
@@ -355,10 +373,6 @@ func TestServerWindow(t *testing.T) {
 		result <- echoedBuf.Bytes()
 	}()
 
-	serverStdin, err := session.StdinPipe()
-	if err != nil {
-		t.Fatalf("StdinPipe failed: %v", err)
-	}
 	written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
 	if err != nil {
 		t.Errorf("failed to copy origBuf to serverStdin: %v", err)
@@ -531,7 +545,7 @@ func sendSignal(signal string, ch Channel, t *testing.T) {
 
 func discardHandler(ch Channel, t *testing.T) {
 	defer ch.Close()
-	io.Copy(ioutil.Discard, ch)
+	io.Copy(io.Discard, ch)
 }
 
 func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
@@ -606,7 +620,7 @@ func TestClientWriteEOF(t *testing.T) {
 	}
 	stdin.Close()
 
-	res, err := ioutil.ReadAll(stdout)
+	res, err := io.ReadAll(stdout)
 	if err != nil {
 		t.Fatalf("Read failed: %v", err)
 	}
@@ -618,7 +632,7 @@ func TestClientWriteEOF(t *testing.T) {
 
 func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
-	data, err := ioutil.ReadAll(ch)
+	data, err := io.ReadAll(ch)
 	if err != nil {
 		t.Errorf("handler read error: %v", err)
 	}
@@ -648,30 +662,57 @@ func TestSessionID(t *testing.T) {
 		User:            "user",
 	}
 
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+
+	srvErrCh := make(chan error, 1)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		conn, chans, reqs, err := NewServerConn(c1, serverConf)
+		srvErrCh <- err
 		if err != nil {
-			t.Fatalf("server handshake: %v", err)
+			return
 		}
 		serverID <- conn.SessionID()
-		go DiscardRequests(reqs)
+		wg.Add(1)
+		go func() {
+			DiscardRequests(reqs)
+			wg.Done()
+		}()
 		for ch := range chans {
 			ch.Reject(Prohibited, "")
 		}
 	}()
 
+	cliErrCh := make(chan error, 1)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
+		cliErrCh <- err
 		if err != nil {
-			t.Fatalf("client handshake: %v", err)
+			return
 		}
 		clientID <- conn.SessionID()
-		go DiscardRequests(reqs)
+		wg.Add(1)
+		go func() {
+			DiscardRequests(reqs)
+			wg.Done()
+		}()
 		for ch := range chans {
 			ch.Reject(Prohibited, "")
 		}
 	}()
 
+	if err := <-srvErrCh; err != nil {
+		t.Fatalf("server handshake: %v", err)
+	}
+
+	if err := <-cliErrCh; err != nil {
+		t.Fatalf("client handshake: %v", err)
+	}
+
 	s := <-serverID
 	c := <-clientID
 	if bytes.Compare(s, c) != 0 {
@@ -726,6 +767,8 @@ func TestHostKeyAlgorithms(t *testing.T) {
 	serverConf.AddHostKey(testSigners["rsa"])
 	serverConf.AddHostKey(testSigners["ecdsa"])
 
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
 	connect := func(clientConf *ClientConfig, want string) {
 		var alg string
 		clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
@@ -739,7 +782,11 @@ func TestHostKeyAlgorithms(t *testing.T) {
 		defer c1.Close()
 		defer c2.Close()
 
-		go NewServerConn(c1, serverConf)
+		wg.Add(1)
+		go func() {
+			NewServerConn(c1, serverConf)
+			wg.Done()
+		}()
 		_, _, _, err = NewClientConn(c2, "", clientConf)
 		if err != nil {
 			t.Fatalf("NewClientConn: %v", err)
@@ -766,6 +813,12 @@ func TestHostKeyAlgorithms(t *testing.T) {
 	// with an RSA-SHA2-512 signature.
 	connect(clientConf, KeyAlgoRSA)
 
+	// Client asks for RSA-SHA2-512 explicitly.
+	clientConf.HostKeyAlgorithms = []string{KeyAlgoRSASHA512}
+	// We get back an "ssh-rsa" key but the verification happened
+	// with an RSA-SHA2-512 signature.
+	connect(clientConf, KeyAlgoRSA)
+
 	c1, c2, err := netPipe()
 	if err != nil {
 		t.Fatalf("netPipe: %v", err)
@@ -773,10 +826,73 @@ func TestHostKeyAlgorithms(t *testing.T) {
 	defer c1.Close()
 	defer c2.Close()
 
-	go NewServerConn(c1, serverConf)
+	wg.Add(1)
+	go func() {
+		NewServerConn(c1, serverConf)
+		wg.Done()
+	}()
 	clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
 	_, _, _, err = NewClientConn(c2, "", clientConf)
 	if err == nil {
 		t.Fatal("succeeded connecting with unknown hostkey algorithm")
 	}
 }
+
+func TestServerClientAuthCallback(t *testing.T) {
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	userCh := make(chan string, 1)
+
+	serverConf := &ServerConfig{
+		NoClientAuth: true,
+		NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
+			userCh <- conn.User()
+			return nil, nil
+		},
+	}
+	const someUsername = "some-username"
+
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	clientConf := &ClientConfig{
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		User:            someUsername,
+	}
+
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		_, chans, reqs, err := NewServerConn(c1, serverConf)
+		if err != nil {
+			t.Errorf("server handshake: %v", err)
+			userCh <- "error"
+			return
+		}
+		wg.Add(1)
+		go func() {
+			DiscardRequests(reqs)
+			wg.Done()
+		}()
+		for ch := range chans {
+			ch.Reject(Prohibited, "")
+		}
+	}()
+
+	conn, _, _, err := NewClientConn(c2, "", clientConf)
+	if err != nil {
+		t.Fatalf("client handshake: %v", err)
+		return
+	}
+	conn.Close()
+
+	got := <-userCh
+	if got != someUsername {
+		t.Errorf("username = %q; want %q", got, someUsername)
+	}
+}

+ 35 - 0
psiphon/common/crypto/ssh/tcpip.go

@@ -5,6 +5,7 @@
 package ssh
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"io"
@@ -332,6 +333,40 @@ func (l *tcpListener) Addr() net.Addr {
 	return l.laddr
 }
 
+// DialContext initiates a connection to the addr from the remote host.
+//
+// The provided Context must be non-nil. If the context expires before the
+// connection is complete, an error is returned. Once successfully connected,
+// any expiration of the context will not affect the connection.
+//
+// See func Dial for additional information.
+func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
+	if err := ctx.Err(); err != nil {
+		return nil, err
+	}
+	type connErr struct {
+		conn net.Conn
+		err  error
+	}
+	ch := make(chan connErr)
+	go func() {
+		conn, err := c.Dial(n, addr)
+		select {
+		case ch <- connErr{conn, err}:
+		case <-ctx.Done():
+			if conn != nil {
+				conn.Close()
+			}
+		}
+	}()
+	select {
+	case res := <-ch:
+		return res.conn, res.err
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	}
+}
+
 // [Psiphon]
 // directTCPIPNoSplitTunnel is the same as "direct-tcpip", except it indicates
 // custom split tunnel behavior. It shares the same payload. We allow the

+ 33 - 0
psiphon/common/crypto/ssh/tcpip_test.go

@@ -5,7 +5,10 @@
 package ssh
 
 import (
+	"context"
+	"net"
 	"testing"
+	"time"
 )
 
 func TestAutoPortListenBroken(t *testing.T) {
@@ -18,3 +21,33 @@ func TestAutoPortListenBroken(t *testing.T) {
 		t.Errorf("version %q marked as broken", works)
 	}
 }
+
+func TestClientImplementsDialContext(t *testing.T) {
+	type ContextDialer interface {
+		DialContext(context.Context, string, string) (net.Conn, error)
+	}
+	// Belt and suspenders assertion, since package net does not
+	// declare a ContextDialer type.
+	var _ ContextDialer = &net.Dialer{}
+	var _ ContextDialer = &Client{}
+}
+
+func TestClientDialContextWithCancel(t *testing.T) {
+	c := &Client{}
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+	_, err := c.DialContext(ctx, "tcp", "localhost:1000")
+	if err != context.Canceled {
+		t.Errorf("DialContext: got nil error, expected %v", context.Canceled)
+	}
+}
+
+func TestClientDialContextWithDeadline(t *testing.T) {
+	c := &Client{}
+	ctx, cancel := context.WithDeadline(context.Background(), time.Now())
+	defer cancel()
+	_, err := c.DialContext(ctx, "tcp", "localhost:1000")
+	if err != context.DeadlineExceeded {
+		t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
+	}
+}

+ 1 - 3
psiphon/common/crypto/ssh/test/agent_unix_test.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd
-// +build aix darwin dragonfly freebsd linux netbsd openbsd
+//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
 
 package test
 
@@ -17,7 +16,6 @@ import (
 
 func TestAgentForward(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 

+ 1 - 3
psiphon/common/crypto/ssh/test/banner_test.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd
-// +build aix darwin dragonfly freebsd linux netbsd openbsd
+//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
 
 package test
 
@@ -13,7 +12,6 @@ import (
 
 func TestBannerCallbackAgainstOpenSSH(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 
 	clientConf := clientConfig()
 

+ 1 - 3
psiphon/common/crypto/ssh/test/cert_test.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd
-// +build aix darwin dragonfly freebsd linux netbsd openbsd
+//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
 
 package test
 
@@ -18,7 +17,6 @@ import (
 // Test both logging in with a cert, and also that the certificate presented by an OpenSSH host can be validated correctly
 func TestCertLogin(t *testing.T) {
 	s := newServer(t)
-	defer s.Shutdown()
 
 	// Use a key different from the default.
 	clientKey := testSigners["dsa"]

+ 8 - 6
psiphon/common/crypto/ssh/test/dial_unix_test.go

@@ -2,17 +2,16 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build !windows && !solaris && !js
-// +build !windows,!solaris,!js
+//go:build !windows && !js && !wasip1
 
 package test
 
 // direct-tcpip and direct-streamlocal functional tests
 
 import (
+	"context"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"net"
 	"strings"
 	"testing"
@@ -25,7 +24,6 @@ type dialTester interface {
 
 func testDial(t *testing.T, n, listenAddr string, x dialTester) {
 	server := newServer(t)
-	defer server.Shutdown()
 	sshConn := server.Dial(clientConfig())
 	defer sshConn.Close()
 
@@ -49,13 +47,17 @@ func testDial(t *testing.T, n, listenAddr string, x dialTester) {
 		}
 	}()
 
-	conn, err := sshConn.Dial(n, l.Addr().String())
+	ctx, cancel := context.WithCancel(context.Background())
+	conn, err := sshConn.DialContext(ctx, n, l.Addr().String())
+	// Canceling the context after dial should have no effect
+	// on the opened connection.
+	cancel()
 	if err != nil {
 		t.Fatalf("Dial: %v", err)
 	}
 	x.TestClientConn(t, conn)
 	defer conn.Close()
-	b, err := ioutil.ReadAll(conn)
+	b, err := io.ReadAll(conn)
 	if err != nil {
 		t.Fatalf("ReadAll: %v", err)
 	}

+ 7 - 16
psiphon/common/crypto/ssh/test/forward_unix_test.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd
-// +build aix darwin dragonfly freebsd linux netbsd openbsd
+//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
 
 package test
 
@@ -11,7 +10,6 @@ import (
 	"bytes"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"math/rand"
 	"net"
 	"testing"
@@ -24,7 +22,6 @@ type closeWriter interface {
 
 func testPortForward(t *testing.T, n, listenAddr string) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -58,7 +55,7 @@ func testPortForward(t *testing.T, n, listenAddr string) {
 
 	readChan := make(chan []byte)
 	go func() {
-		data, _ := ioutil.ReadAll(netConn)
+		data, _ := io.ReadAll(netConn)
 		readChan <- data
 	}()
 
@@ -121,7 +118,6 @@ func TestPortForwardUnix(t *testing.T) {
 
 func testAcceptClose(t *testing.T, n, listenAddr string) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 
 	sshListener, err := conn.Listen(n, listenAddr)
@@ -163,10 +159,9 @@ func TestAcceptCloseUnix(t *testing.T) {
 // Check that listeners exit if the underlying client transport dies.
 func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
 	server := newServer(t)
-	defer server.Shutdown()
-	conn := server.Dial(clientConfig())
+	client := server.Dial(clientConfig())
 
-	sshListener, err := conn.Listen(n, listenAddr)
+	sshListener, err := client.Listen(n, listenAddr)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -185,14 +180,10 @@ func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
 
 	// It would be even nicer if we closed the server side, but it
 	// is more involved as the fd for that side is dup()ed.
-	server.clientConn.Close()
+	server.lastDialConn.Close()
 
-	select {
-	case <-time.After(1 * time.Second):
-		t.Errorf("timeout: listener did not close.")
-	case err := <-quit:
-		t.Logf("quit as expected (error %v)", err)
-	}
+	err = <-quit
+	t.Logf("quit as expected (error %v)", err)
 }
 
 func TestPortForwardConnectionCloseTCP(t *testing.T) {

+ 5 - 7
psiphon/common/crypto/ssh/test/multi_auth_test.go

@@ -15,7 +15,6 @@
 // (for linux) in file ./sshd_test_pw.c.
 
 //go:build linux
-// +build linux
 
 package test
 
@@ -77,27 +76,27 @@ func (ctx *multiAuthTestCtx) kbdIntCb(user, instruction string, questions []stri
 func TestMultiAuth(t *testing.T) {
 	testCases := []multiAuthTestCase{
 		// Test password,publickey authentication, assert that password callback is called 1 time
-		multiAuthTestCase{
+		{
 			authMethods:         []string{"password", "publickey"},
 			expectedPasswordCbs: 1,
 		},
 		// Test keyboard-interactive,publickey authentication, assert that keyboard-interactive callback is called 1 time
-		multiAuthTestCase{
+		{
 			authMethods:       []string{"keyboard-interactive", "publickey"},
 			expectedKbdIntCbs: 1,
 		},
 		// Test publickey,password authentication, assert that password callback is called 1 time
-		multiAuthTestCase{
+		{
 			authMethods:         []string{"publickey", "password"},
 			expectedPasswordCbs: 1,
 		},
 		// Test publickey,keyboard-interactive authentication, assert that keyboard-interactive callback is called 1 time
-		multiAuthTestCase{
+		{
 			authMethods:       []string{"publickey", "keyboard-interactive"},
 			expectedKbdIntCbs: 1,
 		},
 		// Test password,password authentication, assert that password callback is called 2 times
-		multiAuthTestCase{
+		{
 			authMethods:         []string{"password", "password"},
 			expectedPasswordCbs: 2,
 		},
@@ -108,7 +107,6 @@ func TestMultiAuth(t *testing.T) {
 			ctx := newMultiAuthTestCtx(t)
 
 			server := newServerForConfig(t, "MultiAuth", map[string]string{"AuthMethods": strings.Join(testCase.authMethods, ",")})
-			defer server.Shutdown()
 
 			clientConfig := clientConfig()
 			server.setTestPassword(clientConfig.User, ctx.password)

+ 98 - 0
psiphon/common/crypto/ssh/test/server_test.go

@@ -0,0 +1,98 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package test
+
+import (
+	"net"
+
+	"golang.org/x/crypto/ssh"
+)
+
+type exitStatusMsg struct {
+	Status uint32
+}
+
+// goServer is a test Go SSH server that accepts public key and certificate
+// authentication and replies with a 0 exit status to any exec request without
+// running any commands.
+type goTestServer struct {
+	listener net.Listener
+	config   *ssh.ServerConfig
+	done     <-chan struct{}
+}
+
+func newTestServer(config *ssh.ServerConfig) (*goTestServer, error) {
+	server := &goTestServer{
+		config: config,
+	}
+	listener, err := net.Listen("tcp", "127.0.0.1:")
+	if err != nil {
+		return nil, err
+	}
+	server.listener = listener
+	done := make(chan struct{}, 1)
+	server.done = done
+	go server.acceptConnections(done)
+
+	return server, nil
+}
+
+func (s *goTestServer) port() (string, error) {
+	_, port, err := net.SplitHostPort(s.listener.Addr().String())
+	return port, err
+}
+
+func (s *goTestServer) acceptConnections(done chan<- struct{}) {
+	defer close(done)
+
+	for {
+		c, err := s.listener.Accept()
+		if err != nil {
+			return
+		}
+		_, chans, reqs, err := ssh.NewServerConn(c, s.config)
+		if err != nil {
+			return
+		}
+		go ssh.DiscardRequests(reqs)
+		defer c.Close()
+
+		for newChannel := range chans {
+			if newChannel.ChannelType() != "session" {
+				newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
+				continue
+			}
+
+			channel, requests, err := newChannel.Accept()
+			if err != nil {
+				continue
+			}
+
+			go func(in <-chan *ssh.Request) {
+				for req := range in {
+					ok := false
+					switch req.Type {
+					case "exec":
+						ok = true
+						go func() {
+							channel.SendRequest("exit-status", false, ssh.Marshal(&exitStatusMsg{Status: 0}))
+							channel.Close()
+						}()
+					}
+					if req.WantReply {
+						req.Reply(ok, nil)
+					}
+				}
+			}(requests)
+		}
+	}
+}
+
+func (s *goTestServer) Close() error {
+	err := s.listener.Close()
+	// wait for the accept loop to exit
+	<-s.done
+	return err
+}

+ 26 - 21
psiphon/common/crypto/ssh/test/session_test.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build !windows && !solaris && !js
-// +build !windows,!solaris,!js
+//go:build !windows && !js && !wasip1
 
 package test
 
@@ -14,6 +13,8 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"path/filepath"
+	"regexp"
 	"runtime"
 	"strings"
 	"testing"
@@ -23,7 +24,6 @@ import (
 
 func TestRunCommandSuccess(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -40,7 +40,6 @@ func TestRunCommandSuccess(t *testing.T) {
 
 func TestHostKeyCheck(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 
 	conf := clientConfig()
 	hostDB := hostKeyDB()
@@ -62,7 +61,6 @@ func TestHostKeyCheck(t *testing.T) {
 
 func TestRunCommandStdin(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -85,7 +83,6 @@ func TestRunCommandStdin(t *testing.T) {
 
 func TestRunCommandStdinError(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -109,7 +106,6 @@ func TestRunCommandStdinError(t *testing.T) {
 
 func TestRunCommandFailed(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -126,7 +122,6 @@ func TestRunCommandFailed(t *testing.T) {
 
 func TestRunCommandWeClosed(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -146,7 +141,6 @@ func TestRunCommandWeClosed(t *testing.T) {
 
 func TestFuncLargeRead(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -178,7 +172,6 @@ func TestFuncLargeRead(t *testing.T) {
 
 func TestKeyChange(t *testing.T) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conf := clientConfig()
 	hostDB := hostKeyDB()
 	conf.HostKeyCallback = hostDB.Check
@@ -225,7 +218,6 @@ func TestValidTerminalMode(t *testing.T) {
 		t.Skipf("skipping on %s", runtime.GOOS)
 	}
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -255,15 +247,31 @@ func TestValidTerminalMode(t *testing.T) {
 		t.Fatalf("session failed: %s", err)
 	}
 
-	stdin.Write([]byte("stty -a && exit\n"))
+	if _, err := io.WriteString(stdin, "echo && echo SHELL $SHELL && stty -a && exit\n"); err != nil {
+		t.Fatal(err)
+	}
 
-	var buf bytes.Buffer
-	if _, err := io.Copy(&buf, stdout); err != nil {
+	buf := new(strings.Builder)
+	if _, err := io.Copy(buf, stdout); err != nil {
 		t.Fatalf("reading failed: %s", err)
 	}
 
+	if testing.Verbose() {
+		t.Logf("echo && echo SHELL $SHELL && stty -a && exit:\n%s", buf)
+	}
+
+	shellLine := regexp.MustCompile("(?m)^SHELL (.*)$").FindStringSubmatch(buf.String())
+	if len(shellLine) != 2 {
+		t.Fatalf("missing output from echo SHELL $SHELL")
+	}
+	switch shell := filepath.Base(strings.TrimSpace(shellLine[1])); shell {
+	case "sh", "bash":
+	default:
+		t.Skipf("skipping test on non-Bourne shell %q", shell)
+	}
+
 	if sttyOutput := buf.String(); !strings.Contains(sttyOutput, "-echo ") {
-		t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput)
+		t.Fatal("terminal mode failure: expected -echo in stty output")
 	}
 }
 
@@ -274,7 +282,6 @@ func TestWindowChange(t *testing.T) {
 		t.Skipf("skipping on %s", runtime.GOOS)
 	}
 	server := newServer(t)
-	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
@@ -322,7 +329,6 @@ func TestWindowChange(t *testing.T) {
 
 func testOneCipher(t *testing.T, cipher string, cipherOrder []string) {
 	server := newServer(t)
-	defer server.Shutdown()
 	conf := clientConfig()
 	conf.Ciphers = []string{cipher}
 	// Don't fail if sshd doesn't have the cipher.
@@ -381,7 +387,6 @@ func TestMACs(t *testing.T) {
 	for _, mac := range macOrder {
 		t.Run(mac, func(t *testing.T) {
 			server := newServer(t)
-			defer server.Shutdown()
 			conf := clientConfig()
 			conf.MACs = []string{mac}
 			// Don't fail if sshd doesn't have the MAC.
@@ -404,10 +409,12 @@ func TestKeyExchanges(t *testing.T) {
 	// are not included in the default list of supported kex so we have to add them
 	// here manually.
 	kexOrder = append(kexOrder, "diffie-hellman-group-exchange-sha1", "diffie-hellman-group-exchange-sha256")
+	// The key exchange algorithms diffie-hellman-group16-sha512 is disabled by
+	// default so we add it here manually.
+	kexOrder = append(kexOrder, "diffie-hellman-group16-sha512")
 	for _, kex := range kexOrder {
 		t.Run(kex, func(t *testing.T) {
 			server := newServer(t)
-			defer server.Shutdown()
 			conf := clientConfig()
 			// Don't fail if sshd doesn't have the kex.
 			conf.KeyExchanges = append([]string{kex}, kexOrder...)
@@ -442,8 +449,6 @@ func TestClientAuthAlgorithms(t *testing.T) {
 			} else {
 				t.Errorf("failed for key %q", key)
 			}
-
-			server.Shutdown()
 		})
 	}
 }

+ 100 - 0
psiphon/common/crypto/ssh/test/sshcli_test.go

@@ -0,0 +1,100 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package test
+
+import (
+	"bytes"
+	"fmt"
+	"os"
+	"os/exec"
+	"path/filepath"
+	"runtime"
+	"testing"
+
+	"golang.org/x/crypto/internal/testenv"
+	"golang.org/x/crypto/ssh"
+	"golang.org/x/crypto/ssh/testdata"
+)
+
+func sshClient(t *testing.T) string {
+	if testing.Short() {
+		t.Skip("Skipping test that executes OpenSSH in -short mode")
+	}
+	sshCLI := os.Getenv("SSH_CLI_PATH")
+	if sshCLI == "" {
+		sshCLI = "ssh"
+	}
+	var err error
+	sshCLI, err = exec.LookPath(sshCLI)
+	if err != nil {
+		t.Skipf("Can't find an ssh(1) client to test against: %v", err)
+	}
+	return sshCLI
+}
+
+func TestSSHCLIAuth(t *testing.T) {
+	if runtime.GOOS == "windows" {
+		t.Skipf("always fails on Windows, see #64403")
+	}
+	sshCLI := sshClient(t)
+	dir := t.TempDir()
+	keyPrivPath := filepath.Join(dir, "rsa")
+
+	for fn, content := range map[string][]byte{
+		keyPrivPath:                        testdata.PEMBytes["rsa"],
+		keyPrivPath + ".pub":               ssh.MarshalAuthorizedKey(testPublicKeys["rsa"]),
+		filepath.Join(dir, "rsa-cert.pub"): testdata.SSHCertificates["rsa-user-testcertificate"],
+	} {
+		if err := os.WriteFile(fn, content, 0600); err != nil {
+			t.Fatalf("WriteFile(%q): %v", fn, err)
+		}
+	}
+
+	certChecker := ssh.CertChecker{
+		IsUserAuthority: func(k ssh.PublicKey) bool {
+			return bytes.Equal(k.Marshal(), testPublicKeys["ca"].Marshal())
+		},
+		UserKeyFallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+			if conn.User() == "testpubkey" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+				return nil, nil
+			}
+
+			return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
+		},
+	}
+
+	config := &ssh.ServerConfig{
+		PublicKeyCallback: certChecker.Authenticate,
+	}
+	config.AddHostKey(testSigners["rsa"])
+
+	server, err := newTestServer(config)
+	if err != nil {
+		t.Fatalf("unable to start test server: %v", err)
+	}
+	defer server.Close()
+
+	port, err := server.port()
+	if err != nil {
+		t.Fatalf("unable to get server port: %v", err)
+	}
+
+	// test public key authentication.
+	cmd := testenv.Command(t, sshCLI, "-vvv", "-i", keyPrivPath, "-o", "StrictHostKeyChecking=no",
+		"-p", port, "testpubkey@127.0.0.1", "true")
+	out, err := cmd.CombinedOutput()
+	if err != nil {
+		t.Fatalf("public key authentication failed, error: %v, command output %q", err, string(out))
+	}
+	// Test SSH user certificate authentication.
+	// The username must match one of the principals included in the certificate.
+	// The certificate "rsa-user-testcertificate" has "testcertificate" as principal.
+	cmd = testenv.Command(t, sshCLI, "-vvv", "-i", keyPrivPath, "-o", "StrictHostKeyChecking=no",
+		"-p", port, "testcertificate@127.0.0.1", "true")
+	out, err = cmd.CombinedOutput()
+	if err != nil {
+		t.Fatalf("user certificate authentication failed, error: %v, command output %q", err, string(out))
+	}
+}

+ 1 - 1
psiphon/common/crypto/ssh/test/sshd_test_pw.c

@@ -20,7 +20,7 @@
 // Run sshd:
 // LD_PRELOAD="sshd_test_pw.so" TEST_USER="..." TEST_PASSWD="..." sshd ...
 
-// +build ignore
+//go:build ignore
 
 #define _GNU_SOURCE
 #include <string.h>

+ 31 - 43
psiphon/common/crypto/ssh/test/test_unix_test.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || plan9
-// +build aix darwin dragonfly freebsd linux netbsd openbsd plan9
+//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || plan9 || solaris
 
 package test
 
@@ -14,7 +13,6 @@ import (
 	"crypto/rand"
 	"encoding/base64"
 	"fmt"
-	"io/ioutil"
 	"log"
 	"net"
 	"os"
@@ -24,6 +22,7 @@ import (
 	"testing"
 	"text/template"
 
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/internal/testenv"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/testdata"
 )
@@ -68,17 +67,13 @@ var configTmpl = map[string]*template.Template{
 
 type server struct {
 	t          *testing.T
-	cleanup    func() // executed during Shutdown
 	configfile string
-	cmd        *exec.Cmd
-	output     bytes.Buffer // holds stderr from sshd process
 
 	testUser     string // test username for sshd
 	testPasswd   string // test password for sshd
 	sshdTestPwSo string // dynamic library to inject a custom password into sshd
 
-	// Client half of the network connection.
-	clientConn net.Conn
+	lastDialConn net.Conn
 }
 
 func username() string {
@@ -151,7 +146,7 @@ func clientConfig() *ssh.ClientConfig {
 // is used for connecting the Go SSH client with sshd without opening
 // ports.
 func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
-	dir, err := ioutil.TempDir("", "unixConnection")
+	dir, err := os.MkdirTemp("", "unixConnection")
 	if err != nil {
 		return nil, nil, err
 	}
@@ -194,15 +189,15 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl
 		s.t.Fatalf("unixConnection: %v", err)
 	}
 
-	s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
+	cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
 	f, err := c2.File()
 	if err != nil {
 		s.t.Fatalf("UnixConn.File: %v", err)
 	}
 	defer f.Close()
-	s.cmd.Stdin = f
-	s.cmd.Stdout = f
-	s.cmd.Stderr = &s.output
+	cmd.Stdin = f
+	cmd.Stdout = f
+	cmd.Stderr = new(bytes.Buffer)
 
 	if s.sshdTestPwSo != "" {
 		if s.testUser == "" {
@@ -211,18 +206,29 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl
 		if s.testPasswd == "" {
 			s.t.Fatal("password missing from sshd_test_pw.so config")
 		}
-		s.cmd.Env = append(os.Environ(),
+		cmd.Env = append(os.Environ(),
 			fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
 			fmt.Sprintf("TEST_USER=%s", s.testUser),
 			fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
 	}
 
-	if err := s.cmd.Start(); err != nil {
-		s.t.Fail()
-		s.Shutdown()
+	if err := cmd.Start(); err != nil {
 		s.t.Fatalf("s.cmd.Start: %v", err)
 	}
-	s.clientConn = c1
+	s.lastDialConn = c1
+	s.t.Cleanup(func() {
+		// Don't check for errors; if it fails it's most
+		// likely "os: process already finished", and we don't
+		// care about that. Use os.Interrupt, so child
+		// processes are killed too.
+		cmd.Process.Signal(os.Interrupt)
+		cmd.Wait()
+		if s.t.Failed() {
+			// log any output from sshd process
+			s.t.Logf("sshd:\n%s", cmd.Stderr)
+		}
+	})
+
 	conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
 	if err != nil {
 		return nil, err
@@ -233,29 +239,11 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl
 func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
 	conn, err := s.TryDial(config)
 	if err != nil {
-		s.t.Fail()
-		s.Shutdown()
 		s.t.Fatalf("ssh.Client: %v", err)
 	}
 	return conn
 }
 
-func (s *server) Shutdown() {
-	if s.cmd != nil && s.cmd.Process != nil {
-		// Don't check for errors; if it fails it's most
-		// likely "os: process already finished", and we don't
-		// care about that. Use os.Interrupt, so child
-		// processes are killed too.
-		s.cmd.Process.Signal(os.Interrupt)
-		s.cmd.Wait()
-	}
-	if s.t.Failed() {
-		// log any output from sshd process
-		s.t.Logf("sshd: %s", s.output.String())
-	}
-	s.cleanup()
-}
-
 func writeFile(path string, contents []byte) {
 	f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
 	if err != nil {
@@ -316,7 +304,7 @@ func newServerForConfig(t *testing.T, config string, configVars map[string]strin
 	if uname == "root" {
 		t.Skip("skipping test because current user is root")
 	}
-	dir, err := ioutil.TempDir("", "sshtest")
+	dir, err := os.MkdirTemp("", "sshtest")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -352,20 +340,20 @@ func newServerForConfig(t *testing.T, config string, configVars map[string]strin
 		authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
 	}
 	writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
+	t.Cleanup(func() {
+		if err := os.RemoveAll(dir); err != nil {
+			t.Error(err)
+		}
+	})
 
 	return &server{
 		t:          t,
 		configfile: f.Name(),
-		cleanup: func() {
-			if err := os.RemoveAll(dir); err != nil {
-				t.Error(err)
-			}
-		},
 	}
 }
 
 func newTempSocket(t *testing.T) (string, func()) {
-	dir, err := ioutil.TempDir("", "socket")
+	dir, err := os.MkdirTemp("", "socket")
 	if err != nil {
 		t.Fatal(err)
 	}

+ 7 - 1
psiphon/common/crypto/ssh/testdata/keys.go

@@ -226,7 +226,13 @@ var SSHCertificates = map[string][]byte{
 `),
 	"rsa-sha2-256": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgOyK28gunJkM60qp4EbsYAjgbUsyjS8u742OLjipIgc0AAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABeSMJ4AAAAAHBPBLwAAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi0yNTYAAAEAbG4De/+QiqopPS3O1H7ySeEUCY56qmdgr02sFErnihdXPDaWXUXxacvJHaEtLrSTSaPL/3v3iKvjLWDOHaQ5c+cN9J7Tqzso7RQCXZD2nK9bwCUyBoiDyBCRe8w4DQEtfL5okpVzQsSAiojQ8hBohMOpy3gFfXrdm4PVC1ZKqlZh4fAc7ajieRq/Tpq2xOLdHwxkcgPNR83WVHva6K9/xjev/5n227/gkHo0qbGs8YYDOFXIEhENi+B23IzxdNVieWdyQpYpe0C2i95Jhyo0wJmaFY2ArruTS+D1jGQQpMPvAQRy26/A5hI83GLhpwyhrN/M8wCxzAhyPL6Ieuh5tQ== host.example.com
 `),
-       "rsa-sha2-512": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgFGv4IpXfs4L/Y0b3rmUdPFhWoUrVnXuPxXr6aHGs7wgAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABeSMRYAAAAAHBPBp4AAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi01MTIAAAEAnF4fVj6mm+UFeNCIf9AKJCv9WzymjjPvzzmaMWWkPWqoV0P0m5SiYfvbY9SbA73Blpv8SOr0DmpublF183kodREia4KyVuC8hLhSCV2Y16hy9MBegOZMepn80w+apj7Rn9QCz5OfEakDdztp6OWTBtqxnZFcTQ4XrgFkNWeWRElGdEvAVNn2WHwHi4EIdz0mdv48Imv5SPlOuW862ZdFG4Do1dUfDIiGsBofLlgcyIYlf+eNHul6sBeUkuwFxisMpI5DQzNp8PX1g/QJA2wzwT674PTqDXNttKjyh50Fdr4sXxm9Gz1+jVLoESvFNa55ERdSyAqNu4wTy11MZsWwSA== host.example.com
+	"rsa-sha2-512": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgFGv4IpXfs4L/Y0b3rmUdPFhWoUrVnXuPxXr6aHGs7wgAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABeSMRYAAAAAHBPBp4AAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi01MTIAAAEAnF4fVj6mm+UFeNCIf9AKJCv9WzymjjPvzzmaMWWkPWqoV0P0m5SiYfvbY9SbA73Blpv8SOr0DmpublF183kodREia4KyVuC8hLhSCV2Y16hy9MBegOZMepn80w+apj7Rn9QCz5OfEakDdztp6OWTBtqxnZFcTQ4XrgFkNWeWRElGdEvAVNn2WHwHi4EIdz0mdv48Imv5SPlOuW862ZdFG4Do1dUfDIiGsBofLlgcyIYlf+eNHul6sBeUkuwFxisMpI5DQzNp8PX1g/QJA2wzwT674PTqDXNttKjyh50Fdr4sXxm9Gz1+jVLoESvFNa55ERdSyAqNu4wTy11MZsWwSA== host.example.com
+`),
+	// Assumes "ca" key above in file named "ca", sign a user cert for "rsa.pub"
+	// using "testcertificate" as principal:
+	//
+	// ssh-keygen -s ca -I username -n testcertificate rsa.pub
+	"rsa-user-testcertificate": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgr0MnhSf+KkQWEQmweJOGsJfOrUt80pQZDaI9YiwSfDUAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAQAAAAh1c2VybmFtZQAAABMAAAAPdGVzdGNlcnRpZmljYXRlAAAAAAAAAAD//////////wAAAAAAAACCAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFnBlcm1pdC1wb3J0LWZvcndhcmRpbmcAAAAAAAAACnBlcm1pdC1wdHkAAAAAAAAADnBlcm1pdC11c2VyLXJjAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi01MTIAAAEAFuA+67KvnlmcodIp0Lv4mR9UW/CHghAaN1csBJTkI8mx3wXKyIPTsS2uXboEhWD0a7S9gps2SEwC5m6E3kV2Rzg7aH1S03GZqMvVlK2VHe7fzuoW2yOKk6yEPjeTF0pKCFbUQ6mce8pRpD/zdvjG0Z287XM3c3Axlrn7qq7TS0MDTjEZ/dsUNFHxep3co/HuAsWVWPVDItr/FPnvZ6WVH1yc8N/AJn0gLHobkGgug22vI9rNIge1wrnXxU9BUSouzkau/PQsrCQapnn+I1H7HaQt0wdk45nxMP+ags+HRI9qpX/p8WDn6+zpqYqN/nfw2aoytyaJqhsV32Teuqtrgg== rsa.pub
 `),
 }
 

+ 36 - 9
psiphon/common/crypto/ssh/transport.go

@@ -17,7 +17,8 @@ import (
 const debugTransport = false
 
 const (
-	gcmCipherID    = "aes128-gcm@openssh.com"
+	gcm128CipherID = "aes128-gcm@openssh.com"
+	gcm256CipherID = "aes256-gcm@openssh.com"
 	aes128cbcID    = "aes128-cbc"
 	tripledescbcID = "3des-cbc"
 )
@@ -48,6 +49,9 @@ type transport struct {
 	rand      io.Reader
 	isClient  bool
 	io.Closer
+
+	strictMode     bool
+	initialKEXDone bool
 }
 
 // packetCipher represents a combination of SSH encryption/MAC
@@ -73,6 +77,18 @@ type connectionState struct {
 	pendingKeyChange chan packetCipher
 }
 
+func (t *transport) setStrictMode() error {
+	if t.reader.seqNum != 1 {
+		return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
+	}
+	t.strictMode = true
+	return nil
+}
+
+func (t *transport) setInitialKEXDone() {
+	t.initialKEXDone = true
+}
+
 // prepareKeyChange sets up key material for a keychange. The key changes in
 // both directions are triggered by reading and writing a msgNewKey packet
 // respectively.
@@ -111,11 +127,12 @@ func (t *transport) printPacket(p []byte, write bool) {
 // Read and decrypt next packet.
 func (t *transport) readPacket() (p []byte, err error) {
 	for {
-		p, err = t.reader.readPacket(t.bufReader)
+		p, err = t.reader.readPacket(t.bufReader, t.strictMode)
 		if err != nil {
 			break
 		}
-		if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) {
+		// in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
+		if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
 			break
 		}
 	}
@@ -126,7 +143,7 @@ func (t *transport) readPacket() (p []byte, err error) {
 	return p, err
 }
 
-func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
+func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
 	packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
 	s.seqNum++
 	if err == nil && len(packet) == 0 {
@@ -139,6 +156,9 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
 			select {
 			case cipher := <-s.pendingKeyChange:
 				s.packetCipher = cipher
+				if strictMode {
+					s.seqNum = 0
+				}
 			default:
 				return nil, errors.New("ssh: got bogus newkeys message")
 			}
@@ -169,10 +189,10 @@ func (t *transport) writePacket(packet []byte) error {
 	if debugTransport {
 		t.printPacket(packet, true)
 	}
-	return t.writer.writePacket(t.bufWriter, t.rand, packet)
+	return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
 }
 
-func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error {
+func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
 	changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
 
 	err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
@@ -187,6 +207,9 @@ func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []
 		select {
 		case cipher := <-s.pendingKeyChange:
 			s.packetCipher = cipher
+			if strictMode {
+				s.seqNum = 0
+			}
 		default:
 			panic("ssh: no key material for msgNewKeys")
 		}
@@ -238,15 +261,19 @@ var (
 // (to setup server->client keys) or clientKeys (for client->server keys).
 func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
 	cipherMode := cipherModes[algs.Cipher]
-	macMode := macModes[algs.MAC]
 
 	iv := make([]byte, cipherMode.ivSize)
 	key := make([]byte, cipherMode.keySize)
-	macKey := make([]byte, macMode.keySize)
 
 	generateKeyMaterial(iv, d.ivTag, kex)
 	generateKeyMaterial(key, d.keyTag, kex)
-	generateKeyMaterial(macKey, d.macKeyTag, kex)
+
+	var macKey []byte
+	if !aeadCiphers[algs.Cipher] {
+		macMode := macModes[algs.MAC]
+		macKey = make([]byte, macMode.keySize)
+		generateKeyMaterial(macKey, d.macKeyTag, kex)
+	}
 
 	return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
 }