Преглед на файлове

Merge pull request #726 from rod-hynes/upgrade-ssh

Upgrade x/crypto/ssh
Rod Hynes преди 1 година
родител
ревизия
0b08d9040c
променени са 41 файла, в които са добавени 2030 реда и са изтрити 351 реда
  1. 0 10
      psiphon/common/crypto/.gitattributes
  2. 2 2
      psiphon/common/crypto/LICENSE
  3. 126 0
      psiphon/common/crypto/internal/poly1305/_asm/sum_amd64_asm.go
  4. 0 40
      psiphon/common/crypto/internal/poly1305/bits_compat.go
  5. 0 22
      psiphon/common/crypto/internal/poly1305/bits_go1.13.go
  6. 1 2
      psiphon/common/crypto/internal/poly1305/mac_noasm.go
  7. 59 74
      psiphon/common/crypto/internal/poly1305/sum_amd64.s
  8. 23 20
      psiphon/common/crypto/internal/poly1305/sum_generic.go
  9. 1 2
      psiphon/common/crypto/internal/poly1305/sum_ppc64x.go
  10. 25 20
      psiphon/common/crypto/internal/poly1305/sum_ppc64x.s
  11. 2 2
      psiphon/common/crypto/internal/testenv/exec.go
  12. 1 1
      psiphon/common/crypto/nacl/secretbox/secretbox.go
  13. 1 1
      psiphon/common/crypto/ssh/agent/client.go
  14. 16 9
      psiphon/common/crypto/ssh/agent/client_test.go
  15. 9 0
      psiphon/common/crypto/ssh/agent/keyring.go
  16. 46 0
      psiphon/common/crypto/ssh/agent/keyring_test.go
  17. 15 17
      psiphon/common/crypto/ssh/certs_test.go
  18. 20 3
      psiphon/common/crypto/ssh/client_auth.go
  19. 108 8
      psiphon/common/crypto/ssh/client_auth_test.go
  20. 1 1
      psiphon/common/crypto/ssh/doc.go
  21. 1 1
      psiphon/common/crypto/ssh/example_test.go
  22. 49 12
      psiphon/common/crypto/ssh/handshake.go
  23. 220 0
      psiphon/common/crypto/ssh/handshake_test.go
  24. 51 1
      psiphon/common/crypto/ssh/keys.go
  25. 86 2
      psiphon/common/crypto/ssh/keys_test.go
  26. 2 0
      psiphon/common/crypto/ssh/messages.go
  27. 56 0
      psiphon/common/crypto/ssh/messages_test.go
  28. 196 65
      psiphon/common/crypto/ssh/server.go
  29. 412 0
      psiphon/common/crypto/ssh/server_multi_auth_test.go
  30. 338 0
      psiphon/common/crypto/ssh/server_test.go
  31. 1 1
      psiphon/common/crypto/ssh/tcpip.go
  32. 4 3
      psiphon/common/crypto/ssh/test/agent_unix_test.go
  33. 1 1
      psiphon/common/crypto/ssh/test/cert_test.go
  34. 1 0
      psiphon/common/crypto/ssh/test/dial_unix_test.go
  35. 1 1
      psiphon/common/crypto/ssh/test/doc.go
  36. 10 0
      psiphon/common/crypto/ssh/test/forward_unix_test.go
  37. 1 1
      psiphon/common/crypto/ssh/test/server_test.go
  38. 91 6
      psiphon/common/crypto/ssh/test/session_test.go
  39. 23 5
      psiphon/common/crypto/ssh/test/test_unix_test.go
  40. 1 1
      psiphon/common/crypto/ssh/testdata/doc.go
  41. 29 17
      psiphon/common/crypto/ssh/testdata/keys.go

+ 0 - 10
psiphon/common/crypto/.gitattributes

@@ -1,10 +0,0 @@
-# Treat all files in this repo as binary, with no git magic updating
-# line endings. Windows users contributing to Go will need to use a
-# modern version of git and editors capable of LF line endings.
-#
-# We'll prevent accidental CRLF line endings from entering the repo
-# via the git-review gofmt checks.
-#
-# See golang.org/issue/9281
-
-* -text

+ 2 - 2
psiphon/common/crypto/LICENSE

@@ -1,4 +1,4 @@
-Copyright (c) 2009 The Go Authors. All rights reserved.
+Copyright 2009 The Go Authors.
 
 Redistribution and use in source and binary forms, with or without
 modification, are permitted provided that the following conditions are
@@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
 copyright notice, this list of conditions and the following disclaimer
 in the documentation and/or other materials provided with the
 distribution.
-   * Neither the name of Google Inc. nor the names of its
+   * Neither the name of Google LLC nor the names of its
 contributors may be used to endorse or promote products derived from
 this software without specific prior written permission.
 

+ 126 - 0
psiphon/common/crypto/internal/poly1305/_asm/sum_amd64_asm.go

@@ -0,0 +1,126 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+	. "github.com/mmcloughlin/avo/build"
+	. "github.com/mmcloughlin/avo/operand"
+	. "github.com/mmcloughlin/avo/reg"
+	_ "golang.org/x/crypto/sha3"
+)
+
+//go:generate go run . -out ../sum_amd64.s -pkg poly1305
+
+func main() {
+	Package("golang.org/x/crypto/internal/poly1305")
+	ConstraintExpr("gc,!purego")
+	update()
+	Generate()
+}
+
+func update() {
+	Implement("update")
+
+	Load(Param("state"), RDI)
+	MOVQ(NewParamAddr("msg_base", 8), RSI)
+	MOVQ(NewParamAddr("msg_len", 16), R15)
+
+	MOVQ(Mem{Base: DI}.Offset(0), R8)   // h0
+	MOVQ(Mem{Base: DI}.Offset(8), R9)   // h1
+	MOVQ(Mem{Base: DI}.Offset(16), R10) // h2
+	MOVQ(Mem{Base: DI}.Offset(24), R11) // r0
+	MOVQ(Mem{Base: DI}.Offset(32), R12) // r1
+
+	CMPQ(R15, Imm(16))
+	JB(LabelRef("bytes_between_0_and_15"))
+
+	Label("loop")
+	POLY1305_ADD(RSI, R8, R9, R10)
+
+	Label("multiply")
+	POLY1305_MUL(R8, R9, R10, R11, R12, RBX, RCX, R13, R14)
+	SUBQ(Imm(16), R15)
+	CMPQ(R15, Imm(16))
+	JAE(LabelRef("loop"))
+
+	Label("bytes_between_0_and_15")
+	TESTQ(R15, R15)
+	JZ(LabelRef("done"))
+	MOVQ(U32(1), RBX)
+	XORQ(RCX, RCX)
+	XORQ(R13, R13)
+	ADDQ(R15, RSI)
+
+	Label("flush_buffer")
+	SHLQ(Imm(8), RBX, RCX)
+	SHLQ(Imm(8), RBX)
+	MOVB(Mem{Base: SI}.Offset(-1), R13B)
+	XORQ(R13, RBX)
+	DECQ(RSI)
+	DECQ(R15)
+	JNZ(LabelRef("flush_buffer"))
+
+	ADDQ(RBX, R8)
+	ADCQ(RCX, R9)
+	ADCQ(Imm(0), R10)
+	MOVQ(U32(16), R15)
+	JMP(LabelRef("multiply"))
+
+	Label("done")
+	MOVQ(R8, Mem{Base: DI}.Offset(0))
+	MOVQ(R9, Mem{Base: DI}.Offset(8))
+	MOVQ(R10, Mem{Base: DI}.Offset(16))
+	RET()
+}
+
+func POLY1305_ADD(msg, h0, h1, h2 GPPhysical) {
+	ADDQ(Mem{Base: msg}.Offset(0), h0)
+	ADCQ(Mem{Base: msg}.Offset(8), h1)
+	ADCQ(Imm(1), h2)
+	LEAQ(Mem{Base: msg}.Offset(16), msg)
+}
+
+func POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3 GPPhysical) {
+	MOVQ(r0, RAX)
+	MULQ(h0)
+	MOVQ(RAX, t0)
+	MOVQ(RDX, t1)
+	MOVQ(r0, RAX)
+	MULQ(h1)
+	ADDQ(RAX, t1)
+	ADCQ(Imm(0), RDX)
+	MOVQ(r0, t2)
+	IMULQ(h2, t2)
+	ADDQ(RDX, t2)
+
+	MOVQ(r1, RAX)
+	MULQ(h0)
+	ADDQ(RAX, t1)
+	ADCQ(Imm(0), RDX)
+	MOVQ(RDX, h0)
+	MOVQ(r1, t3)
+	IMULQ(h2, t3)
+	MOVQ(r1, RAX)
+	MULQ(h1)
+	ADDQ(RAX, t2)
+	ADCQ(RDX, t3)
+	ADDQ(h0, t2)
+	ADCQ(Imm(0), t3)
+
+	MOVQ(t0, h0)
+	MOVQ(t1, h1)
+	MOVQ(t2, h2)
+	ANDQ(Imm(3), h2)
+	MOVQ(t2, t0)
+	ANDQ(I32(-4), t0)
+	ADDQ(t0, h0)
+	ADCQ(t3, h1)
+	ADCQ(Imm(0), h2)
+	SHRQ(Imm(2), t3, t2)
+	SHRQ(Imm(2), t3)
+	ADDQ(t2, h0)
+	ADCQ(t3, h1)
+	ADCQ(Imm(0), h2)
+}

+ 0 - 40
psiphon/common/crypto/internal/poly1305/bits_compat.go

@@ -1,40 +0,0 @@
-// Copyright 2019 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !go1.13
-// +build !go1.13
-
-package poly1305
-
-// Generic fallbacks for the math/bits intrinsics, copied from
-// src/math/bits/bits.go. They were added in Go 1.12, but Add64 and Sum64 had
-// variable time fallbacks until Go 1.13.
-
-func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
-	sum = x + y + carry
-	carryOut = ((x & y) | ((x | y) &^ sum)) >> 63
-	return
-}
-
-func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
-	diff = x - y - borrow
-	borrowOut = ((^x & y) | (^(x ^ y) & diff)) >> 63
-	return
-}
-
-func bitsMul64(x, y uint64) (hi, lo uint64) {
-	const mask32 = 1<<32 - 1
-	x0 := x & mask32
-	x1 := x >> 32
-	y0 := y & mask32
-	y1 := y >> 32
-	w0 := x0 * y0
-	t := x1*y0 + w0>>32
-	w1 := t & mask32
-	w2 := t >> 32
-	w1 += x0 * y1
-	hi = x1*y1 + w2 + w1>>32
-	lo = x * y
-	return
-}

+ 0 - 22
psiphon/common/crypto/internal/poly1305/bits_go1.13.go

@@ -1,22 +0,0 @@
-// Copyright 2019 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.13
-// +build go1.13
-
-package poly1305
-
-import "math/bits"
-
-func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
-	return bits.Add64(x, y, carry)
-}
-
-func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
-	return bits.Sub64(x, y, borrow)
-}
-
-func bitsMul64(x, y uint64) (hi, lo uint64) {
-	return bits.Mul64(x, y)
-}

+ 1 - 2
psiphon/common/crypto/internal/poly1305/mac_noasm.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build (!amd64 && !ppc64le && !s390x) || !gc || purego
-// +build !amd64,!ppc64le,!s390x !gc purego
+//go:build (!amd64 && !ppc64le && !ppc64 && !s390x) || !gc || purego
 
 package poly1305
 

+ 59 - 74
psiphon/common/crypto/internal/poly1305/sum_amd64.s

@@ -1,109 +1,94 @@
-// Copyright 2012 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Code generated by command: go run sum_amd64_asm.go -out ../sum_amd64.s -pkg poly1305. DO NOT EDIT.
 
 //go:build gc && !purego
 // +build gc,!purego
 
-#include "textflag.h"
-
-#define POLY1305_ADD(msg, h0, h1, h2) \
-	ADDQ 0(msg), h0;  \
-	ADCQ 8(msg), h1;  \
-	ADCQ $1, h2;      \
-	LEAQ 16(msg), msg
-
-#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3) \
-	MOVQ  r0, AX;                  \
-	MULQ  h0;                      \
-	MOVQ  AX, t0;                  \
-	MOVQ  DX, t1;                  \
-	MOVQ  r0, AX;                  \
-	MULQ  h1;                      \
-	ADDQ  AX, t1;                  \
-	ADCQ  $0, DX;                  \
-	MOVQ  r0, t2;                  \
-	IMULQ h2, t2;                  \
-	ADDQ  DX, t2;                  \
-	                               \
-	MOVQ  r1, AX;                  \
-	MULQ  h0;                      \
-	ADDQ  AX, t1;                  \
-	ADCQ  $0, DX;                  \
-	MOVQ  DX, h0;                  \
-	MOVQ  r1, t3;                  \
-	IMULQ h2, t3;                  \
-	MOVQ  r1, AX;                  \
-	MULQ  h1;                      \
-	ADDQ  AX, t2;                  \
-	ADCQ  DX, t3;                  \
-	ADDQ  h0, t2;                  \
-	ADCQ  $0, t3;                  \
-	                               \
-	MOVQ  t0, h0;                  \
-	MOVQ  t1, h1;                  \
-	MOVQ  t2, h2;                  \
-	ANDQ  $3, h2;                  \
-	MOVQ  t2, t0;                  \
-	ANDQ  $0xFFFFFFFFFFFFFFFC, t0; \
-	ADDQ  t0, h0;                  \
-	ADCQ  t3, h1;                  \
-	ADCQ  $0, h2;                  \
-	SHRQ  $2, t3, t2;              \
-	SHRQ  $2, t3;                  \
-	ADDQ  t2, h0;                  \
-	ADCQ  t3, h1;                  \
-	ADCQ  $0, h2
-
-// func update(state *[7]uint64, msg []byte)
+// func update(state *macState, msg []byte)
 TEXT ·update(SB), $0-32
 	MOVQ state+0(FP), DI
 	MOVQ msg_base+8(FP), SI
 	MOVQ msg_len+16(FP), R15
-
-	MOVQ 0(DI), R8   // h0
-	MOVQ 8(DI), R9   // h1
-	MOVQ 16(DI), R10 // h2
-	MOVQ 24(DI), R11 // r0
-	MOVQ 32(DI), R12 // r1
-
-	CMPQ R15, $16
+	MOVQ (DI), R8
+	MOVQ 8(DI), R9
+	MOVQ 16(DI), R10
+	MOVQ 24(DI), R11
+	MOVQ 32(DI), R12
+	CMPQ R15, $0x10
 	JB   bytes_between_0_and_15
 
 loop:
-	POLY1305_ADD(SI, R8, R9, R10)
+	ADDQ (SI), R8
+	ADCQ 8(SI), R9
+	ADCQ $0x01, R10
+	LEAQ 16(SI), SI
 
 multiply:
