Просмотр исходного кода

Merge pull request #726 from rod-hynes/upgrade-ssh

Upgrade x/crypto/ssh
Rod Hynes 1 год назад
Родитель
Сommit
0b08d9040c
41 измененных файлов с 2030 добавлено и 351 удалено
  1. 0 10
      psiphon/common/crypto/.gitattributes
  2. 2 2
      psiphon/common/crypto/LICENSE
  3. 126 0
      psiphon/common/crypto/internal/poly1305/_asm/sum_amd64_asm.go
  4. 0 40
      psiphon/common/crypto/internal/poly1305/bits_compat.go
  5. 0 22
      psiphon/common/crypto/internal/poly1305/bits_go1.13.go
  6. 1 2
      psiphon/common/crypto/internal/poly1305/mac_noasm.go
  7. 59 74
      psiphon/common/crypto/internal/poly1305/sum_amd64.s
  8. 23 20
      psiphon/common/crypto/internal/poly1305/sum_generic.go
  9. 1 2
      psiphon/common/crypto/internal/poly1305/sum_ppc64x.go
  10. 25 20
      psiphon/common/crypto/internal/poly1305/sum_ppc64x.s
  11. 2 2
      psiphon/common/crypto/internal/testenv/exec.go
  12. 1 1
      psiphon/common/crypto/nacl/secretbox/secretbox.go
  13. 1 1
      psiphon/common/crypto/ssh/agent/client.go
  14. 16 9
      psiphon/common/crypto/ssh/agent/client_test.go
  15. 9 0
      psiphon/common/crypto/ssh/agent/keyring.go
  16. 46 0
      psiphon/common/crypto/ssh/agent/keyring_test.go
  17. 15 17
      psiphon/common/crypto/ssh/certs_test.go
  18. 20 3
      psiphon/common/crypto/ssh/client_auth.go
  19. 108 8
      psiphon/common/crypto/ssh/client_auth_test.go
  20. 1 1
      psiphon/common/crypto/ssh/doc.go
  21. 1 1
      psiphon/common/crypto/ssh/example_test.go
  22. 49 12
      psiphon/common/crypto/ssh/handshake.go
  23. 220 0
      psiphon/common/crypto/ssh/handshake_test.go
  24. 51 1
      psiphon/common/crypto/ssh/keys.go
  25. 86 2
      psiphon/common/crypto/ssh/keys_test.go
  26. 2 0
      psiphon/common/crypto/ssh/messages.go
  27. 56 0
      psiphon/common/crypto/ssh/messages_test.go
  28. 196 65
      psiphon/common/crypto/ssh/server.go
  29. 412 0
      psiphon/common/crypto/ssh/server_multi_auth_test.go
  30. 338 0
      psiphon/common/crypto/ssh/server_test.go
  31. 1 1
      psiphon/common/crypto/ssh/tcpip.go
  32. 4 3
      psiphon/common/crypto/ssh/test/agent_unix_test.go
  33. 1 1
      psiphon/common/crypto/ssh/test/cert_test.go
  34. 1 0
      psiphon/common/crypto/ssh/test/dial_unix_test.go
  35. 1 1
      psiphon/common/crypto/ssh/test/doc.go
  36. 10 0
      psiphon/common/crypto/ssh/test/forward_unix_test.go
  37. 1 1
      psiphon/common/crypto/ssh/test/server_test.go
  38. 91 6
      psiphon/common/crypto/ssh/test/session_test.go
  39. 23 5
      psiphon/common/crypto/ssh/test/test_unix_test.go
  40. 1 1
      psiphon/common/crypto/ssh/testdata/doc.go
  41. 29 17
      psiphon/common/crypto/ssh/testdata/keys.go

+ 0 - 10
psiphon/common/crypto/.gitattributes

@@ -1,10 +0,0 @@
-# Treat all files in this repo as binary, with no git magic updating
-# line endings. Windows users contributing to Go will need to use a
-# modern version of git and editors capable of LF line endings.
-#
-# We'll prevent accidental CRLF line endings from entering the repo
-# via the git-review gofmt checks.
-#
-# See golang.org/issue/9281
-
-* -text

+ 2 - 2
psiphon/common/crypto/LICENSE

@@ -1,4 +1,4 @@
-Copyright (c) 2009 The Go Authors. All rights reserved.
+Copyright 2009 The Go Authors.
 
 
 Redistribution and use in source and binary forms, with or without
 Redistribution and use in source and binary forms, with or without
 modification, are permitted provided that the following conditions are
 modification, are permitted provided that the following conditions are
@@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
 copyright notice, this list of conditions and the following disclaimer
 copyright notice, this list of conditions and the following disclaimer
 in the documentation and/or other materials provided with the
 in the documentation and/or other materials provided with the
 distribution.
 distribution.
-   * Neither the name of Google Inc. nor the names of its
+   * Neither the name of Google LLC nor the names of its
 contributors may be used to endorse or promote products derived from
 contributors may be used to endorse or promote products derived from
 this software without specific prior written permission.
 this software without specific prior written permission.
 
 

+ 126 - 0
psiphon/common/crypto/internal/poly1305/_asm/sum_amd64_asm.go

@@ -0,0 +1,126 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+	. "github.com/mmcloughlin/avo/build"
+	. "github.com/mmcloughlin/avo/operand"
+	. "github.com/mmcloughlin/avo/reg"
+	_ "golang.org/x/crypto/sha3"
+)
+
+//go:generate go run . -out ../sum_amd64.s -pkg poly1305
+
+func main() {
+	Package("golang.org/x/crypto/internal/poly1305")
+	ConstraintExpr("gc,!purego")
+	update()
+	Generate()
+}
+
+func update() {
+	Implement("update")
+
+	Load(Param("state"), RDI)
+	MOVQ(NewParamAddr("msg_base", 8), RSI)
+	MOVQ(NewParamAddr("msg_len", 16), R15)
+
+	MOVQ(Mem{Base: DI}.Offset(0), R8)   // h0
+	MOVQ(Mem{Base: DI}.Offset(8), R9)   // h1
+	MOVQ(Mem{Base: DI}.Offset(16), R10) // h2
+	MOVQ(Mem{Base: DI}.Offset(24), R11) // r0
+	MOVQ(Mem{Base: DI}.Offset(32), R12) // r1
+
+	CMPQ(R15, Imm(16))
+	JB(LabelRef("bytes_between_0_and_15"))
+
+	Label("loop")
+	POLY1305_ADD(RSI, R8, R9, R10)
+
+	Label("multiply")
+	POLY1305_MUL(R8, R9, R10, R11, R12, RBX, RCX, R13, R14)
+	SUBQ(Imm(16), R15)
+	CMPQ(R15, Imm(16))
+	JAE(LabelRef("loop"))
+
+	Label("bytes_between_0_and_15")
+	TESTQ(R15, R15)
+	JZ(LabelRef("done"))
+	MOVQ(U32(1), RBX)
+	XORQ(RCX, RCX)
+	XORQ(R13, R13)
+	ADDQ(R15, RSI)
+
+	Label("flush_buffer")
+	SHLQ(Imm(8), RBX, RCX)
+	SHLQ(Imm(8), RBX)
+	MOVB(Mem{Base: SI}.Offset(-1), R13B)
+	XORQ(R13, RBX)
+	DECQ(RSI)
+	DECQ(R15)
+	JNZ(LabelRef("flush_buffer"))
+
+	ADDQ(RBX, R8)
+	ADCQ(RCX, R9)
+	ADCQ(Imm(0), R10)
+	MOVQ(U32(16), R15)
+	JMP(LabelRef("multiply"))
+
+	Label("done")
+	MOVQ(R8, Mem{Base: DI}.Offset(0))
+	MOVQ(R9, Mem{Base: DI}.Offset(8))
+	MOVQ(R10, Mem{Base: DI}.Offset(16))
+	RET()
+}
+
+func POLY1305_ADD(msg, h0, h1, h2 GPPhysical) {
+	ADDQ(Mem{Base: msg}.Offset(0), h0)
+	ADCQ(Mem{Base: msg}.Offset(8), h1)
+	ADCQ(Imm(1), h2)
+	LEAQ(Mem{Base: msg}.Offset(16), msg)
+}
+
+func POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3 GPPhysical) {
+	MOVQ(r0, RAX)
+	MULQ(h0)
+	MOVQ(RAX, t0)
+	MOVQ(RDX, t1)
+	MOVQ(r0, RAX)
+	MULQ(h1)
+	ADDQ(RAX, t1)
+	ADCQ(Imm(0), RDX)
+	MOVQ(r0, t2)
+	IMULQ(h2, t2)
+	ADDQ(RDX, t2)
+
+	MOVQ(r1, RAX)
+	MULQ(h0)
+	ADDQ(RAX, t1)
+	ADCQ(Imm(0), RDX)
+	MOVQ(RDX, h0)
+	MOVQ(r1, t3)
+	IMULQ(h2, t3)
+	MOVQ(r1, RAX)
+	MULQ(h1)
+	ADDQ(RAX, t2)
+	ADCQ(RDX, t3)
+	ADDQ(h0, t2)
+	ADCQ(Imm(0), t3)
+
+	MOVQ(t0, h0)
+	MOVQ(t1, h1)
+	MOVQ(t2, h2)
+	ANDQ(Imm(3), h2)
+	MOVQ(t2, t0)
+	ANDQ(I32(-4), t0)
+	ADDQ(t0, h0)
+	ADCQ(t3, h1)
+	ADCQ(Imm(0), h2)
+	SHRQ(Imm(2), t3, t2)
+	SHRQ(Imm(2), t3)
+	ADDQ(t2, h0)
+	ADCQ(t3, h1)
+	ADCQ(Imm(0), h2)
+}

+ 0 - 40
psiphon/common/crypto/internal/poly1305/bits_compat.go

@@ -1,40 +0,0 @@
-// Copyright 2019 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !go1.13
-// +build !go1.13
-
-package poly1305
-
-// Generic fallbacks for the math/bits intrinsics, copied from
-// src/math/bits/bits.go. They were added in Go 1.12, but Add64 and Sum64 had
-// variable time fallbacks until Go 1.13.
-
-func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
-	sum = x + y + carry
-	carryOut = ((x & y) | ((x | y) &^ sum)) >> 63
-	return
-}
-
-func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
-	diff = x - y - borrow
-	borrowOut = ((^x & y) | (^(x ^ y) & diff)) >> 63
-	return
-}
-
-func bitsMul64(x, y uint64) (hi, lo uint64) {
-	const mask32 = 1<<32 - 1
-	x0 := x & mask32
-	x1 := x >> 32
-	y0 := y & mask32
-	y1 := y >> 32
-	w0 := x0 * y0
-	t := x1*y0 + w0>>32
-	w1 := t & mask32
-	w2 := t >> 32
-	w1 += x0 * y1
-	hi = x1*y1 + w2 + w1>>32
-	lo = x * y
-	return
-}

+ 0 - 22
psiphon/common/crypto/internal/poly1305/bits_go1.13.go

@@ -1,22 +0,0 @@
-// Copyright 2019 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.13
-// +build go1.13
-
-package poly1305
-
-import "math/bits"
-
-func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
-	return bits.Add64(x, y, carry)
-}
-
-func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
-	return bits.Sub64(x, y, borrow)
-}
-
-func bitsMul64(x, y uint64) (hi, lo uint64) {
-	return bits.Mul64(x, y)
-}

+ 1 - 2
psiphon/common/crypto/internal/poly1305/mac_noasm.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 // license that can be found in the LICENSE file.
 
 
-//go:build (!amd64 && !ppc64le && !s390x) || !gc || purego
-// +build !amd64,!ppc64le,!s390x !gc purego
+//go:build (!amd64 && !ppc64le && !ppc64 && !s390x) || !gc || purego
 
 
 package poly1305
 package poly1305
 
 

+ 59 - 74
psiphon/common/crypto/internal/poly1305/sum_amd64.s

@@ -1,109 +1,94 @@
-// Copyright 2012 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
+// Code generated by command: go run sum_amd64_asm.go -out ../sum_amd64.s -pkg poly1305. DO NOT EDIT.
 
 
 //go:build gc && !purego
 //go:build gc && !purego
 // +build gc,!purego
 // +build gc,!purego
 
 
-#include "textflag.h"
-
-#define POLY1305_ADD(msg, h0, h1, h2) \
-	ADDQ 0(msg), h0;  \
-	ADCQ 8(msg), h1;  \
-	ADCQ $1, h2;      \
-	LEAQ 16(msg), msg
-
-#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3) \
-	MOVQ  r0, AX;                  \
-	MULQ  h0;                      \
-	MOVQ  AX, t0;                  \
-	MOVQ  DX, t1;                  \
-	MOVQ  r0, AX;                  \
-	MULQ  h1;                      \
-	ADDQ  AX, t1;                  \
-	ADCQ  $0, DX;                  \
-	MOVQ  r0, t2;                  \
-	IMULQ h2, t2;                  \
-	ADDQ  DX, t2;                  \
-	                               \
-	MOVQ  r1, AX;                  \
-	MULQ  h0;                      \
-	ADDQ  AX, t1;                  \
-	ADCQ  $0, DX;                  \
-	MOVQ  DX, h0;                  \
-	MOVQ  r1, t3;                  \
-	IMULQ h2, t3;                  \
-	MOVQ  r1, AX;                  \
-	MULQ  h1;                      \
-	ADDQ  AX, t2;                  \
-	ADCQ  DX, t3;                  \
-	ADDQ  h0, t2;                  \
-	ADCQ  $0, t3;                  \
-	                               \
-	MOVQ  t0, h0;                  \
-	MOVQ  t1, h1;                  \
-	MOVQ  t2, h2;                  \
-	ANDQ  $3, h2;                  \
-	MOVQ  t2, t0;                  \
-	ANDQ  $0xFFFFFFFFFFFFFFFC, t0; \
-	ADDQ  t0, h0;                  \
-	ADCQ  t3, h1;                  \
-	ADCQ  $0, h2;                  \
-	SHRQ  $2, t3, t2;              \
-	SHRQ  $2, t3;                  \
-	ADDQ  t2, h0;                  \
-	ADCQ  t3, h1;                  \
-	ADCQ  $0, h2
-
-// func update(state *[7]uint64, msg []byte)
+// func update(state *macState, msg []byte)
 TEXT ·update(SB), $0-32
 TEXT ·update(SB), $0-32
 	MOVQ state+0(FP), DI
 	MOVQ state+0(FP), DI
 	MOVQ msg_base+8(FP), SI
 	MOVQ msg_base+8(FP), SI
 	MOVQ msg_len+16(FP), R15
 	MOVQ msg_len+16(FP), R15
-
-	MOVQ 0(DI), R8   // h0
-	MOVQ 8(DI), R9   // h1
-	MOVQ 16(DI), R10 // h2
-	MOVQ 24(DI), R11 // r0
-	MOVQ 32(DI), R12 // r1
-
-	CMPQ R15, $16
+	MOVQ (DI), R8
+	MOVQ 8(DI), R9
+	MOVQ 16(DI), R10
+	MOVQ 24(DI), R11
+	MOVQ 32(DI), R12
+	CMPQ R15, $0x10
 	JB   bytes_between_0_and_15
 	JB   bytes_between_0_and_15
 
 
 loop:
 loop:
-	POLY1305_ADD(SI, R8, R9, R10)
+	ADDQ (SI), R8
+	ADCQ 8(SI), R9
+	ADCQ $0x01, R10
+	LEAQ 16(SI), SI
 
 
 multiply:
 multiply:
-	POLY1305_MUL(R8, R9, R10, R11, R12, BX, CX, R13, R14)
-	SUBQ $16, R15
-	CMPQ R15, $16
-	JAE  loop
+	MOVQ  R11, AX
+	MULQ  R8
+	MOVQ  AX, BX
+	MOVQ  DX, CX
+	MOVQ  R11, AX
+	MULQ  R9
+	ADDQ  AX, CX
+	ADCQ  $0x00, DX
+	MOVQ  R11, R13
+	IMULQ R10, R13
+	ADDQ  DX, R13
+	MOVQ  R12, AX
+	MULQ  R8
+	ADDQ  AX, CX
+	ADCQ  $0x00, DX
+	MOVQ  DX, R8
+	MOVQ  R12, R14
+	IMULQ R10, R14
+	MOVQ  R12, AX
+	MULQ  R9
+	ADDQ  AX, R13
+	ADCQ  DX, R14
+	ADDQ  R8, R13
+	ADCQ  $0x00, R14
+	MOVQ  BX, R8
+	MOVQ  CX, R9
+	MOVQ  R13, R10
+	ANDQ  $0x03, R10
+	MOVQ  R13, BX
+	ANDQ  $-4, BX
+	ADDQ  BX, R8
+	ADCQ  R14, R9
+	ADCQ  $0x00, R10
+	SHRQ  $0x02, R14, R13
+	SHRQ  $0x02, R14
+	ADDQ  R13, R8
+	ADCQ  R14, R9
+	ADCQ  $0x00, R10
+	SUBQ  $0x10, R15
+	CMPQ  R15, $0x10
+	JAE   loop
 
 
 bytes_between_0_and_15:
 bytes_between_0_and_15:
 	TESTQ R15, R15
 	TESTQ R15, R15
 	JZ    done
 	JZ    done
-	MOVQ  $1, BX
+	MOVQ  $0x00000001, BX
 	XORQ  CX, CX
 	XORQ  CX, CX
 	XORQ  R13, R13
 	XORQ  R13, R13
 	ADDQ  R15, SI
 	ADDQ  R15, SI
 
 
 flush_buffer:
 flush_buffer:
-	SHLQ $8, BX, CX
-	SHLQ $8, BX
+	SHLQ $0x08, BX, CX
+	SHLQ $0x08, BX
 	MOVB -1(SI), R13
 	MOVB -1(SI), R13
 	XORQ R13, BX
 	XORQ R13, BX
 	DECQ SI
 	DECQ SI
 	DECQ R15
 	DECQ R15
 	JNZ  flush_buffer
 	JNZ  flush_buffer
-
 	ADDQ BX, R8
 	ADDQ BX, R8
 	ADCQ CX, R9
 	ADCQ CX, R9
-	ADCQ $0, R10
-	MOVQ $16, R15
+	ADCQ $0x00, R10
+	MOVQ $0x00000010, R15
 	JMP  multiply
 	JMP  multiply
 
 
 done:
 done:
-	MOVQ R8, 0(DI)
+	MOVQ R8, (DI)
 	MOVQ R9, 8(DI)
 	MOVQ R9, 8(DI)
 	MOVQ R10, 16(DI)
 	MOVQ R10, 16(DI)
 	RET
 	RET

+ 23 - 20
psiphon/common/crypto/internal/poly1305/sum_generic.go

@@ -7,7 +7,10 @@
 
 
 package poly1305
 package poly1305
 
 
-import "encoding/binary"
+import (
+	"encoding/binary"
+	"math/bits"
+)
 
 
 // Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag
 // Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag
 // for a 64 bytes message is approximately
 // for a 64 bytes message is approximately
@@ -114,13 +117,13 @@ type uint128 struct {
 }
 }
 
 
 func mul64(a, b uint64) uint128 {
 func mul64(a, b uint64) uint128 {
-	hi, lo := bitsMul64(a, b)
+	hi, lo := bits.Mul64(a, b)
 	return uint128{lo, hi}
 	return uint128{lo, hi}
 }
 }
 
 
 func add128(a, b uint128) uint128 {
 func add128(a, b uint128) uint128 {
-	lo, c := bitsAdd64(a.lo, b.lo, 0)
-	hi, c := bitsAdd64(a.hi, b.hi, c)
+	lo, c := bits.Add64(a.lo, b.lo, 0)
+	hi, c := bits.Add64(a.hi, b.hi, c)
 	if c != 0 {
 	if c != 0 {
 		panic("poly1305: unexpected overflow")
 		panic("poly1305: unexpected overflow")
 	}
 	}
@@ -155,8 +158,8 @@ func updateGeneric(state *macState, msg []byte) {
 		// hide leading zeroes. For full chunks, that's 1 << 128, so we can just
 		// hide leading zeroes. For full chunks, that's 1 << 128, so we can just
 		// add 1 to the most significant (2¹²⁸) limb, h2.
 		// add 1 to the most significant (2¹²⁸) limb, h2.
 		if len(msg) >= TagSize {
 		if len(msg) >= TagSize {
-			h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
-			h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
+			h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
+			h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
 			h2 += c + 1
 			h2 += c + 1
 
 
 			msg = msg[TagSize:]
 			msg = msg[TagSize:]
@@ -165,8 +168,8 @@ func updateGeneric(state *macState, msg []byte) {
 			copy(buf[:], msg)
 			copy(buf[:], msg)
 			buf[len(msg)] = 1
 			buf[len(msg)] = 1
 
 
-			h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
-			h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
+			h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
+			h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
 			h2 += c
 			h2 += c
 
 
 			msg = nil
 			msg = nil
@@ -219,9 +222,9 @@ func updateGeneric(state *macState, msg []byte) {
 		m3 := h2r1
 		m3 := h2r1
 
 
 		t0 := m0.lo
 		t0 := m0.lo
-		t1, c := bitsAdd64(m1.lo, m0.hi, 0)
-		t2, c := bitsAdd64(m2.lo, m1.hi, c)
-		t3, _ := bitsAdd64(m3.lo, m2.hi, c)
+		t1, c := bits.Add64(m1.lo, m0.hi, 0)
+		t2, c := bits.Add64(m2.lo, m1.hi, c)
+		t3, _ := bits.Add64(m3.lo, m2.hi, c)
 
 
 		// Now we have the result as 4 64-bit limbs, and we need to reduce it
 		// Now we have the result as 4 64-bit limbs, and we need to reduce it
 		// modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do
 		// modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do
@@ -243,14 +246,14 @@ func updateGeneric(state *macState, msg []byte) {
 
 
 		// To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c.
 		// To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c.
 
 
-		h0, c = bitsAdd64(h0, cc.lo, 0)
-		h1, c = bitsAdd64(h1, cc.hi, c)
+		h0, c = bits.Add64(h0, cc.lo, 0)
+		h1, c = bits.Add64(h1, cc.hi, c)
 		h2 += c
 		h2 += c
 
 
 		cc = shiftRightBy2(cc)
 		cc = shiftRightBy2(cc)
 
 
-		h0, c = bitsAdd64(h0, cc.lo, 0)
-		h1, c = bitsAdd64(h1, cc.hi, c)
+		h0, c = bits.Add64(h0, cc.lo, 0)
+		h1, c = bits.Add64(h1, cc.hi, c)
 		h2 += c
 		h2 += c
 
 
 		// h2 is at most 3 + 1 + 1 = 5, making the whole of h at most
 		// h2 is at most 3 + 1 + 1 = 5, making the whole of h at most
@@ -287,9 +290,9 @@ func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
 	// in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the
 	// in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the
 	// result if the subtraction underflows, and t otherwise.
 	// result if the subtraction underflows, and t otherwise.
 
 
-	hMinusP0, b := bitsSub64(h0, p0, 0)
-	hMinusP1, b := bitsSub64(h1, p1, b)
-	_, b = bitsSub64(h2, p2, b)
+	hMinusP0, b := bits.Sub64(h0, p0, 0)
+	hMinusP1, b := bits.Sub64(h1, p1, b)
+	_, b = bits.Sub64(h2, p2, b)
 
 
 	// h = h if h < p else h - p
 	// h = h if h < p else h - p
 	h0 = select64(b, h0, hMinusP0)
 	h0 = select64(b, h0, hMinusP0)
@@ -301,8 +304,8 @@ func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
 	//
 	//
 	// by just doing a wide addition with the 128 low bits of h and discarding
 	// by just doing a wide addition with the 128 low bits of h and discarding
 	// the overflow.
 	// the overflow.
-	h0, c := bitsAdd64(h0, s[0], 0)
-	h1, _ = bitsAdd64(h1, s[1], c)
+	h0, c := bits.Add64(h0, s[0], 0)
+	h1, _ = bits.Add64(h1, s[1], c)
 
 
 	binary.LittleEndian.PutUint64(out[0:8], h0)
 	binary.LittleEndian.PutUint64(out[0:8], h0)
 	binary.LittleEndian.PutUint64(out[8:16], h1)
 	binary.LittleEndian.PutUint64(out[8:16], h1)

+ 1 - 2
psiphon/common/crypto/internal/poly1305/sum_ppc64le.go → psiphon/common/crypto/internal/poly1305/sum_ppc64x.go

@@ -2,8 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 // license that can be found in the LICENSE file.
 
 
-//go:build gc && !purego
-// +build gc,!purego
+//go:build gc && !purego && (ppc64 || ppc64le)
 
 
 package poly1305
 package poly1305
 
 

+ 25 - 20
psiphon/common/crypto/internal/poly1305/sum_ppc64le.s → psiphon/common/crypto/internal/poly1305/sum_ppc64x.s

@@ -2,16 +2,25 @@
 // Use of this source code is governed by a BSD-style
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 // license that can be found in the LICENSE file.
 
 
-//go:build gc && !purego
-// +build gc,!purego
+//go:build gc && !purego && (ppc64 || ppc64le)
 
 
 #include "textflag.h"
 #include "textflag.h"
 
 
 // This was ported from the amd64 implementation.
 // This was ported from the amd64 implementation.
 
 
+#ifdef GOARCH_ppc64le
+#define LE_MOVD MOVD
+#define LE_MOVWZ MOVWZ
+#define LE_MOVHZ MOVHZ
+#else
+#define LE_MOVD MOVDBR
+#define LE_MOVWZ MOVWBR
+#define LE_MOVHZ MOVHBR
+#endif
+
 #define POLY1305_ADD(msg, h0, h1, h2, t0, t1, t2) \
 #define POLY1305_ADD(msg, h0, h1, h2, t0, t1, t2) \
-	MOVD (msg), t0;  \
-	MOVD 8(msg), t1; \
+	LE_MOVD (msg)( R0), t0; \
+	LE_MOVD (msg)(R24), t1; \
 	MOVD $1, t2;     \
 	MOVD $1, t2;     \
 	ADDC t0, h0, h0; \
 	ADDC t0, h0, h0; \
 	ADDE t1, h1, h1; \
 	ADDE t1, h1, h1; \
@@ -20,15 +29,14 @@
 
 
 #define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3, t4, t5) \
 #define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3, t4, t5) \
 	MULLD  r0, h0, t0;  \
 	MULLD  r0, h0, t0;  \
-	MULLD  r0, h1, t4;  \
 	MULHDU r0, h0, t1;  \
 	MULHDU r0, h0, t1;  \
+	MULLD  r0, h1, t4;  \
 	MULHDU r0, h1, t5;  \
 	MULHDU r0, h1, t5;  \
 	ADDC   t4, t1, t1;  \
 	ADDC   t4, t1, t1;  \
 	MULLD  r0, h2, t2;  \
 	MULLD  r0, h2, t2;  \
-	ADDZE  t5;          \
 	MULHDU r1, h0, t4;  \
 	MULHDU r1, h0, t4;  \
 	MULLD  r1, h0, h0;  \
 	MULLD  r1, h0, h0;  \
-	ADD    t5, t2, t2;  \
+	ADDE   t5, t2, t2;  \
 	ADDC   h0, t1, t1;  \
 	ADDC   h0, t1, t1;  \
 	MULLD  h2, r1, t3;  \
 	MULLD  h2, r1, t3;  \
 	ADDZE  t4, h0;      \
 	ADDZE  t4, h0;      \
@@ -38,13 +46,11 @@
 	ADDE   t5, t3, t3;  \
 	ADDE   t5, t3, t3;  \
 	ADDC   h0, t2, t2;  \
 	ADDC   h0, t2, t2;  \
 	MOVD   $-4, t4;     \
 	MOVD   $-4, t4;     \
-	MOVD   t0, h0;      \
-	MOVD   t1, h1;      \
 	ADDZE  t3;          \
 	ADDZE  t3;          \
-	ANDCC  $3, t2, h2;  \
-	AND    t2, t4, t0;  \
+	RLDICL $0, t2, $62, h2; \
+	AND    t2, t4, h0;  \
 	ADDC   t0, h0, h0;  \
 	ADDC   t0, h0, h0;  \
-	ADDE   t3, h1, h1;  \
+	ADDE   t3, t1, h1;  \
 	SLD    $62, t3, t4; \
 	SLD    $62, t3, t4; \
 	SRD    $2, t2;      \
 	SRD    $2, t2;      \
 	ADDZE  h2;          \
 	ADDZE  h2;          \
@@ -54,10 +60,6 @@
 	ADDE   t3, h1, h1;  \
 	ADDE   t3, h1, h1;  \
 	ADDZE  h2
 	ADDZE  h2
 
 
-DATA ·poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF
-DATA ·poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC
-GLOBL ·poly1305Mask<>(SB), RODATA, $16
-
 // func update(state *[7]uint64, msg []byte)
 // func update(state *[7]uint64, msg []byte)
 TEXT ·update(SB), $0-32
 TEXT ·update(SB), $0-32
 	MOVD state+0(FP), R3
 	MOVD state+0(FP), R3
@@ -70,12 +72,15 @@ TEXT ·update(SB), $0-32
 	MOVD 24(R3), R11 // r0
 	MOVD 24(R3), R11 // r0
 	MOVD 32(R3), R12 // r1
 	MOVD 32(R3), R12 // r1
 
 
+	MOVD $8, R24
+
 	CMP R5, $16
 	CMP R5, $16
 	BLT bytes_between_0_and_15
 	BLT bytes_between_0_and_15
 
 
 loop:
 loop:
 	POLY1305_ADD(R4, R8, R9, R10, R20, R21, R22)
 	POLY1305_ADD(R4, R8, R9, R10, R20, R21, R22)
 
 
+	PCALIGN $16
 multiply:
 multiply:
 	POLY1305_MUL(R8, R9, R10, R11, R12, R16, R17, R18, R14, R20, R21)
 	POLY1305_MUL(R8, R9, R10, R11, R12, R16, R17, R18, R14, R20, R21)
 	ADD $-16, R5
 	ADD $-16, R5
@@ -97,7 +102,7 @@ flush_buffer:
 
 
 	// Greater than 8 -- load the rightmost remaining bytes in msg
 	// Greater than 8 -- load the rightmost remaining bytes in msg
 	// and put into R17 (h1)
 	// and put into R17 (h1)
-	MOVD (R4)(R21), R17
+	LE_MOVD (R4)(R21), R17
 	MOVD $16, R22
 	MOVD $16, R22
 
 
 	// Find the offset to those bytes
 	// Find the offset to those bytes
@@ -121,7 +126,7 @@ just1:
 	BLT less8
 	BLT less8
 
 
 	// Exactly 8
 	// Exactly 8
-	MOVD (R4), R16
+	LE_MOVD (R4), R16
 
 
 	CMP R17, $0
 	CMP R17, $0
 
 
@@ -136,7 +141,7 @@ less8:
 	MOVD  $0, R22   // shift count
 	MOVD  $0, R22   // shift count
 	CMP   R5, $4
 	CMP   R5, $4
 	BLT   less4
 	BLT   less4
-	MOVWZ (R4), R16
+	LE_MOVWZ (R4), R16
 	ADD   $4, R4
 	ADD   $4, R4
 	ADD   $-4, R5
 	ADD   $-4, R5
 	MOVD  $32, R22
 	MOVD  $32, R22
@@ -144,7 +149,7 @@ less8:
 less4:
 less4:
 	CMP   R5, $2
 	CMP   R5, $2
 	BLT   less2
 	BLT   less2
-	MOVHZ (R4), R21
+	LE_MOVHZ (R4), R21
 	SLD   R22, R21, R21
 	SLD   R22, R21, R21
 	OR    R16, R21, R16
 	OR    R16, R21, R16
 	ADD   $16, R22
 	ADD   $16, R22

+ 2 - 2
psiphon/common/crypto/internal/testenv/exec.go

@@ -57,8 +57,8 @@ func CommandContext(t testing.TB, ctx context.Context, name string, args ...stri
 			// grace periods to clean up: one for the delay between the first
 			// grace periods to clean up: one for the delay between the first
 			// termination signal being sent (via the Cancel callback when the Context
 			// termination signal being sent (via the Cancel callback when the Context
 			// expires) and the process being forcibly terminated (via the WaitDelay
 			// expires) and the process being forcibly terminated (via the WaitDelay
-			// field), and a second one for the delay becween the process being
-			// terminated and and the test logging its output for debugging.
+			// field), and a second one for the delay between the process being
+			// terminated and the test logging its output for debugging.
 			//
 			//
 			// (We want to ensure that the test process itself has enough time to
 			// (We want to ensure that the test process itself has enough time to
 			// log the output before it is also terminated.)
 			// log the output before it is also terminated.)

+ 1 - 1
psiphon/common/crypto/nacl/secretbox/secretbox.go

@@ -32,7 +32,7 @@ chunk size.
 
 
 This package is interoperable with NaCl: https://nacl.cr.yp.to/secretbox.html.
 This package is interoperable with NaCl: https://nacl.cr.yp.to/secretbox.html.
 */
 */
-package secretbox // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/nacl/secretbox"
+package secretbox
 
 
 import (
 import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/internal/poly1305"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/internal/poly1305"

+ 1 - 1
psiphon/common/crypto/ssh/agent/client.go

@@ -10,7 +10,7 @@
 // References:
 // References:
 //
 //
 //	[PROTOCOL.agent]: https://tools.ietf.org/html/draft-miller-ssh-agent-00
 //	[PROTOCOL.agent]: https://tools.ietf.org/html/draft-miller-ssh-agent-00
-package agent // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/agent"
+package agent
 
 
 import (
 import (
 	"bytes"
 	"bytes"

+ 16 - 9
psiphon/common/crypto/ssh/agent/client_test.go

@@ -164,15 +164,23 @@ func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert
 	data := []byte("hello")
 	data := []byte("hello")
 	sig, err := agent.Sign(pubKey, data)
 	sig, err := agent.Sign(pubKey, data)
 	if err != nil {
 	if err != nil {
-		t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
-	}
-
-	if err := pubKey.Verify(data, sig); err != nil {
-		t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
+		t.Logf("sign failed with key type %q", pubKey.Type())
+		// In integration tests ssh-rsa (SHA1 signatures) may be disabled for
+		// security reasons, we check SHA-2 variants later.
+		if pubKey.Type() != ssh.KeyAlgoRSA && pubKey.Type() != ssh.CertAlgoRSAv01 {
+			t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
+		}
+	} else {
+		if err := pubKey.Verify(data, sig); err != nil {
+			t.Logf("verify failed with key type %q", pubKey.Type())
+			if pubKey.Type() != ssh.KeyAlgoRSA {
+				t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
+			}
+		}
 	}
 	}
 
 
 	// For tests on RSA keys, try signing with SHA-256 and SHA-512 flags
 	// For tests on RSA keys, try signing with SHA-256 and SHA-512 flags
-	if pubKey.Type() == "ssh-rsa" {
+	if pubKey.Type() == ssh.KeyAlgoRSA {
 		sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) {
 		sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) {
 			sig, err = agent.SignWithFlags(pubKey, data, flag)
 			sig, err = agent.SignWithFlags(pubKey, data, flag)
 			if err != nil {
 			if err != nil {
@@ -185,7 +193,6 @@ func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert
 				t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
 				t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
 			}
 			}
 		}
 		}
-		sshFlagTest(0, ssh.KeyAlgoRSA)
 		sshFlagTest(SignatureFlagRsaSha256, ssh.KeyAlgoRSASHA256)
 		sshFlagTest(SignatureFlagRsaSha256, ssh.KeyAlgoRSASHA256)
 		sshFlagTest(SignatureFlagRsaSha512, ssh.KeyAlgoRSASHA512)
 		sshFlagTest(SignatureFlagRsaSha512, ssh.KeyAlgoRSASHA512)
 	}
 	}
@@ -244,7 +251,7 @@ func TestMalformedRequests(t *testing.T) {
 }
 }
 
 
 func TestAgent(t *testing.T) {
 func TestAgent(t *testing.T) {
-	for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
+	for _, keyType := range []string{"rsa", "ecdsa", "ed25519"} {
 		testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
 		testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
 		testKeyringAgent(t, testPrivateKeys[keyType], nil, 0)
 		testKeyringAgent(t, testPrivateKeys[keyType], nil, 0)
 	}
 	}
@@ -402,7 +409,7 @@ func testLockAgent(agent Agent, t *testing.T) {
 	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil {
 	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil {
 		t.Errorf("Add: %v", err)
 		t.Errorf("Add: %v", err)
 	}
 	}
-	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil {
+	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["ecdsa"], Comment: "comment ecdsa"}); err != nil {
 		t.Errorf("Add: %v", err)
 		t.Errorf("Add: %v", err)
 	}
 	}
 	if keys, err := agent.List(); err != nil {
 	if keys, err := agent.List(); err != nil {

+ 9 - 0
psiphon/common/crypto/ssh/agent/keyring.go

@@ -175,6 +175,15 @@ func (r *keyring) Add(key AddedKey) error {
 		p.expire = &t
 		p.expire = &t
 	}
 	}
 
 
+	// If we already have a Signer with the same public key, replace it with the
+	// new one.
+	for idx, k := range r.keys {
+		if bytes.Equal(k.signer.PublicKey().Marshal(), p.signer.PublicKey().Marshal()) {
+			r.keys[idx] = p
+			return nil
+		}
+	}
+
 	r.keys = append(r.keys, p)
 	r.keys = append(r.keys, p)
 
 
 	return nil
 	return nil

+ 46 - 0
psiphon/common/crypto/ssh/agent/keyring_test.go

@@ -29,6 +29,10 @@ func validateListedKeys(t *testing.T, a Agent, expectedKeys []string) {
 		t.Fatalf("failed to list keys: %v", err)
 		t.Fatalf("failed to list keys: %v", err)
 		return
 		return
 	}
 	}
+	if len(listedKeys) != len(expectedKeys) {
+		t.Fatalf("expeted %d key, got %d", len(expectedKeys), len(listedKeys))
+		return
+	}
 	actualKeys := make(map[string]bool)
 	actualKeys := make(map[string]bool)
 	for _, key := range listedKeys {
 	for _, key := range listedKeys {
 		actualKeys[key.Comment] = true
 		actualKeys[key.Comment] = true
@@ -74,3 +78,45 @@ func TestKeyringAddingAndRemoving(t *testing.T) {
 	}
 	}
 	validateListedKeys(t, k, []string{})
 	validateListedKeys(t, k, []string{})
 }
 }
+
+func TestAddDuplicateKey(t *testing.T) {
+	keyNames := []string{"rsa", "user"}
+
+	k := NewKeyring()
+	for _, keyName := range keyNames {
+		addTestKey(t, k, keyName)
+	}
+	validateListedKeys(t, k, keyNames)
+	// Add the keys again.
+	for _, keyName := range keyNames {
+		addTestKey(t, k, keyName)
+	}
+	validateListedKeys(t, k, keyNames)
+	// Add an existing key with an updated comment.
+	keyName := keyNames[0]
+	addedKey := AddedKey{
+		PrivateKey: testPrivateKeys[keyName],
+		Comment:    "comment updated",
+	}
+	err := k.Add(addedKey)
+	if err != nil {
+		t.Fatalf("failed to add key %q: %v", keyName, err)
+	}
+	// Check the that key is found and the comment was updated.
+	keys, err := k.List()
+	if err != nil {
+		t.Fatalf("failed to list keys: %v", err)
+	}
+	if len(keys) != len(keyNames) {
+		t.Fatalf("expected %d keys, got %d", len(keyNames), len(keys))
+	}
+	isFound := false
+	for _, key := range keys {
+		if key.Comment == addedKey.Comment {
+			isFound = true
+		}
+	}
+	if !isFound {
+		t.Fatal("key with the updated comment not found")
+	}
+}

+ 15 - 17
psiphon/common/crypto/ssh/certs_test.go

@@ -15,14 +15,12 @@ import (
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
 	"time"
 	"time"
-)
 
 
-// Cert generated by ssh-keygen 6.0p1 Debian-4.
-// % ssh-keygen -s ca-key -I test user-key
-const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/testdata"
+)
 
 
 func TestParseCert(t *testing.T) {
 func TestParseCert(t *testing.T) {
-	authKeyBytes := []byte(exampleSSHCert)
+	authKeyBytes := bytes.TrimSuffix(testdata.SSHCertificates["rsa"], []byte(" host.example.com\n"))
 
 
 	key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
 	key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
 	if err != nil {
 	if err != nil {
@@ -103,7 +101,7 @@ func TestParseCertWithOptions(t *testing.T) {
 }
 }
 
 
 func TestValidateCert(t *testing.T) {
 func TestValidateCert(t *testing.T) {
-	key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert))
+	key, _, _, _, err := ParseAuthorizedKey(testdata.SSHCertificates["rsa-user-testcertificate"])
 	if err != nil {
 	if err != nil {
 		t.Fatalf("ParseAuthorizedKey: %v", err)
 		t.Fatalf("ParseAuthorizedKey: %v", err)
 	}
 	}
@@ -116,7 +114,7 @@ func TestValidateCert(t *testing.T) {
 		return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
 		return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
 	}
 	}
 
 
-	if err := checker.CheckCert("user", validCert); err != nil {
+	if err := checker.CheckCert("testcertificate", validCert); err != nil {
 		t.Errorf("Unable to validate certificate: %v", err)
 		t.Errorf("Unable to validate certificate: %v", err)
 	}
 	}
 	invalidCert := &Certificate{
 	invalidCert := &Certificate{
@@ -125,7 +123,7 @@ func TestValidateCert(t *testing.T) {
 		ValidBefore:  CertTimeInfinity,
 		ValidBefore:  CertTimeInfinity,
 		Signature:    &Signature{},
 		Signature:    &Signature{},
 	}
 	}
-	if err := checker.CheckCert("user", invalidCert); err == nil {
+	if err := checker.CheckCert("testcertificate", invalidCert); err == nil {
 		t.Error("Invalid cert signature passed validation")
 		t.Error("Invalid cert signature passed validation")
 	}
 	}
 }
 }
@@ -367,21 +365,21 @@ func TestCertTypes(t *testing.T) {
 
 
 func TestCertSignWithMultiAlgorithmSigner(t *testing.T) {
 func TestCertSignWithMultiAlgorithmSigner(t *testing.T) {
 	type testcase struct {
 	type testcase struct {
-		sigAlgo   string
-		algoritms []string
+		sigAlgo    string
+		algorithms []string
 	}
 	}
 	cases := []testcase{
 	cases := []testcase{
 		{
 		{
-			sigAlgo:   KeyAlgoRSA,
-			algoritms: []string{KeyAlgoRSA, KeyAlgoRSASHA512},
+			sigAlgo:    KeyAlgoRSA,
+			algorithms: []string{KeyAlgoRSA, KeyAlgoRSASHA512},
 		},
 		},
 		{
 		{
-			sigAlgo:   KeyAlgoRSASHA256,
-			algoritms: []string{KeyAlgoRSASHA256, KeyAlgoRSA, KeyAlgoRSASHA512},
+			sigAlgo:    KeyAlgoRSASHA256,
+			algorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSA, KeyAlgoRSASHA512},
 		},
 		},
 		{
 		{
-			sigAlgo:   KeyAlgoRSASHA512,
-			algoritms: []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256},
+			sigAlgo:    KeyAlgoRSASHA512,
+			algorithms: []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256},
 		},
 		},
 	}
 	}
 
 
@@ -393,7 +391,7 @@ func TestCertSignWithMultiAlgorithmSigner(t *testing.T) {
 
 
 	for _, c := range cases {
 	for _, c := range cases {
 		t.Run(c.sigAlgo, func(t *testing.T) {
 		t.Run(c.sigAlgo, func(t *testing.T) {
-			signer, err := NewSignerWithAlgorithms(testSigners["rsa"].(AlgorithmSigner), c.algoritms)
+			signer, err := NewSignerWithAlgorithms(testSigners["rsa"].(AlgorithmSigner), c.algorithms)
 			if err != nil {
 			if err != nil {
 				t.Fatalf("NewSignerWithAlgorithms error: %v", err)
 				t.Fatalf("NewSignerWithAlgorithms error: %v", err)
 			}
 			}

+ 20 - 3
psiphon/common/crypto/ssh/client_auth.go

@@ -71,6 +71,10 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
 	for auth := AuthMethod(new(noneAuth)); auth != nil; {
 	for auth := AuthMethod(new(noneAuth)); auth != nil; {
 		ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
 		ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
 		if err != nil {
 		if err != nil {
+			// On disconnect, return error immediately
+			if _, ok := err.(*disconnectMsg); ok {
+				return err
+			}
 			// We return the error later if there is no other method left to
 			// We return the error later if there is no other method left to
 			// try.
 			// try.
 			ok = authFailure
 			ok = authFailure
@@ -404,10 +408,10 @@ func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, e
 		return false, err
 		return false, err
 	}
 	}
 
 
-	return confirmKeyAck(key, algo, c)
+	return confirmKeyAck(key, c)
 }
 }
 
 
-func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
+func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
 	pubKey := key.Marshal()
 	pubKey := key.Marshal()
 
 
 	for {
 	for {
@@ -425,7 +429,15 @@ func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
 			if err := Unmarshal(packet, &msg); err != nil {
 			if err := Unmarshal(packet, &msg); err != nil {
 				return false, err
 				return false, err
 			}
 			}
-			if msg.Algo != algo || !bytes.Equal(msg.PubKey, pubKey) {
+			// According to RFC 4252 Section 7 the algorithm in
+			// SSH_MSG_USERAUTH_PK_OK should match that of the request but some
+			// servers send the key type instead. OpenSSH allows any algorithm
+			// that matches the public key, so we do the same.
+			// https://github.com/openssh/openssh-portable/blob/86bdd385/sshconnect2.c#L709
+			if !contains(algorithmsForKeyFormat(key.Type()), msg.Algo) {
+				return false, nil
+			}
+			if !bytes.Equal(msg.PubKey, pubKey) {
 				return false, nil
 				return false, nil
 			}
 			}
 			return true, nil
 			return true, nil
@@ -543,6 +555,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 	}
 	}
 
 
 	gotMsgExtInfo := false
 	gotMsgExtInfo := false
+	gotUserAuthInfoRequest := false
 	for {
 	for {
 		packet, err := c.readPacket()
 		packet, err := c.readPacket()
 		if err != nil {
 		if err != nil {
@@ -573,6 +586,9 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 			if msg.PartialSuccess {
 			if msg.PartialSuccess {
 				return authPartialSuccess, msg.Methods, nil
 				return authPartialSuccess, msg.Methods, nil
 			}
 			}
+			if !gotUserAuthInfoRequest {
+				return authFailure, msg.Methods, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
+			}
 			return authFailure, msg.Methods, nil
 			return authFailure, msg.Methods, nil
 		case msgUserAuthSuccess:
 		case msgUserAuthSuccess:
 			return authSuccess, nil, nil
 			return authSuccess, nil, nil
@@ -584,6 +600,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 		if err := Unmarshal(packet, &msg); err != nil {
 		if err := Unmarshal(packet, &msg); err != nil {
 			return authFailure, nil, err
 			return authFailure, nil, err
 		}
 		}
+		gotUserAuthInfoRequest = true
 
 
 		// Manually unpack the prompt/echo pairs.
 		// Manually unpack the prompt/echo pairs.
 		rest := msg.Prompts
 		rest := msg.Prompts

+ 108 - 8
psiphon/common/crypto/ssh/client_auth_test.go

@@ -38,7 +38,7 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
 	return err
 	return err
 }
 }
 
 
-// tryAuth runs a handshake with a given config against an SSH server
+// tryAuthWithGSSAPIWithMICConfig runs a handshake with a given config against an SSH server
 // with a given GSSAPIWithMICConfig and config serverConfig. Returns both client and server side errors.
 // with a given GSSAPIWithMICConfig and config serverConfig. Returns both client and server side errors.
 func tryAuthWithGSSAPIWithMICConfig(t *testing.T, clientConfig *ClientConfig, gssAPIWithMICConfig *GSSAPIWithMICConfig) error {
 func tryAuthWithGSSAPIWithMICConfig(t *testing.T, clientConfig *ClientConfig, gssAPIWithMICConfig *GSSAPIWithMICConfig) error {
 	err, _ := tryAuthBothSides(t, clientConfig, gssAPIWithMICConfig)
 	err, _ := tryAuthBothSides(t, clientConfig, gssAPIWithMICConfig)
@@ -641,17 +641,28 @@ func TestClientAuthMaxAuthTries(t *testing.T) {
 		defer c1.Close()
 		defer c1.Close()
 		defer c2.Close()
 		defer c2.Close()
 
 
-		go newServer(c1, serverConfig)
-		_, _, _, err = NewClientConn(c2, "", clientConfig)
-		if tries > 2 {
-			if err == nil {
+		errCh := make(chan error, 1)
+
+		go func() {
+			_, err := newServer(c1, serverConfig)
+			errCh <- err
+		}()
+		_, _, _, cliErr := NewClientConn(c2, "", clientConfig)
+		srvErr := <-errCh
+
+		if tries > serverConfig.MaxAuthTries {
+			if cliErr == nil {
 				t.Fatalf("client: got no error, want %s", expectedErr)
 				t.Fatalf("client: got no error, want %s", expectedErr)
-			} else if err.Error() != expectedErr.Error() {
+			} else if cliErr.Error() != expectedErr.Error() {
 				t.Fatalf("client: got %s, want %s", err, expectedErr)
 				t.Fatalf("client: got %s, want %s", err, expectedErr)
 			}
 			}
+			var authErr *ServerAuthError
+			if !errors.As(srvErr, &authErr) {
+				t.Errorf("expected ServerAuthError, got: %v", srvErr)
+			}
 		} else {
 		} else {
-			if err != nil {
-				t.Fatalf("client: got %s, want no error", err)
+			if cliErr != nil {
+				t.Fatalf("client: got %s, want no error", cliErr)
 			}
 			}
 		}
 		}
 	}
 	}
@@ -1282,3 +1293,92 @@ func TestCertAuthOpenSSHCompat(t *testing.T) {
 		t.Fatalf("unable to dial remote side: %s", err)
 		t.Fatalf("unable to dial remote side: %s", err)
 	}
 	}
 }
 }
+
+func TestKeyboardInteractiveAuthEarlyFail(t *testing.T) {
+	const maxAuthTries = 2
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	// Start testserver
+	serverConfig := &ServerConfig{
+		MaxAuthTries: maxAuthTries,
+		KeyboardInteractiveCallback: func(c ConnMetadata,
+			client KeyboardInteractiveChallenge) (*Permissions, error) {
+			// Fail keyboard-interactive authentication early before
+			// any prompt is sent to client.
+			return nil, errors.New("keyboard-interactive auth failed")
+		},
+		PasswordCallback: func(c ConnMetadata,
+			pass []byte) (*Permissions, error) {
+			if string(pass) == clientPassword {
+				return nil, nil
+			}
+			return nil, errors.New("password auth failed")
+		},
+	}
+	serverConfig.AddHostKey(testSigners["rsa"])
+
+	serverDone := make(chan struct{})
+	go func() {
+		defer func() { serverDone <- struct{}{} }()
+		conn, chans, reqs, err := NewServerConn(c2, serverConfig)
+		if err != nil {
+			return
+		}
+		_ = conn.Close()
+
+		discarderDone := make(chan struct{})
+		go func() {
+			defer func() { discarderDone <- struct{}{} }()
+			DiscardRequests(reqs)
+		}()
+		for newChannel := range chans {
+			newChannel.Reject(Prohibited,
+				"testserver not accepting requests")
+		}
+
+		<-discarderDone
+	}()
+
+	// Connect to testserver, expect KeyboardInteractive() to be not called,
+	// PasswordCallback() to be called and connection to succeed.
+	passwordCallbackCalled := false
+	clientConfig := &ClientConfig{
+		User: "testuser",
+		Auth: []AuthMethod{
+			RetryableAuthMethod(KeyboardInteractive(func(name,
+				instruction string, questions []string,
+				echos []bool) ([]string, error) {
+				t.Errorf("unexpected call to KeyboardInteractive()")
+				return []string{clientPassword}, nil
+			}), maxAuthTries),
+			RetryableAuthMethod(PasswordCallback(func() (secret string,
+				err error) {
+				t.Logf("PasswordCallback()")
+				passwordCallbackCalled = true
+				return clientPassword, nil
+			}), maxAuthTries),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	conn, _, _, err := NewClientConn(c1, "", clientConfig)
+	if err != nil {
+		t.Errorf("unexpected NewClientConn() error: %v", err)
+	}
+	if conn != nil {
+		conn.Close()
+	}
+
+	// Wait for server to finish.
+	<-serverDone
+
+	if !passwordCallbackCalled {
+		t.Errorf("expected PasswordCallback() to be called")
+	}
+}

+ 1 - 1
psiphon/common/crypto/ssh/doc.go

@@ -20,4 +20,4 @@ References:
 This package does not fall under the stability promise of the Go language itself,
 This package does not fall under the stability promise of the Go language itself,
 so its API may be changed when pressing needs arise.
 so its API may be changed when pressing needs arise.
 */
 */
-package ssh // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
+package ssh

+ 1 - 1
psiphon/common/crypto/ssh/example_test.go

@@ -384,7 +384,7 @@ func ExampleCertificate_SignCert() {
 	}
 	}
 	mas, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256})
 	mas, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256})
 	if err != nil {
 	if err != nil {
-		log.Fatal("unable to create signer with algoritms: ", err)
+		log.Fatal("unable to create signer with algorithms: ", err)
 	}
 	}
 	certificate := ssh.Certificate{
 	certificate := ssh.Certificate{
 		Key:      publicKey,
 		Key:      publicKey,

+ 49 - 12
psiphon/common/crypto/ssh/handshake.go

@@ -30,6 +30,11 @@ const debugHandshake = false
 // quickly.
 // quickly.
 const chanSize = 16
 const chanSize = 16
 
 
+// maxPendingPackets sets the maximum number of packets to queue while waiting
+// for KEX to complete. This limits the total pending data to maxPendingPackets
+// * maxPacket bytes, which is ~16.8MB.
+const maxPendingPackets = 64
+
 // keyingTransport is a packet based transport that supports key
 // keyingTransport is a packet based transport that supports key
 // changes. It need not be thread-safe. It should pass through
 // changes. It need not be thread-safe. It should pass through
 // msgNewKeys in both directions.
 // msgNewKeys in both directions.
@@ -78,13 +83,22 @@ type handshakeTransport struct {
 	incoming  chan []byte
 	incoming  chan []byte
 	readError error
 	readError error
 
 
-	mu               sync.Mutex
-	writeError       error
-	sentInitPacket   []byte
-	sentInitMsg      *kexInitMsg
-	pendingPackets   [][]byte // Used when a key exchange is in progress.
+	mu sync.Mutex
+	// Condition for the above mutex. It is used to notify a completed key
+	// exchange or a write failure. Writes can wait for this condition while a
+	// key exchange is in progress.
+	writeCond      *sync.Cond
+	writeError     error
+	sentInitPacket []byte
+	sentInitMsg    *kexInitMsg
+	// Used to queue writes when a key exchange is in progress. The length is
+	// limited by pendingPacketsSize. Once full, writes will block until the key
+	// exchange is completed or an error occurs. If not empty, it is emptied
+	// all at once when the key exchange is completed in kexLoop.
+	pendingPackets   [][]byte
 	writePacketsLeft uint32
 	writePacketsLeft uint32
 	writeBytesLeft   int64
 	writeBytesLeft   int64
+	userAuthComplete bool // whether the user authentication phase is complete
 
 
 	// If the read loop wants to schedule a kex, it pings this
 	// If the read loop wants to schedule a kex, it pings this
 	// channel, and the write loop will send out a kex
 	// channel, and the write loop will send out a kex
@@ -138,6 +152,7 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
 
 
 		config: config,
 		config: config,
 	}
 	}
+	t.writeCond = sync.NewCond(&t.mu)
 	t.resetReadThresholds()
 	t.resetReadThresholds()
 	t.resetWriteThresholds()
 	t.resetWriteThresholds()
 
 
@@ -264,6 +279,7 @@ func (t *handshakeTransport) recordWriteError(err error) {
 	defer t.mu.Unlock()
 	defer t.mu.Unlock()
 	if t.writeError == nil && err != nil {
 	if t.writeError == nil && err != nil {
 		t.writeError = err
 		t.writeError = err
+		t.writeCond.Broadcast()
 	}
 	}
 }
 }
 
 
@@ -367,6 +383,8 @@ write:
 			}
 			}
 		}
 		}
 		t.pendingPackets = t.pendingPackets[:0]
 		t.pendingPackets = t.pendingPackets[:0]
+		// Unblock writePacket if waiting for KEX.
+		t.writeCond.Broadcast()
 		t.mu.Unlock()
 		t.mu.Unlock()
 	}
 	}
 
 
@@ -941,26 +959,44 @@ func (t *handshakeTransport) sendKexInit() error {
 	return nil
 	return nil
 }
 }
 
 
+var errSendBannerPhase = errors.New("ssh: SendAuthBanner outside of authentication phase")
+
 func (t *handshakeTransport) writePacket(p []byte) error {
 func (t *handshakeTransport) writePacket(p []byte) error {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
 	switch p[0] {
 	switch p[0] {
 	case msgKexInit:
 	case msgKexInit:
 		return errors.New("ssh: only handshakeTransport can send kexInit")
 		return errors.New("ssh: only handshakeTransport can send kexInit")
 	case msgNewKeys:
 	case msgNewKeys:
 		return errors.New("ssh: only handshakeTransport can send newKeys")
 		return errors.New("ssh: only handshakeTransport can send newKeys")
+	case msgUserAuthBanner:
+		if t.userAuthComplete {
+			return errSendBannerPhase
+		}
+	case msgUserAuthSuccess:
+		t.userAuthComplete = true
 	}
 	}
 
 
-	t.mu.Lock()
-	defer t.mu.Unlock()
 	if t.writeError != nil {
 	if t.writeError != nil {
 		return t.writeError
 		return t.writeError
 	}
 	}
 
 
 	if t.sentInitMsg != nil {
 	if t.sentInitMsg != nil {
-		// Copy the packet so the writer can reuse the buffer.
-		cp := make([]byte, len(p))
-		copy(cp, p)
-		t.pendingPackets = append(t.pendingPackets, cp)
-		return nil
+		if len(t.pendingPackets) < maxPendingPackets {
+			// Copy the packet so the writer can reuse the buffer.
+			cp := make([]byte, len(p))
+			copy(cp, p)
+			t.pendingPackets = append(t.pendingPackets, cp)
+			return nil
+		}
+		for t.sentInitMsg != nil {
+			// Block and wait for KEX to complete or an error.
+			t.writeCond.Wait()
+			if t.writeError != nil {
+				return t.writeError
+			}
+		}
 	}
 	}
 
 
 	if t.writeBytesLeft > 0 {
 	if t.writeBytesLeft > 0 {
@@ -977,6 +1013,7 @@ func (t *handshakeTransport) writePacket(p []byte) error {
 
 
 	if err := t.pushPacket(p); err != nil {
 	if err := t.pushPacket(p); err != nil {
 		t.writeError = err
 		t.writeError = err
+		t.writeCond.Broadcast()
 	}
 	}
 
 
 	return nil
 	return nil

+ 220 - 0
psiphon/common/crypto/ssh/handshake_test.go

@@ -539,6 +539,226 @@ func TestDisconnect(t *testing.T) {
 	}
 	}
 }
 }
 
 
+type mockKeyingTransport struct {
+	packetConn
+	kexInitAllowed chan struct{}
+	kexInitSent    chan struct{}
+}
+
+func (n *mockKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
+	return nil
+}
+
+func (n *mockKeyingTransport) writePacket(packet []byte) error {
+	if packet[0] == msgKexInit {
+		<-n.kexInitAllowed
+		n.kexInitSent <- struct{}{}
+	}
+	return n.packetConn.writePacket(packet)
+}
+
+func (n *mockKeyingTransport) readPacket() ([]byte, error) {
+	return n.packetConn.readPacket()
+}
+
+func (n *mockKeyingTransport) setStrictMode() error { return nil }
+
+func (n *mockKeyingTransport) setInitialKEXDone() {}
+
+func TestHandshakePendingPacketsWait(t *testing.T) {
+	a, b := memPipe()
+
+	trS := &mockKeyingTransport{
+		packetConn:     a,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trS.kexInitAllowed <- struct{}{}
+
+	trC := &mockKeyingTransport{
+		packetConn:     b,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trC.kexInitAllowed <- struct{}{}
+
+	clientConf := &ClientConfig{
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+	server := newServerTransport(trS, v, v, serverConf)
+
+	if err := server.waitSession(); err != nil {
+		t.Fatalf("server.waitSession: %v", err)
+	}
+	if err := client.waitSession(); err != nil {
+		t.Fatalf("client.waitSession: %v", err)
+	}
+
+	<-trC.kexInitSent
+	<-trS.kexInitSent
+
+	// Allow and request new KEX server side.
+	trS.kexInitAllowed <- struct{}{}
+	server.requestKeyExchange()
+	// Wait until the KEX init is sent.
+	<-trS.kexInitSent
+	// The client is not allowed to respond to the KEX, so writes will be
+	// blocked on the server side once the packets queue is full.
+	for i := 0; i < maxPendingPackets; i++ {
+		p := []byte{msgRequestSuccess, byte(i)}
+		if err := server.writePacket(p); err != nil {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}
+	// The packets queue is now full, the next write will block.
+	server.mu.Lock()
+	if len(server.pendingPackets) != maxPendingPackets {
+		t.Errorf("unexpected pending packets size; got: %d, want: %d", len(server.pendingPackets), maxPendingPackets)
+	}
+	server.mu.Unlock()
+
+	writeDone := make(chan struct{})
+	go func() {
+		defer close(writeDone)
+
+		p := []byte{msgRequestSuccess, byte(65)}
+		// This write will block until KEX completes.
+		err := server.writePacket(p)
+		if err != nil {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}()
+
+	// Consume packets on the client side
+	readDone := make(chan bool)
+	go func() {
+		defer close(readDone)
+
+		for {
+			if _, err := client.readPacket(); err != nil {
+				if err != io.EOF {
+					t.Errorf("unexpected read error: %v", err)
+				}
+				break
+			}
+		}
+	}()
+
+	// Allow the client to reply to the KEX and so unblock the write goroutine.
+	trC.kexInitAllowed <- struct{}{}
+	<-trC.kexInitSent
+	<-writeDone
+	// Close the client to unblock the read goroutine.
+	client.Close()
+	<-readDone
+	server.Close()
+}
+
+func TestHandshakePendingPacketsError(t *testing.T) {
+	a, b := memPipe()
+
+	trS := &mockKeyingTransport{
+		packetConn:     a,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trS.kexInitAllowed <- struct{}{}
+
+	trC := &mockKeyingTransport{
+		packetConn:     b,
+		kexInitAllowed: make(chan struct{}, 2),
+		kexInitSent:    make(chan struct{}, 2),
+	}
+	// Allow the first KEX.
+	trC.kexInitAllowed <- struct{}{}
+
+	clientConf := &ClientConfig{
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.SetDefaults()
+	server := newServerTransport(trS, v, v, serverConf)
+
+	if err := server.waitSession(); err != nil {
+		t.Fatalf("server.waitSession: %v", err)
+	}
+	if err := client.waitSession(); err != nil {
+		t.Fatalf("client.waitSession: %v", err)
+	}
+
+	<-trC.kexInitSent
+	<-trS.kexInitSent
+
+	// Allow and request new KEX server side.
+	trS.kexInitAllowed <- struct{}{}
+	server.requestKeyExchange()
+	// Wait until the KEX init is sent.
+	<-trS.kexInitSent
+	// The client is not allowed to respond to the KEX, so writes will be
+	// blocked on the server side once the packets queue is full.
+	for i := 0; i < maxPendingPackets; i++ {
+		p := []byte{msgRequestSuccess, byte(i)}
+		if err := server.writePacket(p); err != nil {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}
+	// The packets queue is now full, the next write will block.
+	writeDone := make(chan struct{})
+	go func() {
+		defer close(writeDone)
+
+		p := []byte{msgRequestSuccess, byte(65)}
+		// This write will block until KEX completes.
+		err := server.writePacket(p)
+		if err != io.EOF {
+			t.Errorf("unexpected write error: %v", err)
+		}
+	}()
+
+	// Consume packets on the client side
+	readDone := make(chan bool)
+	go func() {
+		defer close(readDone)
+
+		for {
+			if _, err := client.readPacket(); err != nil {
+				if err != io.EOF {
+					t.Errorf("unexpected read error: %v", err)
+				}
+				break
+			}
+		}
+	}()
+
+	// Close the server to unblock the write after an error
+	server.Close()
+	<-writeDone
+	// Unblock the pending write and close the client to unblock the read
+	// goroutine.
+	trC.kexInitAllowed <- struct{}{}
+	client.Close()
+	<-readDone
+}
+
 func TestHandshakeRekeyDefault(t *testing.T) {
 func TestHandshakeRekeyDefault(t *testing.T) {
 	clientConf := &ClientConfig{
 	clientConf := &ClientConfig{
 		Config: Config{
 		Config: Config{

+ 51 - 1
psiphon/common/crypto/ssh/keys.go

@@ -488,7 +488,49 @@ func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
 	h := hash.New()
 	h := hash.New()
 	h.Write(data)
 	h.Write(data)
 	digest := h.Sum(nil)
 	digest := h.Sum(nil)
-	return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, sig.Blob)
+
+	// Signatures in PKCS1v15 must match the key's modulus in
+	// length. However with SSH, some signers provide RSA
+	// signatures which are missing the MSB 0's of the bignum
+	// represented. With ssh-rsa signatures, this is encouraged by
+	// the spec (even though e.g. OpenSSH will give the full
+	// length unconditionally). With rsa-sha2-* signatures, the
+	// verifier is allowed to support these, even though they are
+	// out of spec. See RFC 4253 Section 6.6 for ssh-rsa and RFC
+	// 8332 Section 3 for rsa-sha2-* details.
+	//
+	// In practice:
+	// * OpenSSH always allows "short" signatures:
+	//   https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L526
+	//   but always generates padded signatures:
+	//   https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L439
+	//
+	// * PuTTY versions 0.81 and earlier will generate short
+	//   signatures for all RSA signature variants. Note that
+	//   PuTTY is embedded in other software, such as WinSCP and
+	//   FileZilla. At the time of writing, a patch has been
+	//   applied to PuTTY to generate padded signatures for
+	//   rsa-sha2-*, but not yet released:
+	//   https://git.tartarus.org/?p=simon/putty.git;a=commitdiff;h=a5bcf3d384e1bf15a51a6923c3724cbbee022d8e
+	//
+	// * SSH.NET versions 2024.0.0 and earlier will generate short
+	//   signatures for all RSA signature variants, fixed in 2024.1.0:
+	//   https://github.com/sshnet/SSH.NET/releases/tag/2024.1.0
+	//
+	// As a result, we pad these up to the key size by inserting
+	// leading 0's.
+	//
+	// Note that support for short signatures with rsa-sha2-* may
+	// be removed in the future due to such signatures not being
+	// allowed by the spec.
+	blob := sig.Blob
+	keySize := (*rsa.PublicKey)(r).Size()
+	if len(blob) < keySize {
+		padded := make([]byte, keySize)
+		copy(padded[keySize-len(blob):], blob)
+		blob = padded
+	}
+	return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, blob)
 }
 }
 
 
 func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
 func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
@@ -904,6 +946,10 @@ func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error {
 	return errors.New("ssh: signature did not verify")
 	return errors.New("ssh: signature did not verify")
 }
 }
 
 
+func (k *skECDSAPublicKey) CryptoPublicKey() crypto.PublicKey {
+	return &k.PublicKey
+}
+
 type skEd25519PublicKey struct {
 type skEd25519PublicKey struct {
 	// application is a URL-like string, typically "ssh:" for SSH.
 	// application is a URL-like string, typically "ssh:" for SSH.
 	// see openssh/PROTOCOL.u2f for details.
 	// see openssh/PROTOCOL.u2f for details.
@@ -1000,6 +1046,10 @@ func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error {
 	return nil
 	return nil
 }
 }
 
 
+func (k *skEd25519PublicKey) CryptoPublicKey() crypto.PublicKey {
+	return k.PublicKey
+}
+
 // NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey,
 // NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey,
 // *ecdsa.PrivateKey or any other crypto.Signer and returns a
 // *ecdsa.PrivateKey or any other crypto.Signer and returns a
 // corresponding Signer instance. ECDSA keys must use P-256, P-384 or
 // corresponding Signer instance. ECDSA keys must use P-256, P-384 or

+ 86 - 2
psiphon/common/crypto/ssh/keys_test.go

@@ -154,6 +154,44 @@ func TestKeySignWithAlgorithmVerify(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestKeySignWithShortSignature(t *testing.T) {
+	signer := testSigners["rsa"].(AlgorithmSigner)
+	pub := signer.PublicKey()
+	// Note: data obtained by empirically trying until a result
+	// starting with 0 appeared
+	tests := []struct {
+		algorithm string
+		data      []byte
+	}{
+		{
+			algorithm: KeyAlgoRSA,
+			data:      []byte("sign me92"),
+		},
+		{
+			algorithm: KeyAlgoRSASHA256,
+			data:      []byte("sign me294"),
+		},
+		{
+			algorithm: KeyAlgoRSASHA512,
+			data:      []byte("sign me60"),
+		},
+	}
+
+	for _, tt := range tests {
+		sig, err := signer.SignWithAlgorithm(rand.Reader, tt.data, tt.algorithm)
+		if err != nil {
+			t.Fatalf("Sign(%T): %v", signer, err)
+		}
+		if sig.Blob[0] != 0 {
+			t.Errorf("%s: Expected signature with a leading 0", tt.algorithm)
+		}
+		sig.Blob = sig.Blob[1:]
+		if err := pub.Verify(tt.data, sig); err != nil {
+			t.Errorf("publicKey.Verify(%s): %v", tt.algorithm, err)
+		}
+	}
+}
+
 func TestParseRSAPrivateKey(t *testing.T) {
 func TestParseRSAPrivateKey(t *testing.T) {
 	key := testPrivateKeys["rsa"]
 	key := testPrivateKeys["rsa"]
 
 
@@ -610,7 +648,7 @@ func TestKnownHostsParsing(t *testing.T) {
 func TestFingerprintLegacyMD5(t *testing.T) {
 func TestFingerprintLegacyMD5(t *testing.T) {
 	pub, _ := getTestKey()
 	pub, _ := getTestKey()
 	fingerprint := FingerprintLegacyMD5(pub)
 	fingerprint := FingerprintLegacyMD5(pub)
-	want := "fb:61:6d:1a:e3:f0:95:45:3c:a0:79:be:4a:93:63:66" // ssh-keygen -lf -E md5 rsa
+	want := "b7:ef:d3:d5:89:29:52:96:9f:df:47:41:4d:15:37:f4" // ssh-keygen -lf -E md5 rsa
 	if fingerprint != want {
 	if fingerprint != want {
 		t.Errorf("got fingerprint %q want %q", fingerprint, want)
 		t.Errorf("got fingerprint %q want %q", fingerprint, want)
 	}
 	}
@@ -619,7 +657,7 @@ func TestFingerprintLegacyMD5(t *testing.T) {
 func TestFingerprintSHA256(t *testing.T) {
 func TestFingerprintSHA256(t *testing.T) {
 	pub, _ := getTestKey()
 	pub, _ := getTestKey()
 	fingerprint := FingerprintSHA256(pub)
 	fingerprint := FingerprintSHA256(pub)
-	want := "SHA256:Anr3LjZK8YVpjrxu79myrW9Hrb/wpcMNpVvTq/RcBm8" // ssh-keygen -lf rsa
+	want := "SHA256:fi5+D7UmDZDE9Q2sAVvvlpcQSIakN4DERdINgXd2AnE" // ssh-keygen -lf rsa
 	if fingerprint != want {
 	if fingerprint != want {
 		t.Errorf("got fingerprint %q want %q", fingerprint, want)
 		t.Errorf("got fingerprint %q want %q", fingerprint, want)
 	}
 	}
@@ -726,3 +764,49 @@ func TestNewSignerWithAlgos(t *testing.T) {
 		t.Error("signer with algos created with restricted algorithms")
 		t.Error("signer with algos created with restricted algorithms")
 	}
 	}
 }
 }
+
+func TestCryptoPublicKey(t *testing.T) {
+	for _, priv := range testSigners {
+		p1 := priv.PublicKey()
+		key, ok := p1.(CryptoPublicKey)
+		if !ok {
+			continue
+		}
+		p2, err := NewPublicKey(key.CryptoPublicKey())
+		if err != nil {
+			t.Fatalf("NewPublicKey(CryptoPublicKey) failed for %s, got: %v", p1.Type(), err)
+		}
+		if !reflect.DeepEqual(p1, p2) {
+			t.Errorf("got %#v in NewPublicKey, want %#v", p2, p1)
+		}
+	}
+	for _, d := range testdata.SKData {
+		p1, _, _, _, err := ParseAuthorizedKey(d.PubKey)
+		if err != nil {
+			t.Fatalf("parseAuthorizedKey returned error: %v", err)
+		}
+		k1, ok := p1.(CryptoPublicKey)
+		if !ok {
+			t.Fatalf("%T does not implement CryptoPublicKey", p1)
+		}
+
+		var p2 PublicKey
+		switch pub := k1.CryptoPublicKey().(type) {
+		case *ecdsa.PublicKey:
+			p2 = &skECDSAPublicKey{
+				application: "ssh:",
+				PublicKey:   *pub,
+			}
+		case ed25519.PublicKey:
+			p2 = &skEd25519PublicKey{
+				application: "ssh:",
+				PublicKey:   pub,
+			}
+		default:
+			t.Fatalf("unexpected type %T from CryptoPublicKey()", pub)
+		}
+		if !reflect.DeepEqual(p1, p2) {
+			t.Errorf("got %#v, want %#v", p2, p1)
+		}
+	}
+}

+ 2 - 0
psiphon/common/crypto/ssh/messages.go

@@ -818,6 +818,8 @@ func decode(packet []byte) (interface{}, error) {
 		return new(userAuthSuccessMsg), nil
 		return new(userAuthSuccessMsg), nil
 	case msgUserAuthFailure:
 	case msgUserAuthFailure:
 		msg = new(userAuthFailureMsg)
 		msg = new(userAuthFailureMsg)
+	case msgUserAuthBanner:
+		msg = new(userAuthBannerMsg)
 	case msgUserAuthPubKeyOk:
 	case msgUserAuthPubKeyOk:
 		msg = new(userAuthPubKeyOkMsg)
 		msg = new(userAuthPubKeyOkMsg)
 	case msgGlobalRequest:
 	case msgGlobalRequest:

+ 56 - 0
psiphon/common/crypto/ssh/messages_test.go

@@ -206,6 +206,62 @@ func TestMarshalMultiTag(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestDecode(t *testing.T) {
+	rnd := rand.New(rand.NewSource(0))
+	kexInit := new(kexInitMsg).Generate(rnd, 10).Interface()
+	kexDHInit := new(kexDHInitMsg).Generate(rnd, 10).Interface()
+	kexDHReply := new(kexDHReplyMsg)
+	kexDHReply.Y = randomInt(rnd)
+	// Note: userAuthSuccessMsg can't be tested directly since it
+	// doesn't have a field for sshtype. So it's tested separately
+	// at the end.
+	decodeMessageTypes := []interface{}{
+		new(disconnectMsg),
+		new(serviceRequestMsg),
+		new(serviceAcceptMsg),
+		new(extInfoMsg),
+		kexInit,
+		kexDHInit,
+		kexDHReply,
+		new(userAuthRequestMsg),
+		new(userAuthFailureMsg),
+		new(userAuthBannerMsg),
+		new(userAuthPubKeyOkMsg),
+		new(globalRequestMsg),
+		new(globalRequestSuccessMsg),
+		new(globalRequestFailureMsg),
+		new(channelOpenMsg),
+		new(channelDataMsg),
+		new(channelOpenConfirmMsg),
+		new(channelOpenFailureMsg),
+		new(windowAdjustMsg),
+		new(channelEOFMsg),
+		new(channelCloseMsg),
+		new(channelRequestMsg),
+		new(channelRequestSuccessMsg),
+		new(channelRequestFailureMsg),
+		new(userAuthGSSAPIToken),
+		new(userAuthGSSAPIMIC),
+		new(userAuthGSSAPIErrTok),
+		new(userAuthGSSAPIError),
+	}
+	for _, msg := range decodeMessageTypes {
+		decoded, err := decode(Marshal(msg))
+		if err != nil {
+			t.Errorf("error decoding %T", msg)
+		} else if reflect.TypeOf(msg) != reflect.TypeOf(decoded) {
+			t.Errorf("error decoding %T, unexpected %T", msg, decoded)
+		}
+	}
+
+	userAuthSuccess, err := decode([]byte{msgUserAuthSuccess})
+	if err != nil {
+		t.Errorf("error decoding userAuthSuccessMsg")
+	} else if reflect.TypeOf(userAuthSuccess) != reflect.TypeOf((*userAuthSuccessMsg)(nil)) {
+		t.Errorf("error decoding userAuthSuccessMsg, unexpected %T", userAuthSuccess)
+	}
+}
+
 func randomBytes(out []byte, rand *rand.Rand) {
 func randomBytes(out []byte, rand *rand.Rand) {
 	for i := 0; i < len(out); i++ {
 	for i := 0; i < len(out); i++ {
 		out[i] = byte(rand.Int31())
 		out[i] = byte(rand.Int31())

+ 196 - 65
psiphon/common/crypto/ssh/server.go

@@ -59,6 +59,27 @@ type GSSAPIWithMICConfig struct {
 	Server GSSAPIServer
 	Server GSSAPIServer
 }
 }
 
 
+// SendAuthBanner implements [ServerPreAuthConn].
+func (s *connection) SendAuthBanner(msg string) error {
+	return s.transport.writePacket(Marshal(&userAuthBannerMsg{
+		Message: msg,
+	}))
+}
+
+func (*connection) unexportedMethodForFutureProofing() {}
+
+// ServerPreAuthConn is the interface available on an incoming server
+// connection before authentication has completed.
+type ServerPreAuthConn interface {
+	unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB
+
+	ConnMetadata
+
+	// SendAuthBanner sends a banner message to the client.
+	// It returns an error once the authentication phase has ended.
+	SendAuthBanner(string) error
+}
+
 // ServerConfig holds server specific configuration data.
 // ServerConfig holds server specific configuration data.
 type ServerConfig struct {
 type ServerConfig struct {
 	// Config contains configuration shared between client and server.
 	// Config contains configuration shared between client and server.
@@ -118,6 +139,12 @@ type ServerConfig struct {
 	// attempts.
 	// attempts.
 	AuthLogCallback func(conn ConnMetadata, method string, err error)
 	AuthLogCallback func(conn ConnMetadata, method string, err error)
 
 
+	// PreAuthConnCallback, if non-nil, is called upon receiving a new connection
+	// before any authentication has started. The provided ServerPreAuthConn
+	// can be used at any time before authentication is complete, including
+	// after this callback has returned.
+	PreAuthConnCallback func(ServerPreAuthConn)
+
 	// ServerVersion is the version identification string to announce in
 	// ServerVersion is the version identification string to announce in
 	// the public handshake.
 	// the public handshake.
 	// If empty, a reasonable default is used.
 	// If empty, a reasonable default is used.
@@ -149,7 +176,7 @@ func (s *ServerConfig) AddHostKey(key Signer) {
 }
 }
 
 
 // cachedPubKey contains the results of querying whether a public key is
 // cachedPubKey contains the results of querying whether a public key is
-// acceptable for a user.
+// acceptable for a user. This is a FIFO cache.
 type cachedPubKey struct {
 type cachedPubKey struct {
 	user       string
 	user       string
 	pubKeyData []byte
 	pubKeyData []byte
@@ -157,7 +184,13 @@ type cachedPubKey struct {
 	perms      *Permissions
 	perms      *Permissions
 }
 }
 
 
-const maxCachedPubKeys = 16
+// maxCachedPubKeys is the number of cache entries we store.
+//
+// Due to consistent misuse of the PublicKeyCallback API, we have reduced this
+// to 1, such that the only key in the cache is the most recently seen one. This
+// forces the behavior that the last call to PublicKeyCallback will always be
+// with the key that is used for authentication.
+const maxCachedPubKeys = 1
 
 
 // pubKeyCache caches tests for public keys.  Since SSH clients
 // pubKeyCache caches tests for public keys.  Since SSH clients
 // will query whether a public key is acceptable before attempting to
 // will query whether a public key is acceptable before attempting to
@@ -179,9 +212,10 @@ func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) {
 
 
 // add adds the given tuple to the cache.
 // add adds the given tuple to the cache.
 func (c *pubKeyCache) add(candidate cachedPubKey) {
 func (c *pubKeyCache) add(candidate cachedPubKey) {
-	if len(c.keys) < maxCachedPubKeys {
-		c.keys = append(c.keys, candidate)
+	if len(c.keys) >= maxCachedPubKeys {
+		c.keys = c.keys[1:]
 	}
 	}
+	c.keys = append(c.keys, candidate)
 }
 }
 
 
 // ServerConn is an authenticated SSH connection, as seen from the
 // ServerConn is an authenticated SSH connection, as seen from the
@@ -426,6 +460,35 @@ func (l ServerAuthError) Error() string {
 	return "[" + strings.Join(errs, ", ") + "]"
 	return "[" + strings.Join(errs, ", ") + "]"
 }
 }
 
 
+// ServerAuthCallbacks defines server-side authentication callbacks.
+type ServerAuthCallbacks struct {
+	// PasswordCallback behaves like [ServerConfig.PasswordCallback].
+	PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
+
+	// PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback].
+	PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
+
+	// KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback].
+	KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
+
+	// GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig].
+	GSSAPIWithMICConfig *GSSAPIWithMICConfig
+}
+
+// PartialSuccessError can be returned by any of the [ServerConfig]
+// authentication callbacks to indicate to the client that authentication has
+// partially succeeded, but further steps are required.
+type PartialSuccessError struct {
+	// Next defines the authentication callbacks to apply to further steps. The
+	// available methods communicated to the client are based on the non-nil
+	// ServerAuthCallbacks fields.
+	Next ServerAuthCallbacks
+}
+
+func (p *PartialSuccessError) Error() string {
+	return "ssh: authenticated with partial success"
+}
+
 // ErrNoAuth is the error value returned if no
 // ErrNoAuth is the error value returned if no
 // authentication method has been passed yet. This happens as a normal
 // authentication method has been passed yet. This happens as a normal
 // part of the authentication loop, since the client first tries
 // part of the authentication loop, since the client first tries
@@ -433,14 +496,46 @@ func (l ServerAuthError) Error() string {
 // It is returned in ServerAuthError.Errors from NewServerConn.
 // It is returned in ServerAuthError.Errors from NewServerConn.
 var ErrNoAuth = errors.New("ssh: no auth passed yet")
 var ErrNoAuth = errors.New("ssh: no auth passed yet")
 
 
+// BannerError is an error that can be returned by authentication handlers in
+// ServerConfig to send a banner message to the client.
+type BannerError struct {
+	Err     error
+	Message string
+}
+
+func (b *BannerError) Unwrap() error {
+	return b.Err
+}
+
+func (b *BannerError) Error() string {
+	if b.Err == nil {
+		return b.Message
+	}
+	return b.Err.Error()
+}
+
 func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
 func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
+	if config.PreAuthConnCallback != nil {
+		config.PreAuthConnCallback(s)
+	}
+
 	sessionID := s.transport.getSessionID()
 	sessionID := s.transport.getSessionID()
 	var cache pubKeyCache
 	var cache pubKeyCache
 	var perms *Permissions
 	var perms *Permissions
 
 
 	authFailures := 0
 	authFailures := 0
+	noneAuthCount := 0
 	var authErrs []error
 	var authErrs []error
-	var displayedBanner bool
+	var calledBannerCallback bool
+	partialSuccessReturned := false
+	// Set the initial authentication callbacks from the config. They can be
+	// changed if a PartialSuccessError is returned.
+	authConfig := ServerAuthCallbacks{
+		PasswordCallback:            config.PasswordCallback,
+		PublicKeyCallback:           config.PublicKeyCallback,
+		KeyboardInteractiveCallback: config.KeyboardInteractiveCallback,
+		GSSAPIWithMICConfig:         config.GSSAPIWithMICConfig,
+	}
 
 
 userAuthLoop:
 userAuthLoop:
 	for {
 	for {
@@ -453,8 +548,8 @@ userAuthLoop:
 			if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
 			if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
-
-			return nil, discMsg
+			authErrs = append(authErrs, discMsg)
+			return nil, &ServerAuthError{Errors: authErrs}
 		}
 		}
 
 
 		var userAuthReq userAuthRequestMsg
 		var userAuthReq userAuthRequestMsg
@@ -471,16 +566,17 @@ userAuthLoop:
 			return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
 			return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
 		}
 		}
 
 
+		if s.user != userAuthReq.User && partialSuccessReturned {
+			return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q",
+				s.user, userAuthReq.User)
+		}
+
 		s.user = userAuthReq.User
 		s.user = userAuthReq.User
 
 
-		if !displayedBanner && config.BannerCallback != nil {
-			displayedBanner = true
-			msg := config.BannerCallback(s)
-			if msg != "" {
-				bannerMsg := &userAuthBannerMsg{
-					Message: msg,
-				}
-				if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
+		if !calledBannerCallback && config.BannerCallback != nil {
+			calledBannerCallback = true
+			if msg := config.BannerCallback(s); msg != "" {
+				if err := s.SendAuthBanner(msg); err != nil {
 					return nil, err
 					return nil, err
 				}
 				}
 			}
 			}
@@ -491,20 +587,18 @@ userAuthLoop:
 
 
 		switch userAuthReq.Method {
 		switch userAuthReq.Method {
 		case "none":
 		case "none":
-			if config.NoClientAuth {
+			noneAuthCount++
+			// We don't allow none authentication after a partial success
+			// response.
+			if config.NoClientAuth && !partialSuccessReturned {
 				if config.NoClientAuthCallback != nil {
 				if config.NoClientAuthCallback != nil {
 					perms, authErr = config.NoClientAuthCallback(s)
 					perms, authErr = config.NoClientAuthCallback(s)
 				} else {
 				} else {
 					authErr = nil
 					authErr = nil
 				}
 				}
 			}
 			}
-
-			// allow initial attempt of 'none' without penalty
-			if authFailures == 0 {
-				authFailures--
-			}
 		case "password":
 		case "password":
-			if config.PasswordCallback == nil {
+			if authConfig.PasswordCallback == nil {
 				authErr = errors.New("ssh: password auth not configured")
 				authErr = errors.New("ssh: password auth not configured")
 				break
 				break
 			}
 			}
@@ -518,17 +612,17 @@ userAuthLoop:
 				return nil, parseError(msgUserAuthRequest)
 				return nil, parseError(msgUserAuthRequest)
 			}
 			}
 
 
-			perms, authErr = config.PasswordCallback(s, password)
+			perms, authErr = authConfig.PasswordCallback(s, password)
 		case "keyboard-interactive":
 		case "keyboard-interactive":
-			if config.KeyboardInteractiveCallback == nil {
+			if authConfig.KeyboardInteractiveCallback == nil {
 				authErr = errors.New("ssh: keyboard-interactive auth not configured")
 				authErr = errors.New("ssh: keyboard-interactive auth not configured")
 				break
 				break
 			}
 			}
 
 
 			prompter := &sshClientKeyboardInteractive{s}
 			prompter := &sshClientKeyboardInteractive{s}
-			perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
+			perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge)
 		case "publickey":
 		case "publickey":
-			if config.PublicKeyCallback == nil {
+			if authConfig.PublicKeyCallback == nil {
 				authErr = errors.New("ssh: publickey auth not configured")
 				authErr = errors.New("ssh: publickey auth not configured")
 				break
 				break
 			}
 			}
@@ -562,11 +656,18 @@ userAuthLoop:
 			if !ok {
 			if !ok {
 				candidate.user = s.user
 				candidate.user = s.user
 				candidate.pubKeyData = pubKeyData
 				candidate.pubKeyData = pubKeyData
-				candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
-				if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
-					candidate.result = checkSourceAddress(
+				candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
+				_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
+
+				if (candidate.result == nil || isPartialSuccessError) &&
+					candidate.perms != nil &&
+					candidate.perms.CriticalOptions != nil &&
+					candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
+					if err := checkSourceAddress(
 						s.RemoteAddr(),
 						s.RemoteAddr(),
-						candidate.perms.CriticalOptions[sourceAddressCriticalOption])
+						candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil {
+						candidate.result = err
+					}
 				}
 				}
 				cache.add(candidate)
 				cache.add(candidate)
 			}
 			}
@@ -578,8 +679,8 @@ userAuthLoop:
 				if len(payload) > 0 {
 				if len(payload) > 0 {
 					return nil, parseError(msgUserAuthRequest)
 					return nil, parseError(msgUserAuthRequest)
 				}
 				}
-
-				if candidate.result == nil {
+				_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
+				if candidate.result == nil || isPartialSuccessError {
 					okMsg := userAuthPubKeyOkMsg{
 					okMsg := userAuthPubKeyOkMsg{
 						Algo:   algo,
 						Algo:   algo,
 						PubKey: pubKeyData,
 						PubKey: pubKeyData,
@@ -629,11 +730,11 @@ userAuthLoop:
 				perms = candidate.perms
 				perms = candidate.perms
 			}
 			}
 		case "gssapi-with-mic":
 		case "gssapi-with-mic":
-			if config.GSSAPIWithMICConfig == nil {
+			if authConfig.GSSAPIWithMICConfig == nil {
 				authErr = errors.New("ssh: gssapi-with-mic auth not configured")
 				authErr = errors.New("ssh: gssapi-with-mic auth not configured")
 				break
 				break
 			}
 			}
-			gssapiConfig := config.GSSAPIWithMICConfig
+			gssapiConfig := authConfig.GSSAPIWithMICConfig
 			userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
 			userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
 			if err != nil {
 			if err != nil {
 				return nil, parseError(msgUserAuthRequest)
 				return nil, parseError(msgUserAuthRequest)
@@ -685,53 +786,83 @@ userAuthLoop:
 			config.AuthLogCallback(s, userAuthReq.Method, authErr)
 			config.AuthLogCallback(s, userAuthReq.Method, authErr)
 		}
 		}
 
 
+		var bannerErr *BannerError
+		if errors.As(authErr, &bannerErr) {
+			if bannerErr.Message != "" {
+				if err := s.SendAuthBanner(bannerErr.Message); err != nil {
+					return nil, err
+				}
+			}
+		}
+
 		if authErr == nil {
 		if authErr == nil {
 			break userAuthLoop
 			break userAuthLoop
 		}
 		}
 
 
-		authFailures++
-		if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
-			// If we have hit the max attempts, don't bother sending the
-			// final SSH_MSG_USERAUTH_FAILURE message, since there are
-			// no more authentication methods which can be attempted,
-			// and this message may cause the client to re-attempt
-			// authentication while we send the disconnect message.
-			// Continue, and trigger the disconnect at the start of
-			// the loop.
-			//
-			// The SSH specification is somewhat confusing about this,
-			// RFC 4252 Section 5.1 requires each authentication failure
-			// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
-			// message, but Section 4 says the server should disconnect
-			// after some number of attempts, but it isn't explicit which
-			// message should take precedence (i.e. should there be a failure
-			// message than a disconnect message, or if we are going to
-			// disconnect, should we only send that message.)
-			//
-			// Either way, OpenSSH disconnects immediately after the last
-			// failed authnetication attempt, and given they are typically
-			// considered the golden implementation it seems reasonable
-			// to match that behavior.
-			continue
+		var failureMsg userAuthFailureMsg
+
+		if partialSuccess, ok := authErr.(*PartialSuccessError); ok {
+			// After a partial success error we don't allow changing the user
+			// name and execute the NoClientAuthCallback.
+			partialSuccessReturned = true
+
+			// In case a partial success is returned, the server may send
+			// a new set of authentication methods.
+			authConfig = partialSuccess.Next
+
+			// Reset pubkey cache, as the new PublicKeyCallback might
+			// accept a different set of public keys.
+			cache = pubKeyCache{}
+
+			// Send back a partial success message to the user.
+			failureMsg.PartialSuccess = true
+		} else {
+			// Allow initial attempt of 'none' without penalty.
+			if authFailures > 0 || userAuthReq.Method != "none" || noneAuthCount != 1 {
+				authFailures++
+			}
+			if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
+				// If we have hit the max attempts, don't bother sending the
+				// final SSH_MSG_USERAUTH_FAILURE message, since there are
+				// no more authentication methods which can be attempted,
+				// and this message may cause the client to re-attempt
+				// authentication while we send the disconnect message.
+				// Continue, and trigger the disconnect at the start of
+				// the loop.
+				//
+				// The SSH specification is somewhat confusing about this,
+				// RFC 4252 Section 5.1 requires each authentication failure
+				// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
+				// message, but Section 4 says the server should disconnect
+				// after some number of attempts, but it isn't explicit which
+				// message should take precedence (i.e. should there be a failure
+				// message than a disconnect message, or if we are going to
+				// disconnect, should we only send that message.)
+				//
+				// Either way, OpenSSH disconnects immediately after the last
+				// failed authentication attempt, and given they are typically
+				// considered the golden implementation it seems reasonable
+				// to match that behavior.
+				continue
+			}
 		}
 		}
 
 
-		var failureMsg userAuthFailureMsg
-		if config.PasswordCallback != nil {
+		if authConfig.PasswordCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "password")
 			failureMsg.Methods = append(failureMsg.Methods, "password")
 		}
 		}
-		if config.PublicKeyCallback != nil {
+		if authConfig.PublicKeyCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "publickey")
 			failureMsg.Methods = append(failureMsg.Methods, "publickey")
 		}
 		}
-		if config.KeyboardInteractiveCallback != nil {
+		if authConfig.KeyboardInteractiveCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
 			failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
 		}
 		}
-		if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil &&
-			config.GSSAPIWithMICConfig.AllowLogin != nil {
+		if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil &&
+			authConfig.GSSAPIWithMICConfig.AllowLogin != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
 			failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
 		}
 		}
 
 
 		if len(failureMsg.Methods) == 0 {
 		if len(failureMsg.Methods) == 0 {
-			return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
+			return nil, errors.New("ssh: no authentication methods available")
 		}
 		}
 
 
 		if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
 		if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {

+ 412 - 0
psiphon/common/crypto/ssh/server_multi_auth_test.go

@@ -0,0 +1,412 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"strings"
+	"testing"
+)
+
+func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) {
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	var serverAuthErrors []error
+
+	serverConfig.AddHostKey(testSigners["rsa"])
+	serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
+		serverAuthErrors = append(serverAuthErrors, err)
+	}
+	go newServer(c1, serverConfig)
+	c, _, _, err := NewClientConn(c2, "", clientConfig)
+	if err == nil {
+		c.Close()
+	}
+	return serverAuthErrors, err
+}
+
+func TestMultiStepAuth(t *testing.T) {
+	// This user can login with password, public key or public key + password.
+	username := "testuser"
+	// This user can login with public key + password only.
+	usernameSecondFactor := "testuser_second_factor"
+	errPwdAuthFailed := errors.New("password auth failed")
+	errWrongSequence := errors.New("wrong sequence")
+
+	serverConfig := &ServerConfig{
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			if conn.User() == usernameSecondFactor {
+				return nil, errWrongSequence
+			}
+			if conn.User() == username && string(password) == clientPassword {
+				return nil, nil
+			}
+			return nil, errPwdAuthFailed
+		},
+		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+			if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+				if conn.User() == usernameSecondFactor {
+					return nil, &PartialSuccessError{
+						Next: ServerAuthCallbacks{
+							PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+								if string(password) == clientPassword {
+									return nil, nil
+								}
+								return nil, errPwdAuthFailed
+							},
+						},
+					}
+				}
+				return nil, nil
+			}
+			return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
+		},
+	}
+
+	clientConfig := &ClientConfig{
+		User: usernameSecondFactor,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1])
+	}
+	// Now test a wrong sequence.
+	clientConfig.Auth = []AuthMethod{
+		Password(clientPassword),
+		PublicKeys(testSigners["rsa"]),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with wrong sequence must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - wrong sequence
+	// - partial success
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if serverAuthErrors[1] != errWrongSequence {
+		t.Fatal("server not returned wrong sequence")
+	}
+	if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok {
+		t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2])
+	}
+	// Now test using a correct sequence but a wrong password before the right
+	// one.
+	n := 0
+	passwords := []string{"WRONG", "WRONG", clientPassword}
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+		RetryableAuthMethod(PasswordCallback(func() (string, error) {
+			p := passwords[n]
+			n++
+			return p, nil
+		}), 3),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - wrong password
+	// - wrong password
+	// - nil
+	if len(serverAuthErrors) != 5 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+	if serverAuthErrors[2] != errPwdAuthFailed {
+		t.Fatal("server not returned password authentication failed")
+	}
+	if serverAuthErrors[3] != errPwdAuthFailed {
+		t.Fatal("server not returned password authentication failed")
+	}
+	// Only password authentication should fail.
+	clientConfig.Auth = []AuthMethod{
+		Password(clientPassword),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with password only must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - wrong sequence
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if serverAuthErrors[1] != errWrongSequence {
+		t.Fatal("server not returned wrong sequence")
+	}
+
+	// Only public key authentication should fail.
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with public key only must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	// Public key and wrong password.
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+		Password("WRONG"),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("client login with wrong password after public key must fail")
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - password auth failed
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+	if serverAuthErrors[2] != errPwdAuthFailed {
+		t.Fatal("server not returned password authentication failed")
+	}
+
+	// Public key, public key again and then correct password. Public key
+	// authentication is attempted only once because the partial success error
+	// returns only "password" as the allowed authentication method.
+	clientConfig.Auth = []AuthMethod{
+		PublicKeys(testSigners["rsa"]),
+		PublicKeys(testSigners["rsa"]),
+		Password(clientPassword),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - no auth passed yet
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	// The unrestricted username can do anything
+	clientConfig = &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("unrestricted client login error: %s", err)
+	}
+
+	clientConfig = &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("unrestricted client login error: %s", err)
+	}
+
+	clientConfig = &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("unrestricted client login error: %s", err)
+	}
+}
+
+func TestDynamicAuthCallbacks(t *testing.T) {
+	user1 := "user1"
+	user2 := "user2"
+	errInvalidCredentials := errors.New("invalid credentials")
+
+	serverConfig := &ServerConfig{
+		NoClientAuth: true,
+		NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
+			switch conn.User() {
+			case user1:
+				return nil, &PartialSuccessError{
+					Next: ServerAuthCallbacks{
+						PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+							if conn.User() == user1 && string(password) == clientPassword {
+								return nil, nil
+							}
+							return nil, errInvalidCredentials
+						},
+					},
+				}
+			case user2:
+				return nil, &PartialSuccessError{
+					Next: ServerAuthCallbacks{
+						PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+							if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+								if conn.User() == user2 {
+									return nil, nil
+								}
+							}
+							return nil, errInvalidCredentials
+						},
+					},
+				}
+			default:
+				return nil, errInvalidCredentials
+			}
+		},
+	}
+
+	clientConfig := &ClientConfig{
+		User: user1,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	clientConfig = &ClientConfig{
+		User: user2,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	// The error sequence is:
+	// - partial success
+	// - nil
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+
+	// user1 cannot login with public key
+	clientConfig = &ClientConfig{
+		User: user1,
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("user1 login with public key must fail")
+	}
+	if !strings.Contains(err.Error(), "no supported methods remain") {
+		t.Errorf("got %v, expected 'no supported methods remain'", err)
+	}
+	if len(serverAuthErrors) != 1 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+	// user2 cannot login with password
+	clientConfig = &ClientConfig{
+		User: user2,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
+	if err == nil {
+		t.Fatal("user2 login with password must fail")
+	}
+	if !strings.Contains(err.Error(), "no supported methods remain") {
+		t.Errorf("got %v, expected 'no supported methods remain'", err)
+	}
+	if len(serverAuthErrors) != 1 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
+		t.Fatal("server not returned partial success")
+	}
+}

+ 338 - 0
psiphon/common/crypto/ssh/server_test.go

@@ -5,8 +5,13 @@
 package ssh
 package ssh
 
 
 import (
 import (
+	"bytes"
+	"errors"
+	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
+	"reflect"
+	"strings"
 	"sync/atomic"
 	"sync/atomic"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -62,6 +67,133 @@ func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestMaxAuthTriesNoneMethod(t *testing.T) {
+	username := "testuser"
+	serverConfig := &ServerConfig{
+		MaxAuthTries: 2,
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			if conn.User() == username && string(password) == clientPassword {
+				return nil, nil
+			}
+			return nil, errors.New("invalid credentials")
+		},
+	}
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	var serverAuthErrors []error
+
+	serverConfig.AddHostKey(testSigners["rsa"])
+	serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
+		serverAuthErrors = append(serverAuthErrors, err)
+	}
+	go newServer(c1, serverConfig)
+
+	clientConfig := ClientConfig{
+		User:            username,
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+	clientConfig.SetDefaults()
+	// Our client will send 'none' auth only once, so we need to send the
+	// requests manually.
+	c := &connection{
+		sshConn: sshConn{
+			conn:          c2,
+			user:          username,
+			clientVersion: []byte(packageVersion),
+		},
+	}
+	c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
+	if err != nil {
+		t.Fatalf("unable to exchange version: %v", err)
+	}
+	c.transport = newClientTransport(
+		newTransport(c.sshConn.conn, clientConfig.Rand, true /* is client */),
+		c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr())
+	if err := c.transport.waitSession(); err != nil {
+		t.Fatalf("unable to wait session: %v", err)
+	}
+	c.sessionID = c.transport.getSessionID()
+	if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
+		t.Fatalf("unable to send ssh-userauth message: %v", err)
+	}
+	packet, err := c.transport.readPacket()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(packet) > 0 && packet[0] == msgExtInfo {
+		packet, err = c.transport.readPacket()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}
+	var serviceAccept serviceAcceptMsg
+	if err := Unmarshal(packet, &serviceAccept); err != nil {
+		t.Fatal(err)
+	}
+	for i := 0; i <= serverConfig.MaxAuthTries; i++ {
+		auth := new(noneAuth)
+		_, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil)
+		if i < serverConfig.MaxAuthTries {
+			if err != nil {
+				t.Fatal(err)
+			}
+			continue
+		}
+		if err == nil {
+			t.Fatal("client: got no error")
+		} else if !strings.Contains(err.Error(), "too many authentication failures") {
+			t.Fatalf("client: got unexpected error: %v", err)
+		}
+	}
+	if len(serverAuthErrors) != 3 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	for _, err := range serverAuthErrors {
+		if !errors.Is(err, ErrNoAuth) {
+			t.Errorf("go error: %v; want: %v", err, ErrNoAuth)
+		}
+	}
+}
+
+func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) {
+	username := "testuser"
+	serverConfig := &ServerConfig{
+		MaxAuthTries: 1,
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			if conn.User() == username && string(password) == clientPassword {
+				return nil, nil
+			}
+			return nil, errors.New("invalid credentials")
+		},
+	}
+	clientConfig := &ClientConfig{
+		User: username,
+		Auth: []AuthMethod{
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
+	if err != nil {
+		t.Fatalf("client login error: %s", err)
+	}
+	if len(serverAuthErrors) != 2 {
+		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
+	}
+	if !errors.Is(serverAuthErrors[0], ErrNoAuth) {
+		t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth)
+	}
+	if serverAuthErrors[1] != nil {
+		t.Errorf("unexpected error: %v", serverAuthErrors[1])
+	}
+}
+
 func TestNewServerConnValidationErrors(t *testing.T) {
 func TestNewServerConnValidationErrors(t *testing.T) {
 	serverConf := &ServerConfig{
 	serverConf := &ServerConfig{
 		PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01},
 		PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01},
@@ -96,6 +228,212 @@ func TestNewServerConnValidationErrors(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestBannerError(t *testing.T) {
+	serverConfig := &ServerConfig{
+		BannerCallback: func(ConnMetadata) string {
+			return "banner from BannerCallback"
+		},
+		NoClientAuth: true,
+		NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
+			err := &BannerError{
+				Err:     errors.New("error from NoClientAuthCallback"),
+				Message: "banner from NoClientAuthCallback",
+			}
+			return nil, fmt.Errorf("wrapped: %w", err)
+		},
+		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+			return &Permissions{}, nil
+		},
+		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+			return nil, &BannerError{
+				Err:     errors.New("error from PublicKeyCallback"),
+				Message: "banner from PublicKeyCallback",
+			}
+		},
+		KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) {
+			return nil, &BannerError{
+				Err:     nil, // make sure that a nil inner error is allowed
+				Message: "banner from KeyboardInteractiveCallback",
+			}
+		},
+	}
+	serverConfig.AddHostKey(testSigners["rsa"])
+
+	var banners []string
+	clientConfig := &ClientConfig{
+		User: "test",
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
+			KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
+				return []string{"letmein"}, nil
+			}),
+			Password(clientPassword),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		BannerCallback: func(msg string) error {
+			banners = append(banners, msg)
+			return nil
+		},
+	}
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+	go newServer(c1, serverConfig)
+	c, _, _, err := NewClientConn(c2, "", clientConfig)
+	if err != nil {
+		t.Fatalf("client connection failed: %v", err)
+	}
+	defer c.Close()
+
+	wantBanners := []string{
+		"banner from BannerCallback",
+		"banner from NoClientAuthCallback",
+		"banner from PublicKeyCallback",
+		"banner from KeyboardInteractiveCallback",
+	}
+	if !reflect.DeepEqual(banners, wantBanners) {
+		t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
+	}
+}
+
+func TestPublicKeyCallbackLastSeen(t *testing.T) {
+	var lastSeenKey PublicKey
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+	serverConf := &ServerConfig{
+		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+			lastSeenKey = key
+			fmt.Printf("seen %#v\n", key)
+			if _, ok := key.(*dsaPublicKey); !ok {
+				return nil, errors.New("nope")
+			}
+			return nil, nil
+		},
+	}
+	serverConf.AddHostKey(testSigners["ecdsap256"])
+
+	done := make(chan struct{})
+	go func() {
+		defer close(done)
+		NewServerConn(c1, serverConf)
+	}()
+
+	clientConf := ClientConfig{
+		User: "user",
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	_, _, _, err = NewClientConn(c2, "", &clientConf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	<-done
+
+	expectedPublicKey := testSigners["dsa"].PublicKey().Marshal()
+	lastSeenMarshalled := lastSeenKey.Marshal()
+	if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) {
+		t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey())
+	}
+}
+
+func TestPreAuthConnAndBanners(t *testing.T) {
+	testDone := make(chan struct{})
+	defer close(testDone)
+
+	authConnc := make(chan ServerPreAuthConn, 1)
+	serverConfig := &ServerConfig{
+		PreAuthConnCallback: func(c ServerPreAuthConn) {
+			t.Logf("got ServerPreAuthConn: %v", c)
+			authConnc <- c // for use later in the test
+			for _, s := range []string{"hello1", "hello2"} {
+				if err := c.SendAuthBanner(s); err != nil {
+					t.Errorf("failed to send banner %q: %v", s, err)
+				}
+			}
+			// Now start a goroutine to spam SendAuthBanner in hopes
+			// of hitting a race.
+			go func() {
+				for {
+					select {
+					case <-testDone:
+						return
+					default:
+						if err := c.SendAuthBanner("attempted-race"); err != nil && err != errSendBannerPhase {
+							t.Errorf("unexpected error from SendAuthBanner: %v", err)
+						}
+						time.Sleep(5 * time.Millisecond)
+					}
+				}
+			}()
+		},
+		NoClientAuth: true,
+		NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
+			t.Logf("got NoClientAuthCallback")
+			return &Permissions{}, nil
+		},
+	}
+	serverConfig.AddHostKey(testSigners["rsa"])
+
+	var banners []string
+	clientConfig := &ClientConfig{
+		User:            "test",
+		HostKeyCallback: InsecureIgnoreHostKey(),
+		BannerCallback: func(msg string) error {
+			if msg != "attempted-race" {
+				banners = append(banners, msg)
+			}
+			return nil
+		},
+	}
+
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+	go newServer(c1, serverConfig)
+	c, _, _, err := NewClientConn(c2, "", clientConfig)
+	if err != nil {
+		t.Fatalf("client connection failed: %v", err)
+	}
+	defer c.Close()
+
+	wantBanners := []string{
+		"hello1",
+		"hello2",
+	}
+	if !reflect.DeepEqual(banners, wantBanners) {
+		t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
+	}
+
+	// Now that we're authenticated, verify that use of SendBanner
+	// is an error.
+	var bc ServerPreAuthConn
+	select {
+	case bc = <-authConnc:
+	default:
+		t.Fatal("expected ServerPreAuthConn")
+	}
+	if err := bc.SendAuthBanner("wrong-phase"); err == nil {
+		t.Error("unexpected success of SendAuthBanner after authentication")
+	} else if err != errSendBannerPhase {
+		t.Errorf("unexpected error: %v; want %v", err, errSendBannerPhase)
+	}
+}
+
 type markerConn struct {
 type markerConn struct {
 	closed uint32
 	closed uint32
 	used   uint32
 	used   uint32

+ 1 - 1
psiphon/common/crypto/ssh/tcpip.go

@@ -472,7 +472,7 @@ func (c *Client) dial(channelType string, laddr string, lport int, raddr string,
 		return nil, err
 		return nil, err
 	}
 	}
 	go DiscardRequests(in)
 	go DiscardRequests(in)
-	return ch, err
+	return ch, nil
 }
 }
 
 
 type tcpChan struct {
 type tcpChan struct {

+ 4 - 3
psiphon/common/crypto/ssh/test/agent_unix_test.go

@@ -20,20 +20,21 @@ func TestAgentForward(t *testing.T) {
 	defer conn.Close()
 	defer conn.Close()
 
 
 	keyring := agent.NewKeyring()
 	keyring := agent.NewKeyring()
-	if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["dsa"]}); err != nil {
+	if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["ecdsa"]}); err != nil {
 		t.Fatalf("Error adding key: %s", err)
 		t.Fatalf("Error adding key: %s", err)
 	}
 	}
 	if err := keyring.Add(agent.AddedKey{
 	if err := keyring.Add(agent.AddedKey{
-		PrivateKey:       testPrivateKeys["dsa"],
+		PrivateKey:       testPrivateKeys["ecdsa"],
 		ConfirmBeforeUse: true,
 		ConfirmBeforeUse: true,
 		LifetimeSecs:     3600,
 		LifetimeSecs:     3600,
 	}); err != nil {
 	}); err != nil {
 		t.Fatalf("Error adding key with constraints: %s", err)
 		t.Fatalf("Error adding key with constraints: %s", err)
 	}
 	}