-	POLY1305_MUL(R8, R9, R10, R11, R12, BX, CX, R13, R14)
-	SUBQ $16, R15
-	CMPQ R15, $16
-	JAE  loop
+	MOVQ  R11, AX
+	MULQ  R8
+	MOVQ  AX, BX
+	MOVQ  DX, CX
+	MOVQ  R11, AX
+	MULQ  R9
+	ADDQ  AX, CX
+	ADCQ  $0x00, DX
+	MOVQ  R11, R13
+	IMULQ R10, R13
+	ADDQ  DX, R13
+	MOVQ  R12, AX
+	MULQ  R8
+	ADDQ  AX, CX
+	ADCQ  $0x00, DX
+	MOVQ  DX, R8
+	MOVQ  R12, R14
+	IMULQ R10, R14
+	MOVQ  R12, AX
+	MULQ  R9
+	ADDQ  AX, R13
+	ADCQ  DX, R14
+	ADDQ  R8, R13
+	ADCQ  $0x00, R14
+	MOVQ  BX, R8
+	MOVQ  CX, R9
+	MOVQ  R13, R10
+	ANDQ  $0x03, R10
+	MOVQ  R13, BX
+	ANDQ  $-4, BX
+	ADDQ  BX, R8
+	ADCQ  R14, R9
+	ADCQ  $0x00, R10
+	SHRQ  $0x02, R14, R13
+	SHRQ  $0x02, R14
+	ADDQ  R13, R8
+	ADCQ  R14, R9
+	ADCQ  $0x00, R10
+	SUBQ  $0x10, R15
+	CMPQ  R15, $0x10
+	JAE   loop
 
 bytes_between_0_and_15:
 	TESTQ R15, R15
 	JZ    done
-	MOVQ  $1, BX
+	MOVQ  $0x00000001, BX
 	XORQ  CX, CX
 	XORQ  R13, R13
 	ADDQ  R15, SI
 
 flush_buffer:
-	SHLQ $8, BX, CX
-	SHLQ $8, BX
+	SHLQ $0x08, BX, CX
+	SHLQ $0x08, BX
 	MOVB -1(SI), R13
 	XORQ R13, BX
 	DECQ SI
 	DECQ R15
 	JNZ  flush_buffer
-
 	ADDQ BX, R8
 	ADCQ CX, R9
-	ADCQ $0, R10
-	MOVQ $16, R15
+	ADCQ $0x00, R10
+	MOVQ $0x00000010, R15
 	JMP  multiply
 
 done:
-	MOVQ R8, 0(DI)
+	MOVQ R8, (DI)
 	MOVQ R9, 8(DI)
 	MOVQ R10, 16(DI)
 	RET

+ 23 - 20
psiphon/common/crypto/internal/poly1305/sum_generic.go

@@ -7,7 +7,10 @@
 
 package poly1305
 
-import "encoding/binary"
+import (
+	"encoding/binary"
+	"math/bits"
+)
 
 // Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag
 // for a 64 bytes message is approximately
@@ -114,13 +117,13 @@ type uint128 struct {
 }
 
 func mul64(a, b uint64) uint128 {
-	hi, lo := bitsMul64(a, b)
+	hi, lo := bits.Mul64(a, b)
 	return uint128{lo, hi}
 }
 
 func add128(a, b uint128) uint128 {
-	lo, c := bitsAdd64(a.lo, b.lo, 0)
-	hi, c := bitsAdd64(a.hi, b.hi, c)
+	lo, c := bits.Add64(a.lo, b.lo, 0)
+	hi, c := bits.Add64(a.hi, b.hi, c)
 	if c != 0 {
 		panic("poly1305: unexpected overflow")
 	}
@@ -155,8 +158,8 @@ func updateGeneric(state *macState, msg []byte) {
 		// hide leading zeroes. For full chunks, that's 1 << 128, so we can just
 		// add 1 to the most significant (2¹²⁸) limb, h2.
 		if len(msg) >= TagSize {
-			h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
-			h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
+			h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
+			h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
 			h2 += c + 1
 
 			msg = msg[TagSize:]
@@ -165,8 +168,8 @@ func updateGeneric(state *macState, msg []byte) {
 			copy(buf[:], msg)
 			buf[len(msg)] = 1
 
-			h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
-			h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
+			h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
+			h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
 			h2 += c
 
 			msg = nil
@@ -219,9 +222,9 @@ func updateGeneric(state *macState, msg []byte) {
 		m3 := h2r1
 
 		t0 := m0.lo
-		t1, c := bitsAdd64(m1.lo, m0.hi, 0)
-		t2, c := bitsAdd64(m2.lo, m1.hi, c)
-		t3, _ := bitsAdd64(m3.lo, m2.hi, c)
+		t1, c := bits.Add64(m1.lo, m0.hi, 0)
+		t2, c := bits.Add64(m2.lo, m1.hi, c)
+		t3, _ := bits.Add64(m3.lo, m2.hi, c)
 
 		// Now we have the result as 4 64-bit limbs, and we need to reduce it
 		// modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do
@@ -243,14 +246,14 @@ func updateGeneric(state *macState, msg []byte) {
 
 		// To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c.
 
-		h0, c = bitsAdd64(h0, cc.lo, 0)
-		h1, c = bitsAdd64(h1, cc.hi, c)
+		h0, c = bits.Add64(h0, cc.lo, 0)
+		h1, c = bits.Add64(h1, cc.hi, c)
 		h2 += c
 
 		cc = shiftRightBy2(cc)
 
-		h0, c = bitsAdd64(h0, cc.lo, 0)
-		h1, c = bitsAdd64(h1, cc.hi, c)
+		h0, c = bits.Add64(h0, cc.lo, 0)
+		h1, c = bits.Add64(h1, cc.hi, c)
 		h2 += c
 
 		// h2 is at most 3 + 1 + 1 = 5, making the whole of h at most
@@ -287,9 +290,9 @@ func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
 	// in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the
 	// result if the subtraction underflows, and t otherwise.
 
-	hMinusP0, b := bitsSub64(h0, p0, 0)
-	hMinusP1, b := bitsSub64(h1, p1, b)
-	_, b = bitsSub64(h2, p2, b)
+	hMinusP0, b := bits.Sub64(h0, p0, 0)
+	hMinusP1, b := bits.Sub64(h1, p1, b)
+	_, b = bits.Sub64(h2, p2, b)
 
 	// h = h if h < p else h - p
 	h0 = select64(b, h0, hMinusP0)
@@ -301,8 +304,8 @@ func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
 	//
 	// by just doing a wide addition with the 128 low bits of h and discarding
 	// the overflow.
-	h0, c := bitsAdd64(h0, s[0], 0)
-	h1, _ = bitsAdd64(h1, s[1], c)
+	h0, c := bits.Add64(h0, s[0], 0)
+	h1, _ = bits.Add64(h1, s[1], c)
 
 	binary.LittleEndian.PutUint64(out[0:8], h0)
 	binary.LittleEndian.PutUint64(out[8:16], h1)

+ 1 - 2
psiphon/common/crypto/internal/poly1305/sum_ppc64le.go → psiphon/common/crypto/internal/poly1305/sum_ppc64x.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build gc && !purego
-// +build gc,!purego
+//go:build gc && !purego && (ppc64 || ppc64le)
 
 package poly1305
 

+ 25 - 20
psiphon/common/crypto/internal/poly1305/sum_ppc64le.s → psiphon/common/crypto/internal/poly1305/sum_ppc64x.s

@@ -2,16 +2,25 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-//go:build gc && !purego
-// +build gc,!purego
+//go:build gc && !purego && (ppc64 || ppc64le)
 
 #include "textflag.h"
 
 // This was ported from the amd64 implementation.
 
+#ifdef GOARCH_ppc64le
+#define LE_MOVD MOVD
+#define LE_MOVWZ MOVWZ
+#define LE_MOVHZ MOVHZ
+#else
+#define LE_MOVD MOVDBR
+#define LE_MOVWZ MOVWBR
+#define LE_MOVHZ MOVHBR
+#endif
+
 #define POLY1305_ADD(msg, h0, h1, h2, t0, t1, t2) \
-	MOVD (msg), t0;  \
-	MOVD 8(msg), t1; \
+	LE_MOVD (msg)( R0), t0; \
+	LE_MOVD (msg)(R24), t1; \
 	MOVD $1, t2;     \
 	ADDC t0, h0, h0; \
 	ADDE t1, h1, h1; \
@@ -20,15 +29,14 @@
 
 #define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3, t4, t5) \
 	MULLD  r0, h0, t0;  \
-	MULLD  r0, h1, t4;  \
 	MULHDU r0, h0, t1;  \
+	MULLD  r0, h1, t4;  \
 	MULHDU r0, h1, t5;  \
 	ADDC   t4, t1, t1;  \
 	MULLD  r0, h2, t2;  \
-	ADDZE  t5;          \
 	MULHDU r1, h0, t4;  \
 	MULLD  r1, h0, h0;  \
-	ADD    t5, t2, t2;  \
+	ADDE   t5, t2, t2;  \
 	ADDC   h0, t1, t1;  \
 	MULLD  h2, r1, t3;  \
 	ADDZE  t4, h0;      \
@@ -38,13 +46,11 @@
 	ADDE   t5, t3, t3;  \
 	ADDC   h0, t2, t2;  \
 	MOVD   $-4, t4;     \
-	MOVD   t0, h0;      \
-	MOVD   t1, h1;      \
 	ADDZE  t3;          \
-	ANDCC  $3, t2, h2;  \
-	AND    t2, t4, t0;  \
+	RLDICL $0, t2, $62, h2; \
+	AND    t2, t4, h0;  \
 	ADDC   t0, h0, h0;  \
-	ADDE   t3, h1, h1;  \
+	ADDE   t3, t1, h1;  \
 	SLD    $62, t3, t4; \
 	SRD    $2, t2;      \
 	ADDZE  h2;          \
@@ -54,10 +60,6 @@
 	ADDE   t3, h1, h1;  \
 	ADDZE  h2
 
-DATA ·poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF
-DATA ·poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC
-GLOBL ·poly1305Mask<>(SB), RODATA, $16
-
 // func update(state *[7]uint64, msg []byte)
 TEXT ·update(SB), $0-32
 	MOVD state+0(FP), R3
@@ -70,12 +72,15 @@ TEXT ·update(SB), $0-32
 	MOVD 24(R3), R11 // r0
 	MOVD 32(R3), R12 // r1
 
+	MOVD $8, R24
+
 	CMP R5, $16
 	BLT bytes_between_0_and_15
 
 loop:
 	POLY1305_ADD(R4, R8, R9, R10, R20, R21, R22)
 
+	PCALIGN $16
 multiply:
 	POLY1305_MUL(R8, R9, R10, R11, R12, R16, R17, R18, R14, R20, R21)
 	ADD $-16, R5
@@ -97,7 +102,7 @@ flush_buffer:
 
 	// Greater than 8 -- load the rightmost remaining bytes in msg
 	// and put into R17 (h1)
-	MOVD (R4)(R21), R17
+	LE_MOVD (R4)(R21), R17
 	MOVD $16, R22
 
 	// Find the offset to those bytes
@@ -121,7 +126,7 @@ just1:
 	BLT less8
 
 	// Exactly 8
-	MOVD (R4), R16
+	LE_MOVD (R4), R16
 
 	CMP R17, $0
 
@@ -136,7 +141,7 @@ less8:
 	MOVD  $0, R22   // shift count
 	CMP   R5, $4
 	BLT   less4
-	MOVWZ (R4), R16
+	LE_MOVWZ (R4), R16
 	ADD   $4, R4
 	ADD   $-4, R5
 	MOVD  $32, R22
@@ -144,7 +149,7 @@ less8:
 less4:
 	CMP   R5, $2
 	BLT   less2
-	MOVHZ (R4), R21
+	LE_MOVHZ (R4), R21
 	SLD   R22, R21, R21
 	OR    R16, R21, R16
 	ADD   $16, R22

+ 2 - 2
psiphon/common/crypto/internal/testenv/exec.go

@@ -57,8 +57,8 @@ func CommandContext(t testing.TB, ctx context.Context, name string, args ...stri
 			// grace periods to clean up: one for the delay between the first
 			// termination signal being sent (via the Cancel callback when the Context
 			// expires) and the process being forcibly terminated (via the WaitDelay
-			// field), and a second one for the delay becween the process being
-			// terminated and and the test logging its output for debugging.
+			// field), and a second one for the delay between the process being
+			// terminated and the test logging its output for debugging.
 			//
 			// (We want to ensure that the test process itself has enough time to
 			// log the output before it is also terminated.)

+ 1 - 1
psiphon/common/crypto/nacl/secretbox/secretbox.go

@@ -32,7 +32,7 @@ chunk size.
 
 This package is interoperable with NaCl: https://nacl.cr.yp.to/secretbox.html.
 */
-package secretbox // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/nacl/secretbox"
+package secretbox
 
 import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/internal/poly1305"

+ 1 - 1
psiphon/common/crypto/ssh/agent/client.go

@@ -10,7 +10,7 @@
 // References:
 //
 //	[PROTOCOL.agent]: https://tools.ietf.org/html/draft-miller-ssh-agent-00
-package agent // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/agent"
+package agent
 
 import (
 	"bytes"

+ 16 - 9
psiphon/common/crypto/ssh/agent/client_test.go

@@ -164,15 +164,23 @@ func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert
 	data := []byte("hello")
 	sig, err := agent.Sign(pubKey, data)
 	if err != nil {
-		t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
-	}
-
-	if err := pubKey.Verify(data, sig); err != nil {
-		t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
+		t.Logf("sign failed with key type %q", pubKey.Type())
+		// In integration tests ssh-rsa (SHA1 signatures) may be disabled for
+		// security reasons, we check SHA-2 variants later.
+		if pubKey.Type() != ssh.KeyAlgoRSA && pubKey.Type() != ssh.CertAlgoRSAv01 {
+			t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
+		}
+	} else {
+		if err := pubKey.Verify(data, sig); err != nil {
+			t.Logf("verify failed with key type %q", pubKey.Type())
+			if pubKey.Type() != ssh.KeyAlgoRSA {
+				t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
+			}
+		}
 	}
 
 	// For tests on RSA keys, try signing with SHA-256 and SHA-512 flags
-	if pubKey.Type() == "ssh-rsa" {
+	if pubKey.Type() == ssh.KeyAlgoRSA {
 		sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) {
 			sig, err = agent.SignWithFlags(pubKey, data, flag)
 			if err != nil {
@@ -185,7 +193,6 @@ func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert
 				t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
 			}
 		}
-		sshFlagTest(0, ssh.KeyAlgoRSA)
 		sshFlagTest(SignatureFlagRsaSha256, ssh.KeyAlgoRSASHA256)
 		sshFlagTest(SignatureFlagRsaSha512, ssh.KeyAlgoRSASHA512)
 	}
@@ -244,7 +251,7 @@ func TestMalformedRequests(t *testing.T) {
 }
 
 func TestAgent(t *testing.T) {
-	for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
+	for _, keyType := range []string{"rsa", "ecdsa", "ed25519"} {
 		testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
 		testKeyringAgent(t, testPrivateKeys[keyType], nil, 0)
 	}
@@ -402,7 +409,7 @@ func testLockAgent(agent Agent, t *testing.T) {
 	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil {
 		t.Errorf("Add: %v", err)
 	}
-	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil {
+	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["ecdsa"], Comment: "comment ecdsa"}); err != nil {
 		t.Errorf("Add: %v", err)
 	}
 	if keys, err := agent.List(); err != nil {

+ 9 - 0
psiphon/common/crypto/ssh/agent/keyring.go

@@ -175,6 +175,15 @@ func (r *keyring) Add(key AddedKey) error {
 		p.expire = &t
 	}
 
+	// If we already have a Signer with the same public key, replace it with the
+	// new one.
+	for idx, k := range r.keys {
+		if bytes.Equal(k.signer.PublicKey().Marshal(), p.signer.PublicKey().Marshal()) {
+			r.keys[idx] = p
+			return nil
+		}
+	}
+
 	r.keys = append(r.keys, p)
 
 	return nil

+ 46 - 0
psiphon/common/crypto/ssh/agent/keyring_test.go

@@ -29,6 +29,10 @@ func validateListedKeys(t *testing.T, a Agent, expectedKeys []string) {
 		t.Fatalf("failed to list keys: %v", err)
 		return
 	}
+	if len(listedKeys) != len(expectedKeys) {
+		t.Fatalf("expeted %d key, got %d", len(expectedKeys), len(listedKeys))
+		return
+	}
 	actualKeys := make(map[string]bool)
 	for _, key := range listedKeys {
 		actualKeys[key.Comment] = true
@@ -74,3 +78,45 @@ func TestKeyringAddingAndRemoving(t *testing.T) {
 	}
 	validateListedKeys(t, k, []string{})
 }
+
+func TestAddDuplicateKey(t *testing.T) {
+	keyNames := []string{"rsa", "user"}
+
+	k := NewKeyring()
+	for _, keyName := range keyNames {
+		addTestKey(t, k, keyName)
+	}
+	validateListedKeys(t, k, keyNames)
+	// Add the keys again.
+	for _, keyName := range keyNames {
+		addTestKey(t, k, keyName)
+	}
+	validateListedKeys(t, k, keyNames)
+	// Add an existing key with an updated comment.
+	keyName := keyNames[0]
+	addedKey := AddedKey{
+		PrivateKey: testPrivateKeys[keyName],
+		Comment:    "comment updated",
+	}
+	err := k.Add(addedKey)
+	if err != nil {
+		t.Fatalf("failed to add key %q: %v", keyName, err)
+	}
+	// Check the that key is found and the comment was updated.
+	keys, err := k.List()
+	if err != nil {
+		t.Fatalf("failed to list keys: %v", err)
+	}
+	if len(keys) != len(keyNames) {
+		t.Fatalf("expected %d keys, got %d", len(keyNames), len(keys))
+	}
+	isFound := false
+	for _, key := range keys {
+		if key.Comment == addedKey.Comment {
+			isFound = true
+		}
+	}
+	if !isFound {
+		t.Fatal("key with the updated comment not found")
+	}
+}

+ 15 - 17
psiphon/common/crypto/ssh/certs_test.go

@@ -15,14 +15,12 @@ import (
 	"reflect"
 	"testing"
 	"time"
-)
 
-// Cert generated by ssh-keygen 6.0p1 Debian-4.
-// % ssh-keygen -s ca-key -I test user-key
-const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/testdata"
+)
 
 func TestParseCert(t *testing.T) {
-	authKeyBytes := []byte(exampleSSHCert)
+	authKeyBytes := bytes.TrimSuffix(testdata.SSHCertificates["rsa"], []byte(" host.example.com\n"))
 
 	key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
 	if err != nil {
@@ -103,7 +101,7 @@ func TestParseCertWithOptions(t *testing.T) {
 }
 
 func TestValidateCert(t *testing.T) {
-	key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert))
+	key, _, _, _, err := ParseAuthorizedKey(testdata.SSHCertificates["rsa-user-testcertificate"])
 	if err != nil {
 		t.Fatalf("ParseAuthorizedKey: %v", err)
 	}
@@ -116,7 +114,7 @@ func TestValidateCert(t *testing.T) {
 		return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
 	}
 
-	if err := checker.CheckCert("user", validCert); err != nil {
+	if err := checker.CheckCert("testcertificate", validCert); err != nil {
 		t.Errorf("Unable to validate certificate: %v", err)
 	}
 	invalidCert := &Certificate{
@@ -125,7 +123,7 @@ func TestValidateCert(t *testing.T) {
 		ValidBefore:  CertTimeInfinity,
 		Signature:    &Signature{},
 	}
-	if err := checker.CheckCert("user", invalidCert); err == nil {
+	if err := checker.CheckCert("testcertificate", invalidCert); err == nil {
 		t.Error("Invalid cert signature passed validation")
 	}
 }
@@ -367,21 +365,21 @@ func TestCertTypes(t *testing.T) {
 
 func TestCertSignWithMultiAlgorithmSigner(t *testing.T) {
 	type testcase struct {
-		sigAlgo   string
-		algoritms []string
+		sigAlgo    string
+		algorithms []string
 	}
 	cases := []testcase{
 		{
-			sigAlgo:   KeyAlgoRSA,
-			algoritms: []string{KeyAlgoRSA, KeyAlgoRSASHA512},
+			sigAlgo:    KeyAlgoRSA,
+			algorithms: []string{KeyAlgoRSA, KeyAlgoRSASHA512},
 		},
 		{
-			sigAlgo:   KeyAlgoRSASHA256,
-			algoritms: []string{KeyAlgoRSASHA256, KeyAlgoRSA, KeyAlgoRSASHA512},
+			sigAlgo:    KeyAlgoRSASHA256,
+			algorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSA, KeyAlgoRSASHA512},
 		},
 		{
-			sigAlgo:   KeyAlgoRSASHA512,
-			algoritms: []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256},
+			sigAlgo:    KeyAlgoRSASHA512,
+			algorithms: []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256},
 		},
 	}
 
@@ -393,7 +391,7 @@ func TestCertSignWithMultiAlgorithmSigner(t *testing.T) {
 
 	for _, c := range cases {
 		t.Run(c.sigAlgo, func(t *testing.T) {
-			signer, err := NewSignerWithAlgorithms(testSigners["rsa"].(AlgorithmSigner), c.algoritms)
+			signer, err := NewSignerWithAlgorithms(testSigners["rsa"].(AlgorithmSigner), c.algorithms)
 			if err != nil {
 				t.Fatalf("NewSignerWithAlgorithms error: %v", err)
 			}

+ 20 - 3
psiphon/common/crypto/ssh/client_auth.go

@@ -71,6 +71,10 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
 	for auth := AuthMethod(new(noneAuth)); auth != nil; {
 		ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
 		if err != nil {
+			// On disconnect, return error immediately
+			if _, ok := err.(*disconnectMsg); ok {
+				return err
+			}
 			// We return the error later if there is no other method left to
 			// try.
 			ok = authFailure
@@ -404,10 +408,10 @@ func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, e
 		return false, err
 	}
 
-	return confirmKeyAck(key, algo, c)
+	return confirmKeyAck(key, c)
 }
 
-func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
+func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
 	pubKey := key.Marshal()
 
 	for {
@@ -425,7 +429,15 @@ func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
 			if err := Unmarshal(packet, &msg); err != nil {
 				return false, err
 			}
-			if msg.Algo != algo || !bytes.Equal(msg.PubKey, pubKey) {
+			// According to RFC 4252 Section 7 the algorithm in
+			// SSH_MSG_USERAUTH_PK_OK should match that of the request but some
+			// servers send the key type instead. OpenSSH allows any algorithm
+			// that matches the public key, so we do the same.
+			// https://github.com/openssh/openssh-portable/blob/86bdd385/sshconnect2.c#L709
+			if !contains(algorithmsForKeyFormat(key.Type()), msg.Algo) {
+				return false, nil
+			}
+			if !bytes.Equal(msg.PubKey, pubKey) {
 				return false, nil
 			}
 			return true, nil
@@ -543,6 +555,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 	}
 
 	gotMsgExtInfo := false
+	gotUserAuthInfoRequest := false
 	for {
 		packet, err := c.readPacket()
 		if err != nil {
@@ -573,6 +586,9 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 			if msg.PartialSuccess {
 				return authPartialSuccess, msg.Methods, nil
 			}
+			if !gotUserAuthInfoRequest {
+				return authFailure, msg.Methods, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
+			}
 			return authFailure, msg.Methods, nil
 		case msgUserAuthSuccess:
 			return authSuccess, nil, nil
@@ -584,6 +600,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 		if err := Unmarshal(packet, &msg); err != nil {
 			return authFailure, nil, err
 		}
+		gotUserAuthInfoRequest = true
 
 		// Manually unpack the prompt/echo pairs.
 		rest := msg.Prompts

+ 108 - 8
psiphon/common/crypto/ssh/client_auth_test.go

@@ -38,7 +38,7 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
 	return err
 }
 
-// tryAuth runs a handshake with a given config against an SSH server
+// tryAuthWithGSSAPIWithMICConfig runs a handshake with a given config against an SSH server
 // with a given GSSAPIWithMICConfig and config serverConfig. Returns both client and server side errors.
 func tryAuthWithGSSAPIWithMICConfig(t *testing.T, clientConfig *ClientConfig, gssAPIWithMICConfig *GSSAPIWithMICConfig) error {
 	err, _ := tryAuthBothSides(t, clientConfig, gssAPIWithMICConfig)
@@ -641,17 +641,28 @@ func TestClientAuthMaxAuthTries(t *testing.T) {
 		defer c1.Close()
 		defer c2.Close()
 
-		go newServer(c1, serverConfig)
-		_, _, _, err = NewClientConn(c2, "", clientConfig)
-		if tries > 2 {
-			if err == nil {
+		errCh := make(chan error, 1)
+
+		go func() {
+			_, err := newServer(c1, serverConfig)
+			errCh <- err
+		}()
+		_, _, _, cliErr := NewClientConn(c2, "", clientConfig)
+		srvErr := <-errCh
+
+		if tries > serverConfig.MaxAuthTries {
+			if cliErr == nil {
 				t.Fatalf("client: got no error, want %s", expectedErr)
-			} else if err.Error() != expectedErr.Error() {
+			} else if cliErr.Error() != expectedErr.Error() {
 				t.Fatalf("client: got %s, want %s", err, expectedErr)
 			}
+			var authErr *ServerAuthError
+			if !errors.As(srvErr, &authErr) {
+				t.Errorf("expected ServerAuthError, got: %v", srvErr)
+			}
 		} else {
-			if err != nil {
-				t.Fatalf("client: got %s, want no error", err)
+			if cliErr != nil {
+				t.Fatalf("client: got %s, want no error", cliErr)
 			}
 		}
 	}
@@ -1282,3 +1293,92 @@ func TestCertAuthOpenSSHCompat(t *testing.T) {
 		t.Fatalf("unable to dial remote side: %s", err)
 	}
 }
+
+func TestKeyboardInteractiveAuthEarlyFail(t *testing.T) {
+	const maxAuthTries = 2
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	// Start testserver
+	serverConfig := &ServerConfig{
+		MaxAuthTries: maxAuthTries,
+		KeyboardInteractiveCallback: func(c ConnMetadata,
+			client KeyboardInteractiveChallenge) (*Permissions, error) {
+			// Fail keyboard-interactive authentication early before
+			// any prompt is sent to client.
+			return nil, errors.New("keyboard-interactive auth failed")
+		},
+		PasswordCallback: func(c ConnMetadata,
+			pass []byte) (*Permissions, error) {
+			if string(pass) == clientPassword {
+				return nil, nil
+			}
+			return nil, errors.New("password auth failed")
+		},
+	}
+	serverConfig.AddHostKey(testSigners["rsa"])
+
+	serverDone := make(chan struct{})
+	go func() {
+		defer func() { serverDone <- struct{}{} }()
+		conn, chans, reqs, err := NewServerConn(c2, serverConfig)
+		if err != nil {
+			return
+		}
+		_ = conn.Close()
+
+		discarderDone := make(chan struct{})
+		go func() {
+			defer func() { discarderDone <- struct{}{} }()
+			DiscardRequests(reqs)
+		}()
+		for newChannel := range chans {
+			newChannel.Reject(Prohibited,
+				"testserver not accepting requests")
+		}
+
+		<-discarderDone
+	}()
+
+	// Connect to testserver, expect KeyboardInteractive() to be not called,
+	// PasswordCallback() to be called and connection to succeed.
+	passwordCallbackCalled := false
+	clientConfig := &ClientConfig{
+		User: "testuser",
+		Auth: []AuthMethod{
+			RetryableAuthMethod(KeyboardInteractive(func(name,
+				instruction string, questions []string,
+				echos []bool) ([]string, error) {
+				t.Errorf("unexpected call to KeyboardInteractive()")
+				return []string{clientPassword}, nil
+			}), maxAuthTries),
+			RetryableAuthMethod(PasswordCallback(func() (secret string,
+				err error) {
+				t.Logf("PasswordCallback()")
+				passwordCallbackCalled = true
+				return clientPassword, nil
+			}), maxAuthTries),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	conn, _, _, err := NewClientConn(c1, "", clientConfig)
+	if err != nil {
+		t.Errorf("unexpected NewClientConn() error: %v", err)
+	}
+	if conn != nil {
+		conn.Close()
+	}
+
+	// Wait for server to finish.
+	<-serverDone
+
+	if !passwordCallbackCalled {
+		t.Errorf("expected PasswordCallback() to be called")
+	}
+}

+ 1 - 1
psiphon/common/crypto/ssh/doc.go

@@ -20,4 +20,4 @@ References:
 This package does not fall under the stability promise of the Go language itself,
 so its API may be changed when pressing needs arise.
 */
-package ssh // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
+package ssh

+ 1 - 1
psiphon/common/crypto/ssh/example_test.go

@@ -384,7 +384,7 @@ func ExampleCertificate_SignCert() {
 	}
 	mas, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256})
 	if err != nil {
-		log.Fatal("unable to create signer with algoritms: ", err)
+		log.Fatal("unable to create signer with algorithms: ", err)
 	}
 	certificate := ssh.Certificate{
 		Key:      publicKey,

+ 49 - 12
psiphon/common/crypto/ssh/handshake.go

@@ -30,6 +30,11 @@ const debugHandshake = false
 // quickly.
 const chanSize = 16
 
+// maxPendingPackets sets the maximum number of packets to queue while waiting
+// for KEX to complete. This limits the total pending data to maxPendingPackets
+// * maxPacket bytes, which is ~16.8MB.
+const maxPendingPackets = 64
+
 // keyingTransport is a packet based transport that supports key
 // changes. It need not be thread-safe. It should pass through
 // msgNewKeys in both directions.
@@ -78,13 +83,22 @@ type handshakeTransport struct {
 	incoming  chan []byte
 	readError error
 
-	mu               sync.Mutex
-	writeError       error
-	sentInitPacket   []byte
-	sentInitMsg      *kexInitMsg
-	pendingPackets   [][]byte // Used when a key exchange is in progress.
+	mu sync.Mutex
+	// Condition for the above mutex. It is used to notify a completed key
+	// exchange or a write failure. Writes can wait for this condition while a
+	// key exchange is in progress.
+	writeCond      *sync.Cond
+	writeError     error
+	sentInitPacket []byte
+	sentInitMsg    *kexInitMsg
+	// Used to queue writes when a key exchange is in progress. The length is
+	// limited by pendingPacketsSize. Once full, writes will block until the key
+	// exchange is completed or an error occurs. If not empty, it is emptied
+	// all at once when the key exchange is completed in kexLoop.
+	pendingPackets   [][]byte
 	writePacketsLeft uint32
 	writeBytesLeft   int64
+	userAuthComplete bool // whether the user authentication phase is complete
 
 	// If the read loop wants to schedule a kex, it pings this
 	// channel, and the write loop will send out a kex
@@ -138,6 +152,7 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
 
 		config: config,
 	}
+	t.writeCond = sync.NewCond(&t.mu)
 	t.resetReadThresholds()
 	t.resetWriteThresholds()
 
@@ -264,6 +279,7 @@ func (t *handshakeTransport) recordWriteError(err error) {
 	defer t.mu.Unlock()
 	if t.writeError == nil && err != nil {
 		t.writeError = err
+		t.writeCond.Broadcast()
 	}
 }
 
@@ -367,6 +383,8 @@ write:
 			}
 		}
 		t.pendingPackets = t.pendingPackets[:0]
+		// Unblock writePacket if waiting for KEX.
+		t.writeCond.Broadcast()
 		t.mu.Unlock()
 	}
 
@@ -941,26 +959,44 @@ func (t *handshakeTransport) sendKexInit() error {
 	return nil
 }
 
+var errSendBannerPhase = errors.New("ssh: SendAuthBanner outside of authentication phase")
+
 func (t *handshakeTransport) writePacket(p []byte) error {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
 	switch p[0] {
 	case msgKexInit:
 		return errors.New("ssh: only handshakeTransport can send kexInit")
 	case msgNewKeys:
 		return errors.New("ssh: only handshakeTransport can send newKeys")
+	case msgUserAuthBanner:
+		if t.userAuthComplete {
+			return errSendBannerPhase
+		}
+	case msgUserAuthSuccess:
+		t.userAuthComplete = true
 	}
 
-	t.mu.Lock()
-	defer t.mu.Unlock()
 	if t.writeError != nil {
 		return t.writeError
 	}
 
 	if t.sentInitMsg != nil {
-		// Copy the packet so the writer can reuse the buffer.
-		cp := make([]byte, len(p))
-		copy(cp, p)
-		t.pendingPackets = append(t.pendingPackets, cp)
-		return nil
+		if len(t.pendingPackets) < maxPendingPackets {
+			// Copy the packet so the writer can reuse the buffer.
+			cp := make([]byte, len(p))
+			copy(cp, p)
+			t.pendingPackets = append(t.pendingPackets, cp)
+			return nil
+		}
+		for t.sentInitMsg != nil {
+			// Block and wait for KEX to complete or an error.
+			t.writeCond.Wait()
+			if t.writeError != nil {
+				return t.writeError
+			}
+		}
 	}
 
 	if t.writeBytesLeft > 0 {
@@ -977,6 +1013,7 @@ func (t *handshakeTransport) writePacket(p []byte) error {
 
 	if err := t.pushPacket(p); err != nil {
 		t.writeError = err
+		t.writeCond.Broadcast()
 	}
 
 	return nil

+ 220 - 0
psiphon/common/crypto/ssh/handshake_test.go

@@ -539,6 +539,226 @@ func TestDisconnect(t *testing.T) {
 	}
 }
 
+type mockKeyingTransport struct {
+	packetConn
+	kexInitAllowed chan struct{}
+	kexInitSent    chan struct{}
+}
+
+func (n *mockKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
+	return nil
+}
+
+func (n *mockKeyingTransport) writePacket(packet []byte) error {
+	if packet[0] == msgKexInit {
+		<-n.kexInitAllowed
+		n.kexInitSent <- struct{}{}
+	}
+	return n.packetConn.writePacket(packet)
+}
+
+func (n *mockKeyingTransport) readPacket() ([]byte, error) {
+	return n.packetConn.readPacket()
+}
+
+func (n *mockKeyingTransport) setStrictMode() error { return nil }
+
+func (n *mockKeyingTransport) setInitialKEXDone() {}
+
+func TestHandshakePendingPacketsWait(t *testing.T) {
+	a, b := memPipe()
+
+	trS := &mockKeyingTransport{
+		packetConn:     a,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trS.kexInitAllowed <- struct{}{}
+
+	trC := &mockKeyingTransport{
+		packetConn:     b,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trC.kexInitAllowed <- struct{}{}
+
+	clientConf := &ClientConfig{
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+	server := newServerTransport(trS, v, v, serverConf)
+
+	if err := server.waitSession(); err != nil {
+		t.Fatalf("server.waitSession: %v", err)
+	}
+	if err := client.waitSession(); err != nil {
+		t.Fatalf("client.waitSession: %v", err)
+	}
+
+	<-trC.kexInitSent
+	<-trS.kexInitSent
+
+	// Allow and request new KEX server side.
+	trS.kexInitAllowed <- struct{}{}
+	server.requestKeyExchange()
+	// Wait until the KEX init is sent.
+	<-trS.kexInitSent
+	// The client is not allowed to respond to the KEX, so writes will be
+	// blocked on the server side once the packets queue is full.
+	for i := 0; i < maxPendingPackets; i++ {
+		p := []byte{msgRequestSuccess, byte(i)}
+		if err := server.writePacket(p); err != nil {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}
+	// The packets queue is now full, the next write will block.
+	server.mu.Lock()
+	if len(server.pendingPackets) != maxPendingPackets {
+		t.Errorf("unexpected pending packets size; got: %d, want: %d", len(server.pendingPackets), maxPendingPackets)
+	}
+	server.mu.Unlock()
+
+	writeDone := make(chan struct{})
+	go func() {
+		defer close(writeDone)
+
+		p := []byte{msgRequestSuccess, byte(65)}
+		// This write will block until KEX completes.
+		err := server.writePacket(p)
+		if err != nil {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}()
+
+	// Consume packets on the client side
+	readDone := make(chan bool)
+	go func() {
+		defer close(readDone)
+
+		for {
+			if _, err := client.readPacket(); err != nil {
+				if err != io.EOF {
+					t.Errorf("unexpected read error: %v", err)
+				}
+				break
+			}
+		}
+	}()
+
+	// Allow the client to reply to the KEX and so unblock the write goroutine.
+	trC.kexInitAllowed <- struct{}{}
+	<-trC.kexInitSent
+	<-writeDone
+	// Close the client to unblock the read goroutine.
+	client.Close()
+	<-readDone
+	server.Close()
+}
+
+func TestHandshakePendingPacketsError(t *testing.T) {
+	a, b := memPipe()
+
+	trS := &mockKeyingTransport{
+		packetConn:     a,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trS.kexInitAllowed <- struct{}{}
+
+	trC := &mockKeyingTransport{
+		packetConn:     b,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trC.kexInitAllowed <- struct{}{}
+
+	clientConf := &ClientConfig{
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+	server := newServerTransport(trS, v, v, serverConf)
+
+	if err := server.waitSession(); err != nil {
+		t.Fatalf("server.waitSession: %v", err)
+	}
+	if err := client.waitSession(); err != nil {
+		t.Fatalf("client.waitSession: %v", err)
+	}
+
+	<-trC.kexInitSent
+	<-trS.kexInitSent
+
+	// Allow and request new KEX server side.
+	trS.kexInitAllowed <- struct{}{}
+	server.requestKeyExchange()
+	// Wait until the KEX init is sent.
+	<-trS.kexInitSent
+	// The client is not allowed to respond to the KEX, so writes will be
+	// blocked on the server side once the packets queue is full.
+	for i := 0; i < maxPendingPackets; i++ {
+		p := []byte{msgRequestSuccess, byte(i)}
+		if err := server.writePacket(p); err != nil {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}
+	// The packets queue is now full, the next write will block.
+	writeDone := make(chan struct{})
+	go func() {
+		defer close(writeDone)
+
+		p := []byte{msgRequestSuccess, byte(65)}
+		// This write will block until KEX completes.
+		err := server.writePacket(p)
+		if err != io.EOF {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}()
+
+	// Consume packets on the client side
+	readDone := make(chan bool)
+	go func() {
+		defer close(readDone)
+
+		for {
+			if _, err := client.readPacket(); err != nil {
+				if err != io.EOF {
+					t.Errorf("unexpected read error: %v", err)
+				}
+				break
+			}
+		}
+	}()
+
+	// Close the server to unblock the write after an error
+	server.Close()
+	<-writeDone
+	// Unblock the pending write and close the client to unblock the read
+	// goroutine.
+	trC.kexInitAllowed <- struct{}{}
+	client.Close()
+	<-readDone
+}
+
 func TestHandshakeRekeyDefault(t *testing.T) {
 	clientConf := &ClientConfig{
 		Config: Config{

+ 51 - 1
psiphon/common/crypto/ssh/keys.go

@@ -488,7 +488,49 @@ func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
 	h := hash.New()
 	h.Write(data)
 	digest := h.Sum(nil)
-	return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, sig.Blob)
+
+	// Signatures in PKCS1v15 must match the key's modulus in
+	// length. However with SSH, some signers provide RSA
+	// signatures which are missing the MSB 0's of the bignum
+	// represented. With ssh-rsa signatures, this is encouraged by
+	// the spec (even though e.g. OpenSSH will give the full
+	// length unconditionally). With rsa-sha2-* signatures, the
+	// verifier is allowed to support these, even though they are
+	// out of spec. See RFC 4253 Section 6.6 for ssh-rsa and RFC
+	// 8332 Section 3 for rsa-sha2-* details.
+	//
+	// In practice:
+	// * OpenSSH always allows "short" signatures:
+	//   https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L526
+	//   but always generates padded signatures:
+	//   https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L439
+	//
+	// * PuTTY versions 0.81 and earlier will generate short
+	//   signatures for all RSA signature variants. Note that
+	//   PuTTY is embedded in other software, such as WinSCP and
+	//   FileZilla. At the time of writing, a patch has been
+	//   applied to PuTTY to generate padded signatures for
+	//   rsa-sha2-*, but not yet released:
+	//   https://git.tartarus.org/?p=simon/putty.git;a=commitdiff;h=a5bcf3d384e1bf15a51a6923c3724cbbee022d8e
+	//
+	// * SSH.NET versions 2024.0.0 and earlier will generate short
+	//   signatures for all RSA signature variants, fixed in 2024.1.0:
+	//   https://github.com/sshnet/SSH.NET/releases/tag/2024.1.0
+	//
+	// As a result, we pad these up to the key size by inserting
+	// leading 0's.
+	//
+	// Note that support for short signatures with rsa-sha2-* may
+	// be removed in the future due to such signatures not being
+	// allowed by the spec.
+	blob := sig.Blob
+	keySize := (*rsa.PublicKey)(r).Size()
+	if len(blob) < keySize {
+		padded := make([]byte, keySize)
+		copy(padded[keySize-len(blob):], blob)
+		blob = padded
+	}
+	return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, blob)
 }
 
 func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
@@ -904,6 +946,10 @@ func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error {
 	return errors.New("ssh: signature did not verify")
 }
 
+func (k *skECDSAPublicKey) CryptoPublicKey() crypto.PublicKey {
+	return &k.PublicKey
+}
+
 type skEd25519PublicKey struct {
 	// application is a URL-like string, typically "ssh:" for SSH.
 	// see openssh/PROTOCOL.u2f for details.
@@ -1000,6 +1046,10 @@ func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error {
 	return nil
 }
 
+func (k *skEd25519PublicKey) CryptoPublicKey() crypto.PublicKey {
+	return k.PublicKey
+}
+
 // NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey,
 // *ecdsa.PrivateKey or any other crypto.Signer and returns a
 // corresponding Signer instance. ECDSA keys must use P-256, P-384 or

+ 86 - 2
psiphon/common/crypto/ssh/keys_test.go

@@ -154,6 +154,44 @@ func TestKeySignWithAlgorithmVerify(t *testing.T) {
 	}
 }
 
+func TestKeySignWithShortSignature(t *testing.T) {
+	signer := testSigners["rsa"].(AlgorithmSigner)
+	pub := signer.PublicKey()
+	// Note: data obtained by empirically trying until a result
+	// starting with 0 appeared
+	tests := []struct {
+		algorithm string
+		data      []byte
+	}{
+		{
+			algorithm: KeyAlgoRSA,
+			data:      []byte("sign me92"),
+		},
+		{
+			algorithm: KeyAlgoRSASHA256,
+			data:      []byte("sign me294"),
+		},
+		{
+			algorithm: KeyAlgoRSASHA512,
+			data:      []byte("sign me60"),
+		},
+	}
+
+	for _, tt := range tests {
+		sig, err := signer.SignWithAlgorithm(rand.Reader, tt.data, tt.algorithm)
+		if err != nil {
+			t.Fatalf("Sign(%T): %v", signer, err)
+		}
+		if sig.Blob[0] != 0 {
+			t.Errorf("%s: Expected signature with a leading 0", tt.algorithm)
+		}
+		sig.Blob = sig.Blob[1:]
+		if err := pub.Verify(tt.data, sig); err != nil {
+			t.Errorf("publicKey.Verify(%s): %v", tt.algorithm, err)
+		}
+	}
+}
+
 func TestParseRSAPrivateKey(t *testing.T) {
 	key := testPrivateKeys["rsa"]
 
@@ -610,7 +648,7 @@ func TestKnownHostsParsing(t *testing.T) {
 func TestFingerprintLegacyMD5(t *testing.T) {
 	pub, _ := getTestKey()
 	fingerprint := FingerprintLegacyMD5(pub)
-	want := "fb:61:6d:1a:e3:f0:95:45:3c:a0:79:be:4a:93:63:66" // ssh-keygen -lf -E md5 rsa
+	want := "b7:ef:d3:d5:89:29:52:96:9f:df:47:41:4d:15:37:f4" // ssh-keygen -lf -E md5 rsa
 	if fingerprint != want {
 		t.Errorf("got fingerprint %q want %q", fingerprint, want)
 	}
@@ -619,7 +657,7 @@ func TestFingerprintLegacyMD5(t *testing.T) {
 func TestFingerprintSHA256(t *testing.T) {
 	pub, _ := getTestKey()
 	fingerprint := FingerprintSHA256(pub)
-	want := "SHA256:Anr3LjZK8YVpjrxu79myrW9Hrb/wpcMNpVvTq/RcBm8" // ssh-keygen -lf rsa
+	want := "SHA256:fi5+D7UmDZDE9Q2sAVvvlpcQSIakN4DERdINgXd2AnE" // ssh-keygen -lf rsa
 	if fingerprint != want {
 		t.Errorf("got fingerprint %q want %q", fingerprint, want)
 	}
@@ -726,3 +764,49 @@ func TestNewSignerWithAlgos(t *testing.T) {
 		t.Error("signer with algos created with restricted algorithms")
 	}
 }
+
+func TestCryptoPublicKey(t *testing.T) {
+	for _, priv := range testSigners {
+		p1 := priv.PublicKey()
+		key, ok := p1.(CryptoPublicKey)
+		if !ok {
+			continue
+		}
+		p2, err := NewPublicKey(key.CryptoPublicKey())
+		if err != nil {
+			t.Fatalf("NewPublicKey(CryptoPublicKey) failed for %s, got: %v", p1.Type(), err)
+		}
+		if !reflect.DeepEqual(p1, p2) {
+			t.Errorf("got %#v in NewPublicKey, want %#v", p2, p1)
+		}
+	}
+	for _, d := range testdata.SKData {
+		p1, _, _, _, err := ParseAuthorizedKey(d.PubKey)
+		if err != nil {
+			t.Fatalf("parseAuthorizedKey returned error: %v", err)
+		}
+		k1, ok := p1.(CryptoPublicKey)
+		if !ok {
+			t.Fatalf("%T does not implement CryptoPublicKey", p1)
+		}
+
+		var p2 PublicKey
+		switch pub := k1.CryptoPublicKey().(type) {
+		case *ecdsa.PublicKey:
+			p2 = &skECDSAPublicKey{
+				application: "ssh:",
+				PublicKey:   *pub,
+			}
+		case ed25519.PublicKey:
+			p2 = &skEd25519PublicKey{
+				application: "ssh:",
+				PublicKey:   pub,
+			}
+		default:
+			t.Fatalf("unexpected type %T from CryptoPublicKey()", pub)
+		}
+		if !reflect.DeepEqual(p1, p2) {
+			t.Errorf("got %#v, want %#v", p2, p1)
+		}
+	}
+}

+ 2 - 0
psiphon/common/crypto/ssh/messages.go

@@ -818,6 +818,8 @@ func decode(packet []byte) (interface{}, error) {
 		return new(userAuthSuccessMsg), nil
 	case msgUserAuthFailure:
 		msg = new(userAuthFailureMsg)
+	case msgUserAuthBanner:
+		msg = new(userAuthBannerMsg)
 	case msgUserAuthPubKeyOk:
 		msg = new(userAuthPubKeyOkMsg)
 	case msgGlobalRequest:

+ 56 - 0
psiphon/common/crypto/ssh/messages_test.go

@@ -206,6 +206,62 @@ func TestMarshalMultiTag(t *testing.T) {
 	}
 }
 
+func TestDecode(t *testing.T) {
+	rnd := rand.New(rand.NewSource(0))
+	kexInit := new(kexInitMsg).Generate(rnd, 10).Interface()
+	kexDHInit := new(kexDHInitMsg).Generate(rnd, 10).Interface()
+	kexDHReply := new(kexDHReplyMsg)
+	kexDHReply.Y = randomInt(rnd)
+	// Note: userAuthSuccessMsg can't be tested directly since it
+	// doesn't have a field for sshtype. So it's tested separately
+	// at the end.
+	decodeMessageTypes := []interface{}{
+		new(disconnectMsg),
+		new(serviceRequestMsg),
+		new(serviceAcceptMsg),
+		new(extInfoMsg),
+		kexInit,
+		kexDHInit,
+		kexDHReply,
+		new(userAuthRequestMsg),
+		new(userAuthFailureMsg),
+		new(userAuthBannerMsg),
+		new(userAuthPubKeyOkMsg),
+		new(globalRequestMsg),
+		new(globalRequestSuccessMsg),
+		new(globalRequestFailureMsg),
+		new(channelOpenMsg),
+		new(channelDataMsg),
+		new(channelOpenConfirmMsg),
+		new(channelOpenFailureMsg),
+		new(windowAdjustMsg),
+		new(channelEOFMsg),
+		new(channelCloseMsg),
+		new(channelRequestMsg),
+		new(channelRequestSuccessMsg),
+		new(channelRequestFailureMsg),
+		new(userAuthGSSAPIToken),
+		new(userAuthGSSAPIMIC),
+		new(userAuthGSSAPIErrTok),
+		new(userAuthGSSAPIError),
+	}
+	for _, msg := range decodeMessageTypes {
+		decoded, err := decode(Marshal(msg))
+		if err != nil {
+			t.Errorf("error decoding %T", msg)
+		} else if reflect.TypeOf(msg) != reflect.TypeOf(decoded) {
+			t.Errorf("error decoding %T, unexpected %T", msg, decoded)
+		}
+	}
+
+	userAuthSuccess, err := decode([]byte{msgUserAuthSuccess})
+	if err != nil {
+		t.Errorf("error decoding userAuthSuccessMsg")
+	} else if reflect.TypeOf(userAuthSuccess) != reflect.TypeOf((*userAuthSuccessMsg)(nil)) {
+		t.Errorf("error decoding userAuthSuccessMsg, unexpected %T", userAuthSuccess)
+	}
+}
+
 func randomBytes(out []byte, rand *rand.Rand) {
 	for i := 0; i < len(out); i++ {
 		out[i] = byte(rand.Int31())

+ 196 - 65
psiphon/common/crypto/ssh/server.go

@@ -59,6 +59,27 @@ type GSSAPIWithMICConfig struct {
 	Server GSSAPIServer
 }
 
+// SendAuthBanner implements [ServerPreAuthConn].
+func (s *connection) SendAuthBanner(msg string) error {
+	return s.transport.writePacket(Marshal(&userAuthBannerMsg{
+		Message: msg,
+	}))
+}
+
+func (*connection) unexportedMethodForFutureProofing() {}
+
+// ServerPreAuthConn is the interface available on an incoming server
+// connection before authentication has completed.
+type ServerPreAuthConn interface {
+	unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB
+
+	ConnMetadata
+
+	// SendAuthBanner sends a banner message to the client.
+	// It returns an error once the authentication phase has ended.
+	SendAuthBanner(string) error
+}
+
 // ServerConfig holds server specific configuration data.
 type ServerConfig struct {
 	// Config contains configuration shared between client and server.
@@ -118,6 +139,12 @@ type ServerConfig struct {
 	// attempts.
 	AuthLogCallback func(conn ConnMetadata, method string, err error)
 
+	// PreAuthConnCallback, if non-nil, is called upon receiving a new connection
+	// before any authentication has started. The provided ServerPreAuthConn
+	// can be used at any time before authentication is complete, including
+	// after this callback has returned.
+	PreAuthConnCallback func(ServerPreAuthConn)
+
 	// ServerVersion is the version identification string to announce in
 	// the public handshake.
 	// If empty, a reasonable default is used.
@@ -149,7 +176,7 @@ func (s *ServerConfig) AddHostKey(key Signer) {
 }
 
 // cachedPubKey contains the results of querying whether a public key is
-// acceptable for a user.
+// acceptable for a user. This is a FIFO cache.
 type cachedPubKey struct {
 	user       string
 	pubKeyData []byte
@@ -157,7 +184,13 @@ type cachedPubKey struct {
 	perms      *Permissions
 }
 
-const maxCachedPubKeys = 16
+// maxCachedPubKeys is the number of cache entries we store.
+//
+// Due to consistent misuse of the PublicKeyCallback API, we have reduced this
+// to 1, such that the only key in the cache is the most recently seen one. This
+// forces the behavior that the last call to PublicKeyCallback will always be
+// with the key that is used for authentication.
+const maxCachedPubKeys = 1
 
 // pubKeyCache caches tests for public keys.  Since SSH clients
 // will query whether a public key is acceptable before attempting to
@@ -179,9 +212,10 @@ func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) {
 
 // add adds the given tuple to the cache.
 func (c *pubKeyCache) add(candidate cachedPubKey) {
-	if len(c.keys) < maxCachedPubKeys {
-		c.keys = append(c.keys, candidate)
+	if len(c.keys) >= maxCachedPubKeys {
+		c.keys = c.keys[1:]
 	}
+	c.keys = append(c.keys, candidate)
 }
 
 // ServerConn is an authenticated SSH connection, as seen from the
@@ -426,6 +460,35 @@ func (l ServerAuthError) Error() string {
 	return "[" + strings.Join(errs, ", ") + "]"
 }
 
+// ServerAuthCallbacks defines server-side authentication callbacks.
+type ServerAuthCallbacks struct {
+	// PasswordCallback behaves like [ServerConfig.PasswordCallback].
+	PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
+
+	// PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback].
+	PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
+
+	// KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback].
+	KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
+
+	// GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig].
+	GSSAPIWithMICConfig *GSSAPIWithMICConfig
+}
+
+// PartialSuccessError can be returned by any of the [ServerConfig]
+// authentication callbacks to indicate to the client that authentication has
+// partially succeeded, but further steps are required.
+type PartialSuccessError struct {
+	// Next defines the authentication callbacks to apply to further steps. The
+	// available methods communicated to the client are based on the non-nil
+	// ServerAuthCallbacks fields.
+	Next ServerAuthCallbacks
+}
+
+func (p *PartialSuccessError) Error() string {
+	return "ssh: authenticated with partial success"
+}
+
 // ErrNoAuth is the error value returned if no
 // authentication method has been passed yet. This happens as a normal
 // part of the authentication loop, since the client first tries
@@ -433,14 +496,46 @@ func (l ServerAuthError) Error() string {
 // It is returned in ServerAuthError.Errors from NewServerConn.
 var ErrNoAuth = errors.New("ssh: no auth passed yet")
 
+// BannerError is an error that can be returned by authentication handlers in
+// ServerConfig to send a banner message to the client.
+type BannerError struct {
+	Err     error
+	Message string
+}
+
+func (b *BannerError) Unwrap() error {
+	return b.Err
+}
+
+func (b *BannerError) Error() string {
+	if b.Err == nil {
+		return b.Message
+	}
+	return b.Err.Error()
+}
+
 func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
+	if config.PreAuthConnCallback != nil {
+		config.PreAuthConnCallback(s)
+	}
+
 	sessionID := s.transport.getSessionID()
 	var cache pubKeyCache
 	var perms *Permissions
 
 	authFailures := 0
+	noneAuthCount := 0
 	var authErrs []error
-	var displayedBanner bool
+	var calledBannerCallback bool
+	partialSuccessReturned := false
+	// Set the initial authentication callbacks from the config. They can be
+	// changed if a PartialSuccessError is returned.
+	authConfig := ServerAuthCallbacks{
+		PasswordCallback:            config.PasswordCallback,
+		PublicKeyCallback:           config.PublicKeyCallback,
+		KeyboardInteractiveCallback: config.KeyboardInteractiveCallback,
+		GSSAPIWithMICConfig:         config.GSSAPIWithMICConfig,
+	}
 
 userAuthLoop:
 	for {
@@ -453,8 +548,8 @@ userAuthLoop:
 			if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
 				return nil, err
 			}
-
-			return nil, discMsg
+			authErrs = append(authErrs, discMsg)
+			return nil, &ServerAuthError{Errors: authErrs}
 		}
 
 		var userAuthReq userAuthRequestMsg
@@ -471,16 +566,17 @@ userAuthLoop:
 			return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
 		}
 
+		if s.user != userAuthReq.User && partialSuccessReturned {
+			return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q",
+				s.user, userAuthReq.User)
+		}
+
 		s.user = userAuthReq.User
 
-		if !displayedBanner && config.BannerCallback != nil {
-			displayedBanner = true
-			msg := config.BannerCallback(s)
-			if msg != "" {
-				bannerMsg := &userAuthBannerMsg{
-					Message: msg,
-				}
-				if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
+		if !calledBannerCallback && config.BannerCallback != nil {
+			calledBannerCallback = true
+			if msg := config.BannerCallback(s); msg != "" {
+				if err := s.SendAuthBanner(msg); err != nil {
 					return nil, err
 				}
 			}
@@ -491,20 +587,18 @@ userAuthLoop:
 
 		switch userAuthReq.Method {
 		case "none":
-			if config.NoClientAuth {
+			noneAuthCount++
+			// We don't allow none authentication after a partial success
+			// response.
+			if config.NoClientAuth && !partialSuccessReturned {
 				if config.NoClientAuthCallback != nil {
 					perms, authErr = config.NoClientAuthCallback(s)
 				} else {
 					authErr = nil
 				}
 			}
-
-			// allow initial attempt of 'none' without penalty
-			if authFailures == 0 {
-				authFailures--
-			}
 		case "password":
-			if config.PasswordCallback == nil {
+			if authConfig.PasswordCallback == nil {
 				authErr = errors.New("ssh: password auth not configured")
 				break
 			}
@@ -518,17 +612,17 @@ userAuthLoop:
 				return nil, parseError(msgUserAuthRequest)
 			}
 
-			perms, authErr = config.PasswordCallback(s, password)
+			perms, authErr = authConfig.PasswordCallback(s, password)
 		case "keyboard-interactive":
-			if config.KeyboardInteractiveCallback == nil {
+			if authConfig.KeyboardInteractiveCallback == nil {
 				authErr = errors.New("ssh: keyboard-interactive auth not configured")
 				break
 			}
 
 			prompter := &sshClientKeyboardInteractive{s}
-			perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
+			perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge)
 		case "publickey":
-			if config.PublicKeyCallback == nil {
+			if authConfig.PublicKeyCallback == nil {
 				authErr = errors.New("ssh: publickey auth not configured")
 				break
 			}
@@ -562,11 +656,18 @@ userAuthLoop:
 			if !ok {
 				candidate.user = s.user
 				candidate.pubKeyData = pubKeyData
-				candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
-				if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
-					candidate.result = checkSourceAddress(
+				candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
+				_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
+
+				if (candidate.result == nil || isPartialSuccessError) &&
+					candidate.perms != nil &&
+					candidate.perms.CriticalOptions != nil &&
+					candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
+					if err := checkSourceAddress(
 						s.RemoteAddr(),
-						candidate.perms.CriticalOptions[sourceAddressCriticalOption])
+						candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil {
+						candidate.result = err
+					}
 				}
 				cache.add(candidate)
 			}
@@ -578,8 +679,8 @@ userAuthLoop:
 				if len(payload) > 0 {
 					return nil, parseError(msgUserAuthRequest)
 				}
-
-				if candidate.result == nil {
+				_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
+				if candidate.result == nil || isPartialSuccessError {
 					okMsg := userAuthPubKeyOkMsg{
 						Algo:   algo,
 						PubKey: pubKeyData,
@@ -629,11 +730,11 @@ userAuthLoop:
 				perms = candidate.perms
 			}
 		case "gssapi-with-mic":
-			if config.GSSAPIWithMICConfig == nil {
+			if authConfig.GSSAPIWithMICConfig == nil {
 				authErr = errors.New("ssh: gssapi-with-mic auth not configured")
 				break
 			}
-			gssapiConfig := config.GSSAPIWithMICConfig
+			gssapiConfig := authConfig.GSSAPIWithMICConfig
 			userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
 			if err != nil {
 				return nil, parseError(msgUserAuthRequest)
@@ -685,53 +786,83 @@ userAuthLoop:
 			config.AuthLogCallback(s, userAuthReq.Method, authErr)
 		}
 
+		var bannerErr *BannerError
+		if errors.As(authErr, &bannerErr) {
+			if bannerErr.Message != "" {
+				if err := s.SendAuthBanner(bannerErr.Message); err != nil {
+					return nil, err
+				}
+			}
+		}
+
 		if authErr == nil {
 			break userAuthLoop
 		}
 
-		authFailures++
-		if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
-			// If we have hit the max attempts, don't bother sending the
-			// final SSH_MSG_USERAUTH_FAILURE message, since there are
-			// no more authentication methods which can be attempted,
-			// and this message may cause the client to re-attempt
-			// authentication while we send the disconnect message.
-			// Continue, and trigger the disconnect at the start of
-			// the loop.
-			//
-			// The SSH specification is somewhat confusing about this,
-			// RFC 4252 Section 5.1 requires each authentication failure
-			// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
-			// message, but Section 4 says the server should disconnect
-			// after some number of attempts, but it isn't explicit which
-			// message should take precedence (i.e. should there be a failure
-			// message than a disconnect message, or if we are going to
-			// disconnect, should we only send that message.)
-			//
-			// Either way, OpenSSH disconnects immediately after the last
-			// failed authnetication attempt, and given they are typically
-			// considered the golden implementation it seems reasonable
-			// to match that behavior.
-			continue
+		var failureMsg userAuthFailureMsg
+
+		if partialSuccess, ok := authErr.(*PartialSuccessError); ok {
+			// After a partial success error we don't allow changing the user
+			// name and execute the NoClientAuthCallback.
+			partialSuccessReturned = true
+
+			// In case a partial success is returned, the server may send
+			// a new set of authentication methods.
+			authConfig = partialSuccess.Next
+
+			// Reset pubkey cache, as the new PublicKeyCallback might
+			// accept a different set of public keys.
+			cache = pubKeyCache{}
+
+			// Send back a partial success message to the user.
+			failureMsg.PartialSuccess = true
+		} else {
+			// Allow initial attempt of 'none' without penalty.
+			if authFailures > 0 || userAuthReq.Method != "none" || noneAuthCount != 1 {
+				authFailures++
+			}
+			if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
+				// If we have hit the max attempts, don't bother sending the
+				// final SSH_MSG_USERAUTH_FAILURE message, since there are
+				// no more authentication methods which can be attempted,
+				// and this message may cause the client to re-attempt
+				// authentication while we send the disconnect message.
+				// Continue, and trigger the disconnect at the start of
+				// the loop.
+				//
+				// The SSH specification is somewhat confusing about this,
+				// RFC 4252 Section 5.1 requires each authentication failure
+				// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
+				// message, but Section 4 says the server should disconnect
+				// after some number of attempts, but it isn't explicit which
+				// message should take precedence (i.e. should there be a failure
+				// message than a disconnect message, or if we are going to
+				// disconnect, should we only send that message.)
+				//
+				// Either way, OpenSSH disconnects immediately after the last
+				// failed authentication attempt, and given they are typically
+				// considered the golden implementation it seems reasonable
+				// to match that behavior.
+				continue
+			}
 		}
 
-		var failureMsg userAuthFailureMsg
-		if config.PasswordCallback != nil {
+		if authConfig.PasswordCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "password")
 		}
-		if config.PublicKeyCallback != nil {
+		if authConfig.PublicKeyCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "publickey")
 		}
-		if config.KeyboardInteractiveCallback != nil {
+		if authConfig.KeyboardInteractiveCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
 		}
-		if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil &&
-			config.GSSAPIWithMICConfig.AllowLogin != nil {
+		if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil &&
+			authConfig.GSSAPIWithMICConfig.AllowLogin != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
 		}
 
 		if len(failureMsg.Methods) == 0 {
-			return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
+			return nil, errors.New("ssh: no authentication methods available")
 		}
 
 		if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {

+ 412 - 0
psiphon/common/crypto/ssh/server_multi_auth_test.go

@@ -0,0 +1,412 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"strings"
+	"testing"
+)
+
+func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) {
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	var serverAuthErrors []error
+
+	serverConfig.AddHostKey(testSigners["rsa"])
+	serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
+		serverAuthErrors = append(serverAuthErrors, err)
+	}
+	go newServer(c1, serverConfig)
+	c, _, _, err := NewClientConn(c2, "", clientConfig)
+	if err == nil {
+		c.Close()
+	}
+	return serverAuthErrors, err
+}
+
+func TestMultiStepAuth(t *testing.T) {
+	// This user can login with password, public key or public key + password.
+	username := "testuser"
+	// This user can login with public key + password only.
+	usernameSecondFactor := "testuser_second_factor"
+	errPwdAuthFailed := errors.New("password auth failed")
+	errWrongSequence := errors.New("wrong sequence")
+
+	serverConfig := &ServerConfig{
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			if conn.User() == usernameSecondFactor {
+				return nil, errWrongSequence
+			}
+			if conn.User() == username && string(password) == clientPassword {
+				return nil, nil
+			}
+			return nil, errPwdAuthFailed
+		},
+		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+			if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+				if conn.User() == usernameSecondFactor {
+					return nil, &PartialSuccessError{
+						Next: ServerAuthCallbacks{
+							PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+								if string(password) == clientPassword {
+									return nil, nil
+								}
+								return nil, errPwdAuthFailed
+							},
+						},
+					}
+				}
+				return nil, nil
+			}
+			return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
+		},
+	}
+
+	clientConfig := &ClientConfig{
+		User: usernameSecondFactor,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1])
+	}
+	// Now test a wrong sequence.
+	clientConfig.Auth = []AuthMethod{
+		Password(clientPassword),
+		PublicKeys(testSigners["rsa"]),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with wrong sequence must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - wrong sequence
+	// - partial success
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if serverAuthErrors[1] != errWrongSequence {
+		t.Fatal("server not returned wrong sequence")
+	}
+	if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok {
+		t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2])
+	}
+	// Now test using a correct sequence but a wrong password before the right
+	// one.
+	n := 0
+	passwords := []string{"WRONG", "WRONG", clientPassword}
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+		RetryableAuthMethod(PasswordCallback(func() (string, error) {
+			p := passwords[n]
+			n++
+			return p, nil
+		}), 3),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - wrong password
+	// - wrong password
+	// - nil
+	if len(serverAuthErrors) != 5 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+	if serverAuthErrors[2] != errPwdAuthFailed {
+		t.Fatal("server not returned password authentication failed")
+	}
+	if serverAuthErrors[3] != errPwdAuthFailed {
+		t.Fatal("server not returned password authentication failed")
+	}
+	// Only password authentication should fail.
+	clientConfig.Auth = []AuthMethod{
+		Password(clientPassword),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with password only must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - wrong sequence
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if serverAuthErrors[1] != errWrongSequence {
+		t.Fatal("server not returned wrong sequence")
+	}
+
+	// Only public key authentication should fail.
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with public key only must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	// Public key and wrong password.
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+		Password("WRONG"),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with wrong password after public key must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - password auth failed
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+	if serverAuthErrors[2] != errPwdAuthFailed {
+		t.Fatal("server not returned password authentication failed")
+	}
+
+	// Public key, public key again and then correct password. Public key
+	// authentication is attempted only once because the partial success error
+	// returns only "password" as the allowed authentication method.
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+		PublicKeys(testSigners["rsa"]),
+		Password(clientPassword),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	// The unrestricted username can do anything
+	clientConfig = &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("unrestricted client login error: %s", err)
+	}
+
+	clientConfig = &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("unrestricted client login error: %s", err)
+	}
+
+	clientConfig = &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("unrestricted client login error: %s", err)
+	}
+}
+
+func TestDynamicAuthCallbacks(t *testing.T) {
+	user1 := "user1"
+	user2 := "user2"
+	errInvalidCredentials := errors.New("invalid credentials")
+
+	serverConfig := &ServerConfig{
+		NoClientAuth: true,
+		NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
+			switch conn.User() {
+			case user1:
+				return nil, &PartialSuccessError{
+					Next: ServerAuthCallbacks{
+						PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+							if conn.User() == user1 && string(password) == clientPassword {
+								return nil, nil
+							}
+							return nil, errInvalidCredentials
+						},
+					},
+				}
+			case user2:
+				return nil, &PartialSuccessError{
+					Next: ServerAuthCallbacks{
+						PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+							if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+								if conn.User() == user2 {
+									return nil, nil
+								}
+							}
+							return nil, errInvalidCredentials
+						},
+					},
+				}
+			default:
+				return nil, errInvalidCredentials
+			}
+		},
+	}
+
+	clientConfig := &ClientConfig{
+		User: user1,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	clientConfig = &ClientConfig{
+		User: user2,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	// user1 cannot login with public key
+	clientConfig = &ClientConfig{
+		User: user1,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("user1 login with public key must fail")
+	}
+	if !strings.Contains(err.Error(), "no supported methods remain") {
+		t.Errorf("got %v, expected 'no supported methods remain'", err)
+	}
+	if len(serverAuthErrors) != 1 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+	// user2 cannot login with password
+	clientConfig = &ClientConfig{
+		User: user2,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("user2 login with password must fail")
+	}
+	if !strings.Contains(err.Error(), "no supported methods remain") {
+		t.Errorf("got %v, expected 'no supported methods remain'", err)
+	}
+	if len(serverAuthErrors) != 1 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+}

+ 338 - 0
psiphon/common/crypto/ssh/server_test.go

@@ -5,8 +5,13 @@
 package ssh
 
 import (
+	"bytes"
+	"errors"
+	"fmt"
 	"io"
 	"net"
+	"reflect"
+	"strings"
 	"sync/atomic"
 	"testing"
 	"time"
@@ -62,6 +67,133 @@ func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) {
 	}
 }
 
+func TestMaxAuthTriesNoneMethod(t *testing.T) {
+	username := "testuser"
+	serverConfig := &ServerConfig{
+		MaxAuthTries: 2,
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			if conn.User() == username && string(password) == clientPassword {
+				return nil, nil
+			}
+			return nil, errors.New("invalid credentials")
+		},
+	}
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	var serverAuthErrors []error
+
+	serverConfig.AddHostKey(testSigners["rsa"])
+	serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
+		serverAuthErrors = append(serverAuthErrors, err)
+	}
+	go newServer(c1, serverConfig)
+
+	clientConfig := ClientConfig{
+		User:            username,
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	clientConfig.SetDefaults()
+	// Our client will send 'none' auth only once, so we need to send the
+	// requests manually.
+	c := &connection{
+		sshConn: sshConn{
+			conn:          c2,
+			user:          username,
+			clientVersion: []byte(packageVersion),
+		},
+	}
+	c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
+	if err != nil {
+		t.Fatalf("unable to exchange version: %v", err)
+	}
+	c.transport = newClientTransport(
+		newTransport(c.sshConn.conn, clientConfig.Rand, true /* is client */),
+		c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr())
+	if err := c.transport.waitSession(); err != nil {
+		t.Fatalf("unable to wait session: %v", err)
+	}
+	c.sessionID = c.transport.getSessionID()
+	if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
+		t.Fatalf("unable to send ssh-userauth message: %v", err)
+	}
+	packet, err := c.transport.readPacket()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(packet) > 0 && packet[0] == msgExtInfo {
+		packet, err = c.transport.readPacket()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}
+	var serviceAccept serviceAcceptMsg
+	if err := Unmarshal(packet, &serviceAccept); err != nil {
+		t.Fatal(err)
+	}
+	for i := 0; i <= serverConfig.MaxAuthTries; i++ {
+		auth := new(noneAuth)
+		_, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil)
+		if i < serverConfig.MaxAuthTries {
+			if err != nil {
+				t.Fatal(err)
+			}
+			continue
+		}
+		if err == nil {
+			t.Fatal("client: got no error")
+		} else if !strings.Contains(err.Error(), "too many authentication failures") {
+			t.Fatalf("client: got unexpected error: %v", err)
+		}
+	}
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	for _, err := range serverAuthErrors {
+		if !errors.Is(err, ErrNoAuth) {
+			t.Errorf("go error: %v; want: %v", err, ErrNoAuth)
+		}
+	}
+}
+
+func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) {
+	username := "testuser"
+	serverConfig := &ServerConfig{
+		MaxAuthTries: 1,
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			if conn.User() == username && string(password) == clientPassword {
+				return nil, nil
+			}
+			return nil, errors.New("invalid credentials")
+		},
+	}
+	clientConfig := &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if !errors.Is(serverAuthErrors[0], ErrNoAuth) {
+		t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth)
+	}
+	if serverAuthErrors[1] != nil {
+		t.Errorf("unexpected error: %v", serverAuthErrors[1])
+	}
+}
+
 func TestNewServerConnValidationErrors(t *testing.T) {
 	serverConf := &ServerConfig{
 		PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01},
@@ -96,6 +228,212 @@ func TestNewServerConnValidationErrors(t *testing.T) {
 	}
 }
 
+func TestBannerError(t *testing.T) {
+	serverConfig := &ServerConfig{
+		BannerCallback: func(ConnMetadata) string {
+			return "banner from BannerCallback"
+		},
+		NoClientAuth: true,
+		NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
+			err := &BannerError{
+				Err:     errors.New("error from NoClientAuthCallback"),
+				Message: "banner from NoClientAuthCallback",
+			}
+			return nil, fmt.Errorf("wrapped: %w", err)
+		},
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			return &Permissions{}, nil
+		},
+		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+			return nil, &BannerError{
+				Err:     errors.New("error from PublicKeyCallback"),
+				Message: "banner from PublicKeyCallback",
+			}
+		},
+		KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) {
+			return nil, &BannerError{
+				Err:     nil, // make sure that a nil inner error is allowed
+				Message: "banner from KeyboardInteractiveCallback",
+			}
+		},
+	}
+	serverConfig.AddHostKey(testSigners["rsa"])
+
+	var banners []string
+	clientConfig := &ClientConfig{
+		User: "test",
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+			KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
+				return []string{"letmein"}, nil
+			}),
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		BannerCallback: func(msg string) error {
+			banners = append(banners, msg)
+			return nil
+		},
+	}
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+	go newServer(c1, serverConfig)
+	c, _, _, err := NewClientConn(c2, "", clientConfig)
+	if err != nil {
+		t.Fatalf("client connection failed: %v", err)
+	}
+	defer c.Close()
+
+	wantBanners := []string{
+		"banner from BannerCallback",
+		"banner from NoClientAuthCallback",
+		"banner from PublicKeyCallback",
+		"banner from KeyboardInteractiveCallback",
+	}
+	if !reflect.DeepEqual(banners, wantBanners) {
+		t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
+	}
+}
+
+func TestPublicKeyCallbackLastSeen(t *testing.T) {
+	var lastSeenKey PublicKey
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+	serverConf := &ServerConfig{
+		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+			lastSeenKey = key
+			fmt.Printf("seen %#v\n", key)
+			if _, ok := key.(*dsaPublicKey); !ok {
+				return nil, errors.New("nope")
+			}
+			return nil, nil
+		},
+	}
+	serverConf.AddHostKey(testSigners["ecdsap256"])
+
+	done := make(chan struct{})
+	go func() {
+		defer close(done)
+		NewServerConn(c1, serverConf)
+	}()
+
+	clientConf := ClientConfig{
+		User: "user",
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, _, _, err = NewClientConn(c2, "", &clientConf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	<-done
+
+	expectedPublicKey := testSigners["dsa"].PublicKey().Marshal()
+	lastSeenMarshalled := lastSeenKey.Marshal()
+	if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) {
+		t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey())
+	}
+}
+
+func TestPreAuthConnAndBanners(t *testing.T) {
+	testDone := make(chan struct{})
+	defer close(testDone)
+
+	authConnc := make(chan ServerPreAuthConn, 1)
+	serverConfig := &ServerConfig{
+		PreAuthConnCallback: func(c ServerPreAuthConn) {
+			t.Logf("got ServerPreAuthConn: %v", c)
+			authConnc <- c // for use later in the test
+			for _, s := range []string{"hello1", "hello2"} {
+				if err := c.SendAuthBanner(s); err != nil {
+					t.Errorf("failed to send banner %q: %v", s, err)
+				}
+			}
+			// Now start a goroutine to spam SendAuthBanner in hopes
+			// of hitting a race.
+			go func() {
+				for {
+					select {
+					case <-testDone:
+						return
+					default:
+						if err := c.SendAuthBanner("attempted-race"); err != nil && err != errSendBannerPhase {
+							t.Errorf("unexpected error from SendAuthBanner: %v", err)
+						}
+						time.Sleep(5 * time.Millisecond)
+					}
+				}
+			}()
+		},
+		NoClientAuth: true,
+		NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
+			t.Logf("got NoClientAuthCallback")
+			return &Permissions{}, nil
+		},
+	}
+	serverConfig.AddHostKey(testSigners["rsa"])
+
+	var banners []string
+	clientConfig := &ClientConfig{
+		User:            "test",
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		BannerCallback: func(msg string) error {
+			if msg != "attempted-race" {
+				banners = append(banners, msg)
+			}
+			return nil
+		},
+	}
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+	go newServer(c1, serverConfig)
+	c, _, _, err := NewClientConn(c2, "", clientConfig)
+	if err != nil {
+		t.Fatalf("client connection failed: %v", err)
+	}
+	defer c.Close()
+
+	wantBanners := []string{
+		"hello1",
+		"hello2",
+	}
+	if !reflect.DeepEqual(banners, wantBanners) {
+		t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
+	}
+
+	// Now that we're authenticated, verify that use of SendBanner
+	// is an error.
+	var bc ServerPreAuthConn
+	select {
+	case bc = <-authConnc:
+	default:
+		t.Fatal("expected ServerPreAuthConn")
+	}
+	if err := bc.SendAuthBanner("wrong-phase"); err == nil {
+		t.Error("unexpected success of SendAuthBanner after authentication")
+	} else if err != errSendBannerPhase {
+		t.Errorf("unexpected error: %v; want %v", err, errSendBannerPhase)
+	}
+}
+
 type markerConn struct {
 	closed uint32
 	used   uint32

+ 1 - 1
psiphon/common/crypto/ssh/tcpip.go

@@ -472,7 +472,7 @@ func (c *Client) dial(channelType string, laddr string, lport int, raddr string,
 		return nil, err
 	}
 	go DiscardRequests(in)
-	return ch, err
+	return ch, nil
 }
 
 type tcpChan struct {

+ 4 - 3
psiphon/common/crypto/ssh/test/agent_unix_test.go

@@ -20,20 +20,21 @@ func TestAgentForward(t *testing.T) {
 	defer conn.Close()
 
 	keyring := agent.NewKeyring()
-	if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["dsa"]}); err != nil {
+	if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["ecdsa"]}); err != nil {
 		t.Fatalf("Error adding key: %s", err)
 	}
 	if err := keyring.Add(agent.AddedKey{
-		PrivateKey:       testPrivateKeys["dsa"],
+		PrivateKey:       testPrivateKeys["ecdsa"],
 		ConfirmBeforeUse: true,
 		LifetimeSecs:     3600,
 	}); err != nil {
 		t.Fatalf("Error adding key with constraints: %s", err)
 	}
-	pub := testPublicKeys["dsa"]
+	pub := testPublicKeys["ecdsa"]
 
 	sess, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("NewSession: %v", err)
 	}
 	if err := agent.RequestAgentForwarding(sess); err != nil {

+ 1 - 1
psiphon/common/crypto/ssh/test/cert_test.go

@@ -19,7 +19,7 @@ func TestCertLogin(t *testing.T) {
 	s := newServer(t)
 
 	// Use a key different from the default.
-	clientKey := testSigners["dsa"]
+	clientKey := testSigners["ed25519"]
 	caAuthKey := testSigners["ecdsa"]
 	cert := &ssh.Certificate{
 		Key:             clientKey.PublicKey(),

+ 1 - 0
psiphon/common/crypto/ssh/test/dial_unix_test.go

@@ -53,6 +53,7 @@ func testDial(t *testing.T, n, listenAddr string, x dialTester) {
 	// on the opened connection.
 	cancel()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("Dial: %v", err)
 	}
 	x.TestClientConn(t, conn)

+ 1 - 1
psiphon/common/crypto/ssh/test/doc.go

@@ -4,4 +4,4 @@
 
 // Package test contains integration tests for the
 // github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh package.
-package test // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/test"
+package test

+ 10 - 0
psiphon/common/crypto/ssh/test/forward_unix_test.go

@@ -12,6 +12,7 @@ import (
 	"io"
 	"math/rand"
 	"net"
+	"runtime"
 	"testing"
 	"time"
 )
@@ -27,6 +28,9 @@ func testPortForward(t *testing.T, n, listenAddr string) {
 
 	sshListener, err := conn.Listen(n, listenAddr)
 	if err != nil {
+		if runtime.GOOS == "darwin" && err == io.EOF {
+			t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
+		}
 		t.Fatal(err)
 	}
 
@@ -122,6 +126,9 @@ func testAcceptClose(t *testing.T, n, listenAddr string) {
 
 	sshListener, err := conn.Listen(n, listenAddr)
 	if err != nil {
+		if runtime.GOOS == "darwin" && err == io.EOF {
+			t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
+		}
 		t.Fatal(err)
 	}
 
@@ -163,6 +170,9 @@ func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
 
 	sshListener, err := client.Listen(n, listenAddr)
 	if err != nil {
+		if runtime.GOOS == "darwin" && err == io.EOF {
+			t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
+		}
 		t.Fatal(err)
 	}
 

+ 1 - 1
psiphon/common/crypto/ssh/test/server_test.go

@@ -14,7 +14,7 @@ type exitStatusMsg struct {
 	Status uint32
 }
 
-// goServer is a test Go SSH server that accepts public key and certificate
+// goTestServer is a test Go SSH server that accepts public key and certificate
 // authentication and replies with a 0 exit status to any exec request without
 // running any commands.
 type goTestServer struct {

+ 91 - 6
psiphon/common/crypto/ssh/test/session_test.go

@@ -22,6 +22,13 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 )
 
+func skipIfIssue64959(t *testing.T, err error) {
+	if err != nil && runtime.GOOS == "darwin" && strings.Contains(err.Error(), "ssh: unexpected packet in response to channel open: <nil>") {
+		t.Helper()
+		t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
+	}
+}
+
 func TestRunCommandSuccess(t *testing.T) {
 	server := newServer(t)
 	conn := server.Dial(clientConfig())
@@ -29,6 +36,7 @@ func TestRunCommandSuccess(t *testing.T) {
 
 	session, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 	}
 	defer session.Close()
@@ -66,6 +74,7 @@ func TestRunCommandStdin(t *testing.T) {
 
 	session, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 	}
 	defer session.Close()
@@ -88,6 +97,7 @@ func TestRunCommandStdinError(t *testing.T) {
 
 	session, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 	}
 	defer session.Close()
@@ -111,6 +121,7 @@ func TestRunCommandFailed(t *testing.T) {
 
 	session, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 	}
 	defer session.Close()
@@ -127,6 +138,7 @@ func TestRunCommandWeClosed(t *testing.T) {
 
 	session, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 	}
 	err = session.Shell()
@@ -146,6 +158,7 @@ func TestFuncLargeRead(t *testing.T) {
 
 	session, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("unable to create new session: %s", err)
 	}
 
@@ -182,6 +195,7 @@ func TestKeyChange(t *testing.T) {
 	for i := 0; i < 4; i++ {
 		session, err := conn.NewSession()
 		if err != nil {
+			skipIfIssue64959(t, err)
 			t.Fatalf("unable to create new session: %s", err)
 		}
 
@@ -223,6 +237,7 @@ func TestValidTerminalMode(t *testing.T) {
 
 	session, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 	}
 	defer session.Close()
@@ -287,6 +302,7 @@ func TestWindowChange(t *testing.T) {
 
 	session, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 	}
 	defer session.Close()
@@ -341,14 +357,10 @@ func testOneCipher(t *testing.T, cipher string, cipherOrder []string) {
 
 	numBytes := 4096
 
-	// Exercise sending data to the server
-	if _, _, err := conn.Conn.SendRequest("drop-me", false, make([]byte, numBytes)); err != nil {
-		t.Fatalf("SendRequest: %v", err)
-	}
-
 	// Exercise receiving data from the server
 	session, err := conn.NewSession()
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("NewSession: %v", err)
 	}
 
@@ -360,6 +372,11 @@ func testOneCipher(t *testing.T, cipher string, cipherOrder []string) {
 	if len(out) != numBytes {
 		t.Fatalf("got %d bytes, want %d bytes", len(out), numBytes)
 	}
+
+	// Exercise sending data to the server
+	if _, _, err := conn.Conn.SendRequest("drop-me", false, make([]byte, numBytes)); err != nil {
+		t.Fatalf("SendRequest: %v", err)
+	}
 }
 
 var deprecatedCiphers = []string{
@@ -431,7 +448,6 @@ func TestKeyExchanges(t *testing.T) {
 func TestClientAuthAlgorithms(t *testing.T) {
 	for _, key := range []string{
 		"rsa",
-		"dsa",
 		"ecdsa",
 		"ed25519",
 	} {
@@ -452,3 +468,72 @@ func TestClientAuthAlgorithms(t *testing.T) {
 		})
 	}
 }
+
+func TestClientAuthDisconnect(t *testing.T) {
+	// Use a static key that is not accepted by server.
+	// This key has been generated with following ssh-keygen command and
+	// used exclusively in this unit test:
+	// $ ssh-keygen -t RSA -b 2048 -f /tmp/static_key \
+	//   -C "Static RSA key for golang.org/x/crypto/ssh unit test"
+
+	const privKeyData = `-----BEGIN OPENSSH PRIVATE KEY-----
+b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
+NhAAAAAwEAAQAAAQEAwV1Zg3MqX27nIQQNWd8V09P4q4F1fx7H2xNJdL3Yg3y91GFLJ92+
+0IiGV8n1VMGL/71PPhzyqBpUYSTpWjiU2JZSfA+iTg1GJBcOaEOA6vrXsTtXTHZ//mOT4d
+mlvuP4+9NqfCBLGXN7ZJpT+amkD8AVW9YW9QN3ipY61ZWxPaAocVpDd8rVgJTk54KvaPa7
+t4ddOSQDQq61aubIDR1Z3P+XjkB4piWOsbck3HJL+veTALy12C09tAhwUnZUAXS+DjhxOL
+xpDVclF/yXYhAvBvsjwyk/OC3+nK9F799hpQZsjxmbP7oN+tGwz06BUcAKi7u7QstENvvk
+85SDZy1q1QAAA/A7ylbJO8pWyQAAAAdzc2gtcnNhAAABAQDBXVmDcypfbuchBA1Z3xXT0/
+irgXV/HsfbE0l0vdiDfL3UYUsn3b7QiIZXyfVUwYv/vU8+HPKoGlRhJOlaOJTYllJ8D6JO
+DUYkFw5oQ4Dq+texO1dMdn/+Y5Ph2aW+4/j702p8IEsZc3tkmlP5qaQPwBVb1hb1A3eKlj
+rVlbE9oChxWkN3ytWAlOTngq9o9ru3h105JANCrrVq5sgNHVnc/5eOQHimJY6xtyTcckv6
+95MAvLXYLT20CHBSdlQBdL4OOHE4vGkNVyUX/JdiEC8G+yPDKT84Lf6cr0Xv32GlBmyPGZ
+s/ug360bDPToFRwAqLu7tCy0Q2++TzlINnLWrVAAAAAwEAAQAAAQAIvPDHMiyIxgCksGPF
+uyv9F9U4XjVip8/abE9zkAMJWW5++wuT/bRlBOUPRrWIXZEM9ETbtsqswo3Wxah+7CjRIH
+qR7SdFlYTP1jPk4yIKXF4OvggBUPySkMpAGJ9hwOMW8Ymcb4gn77JJ4aMoWIcXssje+XiC
+8iO+4UWU3SV2i6K7flK1UDCI5JVCyBr3DVf3QhMOgvwJl9TgD7FzWy1hkjuZq/Pzdv+fA2
+OfrUFiSukLNolidNoI9+KWa1yxixE+B2oN4Xan3ZbqGbL6Wc1dB+K9h/bNcu+SKf7fXWRi
+/vVG44A61xGDZzen1+eQlqFp7narkKXoaU71+45VXDThAAAAgBPWUdQykEEm0yOS6hPIW+
+hS8z1LXWGTEcag9fMwJXKE7cQFO3LEk+dXMbClHdhD/ydswOZYGSNepxwvmo/a5LiO2ulp
+W+5tnsNhcK3skdaf71t+boUEXBNZ6u3WNTkU7tDN8h9tebI+xlNceDGSGjOlNoHQVMKZdA
+W9TA4ZqXUPAAAAgQDWU0UZVOSCAOODPz4PYsbFKdCfXNP8O4+t9txyc9E3hsLAsVs+CpVX
+Gr219MGLrublzAxojipyzuQb6Tp1l9nsu7VkcBrPL8I1tokz0AyTnmNF3A9KszBal7gGNS
+a2qYuf6JO4cub1KzonxUJQHZPZq9YhCxOtDwTd+uyHZiPy9QAAAIEA5vayd+nfVJgCKTdf
+z5MFsxBSUj/cAYg7JYPS/0bZ5bEkLosL22wl5Tm/ZftJa8apkyBPhguAWt6jEWLoDiK+kn
+Fv0SaEq1HUdXgWmISVnWzv2pxdAtq/apmbxTg3iIJyrAwEDo13iImR3k6rNPx1m3i/jX56
+HLcvWM4Y6bFzbGEAAAA0U3RhdGljIFJTQSBrZXkgZm9yIGdvbGFuZy5vcmcveC9jcnlwdG
+8vc3NoIHVuaXQgdGVzdAECAwQFBgc=
+-----END OPENSSH PRIVATE KEY-----`
+
+	signer, err := ssh.ParsePrivateKey([]byte(privKeyData))
+	if err != nil {
+		t.Fatalf("failed to create signer from key: %v", err)
+	}
+
+	// Start server with MaxAuthTries 1 and publickey and password auth
+	// enabled
+	server := newServerForConfig(t, "MaxAuthTries", map[string]string{})
+
+	// Connect to server, expect failure, that PublicKeysCallback is called
+	// and that PasswordCallback is not called.
+	publicKeysCallbackCalled := false
+	config := clientConfig()
+	config.Auth = []ssh.AuthMethod{
+		ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
+			publicKeysCallbackCalled = true
+			return []ssh.Signer{signer}, nil
+		}),
+		ssh.PasswordCallback(func() (string, error) {
+			t.Errorf("unexpected call to PasswordCallback()")
+			return "notaverygoodpassword", nil
+		}),
+	}
+	client, err := server.TryDial(config)
+	if err == nil {
+		t.Errorf("expected TryDial() to fail")
+		_ = client.Close()
+	}
+	if !publicKeysCallbackCalled {
+		t.Errorf("expected PublicKeysCallback() to be called")
+	}
+}

+ 23 - 5
psiphon/common/crypto/ssh/test/test_unix_test.go

@@ -58,12 +58,17 @@ UsePAM yes
 PasswordAuthentication yes
 ChallengeResponseAuthentication yes
 AuthenticationMethods {{.AuthMethods}}
+`
+	maxAuthTriesSshdConfigTail = `
+PasswordAuthentication yes
+MaxAuthTries 1
 `
 )
 
 var configTmpl = map[string]*template.Template{
-	"default":   template.Must(template.New("").Parse(defaultSshdConfig)),
-	"MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail))}
+	"default":      template.Must(template.New("").Parse(defaultSshdConfig)),
+	"MultiAuth":    template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail)),
+	"MaxAuthTries": template.Must(template.New("").Parse(defaultSshdConfig + maxAuthTriesSshdConfigTail))}
 
 type server struct {
 	t          *testing.T
@@ -178,7 +183,7 @@ func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
 
 // addr is the user specified host:port. While we don't actually dial it,
 // we need to know this for host key matching
-func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
+func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (client *ssh.Client, err error) {
 	sshd, err := exec.LookPath("sshd")
 	if err != nil {
 		s.t.Skipf("skipping test: %v", err)
@@ -188,13 +193,26 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl
 	if err != nil {
 		s.t.Fatalf("unixConnection: %v", err)
 	}
+	defer func() {
+		// Close c2 after we've started the sshd command so that it won't prevent c1
+		// from returning EOF when the sshd command exits.
+		c2.Close()
+
+		// Leave c1 open if we're returning a client that wraps it.
+		// (The client is responsible for closing it.)
+		// Otherwise, close it to free up the socket.
+		if client == nil {
+			c1.Close()
+		}
+	}()
 
-	cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
 	f, err := c2.File()
 	if err != nil {
 		s.t.Fatalf("UnixConn.File: %v", err)
 	}
 	defer f.Close()
+
+	cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
 	cmd.Stdin = f
 	cmd.Stdout = f
 	cmd.Stderr = new(bytes.Buffer)
@@ -223,7 +241,7 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl
 		// processes are killed too.
 		cmd.Process.Signal(os.Interrupt)
 		cmd.Wait()
-		if s.t.Failed() {
+		if s.t.Failed() || testing.Verbose() {
 			// log any output from sshd process
 			s.t.Logf("sshd:\n%s", cmd.Stderr)
 		}

+ 1 - 1
psiphon/common/crypto/ssh/testdata/doc.go

@@ -5,4 +5,4 @@
 // This package contains test data shared between the various subpackages of
 // the github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh package. Under no circumstance should
 // this data be used for production code.
-package testdata // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/testdata"
+package testdata

+ 29 - 17
psiphon/common/crypto/ssh/testdata/keys.go

@@ -46,19 +46,31 @@ FFlRjzoB3Oxu8UQgb+MWPedtH9XYBbg9biz4jJLkXQ==
 -----END EC PRIVATE KEY-----
 `),
 	"rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
-MIICXAIBAAKBgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2
-a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8
-Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQIDAQAB
-AoGAJMCk5vqfSRzyXOTXLGIYCuR4Kj6pdsbNSeuuRGfYBeR1F2c/XdFAg7D/8s5R
-38p/Ih52/Ty5S8BfJtwtvgVY9ecf/JlU/rl/QzhG8/8KC0NG7KsyXklbQ7gJT8UT
-Ojmw5QpMk+rKv17ipDVkQQmPaj+gJXYNAHqImke5mm/K/h0CQQDciPmviQ+DOhOq
-2ZBqUfH8oXHgFmp7/6pXw80DpMIxgV3CwkxxIVx6a8lVH9bT/AFySJ6vXq4zTuV9
-6QmZcZzDAkEA2j/UXJPIs1fQ8z/6sONOkU/BjtoePFIWJlRxdN35cZjXnBraX5UR
-fFHkePv4YwqmXNqrBOvSu+w2WdSDci+IKwJAcsPRc/jWmsrJW1q3Ha0hSf/WG/Bu
-X7MPuXaKpP/DkzGoUmb8ks7yqj6XWnYkPNLjCc8izU5vRwIiyWBRf4mxMwJBAILa
-NDvRS0rjwt6lJGv7zPZoqDc65VfrK2aNyHx2PgFyzwrEOtuF57bu7pnvEIxpLTeM
-z26i6XVMeYXAWZMTloMCQBbpGgEERQpeUknLBqUHhg/wXF6+lFA+vEGnkY+Dwab2
-KCXFGd+SQ5GdUcEMe9isUH6DYj/6/yCDoFrXXmpQb+M=
+MIIEpQIBAAKCAQEAnuozKMtcQkIImZGSe4IujS4+Lkas9jmlBivziWGU3waivkpU
+vYspgJbh7vSvnHOPtKscdIJ+3UUyViDUoM73GumsmHvfeRCoA9YROZK4fQR9G0a1
+wfoRqsrJXGToCzTvr/I2KIwpUG0bRE9rUvsW+JN9xgri+cIJWtu/dGYDkILO4bkF
+IxtEvHNVvhGLenyOHFhPw3hAZ7/bKq8kvKzm9D2zOllHe1wWncMkhVmEFF9Houeh
+jbddmeIAAxBpRUFfzp1dD7503ADBlJdK306D4CeI4KIFiqE1VrmfcMgP8fti0S0b
+4JtmvevYoPd+/wB9ItFqvhc6nyuxF0PfWH+SvwIDAQABAoIBAQCCcpNeSFjKVvRC
+Q1nwIrPd1njaec+fK0CIqWl3e2++B+9trwySrvp5gOGjyp2hGsd7Mf7gsQI81oF0
+a+y+uEXlhK3WWdDey0pwI7ft/7+LeDTOQCQRQBpijaXvPzGviVu7nWLRtARx7a41
+S9A4xL5dfI0BFYyuIpaVS8+EV/1TEJIbceZ5q5RBlARA1rc+nBvjygNzYdRq9Rao
+yyehvnXZ7pQrATnwofPolbZNseW2Q9sRMmtm1E60XJJ433P7nFbxXtsMAYgxWSQc
+V/92iRPYeD7sN/b7qulLEgC6e8el2gLGIB9aQyG7B6KFqloqvx/ymYs07+bWIQCU
+6i9y7LABAoGBANKB2Rs0lF5c+gpEE3w0AWoxGyL04TZECtl2hQ/b2jc2n8hLqLhv
+zNIKN+xzEP6j0ijXajjEMLiQH10qQ6+Plv8C7o8GJC1V/Oj4u4kbPqfVV1kSxuWm
+FBjz6+c8VPbEBgXq5lCMEgC2Ii8XVRoyd3iSSOh+LIMBu/Br3JEsjJq/AoGBAMFC
+CvODTiThrZo8v765dRSHrLOvB4jZOPrEKLWECLaQDQpuhgzZhyWm9zFfxGB+LWE9
+R9pU6ZCFPtPfd4cRZ9cezrp+lgdrqjcUX/2ZLLMk21WXFuRaw/4KTlKNsMY6lbAK
+qVkUFWIZWaCQ8bdVCP48polipRNmN9mAHsIhIwgBAoGBAM6yn1qeS11I0GAKLlPT
+wNvjsfCmIQmm0DxtqwRCbUevxD7pQ5cueCB51iW/ap2OgEqIEo4A3pIrOhDB8kpN
+pQdrepFHh3hYqYicy5A6B1DHJAibbl+Krss9n5KjZA4VtpBS8al/kCHQtUomD/M0
+QKlMgnh/g/dzWXYegyqtYraDAoGAGbeBH5B8iJnjcR/eYDHrq5S2XZ7QANzvISeT
+RzxPsIOQyK+WdQVJX7BNOqvExRZlUYhHFH2yKwIgLy+Qh0/Aora9ycFok4o3N2cl
+suh8M0aXTVdyu2Z8qESU0ZV7TZWkL63rhSgQBGLdM2m2ULAnJzXI74VJ9D/o9K+A
+6FJiiAECgYEAujJ/hKxVKEUvxwloGSDhKCUH86+7UOkb/EM2zZFrlPYAz1VcCwr3
+K14r8BtmLFXuLXOlACpoH0Wf4uia+t6n8m9JK3mvpJ7fempAsptP3AdZMQFe7xUm
+SXEGQBYxcyS5Q+ncwWZuPgby5wJ9D4Fd6TQH+wwG52sFugt/fGxbPug=
 -----END RSA PRIVATE KEY-----
 `),
 	"rsa-sha2-256": []byte(`-----BEGIN RSA PRIVATE KEY-----
@@ -222,17 +234,17 @@ var SSHCertificates = map[string][]byte{
 	//
 	// 2. Assumes "ca" key above in file named "ca", sign a cert for "rsa.pub":
 	//    ssh-keygen -s ca -h -n host.example.com -V +500w -I host.example.com-key rsa.pub
-	"rsa": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLjYqmmuTSEmjVhSfLQphBSTJMLwIZhRgmpn8FHKLiEIAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABZHN8UAAAAAGsjIYUAAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABDwAAAAdzc2gtcnNhAAABALeDea+60H6xJGhktAyosHaSY7AYzLocaqd8hJQjEIDifBwzoTlnBmcK9CxGhKuaoJFThdCLdaevCeOSuquh8HTkf+2ebZZc/G5T+2thPvPqmcuEcmMosWo+SIjYhbP3S6KD49aLC1X0kz8IBQeauFvURhkZ5ZjhA1L4aQYt9NjL73nqOl8PplRui+Ov5w8b4ldul4zOvYAFrzfcP6wnnXk3c1Zzwwf5wynD5jakO8GpYKBuhM7Z4crzkKSQjU3hla7xqgfomC5Gz4XbR2TNjcQiRrJQ0UlKtX3X3ObRCEhuvG0Kzjklhv+Ddw6txrhKjMjiSi/Yyius/AE8TmC1p4U= host.example.com
+	"rsa": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgWcP1x+t9PQTxkt/fRa8v5HXz0FFW/fwN3mpR5jdGo3UAAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAACAAAAFGhvc3QuZXhhbXBsZS5jb20ta2V5AAAAFAAAABBob3N0LmV4YW1wbGUuY29tAAAAAGXEzUQAAAAAd8sPtQAAAAAAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEPAAAAB3NzaC1yc2EAAAEAs8/78EsftS/0+6CfUAtRhmgpOPZVH8XJc2o4Qx/lQCBg/uB2sucGGx/5BuFUwAjhMNQ1UejJ6+AbNgZNlH7LIssGxpcW+eKQXLfd9KwWNPCU3DqZGiMBsYESNtVLjKAMuvJmxKFFu4rceyya+GZfsQFJR7G++XK0cQQ4rgDSnQo3pxQnHwguFWeHtwjdOj0helYER4XDKjKQlo0SxI1nVo7n1jkl8QghdxX6HKia6MP7MqstgNujTvrCEUJ6L3YflD/SJ00Y3ySD9+5hqvfqrDh4c4rRuc4+k0e3fYZwyWO2NNtKHRBOZAgW8zzfR3G99YdTw8N8+2oXeAdvB9k0tA== host.example.com
 `),
-	"rsa-sha2-256": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgOyK28gunJkM60qp4EbsYAjgbUsyjS8u742OLjipIgc0AAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABeSMJ4AAAAAHBPBLwAAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi0yNTYAAAEAbG4De/+QiqopPS3O1H7ySeEUCY56qmdgr02sFErnihdXPDaWXUXxacvJHaEtLrSTSaPL/3v3iKvjLWDOHaQ5c+cN9J7Tqzso7RQCXZD2nK9bwCUyBoiDyBCRe8w4DQEtfL5okpVzQsSAiojQ8hBohMOpy3gFfXrdm4PVC1ZKqlZh4fAc7ajieRq/Tpq2xOLdHwxkcgPNR83WVHva6K9/xjev/5n227/gkHo0qbGs8YYDOFXIEhENi+B23IzxdNVieWdyQpYpe0C2i95Jhyo0wJmaFY2ArruTS+D1jGQQpMPvAQRy26/A5hI83GLhpwyhrN/M8wCxzAhyPL6Ieuh5tQ== host.example.com
+	"rsa-sha2-256": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg4+hKHVPKv183MU/Q7XD/mzDBFSc2YY3eraltxLMGJo0AAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAACAAAAFGhvc3QuZXhhbXBsZS5jb20ta2V5AAAAFAAAABBob3N0LmV4YW1wbGUuY29tAAAAAGXEzYAAAAAAd8sP4wAAAAAAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEUAAAADHJzYS1zaGEyLTI1NgAAAQA/ByIegNZYJRRl413S/8LxGvTZnbxsPwaluoJ/54niGZV9P28THz7d9jXfSHPjalhH93jNPfTYXvI4opnDC37ua1Nu8KKfk40IWXnnDdZLWraUxEidIzhmfVtz8kGdGoFQ8H0EzubL7zKNOTlfSfOoDlmQVOuxT/+eh2mEp4ri0/+8J1mLfLBr8tREX0/iaNjK+RKdcyTMicKursAYMCDdu8vlaphxea+ocyHM9izSX/l33t44V13ueTqIOh2Zbl2UE2k+jk+0dc1CmV0SEoiWiIyt8TRM4yQry1vPlQLsrf28sYM/QMwnhCVhyZO3vs5F25aQWrB9d51VEzBW9/fd host.example.com
 `),
-	"rsa-sha2-512": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgFGv4IpXfs4L/Y0b3rmUdPFhWoUrVnXuPxXr6aHGs7wgAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABeSMRYAAAAAHBPBp4AAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi01MTIAAAEAnF4fVj6mm+UFeNCIf9AKJCv9WzymjjPvzzmaMWWkPWqoV0P0m5SiYfvbY9SbA73Blpv8SOr0DmpublF183kodREia4KyVuC8hLhSCV2Y16hy9MBegOZMepn80w+apj7Rn9QCz5OfEakDdztp6OWTBtqxnZFcTQ4XrgFkNWeWRElGdEvAVNn2WHwHi4EIdz0mdv48Imv5SPlOuW862ZdFG4Do1dUfDIiGsBofLlgcyIYlf+eNHul6sBeUkuwFxisMpI5DQzNp8PX1g/QJA2wzwT674PTqDXNttKjyh50Fdr4sXxm9Gz1+jVLoESvFNa55ERdSyAqNu4wTy11MZsWwSA== host.example.com
+	"rsa-sha2-512": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgylCoUX+0nYXGG9k/KEepcgEd22921eUwSQsYuQXKOswAAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAACAAAAFGhvc3QuZXhhbXBsZS5jb20ta2V5AAAAFAAAABBob3N0LmV4YW1wbGUuY29tAAAAAGXEzbwAAAAAd8sQBAAAAAAAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEUAAAADHJzYS1zaGEyLTUxMgAAAQADOjYUNxzwYZ9O3zjjZSKhX7Ix/vUBq8lIltBRckbKJtcjn2/qqtaeZK2ijkbMlnPFCJ59U+Z6m4DMU6gxYF8R9/WJrANlWYNQ4+52fXjE/FPDdJkwB/kPWABt+ffEiM1HfzP4/zXgTtOJ0GogEzTYoMSrhAlpdKACUaC9nCQ7gjmO+owGkrB3ZbyQS4gltqQZjiBNqF0P/vxmXP+0Kx66/ei8JNYUpEPCXT4ZhLZmsVptaYD1gpvc2i5T5LnjRrsvf4Q6EKR+gwJEjlwnylhH5+h+ZU/KXydkw1hhEVP+ZiIBkVo9nN78Qb/VMJum4Fdoltuyh/9k2lMDeyDuRVh7 host.example.com
 `),
 	// Assumes "ca" key above in file named "ca", sign a user cert for "rsa.pub"
 	// using "testcertificate" as principal:
 	//
 	// ssh-keygen -s ca -I username -n testcertificate rsa.pub
-	"rsa-user-testcertificate": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgr0MnhSf+KkQWEQmweJOGsJfOrUt80pQZDaI9YiwSfDUAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAQAAAAh1c2VybmFtZQAAABMAAAAPdGVzdGNlcnRpZmljYXRlAAAAAAAAAAD//////////wAAAAAAAACCAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFnBlcm1pdC1wb3J0LWZvcndhcmRpbmcAAAAAAAAACnBlcm1pdC1wdHkAAAAAAAAADnBlcm1pdC11c2VyLXJjAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi01MTIAAAEAFuA+67KvnlmcodIp0Lv4mR9UW/CHghAaN1csBJTkI8mx3wXKyIPTsS2uXboEhWD0a7S9gps2SEwC5m6E3kV2Rzg7aH1S03GZqMvVlK2VHe7fzuoW2yOKk6yEPjeTF0pKCFbUQ6mce8pRpD/zdvjG0Z287XM3c3Axlrn7qq7TS0MDTjEZ/dsUNFHxep3co/HuAsWVWPVDItr/FPnvZ6WVH1yc8N/AJn0gLHobkGgug22vI9rNIge1wrnXxU9BUSouzkau/PQsrCQapnn+I1H7HaQt0wdk45nxMP+ags+HRI9qpX/p8WDn6+zpqYqN/nfw2aoytyaJqhsV32Teuqtrgg== rsa.pub
+	"rsa-user-testcertificate": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgEKR9xnhXkbi/pgP669VjBH6XVTYR0yx1wunCtOIjzjUAAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAABAAAACHVzZXJuYW1lAAAAEwAAAA90ZXN0Y2VydGlmaWNhdGUAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEUAAAADHJzYS1zaGEyLTUxMgAAAQCKVn2S7FJYhXTRVbcz1Di1HLz1g5Yae5WBhd0Tg471XkNw7ylcCK23Wnrzj1GxrW0oWCCGHROtUnxQXei1xNWt8HONN+eeafrJSZJR6ald3Yd4OveXlHNT6mEDPgqRj4B56OPoY33LzpaFlQZlZ6U9KXySshNaCTjVp3ojTj6uPNxcuOnG9O5emEPC2eaM4QYsz4cqHNJ9SWWEu+HIQgpx5SM12qcrq/KhN6WJhG3edL1YAsxkf1/THEcfi6wvsHi+DKewJ5hZ876//japjHBKA9SB7gsCrLx3m+0XI8TkWUZTZJHETJHHVjJN8xjRvMv5gWKLvRsz7ScmnN1+ckL6 rsa.pub
 `),
 }