-	pub := testPublicKeys["dsa"]
+	pub := testPublicKeys["ecdsa"]
 
 
 	sess, err := conn.NewSession()
 	sess, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("NewSession: %v", err)
 		t.Fatalf("NewSession: %v", err)
 	}
 	}
 	if err := agent.RequestAgentForwarding(sess); err != nil {
 	if err := agent.RequestAgentForwarding(sess); err != nil {

+ 1 - 1
psiphon/common/crypto/ssh/test/cert_test.go

@@ -19,7 +19,7 @@ func TestCertLogin(t *testing.T) {
 	s := newServer(t)
 	s := newServer(t)
 
 
 	// Use a key different from the default.
 	// Use a key different from the default.
-	clientKey := testSigners["dsa"]
+	clientKey := testSigners["ed25519"]
 	caAuthKey := testSigners["ecdsa"]
 	caAuthKey := testSigners["ecdsa"]
 	cert := &ssh.Certificate{
 	cert := &ssh.Certificate{
 		Key:             clientKey.PublicKey(),
 		Key:             clientKey.PublicKey(),

+ 1 - 0
psiphon/common/crypto/ssh/test/dial_unix_test.go

@@ -53,6 +53,7 @@ func testDial(t *testing.T, n, listenAddr string, x dialTester) {
 	// on the opened connection.
 	// on the opened connection.
 	cancel()
 	cancel()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("Dial: %v", err)
 		t.Fatalf("Dial: %v", err)
 	}
 	}
 	x.TestClientConn(t, conn)
 	x.TestClientConn(t, conn)

+ 1 - 1
psiphon/common/crypto/ssh/test/doc.go

@@ -4,4 +4,4 @@
 
 
 // Package test contains integration tests for the
 // Package test contains integration tests for the
 // github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh package.
 // github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh package.
-package test // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/test"
+package test

+ 10 - 0
psiphon/common/crypto/ssh/test/forward_unix_test.go

@@ -12,6 +12,7 @@ import (
 	"io"
 	"io"
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
+	"runtime"
 	"testing"
 	"testing"
 	"time"
 	"time"
 )
 )
@@ -27,6 +28,9 @@ func testPortForward(t *testing.T, n, listenAddr string) {
 
 
 	sshListener, err := conn.Listen(n, listenAddr)
 	sshListener, err := conn.Listen(n, listenAddr)
 	if err != nil {
 	if err != nil {
+		if runtime.GOOS == "darwin" && err == io.EOF {
+			t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
+		}
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
@@ -122,6 +126,9 @@ func testAcceptClose(t *testing.T, n, listenAddr string) {
 
 
 	sshListener, err := conn.Listen(n, listenAddr)
 	sshListener, err := conn.Listen(n, listenAddr)
 	if err != nil {
 	if err != nil {
+		if runtime.GOOS == "darwin" && err == io.EOF {
+			t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
+		}
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
@@ -163,6 +170,9 @@ func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
 
 
 	sshListener, err := client.Listen(n, listenAddr)
 	sshListener, err := client.Listen(n, listenAddr)
 	if err != nil {
 	if err != nil {
+		if runtime.GOOS == "darwin" && err == io.EOF {
+			t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
+		}
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 

+ 1 - 1
psiphon/common/crypto/ssh/test/server_test.go

@@ -14,7 +14,7 @@ type exitStatusMsg struct {
 	Status uint32
 	Status uint32
 }
 }
 
 
-// goServer is a test Go SSH server that accepts public key and certificate
+// goTestServer is a test Go SSH server that accepts public key and certificate
 // authentication and replies with a 0 exit status to any exec request without
 // authentication and replies with a 0 exit status to any exec request without
 // running any commands.
 // running any commands.
 type goTestServer struct {
 type goTestServer struct {

+ 91 - 6
psiphon/common/crypto/ssh/test/session_test.go

@@ -22,6 +22,13 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 )
 )
 
 
+func skipIfIssue64959(t *testing.T, err error) {
+	if err != nil && runtime.GOOS == "darwin" && strings.Contains(err.Error(), "ssh: unexpected packet in response to channel open: <nil>") {
+		t.Helper()
+		t.Skipf("skipping test broken on some versions of macOS; see https://go.dev/issue/64959")
+	}
+}
+
 func TestRunCommandSuccess(t *testing.T) {
 func TestRunCommandSuccess(t *testing.T) {
 	server := newServer(t)
 	server := newServer(t)
 	conn := server.Dial(clientConfig())
 	conn := server.Dial(clientConfig())
@@ -29,6 +36,7 @@ func TestRunCommandSuccess(t *testing.T) {
 
 
 	session, err := conn.NewSession()
 	session, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 		t.Fatalf("session failed: %v", err)
 	}
 	}
 	defer session.Close()
 	defer session.Close()
@@ -66,6 +74,7 @@ func TestRunCommandStdin(t *testing.T) {
 
 
 	session, err := conn.NewSession()
 	session, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 		t.Fatalf("session failed: %v", err)
 	}
 	}
 	defer session.Close()
 	defer session.Close()
@@ -88,6 +97,7 @@ func TestRunCommandStdinError(t *testing.T) {
 
 
 	session, err := conn.NewSession()
 	session, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 		t.Fatalf("session failed: %v", err)
 	}
 	}
 	defer session.Close()
 	defer session.Close()
@@ -111,6 +121,7 @@ func TestRunCommandFailed(t *testing.T) {
 
 
 	session, err := conn.NewSession()
 	session, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 		t.Fatalf("session failed: %v", err)
 	}
 	}
 	defer session.Close()
 	defer session.Close()
@@ -127,6 +138,7 @@ func TestRunCommandWeClosed(t *testing.T) {
 
 
 	session, err := conn.NewSession()
 	session, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 		t.Fatalf("session failed: %v", err)
 	}
 	}
 	err = session.Shell()
 	err = session.Shell()
@@ -146,6 +158,7 @@ func TestFuncLargeRead(t *testing.T) {
 
 
 	session, err := conn.NewSession()
 	session, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("unable to create new session: %s", err)
 		t.Fatalf("unable to create new session: %s", err)
 	}
 	}
 
 
@@ -182,6 +195,7 @@ func TestKeyChange(t *testing.T) {
 	for i := 0; i < 4; i++ {
 	for i := 0; i < 4; i++ {
 		session, err := conn.NewSession()
 		session, err := conn.NewSession()
 		if err != nil {
 		if err != nil {
+			skipIfIssue64959(t, err)
 			t.Fatalf("unable to create new session: %s", err)
 			t.Fatalf("unable to create new session: %s", err)
 		}
 		}
 
 
@@ -223,6 +237,7 @@ func TestValidTerminalMode(t *testing.T) {
 
 
 	session, err := conn.NewSession()
 	session, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 		t.Fatalf("session failed: %v", err)
 	}
 	}
 	defer session.Close()
 	defer session.Close()
@@ -287,6 +302,7 @@ func TestWindowChange(t *testing.T) {
 
 
 	session, err := conn.NewSession()
 	session, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("session failed: %v", err)
 		t.Fatalf("session failed: %v", err)
 	}
 	}
 	defer session.Close()
 	defer session.Close()
@@ -341,14 +357,10 @@ func testOneCipher(t *testing.T, cipher string, cipherOrder []string) {
 
 
 	numBytes := 4096
 	numBytes := 4096
 
 
-	// Exercise sending data to the server
-	if _, _, err := conn.Conn.SendRequest("drop-me", false, make([]byte, numBytes)); err != nil {
-		t.Fatalf("SendRequest: %v", err)
-	}
-
 	// Exercise receiving data from the server
 	// Exercise receiving data from the server
 	session, err := conn.NewSession()
 	session, err := conn.NewSession()
 	if err != nil {
 	if err != nil {
+		skipIfIssue64959(t, err)
 		t.Fatalf("NewSession: %v", err)
 		t.Fatalf("NewSession: %v", err)
 	}
 	}
 
 
@@ -360,6 +372,11 @@ func testOneCipher(t *testing.T, cipher string, cipherOrder []string) {
 	if len(out) != numBytes {
 	if len(out) != numBytes {
 		t.Fatalf("got %d bytes, want %d bytes", len(out), numBytes)
 		t.Fatalf("got %d bytes, want %d bytes", len(out), numBytes)
 	}
 	}
+
+	// Exercise sending data to the server
+	if _, _, err := conn.Conn.SendRequest("drop-me", false, make([]byte, numBytes)); err != nil {
+		t.Fatalf("SendRequest: %v", err)
+	}
 }
 }
 
 
 var deprecatedCiphers = []string{
 var deprecatedCiphers = []string{
@@ -431,7 +448,6 @@ func TestKeyExchanges(t *testing.T) {
 func TestClientAuthAlgorithms(t *testing.T) {
 func TestClientAuthAlgorithms(t *testing.T) {
 	for _, key := range []string{
 	for _, key := range []string{
 		"rsa",
 		"rsa",
-		"dsa",
 		"ecdsa",
 		"ecdsa",
 		"ed25519",
 		"ed25519",
 	} {
 	} {
@@ -452,3 +468,72 @@ func TestClientAuthAlgorithms(t *testing.T) {
 		})
 		})
 	}
 	}
 }
 }
+
+func TestClientAuthDisconnect(t *testing.T) {
+	// Use a static key that is not accepted by server.
+	// This key has been generated with following ssh-keygen command and
+	// used exclusively in this unit test:
+	// $ ssh-keygen -t RSA -b 2048 -f /tmp/static_key \
+	//   -C "Static RSA key for golang.org/x/crypto/ssh unit test"
+
+	const privKeyData = `-----BEGIN OPENSSH PRIVATE KEY-----
+b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
+NhAAAAAwEAAQAAAQEAwV1Zg3MqX27nIQQNWd8V09P4q4F1fx7H2xNJdL3Yg3y91GFLJ92+
+0IiGV8n1VMGL/71PPhzyqBpUYSTpWjiU2JZSfA+iTg1GJBcOaEOA6vrXsTtXTHZ//mOT4d
+mlvuP4+9NqfCBLGXN7ZJpT+amkD8AVW9YW9QN3ipY61ZWxPaAocVpDd8rVgJTk54KvaPa7
+t4ddOSQDQq61aubIDR1Z3P+XjkB4piWOsbck3HJL+veTALy12C09tAhwUnZUAXS+DjhxOL
+xpDVclF/yXYhAvBvsjwyk/OC3+nK9F799hpQZsjxmbP7oN+tGwz06BUcAKi7u7QstENvvk
+85SDZy1q1QAAA/A7ylbJO8pWyQAAAAdzc2gtcnNhAAABAQDBXVmDcypfbuchBA1Z3xXT0/
+irgXV/HsfbE0l0vdiDfL3UYUsn3b7QiIZXyfVUwYv/vU8+HPKoGlRhJOlaOJTYllJ8D6JO
+DUYkFw5oQ4Dq+texO1dMdn/+Y5Ph2aW+4/j702p8IEsZc3tkmlP5qaQPwBVb1hb1A3eKlj
+rVlbE9oChxWkN3ytWAlOTngq9o9ru3h105JANCrrVq5sgNHVnc/5eOQHimJY6xtyTcckv6
+95MAvLXYLT20CHBSdlQBdL4OOHE4vGkNVyUX/JdiEC8G+yPDKT84Lf6cr0Xv32GlBmyPGZ
+s/ug360bDPToFRwAqLu7tCy0Q2++TzlINnLWrVAAAAAwEAAQAAAQAIvPDHMiyIxgCksGPF
+uyv9F9U4XjVip8/abE9zkAMJWW5++wuT/bRlBOUPRrWIXZEM9ETbtsqswo3Wxah+7CjRIH
+qR7SdFlYTP1jPk4yIKXF4OvggBUPySkMpAGJ9hwOMW8Ymcb4gn77JJ4aMoWIcXssje+XiC
+8iO+4UWU3SV2i6K7flK1UDCI5JVCyBr3DVf3QhMOgvwJl9TgD7FzWy1hkjuZq/Pzdv+fA2
+OfrUFiSukLNolidNoI9+KWa1yxixE+B2oN4Xan3ZbqGbL6Wc1dB+K9h/bNcu+SKf7fXWRi
+/vVG44A61xGDZzen1+eQlqFp7narkKXoaU71+45VXDThAAAAgBPWUdQykEEm0yOS6hPIW+
+hS8z1LXWGTEcag9fMwJXKE7cQFO3LEk+dXMbClHdhD/ydswOZYGSNepxwvmo/a5LiO2ulp
+W+5tnsNhcK3skdaf71t+boUEXBNZ6u3WNTkU7tDN8h9tebI+xlNceDGSGjOlNoHQVMKZdA
+W9TA4ZqXUPAAAAgQDWU0UZVOSCAOODPz4PYsbFKdCfXNP8O4+t9txyc9E3hsLAsVs+CpVX
+Gr219MGLrublzAxojipyzuQb6Tp1l9nsu7VkcBrPL8I1tokz0AyTnmNF3A9KszBal7gGNS
+a2qYuf6JO4cub1KzonxUJQHZPZq9YhCxOtDwTd+uyHZiPy9QAAAIEA5vayd+nfVJgCKTdf
+z5MFsxBSUj/cAYg7JYPS/0bZ5bEkLosL22wl5Tm/ZftJa8apkyBPhguAWt6jEWLoDiK+kn
+Fv0SaEq1HUdXgWmISVnWzv2pxdAtq/apmbxTg3iIJyrAwEDo13iImR3k6rNPx1m3i/jX56
+HLcvWM4Y6bFzbGEAAAA0U3RhdGljIFJTQSBrZXkgZm9yIGdvbGFuZy5vcmcveC9jcnlwdG
+8vc3NoIHVuaXQgdGVzdAECAwQFBgc=
+-----END OPENSSH PRIVATE KEY-----`
+
+	signer, err := ssh.ParsePrivateKey([]byte(privKeyData))
+	if err != nil {
+		t.Fatalf("failed to create signer from key: %v", err)
+	}
+
+	// Start server with MaxAuthTries 1 and publickey and password auth
+	// enabled
+	server := newServerForConfig(t, "MaxAuthTries", map[string]string{})
+
+	// Connect to server, expect failure, that PublicKeysCallback is called
+	// and that PasswordCallback is not called.
+	publicKeysCallbackCalled := false
+	config := clientConfig()
+	config.Auth = []ssh.AuthMethod{
+		ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
+			publicKeysCallbackCalled = true
+			return []ssh.Signer{signer}, nil
+		}),
+		ssh.PasswordCallback(func() (string, error) {
+			t.Errorf("unexpected call to PasswordCallback()")
+			return "notaverygoodpassword", nil
+		}),
+	}
+	client, err := server.TryDial(config)
+	if err == nil {
+		t.Errorf("expected TryDial() to fail")
+		_ = client.Close()
+	}
+	if !publicKeysCallbackCalled {
+		t.Errorf("expected PublicKeysCallback() to be called")
+	}
+}

+ 23 - 5
psiphon/common/crypto/ssh/test/test_unix_test.go

@@ -58,12 +58,17 @@ UsePAM yes
 PasswordAuthentication yes
 PasswordAuthentication yes
 ChallengeResponseAuthentication yes
 ChallengeResponseAuthentication yes
 AuthenticationMethods {{.AuthMethods}}
 AuthenticationMethods {{.AuthMethods}}
+`
+	maxAuthTriesSshdConfigTail = `
+PasswordAuthentication yes
+MaxAuthTries 1
 `
 `
 )
 )
 
 
 var configTmpl = map[string]*template.Template{
 var configTmpl = map[string]*template.Template{
-	"default":   template.Must(template.New("").Parse(defaultSshdConfig)),
-	"MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail))}
+	"default":      template.Must(template.New("").Parse(defaultSshdConfig)),
+	"MultiAuth":    template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail)),
+	"MaxAuthTries": template.Must(template.New("").Parse(defaultSshdConfig + maxAuthTriesSshdConfigTail))}
 
 
 type server struct {
 type server struct {
 	t          *testing.T
 	t          *testing.T
@@ -178,7 +183,7 @@ func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
 
 
 // addr is the user specified host:port. While we don't actually dial it,
 // addr is the user specified host:port. While we don't actually dial it,
 // we need to know this for host key matching
 // we need to know this for host key matching
-func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
+func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (client *ssh.Client, err error) {
 	sshd, err := exec.LookPath("sshd")
 	sshd, err := exec.LookPath("sshd")
 	if err != nil {
 	if err != nil {
 		s.t.Skipf("skipping test: %v", err)
 		s.t.Skipf("skipping test: %v", err)
@@ -188,13 +193,26 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl
 	if err != nil {
 	if err != nil {
 		s.t.Fatalf("unixConnection: %v", err)
 		s.t.Fatalf("unixConnection: %v", err)
 	}
 	}
+	defer func() {
+		// Close c2 after we've started the sshd command so that it won't prevent c1
+		// from returning EOF when the sshd command exits.
+		c2.Close()
+
+		// Leave c1 open if we're returning a client that wraps it.
+		// (The client is responsible for closing it.)
+		// Otherwise, close it to free up the socket.
+		if client == nil {
+			c1.Close()
+		}
+	}()
 
 
-	cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
 	f, err := c2.File()
 	f, err := c2.File()
 	if err != nil {
 	if err != nil {
 		s.t.Fatalf("UnixConn.File: %v", err)
 		s.t.Fatalf("UnixConn.File: %v", err)
 	}
 	}
 	defer f.Close()
 	defer f.Close()
+
+	cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
 	cmd.Stdin = f
 	cmd.Stdin = f
 	cmd.Stdout = f
 	cmd.Stdout = f
 	cmd.Stderr = new(bytes.Buffer)
 	cmd.Stderr = new(bytes.Buffer)
@@ -223,7 +241,7 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl
 		// processes are killed too.
 		// processes are killed too.
 		cmd.Process.Signal(os.Interrupt)
 		cmd.Process.Signal(os.Interrupt)
 		cmd.Wait()
 		cmd.Wait()
-		if s.t.Failed() {
+		if s.t.Failed() || testing.Verbose() {
 			// log any output from sshd process
 			// log any output from sshd process
 			s.t.Logf("sshd:\n%s", cmd.Stderr)
 			s.t.Logf("sshd:\n%s", cmd.Stderr)
 		}
 		}

+ 1 - 1
psiphon/common/crypto/ssh/testdata/doc.go

@@ -5,4 +5,4 @@
 // This package contains test data shared between the various subpackages of
 // This package contains test data shared between the various subpackages of
 // the github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh package. Under no circumstance should
 // the github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh package. Under no circumstance should
 // this data be used for production code.
 // this data be used for production code.
-package testdata // import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/testdata"
+package testdata

+ 29 - 17
psiphon/common/crypto/ssh/testdata/keys.go

@@ -46,19 +46,31 @@ FFlRjzoB3Oxu8UQgb+MWPedtH9XYBbg9biz4jJLkXQ==
 -----END EC PRIVATE KEY-----
 -----END EC PRIVATE KEY-----
 `),
 `),
 	"rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
 	"rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
-MIICXAIBAAKBgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2
-a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8
-Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQIDAQAB
-AoGAJMCk5vqfSRzyXOTXLGIYCuR4Kj6pdsbNSeuuRGfYBeR1F2c/XdFAg7D/8s5R
-38p/Ih52/Ty5S8BfJtwtvgVY9ecf/JlU/rl/QzhG8/8KC0NG7KsyXklbQ7gJT8UT
-Ojmw5QpMk+rKv17ipDVkQQmPaj+gJXYNAHqImke5mm/K/h0CQQDciPmviQ+DOhOq
-2ZBqUfH8oXHgFmp7/6pXw80DpMIxgV3CwkxxIVx6a8lVH9bT/AFySJ6vXq4zTuV9
-6QmZcZzDAkEA2j/UXJPIs1fQ8z/6sONOkU/BjtoePFIWJlRxdN35cZjXnBraX5UR
-fFHkePv4YwqmXNqrBOvSu+w2WdSDci+IKwJAcsPRc/jWmsrJW1q3Ha0hSf/WG/Bu
-X7MPuXaKpP/DkzGoUmb8ks7yqj6XWnYkPNLjCc8izU5vRwIiyWBRf4mxMwJBAILa
-NDvRS0rjwt6lJGv7zPZoqDc65VfrK2aNyHx2PgFyzwrEOtuF57bu7pnvEIxpLTeM
-z26i6XVMeYXAWZMTloMCQBbpGgEERQpeUknLBqUHhg/wXF6+lFA+vEGnkY+Dwab2
-KCXFGd+SQ5GdUcEMe9isUH6DYj/6/yCDoFrXXmpQb+M=
+MIIEpQIBAAKCAQEAnuozKMtcQkIImZGSe4IujS4+Lkas9jmlBivziWGU3waivkpU
+vYspgJbh7vSvnHOPtKscdIJ+3UUyViDUoM73GumsmHvfeRCoA9YROZK4fQR9G0a1
+wfoRqsrJXGToCzTvr/I2KIwpUG0bRE9rUvsW+JN9xgri+cIJWtu/dGYDkILO4bkF
+IxtEvHNVvhGLenyOHFhPw3hAZ7/bKq8kvKzm9D2zOllHe1wWncMkhVmEFF9Houeh
+jbddmeIAAxBpRUFfzp1dD7503ADBlJdK306D4CeI4KIFiqE1VrmfcMgP8fti0S0b
+4JtmvevYoPd+/wB9ItFqvhc6nyuxF0PfWH+SvwIDAQABAoIBAQCCcpNeSFjKVvRC
+Q1nwIrPd1njaec+fK0CIqWl3e2++B+9trwySrvp5gOGjyp2hGsd7Mf7gsQI81oF0
+a+y+uEXlhK3WWdDey0pwI7ft/7+LeDTOQCQRQBpijaXvPzGviVu7nWLRtARx7a41
+S9A4xL5dfI0BFYyuIpaVS8+EV/1TEJIbceZ5q5RBlARA1rc+nBvjygNzYdRq9Rao
+yyehvnXZ7pQrATnwofPolbZNseW2Q9sRMmtm1E60XJJ433P7nFbxXtsMAYgxWSQc
+V/92iRPYeD7sN/b7qulLEgC6e8el2gLGIB9aQyG7B6KFqloqvx/ymYs07+bWIQCU
+6i9y7LABAoGBANKB2Rs0lF5c+gpEE3w0AWoxGyL04TZECtl2hQ/b2jc2n8hLqLhv
+zNIKN+xzEP6j0ijXajjEMLiQH10qQ6+Plv8C7o8GJC1V/Oj4u4kbPqfVV1kSxuWm
+FBjz6+c8VPbEBgXq5lCMEgC2Ii8XVRoyd3iSSOh+LIMBu/Br3JEsjJq/AoGBAMFC
+CvODTiThrZo8v765dRSHrLOvB4jZOPrEKLWECLaQDQpuhgzZhyWm9zFfxGB+LWE9
+R9pU6ZCFPtPfd4cRZ9cezrp+lgdrqjcUX/2ZLLMk21WXFuRaw/4KTlKNsMY6lbAK
+qVkUFWIZWaCQ8bdVCP48polipRNmN9mAHsIhIwgBAoGBAM6yn1qeS11I0GAKLlPT
+wNvjsfCmIQmm0DxtqwRCbUevxD7pQ5cueCB51iW/ap2OgEqIEo4A3pIrOhDB8kpN
+pQdrepFHh3hYqYicy5A6B1DHJAibbl+Krss9n5KjZA4VtpBS8al/kCHQtUomD/M0
+QKlMgnh/g/dzWXYegyqtYraDAoGAGbeBH5B8iJnjcR/eYDHrq5S2XZ7QANzvISeT
+RzxPsIOQyK+WdQVJX7BNOqvExRZlUYhHFH2yKwIgLy+Qh0/Aora9ycFok4o3N2cl
+suh8M0aXTVdyu2Z8qESU0ZV7TZWkL63rhSgQBGLdM2m2ULAnJzXI74VJ9D/o9K+A
+6FJiiAECgYEAujJ/hKxVKEUvxwloGSDhKCUH86+7UOkb/EM2zZFrlPYAz1VcCwr3
+K14r8BtmLFXuLXOlACpoH0Wf4uia+t6n8m9JK3mvpJ7fempAsptP3AdZMQFe7xUm
+SXEGQBYxcyS5Q+ncwWZuPgby5wJ9D4Fd6TQH+wwG52sFugt/fGxbPug=
 -----END RSA PRIVATE KEY-----
 -----END RSA PRIVATE KEY-----
 `),
 `),
 	"rsa-sha2-256": []byte(`-----BEGIN RSA PRIVATE KEY-----
 	"rsa-sha2-256": []byte(`-----BEGIN RSA PRIVATE KEY-----
@@ -222,17 +234,17 @@ var SSHCertificates = map[string][]byte{
 	//
 	//
 	// 2. Assumes "ca" key above in file named "ca", sign a cert for "rsa.pub":
 	// 2. Assumes "ca" key above in file named "ca", sign a cert for "rsa.pub":
 	//    ssh-keygen -s ca -h -n host.example.com -V +500w -I host.example.com-key rsa.pub
 	//    ssh-keygen -s ca -h -n host.example.com -V +500w -I host.example.com-key rsa.pub
-	"rsa": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLjYqmmuTSEmjVhSfLQphBSTJMLwIZhRgmpn8FHKLiEIAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABZHN8UAAAAAGsjIYUAAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABDwAAAAdzc2gtcnNhAAABALeDea+60H6xJGhktAyosHaSY7AYzLocaqd8hJQjEIDifBwzoTlnBmcK9CxGhKuaoJFThdCLdaevCeOSuquh8HTkf+2ebZZc/G5T+2thPvPqmcuEcmMosWo+SIjYhbP3S6KD49aLC1X0kz8IBQeauFvURhkZ5ZjhA1L4aQYt9NjL73nqOl8PplRui+Ov5w8b4ldul4zOvYAFrzfcP6wnnXk3c1Zzwwf5wynD5jakO8GpYKBuhM7Z4crzkKSQjU3hla7xqgfomC5Gz4XbR2TNjcQiRrJQ0UlKtX3X3ObRCEhuvG0Kzjklhv+Ddw6txrhKjMjiSi/Yyius/AE8TmC1p4U= host.example.com
+	"rsa": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgWcP1x+t9PQTxkt/fRa8v5HXz0FFW/fwN3mpR5jdGo3UAAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAACAAAAFGhvc3QuZXhhbXBsZS5jb20ta2V5AAAAFAAAABBob3N0LmV4YW1wbGUuY29tAAAAAGXEzUQAAAAAd8sPtQAAAAAAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEPAAAAB3NzaC1yc2EAAAEAs8/78EsftS/0+6CfUAtRhmgpOPZVH8XJc2o4Qx/lQCBg/uB2sucGGx/5BuFUwAjhMNQ1UejJ6+AbNgZNlH7LIssGxpcW+eKQXLfd9KwWNPCU3DqZGiMBsYESNtVLjKAMuvJmxKFFu4rceyya+GZfsQFJR7G++XK0cQQ4rgDSnQo3pxQnHwguFWeHtwjdOj0helYER4XDKjKQlo0SxI1nVo7n1jkl8QghdxX6HKia6MP7MqstgNujTvrCEUJ6L3YflD/SJ00Y3ySD9+5hqvfqrDh4c4rRuc4+k0e3fYZwyWO2NNtKHRBOZAgW8zzfR3G99YdTw8N8+2oXeAdvB9k0tA== host.example.com
 `),
 `),
-	"rsa-sha2-256": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgOyK28gunJkM60qp4EbsYAjgbUsyjS8u742OLjipIgc0AAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABeSMJ4AAAAAHBPBLwAAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi0yNTYAAAEAbG4De/+QiqopPS3O1H7ySeEUCY56qmdgr02sFErnihdXPDaWXUXxacvJHaEtLrSTSaPL/3v3iKvjLWDOHaQ5c+cN9J7Tqzso7RQCXZD2nK9bwCUyBoiDyBCRe8w4DQEtfL5okpVzQsSAiojQ8hBohMOpy3gFfXrdm4PVC1ZKqlZh4fAc7ajieRq/Tpq2xOLdHwxkcgPNR83WVHva6K9/xjev/5n227/gkHo0qbGs8YYDOFXIEhENi+B23IzxdNVieWdyQpYpe0C2i95Jhyo0wJmaFY2ArruTS+D1jGQQpMPvAQRy26/A5hI83GLhpwyhrN/M8wCxzAhyPL6Ieuh5tQ== host.example.com
+	"rsa-sha2-256": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg4+hKHVPKv183MU/Q7XD/mzDBFSc2YY3eraltxLMGJo0AAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAACAAAAFGhvc3QuZXhhbXBsZS5jb20ta2V5AAAAFAAAABBob3N0LmV4YW1wbGUuY29tAAAAAGXEzYAAAAAAd8sP4wAAAAAAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEUAAAADHJzYS1zaGEyLTI1NgAAAQA/ByIegNZYJRRl413S/8LxGvTZnbxsPwaluoJ/54niGZV9P28THz7d9jXfSHPjalhH93jNPfTYXvI4opnDC37ua1Nu8KKfk40IWXnnDdZLWraUxEidIzhmfVtz8kGdGoFQ8H0EzubL7zKNOTlfSfOoDlmQVOuxT/+eh2mEp4ri0/+8J1mLfLBr8tREX0/iaNjK+RKdcyTMicKursAYMCDdu8vlaphxea+ocyHM9izSX/l33t44V13ueTqIOh2Zbl2UE2k+jk+0dc1CmV0SEoiWiIyt8TRM4yQry1vPlQLsrf28sYM/QMwnhCVhyZO3vs5F25aQWrB9d51VEzBW9/fd host.example.com
 `),
 `),
-	"rsa-sha2-512": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgFGv4IpXfs4L/Y0b3rmUdPFhWoUrVnXuPxXr6aHGs7wgAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABeSMRYAAAAAHBPBp4AAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi01MTIAAAEAnF4fVj6mm+UFeNCIf9AKJCv9WzymjjPvzzmaMWWkPWqoV0P0m5SiYfvbY9SbA73Blpv8SOr0DmpublF183kodREia4KyVuC8hLhSCV2Y16hy9MBegOZMepn80w+apj7Rn9QCz5OfEakDdztp6OWTBtqxnZFcTQ4XrgFkNWeWRElGdEvAVNn2WHwHi4EIdz0mdv48Imv5SPlOuW862ZdFG4Do1dUfDIiGsBofLlgcyIYlf+eNHul6sBeUkuwFxisMpI5DQzNp8PX1g/QJA2wzwT674PTqDXNttKjyh50Fdr4sXxm9Gz1+jVLoESvFNa55ERdSyAqNu4wTy11MZsWwSA== host.example.com
+	"rsa-sha2-512": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgylCoUX+0nYXGG9k/KEepcgEd22921eUwSQsYuQXKOswAAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAACAAAAFGhvc3QuZXhhbXBsZS5jb20ta2V5AAAAFAAAABBob3N0LmV4YW1wbGUuY29tAAAAAGXEzbwAAAAAd8sQBAAAAAAAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEUAAAADHJzYS1zaGEyLTUxMgAAAQADOjYUNxzwYZ9O3zjjZSKhX7Ix/vUBq8lIltBRckbKJtcjn2/qqtaeZK2ijkbMlnPFCJ59U+Z6m4DMU6gxYF8R9/WJrANlWYNQ4+52fXjE/FPDdJkwB/kPWABt+ffEiM1HfzP4/zXgTtOJ0GogEzTYoMSrhAlpdKACUaC9nCQ7gjmO+owGkrB3ZbyQS4gltqQZjiBNqF0P/vxmXP+0Kx66/ei8JNYUpEPCXT4ZhLZmsVptaYD1gpvc2i5T5LnjRrsvf4Q6EKR+gwJEjlwnylhH5+h+ZU/KXydkw1hhEVP+ZiIBkVo9nN78Qb/VMJum4Fdoltuyh/9k2lMDeyDuRVh7 host.example.com
 `),
 `),
 	// Assumes "ca" key above in file named "ca", sign a user cert for "rsa.pub"
 	// Assumes "ca" key above in file named "ca", sign a user cert for "rsa.pub"
 	// using "testcertificate" as principal:
 	// using "testcertificate" as principal:
 	//
 	//
 	// ssh-keygen -s ca -I username -n testcertificate rsa.pub
 	// ssh-keygen -s ca -I username -n testcertificate rsa.pub
-	"rsa-user-testcertificate": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgr0MnhSf+KkQWEQmweJOGsJfOrUt80pQZDaI9YiwSfDUAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAQAAAAh1c2VybmFtZQAAABMAAAAPdGVzdGNlcnRpZmljYXRlAAAAAAAAAAD//////////wAAAAAAAACCAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFnBlcm1pdC1wb3J0LWZvcndhcmRpbmcAAAAAAAAACnBlcm1pdC1wdHkAAAAAAAAADnBlcm1pdC11c2VyLXJjAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABFAAAAAxyc2Etc2hhMi01MTIAAAEAFuA+67KvnlmcodIp0Lv4mR9UW/CHghAaN1csBJTkI8mx3wXKyIPTsS2uXboEhWD0a7S9gps2SEwC5m6E3kV2Rzg7aH1S03GZqMvVlK2VHe7fzuoW2yOKk6yEPjeTF0pKCFbUQ6mce8pRpD/zdvjG0Z287XM3c3Axlrn7qq7TS0MDTjEZ/dsUNFHxep3co/HuAsWVWPVDItr/FPnvZ6WVH1yc8N/AJn0gLHobkGgug22vI9rNIge1wrnXxU9BUSouzkau/PQsrCQapnn+I1H7HaQt0wdk45nxMP+ags+HRI9qpX/p8WDn6+zpqYqN/nfw2aoytyaJqhsV32Teuqtrgg== rsa.pub
+	"rsa-user-testcertificate": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgEKR9xnhXkbi/pgP669VjBH6XVTYR0yx1wunCtOIjzjUAAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAABAAAACHVzZXJuYW1lAAAAEwAAAA90ZXN0Y2VydGlmaWNhdGUAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEUAAAADHJzYS1zaGEyLTUxMgAAAQCKVn2S7FJYhXTRVbcz1Di1HLz1g5Yae5WBhd0Tg471XkNw7ylcCK23Wnrzj1GxrW0oWCCGHROtUnxQXei1xNWt8HONN+eeafrJSZJR6ald3Yd4OveXlHNT6mEDPgqRj4B56OPoY33LzpaFlQZlZ6U9KXySshNaCTjVp3ojTj6uPNxcuOnG9O5emEPC2eaM4QYsz4cqHNJ9SWWEu+HIQgpx5SM12qcrq/KhN6WJhG3edL1YAsxkf1/THEcfi6wvsHi+DKewJ5hZ876//japjHBKA9SB7gsCrLx3m+0XI8TkWUZTZJHETJHHVjJN8xjRvMv5gWKLvRsz7ScmnN1+ckL6 rsa.pub
 `),
 `),
 }
 }