Browse Source

Update to Go 1.23

Amir Khan 1 year ago
parent
commit
cebcf25b74
100 changed files with 7008 additions and 3505 deletions
  1. 19 17
      go.mod
  2. 30 29
      go.sum
  3. 0 21
      vendor/github.com/Psiphon-Labs/psiphon-tls/.gitignore
  4. 2 0
      vendor/github.com/Psiphon-Labs/psiphon-tls/alert.go
  5. 1 1
      vendor/github.com/Psiphon-Labs/psiphon-tls/auth.go
  6. 0 102
      vendor/github.com/Psiphon-Labs/psiphon-tls/boring.go
  7. 149 0
      vendor/github.com/Psiphon-Labs/psiphon-tls/byteorder/byteorder.go
  8. 65 34
      vendor/github.com/Psiphon-Labs/psiphon-tls/cipher_suites.go
  9. 206 98
      vendor/github.com/Psiphon-Labs/psiphon-tls/common.go
  10. 5 1
      vendor/github.com/Psiphon-Labs/psiphon-tls/common_string.go
  11. 54 20
      vendor/github.com/Psiphon-Labs/psiphon-tls/conn.go
  12. 139 0
      vendor/github.com/Psiphon-Labs/psiphon-tls/defaults.go
  13. 284 0
      vendor/github.com/Psiphon-Labs/psiphon-tls/ech.go
  14. 248 74
      vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_client.go
  15. 190 44
      vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_client_tls13.go
  16. 368 263
      vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_messages.go
  17. 27 19
      vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_server.go
  18. 91 45
      vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_server_tls13.go
  19. 21 0
      vendor/github.com/Psiphon-Labs/psiphon-tls/internal/boring/notboring.go
  20. 259 0
      vendor/github.com/Psiphon-Labs/psiphon-tls/internal/hpke/hpke.go
  21. 887 0
      vendor/github.com/Psiphon-Labs/psiphon-tls/internal/mlkem768/mlkem768.go
  22. 4 4
      vendor/github.com/Psiphon-Labs/psiphon-tls/key_agreement.go
  23. 42 0
      vendor/github.com/Psiphon-Labs/psiphon-tls/key_schedule.go
  24. 0 13
      vendor/github.com/Psiphon-Labs/psiphon-tls/notboring.go
  25. 9 2
      vendor/github.com/Psiphon-Labs/psiphon-tls/prf.go
  26. 2 2
      vendor/github.com/Psiphon-Labs/psiphon-tls/public.go
  27. 91 12
      vendor/github.com/Psiphon-Labs/psiphon-tls/quic.go
  28. 11 7
      vendor/github.com/Psiphon-Labs/psiphon-tls/ticket.go
  29. 35 13
      vendor/github.com/Psiphon-Labs/psiphon-tls/tls.go
  30. 49 18
      vendor/github.com/Psiphon-Labs/psiphon-tls/unsafe.go
  31. 1 0
      vendor/github.com/Psiphon-Labs/quic-go/.gitignore
  32. 22 15
      vendor/github.com/Psiphon-Labs/quic-go/.golangci.yml
  33. 7 203
      vendor/github.com/Psiphon-Labs/quic-go/README.md
  34. 1 149
      vendor/github.com/Psiphon-Labs/quic-go/client.go
  35. 17 23
      vendor/github.com/Psiphon-Labs/quic-go/closed_conn.go
  36. 4 0
      vendor/github.com/Psiphon-Labs/quic-go/codecov.yml
  37. 11 12
      vendor/github.com/Psiphon-Labs/quic-go/config.go
  38. 19 19
      vendor/github.com/Psiphon-Labs/quic-go/conn_id_generator.go
  39. 22 0
      vendor/github.com/Psiphon-Labs/quic-go/conn_id_manager.go
  40. 216 186
      vendor/github.com/Psiphon-Labs/quic-go/connection.go
  41. 168 0
      vendor/github.com/Psiphon-Labs/quic-go/connection_logging.go
  42. 13 36
      vendor/github.com/Psiphon-Labs/quic-go/crypto_stream.go
  43. 23 28
      vendor/github.com/Psiphon-Labs/quic-go/crypto_stream_manager.go
  44. 5 5
      vendor/github.com/Psiphon-Labs/quic-go/errors.go
  45. 185 84
      vendor/github.com/Psiphon-Labs/quic-go/framer.go
  46. 3 98
      vendor/github.com/Psiphon-Labs/quic-go/http3/README.md
  47. 70 73
      vendor/github.com/Psiphon-Labs/quic-go/http3/body.go
  48. 26 7
      vendor/github.com/Psiphon-Labs/quic-go/http3/capsule.go
  49. 198 326
      vendor/github.com/Psiphon-Labs/quic-go/http3/client.go
  50. 329 0
      vendor/github.com/Psiphon-Labs/quic-go/http3/conn.go
  51. 98 0
      vendor/github.com/Psiphon-Labs/quic-go/http3/datagram.go
  52. 5 0
      vendor/github.com/Psiphon-Labs/quic-go/http3/error.go
  53. 72 15
      vendor/github.com/Psiphon-Labs/quic-go/http3/frames.go
  54. 80 18
      vendor/github.com/Psiphon-Labs/quic-go/http3/headers.go
  55. 211 40
      vendor/github.com/Psiphon-Labs/quic-go/http3/http_stream.go
  56. 48 0
      vendor/github.com/Psiphon-Labs/quic-go/http3/ip_addr.go
  57. 2 2
      vendor/github.com/Psiphon-Labs/quic-go/http3/mockgen.go
  58. 20 18
      vendor/github.com/Psiphon-Labs/quic-go/http3/request_writer.go
  59. 244 107
      vendor/github.com/Psiphon-Labs/quic-go/http3/response_writer.go
  60. 0 303
      vendor/github.com/Psiphon-Labs/quic-go/http3/roundtrip.go
  61. 298 218
      vendor/github.com/Psiphon-Labs/quic-go/http3/server.go
  62. 116 0
      vendor/github.com/Psiphon-Labs/quic-go/http3/state_tracking_stream.go
  63. 110 0
      vendor/github.com/Psiphon-Labs/quic-go/http3/trace.go
  64. 466 0
      vendor/github.com/Psiphon-Labs/quic-go/http3/transport.go
  65. 50 33
      vendor/github.com/Psiphon-Labs/quic-go/interface.go
  66. 1 1
      vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/ackhandler.go
  67. 6 7
      vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/interfaces.go
  68. 14 31
      vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/received_packet_handler.go
  69. 47 57
      vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/received_packet_history.go
  70. 91 67
      vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/received_packet_tracker.go
  71. 76 89
      vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/sent_packet_handler.go
  72. 8 3
      vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/sent_packet_history.go
  73. 4 4
      vendor/github.com/Psiphon-Labs/quic-go/internal/congestion/cubic.go
  74. 1 1
      vendor/github.com/Psiphon-Labs/quic-go/internal/congestion/cubic_sender.go
  75. 8 12
      vendor/github.com/Psiphon-Labs/quic-go/internal/flowcontrol/base_flow_controller.go
  76. 30 29
      vendor/github.com/Psiphon-Labs/quic-go/internal/flowcontrol/connection_flow_controller.go
  77. 16 11
      vendor/github.com/Psiphon-Labs/quic-go/internal/flowcontrol/interface.go
  78. 29 24
      vendor/github.com/Psiphon-Labs/quic-go/internal/flowcontrol/stream_flow_controller.go
  79. 18 21
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/aead.go
  80. 4 5
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/cipher_suite.go
  81. 28 57
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/crypto_setup.go
  82. 5 7
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/header_protector.go
  83. 1 1
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/hkdf.go
  84. 4 5
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/initial_aead.go
  85. 26 3
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/interface.go
  86. 9 6
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/retry.go
  87. 6 6
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/session_ticket.go
  88. 1 1
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/token_generator.go
  89. 6 14
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/token_protector.go
  90. 4 5
      vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/updatable_aead.go
  91. 0 50
      vendor/github.com/Psiphon-Labs/quic-go/internal/logutils/frame.go
  92. 15 1
      vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/connection_id.go
  93. 23 45
      vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/packet_number.go
  94. 7 12
      vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/params.go
  95. 2 2
      vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/perspective.go
  96. 34 25
      vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/version.go
  97. 1 2
      vendor/github.com/Psiphon-Labs/quic-go/internal/qerr/error_codes.go
  98. 25 30
      vendor/github.com/Psiphon-Labs/quic-go/internal/qerr/errors.go
  99. 0 12
      vendor/github.com/Psiphon-Labs/quic-go/internal/qtls/cipher_suite.go
  100. 10 2
      vendor/github.com/Psiphon-Labs/quic-go/internal/qtls/client_session_cache.go

+ 19 - 17
go.mod

@@ -1,6 +1,8 @@
 module github.com/Psiphon-Labs/psiphon-tunnel-core
 
-go 1.21
+go 1.23
+
+toolchain go1.23.6
 
 // The following replace is required only when the build tag
 // PSIPHON_ENABLE_REFRACTION_NETWORKING is specified.
@@ -37,9 +39,9 @@ require (
 	github.com/Psiphon-Labs/bolt v0.0.0-20200624191537-23cedaef7ad7
 	github.com/Psiphon-Labs/consistent v0.0.0-20240322131436-20aaa4e05737
 	github.com/Psiphon-Labs/goptlib v0.0.0-20200406165125-c0e32a7a3464
-	github.com/Psiphon-Labs/psiphon-tls v0.0.0-20240824224428-ca6969e315a9
-	github.com/Psiphon-Labs/quic-go v0.0.0-20250203210204-a4381c68e52f
-	github.com/Psiphon-Labs/utls v1.1.1-0.20241107183331-b18909f8ccaa
+	github.com/Psiphon-Labs/psiphon-tls v0.0.0-20250219165059-533f95b512e9
+	github.com/Psiphon-Labs/quic-go v0.0.0-20250226213529-818b69c11139
+	github.com/Psiphon-Labs/utls v0.0.0-20250228222508-0e6c20273fcc
 	github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f
 	github.com/bifurcation/mint v0.0.0-20180306135233-198357931e61
 	github.com/bits-and-blooms/bloom/v3 v3.6.0
@@ -66,7 +68,9 @@ require (
 	github.com/pion/datachannel v1.5.5
 	github.com/pion/dtls/v2 v2.2.7
 	github.com/pion/ice/v2 v2.3.24
+	github.com/pion/interceptor v0.1.25
 	github.com/pion/logging v0.2.2
+	github.com/pion/rtp v1.8.5
 	github.com/pion/sctp v1.8.16
 	github.com/pion/sdp/v3 v3.0.9
 	github.com/pion/stun v0.6.1
@@ -81,11 +85,11 @@ require (
 	github.com/syndtr/gocapability v0.0.0-20170704070218-db04d3cc01c8
 	github.com/wader/filtertransport v0.0.0-20200316221534-bdd9e61eee78
 	github.com/wlynxg/anet v0.0.1
-	golang.org/x/crypto v0.22.0
-	golang.org/x/net v0.24.0
-	golang.org/x/sync v0.6.0
-	golang.org/x/sys v0.20.0
-	golang.org/x/term v0.19.0
+	golang.org/x/crypto v0.32.0
+	golang.org/x/net v0.34.0
+	golang.org/x/sync v0.10.0
+	golang.org/x/sys v0.29.0
+	golang.org/x/term v0.28.0
 	golang.org/x/time v0.5.0
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
@@ -97,9 +101,9 @@ require (
 	filippo.io/keygen v0.0.0-20230306160926-5201437acf8e // indirect
 	github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
 	github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect
-	github.com/andybalholm/brotli v1.0.6 // indirect
+	github.com/andybalholm/brotli v1.1.1 // indirect
 	github.com/bits-and-blooms/bitset v1.10.0 // indirect
-	github.com/cloudflare/circl v1.3.7 // indirect
+	github.com/cloudflare/circl v1.5.0 // indirect
 	github.com/coreos/go-iptables v0.7.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/dblohm7/wingoes v0.0.0-20230929194252-e994401fc077 // indirect
@@ -115,7 +119,7 @@ require (
 	github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 // indirect
 	github.com/jsimonetti/rtnetlink v1.3.5 // indirect
 	github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
-	github.com/klauspost/compress v1.17.4 // indirect
+	github.com/klauspost/compress v1.17.11 // indirect
 	github.com/libp2p/go-reuseport v0.4.0 // indirect
 	github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
 	github.com/mdlayher/netlink v1.7.2 // indirect
@@ -123,11 +127,9 @@ require (
 	github.com/mroth/weightedrand v1.0.0 // indirect
 	github.com/onsi/ginkgo/v2 v2.12.0 // indirect
 	github.com/pelletier/go-toml v1.9.5 // indirect
-	github.com/pion/interceptor v0.1.25 // indirect
 	github.com/pion/mdns v0.0.12 // indirect
 	github.com/pion/randutil v0.1.0 // indirect
 	github.com/pion/rtcp v1.2.12 // indirect
-	github.com/pion/rtp v1.8.5 // indirect
 	github.com/pion/srtp/v2 v2.0.18 // indirect
 	github.com/pion/turn/v2 v2.1.3 // indirect
 	github.com/pkg/errors v0.9.1 // indirect
@@ -153,9 +155,9 @@ require (
 	go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
 	go4.org/netipx v0.0.0-20230824141953-6213f710f925 // indirect
 	golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e // indirect
-	golang.org/x/mod v0.14.0 // indirect
-	golang.org/x/text v0.14.0 // indirect
-	golang.org/x/tools v0.16.0 // indirect
+	golang.org/x/mod v0.17.0 // indirect
+	golang.org/x/text v0.21.0 // indirect
+	golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
 	google.golang.org/protobuf v1.31.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 30 - 29
go.sum

@@ -22,18 +22,16 @@ github.com/Psiphon-Labs/consistent v0.0.0-20240322131436-20aaa4e05737 h1:QTMy7Uc
 github.com/Psiphon-Labs/consistent v0.0.0-20240322131436-20aaa4e05737/go.mod h1:Enj/Gszv2zCbuRbHbabmNvfO9EM+5kmaGj8CyjwNPlY=
 github.com/Psiphon-Labs/goptlib v0.0.0-20200406165125-c0e32a7a3464 h1:VmnMMMheFXwLV0noxYhbJbLmkV4iaVW3xNnj6xcCNHo=
 github.com/Psiphon-Labs/goptlib v0.0.0-20200406165125-c0e32a7a3464/go.mod h1:Pe5BqN2DdIdChorAXl6bDaQd/wghpCleJfid2NoSli0=
-github.com/Psiphon-Labs/psiphon-tls v0.0.0-20240824224428-ca6969e315a9 h1:AJj1cSg5gW6vWi1spfMmRi8UmVG0PSJU2NXUtWNBelE=
-github.com/Psiphon-Labs/psiphon-tls v0.0.0-20240824224428-ca6969e315a9/go.mod h1:AaKKoshr8RI1LZTheeNDtNuZ39qNVPWVK4uir2c2XIs=
-github.com/Psiphon-Labs/quic-go v0.0.0-20240821052333-b6316b594e39 h1:ft0K9EDdBtMl+Q/akZ+qt3SdcmbtnTQOgE3OlWI6uz0=
-github.com/Psiphon-Labs/quic-go v0.0.0-20240821052333-b6316b594e39/go.mod h1:2MTiPsgoOqWs3Bo6Xr3ElMBX6zzfjd3YkDFpQJLwHdQ=
-github.com/Psiphon-Labs/quic-go v0.0.0-20250203210204-a4381c68e52f h1:18J32AZR8UYwmANk4V00UqSwsOus4wbD+0wp4FTW88k=
-github.com/Psiphon-Labs/quic-go v0.0.0-20250203210204-a4381c68e52f/go.mod h1:2MTiPsgoOqWs3Bo6Xr3ElMBX6zzfjd3YkDFpQJLwHdQ=
-github.com/Psiphon-Labs/utls v1.1.1-0.20241107183331-b18909f8ccaa h1:5FszHIhxb7yO267qt47tTfJOtD31k7R80L88EwNm4tc=
-github.com/Psiphon-Labs/utls v1.1.1-0.20241107183331-b18909f8ccaa/go.mod h1:dxmztdV9lf59cq44YY8r21m3b+xSjhg98cgZW8WK1p0=
+github.com/Psiphon-Labs/psiphon-tls v0.0.0-20250219165059-533f95b512e9 h1:PjzuvkU8C0My+ixI+FWiJYV9PbALsw8uA1F8HrqPG/w=
+github.com/Psiphon-Labs/psiphon-tls v0.0.0-20250219165059-533f95b512e9/go.mod h1:7ZUnPnWT5z8J8hxfsVjKHYK77Zme/Y0If1b/zeziiJs=
+github.com/Psiphon-Labs/quic-go v0.0.0-20250226213529-818b69c11139 h1:FG1ovy7+hwLuHRl59LOC937Q1pk6+tMkVRF4FeFWd5g=
+github.com/Psiphon-Labs/quic-go v0.0.0-20250226213529-818b69c11139/go.mod h1:rONdWgPMbFjyyBai7gB1IBF4pT9r4l0GyiDst5XR1SY=
+github.com/Psiphon-Labs/utls v0.0.0-20250228222508-0e6c20273fcc h1:ojzcP5Hia0pAidJvnNAd2DaA/siX9vPDTPC9kvhDRFY=
+github.com/Psiphon-Labs/utls v0.0.0-20250228222508-0e6c20273fcc/go.mod h1:1vv0gVAzq9e2XYkW8HAKrmtuuZrBdDixQFx5H22KAjI=
 github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
 github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
-github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
-github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
+github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
+github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
 github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f h1:SaJ6yqg936TshyeFZqQE+N+9hYkIeL9AMr7S4voCl10=
 github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU=
 github.com/bifurcation/mint v0.0.0-20180306135233-198357931e61 h1:BU+NxuoaYPIvvp8NNkNlLr8aA0utGyuunf4Q3LJ0bh0=
@@ -53,8 +51,8 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P
 github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
 github.com/cilium/ebpf v0.11.0 h1:V8gS/bTCCjX9uUnkUFUpPsksM8n1lXBAvHcpiFk1X2Y=
 github.com/cilium/ebpf v0.11.0/go.mod h1:WE7CZAnqOL2RouJ4f1uyNhqr2P4CCvXFIqdRDUgWsVs=
-github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU=
-github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA=
+github.com/cloudflare/circl v1.5.0 h1:hxIWksrX6XN5a1L2TI/h53AGPhNHoUBo+TD1ms9+pys=
+github.com/cloudflare/circl v1.5.0/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
 github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea h1:9C2rdYRp8Vzwhm3sbFX0yYfB+70zKFRjn7cnPCucHSw=
 github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea/go.mod h1:MdyNkAe06D7xmJsf+MsLvbZKYNXuOHLKJrvw+x4LlcQ=
 github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
@@ -133,8 +131,8 @@ github.com/jsimonetti/rtnetlink v1.3.5 h1:hVlNQNRlLDGZz31gBPicsG7Q53rnlsz1l1Ix/9
 github.com/jsimonetti/rtnetlink v1.3.5/go.mod h1:0LFedyiTkebnd43tE4YAkWGIq9jQphow4CcwxaT2Y00=
 github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA=
 github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8=
-github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
-github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
+github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
+github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
 github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
 github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
 github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -291,6 +289,8 @@ github.com/wlynxg/anet v0.0.1 h1:VbkEEgHxPSrRQSiyRd0pmrbcEQAEU2TTb8fb4DmSYoQ=
 github.com/wlynxg/anet v0.0.1/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
 github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
 github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
+github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
+github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
 github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
 github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
 github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
@@ -311,16 +311,16 @@ golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE
 golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
 golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
 golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
-golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
-golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
+golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
+golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
 golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e h1:723BNChdd0c2Wk6WOE320qGBiPtYx0F0Bbm1kriShfE=
 golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
 golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
 golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
-golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
-golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
+golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
+golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
 golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
 golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -336,14 +336,14 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
 golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
 golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
-golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
-golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
+golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
+golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
-golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
+golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -373,8 +373,8 @@ golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
-golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
+golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -385,8 +385,8 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
 golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
 golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
 golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
-golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q=
-golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
+golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
+golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
@@ -395,8 +395,9 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
 golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
 golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
 golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
-golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
 golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
+golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
+golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
 golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
 golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -404,8 +405,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
 golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
 golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
-golang.org/x/tools v0.16.0 h1:GO788SKMRunPIBCXiQyo2AaexLstOrVhuAL5YwsckQM=
-golang.org/x/tools v0.16.0/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0=
+golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
+golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

+ 0 - 21
vendor/github.com/Psiphon-Labs/psiphon-tls/.gitignore

@@ -1,21 +0,0 @@
-# If you prefer the allow list template instead of the deny list, see community template:
-# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
-#
-# Binaries for programs and plugins
-*.exe
-*.exe~
-*.dll
-*.so
-*.dylib
-
-# Test binary, built with `go test -c`
-*.test
-
-# Output of the go coverage tool, specifically when used with LiteIDE
-*.out
-
-# Dependency directories (remove the comment below to include it)
-# vendor/
-
-# Go workspace file
-go.work

+ 2 - 0
vendor/github.com/Psiphon-Labs/psiphon-tls/alert.go

@@ -58,6 +58,7 @@ const (
 	alertUnknownPSKIdentity           alert = 115
 	alertCertificateRequired          alert = 116
 	alertNoApplicationProtocol        alert = 120
+	alertECHRequired                  alert = 121
 )
 
 var alertText = map[alert]string{
@@ -94,6 +95,7 @@ var alertText = map[alert]string{
 	alertUnknownPSKIdentity:           "unknown PSK identity",
 	alertCertificateRequired:          "certificate required",
 	alertNoApplicationProtocol:        "no application protocol",
+	alertECHRequired:                  "encrypted client hello required",
 }
 
 func (e alert) String() string {

+ 1 - 1
vendor/github.com/Psiphon-Labs/psiphon-tls/auth.go

@@ -242,7 +242,7 @@ func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureSche
 	// Pick signature scheme in the peer's preference order, as our
 	// preference order is not configurable.
 	for _, preferredAlg := range peerAlgs {
-		if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) {
+		if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, defaultSupportedSignatureAlgorithmsFIPS) {
 			continue
 		}
 		if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) {

+ 0 - 102
vendor/github.com/Psiphon-Labs/psiphon-tls/boring.go

@@ -1,102 +0,0 @@
-// Copyright 2017 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build boringcrypto
-
-package tls
-
-import "crypto/internal/boring/fipstls"
-
-// The FIPS-only policies enforced here currently match BoringSSL's
-// ssl_policy_fips_202205.
-
-// needFIPS returns fipstls.Required(); it avoids a new import in common.go.
-func needFIPS() bool {
-	return fipstls.Required()
-}
-
-// fipsMinVersion replaces c.minVersion in FIPS-only mode.
-func fipsMinVersion(c *Config) uint16 {
-	// FIPS requires TLS 1.2 or TLS 1.3.
-	return VersionTLS12
-}
-
-// fipsMaxVersion replaces c.maxVersion in FIPS-only mode.
-func fipsMaxVersion(c *Config) uint16 {
-	// FIPS requires TLS 1.2 or TLS 1.3.
-	return VersionTLS13
-}
-
-// default defaultFIPSCurvePreferences is the FIPS-allowed curves,
-// in preference order (most preferable first).
-var defaultFIPSCurvePreferences = []CurveID{CurveP256, CurveP384}
-
-// fipsCurvePreferences replaces c.curvePreferences in FIPS-only mode.
-func fipsCurvePreferences(c *Config) []CurveID {
-	if c == nil || len(c.CurvePreferences) == 0 {
-		return defaultFIPSCurvePreferences
-	}
-	var list []CurveID
-	for _, id := range c.CurvePreferences {
-		for _, allowed := range defaultFIPSCurvePreferences {
-			if id == allowed {
-				list = append(list, id)
-				break
-			}
-		}
-	}
-	return list
-}
-
-// defaultCipherSuitesFIPS are the FIPS-allowed cipher suites.
-var defaultCipherSuitesFIPS = []uint16{
-	TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
-	TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
-	TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
-	TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
-}
-
-// fipsCipherSuites replaces c.cipherSuites in FIPS-only mode.
-func fipsCipherSuites(c *Config) []uint16 {
-	if c == nil || c.CipherSuites == nil {
-		return defaultCipherSuitesFIPS
-	}
-	list := make([]uint16, 0, len(defaultCipherSuitesFIPS))
-	for _, id := range c.CipherSuites {
-		for _, allowed := range defaultCipherSuitesFIPS {
-			if id == allowed {
-				list = append(list, id)
-				break
-			}
-		}
-	}
-	return list
-}
-
-// defaultCipherSuitesTLS13FIPS are the FIPS-allowed cipher suites for TLS 1.3.
-var defaultCipherSuitesTLS13FIPS = []uint16{
-	TLS_AES_128_GCM_SHA256,
-	TLS_AES_256_GCM_SHA384,
-}
-
-// fipsSupportedSignatureAlgorithms currently are a subset of
-// defaultSupportedSignatureAlgorithms without Ed25519, SHA-1, and P-521.
-var fipsSupportedSignatureAlgorithms = []SignatureScheme{
-	PSSWithSHA256,
-	PSSWithSHA384,
-	PSSWithSHA512,
-	PKCS1WithSHA256,
-	ECDSAWithP256AndSHA256,
-	PKCS1WithSHA384,
-	ECDSAWithP384AndSHA384,
-	PKCS1WithSHA512,
-}
-
-// supportedSignatureAlgorithms returns the supported signature algorithms.
-func supportedSignatureAlgorithms() []SignatureScheme {
-	if !needFIPS() {
-		return defaultSupportedSignatureAlgorithms
-	}
-	return fipsSupportedSignatureAlgorithms
-}

+ 149 - 0
vendor/github.com/Psiphon-Labs/psiphon-tls/byteorder/byteorder.go

@@ -0,0 +1,149 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package byteorder provides functions for decoding and encoding
+// little and big endian integer types from/to byte slices.
+package byteorder
+
+func LeUint16(b []byte) uint16 {
+	_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
+	return uint16(b[0]) | uint16(b[1])<<8
+}
+
+func LePutUint16(b []byte, v uint16) {
+	_ = b[1] // early bounds check to guarantee safety of writes below
+	b[0] = byte(v)
+	b[1] = byte(v >> 8)
+}
+
+func LeAppendUint16(b []byte, v uint16) []byte {
+	return append(b,
+		byte(v),
+		byte(v>>8),
+	)
+}
+
+func LeUint32(b []byte) uint32 {
+	_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
+	return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+func LePutUint32(b []byte, v uint32) {
+	_ = b[3] // early bounds check to guarantee safety of writes below
+	b[0] = byte(v)
+	b[1] = byte(v >> 8)
+	b[2] = byte(v >> 16)
+	b[3] = byte(v >> 24)
+}
+
+func LeAppendUint32(b []byte, v uint32) []byte {
+	return append(b,
+		byte(v),
+		byte(v>>8),
+		byte(v>>16),
+		byte(v>>24),
+	)
+}
+
+func LeUint64(b []byte) uint64 {
+	_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
+	return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
+		uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
+}
+
+func LePutUint64(b []byte, v uint64) {
+	_ = b[7] // early bounds check to guarantee safety of writes below
+	b[0] = byte(v)
+	b[1] = byte(v >> 8)
+	b[2] = byte(v >> 16)
+	b[3] = byte(v >> 24)
+	b[4] = byte(v >> 32)
+	b[5] = byte(v >> 40)
+	b[6] = byte(v >> 48)
+	b[7] = byte(v >> 56)
+}
+
+func LeAppendUint64(b []byte, v uint64) []byte {
+	return append(b,
+		byte(v),
+		byte(v>>8),
+		byte(v>>16),
+		byte(v>>24),
+		byte(v>>32),
+		byte(v>>40),
+		byte(v>>48),
+		byte(v>>56),
+	)
+}
+
+func BeUint16(b []byte) uint16 {
+	_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
+	return uint16(b[1]) | uint16(b[0])<<8
+}
+
+func BePutUint16(b []byte, v uint16) {
+	_ = b[1] // early bounds check to guarantee safety of writes below
+	b[0] = byte(v >> 8)
+	b[1] = byte(v)
+}
+
+func BeAppendUint16(b []byte, v uint16) []byte {
+	return append(b,
+		byte(v>>8),
+		byte(v),
+	)
+}
+
+func BeUint32(b []byte) uint32 {
+	_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
+	return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
+}
+
+func BePutUint32(b []byte, v uint32) {
+	_ = b[3] // early bounds check to guarantee safety of writes below
+	b[0] = byte(v >> 24)
+	b[1] = byte(v >> 16)
+	b[2] = byte(v >> 8)
+	b[3] = byte(v)
+}
+
+func BeAppendUint32(b []byte, v uint32) []byte {
+	return append(b,
+		byte(v>>24),
+		byte(v>>16),
+		byte(v>>8),
+		byte(v),
+	)
+}
+
+func BeUint64(b []byte) uint64 {
+	_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
+	return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
+		uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
+}
+
+func BePutUint64(b []byte, v uint64) {
+	_ = b[7] // early bounds check to guarantee safety of writes below
+	b[0] = byte(v >> 56)
+	b[1] = byte(v >> 48)
+	b[2] = byte(v >> 40)
+	b[3] = byte(v >> 32)
+	b[4] = byte(v >> 24)
+	b[5] = byte(v >> 16)
+	b[6] = byte(v >> 8)
+	b[7] = byte(v)
+}
+
+func BeAppendUint64(b []byte, v uint64) []byte {
+	return append(b,
+		byte(v>>56),
+		byte(v>>48),
+		byte(v>>40),
+		byte(v>>32),
+		byte(v>>24),
+		byte(v>>16),
+		byte(v>>8),
+		byte(v),
+	)
+}

+ 65 - 34
vendor/github.com/Psiphon-Labs/psiphon-tls/cipher_suites.go

@@ -16,6 +16,9 @@ import (
 	"fmt"
 	"hash"
 	"runtime"
+	_ "unsafe" // for linkname
+
+	"github.com/Psiphon-Labs/psiphon-tls/internal/boring"
 
 	"golang.org/x/crypto/chacha20poly1305"
 	"golang.org/x/sys/cpu"
@@ -44,18 +47,13 @@ var (
 
 // CipherSuites returns a list of cipher suites currently implemented by this
 // package, excluding those with security issues, which are returned by
-// InsecureCipherSuites.
+// [InsecureCipherSuites].
 //
 // The list is sorted by ID. Note that the default cipher suites selected by
 // this package might depend on logic that can't be captured by a static list,
 // and might not match those returned by this function.
 func CipherSuites() []*CipherSuite {
 	return []*CipherSuite{
-		{TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
-		{TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
-		{TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
-		{TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
-
 		{TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false},
 		{TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false},
 		{TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false},
@@ -77,14 +75,18 @@ func CipherSuites() []*CipherSuite {
 // this package and which have security issues.
 //
 // Most applications should not use the cipher suites in this list, and should
-// only use those returned by CipherSuites.
+// only use those returned by [CipherSuites].
 func InsecureCipherSuites() []*CipherSuite {
 	// This list includes RC4, CBC_SHA256, and 3DES cipher suites. See
 	// cipherSuitesPreferenceOrder for details.
 	return []*CipherSuite{
 		{TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
 		{TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
+		{TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, true},
+		{TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, true},
 		{TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
+		{TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, true},
+		{TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, true},
 		{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
 		{TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
 		{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
@@ -197,6 +199,16 @@ type cipherSuiteTLS13 struct {
 	hash   crypto.Hash
 }
 
+// cipherSuitesTLS13 should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+//   - github.com/quic-go/quic-go
+//   - github.com/sagernet/quic-go
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname cipherSuitesTLS13
 var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map.
 	{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
 	{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
@@ -321,36 +333,36 @@ var cipherSuitesPreferenceOrderNoAES = []uint16{
 	TLS_RSA_WITH_RC4_128_SHA,
 }
 
-// disabledCipherSuites are not used unless explicitly listed in
-// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder.
-var disabledCipherSuites = []uint16{
+// disabledCipherSuites are not used unless explicitly listed in Config.CipherSuites.
+var disabledCipherSuites = map[uint16]bool{
 	// CBC_SHA256
-	TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
-	TLS_RSA_WITH_AES_128_CBC_SHA256,
+	TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: true,
+	TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:   true,
+	TLS_RSA_WITH_AES_128_CBC_SHA256:         true,
 
 	// RC4
-	TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
-	TLS_RSA_WITH_RC4_128_SHA,
+	TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: true,
+	TLS_ECDHE_RSA_WITH_RC4_128_SHA:   true,
+	TLS_RSA_WITH_RC4_128_SHA:         true,
 }
 
-var (
-	defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
-	defaultCipherSuites    = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen]
-)
-
-// defaultCipherSuitesTLS13 is also the preference order, since there are no
-// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as
-// cipherSuitesPreferenceOrder applies.
-var defaultCipherSuitesTLS13 = []uint16{
-	TLS_AES_128_GCM_SHA256,
-	TLS_AES_256_GCM_SHA384,
-	TLS_CHACHA20_POLY1305_SHA256,
+// rsaKexCiphers contains the ciphers which use RSA based key exchange,
+// which we also disable by default unless a GODEBUG is set.
+var rsaKexCiphers = map[uint16]bool{
+	TLS_RSA_WITH_RC4_128_SHA:        true,
+	TLS_RSA_WITH_3DES_EDE_CBC_SHA:   true,
+	TLS_RSA_WITH_AES_128_CBC_SHA:    true,
+	TLS_RSA_WITH_AES_256_CBC_SHA:    true,
+	TLS_RSA_WITH_AES_128_CBC_SHA256: true,
+	TLS_RSA_WITH_AES_128_GCM_SHA256: true,
+	TLS_RSA_WITH_AES_256_GCM_SHA384: true,
 }
 
-var defaultCipherSuitesTLS13NoAES = []uint16{
-	TLS_CHACHA20_POLY1305_SHA256,
-	TLS_AES_128_GCM_SHA256,
-	TLS_AES_256_GCM_SHA384,
+// tdesCiphers contains 3DES ciphers,
+// which we also disable by default unless a GODEBUG is set.
+var tdesCiphers = map[uint16]bool{
+	TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: true,
+	TLS_RSA_WITH_3DES_EDE_CBC_SHA:       true,
 }
 
 var (
@@ -414,7 +426,11 @@ func cipherAES(key, iv []byte, isRead bool) any {
 // macSHA1 returns a SHA-1 based constant time MAC.
 func macSHA1(key []byte) hash.Hash {
 	h := sha1.New
-	h = newConstantTimeHash(h)
+	// The BoringCrypto SHA1 does not have a constant-time
+	// checksum function, so don't try to use it.
+	if !boring.Enabled {
+		h = newConstantTimeHash(h)
+	}
 	return hmac.New(h, key)
 }
 
@@ -504,7 +520,12 @@ func aeadAESGCM(key, noncePrefix []byte) aead {
 		panic(err)
 	}
 	var aead cipher.AEAD
-	aead, err = cipher.NewGCM(aes)
+	if boring.Enabled {
+		aead, err = boring.NewGCMTLS(aes)
+	} else {
+		boring.Unreachable()
+		aead, err = cipher.NewGCM(aes)
+	}
 	if err != nil {
 		panic(err)
 	}
@@ -514,6 +535,16 @@ func aeadAESGCM(key, noncePrefix []byte) aead {
 	return ret
 }
 
+// aeadAESGCMTLS13 should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+//   - github.com/xtls/xray-core
+//   - github.com/v2fly/v2ray-core
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname aeadAESGCMTLS13
 func aeadAESGCMTLS13(key, nonceMask []byte) aead {
 	if len(nonceMask) != aeadNonceLength {
 		panic("tls: internal error: wrong nonce length")
@@ -522,8 +553,7 @@ func aeadAESGCMTLS13(key, nonceMask []byte) aead {
 	if err != nil {
 		panic(err)
 	}
-	var aead cipher.AEAD
-	aead, err = cipher.NewGCM(aes)
+	aead, err := cipher.NewGCM(aes)
 	if err != nil {
 		panic(err)
 	}
@@ -565,6 +595,7 @@ func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) }
 func (c *cthWrapper) Sum(b []byte) []byte         { return c.h.ConstantTimeSum(b) }
 
 func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
+	boring.Unreachable()
 	return func() hash.Hash {
 		return &cthWrapper{h().(constantTimeHash)}
 	}

+ 206 - 98
vendor/github.com/Psiphon-Labs/psiphon-tls/common.go

@@ -18,11 +18,17 @@ import (
 	"crypto/x509"
 	"errors"
 	"fmt"
+
+	// [Psiphon]
+	// "internal/godebug"
+
 	"io"
 	"net"
+	"slices"
 	"strings"
 	"sync"
 	"time"
+	_ "unsafe" // for linkname
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 )
@@ -59,12 +65,13 @@ func VersionName(version uint16) string {
 }
 
 const (
-	maxPlaintext       = 16384        // maximum plaintext payload length
-	maxCiphertext      = 16384 + 2048 // maximum ciphertext payload length
-	maxCiphertextTLS13 = 16384 + 256  // maximum ciphertext length in TLS 1.3
-	recordHeaderLen    = 5            // record header length
-	maxHandshake       = 65536        // maximum handshake we support (protocol max is 16 MB)
-	maxUselessRecords  = 16           // maximum number of consecutive non-advancing records
+	maxPlaintext               = 16384        // maximum plaintext payload length
+	maxCiphertext              = 16384 + 2048 // maximum ciphertext payload length
+	maxCiphertextTLS13         = 16384 + 256  // maximum ciphertext length in TLS 1.3
+	recordHeaderLen            = 5            // record header length
+	maxHandshake               = 65536        // maximum handshake we support (protocol max is 16 MB)
+	maxHandshakeCertificateMsg = 262144       // maximum certificate message size (256 KiB)
+	maxUselessRecords          = 16           // maximum number of consecutive non-advancing records
 )
 
 // TLS record types.
@@ -94,7 +101,6 @@ const (
 	typeFinished            uint8 = 20
 	typeCertificateStatus   uint8 = 22
 	typeKeyUpdate           uint8 = 24
-	typeNextProtocol        uint8 = 67  // Not IANA assigned
 	typeMessageHash         uint8 = 254 // synthetic message
 )
 
@@ -124,6 +130,8 @@ const (
 	extensionKeyShare                uint16 = 51
 	extensionQUICTransportParameters uint16 = 57
 	extensionRenegotiationInfo       uint16 = 0xff01
+	extensionECHOuterExtensions      uint16 = 0xfd00
+	extensionEncryptedClientHello    uint16 = 0xfe0d
 )
 
 // TLS signaling cipher suite values
@@ -131,11 +139,13 @@ const (
 	scsvRenegotiation uint16 = 0x00ff
 )
 
-// CurveID is the type of a TLS identifier for an elliptic curve. See
+// CurveID is the type of a TLS identifier for a key exchange mechanism. See
 // https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8.
 //
-// In TLS 1.3, this type is called NamedGroup, but at this time this library
-// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7.
+// In TLS 1.2, this registry used to support only elliptic curves. In TLS 1.3,
+// it was extended to other groups and renamed NamedGroup. See RFC 8446, Section
+// 4.2.7. It was then also extended to other mechanisms, such as hybrid
+// post-quantum KEMs.
 type CurveID uint16
 
 const (
@@ -143,6 +153,11 @@ const (
 	CurveP384 CurveID = 24
 	CurveP521 CurveID = 25
 	X25519    CurveID = 29
+
+	// Experimental codepoint for X25519Kyber768Draft00, specified in
+	// draft-tls-westerbaan-xyber768d00-03. Not exported, as support might be
+	// removed in the future.
+	x25519Kyber768Draft00 CurveID = 0x6399 // X25519Kyber768Draft00
 )
 
 // TLS 1.3 Key Share. See RFC 8446, Section 4.2.8.
@@ -195,25 +210,6 @@ const (
 // hash function associated with the Ed25519 signature scheme.
 var directSigning crypto.Hash = 0
 
-// defaultSupportedSignatureAlgorithms contains the signature and hash algorithms that
-// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+
-// CertificateRequest. The two fields are merged to match with TLS 1.3.
-// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc.
-var defaultSupportedSignatureAlgorithms = []SignatureScheme{
-	PSSWithSHA256,
-	ECDSAWithP256AndSHA256,
-	Ed25519,
-	PSSWithSHA384,
-	PSSWithSHA512,
-	PKCS1WithSHA256,
-	PKCS1WithSHA384,
-	PKCS1WithSHA512,
-	ECDSAWithP384AndSHA384,
-	ECDSAWithP521AndSHA512,
-	PKCS1WithSHA1,
-	ECDSAWithSHA1,
-}
-
 // helloRetryRequestRandom is set as the Random value of a ServerHello
 // to signal that the message is actually a HelloRetryRequest.
 var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
@@ -307,18 +303,32 @@ type ConnectionState struct {
 	// resumed connections that don't support Extended Master Secret (RFC 7627).
 	TLSUnique []byte
 
+	// ECHAccepted indicates if Encrypted Client Hello was offered by the client
+	// and accepted by the server. Currently, ECH is supported only on the
+	// client side.
+	ECHAccepted bool
+
 	// ekm is a closure exposed via ExportKeyingMaterial.
 	ekm func(label string, context []byte, length int) ([]byte, error)
+
+	// testingOnlyDidHRR is true if a HelloRetryRequest was sent/received.
+	testingOnlyDidHRR bool
+
+	// testingOnlyCurveID is the selected CurveID, or zero if an RSA exchanges
+	// is performed.
+	testingOnlyCurveID CurveID
 }
 
 // ExportKeyingMaterial returns length bytes of exported key material in a new
 // slice as defined in RFC 5705. If context is nil, it is not used as part of
 // the seed. If the connection was set to allow renegotiation via
-// Config.Renegotiation, this function will return an error.
+// Config.Renegotiation, or if the connections supports neither TLS 1.3 nor
+// Extended Master Secret, this function will return an error.
 //
-// There are conditions in which the returned values might not be unique to a
-// connection. See the Security Considerations sections of RFC 5705 and RFC 7627,
-// and https://mitls.org/pages/attacks/3SHAKE#channelbindings.
+// Exporting key material without Extended Master Secret or TLS 1.3 was disabled
+// in Go 1.22 due to security issues (see the Security Considerations sections
+// of RFC 5705 and RFC 7627), but can be re-enabled with the GODEBUG setting
+// tlsunsafeekm=1.
 func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
 	return cs.ekm(label, context, length)
 }
@@ -380,7 +390,7 @@ type ClientSessionCache interface {
 	Put(sessionKey string, cs *ClientSessionState)
 }
 
-//go:generate stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go
+//go:generate stringer -linecomment -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go
 
 // SignatureScheme identifies a signature algorithm supported by TLS. See
 // RFC 8446, Section 4.2.3.
@@ -687,7 +697,11 @@ type Config struct {
 	// the list is ignored. Note that TLS 1.3 ciphersuites are not configurable.
 	//
 	// If CipherSuites is nil, a safe default list is used. The default cipher
-	// suites might change over time.
+	// suites might change over time. In Go 1.22 RSA key exchange based cipher
+	// suites were removed from the default list, but can be re-added with the
+	// GODEBUG setting tlsrsakex=1. In Go 1.23 3DES cipher suites were removed
+	// from the default list, but can be re-added with the GODEBUG setting
+	// tls3des=1.
 	CipherSuites []uint16
 
 	// PreferServerCipherSuites is a legacy field and has no effect.
@@ -750,14 +764,11 @@ type Config struct {
 
 	// MinVersion contains the minimum TLS version that is acceptable.
 	//
-	// By default, TLS 1.2 is currently used as the minimum when acting as a
-	// client, and TLS 1.0 when acting as a server. TLS 1.0 is the minimum
-	// supported by this package, both as a client and as a server.
+	// By default, TLS 1.2 is currently used as the minimum. TLS 1.0 is the
+	// minimum supported by this package.
 	//
-	// The client-side default can temporarily be reverted to TLS 1.0 by
-	// including the value "x509sha1=1" in the GODEBUG environment variable.
-	// Note that this option will be removed in Go 1.19 (but it will still be
-	// possible to set this field to VersionTLS10 explicitly).
+	// The server-side default can be reverted to TLS 1.0 by including the value
+	// "tls10server=1" in the GODEBUG environment variable.
 	MinVersion uint16
 
 	// MaxVersion contains the maximum TLS version that is acceptable.
@@ -770,6 +781,10 @@ type Config struct {
 	// an ECDHE handshake, in preference order. If empty, the default will
 	// be used. The client will use the first preference as the type for
 	// its key share in TLS 1.3. This may change in the future.
+	//
+	// From Go 1.23, the default includes the X25519Kyber768Draft00 hybrid
+	// post-quantum key exchange. To disable it, set CurvePreferences explicitly
+	// or use the GODEBUG=tlskyber=0 environment variable.
 	CurvePreferences []CurveID
 
 	// DynamicRecordSizingDisabled disables adaptive sizing of TLS records.
@@ -790,6 +805,41 @@ type Config struct {
 	// used for debugging.
 	KeyLogWriter io.Writer
 
+	// EncryptedClientHelloConfigList is a serialized ECHConfigList. If
+	// provided, clients will attempt to connect to servers using Encrypted
+	// Client Hello (ECH) using one of the provided ECHConfigs. Servers
+	// currently ignore this field.
+	//
+	// If the list contains no valid ECH configs, the handshake will fail
+	// and return an error.
+	//
+	// If EncryptedClientHelloConfigList is set, MinVersion, if set, must
+	// be VersionTLS13.
+	//
+	// When EncryptedClientHelloConfigList is set, the handshake will only
+	// succeed if ECH is sucessfully negotiated. If the server rejects ECH,
+	// an ECHRejectionError error will be returned, which may contain a new
+	// ECHConfigList that the server suggests using.
+	//
+	// How this field is parsed may change in future Go versions, if the
+	// encoding described in the final Encrypted Client Hello RFC changes.
+	EncryptedClientHelloConfigList []byte
+
+	// EncryptedClientHelloRejectionVerify, if not nil, is called when ECH is
+	// rejected, in order to verify the ECH provider certificate in the outer
+	// Client Hello. If it returns a non-nil error, the handshake is aborted and
+	// that error results.
+	//
+	// Unlike VerifyPeerCertificate and VerifyConnection, normal certificate
+	// verification will not be performed before calling
+	// EncryptedClientHelloRejectionVerify.
+	//
+	// If EncryptedClientHelloRejectionVerify is nil and ECH is rejected, the
+	// roots in RootCAs will be used to verify the ECH providers public
+	// certificate. VerifyPeerCertificate and VerifyConnection are not called
+	// when ECH is rejected, even if set, and InsecureSkipVerify is ignored.
+	EncryptedClientHelloRejectionVerify func(ConnectionState) error
+	
 	// [Psiphon]
 	// ClientHelloPRNG is used for Client Hello randomization and replay.
 	ClientHelloPRNG *prng.PRNG
@@ -894,7 +944,7 @@ func (c *Config) ticketKeyFromBytes(b [32]byte) (key ticketKey) {
 // ticket, and the lifetime we set for all tickets we send.
 const maxSessionTicketLifetime = 7 * 24 * time.Hour
 
-// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is
+// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a [Config] that is
 // being used concurrently by a TLS client or server.
 func (c *Config) Clone() *Config {
 	if c == nil {
@@ -903,37 +953,41 @@ func (c *Config) Clone() *Config {
 	c.mutex.RLock()
 	defer c.mutex.RUnlock()
 	return &Config{
-		Rand:                        c.Rand,
-		Time:                        c.Time,
-		Certificates:                c.Certificates,
-		NameToCertificate:           c.NameToCertificate,
-		GetCertificate:              c.GetCertificate,
-		GetClientCertificate:        c.GetClientCertificate,
-		GetConfigForClient:          c.GetConfigForClient,
-		VerifyPeerCertificate:       c.VerifyPeerCertificate,
-		VerifyConnection:            c.VerifyConnection,
-		RootCAs:                     c.RootCAs,
-		NextProtos:                  c.NextProtos,
-		ServerName:                  c.ServerName,
-		ClientAuth:                  c.ClientAuth,
-		ClientCAs:                   c.ClientCAs,
-		InsecureSkipVerify:          c.InsecureSkipVerify,
+		Rand:                                c.Rand,
+		Time:                                c.Time,
+		Certificates:                        c.Certificates,
+		NameToCertificate:                   c.NameToCertificate,
+		GetCertificate:                      c.GetCertificate,
+		GetClientCertificate:                c.GetClientCertificate,
+		GetConfigForClient:                  c.GetConfigForClient,
+		VerifyPeerCertificate:               c.VerifyPeerCertificate,
+		VerifyConnection:                    c.VerifyConnection,
+		RootCAs:                             c.RootCAs,
+		NextProtos:                          c.NextProtos,
+		ServerName:                          c.ServerName,
+		ClientAuth:                          c.ClientAuth,
+		ClientCAs:                           c.ClientCAs,
+		InsecureSkipVerify:                  c.InsecureSkipVerify,
+		CipherSuites:                        c.CipherSuites,
+		PreferServerCipherSuites:            c.PreferServerCipherSuites,
+		SessionTicketsDisabled:              c.SessionTicketsDisabled,
+		SessionTicketKey:                    c.SessionTicketKey,
+		ClientSessionCache:                  c.ClientSessionCache,
+		UnwrapSession:                       c.UnwrapSession,
+		WrapSession:                         c.WrapSession,
+		MinVersion:                          c.MinVersion,
+		MaxVersion:                          c.MaxVersion,
+		CurvePreferences:                    c.CurvePreferences,
+		DynamicRecordSizingDisabled:         c.DynamicRecordSizingDisabled,
+		Renegotiation:                       c.Renegotiation,
+		KeyLogWriter:                        c.KeyLogWriter,
+		EncryptedClientHelloConfigList:      c.EncryptedClientHelloConfigList,
+		EncryptedClientHelloRejectionVerify: c.EncryptedClientHelloRejectionVerify,
+		sessionTicketKeys:                   c.sessionTicketKeys,
+		autoSessionTicketKeys:               c.autoSessionTicketKeys,
+
+		// [Psiphon]
 		InsecureSkipTimeVerify:      c.InsecureSkipTimeVerify,
-		CipherSuites:                c.CipherSuites,
-		PreferServerCipherSuites:    c.PreferServerCipherSuites,
-		SessionTicketsDisabled:      c.SessionTicketsDisabled,
-		SessionTicketKey:            c.SessionTicketKey,
-		ClientSessionCache:          c.ClientSessionCache,
-		UnwrapSession:               c.UnwrapSession,
-		WrapSession:                 c.WrapSession,
-		MinVersion:                  c.MinVersion,
-		MaxVersion:                  c.MaxVersion,
-		CurvePreferences:            c.CurvePreferences,
-		DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
-		Renegotiation:               c.Renegotiation,
-		KeyLogWriter:                c.KeyLogWriter,
-		sessionTicketKeys:           c.sessionTicketKeys,
-		autoSessionTicketKeys:       c.autoSessionTicketKeys,
 	}
 }
 
@@ -1080,13 +1134,19 @@ func (c *Config) time() time.Time {
 }
 
 func (c *Config) cipherSuites() []uint16 {
-	if needFIPS() {
-		return fipsCipherSuites(c)
+	if c.CipherSuites == nil {
+		if needFIPS() {
+			return defaultCipherSuitesFIPS
+		}
+		return defaultCipherSuites()
 	}
-	if c.CipherSuites != nil {
-		return c.CipherSuites
+	if needFIPS() {
+		cipherSuites := slices.Clone(c.CipherSuites)
+		return slices.DeleteFunc(cipherSuites, func(id uint16) bool {
+			return !slices.Contains(defaultCipherSuitesFIPS, id)
+		})
 	}
-	return defaultCipherSuites
+	return c.CipherSuites
 }
 
 var supportedVersions = []uint16{
@@ -1101,14 +1161,26 @@ var supportedVersions = []uint16{
 const roleClient = true
 const roleServer = false
 
+// [Psiphon]
+// var tls10server = godebug.New("tls10server")
+
 func (c *Config) supportedVersions(isClient bool) []uint16 {
 	versions := make([]uint16, 0, len(supportedVersions))
 	for _, v := range supportedVersions {
-		if needFIPS() && (v < fipsMinVersion(c) || v > fipsMaxVersion(c)) {
+		if needFIPS() && !slices.Contains(defaultSupportedVersionsFIPS, v) {
 			continue
 		}
-		if (c == nil || c.MinVersion == 0) &&
-			isClient && v < VersionTLS12 {
+		if (c == nil || c.MinVersion == 0) && v < VersionTLS12 {
+			// [Psiphon] BEGIN
+			// if isClient || tls10server.Value() != "1" {
+			// 	continue
+			// }
+			if isClient {
+				continue
+			}
+			// [Psiphon] END
+		}
+		if isClient && c.EncryptedClientHelloConfigList != nil && v < VersionTLS13 {
 			continue
 		}
 		if c != nil && c.MinVersion != 0 && v < c.MinVersion {
@@ -1144,20 +1216,30 @@ func supportedVersionsFromMax(maxVersion uint16) []uint16 {
 	return versions
 }
 
-var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521}
-
-func (c *Config) curvePreferences() []CurveID {
-	if needFIPS() {
-		return fipsCurvePreferences(c)
+func (c *Config) curvePreferences(version uint16) []CurveID {
+	var curvePreferences []CurveID
+	if c != nil && len(c.CurvePreferences) != 0 {
+		curvePreferences = slices.Clone(c.CurvePreferences)
+		if needFIPS() {
+			return slices.DeleteFunc(curvePreferences, func(c CurveID) bool {
+				return !slices.Contains(defaultCurvePreferencesFIPS, c)
+			})
+		}
+	} else if needFIPS() {
+		curvePreferences = slices.Clone(defaultCurvePreferencesFIPS)
+	} else {
+		curvePreferences = defaultCurvePreferences()
 	}
-	if c == nil || len(c.CurvePreferences) == 0 {
-		return defaultCurvePreferences
+	if version < VersionTLS13 {
+		return slices.DeleteFunc(curvePreferences, func(c CurveID) bool {
+			return c == x25519Kyber768Draft00
+		})
 	}
-	return c.CurvePreferences
+	return curvePreferences
 }
 
-func (c *Config) supportsCurve(curve CurveID) bool {
-	for _, cc := range c.curvePreferences() {
+func (c *Config) supportsCurve(version uint16, curve CurveID) bool {
+	for _, cc := range c.curvePreferences(version) {
 		if cc == curve {
 			return true
 		}
@@ -1179,6 +1261,15 @@ func (c *Config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bo
 	return 0, false
 }
 
+// errNoCertificates should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+//   - github.com/xtls/xray-core
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname errNoCertificates
 var errNoCertificates = errors.New("tls: no certificates configured")
 
 // getCertificate returns the best certificate for the given ClientHelloInfo,
@@ -1230,9 +1321,9 @@ func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, err
 // the client that sent the ClientHello. Otherwise, it returns an error
 // describing the reason for the incompatibility.
 //
-// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate
-// callback, this method will take into account the associated Config. Note that
-// if GetConfigForClient returns a different Config, the change can't be
+// If this [ClientHelloInfo] was passed to a GetConfigForClient or GetCertificate
+// callback, this method will take into account the associated [Config]. Note that
+// if GetConfigForClient returns a different [Config], the change can't be
 // accounted for by this method.
 //
 // This function will call x509.ParseCertificate unless c.Leaf is set, which can
@@ -1316,7 +1407,7 @@ func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error {
 	}
 
 	// The only signed key exchange we support is ECDHE.
-	if !supportsECDHE(config, chi.SupportedCurves, chi.SupportedPoints) {
+	if !supportsECDHE(config, vers, chi.SupportedCurves, chi.SupportedPoints) {
 		return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange"))
 	}
 
@@ -1337,7 +1428,7 @@ func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error {
 			}
 			var curveOk bool
 			for _, c := range chi.SupportedCurves {
-				if c == curve && config.supportsCurve(c) {
+				if c == curve && config.supportsCurve(vers, c) {
 					curveOk = true
 					break
 				}
@@ -1508,6 +1599,15 @@ type handshakeMessage interface {
 	unmarshal([]byte) bool
 }
 
+type handshakeMessageWithOriginalBytes interface {
+	handshakeMessage
+
+	// originalBytes should return the original bytes that were passed to
+	// unmarshal to create the message. If the message was not produced by
+	// unmarshal, it should return nil.
+	originalBytes() []byte
+}
+
 // lruSessionCache is a ClientSessionCache implementation that uses an LRU
 // caching strategy.
 type lruSessionCache struct {
@@ -1523,7 +1623,7 @@ type lruSessionCacheEntry struct {
 	state      *ClientSessionState
 }
 
-// NewLRUClientSessionCache returns a ClientSessionCache with the given
+// NewLRUClientSessionCache returns a [ClientSessionCache] with the given
 // capacity that uses an LRU strategy. If capacity is < 1, a default capacity
 // is used instead.
 func NewLRUClientSessionCache(capacity int) ClientSessionCache {
@@ -1572,7 +1672,7 @@ func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) {
 	c.m[sessionKey] = elem
 }
 
-// Get returns the ClientSessionState value associated with a given key. It
+// Get returns the [ClientSessionState] value associated with a given key. It
 // returns (nil, false) if no value is found.
 func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) {
 	c.Lock()
@@ -1595,6 +1695,14 @@ func unexpectedMessageError(wanted, got any) error {
 	return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted)
 }
 
+// supportedSignatureAlgorithms returns the supported signature algorithms.
+func supportedSignatureAlgorithms() []SignatureScheme {
+	if !needFIPS() {
+		return defaultSupportedSignatureAlgorithms
+	}
+	return defaultSupportedSignatureAlgorithmsFIPS
+}
+
 func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool {
 	for _, s := range supportedSignatureAlgorithms {
 		if s == sigAlg {

+ 5 - 1
vendor/github.com/Psiphon-Labs/psiphon-tls/common_string.go

@@ -1,4 +1,4 @@
-// Code generated by "stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT.
+// Code generated by "stringer -linecomment -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT.
 
 package tls
 
@@ -71,11 +71,13 @@ func _() {
 	_ = x[CurveP384-24]
 	_ = x[CurveP521-25]
 	_ = x[X25519-29]
+	_ = x[x25519Kyber768Draft00-25497]
 }
 
 const (
 	_CurveID_name_0 = "CurveP256CurveP384CurveP521"
 	_CurveID_name_1 = "X25519"
+	_CurveID_name_2 = "X25519Kyber768Draft00"
 )
 
 var (
@@ -89,6 +91,8 @@ func (i CurveID) String() string {
 		return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]]
 	case i == 29:
 		return _CurveID_name_1
+	case i == 25497:
+		return _CurveID_name_2
 	default:
 		return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")"
 	}

+ 54 - 20
vendor/github.com/Psiphon-Labs/psiphon-tls/conn.go

@@ -15,6 +15,9 @@ import (
 	"errors"
 	"fmt"
 	"hash"
+
+	// [Psiphon]
+	// "internal/godebug"
 	"io"
 	"net"
 	"sync"
@@ -51,7 +54,9 @@ type Conn struct {
 	clientSentTicket bool // whether the client sent a session ticket or a PSK in the Client Hello successfully.
 
 	didResume        bool // whether this connection was a session resumption
+	didHRR           bool // whether a HelloRetryRequest was sent/received
 	cipherSuite      uint16
+	curveID          CurveID
 	ocspResponse     []byte   // stapled OCSP response
 	scts             [][]byte // signed certificate timestamps from server
 	peerCertificates []*x509.Certificate
@@ -72,6 +77,7 @@ type Conn struct {
 	// resumptionSecret is the resumption_master_secret for handling
 	// or sending NewSessionTicket messages.
 	resumptionSecret []byte
+	echAccepted      bool
 
 	// ticketKeys is the set of active session ticket keys for this
 	// connection. The first one is used to encrypt new tickets and
@@ -140,21 +146,21 @@ func (c *Conn) RemoteAddr() net.Addr {
 }
 
 // SetDeadline sets the read and write deadlines associated with the connection.
-// A zero value for t means Read and Write will not time out.
+// A zero value for t means [Conn.Read] and [Conn.Write] will not time out.
 // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
 func (c *Conn) SetDeadline(t time.Time) error {
 	return c.conn.SetDeadline(t)
 }
 
 // SetReadDeadline sets the read deadline on the underlying connection.
-// A zero value for t means Read will not time out.
+// A zero value for t means [Conn.Read] will not time out.
 func (c *Conn) SetReadDeadline(t time.Time) error {
 	return c.conn.SetReadDeadline(t)
 }
 
 // SetWriteDeadline sets the write deadline on the underlying connection.
-// A zero value for t means Write will not time out.
-// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
+// A zero value for t means [Conn.Write] will not time out.
+// After a [Conn.Write] has timed out, the TLS state is corrupt and all future writes will return the same error.
 func (c *Conn) SetWriteDeadline(t time.Time) error {
 	return c.conn.SetWriteDeadline(t)
 }
@@ -1057,7 +1063,7 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
 }
 
 // writeHandshakeRecord writes a handshake message to the connection and updates
-// the record layer state. If transcript is non-nil the marshalled message is
+// the record layer state. If transcript is non-nil the marshaled message is
 // written to it.
 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
 	c.out.Lock()
@@ -1124,10 +1130,22 @@ func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
 		return nil, err
 	}
 	data := c.hand.Bytes()
+
+	maxHandshakeSize := maxHandshake
+	// hasVers indicates we're past the first message, forcing someone trying to
+	// make us just allocate a large buffer to at least do the initial part of
+	// the handshake first.
+	if c.haveVers && data[0] == typeCertificate {
+		// Since certificate messages are likely to be the only messages that
+		// can be larger than maxHandshake, we use a special limit for just
+		// those messages.
+		maxHandshakeSize = maxHandshakeCertificateMsg
+	}
+
 	n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
-	if n > maxHandshake {
+	if n > maxHandshakeSize {
 		c.sendAlertLocked(alertInternalError)
-		return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
+		return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
 	}
 	if err := c.readHandshakeBytes(4 + n); err != nil {
 		return nil, err
@@ -1211,10 +1229,10 @@ var (
 
 // Write writes data to the connection.
 //
-// As Write calls Handshake, in order to prevent indefinite blocking a deadline
-// must be set for both Read and Write before Write is called when the handshake
-// has not yet completed. See SetDeadline, SetReadDeadline, and
-// SetWriteDeadline.
+// As Write calls [Conn.Handshake], in order to prevent indefinite blocking a deadline
+// must be set for both [Conn.Read] and Write before Write is called when the handshake
+// has not yet completed. See [Conn.SetDeadline], [Conn.SetReadDeadline], and
+// [Conn.SetWriteDeadline].
 func (c *Conn) Write(b []byte) (int, error) {
 	// interlock with Close below
 	for {
@@ -1386,10 +1404,10 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
 
 // Read reads data from the connection.
 //
-// As Read calls Handshake, in order to prevent indefinite blocking a deadline
-// must be set for both Read and Write before Read is called when the handshake
-// has not yet completed. See SetDeadline, SetReadDeadline, and
-// SetWriteDeadline.
+// As Read calls [Conn.Handshake], in order to prevent indefinite blocking a deadline
+// must be set for both Read and [Conn.Write] before Read is called when the handshake
+// has not yet completed. See [Conn.SetDeadline], [Conn.SetReadDeadline], and
+// [Conn.SetWriteDeadline].
 func (c *Conn) Read(b []byte) (int, error) {
 	if err := c.Handshake(); err != nil {
 		return 0, err
@@ -1473,7 +1491,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
 
 // CloseWrite shuts down the writing side of the connection. It should only be
 // called once the handshake has completed and does not call CloseWrite on the
-// underlying connection. Most callers should just use Close.
+// underlying connection. Most callers should just use [Conn.Close].
 func (c *Conn) CloseWrite() error {
 	if !c.isHandshakeComplete.Load() {
 		return errEarlyCloseWrite
@@ -1501,10 +1519,10 @@ func (c *Conn) closeNotify() error {
 // protocol if it has not yet been run.
 //
 // Most uses of this package need not call Handshake explicitly: the
-// first Read or Write will call it automatically.
+// first [Conn.Read] or [Conn.Write] will call it automatically.
 //
 // For control over canceling or setting a timeout on a handshake, use
-// HandshakeContext or the Dialer's DialContext method instead.
+// [Conn.HandshakeContext] or the [Dialer]'s DialContext method instead.
 //
 // In order to avoid denial of service attacks, the maximum RSA key size allowed
 // in certificates sent by either the TLS server or client is limited to 8192
@@ -1523,7 +1541,7 @@ func (c *Conn) Handshake() error {
 // connection.
 //
 // Most uses of this package need not call HandshakeContext explicitly: the
-// first Read or Write will call it automatically.
+// first [Conn.Read] or [Conn.Write] will call it automatically.
 func (c *Conn) HandshakeContext(ctx context.Context) error {
 	// Delegate to unexported method for named return
 	// without confusing documented signature.
@@ -1649,12 +1667,18 @@ func (c *Conn) ConnectionState() ConnectionState {
 	return c.connectionStateLocked()
 }
 
+// [Psiphon]
+// var tlsunsafeekm = godebug.New("tlsunsafeekm")
+
 func (c *Conn) connectionStateLocked() ConnectionState {
 	var state ConnectionState
 	state.HandshakeComplete = c.isHandshakeComplete.Load()
 	state.Version = c.vers
 	state.NegotiatedProtocol = c.clientProtocol
 	state.DidResume = c.didResume
+	state.testingOnlyDidHRR = c.didHRR
+	// c.curveID is not set on TLS 1.0–1.2 resumptions. Fix that before exposing it.
+	state.testingOnlyCurveID = c.curveID
 	state.NegotiatedProtocolIsMutual = true
 	state.ServerName = c.serverName
 	state.CipherSuite = c.cipherSuite
@@ -1670,10 +1694,20 @@ func (c *Conn) connectionStateLocked() ConnectionState {
 		}
 	}
 	if c.config.Renegotiation != RenegotiateNever {
-		state.ekm = noExportedKeyingMaterial
+		state.ekm = noEKMBecauseRenegotiation
+	} else if c.vers != VersionTLS13 && !c.extMasterSecret {
+		state.ekm = func(label string, context []byte, length int) ([]byte, error) {
+			// [Psiphon]
+			// if tlsunsafeekm.Value() == "1" {
+			// 	tlsunsafeekm.IncNonDefault()
+			// 	return c.ekm(label, context, length)
+			// }
+			return noEKMBecauseNoEMS(label, context, length)
+		}
 	} else {
 		state.ekm = c.ekm
 	}
+	state.ECHAccepted = c.echAccepted
 	return state
 }
 

+ 139 - 0
vendor/github.com/Psiphon-Labs/psiphon-tls/defaults.go

@@ -0,0 +1,139 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+	// [Psiphon]
+	// "internal/godebug"
+	"slices"
+	_ "unsafe" // for linkname
+)
+
+// Defaults are collected in this file to allow distributions to more easily patch
+// them to apply local policies.
+
+// [Psiphon]
+// var tlskyber = godebug.New("tlskyber")
+
+func defaultCurvePreferences() []CurveID {
+	// [Psiphon]
+	// if tlskyber.Value() == "0" {
+	// 	return []CurveID{X25519, CurveP256, CurveP384, CurveP521}
+	// }
+	// For now, x25519Kyber768Draft00 must always be followed by X25519.
+	// return []CurveID{x25519Kyber768Draft00, X25519, CurveP256, CurveP384, CurveP521}
+
+	// [Psiphon] Excluve X22519Kyber768Deaft00 by default
+	return []CurveID{X25519, CurveP256, CurveP384, CurveP521}
+}
+
+// defaultSupportedSignatureAlgorithms contains the signature and hash algorithms that
+// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+
+// CertificateRequest. The two fields are merged to match with TLS 1.3.
+// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc.
+var defaultSupportedSignatureAlgorithms = []SignatureScheme{
+	PSSWithSHA256,
+	ECDSAWithP256AndSHA256,
+	Ed25519,
+	PSSWithSHA384,
+	PSSWithSHA512,
+	PKCS1WithSHA256,
+	PKCS1WithSHA384,
+	PKCS1WithSHA512,
+	ECDSAWithP384AndSHA384,
+	ECDSAWithP521AndSHA512,
+	PKCS1WithSHA1,
+	ECDSAWithSHA1,
+}
+
+// [Psiphon]
+// var tlsrsakex = godebug.New("tlsrsakex")
+// var tls3des = godebug.New("tls3des")
+
+func defaultCipherSuites() []uint16 {
+	suites := slices.Clone(cipherSuitesPreferenceOrder)
+	return slices.DeleteFunc(suites, func(c uint16) bool {
+		// [Psiphon] BEGIN
+		// return disabledCipherSuites[c] ||
+		// 	tlsrsakex.Value() != "1" && rsaKexCiphers[c] ||
+		// 	tls3des.Value() != "1" && tdesCiphers[c]
+		return disabledCipherSuites[c] || rsaKexCiphers[c] || tdesCiphers[c]
+		// [Psiphon] END
+	})
+}
+
+// defaultCipherSuitesTLS13 is also the preference order, since there are no
+// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as
+// cipherSuitesPreferenceOrder applies.
+//
+// defaultCipherSuitesTLS13 should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+//   - github.com/quic-go/quic-go
+//   - github.com/sagernet/quic-go
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname defaultCipherSuitesTLS13
+var defaultCipherSuitesTLS13 = []uint16{
+	TLS_AES_128_GCM_SHA256,
+	TLS_AES_256_GCM_SHA384,
+	TLS_CHACHA20_POLY1305_SHA256,
+}
+
+// defaultCipherSuitesTLS13NoAES should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+//   - github.com/quic-go/quic-go
+//   - github.com/sagernet/quic-go
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname defaultCipherSuitesTLS13NoAES
+var defaultCipherSuitesTLS13NoAES = []uint16{
+	TLS_CHACHA20_POLY1305_SHA256,
+	TLS_AES_128_GCM_SHA256,
+	TLS_AES_256_GCM_SHA384,
+}
+
+var defaultSupportedVersionsFIPS = []uint16{
+	VersionTLS12,
+}
+
+// defaultCurvePreferencesFIPS are the FIPS-allowed curves,
+// in preference order (most preferable first).
+var defaultCurvePreferencesFIPS = []CurveID{CurveP256, CurveP384, CurveP521}
+
+// defaultSupportedSignatureAlgorithmsFIPS currently are a subset of
+// defaultSupportedSignatureAlgorithms without Ed25519 and SHA-1.
+var defaultSupportedSignatureAlgorithmsFIPS = []SignatureScheme{
+	PSSWithSHA256,
+	PSSWithSHA384,
+	PSSWithSHA512,
+	PKCS1WithSHA256,
+	ECDSAWithP256AndSHA256,
+	PKCS1WithSHA384,
+	ECDSAWithP384AndSHA384,
+	PKCS1WithSHA512,
+	ECDSAWithP521AndSHA512,
+}
+
+// defaultCipherSuitesFIPS are the FIPS-allowed cipher suites.
+var defaultCipherSuitesFIPS = []uint16{
+	TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+	TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+	TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+	TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+	TLS_RSA_WITH_AES_128_GCM_SHA256,
+	TLS_RSA_WITH_AES_256_GCM_SHA384,
+}
+
+// defaultCipherSuitesTLS13FIPS are the FIPS-allowed cipher suites for TLS 1.3.
+var defaultCipherSuitesTLS13FIPS = []uint16{
+	TLS_AES_128_GCM_SHA256,
+	TLS_AES_256_GCM_SHA384,
+}

+ 284 - 0
vendor/github.com/Psiphon-Labs/psiphon-tls/ech.go

@@ -0,0 +1,284 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+	"errors"
+	"strings"
+
+	"github.com/Psiphon-Labs/psiphon-tls/internal/hpke"
+
+	"golang.org/x/crypto/cryptobyte"
+)
+
+type echCipher struct {
+	KDFID  uint16
+	AEADID uint16
+}
+
+type echExtension struct {
+	Type uint16
+	Data []byte
+}
+
+type echConfig struct {
+	raw []byte
+
+	Version uint16
+	Length  uint16
+
+	ConfigID             uint8
+	KemID                uint16
+	PublicKey            []byte
+	SymmetricCipherSuite []echCipher
+
+	MaxNameLength uint8
+	PublicName    []byte
+	Extensions    []echExtension
+}
+
+var errMalformedECHConfig = errors.New("tls: malformed ECHConfigList")
+
+// parseECHConfigList parses a draft-ietf-tls-esni-18 ECHConfigList, returning a
+// slice of parsed ECHConfigs, in the same order they were parsed, or an error
+// if the list is malformed.
+func parseECHConfigList(data []byte) ([]echConfig, error) {
+	s := cryptobyte.String(data)
+	// Skip the length prefix
+	var length uint16
+	if !s.ReadUint16(&length) {
+		return nil, errMalformedECHConfig
+	}
+	if length != uint16(len(data)-2) {
+		return nil, errMalformedECHConfig
+	}
+	var configs []echConfig
+	for len(s) > 0 {
+		var ec echConfig
+		ec.raw = []byte(s)
+		if !s.ReadUint16(&ec.Version) {
+			return nil, errMalformedECHConfig
+		}
+		if !s.ReadUint16(&ec.Length) {
+			return nil, errMalformedECHConfig
+		}
+		if len(ec.raw) < int(ec.Length)+4 {
+			return nil, errMalformedECHConfig
+		}
+		ec.raw = ec.raw[:ec.Length+4]
+		if ec.Version != extensionEncryptedClientHello {
+			s.Skip(int(ec.Length))
+			continue
+		}
+		if !s.ReadUint8(&ec.ConfigID) {
+			return nil, errMalformedECHConfig
+		}
+		if !s.ReadUint16(&ec.KemID) {
+			return nil, errMalformedECHConfig
+		}
+		if !s.ReadUint16LengthPrefixed((*cryptobyte.String)(&ec.PublicKey)) {
+			return nil, errMalformedECHConfig
+		}
+		var cipherSuites cryptobyte.String
+		if !s.ReadUint16LengthPrefixed(&cipherSuites) {
+			return nil, errMalformedECHConfig
+		}
+		for !cipherSuites.Empty() {
+			var c echCipher
+			if !cipherSuites.ReadUint16(&c.KDFID) {
+				return nil, errMalformedECHConfig
+			}
+			if !cipherSuites.ReadUint16(&c.AEADID) {
+				return nil, errMalformedECHConfig
+			}
+			ec.SymmetricCipherSuite = append(ec.SymmetricCipherSuite, c)
+		}
+		if !s.ReadUint8(&ec.MaxNameLength) {
+			return nil, errMalformedECHConfig
+		}
+		var publicName cryptobyte.String
+		if !s.ReadUint8LengthPrefixed(&publicName) {
+			return nil, errMalformedECHConfig
+		}
+		ec.PublicName = publicName
+		var extensions cryptobyte.String
+		if !s.ReadUint16LengthPrefixed(&extensions) {
+			return nil, errMalformedECHConfig
+		}
+		for !extensions.Empty() {
+			var e echExtension
+			if !extensions.ReadUint16(&e.Type) {
+				return nil, errMalformedECHConfig
+			}
+			if !extensions.ReadUint16LengthPrefixed((*cryptobyte.String)(&e.Data)) {
+				return nil, errMalformedECHConfig
+			}
+			ec.Extensions = append(ec.Extensions, e)
+		}
+
+		configs = append(configs, ec)
+	}
+	return configs, nil
+}
+
+func pickECHConfig(list []echConfig) *echConfig {
+	for _, ec := range list {
+		if _, ok := hpke.SupportedKEMs[ec.KemID]; !ok {
+			continue
+		}
+		var validSCS bool
+		for _, cs := range ec.SymmetricCipherSuite {
+			if _, ok := hpke.SupportedAEADs[cs.AEADID]; !ok {
+				continue
+			}
+			if _, ok := hpke.SupportedKDFs[cs.KDFID]; !ok {
+				continue
+			}
+			validSCS = true
+			break
+		}
+		if !validSCS {
+			continue
+		}
+		if !validDNSName(string(ec.PublicName)) {
+			continue
+		}
+		var unsupportedExt bool
+		for _, ext := range ec.Extensions {
+			// If high order bit is set to 1 the extension is mandatory.
+			// Since we don't support any extensions, if we see a mandatory
+			// bit, we skip the config.
+			if ext.Type&uint16(1<<15) != 0 {
+				unsupportedExt = true
+			}
+		}
+		if unsupportedExt {
+			continue
+		}
+		return &ec
+	}
+	return nil
+}
+
+func pickECHCipherSuite(suites []echCipher) (echCipher, error) {
+	for _, s := range suites {
+		// NOTE: all of the supported AEADs and KDFs are fine, rather than
+		// imposing some sort of preference here, we just pick the first valid
+		// suite.
+		if _, ok := hpke.SupportedAEADs[s.AEADID]; !ok {
+			continue
+		}
+		if _, ok := hpke.SupportedKDFs[s.KDFID]; !ok {
+			continue
+		}
+		return s, nil
+	}
+	return echCipher{}, errors.New("tls: no supported symmetric ciphersuites for ECH")
+}
+
+func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) {
+	h, err := inner.marshalMsg(true)
+	if err != nil {
+		return nil, err
+	}
+	h = h[4:] // strip four byte prefix
+
+	var paddingLen int
+	if inner.serverName != "" {
+		paddingLen = max(0, maxNameLength-len(inner.serverName))
+	} else {
+		paddingLen = maxNameLength + 9
+	}
+	paddingLen = 31 - ((len(h) + paddingLen - 1) % 32)
+
+	return append(h, make([]byte, paddingLen)...), nil
+}
+
+func generateOuterECHExt(id uint8, kdfID, aeadID uint16, encodedKey []byte, payload []byte) ([]byte, error) {
+	var b cryptobyte.Builder
+	b.AddUint8(0) // outer
+	b.AddUint16(kdfID)
+	b.AddUint16(aeadID)
+	b.AddUint8(id)
+	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(encodedKey) })
+	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(payload) })
+	return b.Bytes()
+}
+
+func computeAndUpdateOuterECHExtension(outer, inner *clientHelloMsg, ech *echContext, useKey bool) error {
+	var encapKey []byte
+	if useKey {
+		encapKey = ech.encapsulatedKey
+	}
+	encodedInner, err := encodeInnerClientHello(inner, int(ech.config.MaxNameLength))
+	if err != nil {
+		return err
+	}
+	// NOTE: the tag lengths for all of the supported AEADs are the same (16
+	// bytes), so we have hardcoded it here. If we add support for another AEAD
+	// with a different tag length, we will need to change this.
+	encryptedLen := len(encodedInner) + 16 // AEAD tag length
+	outer.encryptedClientHello, err = generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, make([]byte, encryptedLen))
+	if err != nil {
+		return err
+	}
+	serializedOuter, err := outer.marshal()
+	if err != nil {
+		return err
+	}
+	serializedOuter = serializedOuter[4:] // strip the four byte prefix
+	encryptedInner, err := ech.hpkeContext.Seal(serializedOuter, encodedInner)
+	if err != nil {
+		return err
+	}
+	outer.encryptedClientHello, err = generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, encryptedInner)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+// validDNSName is a rather rudimentary check for the validity of a DNS name.
+// This is used to check if the public_name in a ECHConfig is valid when we are
+// picking a config. This can be somewhat lax because even if we pick a
+// valid-looking name, the DNS layer will later reject it anyway.
+func validDNSName(name string) bool {
+	if len(name) > 253 {
+		return false
+	}
+	labels := strings.Split(name, ".")
+	if len(labels) <= 1 {
+		return false
+	}
+	for _, l := range labels {
+		labelLen := len(l)
+		if labelLen == 0 {
+			return false
+		}
+		for i, r := range l {
+			if r == '-' && (i == 0 || i == labelLen-1) {
+				return false
+			}
+			if (r < '0' || r > '9') && (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && r != '-' {
+				return false
+			}
+		}
+	}
+	return true
+}
+
+// ECHRejectionError is the error type returned when ECH is rejected by a remote
+// server. If the server offered a ECHConfigList to use for retries, the
+// RetryConfigList field will contain this list.
+//
+// The client may treat an ECHRejectionError with an empty set of RetryConfigs
+// as a secure signal from the server.
+type ECHRejectionError struct {
+	RetryConfigList []byte
+}
+
+func (e *ECHRejectionError) Error() string {
+	return "tls: server rejected ECH"
+}

+ 248 - 74
vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_client.go

@@ -8,7 +8,6 @@ import (
 	"bytes"
 	"context"
 	"crypto"
-	"crypto/ecdh"
 	"crypto/ecdsa"
 	"crypto/ed25519"
 	"crypto/rsa"
@@ -17,10 +16,20 @@ import (
 	"errors"
 	"fmt"
 	"hash"
+
+	// [Psiphon]
+	// "internal/godebug"
+
 	"io"
 	"net"
 	"strings"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tls/internal/hpke"
+	"github.com/Psiphon-Labs/psiphon-tls/internal/mlkem768"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+
+	"github.com/Psiphon-Labs/psiphon-tls/byteorder"
 )
 
 type clientHandshakeState struct {
@@ -37,46 +46,39 @@ type clientHandshakeState struct {
 
 var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme
 
-func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
+func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echContext, error) {
 	config := c.config
 	if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
-		return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
+		return nil, nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
 	}
 
 	nextProtosLength := 0
 	for _, proto := range config.NextProtos {
 		if l := len(proto); l == 0 || l > 255 {
-			return nil, nil, errors.New("tls: invalid NextProtos value")
+			return nil, nil, nil, errors.New("tls: invalid NextProtos value")
 		} else {
 			nextProtosLength += 1 + l
 		}
 	}
 	if nextProtosLength > 0xffff {
-		return nil, nil, errors.New("tls: NextProtos values too large")
+		return nil, nil, nil, errors.New("tls: NextProtos values too large")
 	}
 
 	supportedVersions := config.supportedVersions(roleClient)
 	if len(supportedVersions) == 0 {
-		return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion")
-	}
-
-	clientHelloVersion := config.maxSupportedVersion(roleClient)
-	// The version at the beginning of the ClientHello was capped at TLS 1.2
-	// for compatibility reasons. The supported_versions extension is used
-	// to negotiate versions now. See RFC 8446, Section 4.2.1.
-	if clientHelloVersion > VersionTLS12 {
-		clientHelloVersion = VersionTLS12
+		return nil, nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion")
 	}
+	maxVersion := config.maxSupportedVersion(roleClient)
 
 	hello := &clientHelloMsg{
-		vers:                         clientHelloVersion,
+		vers:                         maxVersion,
 		compressionMethods:           []uint8{compressionNone},
 		random:                       make([]byte, 32),
 		extendedMasterSecret:         true,
 		ocspStapling:                 true,
 		scts:                         true,
 		serverName:                   hostnameInSNI(config.ServerName),
-		supportedCurves:              config.curvePreferences(),
+		supportedCurves:              config.curvePreferences(maxVersion),
 		supportedPoints:              []uint8{pointFormatUncompressed},
 		secureRenegotiationSupported: true,
 		alpnProtocols:                config.NextProtos,
@@ -85,19 +87,36 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
 
 	// [Psiphon]
 	if c.config != nil {
-		hello.PRNG = c.config.ClientHelloPRNG
+
+		if c.config.ClientHelloPRNG != nil {
+			// Generate a ClientHello PRNG seed for reproducibility
+			// of the ClientHello.
+			hello.marshalerPRNGSeed = new(prng.Seed)
+			_, err := c.config.ClientHelloPRNG.Read(hello.marshalerPRNGSeed[:])
+			if err != nil {
+				return nil, nil, nil, errors.New("tls: short read from ClientHelloPRNG: " + err.Error())
+			}
+		}
+
 		if c.config.GetClientHelloRandom != nil {
 			helloRandom, err := c.config.GetClientHelloRandom()
 			if err == nil && len(helloRandom) != 32 {
 				err = errors.New("invalid length")
 			}
 			if err != nil {
-				return nil, nil, errors.New("tls: GetClientHelloRandom failed: " + err.Error())
+				return nil, nil, nil, errors.New("tls: GetClientHelloRandom failed: " + err.Error())
 			}
 			copy(hello.random, helloRandom)
 		}
 	}
 
+	// The version at the beginning of the ClientHello was capped at TLS 1.2
+	// for compatibility reasons. The supported_versions extension is used
+	// to negotiate versions now. See RFC 8446, Section 4.2.1.
+	if hello.vers > VersionTLS12 {
+		hello.vers = VersionTLS12
+	}
+
 	if c.handshakes > 0 {
 		hello.secureRenegotiation = c.clientFinished[:]
 	}
@@ -116,7 +135,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
 		}
 		// Don't advertise TLS 1.2-only cipher suites unless
 		// we're attempting TLS 1.2.
-		if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 {
+		if maxVersion < VersionTLS12 && suite.flags&suiteTLS12 != 0 {
 			continue
 		}
 		hello.cipherSuites = append(hello.cipherSuites, suiteId)
@@ -128,7 +147,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
 
 		_, err := io.ReadFull(config.rand(), hello.random)
 		if err != nil {
-			return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
+			return nil, nil, nil, errors.New("tls: short read from Rand: " + err.Error())
 		}
 	}
 
@@ -140,46 +159,69 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
 	if c.quic == nil {
 		hello.sessionId = make([]byte, 32)
 		if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
-			return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
+			return nil, nil, nil, errors.New("tls: short read from Rand: " + err.Error())
 		}
 	}
 
-	if hello.vers >= VersionTLS12 {
+	if maxVersion >= VersionTLS12 {
 		hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
 	}
 	if testingOnlyForceClientHelloSignatureAlgorithms != nil {
 		hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms
 	}
 
-	var key *ecdh.PrivateKey
+	var keyShareKeys *keySharePrivateKeys
 	if hello.supportedVersions[0] == VersionTLS13 {
 		// Reset the list of ciphers when the client only supports TLS 1.3.
 		if len(hello.supportedVersions) == 1 {
 			hello.cipherSuites = nil
 		}
-		if needFIPS() {
-			hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13FIPS...)
-		} else if hasAESGCMHardwareSupport {
+		if hasAESGCMHardwareSupport {
 			hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...)
 		} else {
 			hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...)
 		}
 
-		curveID := config.curvePreferences()[0]
-		if _, ok := curveForCurveID(curveID); !ok {
-			return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
-		}
-		key, err = generateECDHEKey(config.rand(), curveID)
-		if err != nil {
-			return nil, nil, err
+		curveID := config.curvePreferences(maxVersion)[0]
+		keyShareKeys = &keySharePrivateKeys{curveID: curveID}
+		if curveID == x25519Kyber768Draft00 {
+			keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), X25519)
+			if err != nil {
+				return nil, nil, nil, err
+			}
+			seed := make([]byte, mlkem768.SeedSize)
+			if _, err := io.ReadFull(config.rand(), seed); err != nil {
+				return nil, nil, nil, err
+			}
+			keyShareKeys.kyber, err = mlkem768.NewKeyFromSeed(seed)
+			if err != nil {
+				return nil, nil, nil, err
+			}
+			// For draft-tls-westerbaan-xyber768d00-03, we send both a hybrid
+			// and a standard X25519 key share, since most servers will only
+			// support the latter. We reuse the same X25519 ephemeral key for
+			// both, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2.
+			hello.keyShares = []keyShare{
+				{group: x25519Kyber768Draft00, data: append(keyShareKeys.ecdhe.PublicKey().Bytes(),
+					keyShareKeys.kyber.EncapsulationKey()...)},
+				{group: X25519, data: keyShareKeys.ecdhe.PublicKey().Bytes()},
+			}
+		} else {
+			if _, ok := curveForCurveID(curveID); !ok {
+				return nil, nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
+			}
+			keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), curveID)
+			if err != nil {
+				return nil, nil, nil, err
+			}
+			hello.keyShares = []keyShare{{group: curveID, data: keyShareKeys.ecdhe.PublicKey().Bytes()}}
 		}
-		hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
 	}
 
 	if c.quic != nil {
 		p, err := c.quicGetTransportParameters()
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, nil, err
 		}
 		if p == nil {
 			p = []byte{}
@@ -187,7 +229,60 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
 		hello.quicTransportParameters = p
 	}
 
-	return hello, key, nil
+	var ech *echContext
+	if c.config.EncryptedClientHelloConfigList != nil {
+		if c.config.MinVersion != 0 && c.config.MinVersion < VersionTLS13 {
+			return nil, nil, nil, errors.New("tls: MinVersion must be >= VersionTLS13 if EncryptedClientHelloConfigList is populated")
+		}
+		if c.config.MaxVersion != 0 && c.config.MaxVersion <= VersionTLS12 {
+			return nil, nil, nil, errors.New("tls: MaxVersion must be >= VersionTLS13 if EncryptedClientHelloConfigList is populated")
+		}
+		echConfigs, err := parseECHConfigList(c.config.EncryptedClientHelloConfigList)
+		if err != nil {
+			return nil, nil, nil, err
+		}
+		echConfig := pickECHConfig(echConfigs)
+		if echConfig == nil {
+			return nil, nil, nil, errors.New("tls: EncryptedClientHelloConfigList contains no valid configs")
+		}
+		ech = &echContext{config: echConfig}
+		hello.encryptedClientHello = []byte{1} // indicate inner hello
+		// We need to explicitly set these 1.2 fields to nil, as we do not
+		// marshal them when encoding the inner hello, otherwise transcripts
+		// will later mismatch.
+		hello.supportedPoints = nil
+		hello.ticketSupported = false
+		hello.secureRenegotiationSupported = false
+		hello.extendedMasterSecret = false
+
+		echPK, err := hpke.ParseHPKEPublicKey(ech.config.KemID, ech.config.PublicKey)
+		if err != nil {
+			return nil, nil, nil, err
+		}
+		suite, err := pickECHCipherSuite(ech.config.SymmetricCipherSuite)
+		if err != nil {
+			return nil, nil, nil, err
+		}
+		ech.kdfID, ech.aeadID = suite.KDFID, suite.AEADID
+		info := append([]byte("tls ech\x00"), ech.config.raw...)
+		ech.encapsulatedKey, ech.hpkeContext, err = hpke.SetupSender(ech.config.KemID, suite.KDFID, suite.AEADID, echPK, info)
+		if err != nil {
+			return nil, nil, nil, err
+		}
+	}
+
+	return hello, keyShareKeys, ech, nil
+}
+
+type echContext struct {
+	config          *echConfig
+	hpkeContext     *hpke.Sender
+	encapsulatedKey []byte
+	innerHello      *clientHelloMsg
+	innerTranscript hash.Hash
+	kdfID           uint16
+	aeadID          uint16
+	echRejected     bool
 }
 
 func (c *Conn) clientHandshake(ctx context.Context) (err error) {
@@ -199,11 +294,10 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
 	// need to be reset.
 	c.didResume = false
 
-	hello, ecdheKey, err := c.makeClientHello()
+	hello, keyShareKeys, ech, err := c.makeClientHello()
 	if err != nil {
 		return err
 	}
-	c.serverName = hello.serverName
 
 	session, earlySecret, binderKey, err := c.loadSession(hello)
 	if err != nil {
@@ -225,6 +319,31 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
 		}()
 	}
 
+	if ech != nil {
+		// Split hello into inner and outer
+		ech.innerHello = hello.clone()
+
+		// Overwrite the server name in the outer hello with the public facing
+		// name.
+		hello.serverName = string(ech.config.PublicName)
+		// Generate a new random for the outer hello.
+		hello.random = make([]byte, 32)
+		_, err = io.ReadFull(c.config.rand(), hello.random)
+		if err != nil {
+			return errors.New("tls: short read from Rand: " + err.Error())
+		}
+
+		// NOTE: we don't do PSK GREASE, in line with boringssl, it's meant to
+		// work around _possibly_ broken middleboxes, but there is little-to-no
+		// evidence that this is actually a problem.
+
+		if err := computeAndUpdateOuterECHExtension(hello, ech.innerHello, ech, true); err != nil {
+			return err
+		}
+	}
+
+	c.serverName = hello.serverName
+
 	if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
 		return err
 	}
@@ -274,17 +393,16 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
 
 	if c.vers == VersionTLS13 {
 		hs := &clientHandshakeStateTLS13{
-			c:           c,
-			ctx:         ctx,
-			serverHello: serverHello,
-			hello:       hello,
-			ecdheKey:    ecdheKey,
-			session:     session,
-			earlySecret: earlySecret,
-			binderKey:   binderKey,
+			c:            c,
+			ctx:          ctx,
+			serverHello:  serverHello,
+			hello:        hello,
+			keyShareKeys: keyShareKeys,
+			session:      session,
+			earlySecret:  earlySecret,
+			binderKey:    binderKey,
+			echContext:   ech,
 		}
-
-		// In TLS 1.3, session tickets are delivered after the handshake.
 		return hs.handshake()
 	}
 
@@ -295,12 +413,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
 		hello:       hello,
 		session:     session,
 	}
-
-	if err := hs.handshake(); err != nil {
-		return err
-	}
-
-	return nil
+	return hs.handshake()
 }
 
 func (c *Conn) loadSession(hello *clientHelloMsg) (
@@ -309,7 +422,11 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
 		return nil, nil, nil, nil
 	}
 
-	hello.ticketSupported = true
+	echInner := bytes.Equal(hello.encryptedClientHello, []byte{1})
+
+	// ticketSupported is a TLS 1.2 extension (as TLS 1.3 replaced tickets with PSK
+	// identities) and ECH requires and forces TLS 1.3.
+	hello.ticketSupported = true && !echInner
 
 	if hello.supportedVersions[0] == VersionTLS13 {
 		// Require DHE on resumption as it guarantees forward secrecy against
@@ -375,7 +492,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
 			return nil, nil, nil, nil
 		}
 
-		hello.sessionTicket = cs.ticket
+		hello.sessionTicket = session.ticket
 		return
 	}
 
@@ -403,10 +520,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
 		return nil, nil, nil, nil
 	}
 
-	if c.quic != nil && session.EarlyData {
+	if c.quic != nil {
+		if c.quic.enableSessionEvents {
+			c.quicResumeSession(session)
+		}
+
 		// For 0-RTT, the cipher suite has to match exactly, and we need to be
 		// offering the same ALPN.
-		if mutualCipherSuiteTLS13(hello.cipherSuites, session.cipherSuite) != nil {
+		if session.EarlyData && mutualCipherSuiteTLS13(hello.cipherSuites, session.cipherSuite) != nil {
 			for _, alpn := range hello.alpnProtocols {
 				if alpn == session.alpnProtocol {
 					hello.earlyData = true
@@ -419,7 +540,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
 	// Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1.
 	ticketAge := c.config.time().Sub(time.Unix(int64(session.createdAt), 0))
 	identity := pskIdentity{
-		label:               cs.ticket,
+		label:               session.ticket,
 		obfuscatedTicketAge: uint32(ticketAge/time.Millisecond) + session.ageAdd,
 	}
 	hello.pskIdentities = []pskIdentity{identity}
@@ -429,13 +550,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (
 	earlySecret = cipherSuite.extract(session.secret, nil)
 	binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
 	transcript := cipherSuite.hash.New()
-	helloBytes, err := hello.marshalWithoutBinders()
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	transcript.Write(helloBytes)
-	pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)}
-	if err := hello.updateBinders(pskBinders); err != nil {
+	if err := computeAndUpdatePSK(hello, binderKey, transcript, cipherSuite.finishedHash); err != nil {
 		return nil, nil, nil, err
 	}
 
@@ -554,6 +669,17 @@ func (hs *clientHandshakeState) pickCipherSuite() error {
 		return errors.New("tls: server chose an unconfigured cipher suite")
 	}
 
+	// [Psiphon] BEGIN
+	// if hs.c.config.CipherSuites == nil && !needFIPS() && rsaKexCiphers[hs.suite.id] {
+	// 	tlsrsakex.Value() // ensure godebug is initialized
+	// 	tlsrsakex.IncNonDefault()
+	// }
+	// if hs.c.config.CipherSuites == nil && !needFIPS() && tdesCiphers[hs.suite.id] {
+	// 	tls3des.Value() // ensure godebug is initialized
+	// 	tls3des.IncNonDefault()
+	// }
+	// [Psiphon] END
+
 	hs.c.cipherSuite = hs.suite.id
 	return nil
 }
@@ -626,6 +752,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
 			c.sendAlert(alertUnexpectedMessage)
 			return err
 		}
+		if len(skx.key) >= 3 && skx.key[0] == 3 /* named curve */ {
+			c.curveID = CurveID(byteorder.BeUint16(skx.key[1:]))
+		}
 
 		msg, err = c.readHandshake(&hs.finishedHash)
 		if err != nil {
@@ -939,13 +1068,11 @@ func (hs *clientHandshakeState) saveSessionTicket() error {
 		return nil
 	}
 
-	session, err := c.sessionState()
-	if err != nil {
-		return err
-	}
+	session := c.sessionState()
 	session.secret = hs.masterSecret
+	session.ticket = hs.ticket
 
-	cs := &ClientSessionState{ticket: hs.ticket, session: session}
+	cs := &ClientSessionState{session: session}
 	c.config.ClientSessionCache.Put(cacheKey, cs)
 	return nil
 }
@@ -970,7 +1097,19 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error {
 // to verify the signatures of during a TLS handshake.
 const defaultMaxRSAKeySize = 8192
 
+// [Psiphon]
+// var tlsmaxrsasize = godebug.New("tlsmaxrsasize")
+
 func checkKeySize(n int) (max int, ok bool) {
+	// [Psiphon]
+	// if v := tlsmaxrsasize.Value(); v != "" {
+	// 	if max, err := strconv.Atoi(v); err == nil {
+	// 		if (n <= max) != (n <= defaultMaxRSAKeySize) {
+	// 			tlsmaxrsasize.IncNonDefault()
+	// 		}
+	// 		return max, n <= max
+	// 	}
+	// }
 	return defaultMaxRSAKeySize, n <= defaultMaxRSAKeySize
 }
 
@@ -996,7 +1135,32 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
 		certs[i] = cert.cert
 	}
 
-	if !c.config.InsecureSkipVerify {
+	echRejected := c.config.EncryptedClientHelloConfigList != nil && !c.echAccepted
+	if echRejected {
+		if c.config.EncryptedClientHelloRejectionVerify != nil {
+			if err := c.config.EncryptedClientHelloRejectionVerify(c.connectionStateLocked()); err != nil {
+				c.sendAlert(alertBadCertificate)
+				return err
+			}
+		} else {
+			opts := x509.VerifyOptions{
+				Roots:         c.config.RootCAs,
+				CurrentTime:   c.config.time(),
+				DNSName:       c.serverName,
+				Intermediates: x509.NewCertPool(),
+			}
+
+			for _, cert := range certs[1:] {
+				opts.Intermediates.AddCert(cert)
+			}
+			var err error
+			c.verifiedChains, err = certs[0].Verify(opts)
+			if err != nil {
+				c.sendAlert(alertBadCertificate)
+				return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
+			}
+		}
+	} else if !c.config.InsecureSkipVerify {
 		opts := x509.VerifyOptions{
 			Roots:         c.config.RootCAs,
 			CurrentTime:   c.config.time(),
@@ -1026,14 +1190,14 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
 	c.activeCertHandles = activeHandles
 	c.peerCertificates = certs
 
-	if c.config.VerifyPeerCertificate != nil {
+	if c.config.VerifyPeerCertificate != nil && !echRejected {
 		if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
 			c.sendAlert(alertBadCertificate)
 			return err
 		}
 	}
 
-	if c.config.VerifyConnection != nil {
+	if c.config.VerifyConnection != nil && !echRejected {
 		if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
 			c.sendAlert(alertBadCertificate)
 			return err
@@ -1156,3 +1320,13 @@ func hostnameInSNI(name string) string {
 	}
 	return name
 }
+
+func computeAndUpdatePSK(m *clientHelloMsg, binderKey []byte, transcript hash.Hash, finishedHash func([]byte, hash.Hash) []byte) error {
+	helloBytes, err := m.marshalWithoutBinders()
+	if err != nil {
+		return err
+	}
+	transcript.Write(helloBytes)
+	pskBinders := [][]byte{finishedHash(binderKey, transcript)}
+	return m.updateBinders(pskBinders)
+}

+ 190 - 44
vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_client_tls13.go

@@ -8,20 +8,23 @@ import (
 	"bytes"
 	"context"
 	"crypto"
-	"crypto/ecdh"
 	"crypto/hmac"
 	"crypto/rsa"
+	"crypto/subtle"
 	"errors"
 	"hash"
+	"slices"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tls/internal/mlkem768"
 )
 
 type clientHandshakeStateTLS13 struct {
-	c           *Conn
-	ctx         context.Context
-	serverHello *serverHelloMsg
-	hello       *clientHelloMsg
-	ecdheKey    *ecdh.PrivateKey
+	c            *Conn
+	ctx          context.Context
+	serverHello  *serverHelloMsg
+	hello        *clientHelloMsg
+	keyShareKeys *keySharePrivateKeys
 
 	session     *SessionState
 	earlySecret []byte
@@ -34,13 +37,19 @@ type clientHandshakeStateTLS13 struct {
 	transcript    hash.Hash
 	masterSecret  []byte
 	trafficSecret []byte // client_application_traffic_secret_0
+
+	echContext *echContext
 }
 
-// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheKey, and,
+// handshake requires hs.c, hs.hello, hs.serverHello, hs.keyShareKeys, and,
 // optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
 func (hs *clientHandshakeStateTLS13) handshake() error {
 	c := hs.c
 
+	if needFIPS() {
+		return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
+	}
+
 	// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
 	// sections 4.1.2 and 4.1.3.
 	if c.handshakes > 0 {
@@ -49,7 +58,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
 	}
 
 	// Consistency check on the presence of a keyShare and its parameters.
-	if hs.ecdheKey == nil || len(hs.hello.keyShares) != 1 {
+	if hs.keyShareKeys == nil || hs.keyShareKeys.ecdhe == nil || len(hs.hello.keyShares) == 0 {
 		return c.sendAlert(alertInternalError)
 	}
 
@@ -63,6 +72,13 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
 		return err
 	}
 
+	if hs.echContext != nil {
+		hs.echContext.innerTranscript = hs.suite.hash.New()
+		if err := transcriptMsg(hs.echContext.innerHello, hs.echContext.innerTranscript); err != nil {
+			return err
+		}
+	}
+
 	if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
 		if err := hs.sendDummyChangeCipherSpec(); err != nil {
 			return err
@@ -72,6 +88,41 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
 		}
 	}
 
+	var echRetryConfigList []byte
+	if hs.echContext != nil {
+		confTranscript := cloneHash(hs.echContext.innerTranscript, hs.suite.hash)
+		confTranscript.Write(hs.serverHello.original[:30])
+		confTranscript.Write(make([]byte, 8))
+		confTranscript.Write(hs.serverHello.original[38:])
+		acceptConfirmation := hs.suite.expandLabel(
+			hs.suite.extract(hs.echContext.innerHello.random, nil),
+			"ech accept confirmation",
+			confTranscript.Sum(nil),
+			8,
+		)
+		if subtle.ConstantTimeCompare(acceptConfirmation, hs.serverHello.random[len(hs.serverHello.random)-8:]) == 1 {
+			hs.hello = hs.echContext.innerHello
+			c.serverName = c.config.ServerName
+			hs.transcript = hs.echContext.innerTranscript
+			c.echAccepted = true
+
+			if hs.serverHello.encryptedClientHello != nil {
+				c.sendAlert(alertUnsupportedExtension)
+				return errors.New("tls: unexpected encrypted_client_hello extension in server hello despite ECH being accepted")
+			}
+
+			if hs.hello.serverName == "" && hs.serverHello.serverNameAck {
+				c.sendAlert(alertUnsupportedExtension)
+				return errors.New("tls: unexpected server_name extension in server hello")
+			}
+		} else {
+			hs.echContext.echRejected = true
+			// If the server sent us retry configs, we'll return these to
+			// the user so they can update their Config.
+			echRetryConfigList = hs.serverHello.encryptedClientHello
+		}
+	}
+
 	if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
 		return err
 	}
@@ -105,6 +156,11 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
 		return err
 	}
 
+	if hs.echContext != nil && hs.echContext.echRejected {
+		c.sendAlert(alertECHRequired)
+		return &ECHRejectionError{echRetryConfigList}
+	}
+
 	c.isHandshakeComplete.Store(true)
 
 	return nil
@@ -196,6 +252,48 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
 		return err
 	}
 
+	var isInnerHello bool
+	hello := hs.hello
+	if hs.echContext != nil {
+		chHash = hs.echContext.innerTranscript.Sum(nil)
+		hs.echContext.innerTranscript.Reset()
+		hs.echContext.innerTranscript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
+		hs.echContext.innerTranscript.Write(chHash)
+
+		if hs.serverHello.encryptedClientHello != nil {
+			if len(hs.serverHello.encryptedClientHello) != 8 {
+				hs.c.sendAlert(alertDecodeError)
+				return errors.New("tls: malformed encrypted client hello extension")
+			}
+
+			confTranscript := cloneHash(hs.echContext.innerTranscript, hs.suite.hash)
+			hrrHello := make([]byte, len(hs.serverHello.original))
+			copy(hrrHello, hs.serverHello.original)
+			hrrHello = bytes.Replace(hrrHello, hs.serverHello.encryptedClientHello, make([]byte, 8), 1)
+			confTranscript.Write(hrrHello)
+			acceptConfirmation := hs.suite.expandLabel(
+				hs.suite.extract(hs.echContext.innerHello.random, nil),
+				"hrr ech accept confirmation",
+				confTranscript.Sum(nil),
+				8,
+			)
+			if subtle.ConstantTimeCompare(acceptConfirmation, hs.serverHello.encryptedClientHello) == 1 {
+				hello = hs.echContext.innerHello
+				c.serverName = c.config.ServerName
+				isInnerHello = true
+				c.echAccepted = true
+			}
+		}
+
+		if err := transcriptMsg(hs.serverHello, hs.echContext.innerTranscript); err != nil {
+			return err
+		}
+	} else if hs.serverHello.encryptedClientHello != nil {
+		// Unsolicited ECH extension should be rejected
+		c.sendAlert(alertUnsupportedExtension)
+		return errors.New("tls: unexpected ECH extension in serverHello")
+	}
+
 	// The only HelloRetryRequest extensions we support are key_share and
 	// cookie, and clients must abort the handshake if the HRR would not result
 	// in any change in the ClientHello.
@@ -205,7 +303,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
 	}
 
 	if hs.serverHello.cookie != nil {
-		hs.hello.cookie = hs.serverHello.cookie
+		hello.cookie = hs.serverHello.cookie
 	}
 
 	if hs.serverHello.serverShare.group != 0 {
@@ -217,21 +315,22 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
 	// a group we advertised but did not send a key share for, and send a key
 	// share for it this time.
 	if curveID := hs.serverHello.selectedGroup; curveID != 0 {
-		curveOK := false
-		for _, id := range hs.hello.supportedCurves {
-			if id == curveID {
-				curveOK = true
-				break
-			}
-		}
-		if !curveOK {
+		if !slices.Contains(hello.supportedCurves, curveID) {
 			c.sendAlert(alertIllegalParameter)
 			return errors.New("tls: server selected unsupported group")
 		}
-		if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID {
+		if slices.ContainsFunc(hs.hello.keyShares, func(ks keyShare) bool {
+			return ks.group == curveID
+		}) {
 			c.sendAlert(alertIllegalParameter)
 			return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
 		}
+		// Note: we don't support selecting X25519Kyber768Draft00 in a HRR,
+		// because we currently only support it at all when CurvePreferences is
+		// empty, which will cause us to also send a key share for it.
+		//
+		// This will have to change once we support selecting hybrid KEMs
+		// without sending key shares for them.
 		if _, ok := curveForCurveID(curveID); !ok {
 			c.sendAlert(alertInternalError)
 			return errors.New("tls: CurvePreferences includes unsupported curve")
@@ -241,12 +340,11 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
 			c.sendAlert(alertInternalError)
 			return err
 		}
-		hs.ecdheKey = key
-		hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
+		hs.keyShareKeys = &keySharePrivateKeys{curveID: curveID, ecdhe: key}
+		hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
 	}
 
-	hs.hello.raw = nil
-	if len(hs.hello.pskIdentities) > 0 {
+	if len(hello.pskIdentities) > 0 {
 		pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
 		if pskSuite == nil {
 			return c.sendAlert(alertInternalError)
@@ -254,7 +352,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
 		if pskSuite.hash == hs.suite.hash {
 			// Update binders and obfuscated_ticket_age.
 			ticketAge := c.config.time().Sub(time.Unix(int64(hs.session.createdAt), 0))
-			hs.hello.pskIdentities[0].obfuscatedTicketAge = uint32(ticketAge/time.Millisecond) + hs.session.ageAdd
+			hello.pskIdentities[0].obfuscatedTicketAge = uint32(ticketAge/time.Millisecond) + hs.session.ageAdd
 
 			transcript := hs.suite.hash.New()
 			transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
@@ -262,27 +360,40 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
 			if err := transcriptMsg(hs.serverHello, transcript); err != nil {
 				return err
 			}
-			helloBytes, err := hs.hello.marshalWithoutBinders()
-			if err != nil {
-				return err
-			}
-			transcript.Write(helloBytes)
-			pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
-			if err := hs.hello.updateBinders(pskBinders); err != nil {
+
+			if err := computeAndUpdatePSK(hello, hs.binderKey, transcript, hs.suite.finishedHash); err != nil {
 				return err
 			}
 		} else {
 			// Server selected a cipher suite incompatible with the PSK.
-			hs.hello.pskIdentities = nil
-			hs.hello.pskBinders = nil
+			hello.pskIdentities = nil
+			hello.pskBinders = nil
 		}
 	}
 
-	if hs.hello.earlyData {
-		hs.hello.earlyData = false
+	if hello.earlyData {
+		hello.earlyData = false
 		c.quicRejectedEarlyData()
 	}
 
+	if isInnerHello {
+		// Any extensions which have changed in hello, but are mirrored in the
+		// outer hello and compressed, need to be copied to the outer hello, so
+		// they can be properly decompressed by the server. For now, the only
+		// extension which may have changed is keyShares.
+		hs.hello.keyShares = hello.keyShares
+		hs.echContext.innerHello = hello
+		if err := transcriptMsg(hs.echContext.innerHello, hs.echContext.innerTranscript); err != nil {
+			return err
+		}
+
+		if err := computeAndUpdateOuterECHExtension(hs.hello, hs.echContext.innerHello, hs.echContext, false); err != nil {
+			return err
+		}
+	} else {
+		hs.hello = hello
+	}
+
 	if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
 		return err
 	}
@@ -304,6 +415,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
 		return err
 	}
 
+	c.didHRR = true
 	return nil
 }
 
@@ -329,7 +441,9 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
 		c.sendAlert(alertIllegalParameter)
 		return errors.New("tls: server did not send a key share")
 	}
-	if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID {
+	if !slices.ContainsFunc(hs.hello.keyShares, func(ks keyShare) bool {
+		return ks.group == hs.serverHello.serverShare.group
+	}) {
 		c.sendAlert(alertIllegalParameter)
 		return errors.New("tls: server selected unsupported group")
 	}
@@ -368,16 +482,37 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
 func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
 	c := hs.c
 
-	peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data)
+	ecdhePeerData := hs.serverHello.serverShare.data
+	if hs.serverHello.serverShare.group == x25519Kyber768Draft00 {
+		if len(ecdhePeerData) != x25519PublicKeySize+mlkem768.CiphertextSize {
+			c.sendAlert(alertIllegalParameter)
+			return errors.New("tls: invalid server key share")
+		}
+		ecdhePeerData = hs.serverHello.serverShare.data[:x25519PublicKeySize]
+	}
+	peerKey, err := hs.keyShareKeys.ecdhe.Curve().NewPublicKey(ecdhePeerData)
 	if err != nil {
 		c.sendAlert(alertIllegalParameter)
 		return errors.New("tls: invalid server key share")
 	}
-	sharedKey, err := hs.ecdheKey.ECDH(peerKey)
+	sharedKey, err := hs.keyShareKeys.ecdhe.ECDH(peerKey)
 	if err != nil {
 		c.sendAlert(alertIllegalParameter)
 		return errors.New("tls: invalid server key share")
 	}
+	if hs.serverHello.serverShare.group == x25519Kyber768Draft00 {
+		if hs.keyShareKeys.kyber == nil {
+			return c.sendAlert(alertInternalError)
+		}
+		ciphertext := hs.serverHello.serverShare.data[x25519PublicKeySize:]
+		kyberShared, err := kyberDecapsulate(hs.keyShareKeys.kyber, ciphertext)
+		if err != nil {
+			c.sendAlert(alertIllegalParameter)
+			return errors.New("tls: invalid Kyber server key share")
+		}
+		sharedKey = append(sharedKey, kyberShared...)
+	}
+	c.curveID = hs.serverHello.serverShare.group
 
 	earlySecret := hs.earlySecret
 	if !hs.usingPSK {
@@ -474,6 +609,10 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
 			return errors.New("tls: server accepted 0-RTT with the wrong ALPN")
 		}
 	}
+	if hs.echContext != nil && !hs.echContext.echRejected && encryptedExtensions.echRetryConfigs != nil {
+		c.sendAlert(alertUnsupportedExtension)
+		return errors.New("tls: server sent ECH retry configs after accepting ECH")
+	}
 
 	return nil
 }
@@ -627,6 +766,13 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
 		return nil
 	}
 
+	if hs.echContext != nil && hs.echContext.echRejected {
+		if _, err := hs.c.writeHandshakeRecord(&certificateMsgTLS13{}, hs.transcript); err != nil {
+			return err
+		}
+		return nil
+	}
+
 	cert, err := c.getClientCertificate(&CertificateRequestInfo{
 		AcceptableCAs:    hs.certReq.certificateAuthorities,
 		SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
@@ -749,17 +895,17 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
 	psk := cipherSuite.expandLabel(c.resumptionSecret, "resumption",
 		msg.nonce, cipherSuite.hash.Size())
 
-	session, err := c.sessionState()
-	if err != nil {
-		c.sendAlert(alertInternalError)
-		return err
-	}
+	session := c.sessionState()
 	session.secret = psk
 	session.useBy = uint64(c.config.time().Add(lifetime).Unix())
 	session.ageAdd = msg.ageAdd
 	session.EarlyData = c.quic != nil && msg.maxEarlyData == 0xffffffff // RFC 9001, Section 4.6.1
-	cs := &ClientSessionState{ticket: msg.label, session: session}
-
+	session.ticket = msg.label
+	if c.quic != nil && c.quic.enableSessionEvents {
+		c.quicStoreSession(session)
+		return nil
+	}
+	cs := &ClientSessionState{session: session}
 	if cacheKey := c.clientSessionCacheKey(); cacheKey != "" {
 		c.config.ClientSessionCache.Put(cacheKey, cs)
 	}

File diff suppressed because it is too large
+ 368 - 263
vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_messages.go


+ 27 - 19
vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_server.go

@@ -19,6 +19,8 @@ import (
 	"io"
 	"net"
 	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tls/byteorder"
 )
 
 // serverHandshakeState contains details of a server handshake in progress.
@@ -389,6 +391,12 @@ func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
 	c.in.version = c.vers
 	c.out.version = c.vers
 
+	// [Psiphon]
+	// if c.config.MinVersion == 0 && c.vers < VersionTLS12 {
+	// 	tls10server.Value() // ensure godebug is initialized
+	// 	tls10server.IncNonDefault()
+	// }
+
 	return clientHello, nil
 }
 
@@ -463,7 +471,7 @@ func (hs *serverHandshakeState) processClientHello() error {
 		hs.hello.scts = hs.cert.SignedCertificateTimestamps
 	}
 
-	hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints)
+	hs.ecdheOk = supportsECDHE(c.config, c.vers, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints)
 
 	if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 {
 		// Although omitting the ec_point_formats extension is permitted, some
@@ -534,10 +542,10 @@ func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, erro
 
 // supportsECDHE returns whether ECDHE key exchanges can be used with this
 // pre-TLS 1.3 client.
-func supportsECDHE(c *Config, supportedCurves []CurveID, supportedPoints []uint8) bool {
+func supportsECDHE(c *Config, version uint16, supportedCurves []CurveID, supportedPoints []uint8) bool {
 	supportsCurve := false
 	for _, curve := range supportedCurves {
-		if c.supportsCurve(curve) {
+		if c.supportsCurve(version, curve) {
 			supportsCurve = true
 			break
 		}
@@ -587,6 +595,17 @@ func (hs *serverHandshakeState) pickCipherSuite() error {
 	}
 	c.cipherSuite = hs.suite.id
 
+	// [Psiphon] BEGIN
+	// if c.config.CipherSuites == nil && !needFIPS() && rsaKexCiphers[hs.suite.id] {
+	// 	tlsrsakex.Value() // ensure godebug is initialized
+	// 	tlsrsakex.IncNonDefault()
+	// }
+	// if c.config.CipherSuites == nil && !needFIPS() && tdesCiphers[hs.suite.id] {
+	// 	tls3des.Value() // ensure godebug is initialized
+	// 	tls3des.IncNonDefault()
+	// }
+	// [Psiphon] END
+
 	for _, id := range hs.clientHello.cipherSuites {
 		if id == TLS_FALLBACK_SCSV {
 			// The client is doing a fallback connection. See RFC 7507.
@@ -704,18 +723,6 @@ func (hs *serverHandshakeState) checkForResumption() error {
 	if !sessionState.extMasterSecret && hs.clientHello.extendedMasterSecret {
 		return nil
 	}
-
-	// [Psiphon]
-	// When using obfuscated session tickets, the client-generated session ticket
-	// state never uses EMS. ClientHellos vary in EMS support. So, in this mode,
-	// skip this check to ensure the obfuscated session tickets are not
-	// rejected.
-	if !c.config.UseObfuscatedSessionTickets {
-		if !sessionState.extMasterSecret && hs.clientHello.extendedMasterSecret {
-			return nil
-		}
-	}
-
 	if sessionState.extMasterSecret && !hs.clientHello.extendedMasterSecret {
 		// Aborting is somewhat harsh, but it's a MUST and it would indicate a
 		// weird downgrade in client capabilities.
@@ -810,6 +817,9 @@ func (hs *serverHandshakeState) doFullHandshake() error {
 		return err
 	}
 	if skx != nil {
+		if len(skx.key) >= 3 && skx.key[0] == 3 /* named curve */ {
+			c.curveID = CurveID(byteorder.BeUint16(skx.key[1:]))
+		}
 		if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil {
 			return err
 		}
@@ -1035,10 +1045,7 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
 	c := hs.c
 	m := new(newSessionTicketMsg)
 
-	state, err := c.sessionState()
-	if err != nil {
-		return err
-	}
+	state := c.sessionState()
 	state.secret = hs.masterSecret
 	if hs.sessionState != nil {
 		// If this is re-wrapping an old key, then keep
@@ -1046,6 +1053,7 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
 		state.createdAt = hs.sessionState.createdAt
 	}
 	if c.config.WrapSession != nil {
+		var err error
 		m.ticket, err = c.config.WrapSession(c.connectionStateLocked(), state)
 		if err != nil {
 			return err

+ 91 - 45
vendor/github.com/Psiphon-Labs/psiphon-tls/handshake_server_tls13.go

@@ -10,13 +10,15 @@ import (
 	"crypto"
 	"crypto/hmac"
 	"crypto/rsa"
-	"encoding/binary"
 	"errors"
 	"hash"
 	"io"
+	"slices"
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	"github.com/Psiphon-Labs/psiphon-tls/byteorder"
+	"github.com/Psiphon-Labs/psiphon-tls/internal/mlkem768"
 )
 
 // maxClientPSKIdentities is the number of client PSK identities the server will
@@ -47,6 +49,10 @@ type serverHandshakeStateTLS13 struct {
 func (hs *serverHandshakeStateTLS13) handshake() error {
 	c := hs.c
 
+	if needFIPS() {
+		return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
+	}
+
 	// For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2.
 	if err := hs.processClientHello(); err != nil {
 		return err
@@ -161,9 +167,6 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error {
 	if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
 		preferenceList = defaultCipherSuitesTLS13NoAES
 	}
-	if needFIPS() {
-		preferenceList = defaultCipherSuitesTLS13FIPS
-	}
 	for _, suiteID := range preferenceList {
 		hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID)
 		if hs.suite != nil {
@@ -178,25 +181,29 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error {
 	hs.hello.cipherSuite = hs.suite.id
 	hs.transcript = hs.suite.hash.New()
 
-	// Pick the ECDHE group in server preference order, but give priority to
-	// groups with a key share, to avoid a HelloRetryRequest round-trip.
+	// Pick the key exchange method in server preference order, but give
+	// priority to key shares, to avoid a HelloRetryRequest round-trip.
 	var selectedGroup CurveID
 	var clientKeyShare *keyShare
-GroupSelection:
-	for _, preferredGroup := range c.config.curvePreferences() {
-		for _, ks := range hs.clientHello.keyShares {
-			if ks.group == preferredGroup {
-				selectedGroup = ks.group
-				clientKeyShare = &ks
-				break GroupSelection
+	preferredGroups := c.config.curvePreferences(c.vers)
+	for _, preferredGroup := range preferredGroups {
+		ki := slices.IndexFunc(hs.clientHello.keyShares, func(ks keyShare) bool {
+			return ks.group == preferredGroup
+		})
+		if ki != -1 {
+			clientKeyShare = &hs.clientHello.keyShares[ki]
+			selectedGroup = clientKeyShare.group
+			if !slices.Contains(hs.clientHello.supportedCurves, selectedGroup) {
+				c.sendAlert(alertIllegalParameter)
+				return errors.New("tls: client sent key share for group it does not support")
 			}
+			break
 		}
-		if selectedGroup != 0 {
-			continue
-		}
-		for _, group := range hs.clientHello.supportedCurves {
-			if group == preferredGroup {
-				selectedGroup = group
+	}
+	if selectedGroup == 0 {
+		for _, preferredGroup := range preferredGroups {
+			if slices.Contains(hs.clientHello.supportedCurves, preferredGroup) {
+				selectedGroup = preferredGroup
 				break
 			}
 		}
@@ -206,23 +213,35 @@ GroupSelection:
 		return errors.New("tls: no ECDHE curve supported by both client and server")
 	}
 	if clientKeyShare == nil {
-		if err := hs.doHelloRetryRequest(selectedGroup); err != nil {
+		ks, err := hs.doHelloRetryRequest(selectedGroup)
+		if err != nil {
 			return err
 		}
-		clientKeyShare = &hs.clientHello.keyShares[0]
+		clientKeyShare = ks
 	}
+	c.curveID = selectedGroup
 
-	if _, ok := curveForCurveID(selectedGroup); !ok {
+	ecdhGroup := selectedGroup
+	ecdhData := clientKeyShare.data
+	if selectedGroup == x25519Kyber768Draft00 {
+		ecdhGroup = X25519
+		if len(ecdhData) != x25519PublicKeySize+mlkem768.EncapsulationKeySize {
+			c.sendAlert(alertIllegalParameter)
+			return errors.New("tls: invalid Kyber client key share")
+		}
+		ecdhData = ecdhData[:x25519PublicKeySize]
+	}
+	if _, ok := curveForCurveID(ecdhGroup); !ok {
 		c.sendAlert(alertInternalError)
 		return errors.New("tls: CurvePreferences includes unsupported curve")
 	}
-	key, err := generateECDHEKey(c.config.rand(), selectedGroup)
+	key, err := generateECDHEKey(c.config.rand(), ecdhGroup)
 	if err != nil {
 		c.sendAlert(alertInternalError)
 		return err
 	}
 	hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()}
-	peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data)
+	peerKey, err := key.Curve().NewPublicKey(ecdhData)
 	if err != nil {
 		c.sendAlert(alertIllegalParameter)
 		return errors.New("tls: invalid client key share")
@@ -232,6 +251,15 @@ GroupSelection:
 		c.sendAlert(alertIllegalParameter)
 		return errors.New("tls: invalid client key share")
 	}
+	if selectedGroup == x25519Kyber768Draft00 {
+		ciphertext, kyberShared, err := kyberEncapsulate(clientKeyShare.data[x25519PublicKeySize:])
+		if err != nil {
+			c.sendAlert(alertIllegalParameter)
+			return errors.New("tls: invalid Kyber client key share")
+		}
+		hs.sharedKey = append(hs.sharedKey, kyberShared...)
+		hs.hello.serverShare.data = append(hs.hello.serverShare.data, ciphertext...)
+	}
 
 	selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
 	if err != nil {
@@ -241,8 +269,15 @@ GroupSelection:
 	c.clientProtocol = selectedProto
 
 	if c.quic != nil {
+		// RFC 9001 Section 4.2: Clients MUST NOT offer TLS versions older than 1.3.
+		for _, v := range hs.clientHello.supportedVersions {
+			if v < VersionTLS13 {
+				c.sendAlert(alertProtocolVersion)
+				return errors.New("tls: client offered TLS version older than TLS 1.3")
+			}
+		}
+		// RFC 9001 Section 8.2.
 		if hs.clientHello.quicTransportParameters == nil {
-			// RFC 9001 Section 8.2.
 			c.sendAlert(alertMissingExtension)
 			return errors.New("tls: client did not send a quic_transport_parameters extension")
 		}
@@ -344,6 +379,12 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
 			continue
 		}
 
+		if c.quic != nil && c.quic.enableSessionEvents {
+			if err := c.quicResumeSession(sessionState); err != nil {
+				return err
+			}
+		}
+
 		hs.earlySecret = hs.suite.extract(sessionState.secret, nil)
 		binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil)
 		// Clone the transcript in case a HelloRetryRequest was recorded.
@@ -468,13 +509,13 @@ func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
 	return hs.c.writeChangeCipherRecord()
 }
 
-func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error {
+func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) (*keyShare, error) {
 	c := hs.c
 
 	// The first ClientHello gets double-hashed into the transcript upon a
 	// HelloRetryRequest. See RFC 8446, Section 4.4.1.
 	if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
-		return err
+		return nil, err
 	}
 	chHash := hs.transcript.Sum(nil)
 	hs.transcript.Reset()
@@ -492,42 +533,49 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID)
 	}
 
 	if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
-		return err
+		return nil, err
 	}
 
 	if err := hs.sendDummyChangeCipherSpec(); err != nil {
-		return err
+		return nil, err
 	}
 
 	// clientHelloMsg is not included in the transcript.
 	msg, err := c.readHandshake(nil)
 	if err != nil {
-		return err
+		return nil, err
 	}
 
 	clientHello, ok := msg.(*clientHelloMsg)
 	if !ok {
 		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(clientHello, msg)
+		return nil, unexpectedMessageError(clientHello, msg)
 	}
 
-	if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup {
+	if len(clientHello.keyShares) != 1 {
 		c.sendAlert(alertIllegalParameter)
-		return errors.New("tls: client sent invalid key share in second ClientHello")
+		return nil, errors.New("tls: client didn't send one key share in second ClientHello")
+	}
+	ks := &clientHello.keyShares[0]
+
+	if ks.group != selectedGroup {
+		c.sendAlert(alertIllegalParameter)
+		return nil, errors.New("tls: client sent unexpected key share in second ClientHello")
 	}
 
 	if clientHello.earlyData {
 		c.sendAlert(alertIllegalParameter)
-		return errors.New("tls: client indicated early data in second ClientHello")
+		return nil, errors.New("tls: client indicated early data in second ClientHello")
 	}
 
 	if illegalClientHelloChange(clientHello, hs.clientHello) {
 		c.sendAlert(alertIllegalParameter)
-		return errors.New("tls: client illegally modified second ClientHello")
+		return nil, errors.New("tls: client illegally modified second ClientHello")
 	}
 
+	c.didHRR = true
 	hs.clientHello = clientHello
-	return nil
+	return ks, nil
 }
 
 // illegalClientHelloChange reports whether the two ClientHello messages are
@@ -816,10 +864,10 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
 	if !hs.shouldSendSessionTickets() {
 		return nil
 	}
-	return c.sendSessionTicket(false)
+	return c.sendSessionTicket(false, nil)
 }
 
-func (c *Conn) sendSessionTicket(earlyData bool) error {
+func (c *Conn) sendSessionTicket(earlyData bool, extra [][]byte) error {
 	suite := cipherSuiteTLS13ByID(c.cipherSuite)
 	if suite == nil {
 		return errors.New("tls: internal error: unknown cipher suite")
@@ -831,13 +879,12 @@ func (c *Conn) sendSessionTicket(earlyData bool) error {
 
 	m := new(newSessionTicketMsgTLS13)
 
-	state, err := c.sessionState()
-	if err != nil {
-		return err
-	}
+	state := c.sessionState()
 	state.secret = psk
 	state.EarlyData = earlyData
+	state.Extra = extra
 	if c.config.WrapSession != nil {
+		var err error
 		m.label, err = c.config.WrapSession(c.connectionStateLocked(), state)
 		if err != nil {
 			return err
@@ -867,11 +914,10 @@ func (c *Conn) sendSessionTicket(earlyData bool) error {
 	// The value is not stored anywhere; we never need to check the ticket age
 	// because 0-RTT is not supported.
 	ageAdd := make([]byte, 4)
-	_, err = c.config.rand().Read(ageAdd)
-	if err != nil {
+	if _, err := c.config.rand().Read(ageAdd); err != nil {
 		return err
 	}
-	m.ageAdd = binary.LittleEndian.Uint32(ageAdd)
+	m.ageAdd = byteorder.LeUint32(ageAdd)
 
 	if earlyData {
 		// RFC 9001, Section 4.6.1

+ 21 - 0
vendor/github.com/Psiphon-Labs/psiphon-tls/internal/boring/notboring.go

@@ -0,0 +1,21 @@
+package boring
+
+import (
+	"crypto/cipher"
+	"errors"
+)
+
+const Enabled bool = false
+
+func NewGCMTLS(_ cipher.Block) (cipher.AEAD, error) {
+	return nil, errors.New("boring not implemented")
+}
+
+func NewGCMTLS13(_ cipher.Block) (cipher.AEAD, error) {
+	return nil, errors.New("boring not implemented")
+}
+
+func Unreachable() {
+	// do nothing
+}
+

+ 259 - 0
vendor/github.com/Psiphon-Labs/psiphon-tls/internal/hpke/hpke.go

@@ -0,0 +1,259 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package hpke
+
+import (
+	"crypto"
+	"crypto/aes"
+	"crypto/cipher"
+	"crypto/ecdh"
+	"crypto/rand"
+	"encoding/binary"
+	"errors"
+	"math/bits"
+
+	"golang.org/x/crypto/chacha20poly1305"
+	"golang.org/x/crypto/hkdf"
+)
+
+// testingOnlyGenerateKey is only used during testing, to provide
+// a fixed test key to use when checking the RFC 9180 vectors.
+var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
+
+type hkdfKDF struct {
+	hash crypto.Hash
+}
+
+func (kdf *hkdfKDF) LabeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) []byte {
+	labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey))
+	labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
+	labeledIKM = append(labeledIKM, suiteID...)
+	labeledIKM = append(labeledIKM, label...)
+	labeledIKM = append(labeledIKM, inputKey...)
+	return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
+}
+
+func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte {
+	labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
+	labeledInfo = binary.BigEndian.AppendUint16(labeledInfo, length)
+	labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
+	labeledInfo = append(labeledInfo, suiteID...)
+	labeledInfo = append(labeledInfo, label...)
+	labeledInfo = append(labeledInfo, info...)
+	out := make([]byte, length)
+	n, err := hkdf.Expand(kdf.hash.New, randomKey, labeledInfo).Read(out)
+	if err != nil || n != int(length) {
+		panic("hpke: LabeledExpand failed unexpectedly")
+	}
+	return out
+}
+
+// dhKEM implements the KEM specified in RFC 9180, Section 4.1.
+type dhKEM struct {
+	dh  ecdh.Curve
+	kdf hkdfKDF
+
+	suiteID []byte
+	nSecret uint16
+}
+
+var SupportedKEMs = map[uint16]struct {
+	curve   ecdh.Curve
+	hash    crypto.Hash
+	nSecret uint16
+}{
+	// RFC 9180 Section 7.1
+	0x0020: {ecdh.X25519(), crypto.SHA256, 32},
+}
+
+func newDHKem(kemID uint16) (*dhKEM, error) {
+	suite, ok := SupportedKEMs[kemID]
+	if !ok {
+		return nil, errors.New("unsupported suite ID")
+	}
+	return &dhKEM{
+		dh:      suite.curve,
+		kdf:     hkdfKDF{suite.hash},
+		suiteID: binary.BigEndian.AppendUint16([]byte("KEM"), kemID),
+		nSecret: suite.nSecret,
+	}, nil
+}
+
+func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte {
+	eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
+	return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
+}
+
+func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
+	var privEph *ecdh.PrivateKey
+	if testingOnlyGenerateKey != nil {
+		privEph, err = testingOnlyGenerateKey()
+	} else {
+		privEph, err = dh.dh.GenerateKey(rand.Reader)
+	}
+	if err != nil {
+		return nil, nil, err
+	}
+	dhVal, err := privEph.ECDH(pubRecipient)
+	if err != nil {
+		return nil, nil, err
+	}
+	encPubEph := privEph.PublicKey().Bytes()
+
+	encPubRecip := pubRecipient.Bytes()
+	kemContext := append(encPubEph, encPubRecip...)
+
+	return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil
+}
+
+type Sender struct {
+	aead cipher.AEAD
+	kem  *dhKEM
+
+	sharedSecret []byte
+
+	suiteID []byte
+
+	key            []byte
+	baseNonce      []byte
+	exporterSecret []byte
+
+	seqNum uint128
+}
+
+var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
+	block, err := aes.NewCipher(key)
+	if err != nil {
+		return nil, err
+	}
+	return cipher.NewGCM(block)
+}
+
+var SupportedAEADs = map[uint16]struct {
+	keySize   int
+	nonceSize int
+	aead      func([]byte) (cipher.AEAD, error)
+}{
+	// RFC 9180, Section 7.3
+	0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
+	0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
+	0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
+}
+
+var SupportedKDFs = map[uint16]func() *hkdfKDF{
+	// RFC 9180, Section 7.2
+	0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
+}
+
+func SetupSender(kemID, kdfID, aeadID uint16, pub crypto.PublicKey, info []byte) ([]byte, *Sender, error) {
+	suiteID := SuiteID(kemID, kdfID, aeadID)
+
+	kem, err := newDHKem(kemID)
+	if err != nil {
+		return nil, nil, err
+	}
+	pubRecipient, ok := pub.(*ecdh.PublicKey)
+	if !ok {
+		return nil, nil, errors.New("incorrect public key type")
+	}
+	sharedSecret, encapsulatedKey, err := kem.Encap(pubRecipient)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	kdfInit, ok := SupportedKDFs[kdfID]
+	if !ok {
+		return nil, nil, errors.New("unsupported KDF id")
+	}
+	kdf := kdfInit()
+
+	aeadInfo, ok := SupportedAEADs[aeadID]
+	if !ok {
+		return nil, nil, errors.New("unsupported AEAD id")
+	}
+
+	pskIDHash := kdf.LabeledExtract(suiteID, nil, "psk_id_hash", nil)
+	infoHash := kdf.LabeledExtract(suiteID, nil, "info_hash", info)
+	ksContext := append([]byte{0}, pskIDHash...)
+	ksContext = append(ksContext, infoHash...)
+
+	secret := kdf.LabeledExtract(suiteID, sharedSecret, "secret", nil)
+
+	key := kdf.LabeledExpand(suiteID, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
+	baseNonce := kdf.LabeledExpand(suiteID, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
+	exporterSecret := kdf.LabeledExpand(suiteID, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
+
+	aead, err := aeadInfo.aead(key)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	return encapsulatedKey, &Sender{
+		kem:            kem,
+		aead:           aead,
+		sharedSecret:   sharedSecret,
+		suiteID:        suiteID,
+		key:            key,
+		baseNonce:      baseNonce,
+		exporterSecret: exporterSecret,
+	}, nil
+}
+
+func (s *Sender) nextNonce() []byte {
+	nonce := s.seqNum.bytes()[16-s.aead.NonceSize():]
+	for i := range s.baseNonce {
+		nonce[i] ^= s.baseNonce[i]
+	}
+	// Message limit is, according to the RFC, 2^95+1, which
+	// is somewhat confusing, but we do as we're told.
+	if s.seqNum.bitLen() >= (s.aead.NonceSize()*8)-1 {
+		panic("message limit reached")
+	}
+	s.seqNum = s.seqNum.addOne()
+	return nonce
+}
+
+func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
+
+	ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
+	return ciphertext, nil
+}
+
+func SuiteID(kemID, kdfID, aeadID uint16) []byte {
+	suiteID := make([]byte, 0, 4+2+2+2)
+	suiteID = append(suiteID, []byte("HPKE")...)
+	suiteID = binary.BigEndian.AppendUint16(suiteID, kemID)
+	suiteID = binary.BigEndian.AppendUint16(suiteID, kdfID)
+	suiteID = binary.BigEndian.AppendUint16(suiteID, aeadID)
+	return suiteID
+}
+
+func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
+	kemInfo, ok := SupportedKEMs[kemID]
+	if !ok {
+		return nil, errors.New("unsupported KEM id")
+	}
+	return kemInfo.curve.NewPublicKey(bytes)
+}
+
+type uint128 struct {
+	hi, lo uint64
+}
+
+func (u uint128) addOne() uint128 {
+	lo, carry := bits.Add64(u.lo, 1, 0)
+	return uint128{u.hi + carry, lo}
+}
+
+func (u uint128) bitLen() int {
+	return bits.Len64(u.hi) + bits.Len64(u.lo)
+}
+
+func (u uint128) bytes() []byte {
+	b := make([]byte, 16)
+	binary.BigEndian.PutUint64(b[0:], u.hi)
+	binary.BigEndian.PutUint64(b[8:], u.lo)
+	return b
+}

+ 887 - 0
vendor/github.com/Psiphon-Labs/psiphon-tls/internal/mlkem768/mlkem768.go

@@ -0,0 +1,887 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package mlkem768 implements the quantum-resistant key encapsulation method
+// ML-KEM (formerly known as Kyber).
+//
+// Only the recommended ML-KEM-768 parameter set is provided.
+//
+// The version currently implemented is the one specified by [NIST FIPS 203 ipd],
+// with the unintentional transposition of the matrix A reverted to match the
+// behavior of [Kyber version 3.0]. Future versions of this package might
+// introduce backwards incompatible changes to implement changes to FIPS 203.
+//
+// [Kyber version 3.0]: https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
+// [NIST FIPS 203 ipd]: https://doi.org/10.6028/NIST.FIPS.203.ipd
+package mlkem768
+
+// This package targets security, correctness, simplicity, readability, and
+// reviewability as its primary goals. All critical operations are performed in
+// constant time.
+//
+// Variable and function names, as well as code layout, are selected to
+// facilitate reviewing the implementation against the NIST FIPS 203 ipd
+// document.
+//
+// Reviewers unfamiliar with polynomials or linear algebra might find the
+// background at https://words.filippo.io/kyber-math/ useful.
+
+import (
+	"crypto/rand"
+	"crypto/subtle"
+	"errors"
+
+	"github.com/Psiphon-Labs/psiphon-tls/byteorder"
+
+	"golang.org/x/crypto/sha3"
+)
+
+const (
+	// ML-KEM global constants.
+	n = 256
+	q = 3329
+
+	log2q = 12
+
+	// ML-KEM-768 parameters. The code makes assumptions based on these values,
+	// they can't be changed blindly.
+	k  = 3
+	η  = 2
+	du = 10
+	dv = 4
+
+	// encodingSizeX is the byte size of a ringElement or nttElement encoded
+	// by ByteEncode_X (FIPS 203 (DRAFT), Algorithm 4).
+	encodingSize12 = n * log2q / 8
+	encodingSize10 = n * du / 8
+	encodingSize4  = n * dv / 8
+	encodingSize1  = n * 1 / 8
+
+	messageSize       = encodingSize1
+	decryptionKeySize = k * encodingSize12
+	encryptionKeySize = k*encodingSize12 + 32
+
+	CiphertextSize       = k*encodingSize10 + encodingSize4
+	EncapsulationKeySize = encryptionKeySize
+	DecapsulationKeySize = decryptionKeySize + encryptionKeySize + 32 + 32
+	SharedKeySize        = 32
+	SeedSize             = 32 + 32
+)
+
+// A DecapsulationKey is the secret key used to decapsulate a shared key from a
+// ciphertext. It includes various precomputed values.
+type DecapsulationKey struct {
+	dk [DecapsulationKeySize]byte
+	encryptionKey
+	decryptionKey
+}
+
+// Bytes returns the extended encoding of the decapsulation key, according to
+// FIPS 203 (DRAFT).
+func (dk *DecapsulationKey) Bytes() []byte {
+	var b [DecapsulationKeySize]byte
+	copy(b[:], dk.dk[:])
+	return b[:]
+}
+
+// EncapsulationKey returns the public encapsulation key necessary to produce
+// ciphertexts.
+func (dk *DecapsulationKey) EncapsulationKey() []byte {
+	var b [EncapsulationKeySize]byte
+	copy(b[:], dk.dk[decryptionKeySize:])
+	return b[:]
+}
+
+// encryptionKey is the parsed and expanded form of a PKE encryption key.
+type encryptionKey struct {
+	t [k]nttElement     // ByteDecode₁₂(ek[:384k])
+	A [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
+}
+
+// decryptionKey is the parsed and expanded form of a PKE decryption key.
+type decryptionKey struct {
+	s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize])
+}
+
+// GenerateKey generates a new decapsulation key, drawing random bytes from
+// crypto/rand. The decapsulation key must be kept secret.
+func GenerateKey() (*DecapsulationKey, error) {
+	// The actual logic is in a separate function to outline this allocation.
+	dk := &DecapsulationKey{}
+	return generateKey(dk)
+}
+
+func generateKey(dk *DecapsulationKey) (*DecapsulationKey, error) {
+	var d [32]byte
+	if _, err := rand.Read(d[:]); err != nil {
+		return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error())
+	}
+	var z [32]byte
+	if _, err := rand.Read(z[:]); err != nil {
+		return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error())
+	}
+	return kemKeyGen(dk, &d, &z), nil
+}
+
+// NewKeyFromSeed deterministically generates a decapsulation key from a 64-byte
+// seed in the "d || z" form. The seed must be uniformly random.
+func NewKeyFromSeed(seed []byte) (*DecapsulationKey, error) {
+	// The actual logic is in a separate function to outline this allocation.
+	dk := &DecapsulationKey{}
+	return newKeyFromSeed(dk, seed)
+}
+
+func newKeyFromSeed(dk *DecapsulationKey, seed []byte) (*DecapsulationKey, error) {
+	if len(seed) != SeedSize {
+		return nil, errors.New("mlkem768: invalid seed length")
+	}
+	d := (*[32]byte)(seed[:32])
+	z := (*[32]byte)(seed[32:])
+	return kemKeyGen(dk, d, z), nil
+}
+
+// NewKeyFromExtendedEncoding parses a decapsulation key from its FIPS 203
+// (DRAFT) extended encoding.
+func NewKeyFromExtendedEncoding(decapsulationKey []byte) (*DecapsulationKey, error) {
+	// The actual logic is in a separate function to outline this allocation.
+	dk := &DecapsulationKey{}
+	return newKeyFromExtendedEncoding(dk, decapsulationKey)
+}
+
+func newKeyFromExtendedEncoding(dk *DecapsulationKey, dkBytes []byte) (*DecapsulationKey, error) {
+	if len(dkBytes) != DecapsulationKeySize {
+		return nil, errors.New("mlkem768: invalid decapsulation key length")
+	}
+
+	// Note that we don't check that H(ek) matches ekPKE, as that's not
+	// specified in FIPS 203 (DRAFT). This is one reason to prefer the seed
+	// private key format.
+	dk.dk = [DecapsulationKeySize]byte(dkBytes)
+
+	dkPKE := dkBytes[:decryptionKeySize]
+	if err := parseDK(&dk.decryptionKey, dkPKE); err != nil {
+		return nil, err
+	}
+
+	ekPKE := dkBytes[decryptionKeySize : decryptionKeySize+encryptionKeySize]
+	if err := parseEK(&dk.encryptionKey, ekPKE); err != nil {
+		return nil, err
+	}
+
+	return dk, nil
+}
+
+// kemKeyGen generates a decapsulation key.
+//
+// It implements ML-KEM.KeyGen according to FIPS 203 (DRAFT), Algorithm 15, and
+// K-PKE.KeyGen according to FIPS 203 (DRAFT), Algorithm 12. The two are merged
+// to save copies and allocations.
+func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey {
+	if dk == nil {
+		dk = &DecapsulationKey{}
+	}
+
+	G := sha3.Sum512(d[:])
+	ρ, σ := G[:32], G[32:]
+
+	A := &dk.A
+	for i := byte(0); i < k; i++ {
+		for j := byte(0); j < k; j++ {
+			// Note that this is consistent with Kyber round 3, rather than with
+			// the initial draft of FIPS 203, because NIST signaled that the
+			// change was involuntary and will be reverted.
+			A[i*k+j] = sampleNTT(ρ, j, i)
+		}
+	}
+
+	var N byte
+	s := &dk.s
+	for i := range s {
+		s[i] = ntt(samplePolyCBD(σ, N))
+		N++
+	}
+	e := make([]nttElement, k)
+	for i := range e {
+		e[i] = ntt(samplePolyCBD(σ, N))
+		N++
+	}
+
+	t := &dk.t
+	for i := range t { // t = A ◦ s + e
+		t[i] = e[i]
+		for j := range s {
+			t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j]))
+		}
+	}
+
+	// dkPKE ← ByteEncode₁₂(s)
+	// ekPKE ← ByteEncode₁₂(t) || ρ
+	// ek ← ekPKE
+	// dk ← dkPKE || ek || H(ek) || z
+	dkB := dk.dk[:0]
+
+	for i := range s {
+		dkB = polyByteEncode(dkB, s[i])
+	}
+
+	for i := range t {
+		dkB = polyByteEncode(dkB, t[i])
+	}
+	dkB = append(dkB, ρ...)
+
+	H := sha3.New256()
+	H.Write(dkB[decryptionKeySize:])
+	dkB = H.Sum(dkB)
+
+	dkB = append(dkB, z[:]...)
+
+	if len(dkB) != len(dk.dk) {
+		panic("mlkem768: internal error: invalid decapsulation key size")
+	}
+
+	return dk
+}
+
+// Encapsulate generates a shared key and an associated ciphertext from an
+// encapsulation key, drawing random bytes from crypto/rand.
+// If the encapsulation key is not valid, Encapsulate returns an error.
+//
+// The shared key must be kept secret.
+func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) {
+	// The actual logic is in a separate function to outline this allocation.
+	var cc [CiphertextSize]byte
+	return encapsulate(&cc, encapsulationKey)
+}
+
+func encapsulate(cc *[CiphertextSize]byte, encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) {
+	if len(encapsulationKey) != EncapsulationKeySize {
+		return nil, nil, errors.New("mlkem768: invalid encapsulation key length")
+	}
+	var m [messageSize]byte
+	if _, err := rand.Read(m[:]); err != nil {
+		return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error())
+	}
+	return kemEncaps(cc, encapsulationKey, &m)
+}
+
+// kemEncaps generates a shared key and an associated ciphertext.
+//
+// It implements ML-KEM.Encaps according to FIPS 203 (DRAFT), Algorithm 16.
+func kemEncaps(cc *[CiphertextSize]byte, ek []byte, m *[messageSize]byte) (c, K []byte, err error) {
+	if cc == nil {
+		cc = &[CiphertextSize]byte{}
+	}
+
+	H := sha3.Sum256(ek[:])
+	g := sha3.New512()
+	g.Write(m[:])
+	g.Write(H[:])
+	G := g.Sum(nil)
+	K, r := G[:SharedKeySize], G[SharedKeySize:]
+	var ex encryptionKey
+	if err := parseEK(&ex, ek[:]); err != nil {
+		return nil, nil, err
+	}
+	c = pkeEncrypt(cc, &ex, m, r)
+	return c, K, nil
+}
+
+// parseEK parses an encryption key from its encoded form.
+//
+// It implements the initial stages of K-PKE.Encrypt according to FIPS 203
+// (DRAFT), Algorithm 13.
+func parseEK(ex *encryptionKey, ekPKE []byte) error {
+	if len(ekPKE) != encryptionKeySize {
+		return errors.New("mlkem768: invalid encryption key length")
+	}
+
+	for i := range ex.t {
+		var err error
+		ex.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
+		if err != nil {
+			return err
+		}
+		ekPKE = ekPKE[encodingSize12:]
+	}
+	ρ := ekPKE
+
+	for i := byte(0); i < k; i++ {
+		for j := byte(0); j < k; j++ {
+			// See the note in pkeKeyGen about the order of the indices being
+			// consistent with Kyber round 3.
+			ex.A[i*k+j] = sampleNTT(ρ, j, i)
+		}
+	}
+
+	return nil
+}
+
+// pkeEncrypt encrypt a plaintext message.
+//
+// It implements K-PKE.Encrypt according to FIPS 203 (DRAFT), Algorithm 13,
+// although the computation of t and AT is done in parseEK.
+func pkeEncrypt(cc *[CiphertextSize]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
+	var N byte
+	r, e1 := make([]nttElement, k), make([]ringElement, k)
+	for i := range r {
+		r[i] = ntt(samplePolyCBD(rnd, N))
+		N++
+	}
+	for i := range e1 {
+		e1[i] = samplePolyCBD(rnd, N)
+		N++
+	}
+	e2 := samplePolyCBD(rnd, N)
+
+	u := make([]ringElement, k) // NTT⁻¹(AT ◦ r) + e1
+	for i := range u {
+		u[i] = e1[i]
+		for j := range r {
+			// Note that i and j are inverted, as we need the transposed of A.
+			u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.A[j*k+i], r[j])))
+		}
+	}
+
+	μ := ringDecodeAndDecompress1(m)
+
+	var vNTT nttElement // t⊺ ◦ r
+	for i := range ex.t {
+		vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
+	}
+	v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
+
+	c := cc[:0]
+	for _, f := range u {
+		c = ringCompressAndEncode10(c, f)
+	}
+	c = ringCompressAndEncode4(c, v)
+
+	return c
+}
+
+// Decapsulate generates a shared key from a ciphertext and a decapsulation key.
+// If the ciphertext is not valid, Decapsulate returns an error.
+//
+// The shared key must be kept secret.
+func Decapsulate(dk *DecapsulationKey, ciphertext []byte) (sharedKey []byte, err error) {
+	if len(ciphertext) != CiphertextSize {
+		return nil, errors.New("mlkem768: invalid ciphertext length")
+	}
+	c := (*[CiphertextSize]byte)(ciphertext)
+	return kemDecaps(dk, c), nil
+}
+
+// kemDecaps produces a shared key from a ciphertext.
+//
+// It implements ML-KEM.Decaps according to FIPS 203 (DRAFT), Algorithm 17.
+func kemDecaps(dk *DecapsulationKey, c *[CiphertextSize]byte) (K []byte) {
+	h := dk.dk[decryptionKeySize+encryptionKeySize : decryptionKeySize+encryptionKeySize+32]
+	z := dk.dk[decryptionKeySize+encryptionKeySize+32:]
+
+	m := pkeDecrypt(&dk.decryptionKey, c)
+	g := sha3.New512()
+	g.Write(m[:])
+	g.Write(h)
+	G := g.Sum(nil)
+	Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
+	J := sha3.NewShake256()
+	J.Write(z)
+	J.Write(c[:])
+	Kout := make([]byte, SharedKeySize)
+	J.Read(Kout)
+	var cc [CiphertextSize]byte
+	c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r)
+
+	subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
+	return Kout
+}
+
+// parseDK parses a decryption key from its encoded form.
+//
+// It implements the computation of s from K-PKE.Decrypt according to FIPS 203
+// (DRAFT), Algorithm 14.
+func parseDK(dx *decryptionKey, dkPKE []byte) error {
+	if len(dkPKE) != decryptionKeySize {
+		return errors.New("mlkem768: invalid decryption key length")
+	}
+
+	for i := range dx.s {
+		f, err := polyByteDecode[nttElement](dkPKE[:encodingSize12])
+		if err != nil {
+			return err
+		}
+		dx.s[i] = f
+		dkPKE = dkPKE[encodingSize12:]
+	}
+
+	return nil
+}
+
+// pkeDecrypt decrypts a ciphertext.
+//
+// It implements K-PKE.Decrypt according to FIPS 203 (DRAFT), Algorithm 14,
+// although the computation of s is done in parseDK.
+func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize]byte) []byte {
+	u := make([]ringElement, k)
+	for i := range u {
+		b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
+		u[i] = ringDecodeAndDecompress10(b)
+	}
+
+	b := (*[encodingSize4]byte)(c[encodingSize10*k:])
+	v := ringDecodeAndDecompress4(b)
+
+	var mask nttElement // s⊺ ◦ NTT(u)
+	for i := range dx.s {
+		mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
+	}
+	w := polySub(v, inverseNTT(mask))
+
+	return ringCompressAndEncode1(nil, w)
+}
+
+// fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced.
+type fieldElement uint16
+
+// fieldCheckReduced checks that a value a is < q.
+func fieldCheckReduced(a uint16) (fieldElement, error) {
+	if a >= q {
+		return 0, errors.New("unreduced field element")
+	}
+	return fieldElement(a), nil
+}
+
+// fieldReduceOnce reduces a value a < 2q.
+func fieldReduceOnce(a uint16) fieldElement {
+	x := a - q
+	// If x underflowed, then x >= 2¹⁶ - q > 2¹⁵, so the top bit is set.
+	x += (x >> 15) * q
+	return fieldElement(x)
+}
+
+func fieldAdd(a, b fieldElement) fieldElement {
+	x := uint16(a + b)
+	return fieldReduceOnce(x)
+}
+
+func fieldSub(a, b fieldElement) fieldElement {
+	x := uint16(a - b + q)
+	return fieldReduceOnce(x)
+}
+
+const (
+	barrettMultiplier = 5039 // 2¹² * 2¹² / q
+	barrettShift      = 24   // log₂(2¹² * 2¹²)
+)
+
+// fieldReduce reduces a value a < 2q² using Barrett reduction, to avoid
+// potentially variable-time division.
+func fieldReduce(a uint32) fieldElement {
+	quotient := uint32((uint64(a) * barrettMultiplier) >> barrettShift)
+	return fieldReduceOnce(uint16(a - quotient*q))
+}
+
+func fieldMul(a, b fieldElement) fieldElement {
+	x := uint32(a) * uint32(b)
+	return fieldReduce(x)
+}
+
+// fieldMulSub returns a * (b - c). This operation is fused to save a
+// fieldReduceOnce after the subtraction.
+func fieldMulSub(a, b, c fieldElement) fieldElement {
+	x := uint32(a) * uint32(b-c+q)
+	return fieldReduce(x)
+}
+
+// fieldAddMul returns a * b + c * d. This operation is fused to save a
+// fieldReduceOnce and a fieldReduce.
+func fieldAddMul(a, b, c, d fieldElement) fieldElement {
+	x := uint32(a) * uint32(b)
+	x += uint32(c) * uint32(d)
+	return fieldReduce(x)
+}
+
+// compress maps a field element uniformly to the range 0 to 2ᵈ-1, according to
+// FIPS 203 (DRAFT), Definition 4.5.
+func compress(x fieldElement, d uint8) uint16 {
+	// We want to compute (x * 2ᵈ) / q, rounded to nearest integer, with 1/2
+	// rounding up (see FIPS 203 (DRAFT), Section 2.3).
+
+	// Barrett reduction produces a quotient and a remainder in the range [0, 2q),
+	// such that dividend = quotient * q + remainder.
+	dividend := uint32(x) << d // x * 2ᵈ
+	quotient := uint32(uint64(dividend) * barrettMultiplier >> barrettShift)
+	remainder := dividend - quotient*q
+
+	// Since the remainder is in the range [0, 2q), not [0, q), we need to
+	// portion it into three spans for rounding.
+	//
+	//     [ 0,       q/2     ) -> round to 0
+	//     [ q/2,     q + q/2 ) -> round to 1
+	//     [ q + q/2, 2q      ) -> round to 2
+	//
+	// We can convert that to the following logic: add 1 if remainder > q/2,
+	// then add 1 again if remainder > q + q/2.
+	//
+	// Note that if remainder > x, then ⌊x⌋ - remainder underflows, and the top
+	// bit of the difference will be set.
+	quotient += (q/2 - remainder) >> 31 & 1
+	quotient += (q + q/2 - remainder) >> 31 & 1
+
+	// quotient might have overflowed at this point, so reduce it by masking.
+	var mask uint32 = (1 << d) - 1
+	return uint16(quotient & mask)
+}
+
+// decompress maps a number x between 0 and 2ᵈ-1 uniformly to the full range of
+// field elements, according to FIPS 203 (DRAFT), Definition 4.6.
+func decompress(y uint16, d uint8) fieldElement {
+	// We want to compute (y * q) / 2ᵈ, rounded to nearest integer, with 1/2
+	// rounding up (see FIPS 203 (DRAFT), Section 2.3).
+
+	dividend := uint32(y) * q
+	quotient := dividend >> d // (y * q) / 2ᵈ
+
+	// The d'th least-significant bit of the dividend (the most significant bit
+	// of the remainder) is 1 for the top half of the values that divide to the
+	// same quotient, which are the ones that round up.
+	quotient += dividend >> (d - 1) & 1
+
+	// quotient is at most (2¹¹-1) * q / 2¹¹ + 1 = 3328, so it didn't overflow.
+	return fieldElement(quotient)
+}
+
+// ringElement is a polynomial, an element of R_q, represented as an array
+// according to FIPS 203 (DRAFT), Section 2.4.
+type ringElement [n]fieldElement
+
+// polyAdd adds two ringElements or nttElements.
+func polyAdd[T ~[n]fieldElement](a, b T) (s T) {
+	for i := range s {
+		s[i] = fieldAdd(a[i], b[i])
+	}
+	return s
+}
+
+// polySub subtracts two ringElements or nttElements.
+func polySub[T ~[n]fieldElement](a, b T) (s T) {
+	for i := range s {
+		s[i] = fieldSub(a[i], b[i])
+	}
+	return s
+}
+
+// polyByteEncode appends the 384-byte encoding of f to b.
+//
+// It implements ByteEncode₁₂, according to FIPS 203 (DRAFT), Algorithm 4.
+func polyByteEncode[T ~[n]fieldElement](b []byte, f T) []byte {
+	out, B := sliceForAppend(b, encodingSize12)
+	for i := 0; i < n; i += 2 {
+		x := uint32(f[i]) | uint32(f[i+1])<<12
+		B[0] = uint8(x)
+		B[1] = uint8(x >> 8)
+		B[2] = uint8(x >> 16)
+		B = B[3:]
+	}
+	return out
+}
+
+// polyByteDecode decodes the 384-byte encoding of a polynomial, checking that
+// all the coefficients are properly reduced. This achieves the "Modulus check"
+// step of ML-KEM Encapsulation Input Validation.
+//
+// polyByteDecode is also used in ML-KEM Decapsulation, where the input
+// validation is not required, but implicitly allowed by the specification.
+//
+// It implements ByteDecode₁₂, according to FIPS 203 (DRAFT), Algorithm 5.
+func polyByteDecode[T ~[n]fieldElement](b []byte) (T, error) {
+	if len(b) != encodingSize12 {
+		return T{}, errors.New("mlkem768: invalid encoding length")
+	}
+	var f T
+	for i := 0; i < n; i += 2 {
+		d := uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16
+		const mask12 = 0b1111_1111_1111
+		var err error
+		if f[i], err = fieldCheckReduced(uint16(d & mask12)); err != nil {
+			return T{}, errors.New("mlkem768: invalid polynomial encoding")
+		}
+		if f[i+1], err = fieldCheckReduced(uint16(d >> 12)); err != nil {
+			return T{}, errors.New("mlkem768: invalid polynomial encoding")
+		}
+		b = b[3:]
+	}
+	return f, nil
+}
+
+// sliceForAppend takes a slice and a requested number of bytes. It returns a
+// slice with the contents of the given slice followed by that many bytes and a
+// second slice that aliases into it and contains only the extra bytes. If the
+// original slice has sufficient capacity then no allocation is performed.
+func sliceForAppend(in []byte, n int) (head, tail []byte) {
+	if total := len(in) + n; cap(in) >= total {
+		head = in[:total]
+	} else {
+		head = make([]byte, total)
+		copy(head, in)
+	}
+	tail = head[len(in):]
+	return
+}
+
+// ringCompressAndEncode1 appends a 32-byte encoding of a ring element to s,
+// compressing one coefficients per bit.
+//
+// It implements Compress₁, according to FIPS 203 (DRAFT), Definition 4.5,
+// followed by ByteEncode₁, according to FIPS 203 (DRAFT), Algorithm 4.
+func ringCompressAndEncode1(s []byte, f ringElement) []byte {
+	s, b := sliceForAppend(s, encodingSize1)
+	for i := range b {
+		b[i] = 0
+	}
+	for i := range f {
+		b[i/8] |= uint8(compress(f[i], 1) << (i % 8))
+	}
+	return s
+}
+
+// ringDecodeAndDecompress1 decodes a 32-byte slice to a ring element where each
+// bit is mapped to 0 or ⌈q/2⌋.
+//
+// It implements ByteDecode₁, according to FIPS 203 (DRAFT), Algorithm 5,
+// followed by Decompress₁, according to FIPS 203 (DRAFT), Definition 4.6.
+func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement {
+	var f ringElement
+	for i := range f {
+		b_i := b[i/8] >> (i % 8) & 1
+		const halfQ = (q + 1) / 2        // ⌈q/2⌋, rounded up per FIPS 203 (DRAFT), Section 2.3
+		f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋
+	}
+	return f
+}
+
+// ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s,
+// compressing two coefficients per byte.
+//
+// It implements Compress₄, according to FIPS 203 (DRAFT), Definition 4.5,
+// followed by ByteEncode₄, according to FIPS 203 (DRAFT), Algorithm 4.
+func ringCompressAndEncode4(s []byte, f ringElement) []byte {
+	s, b := sliceForAppend(s, encodingSize4)
+	for i := 0; i < n; i += 2 {
+		b[i/2] = uint8(compress(f[i], 4) | compress(f[i+1], 4)<<4)
+	}
+	return s
+}
+
+// ringDecodeAndDecompress4 decodes a 128-byte encoding of a ring element where
+// each four bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode₄, according to FIPS 203 (DRAFT), Algorithm 5,
+// followed by Decompress₄, according to FIPS 203 (DRAFT), Definition 4.6.
+func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement {
+	var f ringElement
+	for i := 0; i < n; i += 2 {
+		f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4))
+		f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4))
+	}
+	return f
+}
+
+// ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s,
+// compressing four coefficients per five bytes.
+//
+// It implements Compress₁₀, according to FIPS 203 (DRAFT), Definition 4.5,
+// followed by ByteEncode₁₀, according to FIPS 203 (DRAFT), Algorithm 4.
+func ringCompressAndEncode10(s []byte, f ringElement) []byte {
+	s, b := sliceForAppend(s, encodingSize10)
+	for i := 0; i < n; i += 4 {
+		var x uint64
+		x |= uint64(compress(f[i+0], 10))
+		x |= uint64(compress(f[i+1], 10)) << 10
+		x |= uint64(compress(f[i+2], 10)) << 20
+		x |= uint64(compress(f[i+3], 10)) << 30
+		b[0] = uint8(x)
+		b[1] = uint8(x >> 8)
+		b[2] = uint8(x >> 16)
+		b[3] = uint8(x >> 24)
+		b[4] = uint8(x >> 32)
+		b = b[5:]
+	}
+	return s
+}
+
+// ringDecodeAndDecompress10 decodes a 320-byte encoding of a ring element where
+// each ten bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode₁₀, according to FIPS 203 (DRAFT), Algorithm 5,
+// followed by Decompress₁₀, according to FIPS 203 (DRAFT), Definition 4.6.
+func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement {
+	b := bb[:]
+	var f ringElement
+	for i := 0; i < n; i += 4 {
+		x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32
+		b = b[5:]
+		f[i] = fieldElement(decompress(uint16(x>>0&0b11_1111_1111), 10))
+		f[i+1] = fieldElement(decompress(uint16(x>>10&0b11_1111_1111), 10))
+		f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10))
+		f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10))
+	}
+	return f
+}
+
+// samplePolyCBD draws a ringElement from the special Dη distribution given a
+// stream of random bytes generated by the PRF function, according to FIPS 203
+// (DRAFT), Algorithm 7 and Definition 4.1.
+func samplePolyCBD(s []byte, b byte) ringElement {
+	prf := sha3.NewShake256()
+	prf.Write(s)
+	prf.Write([]byte{b})
+	B := make([]byte, 128)
+	prf.Read(B)
+
+	// SamplePolyCBD simply draws four (2η) bits for each coefficient, and adds
+	// the first two and subtracts the last two.
+
+	var f ringElement
+	for i := 0; i < n; i += 2 {
+		b := B[i/2]
+		b_7, b_6, b_5, b_4 := b>>7, b>>6&1, b>>5&1, b>>4&1
+		b_3, b_2, b_1, b_0 := b>>3&1, b>>2&1, b>>1&1, b&1
+		f[i] = fieldSub(fieldElement(b_0+b_1), fieldElement(b_2+b_3))
+		f[i+1] = fieldSub(fieldElement(b_4+b_5), fieldElement(b_6+b_7))
+	}
+	return f
+}
+
+// nttElement is an NTT representation, an element of T_q, represented as an
+// array according to FIPS 203 (DRAFT), Section 2.4.
+type nttElement [n]fieldElement
+
+// gammas are the values ζ^2BitRev7(i)+1 mod q for each index i.
+var gammas = [128]fieldElement{17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175}
+
+// nttMul multiplies two nttElements.
+//
+// It implements MultiplyNTTs, according to FIPS 203 (DRAFT), Algorithm 10.
+func nttMul(f, g nttElement) nttElement {
+	var h nttElement
+	// We use i += 2 for bounds check elimination. See https://go.dev/issue/66826.
+	for i := 0; i < 256; i += 2 {
+		a0, a1 := f[i], f[i+1]
+		b0, b1 := g[i], g[i+1]
+		h[i] = fieldAddMul(a0, b0, fieldMul(a1, b1), gammas[i/2])
+		h[i+1] = fieldAddMul(a0, b1, a1, b0)
+	}
+	return h
+}
+
+// zetas are the values ζ^BitRev7(k) mod q for each index k.
+var zetas = [128]fieldElement{1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154}
+
+// ntt maps a ringElement to its nttElement representation.
+//
+// It implements NTT, according to FIPS 203 (DRAFT), Algorithm 8.
+func ntt(f ringElement) nttElement {
+	k := 1
+	for len := 128; len >= 2; len /= 2 {
+		for start := 0; start < 256; start += 2 * len {
+			zeta := zetas[k]
+			k++
+			// Bounds check elimination hint.
+			f, flen := f[start:start+len], f[start+len:start+len+len]
+			for j := 0; j < len; j++ {
+				t := fieldMul(zeta, flen[j])
+				flen[j] = fieldSub(f[j], t)
+				f[j] = fieldAdd(f[j], t)
+			}
+		}
+	}
+	return nttElement(f)
+}
+
+// inverseNTT maps a nttElement back to the ringElement it represents.
+//
+// It implements NTT⁻¹, according to FIPS 203 (DRAFT), Algorithm 9.
+func inverseNTT(f nttElement) ringElement {
+	k := 127
+	for len := 2; len <= 128; len *= 2 {
+		for start := 0; start < 256; start += 2 * len {
+			zeta := zetas[k]
+			k--
+			// Bounds check elimination hint.
+			f, flen := f[start:start+len], f[start+len:start+len+len]
+			for j := 0; j < len; j++ {
+				t := f[j]
+				f[j] = fieldAdd(t, flen[j])
+				flen[j] = fieldMulSub(zeta, flen[j], t)
+			}
+		}
+	}
+	for i := range f {
+		f[i] = fieldMul(f[i], 3303) // 3303 = 128⁻¹ mod q
+	}
+	return ringElement(f)
+}
+
+// sampleNTT draws a uniformly random nttElement from a stream of uniformly
+// random bytes generated by the XOF function, according to FIPS 203 (DRAFT),
+// Algorithm 6 and Definition 4.2.
+func sampleNTT(rho []byte, ii, jj byte) nttElement {
+	B := sha3.NewShake128()
+	B.Write(rho)
+	B.Write([]byte{ii, jj})
+
+	// SampleNTT essentially draws 12 bits at a time from r, interprets them in
+	// little-endian, and rejects values higher than q, until it drew 256
+	// values. (The rejection rate is approximately 19%.)
+	//
+	// To do this from a bytes stream, it draws three bytes at a time, and
+	// splits them into two uint16 appropriately masked.
+	//
+	//               r₀              r₁              r₂
+	//       |- - - - - - - -|- - - - - - - -|- - - - - - - -|
+	//
+	//               Uint16(r₀ || r₁)
+	//       |- - - - - - - - - - - - - - - -|
+	//       |- - - - - - - - - - - -|
+	//                   d₁
+	//
+	//                                Uint16(r₁ || r₂)
+	//                       |- - - - - - - - - - - - - - - -|
+	//                               |- - - - - - - - - - - -|
+	//                                           d₂
+	//
+	// Note that in little-endian, the rightmost bits are the most significant
+	// bits (dropped with a mask) and the leftmost bits are the least
+	// significant bits (dropped with a right shift).
+
+	var a nttElement
+	var j int        // index into a
+	var buf [24]byte // buffered reads from B
+	off := len(buf)  // index into buf, starts in a "buffer fully consumed" state
+	for {
+		if off >= len(buf) {
+			B.Read(buf[:])
+			off = 0
+		}
+		d1 := byteorder.LeUint16(buf[off:]) & 0b1111_1111_1111
+		d2 := byteorder.LeUint16(buf[off+1:]) >> 4
+		off += 3
+		if d1 < q {
+			a[j] = fieldElement(d1)
+			j++
+		}
+		if j >= len(a) {
+			break
+		}
+		if d2 < q {
+			a[j] = fieldElement(d2)
+			j++
+		}
+		if j >= len(a) {
+			break
+		}
+	}
+	return a
+}

+ 4 - 4
vendor/github.com/Psiphon-Labs/psiphon-tls/key_agreement.go

@@ -16,8 +16,8 @@ import (
 	"io"
 )
 
-// a keyAgreement implements the client and server side of a TLS key agreement
-// protocol by generating and processing key exchange messages.
+// A keyAgreement implements the client and server side of a TLS 1.0–1.2 key
+// agreement protocol by generating and processing key exchange messages.
 type keyAgreement interface {
 	// On the server side, the first two methods are called in order.
 
@@ -126,7 +126,7 @@ func md5SHA1Hash(slices [][]byte) []byte {
 }
 
 // hashForServerKeyExchange hashes the given slices and returns their digest
-// using the given hash function (for >= TLS 1.2) or using a default based on
+// using the given hash function (for TLS 1.2) or using a default based on
 // the sigType (for earlier TLS versions). For Ed25519 signatures, which don't
 // do pre-hashing, it returns the concatenation of the slices.
 func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte {
@@ -169,7 +169,7 @@ type ecdheKeyAgreement struct {
 func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
 	var curveID CurveID
 	for _, c := range clientHello.supportedCurves {
-		if config.supportsCurve(c) {
+		if config.supportsCurve(ka.version, c) {
 			curveID = c
 			break
 		}

+ 42 - 0
vendor/github.com/Psiphon-Labs/psiphon-tls/key_schedule.go

@@ -12,8 +12,11 @@ import (
 	"hash"
 	"io"
 
+	"github.com/Psiphon-Labs/psiphon-tls/internal/mlkem768"
+
 	"golang.org/x/crypto/cryptobyte"
 	"golang.org/x/crypto/hkdf"
+	"golang.org/x/crypto/sha3"
 )
 
 // This file contains the functions necessary to compute the TLS 1.3 key
@@ -117,6 +120,45 @@ func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript
 	}
 }
 
+type keySharePrivateKeys struct {
+	curveID CurveID
+	ecdhe   *ecdh.PrivateKey
+	kyber   *mlkem768.DecapsulationKey
+}
+
+// kyberDecapsulate implements decapsulation according to Kyber Round 3.
+func kyberDecapsulate(dk *mlkem768.DecapsulationKey, c []byte) ([]byte, error) {
+	K, err := mlkem768.Decapsulate(dk, c)
+	if err != nil {
+		return nil, err
+	}
+	return kyberSharedSecret(K, c), nil
+}
+
+// kyberEncapsulate implements encapsulation according to Kyber Round 3.
+func kyberEncapsulate(ek []byte) (c, ss []byte, err error) {
+	c, ss, err = mlkem768.Encapsulate(ek)
+	if err != nil {
+		return nil, nil, err
+	}
+	return c, kyberSharedSecret(ss, c), nil
+}
+
+func kyberSharedSecret(K, c []byte) []byte {
+	// Package mlkem768 implements ML-KEM, which compared to Kyber removed a
+	// final hashing step. Compute SHAKE-256(K || SHA3-256(c), 32) to match Kyber.
+	// See https://words.filippo.io/mlkem768/#bonus-track-using-a-ml-kem-implementation-as-kyber-v3.
+	h := sha3.NewShake256()
+	h.Write(K)
+	ch := sha3.Sum256(c)
+	h.Write(ch[:])
+	out := make([]byte, 32)
+	h.Read(out)
+	return out
+}
+
+const x25519PublicKeySize = 32
+
 // generateECDHEKey returns a PrivateKey that implements Diffie-Hellman
 // according to RFC 8446, Section 4.2.8.2.
 func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) {

+ 0 - 13
vendor/github.com/Psiphon-Labs/psiphon-tls/notboring.go

@@ -7,16 +7,3 @@
 package tls
 
 func needFIPS() bool { return false }
-
-func supportedSignatureAlgorithms() []SignatureScheme {
-	return defaultSupportedSignatureAlgorithms
-}
-
-func fipsMinVersion(c *Config) uint16          { panic("fipsMinVersion") }
-func fipsMaxVersion(c *Config) uint16          { panic("fipsMaxVersion") }
-func fipsCurvePreferences(c *Config) []CurveID { panic("fipsCurvePreferences") }
-func fipsCipherSuites(c *Config) []uint16      { panic("fipsCipherSuites") }
-
-var fipsSupportedSignatureAlgorithms []SignatureScheme
-
-var defaultCipherSuitesTLS13FIPS []uint16

+ 9 - 2
vendor/github.com/Psiphon-Labs/psiphon-tls/prf.go

@@ -252,13 +252,20 @@ func (h *finishedHash) discardHandshakeBuffer() {
 	h.buffer = nil
 }
 
-// noExportedKeyingMaterial is used as a value of
+// noEKMBecauseRenegotiation is used as a value of
 // ConnectionState.ekm when renegotiation is enabled and thus
 // we wish to fail all key-material export requests.
-func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
+func noEKMBecauseRenegotiation(label string, context []byte, length int) ([]byte, error) {
 	return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled")
 }
 
+// noEKMBecauseNoEMS is used as a value of ConnectionState.ekm when Extended
+// Master Secret is not negotiated and thus we wish to fail all key-material
+// export requests.
+func noEKMBecauseNoEMS(label string, context []byte, length int) ([]byte, error) {
+	return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when neither TLS 1.3 nor Extended Master Secret are negotiated; override with GODEBUG=tlsunsafeekm=1")
+}
+
 // ekmFromMasterSecret generates exported keying material as defined in RFC 5705.
 func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) {
 	return func(label string, context []byte, length int) ([]byte, error) {

+ 2 - 2
vendor/github.com/Psiphon-Labs/psiphon-tls/public.go

@@ -3,7 +3,7 @@ package tls
 // [Psiphon]
 // ClientSessionState contains the state needed by clients to resume TLS sessions.
 func MakeClientSessionState(
-	Ticket []uint8,
+	Ticket []byte,
 	Vers uint16,
 	CipherSuite uint16,
 	MasterSecret []byte,
@@ -12,7 +12,6 @@ func MakeClientSessionState(
 	UseBy uint64,
 ) *ClientSessionState {
 	css := &ClientSessionState{
-		ticket: Ticket,
 		session: &SessionState{
 			version:     Vers,
 			cipherSuite: CipherSuite,
@@ -20,6 +19,7 @@ func MakeClientSessionState(
 			createdAt:   CreatedAt,
 			ageAdd:      AgeAdd,
 			useBy:       UseBy,
+			ticket:      Ticket,
 		},
 	}
 	return css

+ 91 - 12
vendor/github.com/Psiphon-Labs/psiphon-tls/quic.go

@@ -46,9 +46,16 @@ type QUICConn struct {
 	sessionTicketSent bool
 }
 
-// A QUICConfig configures a QUICConn.
+// A QUICConfig configures a [QUICConn].
 type QUICConfig struct {
 	TLSConfig *Config
+
+	// EnableSessionEvents may be set to true to enable the
+	// [QUICStoreSession] and [QUICResumeSession] events for client connections.
+	// When this event is enabled, sessions are not automatically
+	// stored in the client session cache.
+	// The application should use [QUICConn.StoreSession] to store sessions.
+	EnableSessionEvents bool
 }
 
 // A QUICEventKind is a type of operation on a QUIC connection.
@@ -87,10 +94,29 @@ const (
 	// QUICRejectedEarlyData indicates that the server rejected 0-RTT data even
 	// if we offered it. It's returned before QUICEncryptionLevelApplication
 	// keys are returned.
+	// This event only occurs on client connections.
 	QUICRejectedEarlyData
 
 	// QUICHandshakeDone indicates that the TLS handshake has completed.
 	QUICHandshakeDone
+
+	// QUICResumeSession indicates that a client is attempting to resume a previous session.
+	// [QUICEvent.SessionState] is set.
+	//
+	// For client connections, this event occurs when the session ticket is selected.
+	// For server connections, this event occurs when receiving the client's session ticket.
+	//
+	// The application may set [QUICEvent.SessionState.EarlyData] to false before the
+	// next call to [QUICConn.NextEvent] to decline 0-RTT even if the session supports it.
+	QUICResumeSession
+
+	// QUICStoreSession indicates that the server has provided state permitting
+	// the client to resume the session.
+	// [QUICEvent.SessionState] is set.
+	// The application should use [QUICConn.StoreSession] session to store the [SessionState].
+	// The application may modify the [SessionState] before storing it.
+	// This event only occurs on client connections.
+	QUICStoreSession
 )
 
 // A QUICEvent is an event occurring on a QUIC connection.
@@ -109,6 +135,9 @@ type QUICEvent struct {
 
 	// Set for QUICSetReadSecret and QUICSetWriteSecret.
 	Suite uint16
+
+	// Set for QUICResumeSession and QUICStoreSession.
+	SessionState *SessionState
 }
 
 type quicState struct {
@@ -127,12 +156,16 @@ type quicState struct {
 	cancelc  <-chan struct{} // handshake has been canceled
 	cancel   context.CancelFunc
 
+	waitingForDrain bool
+
 	// readbuf is shared between HandleData and the handshake goroutine.
 	// HandshakeCryptoData passes ownership to the handshake goroutine by
 	// reading from signalc, and reclaims ownership by reading from blockedc.
 	readbuf []byte
 
 	transportParams []byte // to send to the peer
+
+	enableSessionEvents bool
 }
 
 // QUICClient returns a new TLS client side connection using QUICTransport as the
@@ -140,7 +173,7 @@ type quicState struct {
 //
 // The config's MinVersion must be at least TLS 1.3.
 func QUICClient(config *QUICConfig) *QUICConn {
-	return newQUICConn(Client(nil, config.TLSConfig))
+	return newQUICConn(Client(nil, config.TLSConfig), config)
 }
 
 // QUICServer returns a new TLS server side connection using QUICTransport as the
@@ -148,13 +181,14 @@ func QUICClient(config *QUICConfig) *QUICConn {
 //
 // The config's MinVersion must be at least TLS 1.3.
 func QUICServer(config *QUICConfig) *QUICConn {
-	return newQUICConn(Server(nil, config.TLSConfig))
+	return newQUICConn(Server(nil, config.TLSConfig), config)
 }
 
-func newQUICConn(conn *Conn) *QUICConn {
+func newQUICConn(conn *Conn, config *QUICConfig) *QUICConn {
 	conn.quic = &quicState{
-		signalc:  make(chan struct{}),
-		blockedc: make(chan struct{}),
+		signalc:             make(chan struct{}),
+		blockedc:            make(chan struct{}),
+		enableSessionEvents: config.EnableSessionEvents,
 	}
 	conn.quic.events = conn.quic.eventArr[:0]
 	return &QUICConn{
@@ -163,7 +197,7 @@ func newQUICConn(conn *Conn) *QUICConn {
 }
 
 // Start starts the client or server handshake protocol.
-// It may produce connection events, which may be read with NextEvent.
+// It may produce connection events, which may be read with [QUICConn.NextEvent].
 //
 // Start must be called at most once.
 func (q *QUICConn) Start(ctx context.Context) error {
@@ -182,7 +216,7 @@ func (q *QUICConn) Start(ctx context.Context) error {
 }
 
 // NextEvent returns the next event occurring on the connection.
-// It returns an event with a Kind of QUICNoEvent when no events are available.
+// It returns an event with a Kind of [QUICNoEvent] when no events are available.
 func (q *QUICConn) NextEvent() QUICEvent {
 	qs := q.conn.quic
 	if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
@@ -190,6 +224,11 @@ func (q *QUICConn) NextEvent() QUICEvent {
 		// to catch callers erroniously retaining it.
 		qs.events[last].Data[0] = 0
 	}
+	if qs.nextEvent >= len(qs.events) && qs.waitingForDrain {
+		qs.waitingForDrain = false
+		<-qs.signalc
+		<-qs.blockedc
+	}
 	if qs.nextEvent >= len(qs.events) {
 		qs.events = qs.events[:0]
 		qs.nextEvent = 0
@@ -214,7 +253,7 @@ func (q *QUICConn) Close() error {
 }
 
 // HandleData handles handshake bytes received from the peer.
-// It may produce connection events, which may be read with NextEvent.
+// It may produce connection events, which may be read with [QUICConn.NextEvent].
 func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
 	c := q.conn
 	if c.in.level != level {
@@ -255,10 +294,11 @@ func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
 type QUICSessionTicketOptions struct {
 	// EarlyData specifies whether the ticket may be used for 0-RTT.
 	EarlyData bool
+	Extra     [][]byte
 }
 
 // SendSessionTicket sends a session ticket to the client.
-// It produces connection events, which may be read with NextEvent.
+// It produces connection events, which may be read with [QUICConn.NextEvent].
 // Currently, it can only be called once.
 func (q *QUICConn) SendSessionTicket(opts QUICSessionTicketOptions) error {
 	c := q.conn
@@ -272,7 +312,25 @@ func (q *QUICConn) SendSessionTicket(opts QUICSessionTicketOptions) error {
 		return quicError(errors.New("tls: SendSessionTicket called multiple times"))
 	}
 	q.sessionTicketSent = true
-	return quicError(c.sendSessionTicket(opts.EarlyData))
+	return quicError(c.sendSessionTicket(opts.EarlyData, opts.Extra))
+}
+
+// StoreSession stores a session previously received in a QUICStoreSession event
+// in the ClientSessionCache.
+// The application may process additional events or modify the SessionState
+// before storing the session.
+func (q *QUICConn) StoreSession(session *SessionState) error {
+	c := q.conn
+	if !c.isClient {
+		return quicError(errors.New("tls: StoreSessionTicket called on the server"))
+	}
+	cacheKey := c.clientSessionCacheKey()
+	if cacheKey == "" {
+		return nil
+	}
+	cs := &ClientSessionState{session: session}
+	c.config.ClientSessionCache.Put(cacheKey, cs)
+	return nil
 }
 
 // ConnectionState returns basic TLS details about the connection.
@@ -290,7 +348,7 @@ func (q *QUICConn) TLSConnectionMetrics() ConnectionMetrics {
 // SetTransportParameters sets the transport parameters to send to the peer.
 //
 // Server connections may delay setting the transport parameters until after
-// receiving the client's transport parameters. See QUICTransportParametersRequired.
+// receiving the client's transport parameters. See [QUICTransportParametersRequired].
 func (q *QUICConn) SetTransportParameters(params []byte) {
 	if params == nil {
 		params = []byte{}
@@ -363,6 +421,27 @@ func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
 	last.Data = append(last.Data, data...)
 }
 
+func (c *Conn) quicResumeSession(session *SessionState) error {
+	c.quic.events = append(c.quic.events, QUICEvent{
+		Kind:         QUICResumeSession,
+		SessionState: session,
+	})
+	c.quic.waitingForDrain = true
+	for c.quic.waitingForDrain {
+		if err := c.quicWaitForSignal(); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (c *Conn) quicStoreSession(session *SessionState) {
+	c.quic.events = append(c.quic.events, QUICEvent{
+		Kind:         QUICStoreSession,
+		SessionState: session,
+	})
+}
+
 func (c *Conn) quicSetTransportParameters(params []byte) {
 	c.quic.events = append(c.quic.events, QUICEvent{
 		Kind: QUICTransportParameters,

+ 11 - 7
vendor/github.com/Psiphon-Labs/psiphon-tls/ticket.go

@@ -76,7 +76,7 @@ type SessionState struct {
 	// To allow different layers in a protocol stack to share this field,
 	// applications must only append to it, not replace it, and must use entries
 	// that can be recognized even if out of order (for example, by starting
-	// with a id and version prefix).
+	// with an id and version prefix).
 	Extra [][]byte
 
 	// EarlyData indicates whether the ticket can be used for 0-RTT in a QUIC
@@ -103,6 +103,7 @@ type SessionState struct {
 	// Client-side TLS 1.3-only fields.
 	useBy  uint64 // seconds since UNIX epoch
 	ageAdd uint32
+	ticket []byte
 }
 
 // Bytes encodes the session, including any private fields, so that it can be
@@ -324,7 +325,7 @@ func ParseSessionState(data []byte) (*SessionState, error) {
 
 // sessionState returns a partially filled-out [SessionState] with information
 // from the current connection.
-func (c *Conn) sessionState() (*SessionState, error) {
+func (c *Conn) sessionState() *SessionState {
 	return &SessionState{
 		version:           c.vers,
 		cipherSuite:       c.cipherSuite,
@@ -337,10 +338,10 @@ func (c *Conn) sessionState() (*SessionState, error) {
 		isClient:          c.isClient,
 		extMasterSecret:   c.extMasterSecret,
 		verifiedChains:    c.verifiedChains,
-	}, nil
+	}
 }
 
-// EncryptTicket encrypts a ticket with the Config's configured (or default)
+// EncryptTicket encrypts a ticket with the [Config]'s configured (or default)
 // session ticket keys. It can be used as a [Config.WrapSession] implementation.
 func (c *Config) EncryptTicket(cs ConnectionState, ss *SessionState) ([]byte, error) {
 	ticketKeys := c.ticketKeys(nil)
@@ -431,7 +432,6 @@ func (c *Config) decryptTicket(encrypted []byte, ticketKeys []ticketKey) []byte
 // ClientSessionState contains the state needed by a client to
 // resume a previous TLS session.
 type ClientSessionState struct {
-	ticket  []byte
 	session *SessionState
 }
 
@@ -441,7 +441,10 @@ type ClientSessionState struct {
 // It can be called by [ClientSessionCache.Put] to serialize (with
 // [SessionState.Bytes]) and store the session.
 func (cs *ClientSessionState) ResumptionState() (ticket []byte, state *SessionState, err error) {
-	return cs.ticket, cs.session, nil
+	if cs == nil || cs.session == nil {
+		return nil, nil, nil
+	}
+	return cs.session.ticket, cs.session, nil
 }
 
 // NewResumptionState returns a state value that can be returned by
@@ -450,8 +453,9 @@ func (cs *ClientSessionState) ResumptionState() (ticket []byte, state *SessionSt
 // state needs to be returned by [ParseSessionState], and the ticket and session
 // state must have been returned by [ClientSessionState.ResumptionState].
 func NewResumptionState(ticket []byte, state *SessionState) (*ClientSessionState, error) {
+	state.ticket = ticket
 	return &ClientSessionState{
-		ticket: ticket, session: state,
+		session: state,
 	}, nil
 }
 

+ 35 - 13
vendor/github.com/Psiphon-Labs/psiphon-tls/tls.go

@@ -22,6 +22,10 @@ import (
 	"encoding/pem"
 	"errors"
 	"fmt"
+
+	// [Psiphon]
+	// "internal/godebug"
+
 	"net"
 	"os"
 	"strings"
@@ -79,7 +83,7 @@ func (l *listener) Accept() (net.Conn, error) {
 }
 
 // NewListener creates a Listener which accepts connections from an inner
-// Listener and wraps each connection with Server.
+// Listener and wraps each connection with [Server].
 // The configuration config must be non-nil and must include
 // at least one certificate or else set GetCertificate.
 func NewListener(inner net.Listener, config *Config) net.Listener {
@@ -94,6 +98,7 @@ func NewListener(inner net.Listener, config *Config) net.Listener {
 // The configuration config must be non-nil and must include
 // at least one certificate or else set GetCertificate.
 func Listen(network, laddr string, config *Config) (net.Listener, error) {
+	// If this condition changes, consider updating http.Server.ServeTLS too.
 	if config == nil || len(config.Certificates) == 0 &&
 		config.GetCertificate == nil && config.GetConfigForClient == nil {
 		return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
@@ -117,10 +122,10 @@ func (timeoutError) Temporary() bool { return true }
 // handshake as a whole.
 //
 // DialWithDialer interprets a nil configuration as equivalent to the zero
-// configuration; see the documentation of Config for the defaults.
+// configuration; see the documentation of [Config] for the defaults.
 //
 // DialWithDialer uses context.Background internally; to specify the context,
-// use Dialer.DialContext with NetDialer set to the desired dialer.
+// use [Dialer.DialContext] with NetDialer set to the desired dialer.
 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
 	return dial(context.Background(), dialer, network, addr, config)
 }
@@ -197,10 +202,10 @@ type Dialer struct {
 // Dial connects to the given network address and initiates a TLS
 // handshake, returning the resulting TLS connection.
 //
-// The returned Conn, if any, will always be of type *Conn.
+// The returned [Conn], if any, will always be of type *[Conn].
 //
 // Dial uses context.Background internally; to specify the context,
-// use DialContext.
+// use [Dialer.DialContext].
 func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
 	return d.DialContext(context.Background(), network, addr)
 }
@@ -220,7 +225,7 @@ func (d *Dialer) netDialer() *net.Dialer {
 // connected, any expiration of the context will not affect the
 // connection.
 //
-// The returned Conn, if any, will always be of type *Conn.
+// The returned [Conn], if any, will always be of type *[Conn].
 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
 	c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
 	if err != nil {
@@ -230,11 +235,14 @@ func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Con
 	return c, nil
 }
 
-// LoadX509KeyPair reads and parses a public/private key pair from a pair
-// of files. The files must contain PEM encoded data. The certificate file
-// may contain intermediate certificates following the leaf certificate to
-// form a certificate chain. On successful return, Certificate.Leaf will
-// be nil because the parsed form of the certificate is not retained.
+// LoadX509KeyPair reads and parses a public/private key pair from a pair of
+// files. The files must contain PEM encoded data. The certificate file may
+// contain intermediate certificates following the leaf certificate to form a
+// certificate chain. On successful return, Certificate.Leaf will be populated.
+//
+// Before Go 1.23 Certificate.Leaf was left nil, and the parsed certificate was
+// discarded. This behavior can be re-enabled by setting "x509keypairleaf=0"
+// in the GODEBUG environment variable.
 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
 	certPEMBlock, err := os.ReadFile(certFile)
 	if err != nil {
@@ -247,9 +255,15 @@ func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
 	return X509KeyPair(certPEMBlock, keyPEMBlock)
 }
 
+// [Psiphon]
+// var x509keypairleaf = godebug.New("x509keypairleaf")
+
 // X509KeyPair parses a public/private key pair from a pair of
-// PEM encoded data. On successful return, Certificate.Leaf will be nil because
-// the parsed form of the certificate is not retained.
+// PEM encoded data. On successful return, Certificate.Leaf will be populated.
+//
+// Before Go 1.23 Certificate.Leaf was left nil, and the parsed certificate was
+// discarded. This behavior can be re-enabled by setting "x509keypairleaf=0"
+// in the GODEBUG environment variable.
 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
 	fail := func(err error) (Certificate, error) { return Certificate{}, err }
 
@@ -304,6 +318,14 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
 		return fail(err)
 	}
 
+	// [Psiphon] godebug is not supported, but we
+	// populate Leaf in X509KeyPair by default.
+	// if x509keypairleaf.Value() != "0" {
+	cert.Leaf = x509Cert
+	// } else {
+	// 	x509keypairleaf.IncNonDefault()
+	// }
+
 	cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
 	if err != nil {
 		return fail(err)

+ 49 - 18
vendor/github.com/Psiphon-Labs/psiphon-tls/unsafe.go

@@ -33,13 +33,15 @@ package tls
 
 import (
 	"crypto/tls"
+	"fmt"
 	"reflect"
 	"unsafe"
 )
 
 func init() {
-	if !structsEqual(&tls.ConnectionState{}, &ConnectionState{}) {
-		panic("tls.ConnectionState doesn't match")
+	err := structsEqual(&tls.ConnectionState{}, &ConnectionState{})
+	if err != nil {
+		panic(fmt.Sprintf("tls: ConnectionState is not equal to tls.ConnectionState: %v", err))
 	}
 }
 
@@ -47,7 +49,11 @@ func UnsafeFromConnectionState(ss *ConnectionState) *tls.ConnectionState {
 	return (*tls.ConnectionState)(unsafe.Pointer(ss))
 }
 
-func structsEqual(a, b interface{}) bool {
+func UnsafeToConnectionState(ss *tls.ConnectionState) *ConnectionState {
+	return (*ConnectionState)(unsafe.Pointer(ss))
+}
+
+func structsEqual(a, b interface{}) error {
 	return compare(reflect.TypeOf(a), reflect.TypeOf(b))
 }
 
@@ -56,14 +62,18 @@ func structsEqual(a, b interface{}) bool {
 // compare does not currently support Maps, Chan, UnsafePointer if reflect.DeepEqual fails.
 // Support for these types can be added if needed.
 // note that field names are still compared.
-func compare(a, b reflect.Type) bool {
+func compare(a, b reflect.Type) error {
 
 	if reflect.DeepEqual(a, b) {
-		return true
+		return nil
+	}
+
+	if isPrimitive(a) && isPrimitive(b) && a.Kind() == b.Kind() {
+		return nil
 	}
 
 	if a.Kind() != b.Kind() {
-		return false
+		return fmt.Errorf("kind mismatch: %s vs %s", a.Kind().String(), b.Kind().String())
 	}
 
 	if a.Kind() == reflect.Pointer || a.Kind() == reflect.Slice {
@@ -72,25 +82,30 @@ func compare(a, b reflect.Type) bool {
 
 	if a.Kind() == reflect.Func {
 		if a.NumIn() != b.NumIn() || a.NumOut() != b.NumOut() {
-			return false
+			return fmt.Errorf(
+				"function signature mismatch: different number of input or output parameters: NumIn: %d vs %d, NumOut: %d vs %d",
+				a.NumIn(), b.NumIn(), a.NumOut(), b.NumOut(),
+			)
 		}
 		for i_in := 0; i_in < a.NumIn(); i_in++ {
-			if !compare(a.In(i_in), b.In(i_in)) {
-				return false
+			err := compare(a.In(i_in), b.In(i_in))
+			if err != nil {
+				return fmt.Errorf("function %s input parameter mismatch at index %d: %v", a.Name(), i_in, err)
 			}
 		}
 		for i_out := 0; i_out < a.NumOut(); i_out++ {
-			if !compare(a.Out(i_out), b.Out(i_out)) {
-				return false
+			err := compare(a.Out(i_out), b.Out(i_out))
+			if err != nil {
+				return fmt.Errorf("function %s output parameter mismatch at index %d: %v", a.Name(), i_out, err)
 			}
 		}
-		return true
+		return nil
 	}
 
 	if a.Kind() == reflect.Struct {
 
 		if a.NumField() != b.NumField() {
-			return false
+			return fmt.Errorf("struct field count mismatch: %d vs %d", a.NumField(), b.NumField())
 		}
 
 		for i := 0; i < a.NumField(); i++ {
@@ -99,19 +114,35 @@ func compare(a, b reflect.Type) bool {
 
 			if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name ||
 				fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset {
-				return false
+				return fmt.Errorf("struct field mismatch at index %d: %+v vs %+v", i, fa, fb)
 			}
 
 			if !reflect.DeepEqual(fa.Type, fb.Type) {
-				if !compare(fa.Type, fb.Type) {
-					return false
+				err := compare(fa.Type, fb.Type)
+				if err != nil {
+					return fmt.Errorf("struct %s field type mismatch at index %d with name %s: %v", a.Name(), i, fa.Name, err)
 				}
 			}
 		}
 
-		return true
+		return nil
 	}
 
 	// TODO: add support for missing types
-	return false
+	return fmt.Errorf("unsupported type: %s for field %s", a.Kind().String(), a.Name())
+}
+
+// isPrimitive checks if a reflect.Type represents a primitive type
+func isPrimitive(t reflect.Type) bool {
+	switch t.Kind() {
+	case reflect.Bool,
+		reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
+		reflect.Float32, reflect.Float64,
+		reflect.Complex64, reflect.Complex128,
+		reflect.String:
+		return true
+	default:
+		return false
+	}
 }

+ 1 - 0
vendor/github.com/Psiphon-Labs/quic-go/.gitignore

@@ -4,6 +4,7 @@ main
 mockgen_tmp.go
 *.qtr
 *.qlog
+*.sqlog
 *.txt
 race.[0-9]*
 

+ 22 - 15
vendor/github.com/Psiphon-Labs/quic-go/.golangci.yml

@@ -1,32 +1,29 @@
-run:
-  skip-files:
-    - internal/handshake/cipher_suite.go
 linters-settings:
-  depguard:
-    rules:
-      qtls:
-        list-mode: lax
-        files:
-          - "!internal/qtls/**"
-          - "$all"
-        deny:
-          - pkg: github.com/quic-go/qtls-go1-20
-            desc: "importing qtls only allowed in internal/qtls"
   misspell:
     ignore-words:
       - ect
+  depguard:
+    rules:
+      quicvarint:
+        list-mode: strict
+        files:
+          - "**/github.com/quic-go/quic-go/quicvarint/*"
+          - "!$test"
+        allow:
+          - $gostd
 
 linters:
   disable-all: true
   enable:
     - asciicheck
+    - copyloopvar
     - depguard
     - exhaustive
-    - exportloopref
     - goimports
     - gofmt # redundant, since gofmt *should* be a no-op after gofumpt
     - gofumpt
     - gosimple
+    - govet
     - ineffassign
     - misspell
     - prealloc
@@ -35,10 +32,20 @@ linters:
     - unconvert
     - unparam
     - unused
-    - vet
 
 issues:
+  exclude-files:
+    - internal/handshake/cipher_suite.go
   exclude-rules:
     - path: internal/qtls
       linters:
         - depguard
+    - path: _test\.go
+      linters:
+        - exhaustive
+        - prealloc
+        - unparam
+    - path: _test\.go
+      text: "SA1029:"
+      linters:
+        - staticcheck

+ 7 - 203
vendor/github.com/Psiphon-Labs/quic-go/README.md

@@ -2,13 +2,15 @@
 
 <img src="docs/quic.png" width=303 height=124>
 
+[![Documentation](https://img.shields.io/badge/docs-quic--go.net-red?style=flat)](https://quic-go.net/docs/)
 [![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go)
 [![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/)
 [![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/quic-go.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:quic-go)
 
-quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
+quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)).
+
+In addition to these base RFCs, it also implements the following RFCs:
 
-In addition to these base RFCs, it also implements the following RFCs: 
 * Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
 * Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
 * QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
@@ -16,207 +18,7 @@ In addition to these base RFCs, it also implements the following RFCs:
 
 Support for WebTransport over HTTP/3 ([draft-ietf-webtrans-http3](https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/)) is implemented in [webtransport-go](https://github.com/quic-go/webtransport-go).
 
-## Using QUIC
-
-### Running a Server
-
-The central entry point is the `quic.Transport`. A transport manages QUIC connections running on a single UDP socket. Since QUIC uses Connection IDs, it can demultiplex a listener (accepting incoming connections) and an arbitrary number of outgoing QUIC connections on the same UDP socket.
-
-```go
-udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 1234})
-// ... error handling
-tr := quic.Transport{
-  Conn: udpConn,
-}
-ln, err := tr.Listen(tlsConf, quicConf)
-// ... error handling
-go func() {
-  for {
-    conn, err := ln.Accept()
-    // ... error handling
-    // handle the connection, usually in a new Go routine
-  }
-}()
-```
-
-The listener `ln` can now be used to accept incoming QUIC connections by (repeatedly) calling the `Accept` method (see below for more information on the `quic.Connection`).
-
-As a shortcut,  `quic.Listen` and `quic.ListenAddr` can be used without explicitly initializing a `quic.Transport`:
-
-```
-ln, err := quic.Listen(udpConn, tlsConf, quicConf)
-```
-
-When using the shortcut, it's not possible to reuse the same UDP socket for outgoing connections.
-
-### Running a Client
-
-As mentioned above, multiple outgoing connections can share a single UDP socket, since QUIC uses Connection IDs to demultiplex connections.
-
-```go
-ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
-defer cancel()
-conn, err := tr.Dial(ctx, <server address>, <tls.Config>, <quic.Config>)
-// ... error handling
-```
-
-As a shortcut, `quic.Dial` and `quic.DialAddr` can be used without explictly initializing a `quic.Transport`:
-
-```go
-ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
-defer cancel()
-conn, err := quic.Dial(ctx, conn, <server address>, <tls.Config>, <quic.Config>)
-```
-
-Just as we saw before when used a similar shortcut to run a server, it's also not possible to reuse the same UDP socket for other outgoing connections, or to listen for incoming connections.
-
-### Using a QUIC Connection
-
-#### Accepting Streams
-
-QUIC is a stream-multiplexed transport. A `quic.Connection` fundamentally differs from the `net.Conn` and the `net.PacketConn` interface defined in the standard library. Data is sent and received on (unidirectional and bidirectional) streams (and, if supported, in [datagrams](#quic-datagrams)), not on the connection itself. The stream state machine is described in detail in [Section 3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-3).
-
-Note: A unidirectional stream is a stream that the initiator can only write to (`quic.SendStream`), and the receiver can only read from (`quic.ReceiveStream`). A bidirectional stream (`quic.Stream`) allows reading from and writing to for both sides.
-
-On the receiver side, streams are accepted using the `AcceptStream` (for bidirectional) and `AcceptUniStream` functions. For most user cases, it makes sense to call these functions in a loop:
-
-```go
-for {
-  str, err := conn.AcceptStream(context.Background()) // for bidirectional streams
-  // ... error handling
-  // handle the stream, usually in a new Go routine
-}
-```
-
-These functions return an error when the underlying QUIC connection is closed.
-
-#### Opening Streams
-
-There are two slightly different ways to open streams, one synchronous and one (potentially) asynchronous. This API is necessary since the receiver grants us a certain number of streams that we're allowed to open. It may grant us additional streams later on (typically when existing streams are closed), but it means that at the time we want to open a new stream, we might not be able to do so.
-
-Using the synchronous method `OpenStreamSync` for bidirectional streams, and `OpenUniStreamSync` for unidirectional streams, an application can block until the peer allows opening additional streams. In case that we're allowed to open a new stream, these methods return right away:
-
-```go
-ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
-defer cancel()
-str, err := conn.OpenStreamSync(ctx) // wait up to 5s to open a new bidirectional stream
-```
-
-The asynchronous version never blocks. If it's currently not possible to open a new stream, it returns a `net.Error` timeout error:
-
-```go
-str, err := conn.OpenStream()
-if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
-  // It's currently not possible to open another stream,
-  // but it might be possible later, once the peer allowed us to do so.
-}
-```
-
-These functions return an error when the underlying QUIC connection is closed.
-
-#### Using Streams
-
-Using QUIC streams is pretty straightforward. The `quic.ReceiveStream` implements the `io.Reader` interface, and the `quic.SendStream` implements the `io.Writer` interface. A bidirectional stream (`quic.Stream`) implements both these interfaces. Conceptually, a bidirectional stream can be thought of as the composition of two unidirectional streams in opposite directions.
-
-Calling `Close` on a `quic.SendStream` or a `quic.Stream` closes the send side of the stream. On the receiver side, this will be surfaced as an `io.EOF` returned from the `io.Reader` once all data has been consumed. Note that for bidirectional streams, `Close` _only_ closes the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
-
-In case the application wishes to abort sending on a `quic.SendStream` or a `quic.Stream` , it can reset the send side by calling `CancelWrite` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Reader`. Note that for bidirectional streams, `CancelWrite` _only_ resets the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
-
-Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream.
-
-A bidirectional stream is only closed once both the read and the write side of the stream have been either closed or reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`.
-
-### Configuring QUIC
-
-The `quic.Config` struct passed to both the listen and dial calls (see above) contains a wide range of configuration options for QUIC connections, incl. the ability to fine-tune flow control limits, the number of streams that the peer is allowed to open concurrently, keep-alives, idle timeouts, and many more. Please refer to the documentation for the `quic.Config` for details.
-
-The `quic.Transport` contains a few configuration options that don't apply to any single QUIC connection, but to all connections handled by that transport. It is highly recommend to set the `StatelessResetToken`, which allows endpoints to quickly recover from crashes / reboots of our node (see [Section 10.3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-10.3)).
-
-### Closing a Connection
-
-#### When the remote Peer closes the Connection
-
-In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Additionally, it is set as cancellation cause of the connection context. Users can use errors assertions to find out what exactly went wrong:
-
-* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions.
-* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`.
-* `quic.IdleTimeoutError`: Happens after completion of the handshake if the connection is idle for longer than the minimum of both peers idle timeouts (as configured by `quic.Config.IdleTimeout`). The connection is considered idle when no stream data (and datagrams, if applicable) are exchanged for that period. The QUIC connection can be instructed to regularly send a packet to prevent a connection from going idle by setting `quic.Config.KeepAlive`. However, this is no guarantee that the peer doesn't suddenly go away (e.g. by abruptly shutting down the node or by crashing), or by a NAT binding expiring, in which case this error might still occur.
-* `quic.StatelessResetError`: Happens when the remote peer lost the state required to decrypt the packet. This requires the `quic.Transport.StatelessResetToken` to be configured by the peer.
-* `quic.TransportError`: Happens if when the QUIC protocol is violated. Unless the error code is `APPLICATION_ERROR`, this will not happen unless one of the QUIC stacks involved is misbehaving. Please open an issue if you encounter this error.
-* `quic.ApplicationError`: Happens when the remote decides to close the connection, see below.
-
-#### Initiated by the Application
-
-A `quic.Connection` can be closed using `CloseWithError`:
-
-```go
-conn.CloseWithError(0x42, "error 0x42 occurred")
-```
-
-Applications can transmit both an error code (an unsigned 62-bit number) as well as a UTF-8 encoded human-readable reason. The error code allows the receiver to learn why the connection was closed, and the reason can be useful for debugging purposes.
-
-On the receiver side, this is surfaced as a `quic.ApplicationError`.
-
-### QUIC Datagrams
-
-Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`.
-
-QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not retransmitted.
-
-Datagrams are sent using the `SendDatagram` method on the `quic.Connection`:
-
-```go
-conn.SendDatagram([]byte("foobar"))
-```
-
-And received using `ReceiveDatagram`:
-
-```go
-msg, err := conn.ReceiveDatagram()
-```
-
-Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599).
-
-### QUIC Event Logging using qlog
-
-quic-go logs a wide range of events defined in [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/), providing comprehensive insights in the internals of a QUIC connection. 
-
-qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures.
-
-qlog can be activated by setting the `Tracer` callback on the `Config`. It is called as soon as quic-go decides to start the QUIC handshake on a new connection.
-`qlog.DefaultTracer` provides a tracer implementation which writes qlog files to a directory specified by the `QLOGDIR` environment variable, if set.
-The default qlog tracer can be used like this:
-```go
-quic.Config{
-  Tracer: qlog.DefaultTracer,
-}
-```
-
-This example creates a new qlog file under `<QLOGDIR>/<Original Destination Connection ID>_<Vantage Point>.qlog`, e.g. `qlogs/2e0407da_client.qlog`.
-
-
-For custom qlog behavior, `qlog.NewConnectionTracer` can be used.
-
-## Using HTTP/3
-
-### As a server
-
-See the [example server](example/main.go). Starting a QUIC server is very similar to the standard library http package in Go:
-
-```go
-http.Handle("/", http.FileServer(http.Dir(wwwDir)))
-http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
-```
-
-### As a client
-
-See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`.
-
-```go
-http.Client{
-  Transport: &http3.RoundTripper{},
-}
-```
+Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/).
 
 ## Projects using quic-go
 
@@ -226,11 +28,13 @@ http.Client{
 | [algernon](https://github.com/xyproto/algernon)           | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support                                                            | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square)        |
 | [caddy](https://github.com/caddyserver/caddy/)            | Fast, multi-platform web server with automatic HTTPS                                                                                                              | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square)       |
 | [cloudflared](https://github.com/cloudflare/cloudflared)  | A tunneling daemon that proxies traffic from the Cloudflare network to your origins                                                                               | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square)  |
+| [frp](https://github.com/fatedier/frp)                    | A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet                                                                   | ![GitHub Repo stars](https://img.shields.io/github/stars/fatedier/frp?style=flat-square)            |
 | [go-libp2p](https://github.com/libp2p/go-libp2p)          | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square)     |
 | [gost](https://github.com/go-gost/gost)                   | A simple security tunnel written in Go                                                                                                                        | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square)            |
 | [Hysteria](https://github.com/apernet/hysteria)           | A powerful, lightning fast and censorship resistant proxy                                                                                                         | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square)        |
 | [Mercure](https://github.com/dunglas/mercure)             | An open, easy, fast, reliable and battery-efficient solution for real-time communications                                                                         | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square)         |
 | [OONI Probe](https://github.com/ooni/probe-cli)           | Next generation OONI Probe. Library and CLI tool.                                                                                                                 | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square)          |
+| [reverst](https://github.com/flipt-io/reverst)            | Reverse Tunnels in Go over HTTP/3 and QUIC                                                                                                                        | ![GitHub Repo stars](https://img.shields.io/github/stars/flipt-io/reverst?style=flat-square) |
 | [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins | ![GitHub Repo stars](https://img.shields.io/github/stars/roadrunner-server/roadrunner?style=flat-square) |
 | [syncthing](https://github.com/syncthing/syncthing/)      | Open Source Continuous File Synchronization                                                                                                                       | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square)     |
 | [traefik](https://github.com/traefik/traefik)             | The Cloud Native Application Proxy                                                                                                                                | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square)         |

+ 1 - 149
vendor/github.com/Psiphon-Labs/quic-go/client.go

@@ -2,44 +2,13 @@ package quic
 
 import (
 	"context"
+	tls "github.com/Psiphon-Labs/psiphon-tls"
 	"errors"
 	"net"
 
-	tls "github.com/Psiphon-Labs/psiphon-tls"
-
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
-	"github.com/Psiphon-Labs/quic-go/internal/utils"
-	"github.com/Psiphon-Labs/quic-go/logging"
 )
 
-type client struct {
-	sendConn sendConn
-
-	use0RTT bool
-
-	packetHandlers packetHandlerManager
-	onClose        func()
-
-	tlsConf *tls.Config
-	config  *Config
-
-	connIDGenerator ConnectionIDGenerator
-	srcConnID       protocol.ConnectionID
-	destConnID      protocol.ConnectionID
-
-	initialPacketNumber  protocol.PacketNumber
-	hasNegotiatedVersion bool
-	version              protocol.VersionNumber
-
-	handshakeChan chan struct{}
-
-	conn quicConn
-
-	tracer    *logging.ConnectionTracer
-	tracingID uint64
-	logger    utils.Logger
-}
-
 // make it possible to mock connection ID for initial generation in the tests
 var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
 
@@ -133,120 +102,3 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo
 		isSingleUse: true,
 	}, nil
 }
-
-func dial(
-	ctx context.Context,
-	conn sendConn,
-	connIDGenerator ConnectionIDGenerator,
-	packetHandlers packetHandlerManager,
-	tlsConf *tls.Config,
-	config *Config,
-	onClose func(),
-	use0RTT bool,
-) (quicConn, error) {
-	c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
-	if err != nil {
-		return nil, err
-	}
-	c.packetHandlers = packetHandlers
-
-	c.tracingID = nextConnTracingID()
-	if c.config.Tracer != nil {
-		c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
-	}
-	if c.tracer != nil && c.tracer.StartedConnection != nil {
-		c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
-	}
-	if err := c.dial(ctx); err != nil {
-		return nil, err
-	}
-	return c.conn, nil
-}
-
-func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
-	srcConnID, err := connIDGenerator.GenerateConnectionID()
-	if err != nil {
-		return nil, err
-	}
-	destConnID, err := generateConnectionIDForInitial()
-	if err != nil {
-		return nil, err
-	}
-	c := &client{
-		connIDGenerator: connIDGenerator,
-		srcConnID:       srcConnID,
-		destConnID:      destConnID,
-		sendConn:        sendConn,
-		use0RTT:         use0RTT,
-		onClose:         onClose,
-		tlsConf:         tlsConf,
-		config:          config,
-		version:         config.Versions[0],
-		handshakeChan:   make(chan struct{}),
-		logger:          utils.DefaultLogger.WithPrefix("client"),
-	}
-	return c, nil
-}
-
-func (c *client) dial(ctx context.Context) error {
-	c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
-
-	c.conn = newClientConnection(
-		c.sendConn,
-		c.packetHandlers,
-		c.destConnID,
-		c.srcConnID,
-		c.connIDGenerator,
-		c.config,
-		c.tlsConf,
-		c.initialPacketNumber,
-		c.use0RTT,
-		c.hasNegotiatedVersion,
-		c.tracer,
-		c.tracingID,
-		c.logger,
-		c.version,
-	)
-	c.packetHandlers.Add(c.srcConnID, c.conn)
-
-	errorChan := make(chan error, 1)
-	recreateChan := make(chan errCloseForRecreating)
-	go func() {
-		err := c.conn.run()
-		var recreateErr *errCloseForRecreating
-		if errors.As(err, &recreateErr) {
-			recreateChan <- *recreateErr
-			return
-		}
-		if c.onClose != nil {
-			c.onClose()
-		}
-		errorChan <- err // returns as soon as the connection is closed
-	}()
-
-	// only set when we're using 0-RTT
-	// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
-	var earlyConnChan <-chan struct{}
-	if c.use0RTT {
-		earlyConnChan = c.conn.earlyConnReady()
-	}
-
-	select {
-	case <-ctx.Done():
-		c.conn.shutdown()
-		return context.Cause(ctx)
-	case err := <-errorChan:
-		return err
-	case recreateErr := <-recreateChan:
-		c.initialPacketNumber = recreateErr.nextPacketNumber
-		c.version = recreateErr.nextVersion
-		c.hasNegotiatedVersion = true
-		return c.dial(ctx)
-	case <-earlyConnChan:
-		// ready to send 0-RTT data
-		return nil
-	case <-c.conn.HandshakeComplete():
-		// handshake successfully completed
-		return nil
-	}
-}

+ 17 - 23
vendor/github.com/Psiphon-Labs/quic-go/closed_conn.go

@@ -3,8 +3,8 @@ package quic
 import (
 	"math/bits"
 	"net"
+	"sync/atomic"
 
-	"github.com/Psiphon-Labs/quic-go/internal/protocol"
 	"github.com/Psiphon-Labs/quic-go/internal/utils"
 )
 
@@ -12,9 +12,8 @@ import (
 // When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
 // with an exponential backoff.
 type closedLocalConn struct {
-	counter     uint32
-	perspective protocol.Perspective
-	logger      utils.Logger
+	counter atomic.Uint32
+	logger  utils.Logger
 
 	sendPacket func(net.Addr, packetInfo)
 }
@@ -22,43 +21,38 @@ type closedLocalConn struct {
 var _ packetHandler = &closedLocalConn{}
 
 // newClosedLocalConn creates a new closedLocalConn and runs it.
-func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
+func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler {
 	return &closedLocalConn{
-		sendPacket:  sendPacket,
-		perspective: pers,
-		logger:      logger,
+		sendPacket: sendPacket,
+		logger:     logger,
 	}
 }
 
 func (c *closedLocalConn) handlePacket(p receivedPacket) {
-	c.counter++
+	n := c.counter.Add(1)
 	// exponential backoff
 	// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
-	if bits.OnesCount32(c.counter) != 1 {
+	if bits.OnesCount32(n) != 1 {
 		return
 	}
-	c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter)
+	c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", n)
 	c.sendPacket(p.remoteAddr, p.info)
 }
 
-func (c *closedLocalConn) shutdown()                            {}
-func (c *closedLocalConn) destroy(error)                        {}
-func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
+func (c *closedLocalConn) destroy(error)                              {}
+func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
 
 // A closedRemoteConn is a connection that was closed remotely.
 // For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
 // We can just ignore those packets.
-type closedRemoteConn struct {
-	perspective protocol.Perspective
-}
+type closedRemoteConn struct{}
 
 var _ packetHandler = &closedRemoteConn{}
 
-func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
-	return &closedRemoteConn{perspective: pers}
+func newClosedRemoteConn() packetHandler {
+	return &closedRemoteConn{}
 }
 
-func (s *closedRemoteConn) handlePacket(receivedPacket)          {}
-func (s *closedRemoteConn) shutdown()                            {}
-func (s *closedRemoteConn) destroy(error)                        {}
-func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
+func (c *closedRemoteConn) handlePacket(receivedPacket)                {}
+func (c *closedRemoteConn) destroy(error)                              {}
+func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}

+ 4 - 0
vendor/github.com/Psiphon-Labs/quic-go/codecov.yml

@@ -5,6 +5,10 @@ coverage:
     - interop/
     - internal/handshake/cipher_suite.go
     - internal/utils/linkedlist/linkedlist.go
+    - internal/testdata
+    - logging/connection_tracer_multiplexer.go
+    - logging/tracer_multiplexer.go
+    - testutils/
     - fuzzing/
     - metrics/
   status:

+ 11 - 12
vendor/github.com/Psiphon-Labs/quic-go/config.go

@@ -2,7 +2,6 @@ package quic
 
 import (
 	"fmt"
-	"net"
 	"time"
 
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
@@ -40,6 +39,12 @@ func validateConfig(config *Config) error {
 	if config.MaxConnectionReceiveWindow > quicvarint.Max {
 		config.MaxConnectionReceiveWindow = quicvarint.Max
 	}
+	if config.InitialPacketSize > 0 && config.InitialPacketSize < protocol.MinInitialPacketSize {
+		config.InitialPacketSize = protocol.MinInitialPacketSize
+	}
+	if config.InitialPacketSize > protocol.MaxPacketBufferSize {
+		config.InitialPacketSize = protocol.MaxPacketBufferSize
+	}
 	// check that all QUIC versions are actually supported
 	for _, v := range config.Versions {
 		if !protocol.IsValidVersion(v) {
@@ -49,16 +54,6 @@ func validateConfig(config *Config) error {
 	return nil
 }
 
-// populateServerConfig populates fields in the quic.Config with their default values, if none are set
-// it may be called with nil
-func populateServerConfig(config *Config) *Config {
-	config = populateConfig(config)
-	if config.RequireAddressValidation == nil {
-		config.RequireAddressValidation = func(net.Addr) bool { return false }
-	}
-	return config
-}
-
 // populateConfig populates fields in the quic.Config with their default values, if none are set
 // it may be called with nil
 func populateConfig(config *Config) *Config {
@@ -105,13 +100,16 @@ func populateConfig(config *Config) *Config {
 	} else if maxIncomingUniStreams < 0 {
 		maxIncomingUniStreams = 0
 	}
+	initialPacketSize := config.InitialPacketSize
+	if initialPacketSize == 0 {
+		initialPacketSize = protocol.InitialPacketSize
+	}
 
 	return &Config{
 		GetConfigForClient:             config.GetConfigForClient,
 		Versions:                       versions,
 		HandshakeIdleTimeout:           handshakeIdleTimeout,
 		MaxIdleTimeout:                 idleTimeout,
-		RequireAddressValidation:       config.RequireAddressValidation,
 		KeepAlivePeriod:                config.KeepAlivePeriod,
 		InitialStreamReceiveWindow:     initialStreamReceiveWindow,
 		MaxStreamReceiveWindow:         maxStreamReceiveWindow,
@@ -122,6 +120,7 @@ func populateConfig(config *Config) *Config {
 		MaxIncomingUniStreams:          maxIncomingUniStreams,
 		TokenStore:                     config.TokenStore,
 		EnableDatagrams:                config.EnableDatagrams,
+		InitialPacketSize:              initialPacketSize,
 		DisablePathMTUDiscovery:        config.DisablePathMTUDiscovery,
 		Allow0RTT:                      config.Allow0RTT,
 		Tracer:                         config.Tracer,

+ 19 - 19
vendor/github.com/Psiphon-Labs/quic-go/conn_id_generator.go

@@ -15,34 +15,34 @@ type connIDGenerator struct {
 	activeSrcConnIDs        map[uint64]protocol.ConnectionID
 	initialClientDestConnID *protocol.ConnectionID // nil for the client
 
-	addConnectionID        func(protocol.ConnectionID)
-	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
-	removeConnectionID     func(protocol.ConnectionID)
-	retireConnectionID     func(protocol.ConnectionID)
-	replaceWithClosed      func([]protocol.ConnectionID, protocol.Perspective, []byte)
-	queueControlFrame      func(wire.Frame)
+	addConnectionID    func(protocol.ConnectionID)
+	statelessResetter  *statelessResetter
+	removeConnectionID func(protocol.ConnectionID)
+	retireConnectionID func(protocol.ConnectionID)
+	replaceWithClosed  func([]protocol.ConnectionID, []byte)
+	queueControlFrame  func(wire.Frame)
 }
 
 func newConnIDGenerator(
 	initialConnectionID protocol.ConnectionID,
 	initialClientDestConnID *protocol.ConnectionID, // nil for the client
 	addConnectionID func(protocol.ConnectionID),
-	getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
+	statelessResetter *statelessResetter,
 	removeConnectionID func(protocol.ConnectionID),
 	retireConnectionID func(protocol.ConnectionID),
-	replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
+	replaceWithClosed func([]protocol.ConnectionID, []byte),
 	queueControlFrame func(wire.Frame),
 	generator ConnectionIDGenerator,
 ) *connIDGenerator {
 	m := &connIDGenerator{
-		generator:              generator,
-		activeSrcConnIDs:       make(map[uint64]protocol.ConnectionID),
-		addConnectionID:        addConnectionID,
-		getStatelessResetToken: getStatelessResetToken,
-		removeConnectionID:     removeConnectionID,
-		retireConnectionID:     retireConnectionID,
-		replaceWithClosed:      replaceWithClosed,
-		queueControlFrame:      queueControlFrame,
+		generator:          generator,
+		activeSrcConnIDs:   make(map[uint64]protocol.ConnectionID),
+		addConnectionID:    addConnectionID,
+		statelessResetter:  statelessResetter,
+		removeConnectionID: removeConnectionID,
+		retireConnectionID: retireConnectionID,
+		replaceWithClosed:  replaceWithClosed,
+		queueControlFrame:  queueControlFrame,
 	}
 	m.activeSrcConnIDs[0] = initialConnectionID
 	m.initialClientDestConnID = initialClientDestConnID
@@ -104,7 +104,7 @@ func (m *connIDGenerator) issueNewConnID() error {
 	m.queueControlFrame(&wire.NewConnectionIDFrame{
 		SequenceNumber:      m.highestSeq + 1,
 		ConnectionID:        connID,
-		StatelessResetToken: m.getStatelessResetToken(connID),
+		StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID),
 	})
 	m.highestSeq++
 	return nil
@@ -126,7 +126,7 @@ func (m *connIDGenerator) RemoveAll() {
 	}
 }
 
-func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
+func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
 	connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
 	if m.initialClientDestConnID != nil {
 		connIDs = append(connIDs, *m.initialClientDestConnID)
@@ -134,5 +134,5 @@ func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose
 	for _, connID := range m.activeSrcConnIDs {
 		connIDs = append(connIDs, connID)
 	}
-	m.replaceWithClosed(connIDs, pers, connClose)
+	m.replaceWithClosed(connIDs, connClose)
 }

+ 22 - 0
vendor/github.com/Psiphon-Labs/quic-go/conn_id_manager.go

@@ -35,6 +35,8 @@ type connIDManager struct {
 	addStatelessResetToken    func(protocol.StatelessResetToken)
 	removeStatelessResetToken func(protocol.StatelessResetToken)
 	queueControlFrame         func(wire.Frame)
+
+	closed bool
 }
 
 func newConnIDManager(
@@ -66,6 +68,12 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
 }
 
 func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
+	if h.activeConnectionID.Len() == 0 {
+		return &qerr.TransportError{
+			ErrorCode:    qerr.ProtocolViolation,
+			ErrorMessage: "received NEW_CONNECTION_ID frame but zero-length connection IDs are in use",
+		}
+	}
 	// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
 	// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
 	if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired {
@@ -142,6 +150,7 @@ func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID
 }
 
 func (h *connIDManager) updateConnectionID() {
+	h.assertNotClosed()
 	h.queueControlFrame(&wire.RetireConnectionIDFrame{
 		SequenceNumber: h.activeSequenceNumber,
 	})
@@ -160,6 +169,7 @@ func (h *connIDManager) updateConnectionID() {
 }
 
 func (h *connIDManager) Close() {
+	h.closed = true
 	if h.activeStatelessResetToken != nil {
 		h.removeStatelessResetToken(*h.activeStatelessResetToken)
 	}
@@ -176,6 +186,7 @@ func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
 
 // is called when the server provides a stateless reset token in the transport parameters
 func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
+	h.assertNotClosed()
 	if h.activeSequenceNumber != 0 {
 		panic("expected first connection ID to have sequence number 0")
 	}
@@ -203,6 +214,7 @@ func (h *connIDManager) shouldUpdateConnID() bool {
 }
 
 func (h *connIDManager) Get() protocol.ConnectionID {
+	h.assertNotClosed()
 	if h.shouldUpdateConnID() {
 		h.updateConnectionID()
 	}
@@ -212,3 +224,13 @@ func (h *connIDManager) Get() protocol.ConnectionID {
 func (h *connIDManager) SetHandshakeComplete() {
 	h.handshakeComplete = true
 }
+
+// Using the connIDManager after it has been closed can have disastrous effects:
+// If the connection ID is rotated, a new entry would be inserted into the packet handler map,
+// leading to a memory leak of the connection struct.
+// See https://github.com/quic-go/quic-go/pull/4852 for more details.
+func (h *connIDManager) assertNotClosed() {
+	if h.closed {
+		panic("connection ID manager is closed")
+	}
+}

File diff suppressed because it is too large
+ 216 - 186
vendor/github.com/Psiphon-Labs/quic-go/connection.go


+ 168 - 0
vendor/github.com/Psiphon-Labs/quic-go/connection_logging.go

@@ -0,0 +1,168 @@
+package quic
+
+import (
+	"slices"
+
+	"github.com/Psiphon-Labs/quic-go/internal/ackhandler"
+	"github.com/Psiphon-Labs/quic-go/internal/protocol"
+	"github.com/Psiphon-Labs/quic-go/internal/wire"
+	"github.com/Psiphon-Labs/quic-go/logging"
+)
+
+// ConvertFrame converts a wire.Frame into a logging.Frame.
+// This makes it possible for external packages to access the frames.
+// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
+func toLoggingFrame(frame wire.Frame) logging.Frame {
+	switch f := frame.(type) {
+	case *wire.AckFrame:
+		// We use a pool for ACK frames.
+		// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
+		return toLoggingAckFrame(f)
+	case *wire.CryptoFrame:
+		return &logging.CryptoFrame{
+			Offset: f.Offset,
+			Length: protocol.ByteCount(len(f.Data)),
+		}
+	case *wire.StreamFrame:
+		return &logging.StreamFrame{
+			StreamID: f.StreamID,
+			Offset:   f.Offset,
+			Length:   f.DataLen(),
+			Fin:      f.Fin,
+		}
+	case *wire.DatagramFrame:
+		return &logging.DatagramFrame{
+			Length: logging.ByteCount(len(f.Data)),
+		}
+	default:
+		return logging.Frame(frame)
+	}
+}
+
+func toLoggingAckFrame(f *wire.AckFrame) *logging.AckFrame {
+	ack := &logging.AckFrame{
+		AckRanges: slices.Clone(f.AckRanges),
+		DelayTime: f.DelayTime,
+		ECNCE:     f.ECNCE,
+		ECT0:      f.ECT0,
+		ECT1:      f.ECT1,
+	}
+	return ack
+}
+
+func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
+	// quic-go logging
+	if s.logger.Debug() {
+		p.header.Log(s.logger)
+		if p.ack != nil {
+			wire.LogFrame(s.logger, p.ack, true)
+		}
+		for _, frame := range p.frames {
+			wire.LogFrame(s.logger, frame.Frame, true)
+		}
+		for _, frame := range p.streamFrames {
+			wire.LogFrame(s.logger, frame.Frame, true)
+		}
+	}
+
+	// tracing
+	if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil {
+		frames := make([]logging.Frame, 0, len(p.frames))
+		for _, f := range p.frames {
+			frames = append(frames, toLoggingFrame(f.Frame))
+		}
+		for _, f := range p.streamFrames {
+			frames = append(frames, toLoggingFrame(f.Frame))
+		}
+		var ack *logging.AckFrame
+		if p.ack != nil {
+			ack = toLoggingAckFrame(p.ack)
+		}
+		s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames)
+	}
+}
+
+func (s *connection) logShortHeaderPacket(
+	destConnID protocol.ConnectionID,
+	ackFrame *wire.AckFrame,
+	frames []ackhandler.Frame,
+	streamFrames []ackhandler.StreamFrame,
+	pn protocol.PacketNumber,
+	pnLen protocol.PacketNumberLen,
+	kp protocol.KeyPhaseBit,
+	ecn protocol.ECN,
+	size protocol.ByteCount,
+	isCoalesced bool,
+) {
+	if s.logger.Debug() && !isCoalesced {
+		s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn)
+	}
+	// quic-go logging
+	if s.logger.Debug() {
+		wire.LogShortHeader(s.logger, destConnID, pn, pnLen, kp)
+		if ackFrame != nil {
+			wire.LogFrame(s.logger, ackFrame, true)
+		}
+		for _, f := range frames {
+			wire.LogFrame(s.logger, f.Frame, true)
+		}
+		for _, f := range streamFrames {
+			wire.LogFrame(s.logger, f.Frame, true)
+		}
+	}
+
+	// tracing
+	if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil {
+		fs := make([]logging.Frame, 0, len(frames)+len(streamFrames))
+		for _, f := range frames {
+			fs = append(fs, toLoggingFrame(f.Frame))
+		}
+		for _, f := range streamFrames {
+			fs = append(fs, toLoggingFrame(f.Frame))
+		}
+		var ack *logging.AckFrame
+		if ackFrame != nil {
+			ack = toLoggingAckFrame(ackFrame)
+		}
+		s.tracer.SentShortHeaderPacket(
+			&logging.ShortHeader{DestConnectionID: destConnID, PacketNumber: pn, PacketNumberLen: pnLen, KeyPhase: kp},
+			size,
+			ecn,
+			ack,
+			fs,
+		)
+	}
+}
+
+func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
+	if s.logger.Debug() {
+		// There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
+		// during which we might call PackCoalescedPacket but just pack a short header packet.
+		if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil {
+			s.logShortHeaderPacket(
+				packet.shortHdrPacket.DestConnID,
+				packet.shortHdrPacket.Ack,
+				packet.shortHdrPacket.Frames,
+				packet.shortHdrPacket.StreamFrames,
+				packet.shortHdrPacket.PacketNumber,
+				packet.shortHdrPacket.PacketNumberLen,
+				packet.shortHdrPacket.KeyPhase,
+				ecn,
+				packet.shortHdrPacket.Length,
+				false,
+			)
+			return
+		}
+		if len(packet.longHdrPackets) > 1 {
+			s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), s.logID)
+		} else {
+			s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.longHdrPackets[0].EncryptionLevel())
+		}
+	}
+	for _, p := range packet.longHdrPackets {
+		s.logLongHeaderPacket(p, ecn)
+	}
+	if p := packet.shortHdrPacket; p != nil {
+		s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true)
+	}
+}

+ 13 - 36
vendor/github.com/Psiphon-Labs/quic-go/crypto_stream.go

@@ -2,27 +2,14 @@ package quic
 
 import (
 	"fmt"
-	"io"
 
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
 	"github.com/Psiphon-Labs/quic-go/internal/qerr"
 	"github.com/Psiphon-Labs/quic-go/internal/wire"
 )
 
-type cryptoStream interface {
-	// for receiving data
-	HandleCryptoFrame(*wire.CryptoFrame) error
-	GetCryptoData() []byte
-	Finish() error
-	// for sending data
-	io.Writer
-	HasData() bool
-	PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
-}
-
-type cryptoStreamImpl struct {
-	queue  *frameSorter
-	msgBuf []byte
+type cryptoStream struct {
+	queue frameSorter
 
 	highestOffset protocol.ByteCount
 	finished      bool
@@ -31,11 +18,11 @@ type cryptoStreamImpl struct {
 	writeBuf    []byte
 }
 
-func newCryptoStream() cryptoStream {
-	return &cryptoStreamImpl{queue: newFrameSorter()}
+func newCryptoStream() *cryptoStream {
+	return &cryptoStream{queue: *newFrameSorter()}
 }
 
-func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
+func (s *cryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error {
 	highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
 	if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
 		return &qerr.TransportError{
@@ -56,26 +43,16 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
 		return nil
 	}
 	s.highestOffset = max(s.highestOffset, highestOffset)
-	if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
-		return err
-	}
-	for {
-		_, data, _ := s.queue.Pop()
-		if data == nil {
-			return nil
-		}
-		s.msgBuf = append(s.msgBuf, data...)
-	}
+	return s.queue.Push(f.Data, f.Offset, nil)
 }
 
 // GetCryptoData retrieves data that was received in CRYPTO frames
-func (s *cryptoStreamImpl) GetCryptoData() []byte {
-	b := s.msgBuf
-	s.msgBuf = nil
-	return b
+func (s *cryptoStream) GetCryptoData() []byte {
+	_, data, _ := s.queue.Pop()
+	return data
 }
 
-func (s *cryptoStreamImpl) Finish() error {
+func (s *cryptoStream) Finish() error {
 	if s.queue.HasMoreData() {
 		return &qerr.TransportError{
 			ErrorCode:    qerr.ProtocolViolation,
@@ -87,16 +64,16 @@ func (s *cryptoStreamImpl) Finish() error {
 }
 
 // Writes writes data that should be sent out in CRYPTO frames
-func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
+func (s *cryptoStream) Write(p []byte) (int, error) {
 	s.writeBuf = append(s.writeBuf, p...)
 	return len(p), nil
 }
 
-func (s *cryptoStreamImpl) HasData() bool {
+func (s *cryptoStream) HasData() bool {
 	return len(s.writeBuf) > 0
 }
 
-func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
+func (s *cryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
 	f := &wire.CryptoFrame{Offset: s.writeOffset}
 	n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
 	f.Data = s.writeBuf[:n]

+ 23 - 28
vendor/github.com/Psiphon-Labs/quic-go/crypto_stream_manager.go

@@ -3,32 +3,22 @@ package quic
 import (
 	"fmt"
 
-	"github.com/Psiphon-Labs/quic-go/internal/handshake"
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
 	"github.com/Psiphon-Labs/quic-go/internal/wire"
 )
 
-type cryptoDataHandler interface {
-	HandleMessage([]byte, protocol.EncryptionLevel) error
-	NextEvent() handshake.Event
-}
-
 type cryptoStreamManager struct {
-	cryptoHandler cryptoDataHandler
-
-	initialStream   cryptoStream
-	handshakeStream cryptoStream
-	oneRTTStream    cryptoStream
+	initialStream   *cryptoStream
+	handshakeStream *cryptoStream
+	oneRTTStream    *cryptoStream
 }
 
 func newCryptoStreamManager(
-	cryptoHandler cryptoDataHandler,
-	initialStream cryptoStream,
-	handshakeStream cryptoStream,
-	oneRTTStream cryptoStream,
+	initialStream *cryptoStream,
+	handshakeStream *cryptoStream,
+	oneRTTStream *cryptoStream,
 ) *cryptoStreamManager {
 	return &cryptoStreamManager{
-		cryptoHandler:   cryptoHandler,
 		initialStream:   initialStream,
 		handshakeStream: handshakeStream,
 		oneRTTStream:    oneRTTStream,
@@ -36,7 +26,7 @@ func newCryptoStreamManager(
 }
 
 func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
-	var str cryptoStream
+	var str *cryptoStream
 	//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
 	switch encLevel {
 	case protocol.EncryptionInitial:
@@ -48,18 +38,23 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
 	default:
 		return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
 	}
-	if err := str.HandleCryptoFrame(frame); err != nil {
-		return err
-	}
-	for {
-		data := str.GetCryptoData()
-		if data == nil {
-			return nil
-		}
-		if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
-			return err
-		}
+	return str.HandleCryptoFrame(frame)
+}
+
+func (m *cryptoStreamManager) GetCryptoData(encLevel protocol.EncryptionLevel) []byte {
+	var str *cryptoStream
+	//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
+	switch encLevel {
+	case protocol.EncryptionInitial:
+		str = m.initialStream
+	case protocol.EncryptionHandshake:
+		str = m.handshakeStream
+	case protocol.Encryption1RTT:
+		str = m.oneRTTStream
+	default:
+		panic(fmt.Sprintf("received CRYPTO frame with unexpected encryption level: %s", encLevel))
 	}
+	return str.GetCryptoData()
 }
 
 func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {

+ 5 - 5
vendor/github.com/Psiphon-Labs/quic-go/errors.go

@@ -50,8 +50,8 @@ type StreamError struct {
 }
 
 func (e *StreamError) Is(target error) bool {
-	_, ok := target.(*StreamError)
-	return ok
+	t, ok := target.(*StreamError)
+	return ok && e.StreamID == t.StreamID && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
 }
 
 func (e *StreamError) Error() string {
@@ -64,12 +64,12 @@ func (e *StreamError) Error() string {
 
 // DatagramTooLargeError is returned from Connection.SendDatagram if the payload is too large to be sent.
 type DatagramTooLargeError struct {
-	PeerMaxDatagramFrameSize int64
+	MaxDatagramPayloadSize int64
 }
 
 func (e *DatagramTooLargeError) Is(target error) bool {
-	_, ok := target.(*DatagramTooLargeError)
-	return ok
+	t, ok := target.(*DatagramTooLargeError)
+	return ok && e.MaxDatagramPayloadSize == t.MaxDatagramPayloadSize
 }
 
 func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" }

+ 185 - 84
vendor/github.com/Psiphon-Labs/quic-go/framer.go

@@ -1,53 +1,54 @@
 package quic
 
 import (
-	"errors"
+	"slices"
 	"sync"
+	"time"
 
 	"github.com/Psiphon-Labs/quic-go/internal/ackhandler"
+	"github.com/Psiphon-Labs/quic-go/internal/flowcontrol"
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
 	"github.com/Psiphon-Labs/quic-go/internal/utils/ringbuffer"
 	"github.com/Psiphon-Labs/quic-go/internal/wire"
 	"github.com/Psiphon-Labs/quic-go/quicvarint"
 )
 
-type framer interface {
-	HasData() bool
-
-	QueueControlFrame(wire.Frame)
-	AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
+const (
+	maxPathResponses = 256
+	maxControlFrames = 16 << 10
+)
 
-	AddActiveStream(protocol.StreamID)
-	AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
+// This is the largest possible size of a stream-related control frame
+// (which is the RESET_STREAM frame).
+const maxStreamControlFrameSize = 25
 
-	Handle0RTTRejection() error
+type streamControlFrameGetter interface {
+	getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool)
 }
 
-const maxPathResponses = 256
-
-type framerI struct {
+type framer struct {
 	mutex sync.Mutex
 
-	streamGetter streamGetter
+	activeStreams            map[protocol.StreamID]sendStreamI
+	streamQueue              ringbuffer.RingBuffer[protocol.StreamID]
+	streamsWithControlFrames map[protocol.StreamID]streamControlFrameGetter
 
-	activeStreams map[protocol.StreamID]struct{}
-	streamQueue   ringbuffer.RingBuffer[protocol.StreamID]
-
-	controlFrameMutex sync.Mutex
-	controlFrames     []wire.Frame
-	pathResponses     []*wire.PathResponseFrame
+	controlFrameMutex          sync.Mutex
+	controlFrames              []wire.Frame
+	pathResponses              []*wire.PathResponseFrame
+	connFlowController         flowcontrol.ConnectionFlowController
+	queuedTooManyControlFrames bool
 }
 
-var _ framer = &framerI{}
-
-func newFramer(streamGetter streamGetter) framer {
-	return &framerI{
-		streamGetter:  streamGetter,
-		activeStreams: make(map[protocol.StreamID]struct{}),
+func newFramer(connFlowController flowcontrol.ConnectionFlowController) *framer {
+	return &framer{
+		activeStreams:            make(map[protocol.StreamID]sendStreamI),
+		streamsWithControlFrames: make(map[protocol.StreamID]streamControlFrameGetter),
+		connFlowController:       connFlowController,
 	}
 }
 
-func (f *framerI) HasData() bool {
+func (f *framer) HasData() bool {
 	f.mutex.Lock()
 	hasData := !f.streamQueue.Empty()
 	f.mutex.Unlock()
@@ -56,10 +57,10 @@ func (f *framerI) HasData() bool {
 	}
 	f.controlFrameMutex.Lock()
 	defer f.controlFrameMutex.Unlock()
-	return len(f.controlFrames) > 0 || len(f.pathResponses) > 0
+	return len(f.streamsWithControlFrames) > 0 || len(f.controlFrames) > 0 || len(f.pathResponses) > 0
 }
 
-func (f *framerI) QueueControlFrame(frame wire.Frame) {
+func (f *framer) QueueControlFrame(frame wire.Frame) {
 	f.controlFrameMutex.Lock()
 	defer f.controlFrameMutex.Unlock()
 
@@ -73,13 +74,88 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) {
 		f.pathResponses = append(f.pathResponses, pr)
 		return
 	}
+	// This is a hack.
+	if len(f.controlFrames) >= maxControlFrames {
+		f.queuedTooManyControlFrames = true
+		return
+	}
 	f.controlFrames = append(f.controlFrames, frame)
 }
 
-func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
+func (f *framer) Append(
+	frames []ackhandler.Frame,
+	streamFrames []ackhandler.StreamFrame,
+	maxLen protocol.ByteCount,
+	now time.Time,
+	v protocol.Version,
+) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) {
 	f.controlFrameMutex.Lock()
-	defer f.controlFrameMutex.Unlock()
+	frames, controlFrameLen := f.appendControlFrames(frames, maxLen, now, v)
+	maxLen -= controlFrameLen
+
+	var lastFrame ackhandler.StreamFrame
+	var streamFrameLen protocol.ByteCount
+	f.mutex.Lock()
+	// pop STREAM frames, until less than 128 bytes are left in the packet
+	numActiveStreams := f.streamQueue.Len()
+	for i := 0; i < numActiveStreams; i++ {
+		if protocol.MinStreamFrameSize > maxLen {
+			break
+		}
+		sf, blocked := f.getNextStreamFrame(maxLen, v)
+		if sf.Frame != nil {
+			streamFrames = append(streamFrames, sf)
+			maxLen -= sf.Frame.Length(v)
+			lastFrame = sf
+			streamFrameLen += sf.Frame.Length(v)
+		}
+		// If the stream just became blocked on stream flow control, attempt to pack the
+		// STREAM_DATA_BLOCKED into the same packet.
+		if blocked != nil {
+			l := blocked.Length(v)
+			// In case it doesn't fit, queue it for the next packet.
+			if maxLen < l {
+				f.controlFrames = append(f.controlFrames, blocked)
+				break
+			}
+			frames = append(frames, ackhandler.Frame{Frame: blocked})
+			maxLen -= l
+			controlFrameLen += l
+		}
+	}
 
+	// The only way to become blocked on connection-level flow control is by sending STREAM frames.
+	if isBlocked, offset := f.connFlowController.IsNewlyBlocked(); isBlocked {
+		blocked := &wire.DataBlockedFrame{MaximumData: offset}
+		l := blocked.Length(v)
+		// In case it doesn't fit, queue it for the next packet.
+		if maxLen >= l {
+			frames = append(frames, ackhandler.Frame{Frame: blocked})
+			controlFrameLen += l
+		} else {
+			f.controlFrames = append(f.controlFrames, blocked)
+		}
+	}
+
+	f.mutex.Unlock()
+	f.controlFrameMutex.Unlock()
+
+	if lastFrame.Frame != nil {
+		// account for the smaller size of the last STREAM frame
+		streamFrameLen -= lastFrame.Frame.Length(v)
+		lastFrame.Frame.DataLenPresent = false
+		streamFrameLen += lastFrame.Frame.Length(v)
+	}
+
+	return frames, streamFrames, controlFrameLen + streamFrameLen
+}
+
+func (f *framer) appendControlFrames(
+	frames []ackhandler.Frame,
+	maxLen protocol.ByteCount,
+	now time.Time,
+	v protocol.Version,
+) ([]ackhandler.Frame, protocol.ByteCount) {
 	var length protocol.ByteCount
 	// add a PATH_RESPONSE first, but only pack a single PATH_RESPONSE per packet
 	if len(f.pathResponses) > 0 {
@@ -92,6 +168,29 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
 		}
 	}
 
+	// add stream-related control frames
+	for id, str := range f.streamsWithControlFrames {
+	start:
+		remainingLen := maxLen - length
+		if remainingLen <= maxStreamControlFrameSize {
+			break
+		}
+		fr, ok, hasMore := str.getControlFrame(now)
+		if !hasMore {
+			delete(f.streamsWithControlFrames, id)
+		}
+		if !ok {
+			continue
+		}
+		frames = append(frames, fr)
+		length += fr.Frame.Length(v)
+		if hasMore {
+			// It is rare that a stream has more than one control frame to queue.
+			// We don't want to spawn another loop for just to cover that case.
+			goto start
+		}
+	}
+
 	for len(f.controlFrames) > 0 {
 		frame := f.controlFrames[len(f.controlFrames)-1]
 		frameLen := frame.Length(v)
@@ -102,72 +201,77 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
 		length += frameLen
 		f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
 	}
+
 	return frames, length
 }
 
-func (f *framerI) AddActiveStream(id protocol.StreamID) {
+// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
+// This is a hack.
+// It is easier to implement than propagating an error return value in QueueControlFrame.
+// The correct solution would be to queue frames with their respective structs.
+// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
+func (f *framer) QueuedTooManyControlFrames() bool {
+	return f.queuedTooManyControlFrames
+}
+
+func (f *framer) AddActiveStream(id protocol.StreamID, str sendStreamI) {
 	f.mutex.Lock()
 	if _, ok := f.activeStreams[id]; !ok {
 		f.streamQueue.PushBack(id)
-		f.activeStreams[id] = struct{}{}
+		f.activeStreams[id] = str
 	}
 	f.mutex.Unlock()
 }
 
-func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) {
-	startLen := len(frames)
-	var length protocol.ByteCount
-	f.mutex.Lock()
-	// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
-	numActiveStreams := f.streamQueue.Len()
-	for i := 0; i < numActiveStreams; i++ {
-		if protocol.MinStreamFrameSize+length > maxLen {
-			break
-		}
-		id := f.streamQueue.PopFront()
-		// This should never return an error. Better check it anyway.
-		// The stream will only be in the streamQueue, if it enqueued itself there.
-		str, err := f.streamGetter.GetOrOpenSendStream(id)
-		// The stream can be nil if it completed after it said it had data.
-		if str == nil || err != nil {
-			delete(f.activeStreams, id)
-			continue
-		}
-		remainingLen := maxLen - length
-		// For the last STREAM frame, we'll remove the DataLen field later.
-		// Therefore, we can pretend to have more bytes available when popping
-		// the STREAM frame (which will always have the DataLen set).
-		remainingLen += quicvarint.Len(uint64(remainingLen))
-		frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v)
-		if hasMoreData { // put the stream back in the queue (at the end)
-			f.streamQueue.PushBack(id)
-		} else { // no more data to send. Stream is not active
-			delete(f.activeStreams, id)
-		}
-		// The frame can be "nil"
-		// * if the receiveStream was canceled after it said it had data
-		// * the remaining size doesn't allow us to add another STREAM frame
-		if !ok {
-			continue
-		}
-		frames = append(frames, frame)
-		length += frame.Frame.Length(v)
+func (f *framer) AddStreamWithControlFrames(id protocol.StreamID, str streamControlFrameGetter) {
+	f.controlFrameMutex.Lock()
+	if _, ok := f.streamsWithControlFrames[id]; !ok {
+		f.streamsWithControlFrames[id] = str
 	}
+	f.controlFrameMutex.Unlock()
+}
+
+// RemoveActiveStream is called when a stream completes.
+func (f *framer) RemoveActiveStream(id protocol.StreamID) {
+	f.mutex.Lock()
+	delete(f.activeStreams, id)
+	// We don't delete the stream from the streamQueue,
+	// since we'd have to iterate over the ringbuffer.
+	// Instead, we check if the stream is still in activeStreams when appending STREAM frames.
 	f.mutex.Unlock()
-	if len(frames) > startLen {
-		l := frames[len(frames)-1].Frame.Length(v)
-		// account for the smaller size of the last STREAM frame
-		frames[len(frames)-1].Frame.DataLenPresent = false
-		length += frames[len(frames)-1].Frame.Length(v) - l
+}
+
+func (f *framer) getNextStreamFrame(maxLen protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame) {
+	id := f.streamQueue.PopFront()
+	// This should never return an error. Better check it anyway.
+	// The stream will only be in the streamQueue, if it enqueued itself there.
+	str, ok := f.activeStreams[id]
+	// The stream might have been removed after being enqueued.
+	if !ok {
+		return ackhandler.StreamFrame{}, nil
 	}
-	return frames, length
+	// For the last STREAM frame, we'll remove the DataLen field later.
+	// Therefore, we can pretend to have more bytes available when popping
+	// the STREAM frame (which will always have the DataLen set).
+	maxLen += protocol.ByteCount(quicvarint.Len(uint64(maxLen)))
+	frame, blocked, hasMoreData := str.popStreamFrame(maxLen, v)
+	if hasMoreData { // put the stream back in the queue (at the end)
+		f.streamQueue.PushBack(id)
+	} else { // no more data to send. Stream is not active
+		delete(f.activeStreams, id)
+	}
+	// Note that the frame.Frame can be nil:
+	// * if the stream was canceled after it said it had data
+	// * the remaining size doesn't allow us to add another STREAM frame
+	return frame, blocked
 }
 
-func (f *framerI) Handle0RTTRejection() error {
+func (f *framer) Handle0RTTRejection() {
 	f.mutex.Lock()
 	defer f.mutex.Unlock()
-
 	f.controlFrameMutex.Lock()
+	defer f.controlFrameMutex.Unlock()
+
 	f.streamQueue.Clear()
 	for id := range f.activeStreams {
 		delete(f.activeStreams, id)
@@ -175,16 +279,13 @@ func (f *framerI) Handle0RTTRejection() error {
 	var j int
 	for i, frame := range f.controlFrames {
 		switch frame.(type) {
-		case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame:
-			return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT")
-		case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
+		case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame,
+			*wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
 			continue
 		default:
 			f.controlFrames[j] = f.controlFrames[i]
 			j++
 		}
 	}
-	f.controlFrames = f.controlFrames[:j]
-	f.controlFrameMutex.Unlock()
-	return nil
+	f.controlFrames = slices.Delete(f.controlFrames, j, len(f.controlFrames))
 }

+ 3 - 98
vendor/github.com/Psiphon-Labs/quic-go/http3/README.md

@@ -1,104 +1,9 @@
 # HTTP/3
 
+[![Documentation](https://img.shields.io/badge/docs-quic--go.net-red?style=flat)](https://quic-go.net/docs/)
 [![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go/http3)](https://pkg.go.dev/github.com/quic-go/quic-go/http3)
 
-This package implements HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
+This package implements HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)).
 It aims to provide feature parity with the standard library's HTTP/1.1 and HTTP/2 implementation.
 
-## Serving HTTP/3
-
-The easiest way to start an HTTP/3 server is using
-```go
-mux := http.NewServeMux()
-// ... add HTTP handlers to mux ...
-// If mux is nil, the http.DefaultServeMux is used.
-http3.ListenAndServeQUIC("0.0.0.0:443", "/path/to/cert", "/path/to/key", mux)
-```
-
-`ListenAndServeQUIC` is a convenience function. For more configurability, set up an `http3.Server` explicitly:
-```go
-server := http3.Server{
-	Handler:    mux,
-	Addr:       "0.0.0.0:443",
-	TLSConfig:  http3.ConfigureTLSConfig(&tls.Config{}), // use your tls.Config here
-	QuicConfig: &quic.Config{},
-}
-err := server.ListenAndServe()
-```
-
-The `http3.Server` provides a number of configuration options, please refer to the [documentation](https://pkg.go.dev/github.com/quic-go/quic-go/http3#Server) for a complete list. The `QuicConfig` is used to configure the underlying QUIC connection. More details can be found in the documentation of the QUIC package.
-
-It is also possible to manually set up a `quic.Transport`, and then pass the listener to the server. This is useful when you want to set configuration options on the `quic.Transport`.
-```go
-tr := quic.Transport{Conn: conn}
-tlsConf := http3.ConfigureTLSConfig(&tls.Config{})  // use your tls.Config here
-quicConf := &quic.Config{} // QUIC connection options
-server := http3.Server{}
-ln, _ := tr.ListenEarly(tlsConf, quicConf)
-server.ServeListener(ln)
-```
-
-Alternatively, it is also possible to pass fully established QUIC connections to the HTTP/3 server. This is useful if the QUIC server offers multiple ALPNs (via `NextProtos` in the `tls.Config`).
-```go
-tr := quic.Transport{Conn: conn}
-tlsConf := http3.ConfigureTLSConfig(&tls.Config{})  // use your tls.Config here
-quicConf := &quic.Config{} // QUIC connection options
-server := http3.Server{}
-// alternatively, use tr.ListenEarly to accept 0-RTT connections
-ln, _ := tr.Listen(tlsConf, quicConf)
-for {
-	c, _ := ln.Accept()
-	switch c.ConnectionState().TLS.NegotiatedProtocol {
-	case http3.NextProtoH3:
-		go server.ServeQUICConn(c) 
-        // ... handle other protocols ...  
-	}
-}
-```
-
-## Dialing HTTP/3
-
-This package provides a `http.RoundTripper` implementation that can be used on the `http.Client`:
-
-```go
-&http3.RoundTripper{
-	TLSClientConfig: &tls.Config{},  // set a TLS client config, if desired
-	QuicConfig:      &quic.Config{}, // QUIC connection options
-}
-defer roundTripper.Close()
-client := &http.Client{
-	Transport: roundTripper,
-}
-```
-
-The `http3.RoundTripper` provides a number of configuration options, please refer to the [documentation](https://pkg.go.dev/github.com/quic-go/quic-go/http3#RoundTripper) for a complete list.
-
-To use a custom `quic.Transport`, the function used to dial new QUIC connections can be configured:
-```go
-tr := quic.Transport{}
-roundTripper := &http3.RoundTripper{
-	TLSClientConfig: &tls.Config{},  // set a TLS client config, if desired 
-	QuicConfig:      &quic.Config{}, // QUIC connection options 
-	Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
-		a, err := net.ResolveUDPAddr("udp", addr)
-		if err != nil {
-			return nil, err
-		}
-		return tr.DialEarly(ctx, a, tlsConf, quicConf)
-	},
-}
-```
-
-## Using the same UDP Socket for Server and Roundtripper
-
-Since QUIC demultiplexes packets based on their connection IDs, it is possible allows running a QUIC server and client on the same UDP socket. This also works when using HTTP/3: HTTP requests can be sent from the same socket that a server is listening on.
-
-To achieve this using this package, first initialize a single `quic.Transport`, and pass a `quic.EarlyListner` obtained from that transport to `http3.Server.ServeListener`, and use the `DialEarly` function of the transport as the `Dial` function for the `http3.RoundTripper`.
-
-## QPACK
-
-HTTP/3 utilizes QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) for efficient HTTP header field compression. Our implementation, available at[quic-go/qpack](https://github.com/quic-go/qpack), provides a minimal implementation of the protocol.  
-
-While the current implementation is a fully interoperable implementation of the QPACK protocol, it only uses the static compression table. The dynamic table would allow for more effective compression of frequently transmitted header fields. This can be particularly beneficial in scenarios where headers have considerable redundancy or in high-throughput environments.
-
-If you think that your application would benefit from higher compression efficiency, or if you're interested in contributing improvements here, please let us know in [#2424](https://github.com/quic-go/quic-go/issues/2424).
+Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/).

+ 70 - 73
vendor/github.com/Psiphon-Labs/quic-go/http3/body.go

@@ -2,68 +2,68 @@ package http3
 
 import (
 	"context"
+	"errors"
 	"io"
-	"net"
+	"sync"
 
 	"github.com/Psiphon-Labs/quic-go"
 )
 
-// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by:
-// * for the server: the http.Request.Body
-// * for the client: the http.Response.Body
-// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
-// When a stream is taken over, it's the caller's responsibility to close the stream.
-type HTTPStreamer interface {
-	HTTPStream() Stream
-}
-
-type StreamCreator interface {
-	// Context returns a context that is cancelled when the underlying connection is closed.
-	Context() context.Context
-	OpenStream() (quic.Stream, error)
-	OpenStreamSync(context.Context) (quic.Stream, error)
-	OpenUniStream() (quic.SendStream, error)
-	OpenUniStreamSync(context.Context) (quic.SendStream, error)
-	LocalAddr() net.Addr
-	RemoteAddr() net.Addr
-	ConnectionState() quic.ConnectionState
-}
-
-var _ StreamCreator = quic.Connection(nil)
-
 // A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body.
 // It is used by WebTransport to create WebTransport streams after a session has been established.
 type Hijacker interface {
-	StreamCreator() StreamCreator
+	Connection() Connection
 }
 
-// The body of a http.Request or http.Response.
+var errTooMuchData = errors.New("peer sent too much data")
+
+// The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response).
 type body struct {
-	str quic.Stream
+	str *stream
 
-	wasHijacked bool // set when HTTPStream is called
+	remainingContentLength int64
+	violatedContentLength  bool
+	hasContentLength       bool
 }
 
-var (
-	_ io.ReadCloser = &body{}
-	_ HTTPStreamer  = &body{}
-)
-
-func newRequestBody(str Stream) *body {
-	return &body{str: str}
+func newBody(str *stream, contentLength int64) *body {
+	b := &body{str: str}
+	if contentLength >= 0 {
+		b.hasContentLength = true
+		b.remainingContentLength = contentLength
+	}
+	return b
 }
 
-func (r *body) HTTPStream() Stream {
-	r.wasHijacked = true
-	return r.str
-}
+func (r *body) StreamID() quic.StreamID { return r.str.StreamID() }
 
-func (r *body) wasStreamHijacked() bool {
-	return r.wasHijacked
+func (r *body) checkContentLengthViolation() error {
+	if !r.hasContentLength {
+		return nil
+	}
+	if r.remainingContentLength < 0 || r.remainingContentLength == 0 && r.str.hasMoreData() {
+		if !r.violatedContentLength {
+			r.str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
+			r.str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
+			r.violatedContentLength = true
+		}
+		return errTooMuchData
+	}
+	return nil
 }
 
 func (r *body) Read(b []byte) (int, error) {
+	if err := r.checkContentLengthViolation(); err != nil {
+		return 0, err
+	}
+	if r.hasContentLength {
+		b = b[:min(int64(len(b)), r.remainingContentLength)]
+	}
 	n, err := r.str.Read(b)
+	r.remainingContentLength -= int64(n)
+	if err := r.checkContentLengthViolation(); err != nil {
+		return n, err
+	}
 	return n, maybeReplaceError(err)
 }
 
@@ -72,38 +72,45 @@ func (r *body) Close() error {
 	return nil
 }
 
-type hijackableBody struct {
+type requestBody struct {
 	body
-	conn quic.Connection // only needed to implement Hijacker
+	connCtx      context.Context
+	rcvdSettings <-chan struct{}
+	getSettings  func() *Settings
+}
+
+var _ io.ReadCloser = &requestBody{}
+
+func newRequestBody(str *stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody {
+	return &requestBody{
+		body:         *newBody(str, contentLength),
+		connCtx:      connCtx,
+		rcvdSettings: rcvdSettings,
+		getSettings:  getSettings,
+	}
+}
+
+type hijackableBody struct {
+	body body
 
 	// only set for the http.Response
 	// The channel is closed when the user is done with this response:
 	// either when Read() errors, or when Close() is called.
-	reqDone       chan<- struct{}
-	reqDoneClosed bool
+	reqDone     chan<- struct{}
+	reqDoneOnce sync.Once
 }
 
-var (
-	_ Hijacker     = &hijackableBody{}
-	_ HTTPStreamer = &hijackableBody{}
-)
+var _ io.ReadCloser = &hijackableBody{}
 
-func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody {
+func newResponseBody(str *stream, contentLength int64, done chan<- struct{}) *hijackableBody {
 	return &hijackableBody{
-		body: body{
-			str: str,
-		},
+		body:    *newBody(str, contentLength),
 		reqDone: done,
-		conn:    conn,
 	}
 }
 
-func (r *hijackableBody) StreamCreator() StreamCreator {
-	return r.conn
-}
-
 func (r *hijackableBody) Read(b []byte) (int, error) {
-	n, err := r.str.Read(b)
+	n, err := r.body.Read(b)
 	if err != nil {
 		r.requestDone()
 	}
@@ -111,26 +118,16 @@ func (r *hijackableBody) Read(b []byte) (int, error) {
 }
 
 func (r *hijackableBody) requestDone() {
-	if r.reqDoneClosed || r.reqDone == nil {
-		return
-	}
 	if r.reqDone != nil {
-		close(r.reqDone)
+		r.reqDoneOnce.Do(func() {
+			close(r.reqDone)
+		})
 	}
-	r.reqDoneClosed = true
-}
-
-func (r *body) StreamID() quic.StreamID {
-	return r.str.StreamID()
 }
 
 func (r *hijackableBody) Close() error {
 	r.requestDone()
 	// If the EOF was read, CancelRead() is a no-op.
-	r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
+	r.body.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
 	return nil
 }
-
-func (r *hijackableBody) HTTPStream() Stream {
-	return r.str
-}

+ 26 - 7
vendor/github.com/Psiphon-Labs/quic-go/http3/capsule.go

@@ -6,28 +6,47 @@ import (
 	"github.com/Psiphon-Labs/quic-go/quicvarint"
 )
 
-// CapsuleType is the type of the capsule.
+// CapsuleType is the type of the capsule
 type CapsuleType uint64
 
+// CapsuleProtocolHeader is the header value used to advertise support for the capsule protocol
+const CapsuleProtocolHeader = "Capsule-Protocol"
+
 type exactReader struct {
-	R *io.LimitedReader
+	R io.LimitedReader
 }
 
 func (r *exactReader) Read(b []byte) (int, error) {
 	n, err := r.R.Read(b)
-	if r.R.N > 0 {
+	if err == io.EOF && r.R.N > 0 {
 		return n, io.ErrUnexpectedEOF
 	}
 	return n, err
 }
 
+type countingByteReader struct {
+	io.ByteReader
+	Read int
+}
+
+func (r *countingByteReader) ReadByte() (byte, error) {
+	b, err := r.ByteReader.ReadByte()
+	if err == nil {
+		r.Read++
+	}
+	return b, err
+}
+
 // ParseCapsule parses the header of a Capsule.
-// It returns an io.LimitedReader that can be used to read the Capsule value.
+// It returns an io.Reader that can be used to read the Capsule value.
 // The Capsule value must be read entirely (i.e. until the io.EOF) before using r again.
 func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
-	ct, err := quicvarint.Read(r)
+	cbr := countingByteReader{ByteReader: r}
+	ct, err := quicvarint.Read(&cbr)
 	if err != nil {
-		if err == io.EOF {
+		// If an io.EOF is returned without consuming any bytes, return it unmodified.
+		// Otherwise, return an io.ErrUnexpectedEOF.
+		if err == io.EOF && cbr.Read > 0 {
 			return 0, nil, io.ErrUnexpectedEOF
 		}
 		return 0, nil, err
@@ -39,7 +58,7 @@ func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
 		}
 		return 0, nil, err
 	}
-	return CapsuleType(ct), &exactReader{R: io.LimitReader(r, int64(l)).(*io.LimitedReader)}, nil
+	return CapsuleType(ct), &exactReader{R: io.LimitedReader{R: r, N: int64(l)}}, nil
 }
 
 // WriteCapsule writes a capsule

+ 198 - 326
vendor/github.com/Psiphon-Labs/quic-go/http3/client.go

@@ -5,26 +5,29 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"net"
+	"log/slog"
 	"net/http"
-	"strconv"
-	"sync"
-	"sync/atomic"
+	"net/http/httptrace"
+	"net/textproto"
 	"time"
 
-	tls "github.com/Psiphon-Labs/psiphon-tls"
-
 	"github.com/Psiphon-Labs/quic-go"
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
-	"github.com/Psiphon-Labs/quic-go/internal/utils"
 	"github.com/Psiphon-Labs/quic-go/quicvarint"
 
 	"github.com/quic-go/qpack"
+
+	tls "github.com/Psiphon-Labs/psiphon-tls"
 )
 
-// MethodGet0RTT allows a GET request to be sent using 0-RTT.
-// Note that 0-RTT data doesn't provide replay protection.
-const MethodGet0RTT = "GET_0RTT"
+const (
+	// MethodGet0RTT allows a GET request to be sent using 0-RTT.
+	// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
+	MethodGet0RTT = "GET_0RTT"
+	// MethodHead0RTT allows a HEAD request to be sent using 0-RTT.
+	// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
+	MethodHead0RTT = "HEAD_0RTT"
+)
 
 const (
 	defaultUserAgent              = "quic-go HTTP/3"
@@ -36,226 +39,141 @@ var defaultQuicConfig = &quic.Config{
 	KeepAlivePeriod:    10 * time.Second,
 }
 
-type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
-
-var dialAddr dialFunc = quic.DialAddrEarly
-
-type roundTripperOpts struct {
-	DisableCompression bool
-	EnableDatagram     bool
-	MaxHeaderBytes     int64
-	AdditionalSettings map[uint64]uint64
-	StreamHijacker     func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
-	UniStreamHijacker  func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
-}
-
-// client is a HTTP3 client doing requests
-type client struct {
-	tlsConf *tls.Config
-	config  *quic.Config
-	opts    *roundTripperOpts
-
-	dialOnce     sync.Once
-	dialer       dialFunc
-	handshakeErr error
-
-	requestWriter *requestWriter
+// ClientConn is an HTTP/3 client doing requests to a single remote server.
+type ClientConn struct {
+	connection
 
-	decoder *qpack.Decoder
+	// Enable support for HTTP/3 datagrams (RFC 9297).
+	// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting enableDatagrams.
+	enableDatagrams bool
 
-	hostname string
-	conn     atomic.Pointer[quic.EarlyConnection]
+	// Additional HTTP/3 settings.
+	// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
+	additionalSettings map[uint64]uint64
 
-	logger utils.Logger
-}
+	// maxResponseHeaderBytes specifies a limit on how many response bytes are
+	// allowed in the server's response header.
+	maxResponseHeaderBytes uint64
 
-var _ roundTripCloser = &client{}
+	// disableCompression, if true, prevents the Transport from requesting compression with an
+	// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
+	// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
+	// decoded in the Response.Body.
+	// However, if the user explicitly requested gzip it is not automatically uncompressed.
+	disableCompression bool
 
-func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
-	if conf == nil {
-		conf = defaultQuicConfig.Clone()
-	}
-	if len(conf.Versions) == 0 {
-		conf = conf.Clone()
-		conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]}
-	}
-	if len(conf.Versions) != 1 {
-		return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
-	}
-	if conf.MaxIncomingStreams == 0 {
-		conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
-	}
-	conf.EnableDatagrams = opts.EnableDatagram
-	logger := utils.DefaultLogger.WithPrefix("h3 client")
+	logger *slog.Logger
 
-	if tlsConf == nil {
-		tlsConf = &tls.Config{}
-	} else {
-		tlsConf = tlsConf.Clone()
-	}
-	if tlsConf.ServerName == "" {
-		sni, _, err := net.SplitHostPort(hostname)
-		if err != nil {
-			// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
-			sni = hostname
-		}
-		tlsConf.ServerName = sni
-	}
-	// Replace existing ALPNs by H3
-	tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
-
-	return &client{
-		hostname:      authorityAddr("https", hostname),
-		tlsConf:       tlsConf,
-		requestWriter: newRequestWriter(logger),
-		decoder:       qpack.NewDecoder(func(hf qpack.HeaderField) {}),
-		config:        conf,
-		opts:          opts,
-		dialer:        dialer,
-		logger:        logger,
-	}, nil
+	requestWriter *requestWriter
+	decoder       *qpack.Decoder
 }
 
-func (c *client) dial(ctx context.Context) error {
-	var err error
-	var conn quic.EarlyConnection
-	if c.dialer != nil {
-		conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
+var _ http.RoundTripper = &ClientConn{}
+
+// Deprecated: SingleDestinationRoundTripper was renamed to ClientConn.
+// It can be obtained by calling NewClientConn on a Transport.
+type SingleDestinationRoundTripper = ClientConn
+
+func newClientConn(
+	conn quic.Connection,
+	enableDatagrams bool,
+	additionalSettings map[uint64]uint64,
+	streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error),
+	uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool),
+	maxResponseHeaderBytes int64,
+	disableCompression bool,
+	logger *slog.Logger,
+) *ClientConn {
+	c := &ClientConn{
+		enableDatagrams:    enableDatagrams,
+		additionalSettings: additionalSettings,
+		disableCompression: disableCompression,
+		logger:             logger,
+	}
+	if maxResponseHeaderBytes <= 0 {
+		c.maxResponseHeaderBytes = defaultMaxResponseHeaderBytes
 	} else {
-		conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
-	}
-	if err != nil {
-		return err
-	}
-	c.conn.Store(&conn)
-
+		c.maxResponseHeaderBytes = uint64(maxResponseHeaderBytes)
+	}
+	c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
+	c.requestWriter = newRequestWriter()
+	c.connection = *newConnection(
+		conn.Context(),
+		conn,
+		c.enableDatagrams,
+		protocol.PerspectiveClient,
+		c.logger,
+		0,
+	)
 	// send the SETTINGs frame, using 0-RTT data, if possible
 	go func() {
-		if err := c.setupConn(conn); err != nil {
-			c.logger.Debugf("Setting up connection failed: %s", err)
-			conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
+		if err := c.setupConn(); err != nil {
+			if c.logger != nil {
+				c.logger.Debug("Setting up connection failed", "error", err)
+			}
+			c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
 		}
 	}()
-
-	if c.opts.StreamHijacker != nil {
-		go c.handleBidirectionalStreams(conn)
+	if streamHijacker != nil {
+		go c.handleBidirectionalStreams(streamHijacker)
 	}
-	go c.handleUnidirectionalStreams(conn)
-	return nil
+	go c.connection.handleUnidirectionalStreams(uniStreamHijacker)
+	return c
 }
 
-func (c *client) setupConn(conn quic.EarlyConnection) error {
+// OpenRequestStream opens a new request stream on the HTTP/3 connection.
+func (c *ClientConn) OpenRequestStream(ctx context.Context) (RequestStream, error) {
+	return c.connection.openRequestStream(ctx, c.requestWriter, nil, c.disableCompression, c.maxResponseHeaderBytes)
+}
+
+func (c *ClientConn) setupConn() error {
 	// open the control stream
-	str, err := conn.OpenUniStream()
+	str, err := c.connection.OpenUniStream()
 	if err != nil {
 		return err
 	}
 	b := make([]byte, 0, 64)
 	b = quicvarint.Append(b, streamTypeControlStream)
 	// send the SETTINGS frame
-	b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b)
+	b = (&settingsFrame{Datagram: c.enableDatagrams, Other: c.additionalSettings}).Append(b)
 	_, err = str.Write(b)
 	return err
 }
 
-func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
+func (c *ClientConn) handleBidirectionalStreams(streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)) {
 	for {
-		str, err := conn.AcceptStream(context.Background())
+		str, err := c.connection.AcceptStream(context.Background())
 		if err != nil {
-			c.logger.Debugf("accepting bidirectional stream failed: %s", err)
-			return
-		}
-		go func(str quic.Stream) {
-			_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
-				return c.opts.StreamHijacker(ft, conn, str, e)
-			})
-			if err == errHijacked {
-				return
+			if c.logger != nil {
+				c.logger.Debug("accepting bidirectional stream failed", "error", err)
 			}
-			if err != nil {
-				c.logger.Debugf("error handling stream: %s", err)
-			}
-			conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
-		}(str)
-	}
-}
-
-func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
-	for {
-		str, err := conn.AcceptUniStream(context.Background())
-		if err != nil {
-			c.logger.Debugf("accepting unidirectional stream failed: %s", err)
 			return
 		}
-
-		go func(str quic.ReceiveStream) {
-			streamType, err := quicvarint.Read(quicvarint.NewReader(str))
-			if err != nil {
-				if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) {
-					return
-				}
-				c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
-				return
-			}
-			// We're only interested in the control stream here.
-			switch streamType {
-			case streamTypeControlStream:
-			case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
-				// Our QPACK implementation doesn't use the dynamic table yet.
-				// TODO: check that only one stream of each type is opened.
-				return
-			case streamTypePushStream:
-				// We never increased the Push ID, so we don't expect any push streams.
-				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
-				return
-			default:
-				if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
-					return
-				}
-				str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
+		fp := &frameParser{
+			r:    str,
+			conn: &c.connection,
+			unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) {
+				id := c.connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
+				return streamHijacker(ft, id, str, e)
+			},
+		}
+		go func() {
+			if _, err := fp.ParseNext(); err == errHijacked {
 				return
 			}
-			f, err := parseNextFrame(str, nil)
 			if err != nil {
-				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
-				return
-			}
-			sf, ok := f.(*settingsFrame)
-			if !ok {
-				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
-				return
-			}
-			if !sf.Datagram {
-				return
-			}
-			// If datagram support was enabled on our side as well as on the server side,
-			// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
-			// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
-			if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams {
-				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
+				if c.logger != nil {
+					c.logger.Debug("error handling stream", "error", err)
+				}
 			}
-		}(str)
-	}
-}
-
-func (c *client) Close() error {
-	conn := c.conn.Load()
-	if conn == nil {
-		return nil
-	}
-	return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
-}
-
-func (c *client) maxHeaderBytes() uint64 {
-	if c.opts.MaxHeaderBytes <= 0 {
-		return defaultMaxResponseHeaderBytes
+			c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
+		}()
 	}
-	return uint64(c.opts.MaxHeaderBytes)
 }
 
-// RoundTripOpt executes a request and returns a response
-func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
-	rsp, err := c.roundTripOpt(req, opt)
+// RoundTrip executes a request and returns a response
+func (c *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
+	rsp, err := c.roundTrip(req)
 	if err != nil && req.Context().Err() != nil {
 		// if the context was canceled, return the context cancellation error
 		err = req.Context().Err()
@@ -263,34 +181,54 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
 	return rsp, err
 }
 
-func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
-	if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
-		return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
-	}
-
-	c.dialOnce.Do(func() {
-		c.handshakeErr = c.dial(req.Context())
-	})
-	if c.handshakeErr != nil {
-		return nil, c.handshakeErr
-	}
-
-	// At this point, c.conn is guaranteed to be set.
-	conn := *c.conn.Load()
-
+func (c *ClientConn) roundTrip(req *http.Request) (*http.Response, error) {
 	// Immediately send out this request, if this is a 0-RTT request.
-	if req.Method == MethodGet0RTT {
+	switch req.Method {
+	case MethodGet0RTT:
+		// don't modify the original request
+		reqCopy := *req
+		req = &reqCopy
 		req.Method = http.MethodGet
-	} else {
+	case MethodHead0RTT:
+		// don't modify the original request
+		reqCopy := *req
+		req = &reqCopy
+		req.Method = http.MethodHead
+	default:
 		// wait for the handshake to complete
+		earlyConn, ok := c.Connection.(quic.EarlyConnection)
+		if ok {
+			select {
+			case <-earlyConn.HandshakeComplete():
+			case <-req.Context().Done():
+				return nil, req.Context().Err()
+			}
+		}
+	}
+
+	// It is only possible to send an Extended CONNECT request once the SETTINGS were received.
+	// See section 3 of RFC 8441.
+	if isExtendedConnectRequest(req) {
+		connCtx := c.Connection.Context()
+		// wait for the server's SETTINGS frame to arrive
 		select {
-		case <-conn.HandshakeComplete():
-		case <-req.Context().Done():
-			return nil, req.Context().Err()
+		case <-c.connection.ReceivedSettings():
+		case <-connCtx.Done():
+			return nil, context.Cause(connCtx)
+		}
+		if !c.connection.Settings().EnableExtendedConnect {
+			return nil, errors.New("http3: server didn't enable Extended CONNECT")
 		}
 	}
 
-	str, err := conn.OpenStreamSync(req.Context())
+	reqDone := make(chan struct{})
+	str, err := c.connection.openRequestStream(
+		req.Context(),
+		c.requestWriter,
+		reqDone,
+		c.disableCompression,
+		c.maxResponseHeaderBytes,
+	)
 	if err != nil {
 		return nil, err
 	}
@@ -298,7 +236,6 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
 	// Request Cancellation:
 	// This go routine keeps running even after RoundTripOpt() returns.
 	// It is shut down when the application is done processing the body.
-	reqDone := make(chan struct{})
 	done := make(chan struct{})
 	go func() {
 		defer close(done)
@@ -310,31 +247,13 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
 		}
 	}()
 
-	doneChan := reqDone
-	if opt.DontCloseRequestStream {
-		doneChan = nil
-	}
-	rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
-	if rerr.err != nil { // if any error occurred
-		close(reqDone)
-		<-done
-		if rerr.streamErr != 0 { // if it was a stream error
-			str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
-		}
-		if rerr.connErr != 0 { // if it was a connection error
-			var reason string
-			if rerr.err != nil {
-				reason = rerr.err.Error()
-			}
-			conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
-		}
-		return nil, maybeReplaceError(rerr.err)
-	}
-	if opt.DontCloseRequestStream {
+	rsp, err := c.doRequest(req, str)
+	if err != nil { // if any error occurred
 		close(reqDone)
 		<-done
+		return nil, maybeReplaceError(err)
 	}
-	return rsp, maybeReplaceError(rerr.err)
+	return rsp, maybeReplaceError(err)
 }
 
 // cancelingReader reads from the io.Reader.
@@ -352,7 +271,7 @@ func (r *cancelingReader) Read(b []byte) (int, error) {
 	return n, err
 }
 
-func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
+func (c *ClientConn) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
 	defer body.Close()
 	buf := make([]byte, bodyCopyBufferSize)
 	sr := &cancelingReader{str: str, r: body}
@@ -376,21 +295,16 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength i
 	return err
 }
 
-func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
-	var requestGzip bool
-	if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
-		requestGzip = true
-	}
-	if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
-		return nil, newStreamError(ErrCodeInternalError, err)
+func (c *ClientConn) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
+	trace := httptrace.ContextClientTrace(req.Context())
+	if err := str.SendRequestHeader(req); err != nil {
+		traceWroteRequest(trace, err)
+		return nil, err
 	}
-
-	if req.Body == nil && !opt.DontCloseRequestStream {
+	if req.Body == nil {
+		traceWroteRequest(trace, nil)
 		str.Close()
-	}
-
-	hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") })
-	if req.Body != nil {
+	} else {
 		// send the request body asynchronously
 		go func() {
 			contentLength := int64(-1)
@@ -399,92 +313,50 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
 			if req.ContentLength > 0 {
 				contentLength = req.ContentLength
 			}
-			if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
-				c.logger.Errorf("Error writing request: %s", err)
-			}
-			if !opt.DontCloseRequestStream {
-				hstr.Close()
+			err := c.sendRequestBody(str, req.Body, contentLength)
+			traceWroteRequest(trace, err)
+			if err != nil {
+				if c.logger != nil {
+					c.logger.Debug("error writing request", "error", err)
+				}
 			}
+			str.Close()
 		}()
 	}
 
-	frame, err := parseNextFrame(str, nil)
-	if err != nil {
-		return nil, newStreamError(ErrCodeFrameError, err)
-	}
-	hf, ok := frame.(*headersFrame)
-	if !ok {
-		return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
-	}
-	if hf.Length > c.maxHeaderBytes() {
-		return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
-	}
-	headerBlock := make([]byte, hf.Length)
-	if _, err := io.ReadFull(str, headerBlock); err != nil {
-		return nil, newStreamError(ErrCodeRequestIncomplete, err)
-	}
-	hfs, err := c.decoder.DecodeFull(headerBlock)
-	if err != nil {
-		// TODO: use the right error code
-		return nil, newConnError(ErrCodeGeneralProtocolError, err)
-	}
+	// copy from net/http: support 1xx responses
+	num1xx := 0               // number of informational 1xx headers received
+	const max1xxResponses = 5 // arbitrary bound on number of informational responses
 
-	res, err := responseFromHeaders(hfs)
-	if err != nil {
-		return nil, newStreamError(ErrCodeMessageError, err)
+	var res *http.Response
+	for {
+		var err error
+		res, err = str.ReadResponse()
+		if err != nil {
+			return nil, err
+		}
+		resCode := res.StatusCode
+		is1xx := 100 <= resCode && resCode <= 199
+		// treat 101 as a terminal status, see https://github.com/golang/go/issues/26161
+		is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
+		if is1xxNonTerminal {
+			num1xx++
+			if num1xx > max1xxResponses {
+				return nil, errors.New("http: too many 1xx informational responses")
+			}
+			traceGot1xxResponse(trace, resCode, textproto.MIMEHeader(res.Header))
+			if resCode == 100 {
+				traceGot100Continue(trace)
+			}
+			continue
+		}
+		break
 	}
-	connState := conn.ConnectionState().TLS
+	connState := c.connection.ConnectionState().TLS
 
 	// [Psiphon]
 	res.TLS = tls.UnsafeFromConnectionState(&connState)
 
 	res.Request = req
-	// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
-	// See section 4.1.2 of RFC 9114.
-	var httpStr Stream
-	if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 {
-		httpStr = newLengthLimitedStream(hstr, res.ContentLength)
-	} else {
-		httpStr = hstr
-	}
-	respBody := newResponseBody(httpStr, conn, reqDone)
-
-	// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
-	_, hasTransferEncoding := res.Header["Transfer-Encoding"]
-	isInformational := res.StatusCode >= 100 && res.StatusCode < 200
-	isNoContent := res.StatusCode == http.StatusNoContent
-	isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
-	if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
-		res.ContentLength = -1
-		if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
-			if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
-				res.ContentLength = clen64
-			}
-		}
-	}
-
-	if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
-		res.Header.Del("Content-Encoding")
-		res.Header.Del("Content-Length")
-		res.ContentLength = -1
-		res.Body = newGzipReader(respBody)
-		res.Uncompressed = true
-	} else {
-		res.Body = respBody
-	}
-
-	return res, requestError{}
-}
-
-func (c *client) HandshakeComplete() bool {
-	conn := c.conn.Load()
-	if conn == nil {
-		return false
-	}
-	select {
-	case <-(*conn).HandshakeComplete():
-		return true
-	default:
-		return false
-	}
+	return res, nil
 }

+ 329 - 0
vendor/github.com/Psiphon-Labs/quic-go/http3/conn.go

@@ -0,0 +1,329 @@
+package http3
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"log/slog"
+	"net"
+	"net/http"
+	"net/http/httptrace"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/Psiphon-Labs/quic-go"
+	"github.com/Psiphon-Labs/quic-go/internal/protocol"
+	"github.com/Psiphon-Labs/quic-go/quicvarint"
+
+	"github.com/quic-go/qpack"
+)
+
+// Connection is an HTTP/3 connection.
+// It has all methods from the quic.Connection expect for AcceptStream, AcceptUniStream,
+// SendDatagram and ReceiveDatagram.
+type Connection interface {
+	OpenStream() (quic.Stream, error)
+	OpenStreamSync(context.Context) (quic.Stream, error)
+	OpenUniStream() (quic.SendStream, error)
+	OpenUniStreamSync(context.Context) (quic.SendStream, error)
+	LocalAddr() net.Addr
+	RemoteAddr() net.Addr
+	CloseWithError(quic.ApplicationErrorCode, string) error
+	Context() context.Context
+	ConnectionState() quic.ConnectionState
+
+	// ReceivedSettings returns a channel that is closed once the client's SETTINGS frame was received.
+	ReceivedSettings() <-chan struct{}
+	// Settings returns the settings received on this connection.
+	Settings() *Settings
+}
+
+type connection struct {
+	quic.Connection
+	ctx context.Context
+
+	perspective protocol.Perspective
+	logger      *slog.Logger
+
+	enableDatagrams bool
+
+	decoder *qpack.Decoder
+
+	streamMx sync.Mutex
+	streams  map[protocol.StreamID]*datagrammer
+
+	settings         *Settings
+	receivedSettings chan struct{}
+
+	idleTimeout time.Duration
+	idleTimer   *time.Timer
+}
+
+func newConnection(
+	ctx context.Context,
+	quicConn quic.Connection,
+	enableDatagrams bool,
+	perspective protocol.Perspective,
+	logger *slog.Logger,
+	idleTimeout time.Duration,
+) *connection {
+	c := &connection{
+		ctx:              ctx,
+		Connection:       quicConn,
+		perspective:      perspective,
+		logger:           logger,
+		idleTimeout:      idleTimeout,
+		enableDatagrams:  enableDatagrams,
+		decoder:          qpack.NewDecoder(func(hf qpack.HeaderField) {}),
+		receivedSettings: make(chan struct{}),
+		streams:          make(map[protocol.StreamID]*datagrammer),
+	}
+	if idleTimeout > 0 {
+		c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer)
+	}
+	return c
+}
+
+func (c *connection) onIdleTimer() {
+	c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "idle timeout")
+}
+
+func (c *connection) clearStream(id quic.StreamID) {
+	c.streamMx.Lock()
+	defer c.streamMx.Unlock()
+
+	delete(c.streams, id)
+	if c.idleTimeout > 0 && len(c.streams) == 0 {
+		c.idleTimer.Reset(c.idleTimeout)
+	}
+}
+
+func (c *connection) openRequestStream(
+	ctx context.Context,
+	requestWriter *requestWriter,
+	reqDone chan<- struct{},
+	disableCompression bool,
+	maxHeaderBytes uint64,
+) (*requestStream, error) {
+	str, err := c.Connection.OpenStreamSync(ctx)
+	if err != nil {
+		return nil, err
+	}
+	datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
+	c.streamMx.Lock()
+	c.streams[str.StreamID()] = datagrams
+	c.streamMx.Unlock()
+	qstr := newStateTrackingStream(str, c, datagrams)
+	rsp := &http.Response{}
+	hstr := newStream(qstr, c, datagrams, func(r io.Reader, l uint64) error {
+		hdr, err := c.decodeTrailers(r, l, maxHeaderBytes)
+		if err != nil {
+			return err
+		}
+		rsp.Trailer = hdr
+		return nil
+	})
+	trace := httptrace.ContextClientTrace(ctx)
+	return newRequestStream(hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes, rsp, trace), nil
+}
+
+func (c *connection) decodeTrailers(r io.Reader, l, maxHeaderBytes uint64) (http.Header, error) {
+	if l > maxHeaderBytes {
+		return nil, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", l, maxHeaderBytes)
+	}
+
+	b := make([]byte, l)
+	if _, err := io.ReadFull(r, b); err != nil {
+		return nil, err
+	}
+	fields, err := c.decoder.DecodeFull(b)
+	if err != nil {
+		return nil, err
+	}
+	return parseTrailers(fields)
+}
+
+func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagrammer, error) {
+	str, err := c.AcceptStream(ctx)
+	if err != nil {
+		return nil, nil, err
+	}
+	datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
+	if c.perspective == protocol.PerspectiveServer {
+		strID := str.StreamID()
+		c.streamMx.Lock()
+		c.streams[strID] = datagrams
+		if c.idleTimeout > 0 {
+			if len(c.streams) == 1 {
+				c.idleTimer.Stop()
+			}
+		}
+		c.streamMx.Unlock()
+		str = newStateTrackingStream(str, c, datagrams)
+	}
+	return str, datagrams, nil
+}
+
+func (c *connection) CloseWithError(code quic.ApplicationErrorCode, msg string) error {
+	if c.idleTimer != nil {
+		c.idleTimer.Stop()
+	}
+	return c.Connection.CloseWithError(code, msg)
+}
+
+func (c *connection) handleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) {
+	var (
+		rcvdControlStr      atomic.Bool
+		rcvdQPACKEncoderStr atomic.Bool
+		rcvdQPACKDecoderStr atomic.Bool
+	)
+
+	for {
+		str, err := c.Connection.AcceptUniStream(context.Background())
+		if err != nil {
+			if c.logger != nil {
+				c.logger.Debug("accepting unidirectional stream failed", "error", err)
+			}
+			return
+		}
+
+		go func(str quic.ReceiveStream) {
+			streamType, err := quicvarint.Read(quicvarint.NewReader(str))
+			if err != nil {
+				id := c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
+				if hijack != nil && hijack(StreamType(streamType), id, str, err) {
+					return
+				}
+				if c.logger != nil {
+					c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err)
+				}
+				return
+			}
+			// We're only interested in the control stream here.
+			switch streamType {
+			case streamTypeControlStream:
+			case streamTypeQPACKEncoderStream:
+				if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst {
+					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream")
+				}
+				// Our QPACK implementation doesn't use the dynamic table yet.
+				return
+			case streamTypeQPACKDecoderStream:
+				if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst {
+					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream")
+				}
+				// Our QPACK implementation doesn't use the dynamic table yet.
+				return
+			case streamTypePushStream:
+				switch c.perspective {
+				case protocol.PerspectiveClient:
+					// we never increased the Push ID, so we don't expect any push streams
+					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
+				case protocol.PerspectiveServer:
+					// only the server can push
+					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "")
+				}
+				return
+			default:
+				if hijack != nil {
+					if hijack(
+						StreamType(streamType),
+						c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID),
+						str,
+						nil,
+					) {
+						return
+					}
+				}
+				str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
+				return
+			}
+			// Only a single control stream is allowed.
+			if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr {
+				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream")
+				return
+			}
+			fp := &frameParser{conn: c.Connection, r: str}
+			f, err := fp.ParseNext()
+			if err != nil {
+				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
+				return
+			}
+			sf, ok := f.(*settingsFrame)
+			if !ok {
+				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
+				return
+			}
+			c.settings = &Settings{
+				EnableDatagrams:       sf.Datagram,
+				EnableExtendedConnect: sf.ExtendedConnect,
+				Other:                 sf.Other,
+			}
+			close(c.receivedSettings)
+			if !sf.Datagram {
+				return
+			}
+			// If datagram support was enabled on our side as well as on the server side,
+			// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
+			// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
+			if c.enableDatagrams && !c.Connection.ConnectionState().SupportsDatagrams {
+				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
+				return
+			}
+			go func() {
+				if err := c.receiveDatagrams(); err != nil {
+					if c.logger != nil {
+						c.logger.Debug("receiving datagrams failed", "error", err)
+					}
+				}
+			}()
+		}(str)
+	}
+}
+
+func (c *connection) sendDatagram(streamID protocol.StreamID, b []byte) error {
+	// TODO: this creates a lot of garbage and an additional copy
+	data := make([]byte, 0, len(b)+8)
+	data = quicvarint.Append(data, uint64(streamID/4))
+	data = append(data, b...)
+	return c.Connection.SendDatagram(data)
+}
+
+func (c *connection) receiveDatagrams() error {
+	for {
+		b, err := c.Connection.ReceiveDatagram(context.Background())
+		if err != nil {
+			return err
+		}
+		quarterStreamID, n, err := quicvarint.Parse(b)
+		if err != nil {
+			c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
+			return fmt.Errorf("could not read quarter stream id: %w", err)
+		}
+		if quarterStreamID > maxQuarterStreamID {
+			c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
+			return fmt.Errorf("invalid quarter stream id: %w", err)
+		}
+		streamID := protocol.StreamID(4 * quarterStreamID)
+		c.streamMx.Lock()
+		dg, ok := c.streams[streamID]
+		if !ok {
+			c.streamMx.Unlock()
+			return nil
+		}
+		c.streamMx.Unlock()
+		dg.enqueue(b[n:])
+	}
+}
+
+// ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received.
+// Settings can be optained from the Settings method after the channel was closed.
+func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSettings }
+
+// Settings returns the settings received on this connection.
+// It is only valid to call this function after the channel returned by ReceivedSettings was closed.
+func (c *connection) Settings() *Settings { return c.settings }
+
+// Context returns the context of the underlying QUIC connection.
+func (c *connection) Context() context.Context { return c.ctx }

+ 98 - 0
vendor/github.com/Psiphon-Labs/quic-go/http3/datagram.go

@@ -0,0 +1,98 @@
+package http3
+
+import (
+	"context"
+	"sync"
+)
+
+const maxQuarterStreamID = 1<<60 - 1
+
+const streamDatagramQueueLen = 32
+
+type datagrammer struct {
+	sendDatagram func([]byte) error
+
+	hasData chan struct{}
+	queue   [][]byte // TODO: use a ring buffer
+
+	mx         sync.Mutex
+	sendErr    error
+	receiveErr error
+}
+
+func newDatagrammer(sendDatagram func([]byte) error) *datagrammer {
+	return &datagrammer{
+		sendDatagram: sendDatagram,
+		hasData:      make(chan struct{}, 1),
+	}
+}
+
+func (d *datagrammer) SetReceiveError(err error) {
+	d.mx.Lock()
+	defer d.mx.Unlock()
+
+	d.receiveErr = err
+	d.signalHasData()
+}
+
+func (d *datagrammer) SetSendError(err error) {
+	d.mx.Lock()
+	defer d.mx.Unlock()
+
+	d.sendErr = err
+}
+
+func (d *datagrammer) Send(b []byte) error {
+	d.mx.Lock()
+	sendErr := d.sendErr
+	d.mx.Unlock()
+	if sendErr != nil {
+		return sendErr
+	}
+
+	return d.sendDatagram(b)
+}
+
+func (d *datagrammer) signalHasData() {
+	select {
+	case d.hasData <- struct{}{}:
+	default:
+	}
+}
+
+func (d *datagrammer) enqueue(data []byte) {
+	d.mx.Lock()
+	defer d.mx.Unlock()
+
+	if d.receiveErr != nil {
+		return
+	}
+	if len(d.queue) >= streamDatagramQueueLen {
+		return
+	}
+	d.queue = append(d.queue, data)
+	d.signalHasData()
+}
+
+func (d *datagrammer) Receive(ctx context.Context) ([]byte, error) {
+start:
+	d.mx.Lock()
+	if len(d.queue) >= 1 {
+		data := d.queue[0]
+		d.queue = d.queue[1:]
+		d.mx.Unlock()
+		return data, nil
+	}
+	if receiveErr := d.receiveErr; receiveErr != nil {
+		d.mx.Unlock()
+		return nil, receiveErr
+	}
+	d.mx.Unlock()
+
+	select {
+	case <-ctx.Done():
+		return nil, context.Cause(ctx)
+	case <-d.hasData:
+	}
+	goto start
+}

+ 5 - 0
vendor/github.com/Psiphon-Labs/quic-go/http3/error.go

@@ -33,6 +33,11 @@ func (e *Error) Error() string {
 	return s
 }
 
+func (e *Error) Is(target error) bool {
+	t, ok := target.(*Error)
+	return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
+}
+
 func maybeReplaceError(err error) error {
 	if err == nil {
 		return nil

+ 72 - 15
vendor/github.com/Psiphon-Labs/quic-go/http3/frames.go

@@ -6,7 +6,7 @@ import (
 	"fmt"
 	"io"
 
-	"github.com/Psiphon-Labs/quic-go/internal/protocol"
+	"github.com/Psiphon-Labs/quic-go"
 	"github.com/Psiphon-Labs/quic-go/quicvarint"
 )
 
@@ -19,13 +19,19 @@ type frame interface{}
 
 var errHijacked = errors.New("hijacked")
 
-func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) {
-	qr := quicvarint.NewReader(r)
+type frameParser struct {
+	r                   io.Reader
+	conn                quic.Connection
+	unknownFrameHandler unknownFrameHandlerFunc
+}
+
+func (p *frameParser) ParseNext() (frame, error) {
+	qr := quicvarint.NewReader(p.r)
 	for {
 		t, err := quicvarint.Read(qr)
 		if err != nil {
-			if unknownFrameHandler != nil {
-				hijacked, err := unknownFrameHandler(0, err)
+			if p.unknownFrameHandler != nil {
+				hijacked, err := p.unknownFrameHandler(0, err)
 				if err != nil {
 					return nil, err
 				}
@@ -36,8 +42,8 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
 			return nil, err
 		}
 		// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
-		if t > 0xd && unknownFrameHandler != nil {
-			hijacked, err := unknownFrameHandler(FrameType(t), nil)
+		if t > 0xd && p.unknownFrameHandler != nil {
+			hijacked, err := p.unknownFrameHandler(FrameType(t), nil)
 			if err != nil {
 				return nil, err
 			}
@@ -57,11 +63,15 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
 		case 0x1:
 			return &headersFrame{Length: l}, nil
 		case 0x4:
-			return parseSettingsFrame(r, l)
+			return parseSettingsFrame(p.r, l)
 		case 0x3: // CANCEL_PUSH
 		case 0x5: // PUSH_PROMISE
-		case 0x7: // GOAWAY
+		case 0x7:
+			return parseGoAwayFrame(qr, l)
 		case 0xd: // MAX_PUSH_ID
+		case 0x2, 0x6, 0x8, 0x9:
+			p.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
+			return nil, fmt.Errorf("http3: reserved frame type: %d", t)
 		}
 		// skip over unknown frames
 		if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil {
@@ -88,11 +98,18 @@ func (f *headersFrame) Append(b []byte) []byte {
 	return quicvarint.Append(b, f.Length)
 }
 
-const settingDatagram = 0x33
+const (
+	// Extended CONNECT, RFC 9220
+	settingExtendedConnect = 0x8
+	// HTTP Datagrams, RFC 9297
+	settingDatagram = 0x33
+)
 
 type settingsFrame struct {
-	Datagram bool
-	Other    map[uint64]uint64 // all settings that we don't explicitly recognize
+	Datagram        bool // HTTP Datagrams, RFC 9297
+	ExtendedConnect bool // Extended CONNECT, RFC 9220
+
+	Other map[uint64]uint64 // all settings that we don't explicitly recognize
 }
 
 func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
@@ -108,7 +125,7 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
 	}
 	frame := &settingsFrame{}
 	b := bytes.NewReader(buf)
-	var readDatagram bool
+	var readDatagram, readExtendedConnect bool
 	for b.Len() > 0 {
 		id, err := quicvarint.Read(b)
 		if err != nil { // should not happen. We allocated the whole frame already.
@@ -120,13 +137,22 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
 		}
 
 		switch id {
+		case settingExtendedConnect:
+			if readExtendedConnect {
+				return nil, fmt.Errorf("duplicate setting: %d", id)
+			}
+			readExtendedConnect = true
+			if val != 0 && val != 1 {
+				return nil, fmt.Errorf("invalid value for SETTINGS_ENABLE_CONNECT_PROTOCOL: %d", val)
+			}
+			frame.ExtendedConnect = val == 1
 		case settingDatagram:
 			if readDatagram {
 				return nil, fmt.Errorf("duplicate setting: %d", id)
 			}
 			readDatagram = true
 			if val != 0 && val != 1 {
-				return nil, fmt.Errorf("invalid value for H3_DATAGRAM: %d", val)
+				return nil, fmt.Errorf("invalid value for SETTINGS_H3_DATAGRAM: %d", val)
 			}
 			frame.Datagram = val == 1
 		default:
@@ -144,21 +170,52 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
 
 func (f *settingsFrame) Append(b []byte) []byte {
 	b = quicvarint.Append(b, 0x4)
-	var l protocol.ByteCount
+	var l int
 	for id, val := range f.Other {
 		l += quicvarint.Len(id) + quicvarint.Len(val)
 	}
 	if f.Datagram {
 		l += quicvarint.Len(settingDatagram) + quicvarint.Len(1)
 	}
+	if f.ExtendedConnect {
+		l += quicvarint.Len(settingExtendedConnect) + quicvarint.Len(1)
+	}
 	b = quicvarint.Append(b, uint64(l))
 	if f.Datagram {
 		b = quicvarint.Append(b, settingDatagram)
 		b = quicvarint.Append(b, 1)
 	}
+	if f.ExtendedConnect {
+		b = quicvarint.Append(b, settingExtendedConnect)
+		b = quicvarint.Append(b, 1)
+	}
 	for id, val := range f.Other {
 		b = quicvarint.Append(b, id)
 		b = quicvarint.Append(b, val)
 	}
 	return b
 }
+
+type goAwayFrame struct {
+	StreamID quic.StreamID
+}
+
+func parseGoAwayFrame(r io.ByteReader, l uint64) (*goAwayFrame, error) {
+	frame := &goAwayFrame{}
+	cbr := countingByteReader{ByteReader: r}
+	id, err := quicvarint.Read(&cbr)
+	if err != nil {
+		return nil, err
+	}
+	if cbr.Read != int(l) {
+		return nil, errors.New("GOAWAY frame: inconsistent length")
+	}
+	frame.StreamID = quic.StreamID(id)
+	return frame, nil
+}
+
+func (f *goAwayFrame) Append(b []byte) []byte {
+	b = quicvarint.Append(b, 0x7)
+	b = quicvarint.Append(b, uint64(quicvarint.Len(uint64(f.StreamID))))
+	return quicvarint.Append(b, uint64(f.StreamID))
+}

+ 80 - 18
vendor/github.com/Psiphon-Labs/quic-go/http3/headers.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"net/textproto"
 	"net/url"
 	"strconv"
 	"strings"
@@ -22,12 +23,21 @@ type header struct {
 	Status    string
 	// for Extended connect
 	Protocol string
-	// parsed and deduplicated
+	// parsed and deduplicated. -1 if no Content-Length header is sent
 	ContentLength int64
 	// all non-pseudo headers
 	Headers http.Header
 }
 
+// connection-specific header fields must not be sent on HTTP/3
+var invalidHeaderFields = [...]string{
+	"connection",
+	"keep-alive",
+	"proxy-connection",
+	"transfer-encoding",
+	"upgrade",
+}
+
 func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) {
 	hdr := header{Headers: make(http.Header, len(headers))}
 	var readFirstRegularHeader, readContentLength bool
@@ -73,6 +83,14 @@ func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) {
 			if !httpguts.ValidHeaderFieldName(h.Name) {
 				return header{}, fmt.Errorf("invalid header field name: %q", h.Name)
 			}
+			for _, invalidField := range invalidHeaderFields {
+				if h.Name == invalidField {
+					return header{}, fmt.Errorf("invalid header field name: %q", h.Name)
+				}
+			}
+			if h.Name == "te" && h.Value != "trailers" {
+				return header{}, fmt.Errorf("invalid TE header field value: %q", h.Value)
+			}
 			readFirstRegularHeader = true
 			switch h.Name {
 			case "content-length":
@@ -89,6 +107,7 @@ func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) {
 			}
 		}
 	}
+	hdr.ContentLength = -1
 	if len(contentLengthStr) > 0 {
 		// use ParseUint instead of ParseInt, so that parsing fails on negative values
 		cl, err := strconv.ParseUint(contentLengthStr, 10, 63)
@@ -101,6 +120,17 @@ func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) {
 	return hdr, nil
 }
 
+func parseTrailers(headers []qpack.HeaderField) (http.Header, error) {
+	h := make(http.Header, len(headers))
+	for _, field := range headers {
+		if field.IsPseudo() {
+			return nil, fmt.Errorf("http3: received pseudo header in trailer: %s", field.Name)
+		}
+		h.Add(field.Name, field.Value)
+	}
+	return h, nil
+}
+
 func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) {
 	hdr, err := parseHeaders(headerFields, true)
 	if err != nil {
@@ -126,9 +156,14 @@ func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error)
 		return nil, errors.New(":path, :authority and :method must not be empty")
 	}
 
+	if !isExtendedConnected && len(hdr.Protocol) > 0 {
+		return nil, errors.New(":protocol must be empty")
+	}
+
 	var u *url.URL
 	var requestURI string
-	var protocol string
+
+	protocol := "HTTP/3.0"
 
 	if isConnect {
 		u = &url.URL{}
@@ -137,15 +172,14 @@ func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error)
 			if err != nil {
 				return nil, err
 			}
+			protocol = hdr.Protocol
 		} else {
 			u.Path = hdr.Path
 		}
 		u.Scheme = hdr.Scheme
 		u.Host = hdr.Authority
 		requestURI = hdr.Authority
-		protocol = hdr.Protocol
 	} else {
-		protocol = "HTTP/3.0"
 		u, err = url.ParseRequestURI(hdr.Path)
 		if err != nil {
 			return nil, fmt.Errorf("invalid content length: %w", err)
@@ -167,32 +201,60 @@ func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error)
 	}, nil
 }
 
-func hostnameFromRequest(req *http.Request) string {
-	if req.URL != nil {
-		return req.URL.Host
+func hostnameFromURL(url *url.URL) string {
+	if url != nil {
+		return url.Host
 	}
 	return ""
 }
 
-func responseFromHeaders(headerFields []qpack.HeaderField) (*http.Response, error) {
+// updateResponseFromHeaders sets up http.Response as an HTTP/3 response,
+// using the decoded qpack header filed.
+// It is only called for the HTTP header (and not the HTTP trailer).
+// It takes an http.Response as an argument to allow the caller to set the trailer later on.
+func updateResponseFromHeaders(rsp *http.Response, headerFields []qpack.HeaderField) error {
 	hdr, err := parseHeaders(headerFields, false)
 	if err != nil {
-		return nil, err
+		return err
 	}
 	if hdr.Status == "" {
-		return nil, errors.New("missing status field")
-	}
-	rsp := &http.Response{
-		Proto:         "HTTP/3.0",
-		ProtoMajor:    3,
-		Header:        hdr.Headers,
-		ContentLength: hdr.ContentLength,
+		return errors.New("missing status field")
 	}
+	rsp.Proto = "HTTP/3.0"
+	rsp.ProtoMajor = 3
+	rsp.Header = hdr.Headers
+	processTrailers(rsp)
+	rsp.ContentLength = hdr.ContentLength
+
 	status, err := strconv.Atoi(hdr.Status)
 	if err != nil {
-		return nil, fmt.Errorf("invalid status code: %w", err)
+		return fmt.Errorf("invalid status code: %w", err)
 	}
 	rsp.StatusCode = status
 	rsp.Status = hdr.Status + " " + http.StatusText(status)
-	return rsp, nil
+	return nil
+}
+
+// processTrailers initializes the rsp.Trailer map, and adds keys for every announced header value.
+// The Trailer header is removed from the http.Response.Header map.
+// It handles both duplicate as well as comma-separated values for the Trailer header.
+// For example:
+//
+//	Trailer: Trailer1, Trailer2
+//	Trailer: Trailer3
+//
+// Will result in a http.Response.Trailer map containing the keys "Trailer1", "Trailer2", "Trailer3".
+func processTrailers(rsp *http.Response) {
+	rawTrailers, ok := rsp.Header["Trailer"]
+	if !ok {
+		return
+	}
+
+	rsp.Trailer = make(http.Header)
+	for _, rawVal := range rawTrailers {
+		for _, val := range strings.Split(rawVal, ",") {
+			rsp.Trailer[http.CanonicalHeaderKey(textproto.TrimString(val))] = nil
+		}
+	}
+	delete(rsp.Header, "Trailer")
 }

+ 211 - 40
vendor/github.com/Psiphon-Labs/quic-go/http3/http_stream.go

@@ -1,54 +1,101 @@
 package http3
 
 import (
+	"context"
 	"errors"
 	"fmt"
+	"io"
+	"net/http"
+	"net/http/httptrace"
 
 	"github.com/Psiphon-Labs/quic-go"
+	"github.com/Psiphon-Labs/quic-go/internal/protocol"
+
+	"github.com/quic-go/qpack"
 )
 
-// A Stream is a HTTP/3 stream.
+// A Stream is an HTTP/3 request stream.
 // When writing to and reading from the stream, data is framed in HTTP/3 DATA frames.
-type Stream quic.Stream
+type Stream interface {
+	quic.Stream
+
+	SendDatagram([]byte) error
+	ReceiveDatagram(context.Context) ([]byte, error)
+}
+
+// A RequestStream is an HTTP/3 request stream.
+// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames.
+type RequestStream interface {
+	Stream
+
+	// SendRequestHeader sends the HTTP request.
+	// It is invalid to call it more than once.
+	// It is invalid to call it after Write has been called.
+	SendRequestHeader(req *http.Request) error
+
+	// ReadResponse reads the HTTP response from the stream.
+	// It is invalid to call it more than once.
+	// It doesn't set Response.Request and Response.TLS.
+	// It is invalid to call it after Read has been called.
+	ReadResponse() (*http.Response, error)
+}
 
-// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly
-// from the QUIC stream, it writes to and reads from the HTTP stream.
 type stream struct {
 	quic.Stream
+	conn *connection
 
-	buf []byte
+	buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers
 
-	onFrameError          func()
 	bytesRemainingInFrame uint64
+
+	datagrams *datagrammer
+
+	parseTrailer  func(io.Reader, uint64) error
+	parsedTrailer bool
 }
 
 var _ Stream = &stream{}
 
-func newStream(str quic.Stream, onFrameError func()) *stream {
+func newStream(str quic.Stream, conn *connection, datagrams *datagrammer, parseTrailer func(io.Reader, uint64) error) *stream {
 	return &stream{
 		Stream:       str,
-		onFrameError: onFrameError,
-		buf:          make([]byte, 0, 16),
+		conn:         conn,
+		buf:          make([]byte, 16),
+		datagrams:    datagrams,
+		parseTrailer: parseTrailer,
 	}
 }
 
 func (s *stream) Read(b []byte) (int, error) {
+	fp := &frameParser{
+		r:    s.Stream,
+		conn: s.conn,
+	}
 	if s.bytesRemainingInFrame == 0 {
 	parseLoop:
 		for {
-			frame, err := parseNextFrame(s.Stream, nil)
+			frame, err := fp.ParseNext()
 			if err != nil {
 				return 0, err
 			}
 			switch f := frame.(type) {
-			case *headersFrame:
-				// skip HEADERS frames
-				continue
 			case *dataFrame:
+				if s.parsedTrailer {
+					return 0, errors.New("DATA frame received after trailers")
+				}
 				s.bytesRemainingInFrame = f.Length
 				break parseLoop
+			case *headersFrame:
+				if s.conn.perspective == protocol.PerspectiveServer {
+					continue
+				}
+				if s.parsedTrailer {
+					return 0, errors.New("additional HEADERS frame received after trailers")
+				}
+				s.parsedTrailer = true
+				return 0, s.parseTrailer(s.Stream, f.Length)
 			default:
-				s.onFrameError()
+				s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
 				// parseNextFrame skips over unknown frame types
 				// Therefore, this condition is only entered when we parsed another known frame type.
 				return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
@@ -80,44 +127,168 @@ func (s *stream) Write(b []byte) (int, error) {
 	return s.Stream.Write(b)
 }
 
-var errTooMuchData = errors.New("peer sent too much data")
+func (s *stream) writeUnframed(b []byte) (int, error) {
+	return s.Stream.Write(b)
+}
+
+func (s *stream) StreamID() protocol.StreamID {
+	return s.Stream.StreamID()
+}
 
-type lengthLimitedStream struct {
+// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly
+// from the QUIC stream, it writes to and reads from the HTTP stream.
+type requestStream struct {
 	*stream
-	contentLength int64
-	read          int64
-	resetStream   bool
+
+	responseBody io.ReadCloser // set by ReadResponse
+
+	decoder            *qpack.Decoder
+	requestWriter      *requestWriter
+	maxHeaderBytes     uint64
+	reqDone            chan<- struct{}
+	disableCompression bool
+	response           *http.Response
+	trace              *httptrace.ClientTrace
+
+	sentRequest   bool
+	requestedGzip bool
+	isConnect     bool
+	firstByte     bool
 }
 
-var _ Stream = &lengthLimitedStream{}
+var _ RequestStream = &requestStream{}
 
-func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
-	return &lengthLimitedStream{
-		stream:        str,
-		contentLength: contentLength,
+func newRequestStream(
+	str *stream,
+	requestWriter *requestWriter,
+	reqDone chan<- struct{},
+	decoder *qpack.Decoder,
+	disableCompression bool,
+	maxHeaderBytes uint64,
+	rsp *http.Response,
+	trace *httptrace.ClientTrace,
+) *requestStream {
+	return &requestStream{
+		stream:             str,
+		requestWriter:      requestWriter,
+		reqDone:            reqDone,
+		decoder:            decoder,
+		disableCompression: disableCompression,
+		maxHeaderBytes:     maxHeaderBytes,
+		response:           rsp,
+		trace:              trace,
 	}
 }
 
-func (s *lengthLimitedStream) checkContentLengthViolation() error {
-	if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() {
-		if !s.resetStream {
-			s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
-			s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
-			s.resetStream = true
-		}
-		return errTooMuchData
+func (s *requestStream) Read(b []byte) (int, error) {
+	if s.responseBody == nil {
+		return 0, errors.New("http3: invalid use of RequestStream.Read: need to call ReadResponse first")
 	}
-	return nil
+	return s.responseBody.Read(b)
 }
 
-func (s *lengthLimitedStream) Read(b []byte) (int, error) {
-	if err := s.checkContentLengthViolation(); err != nil {
-		return 0, err
+func (s *requestStream) SendRequestHeader(req *http.Request) error {
+	if s.sentRequest {
+		return errors.New("http3: invalid duplicate use of SendRequestHeader")
+	}
+	if !s.disableCompression && req.Method != http.MethodHead &&
+		req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
+		s.requestedGzip = true
+	}
+	s.isConnect = req.Method == http.MethodConnect
+	s.sentRequest = true
+	return s.requestWriter.WriteRequestHeader(s.Stream, req, s.requestedGzip)
+}
+
+func (s *requestStream) ReadResponse() (*http.Response, error) {
+	fp := &frameParser{
+		conn: s.conn,
+		r: &tracingReader{
+			Reader: s.Stream,
+			first:  &s.firstByte,
+			trace:  s.trace,
+		},
+	}
+	frame, err := fp.ParseNext()
+	if err != nil {
+		s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
+		s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
+		return nil, fmt.Errorf("http3: parsing frame failed: %w", err)
 	}
-	n, err := s.stream.Read(b[:min(int64(len(b)), s.contentLength-s.read)])
-	s.read += int64(n)
-	if err := s.checkContentLengthViolation(); err != nil {
-		return n, err
+	hf, ok := frame.(*headersFrame)
+	if !ok {
+		s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame")
+		return nil, errors.New("http3: expected first frame to be a HEADERS frame")
+	}
+	if hf.Length > s.maxHeaderBytes {
+		s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
+		s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
+		return nil, fmt.Errorf("http3: HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes)
+	}
+	headerBlock := make([]byte, hf.Length)
+	if _, err := io.ReadFull(s.Stream, headerBlock); err != nil {
+		s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
+		s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
+		return nil, fmt.Errorf("http3: failed to read response headers: %w", err)
+	}
+	hfs, err := s.decoder.DecodeFull(headerBlock)
+	if err != nil {
+		// TODO: use the right error code
+		s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "")
+		return nil, fmt.Errorf("http3: failed to decode response headers: %w", err)
+	}
+	res := s.response
+	if err := updateResponseFromHeaders(res, hfs); err != nil {
+		s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
+		s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
+		return nil, fmt.Errorf("http3: invalid response: %w", err)
+	}
+
+	// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
+	// See section 4.1.2 of RFC 9114.
+	respBody := newResponseBody(s.stream, res.ContentLength, s.reqDone)
+
+	// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
+	isInformational := res.StatusCode >= 100 && res.StatusCode < 200
+	isNoContent := res.StatusCode == http.StatusNoContent
+	isSuccessfulConnect := s.isConnect && res.StatusCode >= 200 && res.StatusCode < 300
+	if (isInformational || isNoContent || isSuccessfulConnect) && res.ContentLength == -1 {
+		res.ContentLength = 0
+	}
+	if s.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
+		res.Header.Del("Content-Encoding")
+		res.Header.Del("Content-Length")
+		res.ContentLength = -1
+		s.responseBody = newGzipReader(respBody)
+		res.Uncompressed = true
+	} else {
+		s.responseBody = respBody
+	}
+	res.Body = s.responseBody
+	return res, nil
+}
+
+func (s *stream) SendDatagram(b []byte) error {
+	// TODO: reject if datagrams are not negotiated (yet)
+	return s.datagrams.Send(b)
+}
+
+func (s *stream) ReceiveDatagram(ctx context.Context) ([]byte, error) {
+	// TODO: reject if datagrams are not negotiated (yet)
+	return s.datagrams.Receive(ctx)
+}
+
+type tracingReader struct {
+	io.Reader
+	first *bool
+	trace *httptrace.ClientTrace
+}
+
+func (r *tracingReader) Read(b []byte) (int, error) {
+	n, err := r.Reader.Read(b)
+	if n > 0 && r.first != nil && !*r.first {
+		traceGotFirstResponseByte(r.trace)
+		*r.first = true
 	}
 	return n, err
 }

+ 48 - 0
vendor/github.com/Psiphon-Labs/quic-go/http3/ip_addr.go

@@ -0,0 +1,48 @@
+package http3
+
+import (
+	"net"
+	"strings"
+)
+
+// An addrList represents a list of network endpoint addresses.
+// Copy from [net.addrList] and change type from [net.Addr] to [net.IPAddr]
+type addrList []net.IPAddr
+
+// isIPv4 reports whether addr contains an IPv4 address.
+func isIPv4(addr net.IPAddr) bool {
+	return addr.IP.To4() != nil
+}
+
+// isNotIPv4 reports whether addr does not contain an IPv4 address.
+func isNotIPv4(addr net.IPAddr) bool { return !isIPv4(addr) }
+
+// forResolve returns the most appropriate address in address for
+// a call to ResolveTCPAddr, ResolveUDPAddr, or ResolveIPAddr.
+// IPv4 is preferred, unless addr contains an IPv6 literal.
+func (addrs addrList) forResolve(network, addr string) net.IPAddr {
+	var want6 bool
+	switch network {
+	case "ip":
+		// IPv6 literal (addr does NOT contain a port)
+		want6 = strings.ContainsRune(addr, ':')
+	case "tcp", "udp":
+		// IPv6 literal. (addr contains a port, so look for '[')
+		want6 = strings.ContainsRune(addr, '[')
+	}
+	if want6 {
+		return addrs.first(isNotIPv4)
+	}
+	return addrs.first(isIPv4)
+}
+
+// first returns the first address which satisfies strategy, or if
+// none do, then the first address of any kind.
+func (addrs addrList) first(strategy func(net.IPAddr) bool) net.IPAddr {
+	for _, addr := range addrs {
+		if strategy(addr) {
+			return addr
+		}
+	}
+	return addrs[0]
+}

+ 2 - 2
vendor/github.com/Psiphon-Labs/quic-go/http3/mockgen.go

@@ -2,7 +2,7 @@
 
 package http3
 
-//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\"  -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser"
-type RoundTripCloser = roundTripCloser
+//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -mock_names=TestClientConnInterface=MockClientConn  -package http3 -destination mock_clientconn_test.go github.com/quic-go/quic-go/http3 TestClientConnInterface"
+type TestClientConnInterface = clientConn
 
 //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener"

+ 20 - 18
vendor/github.com/Psiphon-Labs/quic-go/http3/request_writer.go

@@ -7,6 +7,7 @@ import (
 	"io"
 	"net"
 	"net/http"
+	"net/http/httptrace"
 	"strconv"
 	"strings"
 	"sync"
@@ -15,9 +16,8 @@ import (
 	"golang.org/x/net/http2/hpack"
 	"golang.org/x/net/idna"
 
-	"github.com/Psiphon-Labs/quic-go"
-	"github.com/Psiphon-Labs/quic-go/internal/utils"
 	"github.com/quic-go/qpack"
+	"github.com/Psiphon-Labs/quic-go"
 )
 
 const bodyCopyBufferSize = 8 * 1024
@@ -26,17 +26,14 @@ type requestWriter struct {
 	mutex     sync.Mutex
 	encoder   *qpack.Encoder
 	headerBuf *bytes.Buffer
-
-	logger utils.Logger
 }
 
-func newRequestWriter(logger utils.Logger) *requestWriter {
+func newRequestWriter() *requestWriter {
 	headerBuf := &bytes.Buffer{}
 	encoder := qpack.NewEncoder(headerBuf)
 	return &requestWriter{
 		encoder:   encoder,
 		headerBuf: headerBuf,
-		logger:    logger,
 	}
 }
 
@@ -46,8 +43,12 @@ func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, g
 	if err := w.writeHeaders(buf, req, gzip); err != nil {
 		return err
 	}
-	_, err := str.Write(buf.Bytes())
-	return err
+	if _, err := str.Write(buf.Bytes()); err != nil {
+		return err
+	}
+	trace := httptrace.ContextClientTrace(req.Context())
+	traceWroteHeaders(trace)
+	return nil
 }
 
 func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
@@ -69,6 +70,10 @@ func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool)
 	return err
 }
 
+func isExtendedConnectRequest(req *http.Request) bool {
+	return req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1"
+}
+
 // copied from net/transport.go
 // Modified to support Extended CONNECT:
 // Contrary to what the godoc for the http.Request says,
@@ -87,7 +92,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
 	}
 
 	// http.NewRequest sets this field to HTTP/1.1
-	isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1"
+	isExtendedConnect := isExtendedConnectRequest(req)
 
 	var path string
 	if req.Method != http.MethodConnect || isExtendedConnect {
@@ -198,16 +203,16 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
 	// 	return errRequestHeaderListSize
 	// }
 
-	// trace := httptrace.ContextClientTrace(req.Context())
-	// traceHeaders := traceHasWroteHeaderField(trace)
+	trace := httptrace.ContextClientTrace(req.Context())
+	traceHeaders := traceHasWroteHeaderField(trace)
 
 	// Header list size is ok. Write the headers.
 	enumerateHeaders(func(name, value string) {
 		name = strings.ToLower(name)
 		w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
-		// if traceHeaders {
-		// 	traceWroteHeaderField(trace, name, value)
-		// }
+		if traceHeaders {
+			traceWroteHeaderField(trace, name, value)
+		}
 	})
 
 	return nil
@@ -215,13 +220,10 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
 
 // authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
 // and returns a host:port. The port 443 is added if needed.
-func authorityAddr(scheme string, authority string) (addr string) {
+func authorityAddr(authority string) (addr string) {
 	host, port, err := net.SplitHostPort(authority)
 	if err != nil { // authority didn't have a port
 		port = "443"
-		if scheme == "http" {
-			port = "80"
-		}
 		host = authority
 	}
 	if a, err := idna.ToASCII(host); err == nil {

+ 244 - 107
vendor/github.com/Psiphon-Labs/quic-go/http3/response_writer.go

@@ -1,96 +1,73 @@
 package http3
 
 import (
-	"bufio"
 	"bytes"
 	"fmt"
+	"log/slog"
 	"net/http"
+	"net/textproto"
 	"strconv"
 	"strings"
 	"time"
 
-	"github.com/Psiphon-Labs/quic-go"
-	"github.com/Psiphon-Labs/quic-go/internal/utils"
-
 	"github.com/quic-go/qpack"
+	"golang.org/x/net/http/httpguts"
 )
 
+// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented the http.Response.Body.
+// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
+// When a stream is taken over, it's the caller's responsibility to close the stream.
+type HTTPStreamer interface {
+	HTTPStream() Stream
+}
+
 // The maximum length of an encoded HTTP/3 frame header is 16:
 // The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
 const frameHeaderLen = 16
 
-// headerWriter wraps the stream, so that the first Write call flushes the header to the stream
-type headerWriter struct {
-	str     quic.Stream
-	header  http.Header
-	status  int // status code passed to WriteHeader
-	written bool
+const maxSmallResponseSize = 4096
 
-	logger utils.Logger
-}
+type responseWriter struct {
+	str *stream
 
-// writeHeader encodes and flush header to the stream
-func (hw *headerWriter) writeHeader() error {
-	var headers bytes.Buffer
-	enc := qpack.NewEncoder(&headers)
-	enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)})
+	conn     Connection
+	header   http.Header
+	trailers map[string]struct{}
+	buf      []byte
+	status   int // status code passed to WriteHeader
 
-	for k, v := range hw.header {
-		for index := range v {
-			enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
-		}
-	}
+	// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
+	// and automatically add the Content-Length header
+	smallResponseBuf []byte
 
-	buf := make([]byte, 0, frameHeaderLen+headers.Len())
-	buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
-	hw.logger.Infof("Responding with %d", hw.status)
-	buf = append(buf, headers.Bytes()...)
+	contentLen     int64 // if handler set valid Content-Length header
+	numWritten     int64 // bytes written
+	headerComplete bool  // set once WriteHeader is called with a status code >= 200
+	headerWritten  bool  // set once the response header has been serialized to the stream
+	isHead         bool
+	trailerWritten bool // set once the response trailers has been serialized to the stream
 
-	_, err := hw.str.Write(buf)
-	return err
-}
-
-// first Write will trigger flushing header
-func (hw *headerWriter) Write(p []byte) (int, error) {
-	if !hw.written {
-		if err := hw.writeHeader(); err != nil {
-			return 0, err
-		}
-		hw.written = true
-	}
-	return hw.str.Write(p)
-}
-
-type responseWriter struct {
-	*headerWriter
-	conn        quic.Connection
-	bufferedStr *bufio.Writer
-	buf         []byte
+	hijacked bool // set on HTTPStream is called
 
-	contentLen    int64 // if handler set valid Content-Length header
-	numWritten    int64 // bytes written
-	headerWritten bool
-	isHead        bool
+	logger *slog.Logger
 }
 
 var (
 	_ http.ResponseWriter = &responseWriter{}
 	_ http.Flusher        = &responseWriter{}
 	_ Hijacker            = &responseWriter{}
+	_ HTTPStreamer        = &responseWriter{}
 )
 
-func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
-	hw := &headerWriter{
+func newResponseWriter(str *stream, conn Connection, isHead bool, logger *slog.Logger) *responseWriter {
+	return &responseWriter{
 		str:    str,
+		conn:   conn,
 		header: http.Header{},
+		buf:    make([]byte, frameHeaderLen),
+		isHead: isHead,
 		logger: logger,
 	}
-	return &responseWriter{
-		headerWriter: hw,
-		buf:          make([]byte, frameHeaderLen),
-		conn:         conn,
-		bufferedStr:  bufio.NewWriter(hw),
-	}
 }
 
 func (w *responseWriter) Header() http.Header {
@@ -98,7 +75,7 @@ func (w *responseWriter) Header() http.Header {
 }
 
 func (w *responseWriter) WriteHeader(status int) {
-	if w.headerWritten {
+	if w.headerComplete {
 		return
 	}
 
@@ -106,51 +83,55 @@ func (w *responseWriter) WriteHeader(status int) {
 	if status < 100 || status > 999 {
 		panic(fmt.Sprintf("invalid WriteHeader code %v", status))
 	}
+	w.status = status
 
-	if status >= 200 {
-		w.headerWritten = true
-		// Add Date header.
-		// This is what the standard library does.
-		// Can be disabled by setting the Date header to nil.
-		if _, ok := w.header["Date"]; !ok {
-			w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
-		}
-		// Content-Length checking
-		// use ParseUint instead of ParseInt, as negative values are invalid
-		if clen := w.header.Get("Content-Length"); clen != "" {
-			if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
-				w.contentLen = int64(cl)
-			} else {
-				// emit a warning for malformed Content-Length and remove it
-				w.logger.Errorf("Malformed Content-Length %s", clen)
-				w.header.Del("Content-Length")
+	// immediately write 1xx headers
+	if status < 200 {
+		w.writeHeader(status)
+		return
+	}
+
+	// We're done with headers once we write a status >= 200.
+	w.headerComplete = true
+	// Add Date header.
+	// This is what the standard library does.
+	// Can be disabled by setting the Date header to nil.
+	if _, ok := w.header["Date"]; !ok {
+		w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
+	}
+	// Content-Length checking
+	// use ParseUint instead of ParseInt, as negative values are invalid
+	if clen := w.header.Get("Content-Length"); clen != "" {
+		if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
+			w.contentLen = int64(cl)
+		} else {
+			// emit a warning for malformed Content-Length and remove it
+			logger := w.logger
+			if logger == nil {
+				logger = slog.Default()
 			}
+			logger.Error("Malformed Content-Length", "value", clen)
+			w.header.Del("Content-Length")
 		}
 	}
-	w.status = status
+}
 
-	if !w.headerWritten {
-		w.writeHeader()
+func (w *responseWriter) sniffContentType(p []byte) {
+	// If no content type, apply sniffing algorithm to body.
+	// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shouldn't do sniffing.
+	_, haveType := w.header["Content-Type"]
+
+	// If the Content-Encoding was set and is non-blank, we shouldn't sniff the body.
+	hasCE := w.header.Get("Content-Encoding") != ""
+	if !hasCE && !haveType && len(p) > 0 {
+		w.header.Set("Content-Type", http.DetectContentType(p))
 	}
 }
 
 func (w *responseWriter) Write(p []byte) (int, error) {
 	bodyAllowed := bodyAllowedForStatus(w.status)
-	if !w.headerWritten {
-		// If body is not allowed, we don't need to (and we can't) sniff the content type.
-		if bodyAllowed {
-			// If no content type, apply sniffing algorithm to body.
-			// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing.
-			_, haveType := w.header["Content-Type"]
-
-			// If the Transfer-Encoding or Content-Encoding was set and is non-blank,
-			// we shouldn't sniff the body.
-			hasTE := w.header.Get("Transfer-Encoding") != ""
-			hasCE := w.header.Get("Content-Encoding") != ""
-			if !hasCE && !haveType && !hasTE && len(p) > 0 {
-				w.header.Set("Content-Type", http.DetectContentType(p))
-			}
-		}
+	if !w.headerComplete {
+		w.sniffContentType(p)
 		w.WriteHeader(http.StatusOK)
 		bodyAllowed = true
 	}
@@ -167,36 +148,192 @@ func (w *responseWriter) Write(p []byte) (int, error) {
 		return len(p), nil
 	}
 
-	df := &dataFrame{Length: uint64(len(p))}
+	if !w.headerWritten {
+		// Buffer small responses.
+		// This allows us to automatically set the Content-Length field.
+		if len(w.smallResponseBuf)+len(p) < maxSmallResponseSize {
+			w.smallResponseBuf = append(w.smallResponseBuf, p...)
+			return len(p), nil
+		}
+	}
+	return w.doWrite(p)
+}
+
+func (w *responseWriter) doWrite(p []byte) (int, error) {
+	if !w.headerWritten {
+		w.sniffContentType(w.smallResponseBuf)
+		if err := w.writeHeader(w.status); err != nil {
+			return 0, maybeReplaceError(err)
+		}
+		w.headerWritten = true
+	}
+
+	l := uint64(len(w.smallResponseBuf) + len(p))
+	if l == 0 {
+		return 0, nil
+	}
+	df := &dataFrame{Length: l}
 	w.buf = w.buf[:0]
 	w.buf = df.Append(w.buf)
-	if _, err := w.bufferedStr.Write(w.buf); err != nil {
+	if _, err := w.str.writeUnframed(w.buf); err != nil {
 		return 0, maybeReplaceError(err)
 	}
-	n, err := w.bufferedStr.Write(p)
-	return n, maybeReplaceError(err)
+	if len(w.smallResponseBuf) > 0 {
+		if _, err := w.str.writeUnframed(w.smallResponseBuf); err != nil {
+			return 0, maybeReplaceError(err)
+		}
+		w.smallResponseBuf = nil
+	}
+	var n int
+	if len(p) > 0 {
+		var err error
+		n, err = w.str.writeUnframed(p)
+		if err != nil {
+			return n, maybeReplaceError(err)
+		}
+	}
+	return n, nil
+}
+
+func (w *responseWriter) writeHeader(status int) error {
+	var headers bytes.Buffer
+	enc := qpack.NewEncoder(&headers)
+	if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}); err != nil {
+		return err
+	}
+
+	// Handle trailer fields
+	if vals, ok := w.header["Trailer"]; ok {
+		for _, val := range vals {
+			for _, trailer := range strings.Split(val, ",") {
+				// We need to convert to the canonical header key value here because this will be called when using
+				// headers.Add or headers.Set.
+				trailer = textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(trailer))
+				w.declareTrailer(trailer)
+			}
+		}
+	}
+
+	for k, v := range w.header {
+		if _, excluded := w.trailers[k]; excluded {
+			continue
+		}
+		// Ignore "Trailer:" prefixed headers
+		if strings.HasPrefix(k, http.TrailerPrefix) {
+			continue
+		}
+		for index := range v {
+			if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
+				return err
+			}
+		}
+	}
+
+	buf := make([]byte, 0, frameHeaderLen+headers.Len())
+	buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
+	buf = append(buf, headers.Bytes()...)
+
+	_, err := w.str.writeUnframed(buf)
+	return err
 }
 
 func (w *responseWriter) FlushError() error {
-	if !w.headerWritten {
+	if !w.headerComplete {
 		w.WriteHeader(http.StatusOK)
 	}
-	if !w.written {
-		if err := w.writeHeader(); err != nil {
-			return maybeReplaceError(err)
-		}
-		w.written = true
+	_, err := w.doWrite(nil)
+	return err
+}
+
+func (w *responseWriter) flushTrailers() {
+	if w.trailerWritten {
+		return
+	}
+	if err := w.writeTrailers(); err != nil {
+		w.logger.Debug("could not write trailers", "error", err)
 	}
-	return w.bufferedStr.Flush()
 }
 
 func (w *responseWriter) Flush() {
 	if err := w.FlushError(); err != nil {
-		w.logger.Errorf("could not flush to stream: %s", err.Error())
+		if w.logger != nil {
+			w.logger.Debug("could not flush to stream", "error", err)
+		}
+	}
+}
+
+// declareTrailer adds a trailer to the trailer list, while also validating that the trailer has a
+// valid name.
+func (w *responseWriter) declareTrailer(k string) {
+	if !httpguts.ValidTrailerHeader(k) {
+		// Forbidden by RFC 9110, section 6.5.1.
+		w.logger.Debug("ignoring invalid trailer", slog.String("header", k))
+		return
+	}
+	if w.trailers == nil {
+		w.trailers = make(map[string]struct{})
+	}
+	w.trailers[k] = struct{}{}
+}
+
+// hasNonEmptyTrailers checks to see if there are any trailers with an actual
+// value set. This is possible by adding trailers to the "Trailers" header
+// but never actually setting those names as trailers in the course of handling
+// the request. In that case, this check may save us some allocations.
+func (w *responseWriter) hasNonEmptyTrailers() bool {
+	for trailer := range w.trailers {
+		if _, ok := w.header[trailer]; ok {
+			return true
+		}
+	}
+	return false
+}
+
+// writeTrailers will write trailers to the stream if there are any.
+func (w *responseWriter) writeTrailers() error {
+	// promote headers added via "Trailer:" convention as trailers, these can be added after
+	// streaming the status/headers have been written.
+	for k := range w.header {
+		// Handle "Trailer:" prefix
+		if strings.HasPrefix(k, http.TrailerPrefix) {
+			w.declareTrailer(k)
+		}
+	}
+
+	if !w.hasNonEmptyTrailers() {
+		return nil
 	}
+
+	var b bytes.Buffer
+	enc := qpack.NewEncoder(&b)
+	for trailer := range w.trailers {
+		trailerName := strings.ToLower(strings.TrimPrefix(trailer, http.TrailerPrefix))
+		if vals, ok := w.header[trailer]; ok {
+			for _, val := range vals {
+				if err := enc.WriteField(qpack.HeaderField{Name: trailerName, Value: val}); err != nil {
+					return err
+				}
+			}
+		}
+	}
+
+	buf := make([]byte, 0, frameHeaderLen+b.Len())
+	buf = (&headersFrame{Length: uint64(b.Len())}).Append(buf)
+	buf = append(buf, b.Bytes()...)
+	_, err := w.str.writeUnframed(buf)
+	w.trailerWritten = true
+	return err
 }
 
-func (w *responseWriter) StreamCreator() StreamCreator {
+func (w *responseWriter) HTTPStream() Stream {
+	w.hijacked = true
+	w.Flush()
+	return w.str
+}
+
+func (w *responseWriter) wasStreamHijacked() bool { return w.hijacked }
+
+func (w *responseWriter) Connection() Connection {
 	return w.conn
 }
 

+ 0 - 303
vendor/github.com/Psiphon-Labs/quic-go/http3/roundtrip.go

@@ -1,303 +0,0 @@
-package http3
-
-import (
-	"context"
-	"errors"
-	"fmt"
-	"io"
-	"net"
-	"net/http"
-	"strings"
-	"sync"
-	"sync/atomic"
-
-	tls "github.com/Psiphon-Labs/psiphon-tls"
-
-	"golang.org/x/net/http/httpguts"
-
-	"github.com/Psiphon-Labs/quic-go"
-)
-
-type roundTripCloser interface {
-	RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
-	HandshakeComplete() bool
-	io.Closer
-}
-
-type roundTripCloserWithCount struct {
-	roundTripCloser
-	useCount atomic.Int64
-}
-
-// RoundTripper implements the http.RoundTripper interface
-type RoundTripper struct {
-	mutex sync.Mutex
-
-	// DisableCompression, if true, prevents the Transport from
-	// requesting compression with an "Accept-Encoding: gzip"
-	// request header when the Request contains no existing
-	// Accept-Encoding value. If the Transport requests gzip on
-	// its own and gets a gzipped response, it's transparently
-	// decoded in the Response.Body. However, if the user
-	// explicitly requested gzip it is not automatically
-	// uncompressed.
-	DisableCompression bool
-
-	// TLSClientConfig specifies the TLS configuration to use with
-	// tls.Client. If nil, the default configuration is used.
-	TLSClientConfig *tls.Config
-
-	// QuicConfig is the quic.Config used for dialing new connections.
-	// If nil, reasonable default values will be used.
-	QuicConfig *quic.Config
-
-	// Enable support for HTTP/3 datagrams.
-	// If set to true, QuicConfig.EnableDatagram will be set.
-	// See https://datatracker.ietf.org/doc/html/rfc9297.
-	EnableDatagrams bool
-
-	// Additional HTTP/3 settings.
-	// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
-	AdditionalSettings map[uint64]uint64
-
-	// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
-	// It is called right after parsing the frame type.
-	// If parsing the frame type fails, the error is passed to the callback.
-	// In that case, the frame type will not be set.
-	// Callers can either ignore the frame and return control of the stream back to HTTP/3
-	// (by returning hijacked false).
-	// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
-	StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
-
-	// When set, this callback is called for unknown unidirectional stream of unknown stream type.
-	// If parsing the stream type fails, the error is passed to the callback.
-	// In that case, the stream type will not be set.
-	UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
-
-	// Dial specifies an optional dial function for creating QUIC
-	// connections for requests.
-	// If Dial is nil, a UDPConn will be created at the first request
-	// and will be reused for subsequent connections to other servers.
-	Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
-
-	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
-	// allowed in the server's response header.
-	// Zero means to use a default limit.
-	MaxResponseHeaderBytes int64
-
-	newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
-	clients   map[string]*roundTripCloserWithCount
-	transport *quic.Transport
-}
-
-// RoundTripOpt are options for the Transport.RoundTripOpt method.
-type RoundTripOpt struct {
-	// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
-	// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
-	OnlyCachedConn bool
-	// DontCloseRequestStream controls whether the request stream is closed after sending the request.
-	// If set, context cancellations have no effect after the response headers are received.
-	DontCloseRequestStream bool
-}
-
-var (
-	_ http.RoundTripper = &RoundTripper{}
-	_ io.Closer         = &RoundTripper{}
-)
-
-// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
-var ErrNoCachedConn = errors.New("http3: no cached connection was available")
-
-// RoundTripOpt is like RoundTrip, but takes options.
-func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
-	if req.URL == nil {
-		closeRequestBody(req)
-		return nil, errors.New("http3: nil Request.URL")
-	}
-	if req.URL.Scheme != "https" {
-		closeRequestBody(req)
-		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
-	}
-	if req.URL.Host == "" {
-		closeRequestBody(req)
-		return nil, errors.New("http3: no Host in request URL")
-	}
-	if req.Header == nil {
-		closeRequestBody(req)
-		return nil, errors.New("http3: nil Request.Header")
-	}
-	for k, vv := range req.Header {
-		if !httpguts.ValidHeaderFieldName(k) {
-			return nil, fmt.Errorf("http3: invalid http header field name %q", k)
-		}
-		for _, v := range vv {
-			if !httpguts.ValidHeaderFieldValue(v) {
-				return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
-			}
-		}
-	}
-
-	if req.Method != "" && !validMethod(req.Method) {
-		closeRequestBody(req)
-		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
-	}
-
-	hostname := authorityAddr("https", hostnameFromRequest(req))
-	cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
-	if err != nil {
-		return nil, err
-	}
-	defer cl.useCount.Add(-1)
-	rsp, err := cl.RoundTripOpt(req, opt)
-	if err != nil {
-		r.removeClient(hostname)
-		if isReused {
-			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
-				return r.RoundTripOpt(req, opt)
-			}
-		}
-	}
-	return rsp, err
-}
-
-// RoundTrip does a round trip.
-func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
-	return r.RoundTripOpt(req, RoundTripOpt{})
-}
-
-func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
-	r.mutex.Lock()
-	defer r.mutex.Unlock()
-
-	if r.clients == nil {
-		r.clients = make(map[string]*roundTripCloserWithCount)
-	}
-
-	client, ok := r.clients[hostname]
-	if !ok {
-		if onlyCached {
-			return nil, false, ErrNoCachedConn
-		}
-		var err error
-		newCl := newClient
-		if r.newClient != nil {
-			newCl = r.newClient
-		}
-		dial := r.Dial
-		if dial == nil {
-			if r.transport == nil {
-				udpConn, err := net.ListenUDP("udp", nil)
-				if err != nil {
-					return nil, false, err
-				}
-				r.transport = &quic.Transport{Conn: udpConn}
-			}
-			dial = r.makeDialer()
-		}
-		c, err := newCl(
-			hostname,
-			r.TLSClientConfig,
-			&roundTripperOpts{
-				EnableDatagram:     r.EnableDatagrams,
-				DisableCompression: r.DisableCompression,
-				MaxHeaderBytes:     r.MaxResponseHeaderBytes,
-				StreamHijacker:     r.StreamHijacker,
-				UniStreamHijacker:  r.UniStreamHijacker,
-				AdditionalSettings: r.AdditionalSettings,
-			},
-			r.QuicConfig,
-			dial,
-		)
-		if err != nil {
-			return nil, false, err
-		}
-		client = &roundTripCloserWithCount{roundTripCloser: c}
-		r.clients[hostname] = client
-	} else if client.HandshakeComplete() {
-		isReused = true
-	}
-	client.useCount.Add(1)
-	return client, isReused, nil
-}
-
-func (r *RoundTripper) removeClient(hostname string) {
-	r.mutex.Lock()
-	defer r.mutex.Unlock()
-	if r.clients == nil {
-		return
-	}
-	delete(r.clients, hostname)
-}
-
-// Close closes the QUIC connections that this RoundTripper has used.
-// It also closes the underlying UDPConn if it is not nil.
-func (r *RoundTripper) Close() error {
-	r.mutex.Lock()
-	defer r.mutex.Unlock()
-	for _, client := range r.clients {
-		if err := client.Close(); err != nil {
-			return err
-		}
-	}
-	r.clients = nil
-	if r.transport != nil {
-		if err := r.transport.Close(); err != nil {
-			return err
-		}
-		if err := r.transport.Conn.Close(); err != nil {
-			return err
-		}
-		r.transport = nil
-	}
-	return nil
-}
-
-func closeRequestBody(req *http.Request) {
-	if req.Body != nil {
-		req.Body.Close()
-	}
-}
-
-func validMethod(method string) bool {
-	/*
-				     Method         = "OPTIONS"                ; Section 9.2
-		   		                    | "GET"                    ; Section 9.3
-		   		                    | "HEAD"                   ; Section 9.4
-		   		                    | "POST"                   ; Section 9.5
-		   		                    | "PUT"                    ; Section 9.6
-		   		                    | "DELETE"                 ; Section 9.7
-		   		                    | "TRACE"                  ; Section 9.8
-		   		                    | "CONNECT"                ; Section 9.9
-		   		                    | extension-method
-		   		   extension-method = token
-		   		     token          = 1*<any CHAR except CTLs or separators>
-	*/
-	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
-}
-
-// copied from net/http/http.go
-func isNotToken(r rune) bool {
-	return !httpguts.IsTokenRune(r)
-}
-
-// makeDialer makes a QUIC dialer using r.udpConn.
-func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
-	return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
-		udpAddr, err := net.ResolveUDPAddr("udp", addr)
-		if err != nil {
-			return nil, err
-		}
-		return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
-	}
-}
-
-func (r *RoundTripper) CloseIdleConnections() {
-	r.mutex.Lock()
-	defer r.mutex.Unlock()
-	for hostname, client := range r.clients {
-		if client.useCount.Load() == 0 {
-			client.Close()
-			delete(r.clients, hostname)
-		}
-	}
-}

+ 298 - 218
vendor/github.com/Psiphon-Labs/quic-go/http3/server.go

@@ -2,22 +2,22 @@ package http3
 
 import (
 	"context"
+	tls "github.com/Psiphon-Labs/psiphon-tls"
 	"errors"
 	"fmt"
 	"io"
+	"log/slog"
 	"net"
 	"net/http"
 	"runtime"
 	"strconv"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"time"
 
-	tls "github.com/Psiphon-Labs/psiphon-tls"
-
 	"github.com/Psiphon-Labs/quic-go"
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
-	"github.com/Psiphon-Labs/quic-go/internal/utils"
 	"github.com/Psiphon-Labs/quic-go/quicvarint"
 
 	"github.com/quic-go/qpack"
@@ -31,7 +31,6 @@ var (
 	quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
 		return quic.ListenAddrEarly(addr, tlsConf, config)
 	}
-	errPanicked = errors.New("panicked")
 )
 
 // NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2.
@@ -47,6 +46,8 @@ const (
 	streamTypeQPACKDecoderStream = 3
 )
 
+const goawayTimeout = 5 * time.Second
+
 // A QUICEarlyListener listens for incoming QUIC connections.
 type QUICEarlyListener interface {
 	Accept(context.Context) (quic.EarlyConnection, error)
@@ -56,7 +57,7 @@ type QUICEarlyListener interface {
 
 var _ QUICEarlyListener = &quic.EarlyListener{}
 
-func versionToALPN(v protocol.VersionNumber) string {
+func versionToALPN(v protocol.Version) string {
 	//nolint:exhaustive // These are all the versions we care about.
 	switch v {
 	case protocol.Version1, protocol.Version2:
@@ -78,7 +79,7 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
 			// determine the ALPN from the QUIC version used
 			proto := NextProtoH3
 			val := ch.Context().Value(quic.QUICVersionContextKey)
-			if v, ok := val.(quic.VersionNumber); ok {
+			if v, ok := val.(quic.Version); ok {
 				proto = versionToALPN(v)
 			}
 			config := tlsConf
@@ -96,6 +97,10 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
 			if config == nil {
 				return nil, nil
 			}
+			// Workaround for https://github.com/golang/go/issues/60506.
+			// This initializes the session tickets _before_ cloning the config.
+			_, _ = config.DecryptTicket(nil, tls.ConnectionState{})
+
 			config = config.Clone()
 			config.NextProtos = []string{proto}
 			return config, nil
@@ -127,20 +132,6 @@ var ServerContextKey = &contextKey{"http3-server"}
 // than its string representation.
 var RemoteAddrContextKey = &contextKey{"remote-addr"}
 
-type requestError struct {
-	err       error
-	streamErr ErrCode
-	connErr   ErrCode
-}
-
-func newStreamError(code ErrCode, err error) requestError {
-	return requestError{err: err, streamErr: code}
-}
-
-func newConnError(code ErrCode, err error) requestError {
-	return requestError{err: err, connErr: code}
-}
-
 // listenerInfo contains info about specific listener added with addListener
 type listenerInfo struct {
 	port int // 0 means that no info about port is available
@@ -157,10 +148,10 @@ type Server struct {
 	//
 	// Otherwise, if Port is not set and underlying QUIC listeners do not
 	// have valid port numbers, the port part is used in Alt-Svc headers set
-	// with SetQuicHeaders.
+	// with SetQUICHeaders.
 	Addr string
 
-	// Port is used in Alt-Svc response headers set with SetQuicHeaders. If
+	// Port is used in Alt-Svc response headers set with SetQUICHeaders. If
 	// needed Port can be manually set when the Server is created.
 	//
 	// This is useful when a Layer 4 firewall is redirecting UDP traffic and
@@ -172,20 +163,18 @@ type Server struct {
 	// set for ListenAndServe and Serve methods.
 	TLSConfig *tls.Config
 
-	// QuicConfig provides the parameters for QUIC connection created with
-	// Serve. If nil, it uses reasonable default values.
+	// QUICConfig provides the parameters for QUIC connection created with Serve.
+	// If nil, it uses reasonable default values.
 	//
-	// Configured versions are also used in Alt-Svc response header set with
-	// SetQuicHeaders.
-	QuicConfig *quic.Config
+	// Configured versions are also used in Alt-Svc response header set with SetQUICHeaders.
+	QUICConfig *quic.Config
 
 	// Handler is the HTTP request handler to use. If not set, defaults to
 	// http.NotFound.
 	Handler http.Handler
 
-	// EnableDatagrams enables support for HTTP/3 datagrams.
-	// If set to true, QuicConfig.EnableDatagram will be set.
-	// See https://datatracker.ietf.org/doc/html/rfc9297.
+	// EnableDatagrams enables support for HTTP/3 datagrams (RFC 9297).
+	// If set to true, QUICConfig.EnableDatagrams will be set.
 	EnableDatagrams bool
 
 	// MaxHeaderBytes controls the maximum number of bytes the server will
@@ -195,7 +184,7 @@ type Server struct {
 	MaxHeaderBytes int
 
 	// AdditionalSettings specifies additional HTTP/3 settings.
-	// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
+	// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
 	AdditionalSettings map[uint64]uint64
 
 	// StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream.
@@ -205,33 +194,50 @@ type Server struct {
 	// Callers can either ignore the frame and return control of the stream back to HTTP/3
 	// (by returning hijacked false).
 	// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
-	StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
+	StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
 
 	// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
 	// If parsing the stream type fails, the error is passed to the callback.
 	// In that case, the stream type will not be set.
-	UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
+	UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
+
+	// IdleTimeout specifies how long until idle clients connection should be
+	// closed. Idle refers only to the HTTP/3 layer, activity at the QUIC layer
+	// like PING frames are not considered.
+	// If zero or negative, there is no timeout.
+	IdleTimeout time.Duration
 
-	// ConnContext optionally specifies a function that modifies
-	// the context used for a new connection c. The provided ctx
-	// has a ServerContextKey value.
+	// ConnContext optionally specifies a function that modifies the context used for a new connection c.
+	// The provided ctx has a ServerContextKey value.
 	ConnContext func(ctx context.Context, c quic.Connection) context.Context
 
+	Logger *slog.Logger
+
 	mutex     sync.RWMutex
 	listeners map[*QUICEarlyListener]listenerInfo
 
-	closed bool
+	closed           bool
+	closeCtx         context.Context    // canceled when the server is closed
+	closeCancel      context.CancelFunc // cancels the closeCtx
+	graceCtx         context.Context    // canceled when the server is closed or gracefully closed
+	graceCancel      context.CancelFunc // cancels the graceCtx
+	connCount        atomic.Int64
+	connHandlingDone chan struct{}
 
 	altSvcHeader string
-
-	logger utils.Logger
 }
 
 // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
 //
 // If s.Addr is blank, ":https" is used.
 func (s *Server) ListenAndServe() error {
-	return s.serveConn(s.TLSConfig, nil)
+	ln, err := s.setupListenerForConn(s.TLSConfig, nil)
+	if err != nil {
+		return err
+	}
+	defer s.removeListener(&ln)
+
+	return s.serveListener(ln)
 }
 
 // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
@@ -246,27 +252,53 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
 	}
 	// We currently only use the cert-related stuff from tls.Config,
 	// so we don't need to make a full copy.
-	config := &tls.Config{
-		Certificates: certs,
+	ln, err := s.setupListenerForConn(&tls.Config{Certificates: certs}, nil)
+	if err != nil {
+		return err
 	}
-	return s.serveConn(config, nil)
+	defer s.removeListener(&ln)
+
+	return s.serveListener(ln)
 }
 
 // Serve an existing UDP connection.
 // It is possible to reuse the same connection for outgoing connections.
 // Closing the server does not close the connection.
 func (s *Server) Serve(conn net.PacketConn) error {
-	return s.serveConn(s.TLSConfig, conn)
+	ln, err := s.setupListenerForConn(s.TLSConfig, conn)
+	if err != nil {
+		return err
+	}
+	defer s.removeListener(&ln)
+
+	return s.serveListener(ln)
+}
+
+// init initializes the contexts used for shutting down the server.
+// It must be called with the mutex held.
+func (s *Server) init() {
+	if s.closeCtx == nil {
+		s.closeCtx, s.closeCancel = context.WithCancel(context.Background())
+		s.graceCtx, s.graceCancel = context.WithCancel(s.closeCtx)
+	}
+	s.connHandlingDone = make(chan struct{}, 1)
+}
+
+func (s *Server) decreaseConnCount() {
+	if s.connCount.Add(-1) == 0 && s.graceCtx.Err() != nil {
+		close(s.connHandlingDone)
+	}
 }
 
 // ServeQUICConn serves a single QUIC connection.
 func (s *Server) ServeQUICConn(conn quic.Connection) error {
 	s.mutex.Lock()
-	if s.logger == nil {
-		s.logger = utils.DefaultLogger.WithPrefix("server")
-	}
+	s.init()
 	s.mutex.Unlock()
 
+	s.connCount.Add(1)
+	defer s.decreaseConnCount()
+
 	return s.handleConn(conn)
 }
 
@@ -276,21 +308,34 @@ func (s *Server) ServeQUICConn(conn quic.Connection) error {
 // Closing the server does close the listener.
 // ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
 func (s *Server) ServeListener(ln QUICEarlyListener) error {
+	s.mutex.Lock()
 	if err := s.addListener(&ln); err != nil {
+		s.mutex.Unlock()
 		return err
 	}
+	s.mutex.Unlock()
 	defer s.removeListener(&ln)
+
+	return s.serveListener(ln)
+}
+
+func (s *Server) serveListener(ln QUICEarlyListener) error {
 	for {
-		conn, err := ln.Accept(context.Background())
-		if err == quic.ErrServerClosed {
+		conn, err := ln.Accept(s.graceCtx)
+		// server closed
+		if errors.Is(err, quic.ErrServerClosed) || s.graceCtx.Err() != nil {
 			return http.ErrServerClosed
 		}
 		if err != nil {
 			return err
 		}
+		s.connCount.Add(1)
 		go func() {
+			defer s.decreaseConnCount()
 			if err := s.handleConn(conn); err != nil {
-				s.logger.Debugf("handling connection failed: %s", err)
+				if s.Logger != nil {
+					s.Logger.Debug("handling connection failed", "error", err)
+				}
 			}
 		}()
 	}
@@ -298,29 +343,29 @@ func (s *Server) ServeListener(ln QUICEarlyListener) error {
 
 var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig")
 
-func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
+func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn) (QUICEarlyListener, error) {
 	if tlsConf == nil {
-		return errServerWithoutTLSConfig
-	}
-
-	s.mutex.Lock()
-	closed := s.closed
-	s.mutex.Unlock()
-	if closed {
-		return http.ErrServerClosed
+		return nil, errServerWithoutTLSConfig
 	}
 
 	baseConf := ConfigureTLSConfig(tlsConf)
-	quicConf := s.QuicConfig
+	quicConf := s.QUICConfig
 	if quicConf == nil {
 		quicConf = &quic.Config{Allow0RTT: true}
 	} else {
-		quicConf = s.QuicConfig.Clone()
+		quicConf = s.QUICConfig.Clone()
 	}
 	if s.EnableDatagrams {
 		quicConf.EnableDatagrams = true
 	}
 
+	s.mutex.Lock()
+	defer s.mutex.Unlock()
+	closed := s.closed
+	if closed {
+		return nil, http.ErrServerClosed
+	}
+
 	var ln QUICEarlyListener
 	var err error
 	if conn == nil {
@@ -333,9 +378,12 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
 		ln, err = quicListen(conn, baseConf, quicConf)
 	}
 	if err != nil {
-		return err
+		return nil, err
 	}
-	return s.ServeListener(ln)
+	if err := s.addListener(&ln); err != nil {
+		return nil, err
+	}
+	return ln, nil
 }
 
 func extractPort(addr string) (int, error) {
@@ -360,8 +408,8 @@ func (s *Server) generateAltSvcHeader() {
 
 	// This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed.
 	supportedVersions := protocol.SupportedVersions
-	if s.QuicConfig != nil && len(s.QuicConfig.Versions) > 0 {
-		supportedVersions = s.QuicConfig.Versions
+	if s.QUICConfig != nil && len(s.QUICConfig.Versions) > 0 {
+		supportedVersions = s.QUICConfig.Versions
 	}
 
 	// keep track of which have been seen so we don't yield duplicate values
@@ -411,24 +459,23 @@ func (s *Server) generateAltSvcHeader() {
 // call trackListener via Serve and can track+defer untrack the same pointer to
 // local variable there. We never need to compare a Listener from another caller.
 func (s *Server) addListener(l *QUICEarlyListener) error {
-	s.mutex.Lock()
-	defer s.mutex.Unlock()
-
 	if s.closed {
 		return http.ErrServerClosed
 	}
-	if s.logger == nil {
-		s.logger = utils.DefaultLogger.WithPrefix("server")
-	}
 	if s.listeners == nil {
 		s.listeners = make(map[*QUICEarlyListener]listenerInfo)
 	}
+	s.init()
 
 	laddr := (*l).Addr()
 	if port, err := extractPort(laddr.String()); err == nil {
 		s.listeners[l] = listenerInfo{port}
 	} else {
-		s.logger.Errorf("Unable to extract port from listener %s, will not be announced using SetQuicHeaders: %s", laddr, err)
+		logger := s.Logger
+		if logger == nil {
+			logger = slog.Default()
+		}
+		logger.Error("Unable to extract port from listener, will not be announced using SetQUICHeaders", "local addr", laddr, "error", err)
 		s.listeners[l] = listenerInfo{}
 	}
 	s.generateAltSvcHeader()
@@ -442,113 +489,97 @@ func (s *Server) removeListener(l *QUICEarlyListener) {
 	s.generateAltSvcHeader()
 }
 
+// handleConn handles the HTTP/3 exchange on a QUIC connection.
+// It blocks until all HTTP handlers for all streams have returned.
 func (s *Server) handleConn(conn quic.Connection) error {
-	decoder := qpack.NewDecoder(nil)
-
-	// send a SETTINGS frame
-	str, err := conn.OpenUniStream()
+	// open the control stream and send a SETTINGS frame, it's also used to send a GOAWAY frame later
+	// when the server is gracefully closed
+	ctrlStr, err := conn.OpenUniStream()
 	if err != nil {
 		return fmt.Errorf("opening the control stream failed: %w", err)
 	}
 	b := make([]byte, 0, 64)
 	b = quicvarint.Append(b, streamTypeControlStream) // stream type
-	b = (&settingsFrame{Datagram: s.EnableDatagrams, Other: s.AdditionalSettings}).Append(b)
-	str.Write(b)
+	b = (&settingsFrame{
+		Datagram:        s.EnableDatagrams,
+		ExtendedConnect: true,
+		Other:           s.AdditionalSettings,
+	}).Append(b)
+	ctrlStr.Write(b)
+
+	ctx := conn.Context()
+	ctx = context.WithValue(ctx, ServerContextKey, s)
+	ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
+	ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
+	if s.ConnContext != nil {
+		ctx = s.ConnContext(ctx, conn)
+		if ctx == nil {
+			panic("http3: ConnContext returned nil")
+		}
+	}
 
-	go s.handleUnidirectionalStreams(conn)
+	hconn := newConnection(
+		ctx,
+		conn,
+		s.EnableDatagrams,
+		protocol.PerspectiveServer,
+		s.Logger,
+		s.IdleTimeout,
+	)
+	go hconn.handleUnidirectionalStreams(s.UniStreamHijacker)
 
+	var nextStreamID quic.StreamID
+	var wg sync.WaitGroup
+	var handleErr error
 	// Process all requests immediately.
 	// It's the client's responsibility to decide which requests are eligible for 0-RTT.
 	for {
-		str, err := conn.AcceptStream(context.Background())
+		str, datagrams, err := hconn.acceptStream(s.graceCtx)
 		if err != nil {
-			var appErr *quic.ApplicationError
-			if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) {
-				return nil
+			// server (not gracefully) closed, close the connection immediately
+			if s.closeCtx.Err() != nil {
+				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
+				handleErr = http.ErrServerClosed
+				break
 			}
-			return fmt.Errorf("accepting stream failed: %w", err)
-		}
-		go func() {
-			rerr := s.handleRequest(conn, str, decoder, func() {
-				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
-			})
-			if rerr.err == errHijacked {
-				return
-			}
-			if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
-				s.logger.Debugf("Handling request failed: %s", err)
-				if rerr.streamErr != 0 {
-					str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
-				}
-				if rerr.connErr != 0 {
-					var reason string
-					if rerr.err != nil {
-						reason = rerr.err.Error()
-					}
-					conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
+
+			// gracefully closed, send GOAWAY frame and wait for requests to complete or grace period to end
+			// new requests will be rejected and shouldn't be sent
+			if s.graceCtx.Err() != nil {
+				b = (&goAwayFrame{StreamID: nextStreamID}).Append(b[:0])
+				// set a deadline to send the GOAWAY frame
+				ctrlStr.SetWriteDeadline(time.Now().Add(goawayTimeout))
+				ctrlStr.Write(b)
+
+				select {
+				case <-hconn.Context().Done():
+					// we expect the client to eventually close the connection after receiving the GOAWAY
+				case <-s.closeCtx.Done():
+					// close the connection after graceful period
+					conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
 				}
-				return
+				handleErr = http.ErrServerClosed
+				break
 			}
-			str.Close()
-		}()
-	}
-}
 
-func (s *Server) handleUnidirectionalStreams(conn quic.Connection) {
-	for {
-		str, err := conn.AcceptUniStream(context.Background())
-		if err != nil {
-			s.logger.Debugf("accepting unidirectional stream failed: %s", err)
-			return
+			var appErr *quic.ApplicationError
+			if !errors.As(err, &appErr) || appErr.ErrorCode != quic.ApplicationErrorCode(ErrCodeNoError) {
+				handleErr = fmt.Errorf("accepting stream failed: %w", err)
+			}
+			break
 		}
 
-		go func(str quic.ReceiveStream) {
-			streamType, err := quicvarint.Read(quicvarint.NewReader(str))
-			if err != nil {
-				if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) {
-					return
-				}
-				s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
-				return
-			}
-			// We're only interested in the control stream here.
-			switch streamType {
-			case streamTypeControlStream:
-			case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
-				// Our QPACK implementation doesn't use the dynamic table yet.
-				// TODO: check that only one stream of each type is opened.
-				return
-			case streamTypePushStream: // only the server can push
-				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "")
-				return
-			default:
-				if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
-					return
-				}
-				str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
-				return
-			}
-			f, err := parseNextFrame(str, nil)
-			if err != nil {
-				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
-				return
-			}
-			sf, ok := f.(*settingsFrame)
-			if !ok {
-				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
-				return
-			}
-			if !sf.Datagram {
-				return
-			}
-			// If datagram support was enabled on our side as well as on the client side,
-			// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
-			// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
-			if s.EnableDatagrams && !conn.ConnectionState().SupportsDatagrams {
-				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
-			}
-		}(str)
+		nextStreamID = str.StreamID() + 4
+		wg.Add(1)
+		go func() {
+			// handleRequest will return once the request has been handled,
+			// or the underlying connection is closed
+			defer wg.Done()
+			s.handleRequest(hconn, str, datagrams, hconn.decoder)
+		}()
 	}
+	wg.Wait()
+	return handleErr
 }
 
 func (s *Server) maxHeaderBytes() uint64 {
@@ -558,37 +589,54 @@ func (s *Server) maxHeaderBytes() uint64 {
 	return uint64(s.MaxHeaderBytes)
 }
 
-func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
+func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *datagrammer, decoder *qpack.Decoder) {
 	var ufh unknownFrameHandlerFunc
 	if s.StreamHijacker != nil {
-		ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) }
+		ufh = func(ft FrameType, e error) (processed bool, err error) {
+			return s.StreamHijacker(
+				ft,
+				conn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID),
+				str,
+				e,
+			)
+		}
 	}
-	frame, err := parseNextFrame(str, ufh)
+	fp := &frameParser{conn: conn, r: str, unknownFrameHandler: ufh}
+	frame, err := fp.ParseNext()
 	if err != nil {
-		if err == errHijacked {
-			return requestError{err: errHijacked}
+		if !errors.Is(err, errHijacked) {
+			str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
+			str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
 		}
-		return newStreamError(ErrCodeRequestIncomplete, err)
+		return
 	}
 	hf, ok := frame.(*headersFrame)
 	if !ok {
-		return newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
+		conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame")
+		return
 	}
 	if hf.Length > s.maxHeaderBytes() {
-		return newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes()))
+		str.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
+		str.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
+		return
 	}
 	headerBlock := make([]byte, hf.Length)
 	if _, err := io.ReadFull(str, headerBlock); err != nil {
-		return newStreamError(ErrCodeRequestIncomplete, err)
+		str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
+		str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
+		return
 	}
 	hfs, err := decoder.DecodeFull(headerBlock)
 	if err != nil {
 		// TODO: use the right error code
-		return newConnError(ErrCodeGeneralProtocolError, err)
+		conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "expected first frame to be a HEADERS frame")
+		return
 	}
 	req, err := requestFromHeaders(hfs)
 	if err != nil {
-		return newStreamError(ErrCodeMessageError, err)
+		str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
+		str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
+		return
 	}
 
 	connState := conn.ConnectionState().TLS
@@ -600,41 +648,29 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
 
 	// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
 	// See section 4.1.2 of RFC 9114.
-	var httpStr Stream
+	contentLength := int64(-1)
 	if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
-		httpStr = newLengthLimitedStream(newStream(str, onFrameError), req.ContentLength)
-	} else {
-		httpStr = newStream(str, onFrameError)
+		contentLength = req.ContentLength
 	}
-	body := newRequestBody(httpStr)
+	hstr := newStream(str, conn, datagrams, nil)
+	body := newRequestBody(hstr, contentLength, conn.Context(), conn.ReceivedSettings(), conn.Settings)
 	req.Body = body
 
-	if s.logger.Debug() {
-		s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
-	} else {
-		s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
+	if s.Logger != nil {
+		s.Logger.Debug("handling request", "method", req.Method, "host", req.Host, "uri", req.RequestURI)
 	}
 
-	ctx := str.Context()
-	ctx = context.WithValue(ctx, ServerContextKey, s)
-	ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
-	ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
-	if s.ConnContext != nil {
-		ctx = s.ConnContext(ctx, conn)
-		if ctx == nil {
-			panic("http3: ConnContext returned nil")
-		}
-	}
+	ctx, cancel := context.WithCancel(conn.Context())
 	req = req.WithContext(ctx)
-	r := newResponseWriter(str, conn, s.logger)
-	if req.Method == http.MethodHead {
-		r.isHead = true
-	}
+	context.AfterFunc(str.Context(), cancel)
+
+	r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, s.Logger)
 	handler := s.Handler
 	if handler == nil {
 		handler = http.DefaultServeMux
 	}
 
+	// It's the client's responsibility to decide which requests are eligible for 0-RTT.
 	var panicked bool
 	func() {
 		defer func() {
@@ -647,43 +683,58 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
 				const size = 64 << 10
 				buf := make([]byte, size)
 				buf = buf[:runtime.Stack(buf, false)]
-				s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
+				logger := s.Logger
+				if logger == nil {
+					logger = slog.Default()
+				}
+				logger.Error("http: panic serving", "arg", p, "trace", string(buf))
 			}
 		}()
 		handler.ServeHTTP(r, req)
 	}()
 
-	if body.wasStreamHijacked() {
-		return requestError{err: errHijacked}
+	if r.wasStreamHijacked() {
+		return
 	}
 
 	// only write response when there is no panic
 	if !panicked {
 		// response not written to the client yet, set Content-Length
-		if !r.written {
+		if !r.headerWritten {
 			if _, haveCL := r.header["Content-Length"]; !haveCL {
 				r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
 			}
 		}
 		r.Flush()
+		r.flushTrailers()
 	}
-	// If the EOF was read by the handler, CancelRead() is a no-op.
-	str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
 
 	// abort the stream when there is a panic
 	if panicked {
-		return newStreamError(ErrCodeInternalError, errPanicked)
+		str.CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
+		str.CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
+		return
 	}
-	return requestError{}
+
+	// If the EOF was read by the handler, CancelRead() is a no-op.
+	str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
+
+	str.Close()
 }
 
 // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
 // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
+// It is the caller's responsibility to close any connection passed to ServeQUICConn.
 func (s *Server) Close() error {
 	s.mutex.Lock()
 	defer s.mutex.Unlock()
 
 	s.closed = true
+	// server is never used
+	if s.closeCtx == nil {
+		return nil
+	}
+	s.closeCancel()
 
 	var err error
 	for ln := range s.listeners {
@@ -691,38 +742,67 @@ func (s *Server) Close() error {
 			err = cerr
 		}
 	}
+	if s.connCount.Load() == 0 {
+		return err
+	}
+	// wait for all connections to be closed
+	<-s.connHandlingDone
 	return err
 }
 
-// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
-// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
-func (s *Server) CloseGracefully(timeout time.Duration) error {
-	// TODO: implement
-	return nil
+// Shutdown shuts down the server gracefully.
+// The server sends a GOAWAY frame first, then or for all running requests to complete.
+// Shutdown in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
+func (s *Server) Shutdown(ctx context.Context) error {
+	s.mutex.Lock()
+	s.closed = true
+	// server is never used
+	if s.closeCtx == nil {
+		s.mutex.Unlock()
+		return nil
+	}
+	s.graceCancel()
+	s.mutex.Unlock()
+
+	if s.connCount.Load() == 0 {
+		return s.Close()
+	}
+	select {
+	case <-s.connHandlingDone: // all connections were closed
+		// When receiving a GOAWAY frame, HTTP/3 clients are expected to close the connection
+		// once all requests were successfully handled...
+		return s.Close()
+	case <-ctx.Done():
+		// ... however, clients handling long-lived requests (and misbehaving clients),
+		// might not do so before the context is cancelled.
+		// In this case, we close the server, which closes all existing connections
+		// (expect those passed to ServeQUICConn).
+		_ = s.Close()
+		return ctx.Err()
+	}
 }
 
-// ErrNoAltSvcPort is the error returned by SetQuicHeaders when no port was found
+// ErrNoAltSvcPort is the error returned by SetQUICHeaders when no port was found
 // for Alt-Svc to announce. This can happen if listening on a PacketConn without a port
 // (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr.
 var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr")
 
-// SetQuicHeaders can be used to set the proper headers that announce that this server supports HTTP/3.
-// The values set by default advertise all of the ports the server is listening on, but can be
-// changed to a specific port by setting Server.Port before launching the serverr.
+// SetQUICHeaders can be used to set the proper headers that announce that this server supports HTTP/3.
+// The values set by default advertise all the ports the server is listening on, but can be
+// changed to a specific port by setting Server.Port before launching the server.
 // If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used
 // to extract the port, if specified.
 // For example, a server launched using ListenAndServe on an address with port 443 would set:
 //
-//	Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
-func (s *Server) SetQuicHeaders(hdr http.Header) error {
+//	Alt-Svc: h3=":443"; ma=2592000
+func (s *Server) SetQUICHeaders(hdr http.Header) error {
 	s.mutex.RLock()
 	defer s.mutex.RUnlock()
 
 	if s.altSvcHeader == "" {
 		return ErrNoAltSvcPort
 	}
-	// use the map directly to avoid constant canonicalization
-	// since the key is already canonicalized
+	// use the map directly to avoid constant canonicalization since the key is already canonicalized
 	hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader)
 	return nil
 }
@@ -738,11 +818,11 @@ func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) er
 	return server.ListenAndServeTLS(certFile, keyFile)
 }
 
-// ListenAndServe listens on the given network address for both TLS/TCP and QUIC
+// ListenAndServeTLS listens on the given network address for both TLS/TCP and QUIC
 // connections in parallel. It returns if one of the two returns an error.
 // http.DefaultServeMux is used when handler is nil.
 // The correct Alt-Svc headers for QUIC are set.
-func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
+func ListenAndServeTLS(addr, certFile, keyFile string, handler http.Handler) error {
 	// Load certs
 	var err error
 	certs := make([]tls.Certificate, 1)
@@ -784,7 +864,7 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
 	qErr := make(chan error, 1)
 	go func() {
 		hErr <- http.ListenAndServeTLS(addr, certFile, keyFile, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-			quicServer.SetQuicHeaders(w.Header())
+			quicServer.SetQUICHeaders(w.Header())
 			handler.ServeHTTP(w, r)
 		}))
 	}()

+ 116 - 0
vendor/github.com/Psiphon-Labs/quic-go/http3/state_tracking_stream.go

@@ -0,0 +1,116 @@
+package http3
+
+import (
+	"context"
+	"errors"
+	"os"
+	"sync"
+
+	"github.com/Psiphon-Labs/quic-go"
+)
+
+var _ quic.Stream = &stateTrackingStream{}
+
+// stateTrackingStream is an implementation of quic.Stream that delegates
+// to an underlying stream
+// it takes care of proxying send and receive errors onto an implementation of
+// the errorSetter interface (intended to be occupied by a datagrammer)
+// it is also responsible for clearing the stream based on its ID from its
+// parent connection, this is done through the streamClearer interface when
+// both the send and receive sides are closed
+type stateTrackingStream struct {
+	quic.Stream
+
+	mx      sync.Mutex
+	sendErr error
+	recvErr error
+
+	clearer streamClearer
+	setter  errorSetter
+}
+
+type streamClearer interface {
+	clearStream(quic.StreamID)
+}
+
+type errorSetter interface {
+	SetSendError(error)
+	SetReceiveError(error)
+}
+
+func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream {
+	t := &stateTrackingStream{
+		Stream:  s,
+		clearer: clearer,
+		setter:  setter,
+	}
+
+	context.AfterFunc(s.Context(), func() {
+		t.closeSend(context.Cause(s.Context()))
+	})
+
+	return t
+}
+
+func (s *stateTrackingStream) closeSend(e error) {
+	s.mx.Lock()
+	defer s.mx.Unlock()
+
+	// clear the stream the first time both the send
+	// and receive are finished
+	if s.sendErr == nil {
+		if s.recvErr != nil {
+			s.clearer.clearStream(s.StreamID())
+		}
+
+		s.setter.SetSendError(e)
+		s.sendErr = e
+	}
+}
+
+func (s *stateTrackingStream) closeReceive(e error) {
+	s.mx.Lock()
+	defer s.mx.Unlock()
+
+	// clear the stream the first time both the send
+	// and receive are finished
+	if s.recvErr == nil {
+		if s.sendErr != nil {
+			s.clearer.clearStream(s.StreamID())
+		}
+
+		s.setter.SetReceiveError(e)
+		s.recvErr = e
+	}
+}
+
+func (s *stateTrackingStream) Close() error {
+	s.closeSend(errors.New("write on closed stream"))
+	return s.Stream.Close()
+}
+
+func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) {
+	s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
+	s.Stream.CancelWrite(e)
+}
+
+func (s *stateTrackingStream) Write(b []byte) (int, error) {
+	n, err := s.Stream.Write(b)
+	if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
+		s.closeSend(err)
+	}
+	return n, err
+}
+
+func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) {
+	s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
+	s.Stream.CancelRead(e)
+}
+
+func (s *stateTrackingStream) Read(b []byte) (int, error) {
+	n, err := s.Stream.Read(b)
+	if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
+		s.closeReceive(err)
+	}
+	return n, err
+}

+ 110 - 0
vendor/github.com/Psiphon-Labs/quic-go/http3/trace.go

@@ -0,0 +1,110 @@
+package http3
+
+import (
+	"net"
+	"net/http/httptrace"
+	"net/textproto"
+	"time"
+
+	tls "github.com/Psiphon-Labs/psiphon-tls"
+
+	"github.com/Psiphon-Labs/quic-go"
+)
+
+func traceGetConn(trace *httptrace.ClientTrace, hostPort string) {
+	if trace != nil && trace.GetConn != nil {
+		trace.GetConn(hostPort)
+	}
+}
+
+// fakeConn is a wrapper for quic.EarlyConnection
+// because the quic connection does not implement net.Conn.
+type fakeConn struct {
+	conn quic.EarlyConnection
+}
+
+func (c *fakeConn) Close() error                       { panic("connection operation prohibited") }
+func (c *fakeConn) Read(p []byte) (int, error)         { panic("connection operation prohibited") }
+func (c *fakeConn) Write(p []byte) (int, error)        { panic("connection operation prohibited") }
+func (c *fakeConn) SetDeadline(t time.Time) error      { panic("connection operation prohibited") }
+func (c *fakeConn) SetReadDeadline(t time.Time) error  { panic("connection operation prohibited") }
+func (c *fakeConn) SetWriteDeadline(t time.Time) error { panic("connection operation prohibited") }
+func (c *fakeConn) RemoteAddr() net.Addr               { return c.conn.RemoteAddr() }
+func (c *fakeConn) LocalAddr() net.Addr                { return c.conn.LocalAddr() }
+
+func traceGotConn(trace *httptrace.ClientTrace, conn quic.EarlyConnection, reused bool) {
+	if trace != nil && trace.GotConn != nil {
+		trace.GotConn(httptrace.GotConnInfo{
+			Conn:   &fakeConn{conn: conn},
+			Reused: reused,
+		})
+	}
+}
+
+func traceGotFirstResponseByte(trace *httptrace.ClientTrace) {
+	if trace != nil && trace.GotFirstResponseByte != nil {
+		trace.GotFirstResponseByte()
+	}
+}
+
+func traceGot1xxResponse(trace *httptrace.ClientTrace, code int, header textproto.MIMEHeader) {
+	if trace != nil && trace.Got1xxResponse != nil {
+		trace.Got1xxResponse(code, header)
+	}
+}
+
+func traceGot100Continue(trace *httptrace.ClientTrace) {
+	if trace != nil && trace.Got100Continue != nil {
+		trace.Got100Continue()
+	}
+}
+
+func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
+	return trace != nil && trace.WroteHeaderField != nil
+}
+
+func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
+	if trace != nil && trace.WroteHeaderField != nil {
+		trace.WroteHeaderField(k, []string{v})
+	}
+}
+
+func traceWroteHeaders(trace *httptrace.ClientTrace) {
+	if trace != nil && trace.WroteHeaders != nil {
+		trace.WroteHeaders()
+	}
+}
+
+func traceWroteRequest(trace *httptrace.ClientTrace, err error) {
+	if trace != nil && trace.WroteRequest != nil {
+		trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
+	}
+}
+
+func traceConnectStart(trace *httptrace.ClientTrace, network, addr string) {
+	if trace != nil && trace.ConnectStart != nil {
+		trace.ConnectStart(network, addr)
+	}
+}
+
+func traceConnectDone(trace *httptrace.ClientTrace, network, addr string, err error) {
+	if trace != nil && trace.ConnectDone != nil {
+		trace.ConnectDone(network, addr, err)
+	}
+}
+
+func traceTLSHandshakeStart(trace *httptrace.ClientTrace) {
+	if trace != nil && trace.TLSHandshakeStart != nil {
+		trace.TLSHandshakeStart()
+	}
+}
+
+func traceTLSHandshakeDone(trace *httptrace.ClientTrace, state tls.ConnectionState, err error) {
+	if trace != nil && trace.TLSHandshakeDone != nil {
+
+		// [Psiphon]
+		state := *tls.UnsafeFromConnectionState(&state)
+
+		trace.TLSHandshakeDone(state, err)
+	}
+}

+ 466 - 0
vendor/github.com/Psiphon-Labs/quic-go/http3/transport.go

@@ -0,0 +1,466 @@
+package http3
+
+import (
+	"context"
+	tls "github.com/Psiphon-Labs/psiphon-tls"
+	"errors"
+	"fmt"
+	"io"
+	"log/slog"
+	"net"
+	"net/http"
+	"net/http/httptrace"
+	"strings"
+	"sync"
+	"sync/atomic"
+
+	"golang.org/x/net/http/httpguts"
+
+	"github.com/Psiphon-Labs/quic-go"
+	"github.com/Psiphon-Labs/quic-go/internal/protocol"
+)
+
+// Settings are HTTP/3 settings that apply to the underlying connection.
+type Settings struct {
+	// Support for HTTP/3 datagrams (RFC 9297)
+	EnableDatagrams bool
+	// Extended CONNECT, RFC 9220
+	EnableExtendedConnect bool
+	// Other settings, defined by the application
+	Other map[uint64]uint64
+}
+
+// RoundTripOpt are options for the Transport.RoundTripOpt method.
+type RoundTripOpt struct {
+	// OnlyCachedConn controls whether the Transport may create a new QUIC connection.
+	// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
+	OnlyCachedConn bool
+}
+
+type clientConn interface {
+	OpenRequestStream(context.Context) (RequestStream, error)
+	RoundTrip(*http.Request) (*http.Response, error)
+}
+
+type roundTripperWithCount struct {
+	cancel     context.CancelFunc
+	dialing    chan struct{} // closed as soon as quic.Dial(Early) returned
+	dialErr    error
+	conn       quic.EarlyConnection
+	clientConn clientConn
+
+	useCount atomic.Int64
+}
+
+func (r *roundTripperWithCount) Close() error {
+	r.cancel()
+	<-r.dialing
+	if r.conn != nil {
+		return r.conn.CloseWithError(0, "")
+	}
+	return nil
+}
+
+// Transport implements the http.RoundTripper interface
+type Transport struct {
+	// TLSClientConfig specifies the TLS configuration to use with
+	// tls.Client. If nil, the default configuration is used.
+	TLSClientConfig *tls.Config
+
+	// QUICConfig is the quic.Config used for dialing new connections.
+	// If nil, reasonable default values will be used.
+	QUICConfig *quic.Config
+
+	// Dial specifies an optional dial function for creating QUIC
+	// connections for requests.
+	// If Dial is nil, a UDPConn will be created at the first request
+	// and will be reused for subsequent connections to other servers.
+	Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
+
+	// Enable support for HTTP/3 datagrams (RFC 9297).
+	// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
+	EnableDatagrams bool
+
+	// Additional HTTP/3 settings.
+	// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
+	AdditionalSettings map[uint64]uint64
+
+	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
+	// allowed in the server's response header.
+	// Zero means to use a default limit.
+	MaxResponseHeaderBytes int64
+
+	// DisableCompression, if true, prevents the Transport from requesting compression with an
+	// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
+	// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
+	// decoded in the Response.Body.
+	// However, if the user explicitly requested gzip it is not automatically uncompressed.
+	DisableCompression bool
+
+	StreamHijacker    func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
+	UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
+
+	Logger *slog.Logger
+
+	mutex sync.Mutex
+
+	initOnce sync.Once
+	initErr  error
+
+	newClientConn func(quic.EarlyConnection) clientConn
+
+	clients   map[string]*roundTripperWithCount
+	transport *quic.Transport
+}
+
+var (
+	_ http.RoundTripper = &Transport{}
+	_ io.Closer         = &Transport{}
+)
+
+// Deprecated: RoundTripper was renamed to Transport.
+type RoundTripper = Transport
+
+// ErrNoCachedConn is returned when Transport.OnlyCachedConn is set
+var ErrNoCachedConn = errors.New("http3: no cached connection was available")
+
+func (t *Transport) init() error {
+	if t.newClientConn == nil {
+		t.newClientConn = func(conn quic.EarlyConnection) clientConn {
+			return newClientConn(
+				conn,
+				t.EnableDatagrams,
+				t.AdditionalSettings,
+				t.StreamHijacker,
+				t.UniStreamHijacker,
+				t.MaxResponseHeaderBytes,
+				t.DisableCompression,
+				t.Logger,
+			)
+		}
+	}
+	if t.QUICConfig == nil {
+		t.QUICConfig = defaultQuicConfig.Clone()
+		t.QUICConfig.EnableDatagrams = t.EnableDatagrams
+	}
+	if t.EnableDatagrams && !t.QUICConfig.EnableDatagrams {
+		return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
+	}
+	if len(t.QUICConfig.Versions) == 0 {
+		t.QUICConfig = t.QUICConfig.Clone()
+		t.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]}
+	}
+	if len(t.QUICConfig.Versions) != 1 {
+		return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
+	}
+	if t.QUICConfig.MaxIncomingStreams == 0 {
+		t.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
+	}
+	return nil
+}
+
+// RoundTripOpt is like RoundTrip, but takes options.
+func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
+	rsp, err := t.roundTripOpt(req, opt)
+	if err != nil {
+		if req.Body != nil {
+			req.Body.Close()
+		}
+		return nil, err
+	}
+	return rsp, nil
+}
+
+func (t *Transport) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
+	t.initOnce.Do(func() { t.initErr = t.init() })
+	if t.initErr != nil {
+		return nil, t.initErr
+	}
+
+	if req.URL == nil {
+		return nil, errors.New("http3: nil Request.URL")
+	}
+	if req.URL.Scheme != "https" {
+		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
+	}
+	if req.URL.Host == "" {
+		return nil, errors.New("http3: no Host in request URL")
+	}
+	if req.Header == nil {
+		return nil, errors.New("http3: nil Request.Header")
+	}
+	if req.Method != "" && !validMethod(req.Method) {
+		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
+	}
+	for k, vv := range req.Header {
+		if !httpguts.ValidHeaderFieldName(k) {
+			return nil, fmt.Errorf("http3: invalid http header field name %q", k)
+		}
+		for _, v := range vv {
+			if !httpguts.ValidHeaderFieldValue(v) {
+				return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
+			}
+		}
+	}
+
+	trace := httptrace.ContextClientTrace(req.Context())
+	hostname := authorityAddr(hostnameFromURL(req.URL))
+	traceGetConn(trace, hostname)
+	cl, isReused, err := t.getClient(req.Context(), hostname, opt.OnlyCachedConn)
+	if err != nil {
+		return nil, err
+	}
+
+	select {
+	case <-cl.dialing:
+	case <-req.Context().Done():
+		return nil, context.Cause(req.Context())
+	}
+
+	if cl.dialErr != nil {
+		t.removeClient(hostname)
+		return nil, cl.dialErr
+	}
+	traceGotConn(trace, cl.conn, isReused)
+	defer cl.useCount.Add(-1)
+	rsp, err := cl.clientConn.RoundTrip(req)
+	if err != nil {
+		// request aborted due to context cancellation
+		select {
+		case <-req.Context().Done():
+			return nil, err
+		default:
+		}
+
+		// Retry the request on a new connection if:
+		// 1. it was sent on a reused connection,
+		// 2. this connection is now closed,
+		// 3. and the error is a timeout error.
+		select {
+		case <-cl.conn.Context().Done():
+			t.removeClient(hostname)
+			if isReused {
+				var nerr net.Error
+				if errors.As(err, &nerr) && nerr.Timeout() {
+					return t.RoundTripOpt(req, opt)
+				}
+			}
+			return nil, err
+		default:
+			return nil, err
+		}
+	}
+	return rsp, nil
+}
+
+// RoundTrip does a round trip.
+func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
+	return t.RoundTripOpt(req, RoundTripOpt{})
+}
+
+func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+
+	if t.clients == nil {
+		t.clients = make(map[string]*roundTripperWithCount)
+	}
+
+	cl, ok := t.clients[hostname]
+	if !ok {
+		if onlyCached {
+			return nil, false, ErrNoCachedConn
+		}
+		ctx, cancel := context.WithCancel(ctx)
+		cl = &roundTripperWithCount{
+			dialing: make(chan struct{}),
+			cancel:  cancel,
+		}
+		go func() {
+			defer close(cl.dialing)
+			defer cancel()
+			conn, rt, err := t.dial(ctx, hostname)
+			if err != nil {
+				cl.dialErr = err
+				return
+			}
+			cl.conn = conn
+			cl.clientConn = rt
+		}()
+		t.clients[hostname] = cl
+	}
+	select {
+	case <-cl.dialing:
+		if cl.dialErr != nil {
+			delete(t.clients, hostname)
+			return nil, false, cl.dialErr
+		}
+		select {
+		case <-cl.conn.HandshakeComplete():
+			isReused = true
+		default:
+		}
+	default:
+	}
+	cl.useCount.Add(1)
+	return cl, isReused, nil
+}
+
+func (t *Transport) dial(ctx context.Context, hostname string) (quic.EarlyConnection, clientConn, error) {
+	var tlsConf *tls.Config
+	if t.TLSClientConfig == nil {
+		tlsConf = &tls.Config{}
+	} else {
+		tlsConf = t.TLSClientConfig.Clone()
+	}
+	if tlsConf.ServerName == "" {
+		sni, _, err := net.SplitHostPort(hostname)
+		if err != nil {
+			// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
+			sni = hostname
+		}
+		tlsConf.ServerName = sni
+	}
+	// Replace existing ALPNs by H3
+	tlsConf.NextProtos = []string{versionToALPN(t.QUICConfig.Versions[0])}
+
+	dial := t.Dial
+	if dial == nil {
+		if t.transport == nil {
+			udpConn, err := net.ListenUDP("udp", nil)
+			if err != nil {
+				return nil, nil, err
+			}
+			t.transport = &quic.Transport{Conn: udpConn}
+		}
+		dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
+			network := "udp"
+			udpAddr, err := t.resolveUDPAddr(ctx, network, addr)
+			if err != nil {
+				return nil, err
+			}
+			trace := httptrace.ContextClientTrace(ctx)
+			traceConnectStart(trace, network, udpAddr.String())
+			traceTLSHandshakeStart(trace)
+			conn, err := t.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
+			var state tls.ConnectionState
+			if conn != nil {
+				state = conn.ConnectionState().TLS
+			}
+			traceTLSHandshakeDone(trace, state, err)
+			traceConnectDone(trace, network, udpAddr.String(), err)
+			return conn, err
+		}
+	}
+	conn, err := dial(ctx, hostname, tlsConf, t.QUICConfig)
+	if err != nil {
+		return nil, nil, err
+	}
+	return conn, t.newClientConn(conn), nil
+}
+
+func (t *Transport) resolveUDPAddr(ctx context.Context, network, addr string) (*net.UDPAddr, error) {
+	host, portStr, err := net.SplitHostPort(addr)
+	if err != nil {
+		return nil, err
+	}
+	port, err := net.LookupPort(network, portStr)
+	if err != nil {
+		return nil, err
+	}
+	resolver := net.DefaultResolver
+	ipAddrs, err := resolver.LookupIPAddr(ctx, host)
+	if err != nil {
+		return nil, err
+	}
+	addrs := addrList(ipAddrs)
+	ip := addrs.forResolve(network, addr)
+	return &net.UDPAddr{IP: ip.IP, Port: port, Zone: ip.Zone}, nil
+}
+
+func (t *Transport) removeClient(hostname string) {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	if t.clients == nil {
+		return
+	}
+	delete(t.clients, hostname)
+}
+
+// NewClientConn creates a new HTTP/3 client connection on top of a QUIC connection.
+// Most users should use RoundTrip instead of creating a connection directly.
+// Specifically, it is not needed to perform GET, POST, HEAD and CONNECT requests.
+//
+// Obtaining a ClientConn is only needed for more advanced use cases, such as
+// using Extended CONNECT for WebTransport or the various MASQUE protocols.
+func (t *Transport) NewClientConn(conn quic.Connection) *ClientConn {
+	return newClientConn(
+		conn,
+		t.EnableDatagrams,
+		t.AdditionalSettings,
+		t.StreamHijacker,
+		t.UniStreamHijacker,
+		t.MaxResponseHeaderBytes,
+		t.DisableCompression,
+		t.Logger,
+	)
+}
+
+// Close closes the QUIC connections that this Transport has used.
+func (t *Transport) Close() error {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	for _, cl := range t.clients {
+		if err := cl.Close(); err != nil {
+			return err
+		}
+	}
+	t.clients = nil
+	if t.transport != nil {
+		if err := t.transport.Close(); err != nil {
+			return err
+		}
+		if err := t.transport.Conn.Close(); err != nil {
+			return err
+		}
+		t.transport = nil
+	}
+	return nil
+}
+
+func validMethod(method string) bool {
+	/*
+				     Method         = "OPTIONS"                ; Section 9.2
+		   		                    | "GET"                    ; Section 9.3
+		   		                    | "HEAD"                   ; Section 9.4
+		   		                    | "POST"                   ; Section 9.5
+		   		                    | "PUT"                    ; Section 9.6
+		   		                    | "DELETE"                 ; Section 9.7
+		   		                    | "TRACE"                  ; Section 9.8
+		   		                    | "CONNECT"                ; Section 9.9
+		   		                    | extension-method
+		   		   extension-method = token
+		   		     token          = 1*<any CHAR except CTLs or separators>
+	*/
+	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
+}
+
+// copied from net/http/http.go
+func isNotToken(r rune) bool {
+	return !httpguts.IsTokenRune(r)
+}
+
+// CloseIdleConnections closes any QUIC connections in the transport's pool that are currently idle.
+// An idle connection is one that was previously used for requests but is now sitting unused.
+// This method does not interrupt any connections currently in use.
+// It also does not affect connections obtained via NewClientConn.
+func (t *Transport) CloseIdleConnections() {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	for hostname, cl := range t.clients {
+		if cl.useCount.Load() == 0 {
+			cl.Close()
+			delete(t.clients, hostname)
+		}
+	}
+}

+ 50 - 33
vendor/github.com/Psiphon-Labs/quic-go/interface.go

@@ -18,8 +18,8 @@ import (
 // The StreamID is the ID of a QUIC stream.
 type StreamID = protocol.StreamID
 
-// A VersionNumber is a QUIC version number.
-type VersionNumber = protocol.VersionNumber
+// A Version is a QUIC version number.
+type Version = protocol.Version
 
 const (
 	// Version1 is RFC 9000
@@ -55,8 +55,13 @@ var Err0RTTRejected = errors.New("0-RTT rejected")
 // ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection.
 // It is set on the Connection.Context() context,
 // as well as on the context passed to logging.Tracer.NewConnectionTracer.
+// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
 var ConnectionTracingKey = connTracingCtxKey{}
 
+// ConnectionTracingID is the type of the context value saved under the ConnectionTracingKey.
+// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
+type ConnectionTracingID uint64
+
 type connTracingCtxKey struct{}
 
 // QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the
@@ -82,8 +87,8 @@ type ReceiveStream interface {
 	// Read reads data from the stream.
 	// Read can be made to time out and return a net.Error with Timeout() == true
 	// after a fixed time limit; see SetDeadline and SetReadDeadline.
-	// If the stream was canceled by the peer, the error implements the StreamError
-	// interface, and Canceled() == true.
+	// If the stream was canceled by the peer, the error is a StreamError and
+	// Remote == true.
 	// If the connection was closed due to a timeout, the error satisfies
 	// the net.Error interface, and Timeout() will be true.
 	io.Reader
@@ -95,7 +100,6 @@ type ReceiveStream interface {
 	// SetReadDeadline sets the deadline for future Read calls and
 	// any currently-blocked Read call.
 	// A zero value for t means Read will not time out.
-
 	SetReadDeadline(t time.Time) error
 }
 
@@ -106,8 +110,8 @@ type SendStream interface {
 	// Write writes data to the stream.
 	// Write can be made to time out and return a net.Error with Timeout() == true
 	// after a fixed time limit; see SetDeadline and SetWriteDeadline.
-	// If the stream was canceled by the peer, the error implements the StreamError
-	// interface, and Canceled() == true.
+	// If the stream was canceled by the peer, the error is a StreamError and
+	// Remote == true.
 	// If the connection was closed due to a timeout, the error satisfies
 	// the net.Error interface, and Timeout() will be true.
 	io.Writer
@@ -119,7 +123,9 @@ type SendStream interface {
 	// CancelWrite aborts sending on this stream.
 	// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
 	// Write will unblock immediately, and future calls to Write will fail.
-	// When called multiple times or after closing the stream it is a no-op.
+	// When called multiple times it is a no-op.
+	// When called after Close, it aborts delivery. Note that there is no guarantee if
+	// the peer will receive the FIN or the reset first.
 	CancelWrite(StreamErrorCode)
 	// The Context is canceled as soon as the write-side of the stream is closed.
 	// This happens when Close() or CancelWrite() is called, or when the peer
@@ -141,7 +147,7 @@ type SendStream interface {
 // * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer)
 // * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error)
 // * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error)
-// * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error)
+// * StatelessResetError: when we receive a stateless reset
 // * VersionNegotiationError: returned by the client, when there's no version overlap between the peers
 type Connection interface {
 	// AcceptStream returns the next stream opened by the peer, blocking until one is available.
@@ -154,25 +160,29 @@ type Connection interface {
 	AcceptUniStream(context.Context) (ReceiveStream, error)
 	// OpenStream opens a new bidirectional QUIC stream.
 	// There is no signaling to the peer about new streams:
-	// The peer can only accept the stream after data has been sent on the stream.
-	// If the error is non-nil, it satisfies the net.Error interface.
-	// When reaching the peer's stream limit, err.Temporary() will be true.
-	// If the connection was closed due to a timeout, Timeout() will be true.
+	// The peer can only accept the stream after data has been sent on the stream,
+	// or the stream has been reset or closed.
+	// When reaching the peer's stream limit, it is not possible to open a new stream until the
+	// peer raises the stream limit. In that case, a StreamLimitReachedError is returned.
 	OpenStream() (Stream, error)
 	// OpenStreamSync opens a new bidirectional QUIC stream.
 	// It blocks until a new stream can be opened.
-	// If the error is non-nil, it satisfies the net.Error interface.
-	// If the connection was closed due to a timeout, Timeout() will be true.
+	// There is no signaling to the peer about new streams:
+	// The peer can only accept the stream after data has been sent on the stream,
+	// or the stream has been reset or closed.
 	OpenStreamSync(context.Context) (Stream, error)
 	// OpenUniStream opens a new outgoing unidirectional QUIC stream.
-	// If the error is non-nil, it satisfies the net.Error interface.
-	// When reaching the peer's stream limit, Temporary() will be true.
-	// If the connection was closed due to a timeout, Timeout() will be true.
+	// There is no signaling to the peer about new streams:
+	// The peer can only accept the stream after data has been sent on the stream,
+	// or the stream has been reset or closed.
+	// When reaching the peer's stream limit, it is not possible to open a new stream until the
+	// peer raises the stream limit. In that case, a StreamLimitReachedError is returned.
 	OpenUniStream() (SendStream, error)
 	// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
 	// It blocks until a new stream can be opened.
-	// If the error is non-nil, it satisfies the net.Error interface.
-	// If the connection was closed due to a timeout, Timeout() will be true.
+	// There is no signaling to the peer about new streams:
+	// The peer can only accept the stream after data has been sent on the stream,
+	// or the stream has been reset or closed.
 	OpenUniStreamSync(context.Context) (SendStream, error)
 	// LocalAddr returns the local address.
 	LocalAddr() net.Addr
@@ -217,7 +227,7 @@ type EarlyConnection interface {
 	// however the client's identity is only verified once the handshake completes.
 	HandshakeComplete() <-chan struct{}
 
-	NextConnection() Connection
+	NextConnection(context.Context) (Connection, error)
 }
 
 // StatelessResetKey is a key used to derive stateless reset tokens.
@@ -262,7 +272,7 @@ type Config struct {
 	GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
 	// The QUIC versions that can be negotiated.
 	// If not set, it uses all versions available.
-	Versions []VersionNumber
+	Versions []Version
 	// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
 	// If we don't receive any packet from the peer within this time, the connection attempt is aborted.
 	// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
@@ -274,11 +284,6 @@ type Config struct {
 	// If the timeout is exceeded, the connection is closed.
 	// If this value is zero, the timeout is set to 30 seconds.
 	MaxIdleTimeout time.Duration
-	// RequireAddressValidation determines if a QUIC Retry packet is sent.
-	// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
-	// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
-	// If not set, every client is forced to prove its remote address.
-	RequireAddressValidation func(net.Addr) bool
 	// The TokenStore stores tokens received from the server.
 	// Tokens are used to skip address validation on future connection attempts.
 	// The key used to store tokens is the ServerName from the tls.Config, if set
@@ -325,10 +330,15 @@ type Config struct {
 	// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
 	// every half of MaxIdleTimeout, whichever is smaller).
 	KeepAlivePeriod time.Duration
+	// InitialPacketSize is the initial size of packets sent.
+	// It is usually not necessary to manually set this value,
+	// since Path MTU discovery very quickly finds the path's MTU.
+	// If set too high, the path might not support packets that large, leading to a timeout of the QUIC handshake.
+	// Values below 1200 are invalid.
+	InitialPacketSize uint16
 	// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
 	// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
 	// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
-	// If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
 	DisablePathMTUDiscovery bool
 	// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
 	// Only valid for the server.
@@ -372,23 +382,30 @@ type Config struct {
 	ServerMaxPacketSizeAdjustment func(net.Addr) int
 }
 
+// ClientHelloInfo contains information about an incoming connection attempt.
 type ClientHelloInfo struct {
+	// RemoteAddr is the remote address on the Initial packet.
+	// Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address.
 	RemoteAddr net.Addr
+	// AddrVerified says if the remote address was verified using QUIC's Retry mechanism.
+	// Note that the Retry mechanism costs one network roundtrip,
+	// and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed.
+	AddrVerified bool
 }
 
 // ConnectionState records basic details about a QUIC connection
 type ConnectionState struct {
 	// TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
 	TLS tls.ConnectionState
-	// SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated.
-	// This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams).
-	// If datagram support was negotiated, datagrams can be sent and received using the
-	// SendDatagram and ReceiveDatagram methods on the Connection.
+	// SupportsDatagrams indicates whether the peer advertised support for QUIC datagrams (RFC 9221).
+	// When true, datagrams can be sent using the Connection's SendDatagram method.
+	// This is a unilateral declaration by the peer - receiving datagrams is only possible if
+	// datagram support was enabled locally via Config.EnableDatagrams.
 	SupportsDatagrams bool
 	// Used0RTT says if 0-RTT resumption was used.
 	Used0RTT bool
 	// Version is the QUIC version of the QUIC connection.
-	Version VersionNumber
+	Version Version
 	// GSO says if generic segmentation offload is used
 	GSO bool
 

+ 1 - 1
vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/ackhandler.go

@@ -20,5 +20,5 @@ func NewAckHandler(
 	logger utils.Logger,
 ) (SentPacketHandler, ReceivedPacketHandler) {
 	sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
-	return sph, newReceivedPacketHandler(sph, rttStats, logger)
+	return sph, newReceivedPacketHandler(sph, logger)
 }

+ 6 - 7
vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/interfaces.go

@@ -14,10 +14,9 @@ type SentPacketHandler interface {
 	// ReceivedAck processes an ACK frame.
 	// It does not store a copy of the frame.
 	ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error)
-	ReceivedBytes(protocol.ByteCount)
-	DropPackets(protocol.EncryptionLevel)
-	ResetForRetry(rcvTime time.Time) error
-	SetHandshakeConfirmed()
+	ReceivedBytes(_ protocol.ByteCount, rcvTime time.Time)
+	DropPackets(_ protocol.EncryptionLevel, rcvTime time.Time)
+	ResetForRetry(rcvTime time.Time)
 
 	// The SendMode determines if and what kind of packets can be sent.
 	SendMode(now time.Time) SendMode
@@ -34,12 +33,12 @@ type SentPacketHandler interface {
 	PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
 
 	GetLossDetectionTimeout() time.Time
-	OnLossDetectionTimeout() error
+	OnLossDetectionTimeout(now time.Time) error
 }
 
 type sentPacketTracker interface {
 	GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
-	ReceivedPacket(protocol.EncryptionLevel)
+	ReceivedPacket(_ protocol.EncryptionLevel, rcvTime time.Time)
 }
 
 // ReceivedPacketHandler handles ACKs needed to send for incoming packets
@@ -49,5 +48,5 @@ type ReceivedPacketHandler interface {
 	DropPackets(protocol.EncryptionLevel)
 
 	GetAlarmTimeout() time.Time
-	GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame
+	GetAckFrame(_ protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame
 }

+ 14 - 31
vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/received_packet_handler.go

@@ -14,23 +14,19 @@ type receivedPacketHandler struct {
 
 	initialPackets   *receivedPacketTracker
 	handshakePackets *receivedPacketTracker
-	appDataPackets   *receivedPacketTracker
+	appDataPackets   appDataReceivedPacketTracker
 
 	lowest1RTTPacket protocol.PacketNumber
 }
 
 var _ ReceivedPacketHandler = &receivedPacketHandler{}
 
-func newReceivedPacketHandler(
-	sentPackets sentPacketTracker,
-	rttStats *utils.RTTStats,
-	logger utils.Logger,
-) ReceivedPacketHandler {
+func newReceivedPacketHandler(sentPackets sentPacketTracker, logger utils.Logger) ReceivedPacketHandler {
 	return &receivedPacketHandler{
 		sentPackets:      sentPackets,
-		initialPackets:   newReceivedPacketTracker(rttStats, logger),
-		handshakePackets: newReceivedPacketTracker(rttStats, logger),
-		appDataPackets:   newReceivedPacketTracker(rttStats, logger),
+		initialPackets:   newReceivedPacketTracker(),
+		handshakePackets: newReceivedPacketTracker(),
+		appDataPackets:   *newAppDataReceivedPacketTracker(logger),
 		lowest1RTTPacket: protocol.InvalidPacketNumber,
 	}
 }
@@ -42,7 +38,7 @@ func (h *receivedPacketHandler) ReceivedPacket(
 	rcvTime time.Time,
 	ackEliciting bool,
 ) error {
-	h.sentPackets.ReceivedPacket(encLevel)
+	h.sentPackets.ReceivedPacket(encLevel, rcvTime)
 	switch encLevel {
 	case protocol.EncryptionInitial:
 		return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
@@ -88,41 +84,28 @@ func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
 }
 
 func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
-	var initialAlarm, handshakeAlarm time.Time
-	if h.initialPackets != nil {
-		initialAlarm = h.initialPackets.GetAlarmTimeout()
-	}
-	if h.handshakePackets != nil {
-		handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
-	}
-	oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
-	return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
+	return h.appDataPackets.GetAlarmTimeout()
 }
 
-func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
-	var ack *wire.AckFrame
+func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame {
 	//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
 	switch encLevel {
 	case protocol.EncryptionInitial:
 		if h.initialPackets != nil {
-			ack = h.initialPackets.GetAckFrame(onlyIfQueued)
+			return h.initialPackets.GetAckFrame()
 		}
+		return nil
 	case protocol.EncryptionHandshake:
 		if h.handshakePackets != nil {
-			ack = h.handshakePackets.GetAckFrame(onlyIfQueued)
+			return h.handshakePackets.GetAckFrame()
 		}
+		return nil
 	case protocol.Encryption1RTT:
-		// 0-RTT packets can't contain ACK frames
-		return h.appDataPackets.GetAckFrame(onlyIfQueued)
+		return h.appDataPackets.GetAckFrame(now, onlyIfQueued)
 	default:
+		// 0-RTT packets can't contain ACK frames
 		return nil
 	}
-	// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
-	// Set it to 0 in order to save bytes.
-	if ack != nil {
-		ack.DelayTime = 0
-	}
-	return ack
 }
 
 func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {

+ 47 - 57
vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/received_packet_history.go

@@ -1,10 +1,9 @@
 package ackhandler
 
 import (
-	"sync"
+	"slices"
 
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
-	list "github.com/Psiphon-Labs/quic-go/internal/utils/linkedlist"
 	"github.com/Psiphon-Labs/quic-go/internal/wire"
 )
 
@@ -14,25 +13,17 @@ type interval struct {
 	End   protocol.PacketNumber
 }
 
-var intervalElementPool sync.Pool
-
-func init() {
-	intervalElementPool = *list.NewPool[interval]()
-}
-
 // The receivedPacketHistory stores if a packet number has already been received.
 // It generates ACK ranges which can be used to assemble an ACK frame.
 // It does not store packet contents.
 type receivedPacketHistory struct {
-	ranges *list.List[interval]
+	ranges []interval // maximum length: protocol.MaxNumAckRanges
 
 	deletedBelow protocol.PacketNumber
 }
 
 func newReceivedPacketHistory() *receivedPacketHistory {
-	return &receivedPacketHistory{
-		ranges: list.NewWithPool[interval](&intervalElementPool),
-	}
+	return &receivedPacketHistory{}
 }
 
 // ReceivedPacket registers a packet with PacketNumber p and updates the ranges
@@ -41,58 +32,54 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /*
 	if p < h.deletedBelow {
 		return false
 	}
+
 	isNew := h.addToRanges(p)
-	h.maybeDeleteOldRanges()
+	// Delete old ranges, if we're tracking too many of them.
+	// This is a DoS defense against a peer that sends us too many gaps.
+	if len(h.ranges) > protocol.MaxNumAckRanges {
+		h.ranges = slices.Delete(h.ranges, 0, len(h.ranges)-protocol.MaxNumAckRanges)
+	}
 	return isNew
 }
 
 func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
-	if h.ranges.Len() == 0 {
-		h.ranges.PushBack(interval{Start: p, End: p})
+	if len(h.ranges) == 0 {
+		h.ranges = append(h.ranges, interval{Start: p, End: p})
 		return true
 	}
 
-	for el := h.ranges.Back(); el != nil; el = el.Prev() {
+	for i := len(h.ranges) - 1; i >= 0; i-- {
 		// p already included in an existing range. Nothing to do here
-		if p >= el.Value.Start && p <= el.Value.End {
+		if p >= h.ranges[i].Start && p <= h.ranges[i].End {
 			return false
 		}
 
-		if el.Value.End == p-1 { // extend a range at the end
-			el.Value.End = p
+		if h.ranges[i].End == p-1 { // extend a range at the end
+			h.ranges[i].End = p
 			return true
 		}
-		if el.Value.Start == p+1 { // extend a range at the beginning
-			el.Value.Start = p
+		if h.ranges[i].Start == p+1 { // extend a range at the beginning
+			h.ranges[i].Start = p
 
-			prev := el.Prev()
-			if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
-				prev.Value.End = el.Value.End
-				h.ranges.Remove(el)
+			if i > 0 && h.ranges[i-1].End+1 == h.ranges[i].Start { // merge two ranges
+				h.ranges[i-1].End = h.ranges[i].End
+				h.ranges = slices.Delete(h.ranges, i, i+1)
 			}
 			return true
 		}
 
-		// create a new range at the end
-		if p > el.Value.End {
-			h.ranges.InsertAfter(interval{Start: p, End: p}, el)
+		// create a new range after the current one
+		if p > h.ranges[i].End {
+			h.ranges = slices.Insert(h.ranges, i+1, interval{Start: p, End: p})
 			return true
 		}
 	}
 
 	// create a new range at the beginning
-	h.ranges.InsertBefore(interval{Start: p, End: p}, h.ranges.Front())
+	h.ranges = slices.Insert(h.ranges, 0, interval{Start: p, End: p})
 	return true
 }
 
-// Delete old ranges, if we're tracking more than 500 of them.
-// This is a DoS defense against a peer that sends us too many gaps.
-func (h *receivedPacketHistory) maybeDeleteOldRanges() {
-	for h.ranges.Len() > protocol.MaxNumAckRanges {
-		h.ranges.Remove(h.ranges.Front())
-	}
-}
-
 // DeleteBelow deletes all entries below (but not including) p
 func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
 	if p < h.deletedBelow {
@@ -100,37 +87,39 @@ func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
 	}
 	h.deletedBelow = p
 
-	nextEl := h.ranges.Front()
-	for el := h.ranges.Front(); nextEl != nil; el = nextEl {
-		nextEl = el.Next()
+	if len(h.ranges) == 0 {
+		return
+	}
 
-		if el.Value.End < p { // delete a whole range
-			h.ranges.Remove(el)
-		} else if p > el.Value.Start && p <= el.Value.End {
-			el.Value.Start = p
-			return
+	idx := -1
+	for i := 0; i < len(h.ranges); i++ {
+		if h.ranges[i].End < p { // delete a whole range
+			idx = i
+		} else if p > h.ranges[i].Start && p <= h.ranges[i].End {
+			h.ranges[i].Start = p
+			break
 		} else { // no ranges affected. Nothing to do
-			return
+			break
 		}
 	}
+	if idx >= 0 {
+		h.ranges = slices.Delete(h.ranges, 0, idx+1)
+	}
 }
 
 // AppendAckRanges appends to a slice of all AckRanges that can be used in an AckFrame
 func (h *receivedPacketHistory) AppendAckRanges(ackRanges []wire.AckRange) []wire.AckRange {
-	if h.ranges.Len() > 0 {
-		for el := h.ranges.Back(); el != nil; el = el.Prev() {
-			ackRanges = append(ackRanges, wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End})
-		}
+	for i := len(h.ranges) - 1; i >= 0; i-- {
+		ackRanges = append(ackRanges, wire.AckRange{Smallest: h.ranges[i].Start, Largest: h.ranges[i].End})
 	}
 	return ackRanges
 }
 
 func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
 	ackRange := wire.AckRange{}
-	if h.ranges.Len() > 0 {
-		r := h.ranges.Back().Value
-		ackRange.Smallest = r.Start
-		ackRange.Largest = r.End
+	if len(h.ranges) > 0 {
+		ackRange.Smallest = h.ranges[len(h.ranges)-1].Start
+		ackRange.Largest = h.ranges[len(h.ranges)-1].End
 	}
 	return ackRange
 }
@@ -139,11 +128,12 @@ func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber)
 	if p < h.deletedBelow {
 		return true
 	}
-	for el := h.ranges.Back(); el != nil; el = el.Prev() {
-		if p > el.Value.End {
+	// Iterating over the slices is faster than using a binary search (using slices.BinarySearchFunc).
+	for i := len(h.ranges) - 1; i >= 0; i-- {
+		if p > h.ranges[i].End {
 			return false
 		}
-		if p <= el.Value.End && p >= el.Value.Start {
+		if p <= h.ranges[i].End && p >= h.ranges[i].Start {
 			return true
 		}
 	}

+ 91 - 67
vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/received_packet_tracker.go

@@ -9,40 +9,19 @@ import (
 	"github.com/Psiphon-Labs/quic-go/internal/wire"
 )
 
-// number of ack-eliciting packets received before sending an ack.
-const packetsBeforeAck = 2
-
+// The receivedPacketTracker tracks packets for the Initial and Handshake packet number space.
+// Every received packet is acknowledged immediately.
 type receivedPacketTracker struct {
-	largestObserved         protocol.PacketNumber
-	ignoreBelow             protocol.PacketNumber
-	largestObservedRcvdTime time.Time
-	ect0, ect1, ecnce       uint64
+	ect0, ect1, ecnce uint64
 
-	packetHistory *receivedPacketHistory
-
-	maxAckDelay time.Duration
-	rttStats    *utils.RTTStats
+	packetHistory receivedPacketHistory
 
+	lastAck   *wire.AckFrame
 	hasNewAck bool // true as soon as we received an ack-eliciting new packet
-	ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
-
-	ackElicitingPacketsReceivedSinceLastAck int
-	ackAlarm                                time.Time
-	lastAck                                 *wire.AckFrame
-
-	logger utils.Logger
 }
 
-func newReceivedPacketTracker(
-	rttStats *utils.RTTStats,
-	logger utils.Logger,
-) *receivedPacketTracker {
-	return &receivedPacketTracker{
-		packetHistory: newReceivedPacketHistory(),
-		maxAckDelay:   protocol.MaxAckDelay,
-		rttStats:      rttStats,
-		logger:        logger,
-	}
+func newReceivedPacketTracker() *receivedPacketTracker {
+	return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()}
 }
 
 func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
@@ -50,12 +29,6 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
 		return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
 	}
 
-	isMissing := h.isMissing(pn)
-	if pn >= h.largestObserved {
-		h.largestObserved = pn
-		h.largestObservedRcvdTime = rcvTime
-	}
-
 	//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE.
 	switch ecn {
 	case protocol.ECT0:
@@ -65,13 +38,82 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
 	case protocol.ECNCE:
 		h.ecnce++
 	}
-
 	if !ackEliciting {
 		return nil
 	}
-
 	h.hasNewAck = true
+	return nil
+}
+
+func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame {
+	if !h.hasNewAck {
+		return nil
+	}
+
+	// This function always returns the same ACK frame struct, filled with the most recent values.
+	ack := h.lastAck
+	if ack == nil {
+		ack = &wire.AckFrame{}
+	}
+	ack.Reset()
+	ack.ECT0 = h.ect0
+	ack.ECT1 = h.ect1
+	ack.ECNCE = h.ecnce
+	ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
+
+	h.lastAck = ack
+	h.hasNewAck = false
+	return ack
+}
+
+func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
+	return h.packetHistory.IsPotentiallyDuplicate(pn)
+}
+
+// number of ack-eliciting packets received before sending an ACK
+const packetsBeforeAck = 2
+
+// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space.
+// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached.
+type appDataReceivedPacketTracker struct {
+	receivedPacketTracker
+
+	largestObservedRcvdTime time.Time
+
+	largestObserved protocol.PacketNumber
+	ignoreBelow     protocol.PacketNumber
+
+	maxAckDelay time.Duration
+	ackQueued   bool // true if we need send a new ACK
+
+	ackElicitingPacketsReceivedSinceLastAck int
+	ackAlarm                                time.Time
+
+	logger utils.Logger
+}
+
+func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker {
+	h := &appDataReceivedPacketTracker{
+		receivedPacketTracker: *newReceivedPacketTracker(),
+		maxAckDelay:           protocol.MaxAckDelay,
+		logger:                logger,
+	}
+	return h
+}
+
+func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
+	if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
+		return err
+	}
+	if pn >= h.largestObserved {
+		h.largestObserved = pn
+		h.largestObservedRcvdTime = rcvTime
+	}
+	if !ackEliciting {
+		return nil
+	}
 	h.ackElicitingPacketsReceivedSinceLastAck++
+	isMissing := h.isMissing(pn)
 	if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) {
 		h.ackQueued = true
 		h.ackAlarm = time.Time{} // cancel the ack alarm
@@ -88,7 +130,7 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
 
 // IgnoreBelow sets a lower limit for acknowledging packets.
 // Packets with packet numbers smaller than p will not be acked.
-func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
+func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
 	if pn <= h.ignoreBelow {
 		return
 	}
@@ -100,14 +142,14 @@ func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
 }
 
 // isMissing says if a packet was reported missing in the last ACK.
-func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
+func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
 	if h.lastAck == nil || p < h.ignoreBelow {
 		return false
 	}
 	return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
 }
 
-func (h *receivedPacketTracker) hasNewMissingPackets() bool {
+func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool {
 	if h.lastAck == nil {
 		return false
 	}
@@ -115,7 +157,7 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool {
 	return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
 }
 
-func (h *receivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool {
+func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool {
 	// always acknowledge the first packet
 	if h.lastAck == nil {
 		h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
@@ -124,7 +166,7 @@ func (h *receivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn pro
 
 	// Send an ACK if this packet was reported missing in an ACK sent before.
 	// Ack decimation with reordering relies on the timer to send an ACK, but if
-	// missing packets we reported in the previous ack, send an ACK immediately.
+	// missing packets we reported in the previous ACK, send an ACK immediately.
 	if wasMissing {
 		if h.logger.Debug() {
 			h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
@@ -154,42 +196,24 @@ func (h *receivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn pro
 	return false
 }
 
-func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
-	if !h.hasNewAck {
-		return nil
-	}
-	now := time.Now()
-	if onlyIfQueued {
-		if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
+func (h *appDataReceivedPacketTracker) GetAckFrame(now time.Time, onlyIfQueued bool) *wire.AckFrame {
+	if onlyIfQueued && !h.ackQueued {
+		if h.ackAlarm.IsZero() || h.ackAlarm.After(now) {
 			return nil
 		}
-		if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
+		if h.logger.Debug() && !h.ackAlarm.IsZero() {
 			h.logger.Debugf("Sending ACK because the ACK timer expired.")
 		}
 	}
-
-	// This function always returns the same ACK frame struct, filled with the most recent values.
-	ack := h.lastAck
+	ack := h.receivedPacketTracker.GetAckFrame()
 	if ack == nil {
-		ack = &wire.AckFrame{}
+		return nil
 	}
-	ack.Reset()
 	ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime))
-	ack.ECT0 = h.ect0
-	ack.ECT1 = h.ect1
-	ack.ECNCE = h.ecnce
-	ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
-
-	h.lastAck = ack
-	h.ackAlarm = time.Time{}
 	h.ackQueued = false
-	h.hasNewAck = false
+	h.ackAlarm = time.Time{}
 	h.ackElicitingPacketsReceivedSinceLastAck = 0
 	return ack
 }
 
-func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
-
-func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
-	return h.packetHistory.IsPotentiallyDuplicate(pn)
-}
+func (h *appDataReceivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }

+ 76 - 89
vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/sent_packet_handler.go

@@ -28,7 +28,7 @@ const (
 )
 
 type packetNumberSpace struct {
-	history *sentPacketHistory
+	history sentPacketHistory
 	pns     packetNumberGenerator
 
 	lossTime                   time.Time
@@ -38,21 +38,27 @@ type packetNumberSpace struct {
 	largestSent  protocol.PacketNumber
 }
 
-func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool) *packetNumberSpace {
+func newPacketNumberSpace(initialPN protocol.PacketNumber, isAppData bool) *packetNumberSpace {
 	var pns packetNumberGenerator
-	if skipPNs {
+	if isAppData {
 		pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
 	} else {
 		pns = newSequentialPacketNumberGenerator(initialPN)
 	}
 	return &packetNumberSpace{
-		history:      newSentPacketHistory(),
+		history:      *newSentPacketHistory(isAppData),
 		pns:          pns,
 		largestSent:  protocol.InvalidPacketNumber,
 		largestAcked: protocol.InvalidPacketNumber,
 	}
 }
 
+type alarmTimer struct {
+	Time            time.Time
+	TimerType       logging.TimerType
+	EncryptionLevel protocol.EncryptionLevel
+}
+
 type sentPacketHandler struct {
 	initialPackets   *packetNumberSpace
 	handshakePackets *packetNumberSpace
@@ -90,7 +96,7 @@ type sentPacketHandler struct {
 	numProbesToSend int
 
 	// The alarm timeout
-	alarm time.Time
+	alarm alarmTimer
 
 	enableECN  bool
 	ecnTracker ecnHandler
@@ -155,7 +161,7 @@ func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
 	}
 }
 
-func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
+func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now time.Time) {
 	// The server won't await address validation after the handshake is confirmed.
 	// This applies even if we didn't receive an ACK for a Handshake packet.
 	if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake {
@@ -179,6 +185,9 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
 	case protocol.EncryptionInitial:
 		h.initialPackets = nil
 	case protocol.EncryptionHandshake:
+		// Dropping the handshake packet number space means that the handshake is confirmed,
+		// see section 4.9.2 of RFC 9001.
+		h.handshakeConfirmed = true
 		h.handshakePackets = nil
 	case protocol.Encryption0RTT:
 		// This function is only called when 0-RTT is rejected,
@@ -202,21 +211,21 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
 	h.ptoCount = 0
 	h.numProbesToSend = 0
 	h.ptoMode = SendNone
-	h.setLossDetectionTimer()
+	h.setLossDetectionTimer(now)
 }
 
-func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) {
+func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount, t time.Time) {
 	wasAmplificationLimit := h.isAmplificationLimited()
 	h.bytesReceived += n
 	if wasAmplificationLimit && !h.isAmplificationLimited() {
-		h.setLossDetectionTimer()
+		h.setLossDetectionTimer(t)
 	}
 }
 
-func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) {
+func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel, t time.Time) {
 	if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated {
 		h.peerAddressValidated = true
-		h.setLossDetectionTimer()
+		h.setLossDetectionTimer(t)
 	}
 }
 
@@ -269,7 +278,7 @@ func (h *sentPacketHandler) SentPacket(
 	if !isAckEliciting {
 		pnSpace.history.SentNonAckElicitingPacket(pn)
 		if !h.peerCompletedAddressValidation {
-			h.setLossDetectionTimer()
+			h.setLossDetectionTimer(t)
 		}
 		return
 	}
@@ -289,7 +298,7 @@ func (h *sentPacketHandler) SentPacket(
 	if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
 		h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
 	}
-	h.setLossDetectionTimer()
+	h.setLossDetectionTimer(t)
 }
 
 func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace {
@@ -322,7 +331,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
 		h.peerCompletedAddressValidation = true
 		h.logger.Debugf("Peer doesn't await address validation any longer.")
 		// Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets.
-		h.setLossDetectionTimer()
+		h.setLossDetectionTimer(rcvTime)
 	}
 
 	priorInFlight := h.bytesInFlight
@@ -338,7 +347,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
 			if encLevel == protocol.Encryption1RTT {
 				ackDelay = min(ack.DelayTime, h.rttStats.MaxAckDelay())
 			}
-			h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
+			h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay)
 			if h.logger.Debug() {
 				h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
 			}
@@ -387,7 +396,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
 		h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
 	}
 
-	h.setLossDetectionTimer()
+	h.setLossDetectionTimer(rcvTime)
 	return acked1RTTPacket, nil
 }
 
@@ -498,14 +507,14 @@ func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration
 }
 
 // same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime
-func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) {
+func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) {
 	// We only send application data probe packets once the handshake is confirmed,
 	// because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets.
 	if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() {
 		if h.peerCompletedAddressValidation {
 			return
 		}
-		t := time.Now().Add(h.getScaledPTO(false))
+		t := now.Add(h.getScaledPTO(false))
 		if h.initialPackets != nil {
 			return t, protocol.EncryptionInitial, true
 		}
@@ -545,61 +554,53 @@ func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
 	return false
 }
 
-func (h *sentPacketHandler) hasOutstandingPackets() bool {
-	return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets()
-}
-
-func (h *sentPacketHandler) setLossDetectionTimer() {
+func (h *sentPacketHandler) setLossDetectionTimer(now time.Time) {
 	oldAlarm := h.alarm // only needed in case tracing is enabled
-	lossTime, encLevel := h.getLossTimeAndSpace()
-	if !lossTime.IsZero() {
-		// Early retransmit timer or time loss detection.
-		h.alarm = lossTime
-		if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
-			h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm)
+	newAlarm := h.lossDetectionTime(now)
+	h.alarm = newAlarm
+
+	if newAlarm.Time.IsZero() && !oldAlarm.Time.IsZero() {
+		h.logger.Debugf("Canceling loss detection timer.")
+		if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
+			h.tracer.LossTimerCanceled()
 		}
-		return
 	}
 
-	// Cancel the alarm if amplification limited.
+	if h.tracer != nil && h.tracer.SetLossTimer != nil && newAlarm != oldAlarm {
+		h.tracer.SetLossTimer(newAlarm.TimerType, newAlarm.EncryptionLevel, newAlarm.Time)
+	}
+}
+
+func (h *sentPacketHandler) lossDetectionTime(now time.Time) alarmTimer {
+	// cancel the alarm if no packets are outstanding
+	if h.peerCompletedAddressValidation &&
+		!h.hasOutstandingCryptoPackets() && !h.appDataPackets.history.HasOutstandingPackets() {
+		return alarmTimer{}
+	}
+
+	// cancel the alarm if amplification limited
 	if h.isAmplificationLimited() {
-		h.alarm = time.Time{}
-		if !oldAlarm.IsZero() {
-			h.logger.Debugf("Canceling loss detection timer. Amplification limited.")
-			if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
-				h.tracer.LossTimerCanceled()
-			}
-		}
-		return
+		return alarmTimer{}
 	}
 
-	// Cancel the alarm if no packets are outstanding
-	if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation {
-		h.alarm = time.Time{}
-		if !oldAlarm.IsZero() {
-			h.logger.Debugf("Canceling loss detection timer. No packets in flight.")
-			if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
-				h.tracer.LossTimerCanceled()
-			}
+	// early retransmit timer or time loss detection
+	lossTime, encLevel := h.getLossTimeAndSpace()
+	if !lossTime.IsZero() {
+		return alarmTimer{
+			Time:            lossTime,
+			TimerType:       logging.TimerTypeACK,
+			EncryptionLevel: encLevel,
 		}
-		return
 	}
 
-	// PTO alarm
-	ptoTime, encLevel, ok := h.getPTOTimeAndSpace()
+	ptoTime, encLevel, ok := h.getPTOTimeAndSpace(now)
 	if !ok {
-		if !oldAlarm.IsZero() {
-			h.alarm = time.Time{}
-			h.logger.Debugf("Canceling loss detection timer. No PTO needed..")
-			if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
-				h.tracer.LossTimerCanceled()
-			}
-		}
-		return
+		return alarmTimer{}
 	}
-	h.alarm = ptoTime
-	if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
-		h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm)
+	return alarmTimer{
+		Time:            ptoTime,
+		TimerType:       logging.TimerTypePTO,
+		EncryptionLevel: encLevel,
 	}
 }
 
@@ -623,7 +624,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
 		}
 
 		var packetLost bool
-		if p.SendTime.Before(lostSendTime) {
+		if !p.SendTime.After(lostSendTime) {
 			packetLost = true
 			if !p.skippedPacket {
 				if h.logger.Debug() {
@@ -669,8 +670,8 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
 	})
 }
 
-func (h *sentPacketHandler) OnLossDetectionTimeout() error {
-	defer h.setLossDetectionTimer()
+func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error {
+	defer h.setLossDetectionTimer(now)
 	earliestLossTime, encLevel := h.getLossTimeAndSpace()
 	if !earliestLossTime.IsZero() {
 		if h.logger.Debug() {
@@ -680,13 +681,13 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
 			h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel)
 		}
 		// Early retransmit or time loss detection
-		return h.detectLostPackets(time.Now(), encLevel)
+		return h.detectLostPackets(now, encLevel)
 	}
 
 	// PTO
-	// When all outstanding are acknowledged, the alarm is canceled in
-	// setLossDetectionTimer. This doesn't reset the timer in the session though.
-	// When OnAlarm is called, we therefore need to make sure that there are
+	// When all outstanding are acknowledged, the alarm is canceled in setLossDetectionTimer.
+	// However, there's no way to reset the timer in the connection.
+	// When OnLossDetectionTimeout is called, we therefore need to make sure that there are
 	// actually packets outstanding.
 	if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation {
 		h.ptoCount++
@@ -701,7 +702,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
 		return nil
 	}
 
-	_, encLevel, ok := h.getPTOTimeAndSpace()
+	_, encLevel, ok := h.getPTOTimeAndSpace(now)
 	if !ok {
 		return nil
 	}
@@ -739,7 +740,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
 }
 
 func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
-	return h.alarm
+	return h.alarm.Time
 }
 
 func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN {
@@ -756,7 +757,7 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel)
 	pnSpace := h.getPacketNumberSpace(encLevel)
 	pn := pnSpace.pns.Peek()
 	// See section 17.1 of RFC 9000.
-	return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
+	return pn, protocol.PacketNumberLengthForHeader(pn, pnSpace.largestAcked)
 }
 
 func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
@@ -864,7 +865,7 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
 	p.Frames = nil
 }
 
-func (h *sentPacketHandler) ResetForRetry(now time.Time) error {
+func (h *sentPacketHandler) ResetForRetry(now time.Time) {
 	h.bytesInFlight = 0
 	var firstPacketSendTime time.Time
 	h.initialPackets.history.Iterate(func(p *packet) (bool, error) {
@@ -890,7 +891,7 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error {
 	// Otherwise, we don't know which Initial the Retry was sent in response to.
 	if h.ptoCount == 0 {
 		// Don't set the RTT to a value lower than 5ms here.
-		h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now)
+		h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0)
 		if h.logger.Debug() {
 			h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
 		}
@@ -901,28 +902,14 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error {
 	h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false)
 	h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true)
 	oldAlarm := h.alarm
-	h.alarm = time.Time{}
+	h.alarm = alarmTimer{}
 	if h.tracer != nil {
 		if h.tracer.UpdatedPTOCount != nil {
 			h.tracer.UpdatedPTOCount(0)
 		}
-		if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil {
+		if !oldAlarm.Time.IsZero() && h.tracer.LossTimerCanceled != nil {
 			h.tracer.LossTimerCanceled()
 		}
 	}
 	h.ptoCount = 0
-	return nil
-}
-
-func (h *sentPacketHandler) SetHandshakeConfirmed() {
-	if h.initialPackets != nil {
-		panic("didn't drop initial correctly")
-	}
-	if h.handshakePackets != nil {
-		panic("didn't drop handshake correctly")
-	}
-	h.handshakeConfirmed = true
-	// We don't send PTOs for application data packets before the handshake completes.
-	// Make sure the timer is armed now, if necessary.
-	h.setLossDetectionTimer()
 }

+ 8 - 3
vendor/github.com/Psiphon-Labs/quic-go/internal/ackhandler/sent_packet_history.go

@@ -14,11 +14,16 @@ type sentPacketHistory struct {
 	highestPacketNumber protocol.PacketNumber
 }
 
-func newSentPacketHistory() *sentPacketHistory {
-	return &sentPacketHistory{
-		packets:             make([]*packet, 0, 32),
+func newSentPacketHistory(isAppData bool) *sentPacketHistory {
+	h := &sentPacketHistory{
 		highestPacketNumber: protocol.InvalidPacketNumber,
 	}
+	if isAppData {
+		h.packets = make([]*packet, 0, 32)
+	} else {
+		h.packets = make([]*packet, 0, 6)
+	}
+	return h
 }
 
 func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) {

+ 4 - 4
vendor/github.com/Psiphon-Labs/quic-go/internal/congestion/cubic.go

@@ -17,11 +17,11 @@ import (
 // 1024*1024^3 (first 1024 is from 0.100^3)
 // where 0.100 is 100 ms which is the scaling round trip time.
 const (
-	cubeScale                                    = 40
-	cubeCongestionWindowScale                    = 410
-	cubeFactor                protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
+	cubeScale                 = 40
+	cubeCongestionWindowScale = 410
+	cubeFactor                = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
 	// TODO: when re-enabling cubic, make sure to use the actual packet size here
-	maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
+	maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
 )
 
 const defaultNumConnections = 1

+ 1 - 1
vendor/github.com/Psiphon-Labs/quic-go/internal/congestion/cubic_sender.go

@@ -12,7 +12,7 @@ import (
 const (
 	// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
 	// Used in QUIC for congestion window computations in bytes.
-	initialMaxDatagramSize     = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
+	initialMaxDatagramSize     = protocol.ByteCount(protocol.InitialPacketSize)
 	maxBurstPackets            = 3
 	renoBeta                   = 0.7 // Reno backoff factor.
 	minCongestionWindowPackets = 2

+ 8 - 12
vendor/github.com/Psiphon-Labs/quic-go/internal/flowcontrol/base_flow_controller.go

@@ -36,7 +36,7 @@ type baseFlowController struct {
 // For every offset, it only returns true once.
 // If it is blocked, the offset is returned.
 func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
-	if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
+	if c.SendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
 		return false, 0
 	}
 	c.lastBlockedAt = c.sendWindow
@@ -48,13 +48,15 @@ func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
 }
 
 // UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
-func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
+func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) {
 	if offset > c.sendWindow {
 		c.sendWindow = offset
+		return true
 	}
+	return false
 }
 
-func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
+func (c *baseFlowController) SendWindowSize() protocol.ByteCount {
 	// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
 	if c.bytesSent > c.sendWindow {
 		return 0
@@ -64,11 +66,6 @@ func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
 
 // needs to be called with locked mutex
 func (c *baseFlowController) addBytesRead(n protocol.ByteCount) {
-	// pretend we sent a WindowUpdate when reading the first byte
-	// this way auto-tuning of the window size already works for the first WindowUpdate
-	if c.bytesRead == 0 {
-		c.startNewAutoTuningEpoch(time.Now())
-	}
 	c.bytesRead += n
 }
 
@@ -80,19 +77,19 @@ func (c *baseFlowController) hasWindowUpdate() bool {
 
 // getWindowUpdate updates the receive window, if necessary
 // it returns the new offset
-func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
+func (c *baseFlowController) getWindowUpdate(now time.Time) protocol.ByteCount {
 	if !c.hasWindowUpdate() {
 		return 0
 	}
 
-	c.maybeAdjustWindowSize()
+	c.maybeAdjustWindowSize(now)
 	c.receiveWindow = c.bytesRead + c.receiveWindowSize
 	return c.receiveWindow
 }
 
 // maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
 // For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
-func (c *baseFlowController) maybeAdjustWindowSize() {
+func (c *baseFlowController) maybeAdjustWindowSize(now time.Time) {
 	bytesReadInEpoch := c.bytesRead - c.epochStartOffset
 	// don't do anything if less than half the window has been consumed
 	if bytesReadInEpoch <= c.receiveWindowSize/2 {
@@ -104,7 +101,6 @@ func (c *baseFlowController) maybeAdjustWindowSize() {
 	}
 
 	fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
-	now := time.Now()
 	if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
 		// window is consumed too fast, try to increase the window size
 		newSize := min(2*c.receiveWindowSize, c.maxReceiveWindowSize)

+ 30 - 29
vendor/github.com/Psiphon-Labs/quic-go/internal/flowcontrol/connection_flow_controller.go

@@ -12,8 +12,6 @@ import (
 
 type connectionFlowController struct {
 	baseFlowController
-
-	queueWindowUpdate func()
 }
 
 var _ ConnectionFlowController = &connectionFlowController{}
@@ -23,11 +21,10 @@ var _ ConnectionFlowController = &connectionFlowController{}
 func NewConnectionFlowController(
 	receiveWindow protocol.ByteCount,
 	maxReceiveWindow protocol.ByteCount,
-	queueWindowUpdate func(),
 	allowWindowIncrease func(size protocol.ByteCount) bool,
 	rttStats *utils.RTTStats,
 	logger utils.Logger,
-) ConnectionFlowController {
+) *connectionFlowController {
 	return &connectionFlowController{
 		baseFlowController: baseFlowController{
 			rttStats:             rttStats,
@@ -37,20 +34,20 @@ func NewConnectionFlowController(
 			allowWindowIncrease:  allowWindowIncrease,
 			logger:               logger,
 		},
-		queueWindowUpdate: queueWindowUpdate,
 	}
 }
 
-func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
-	return c.baseFlowController.sendWindowSize()
-}
-
 // IncrementHighestReceived adds an increment to the highestReceived value
-func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
+func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount, now time.Time) error {
 	c.mutex.Lock()
 	defer c.mutex.Unlock()
 
+	// If this is the first frame received on this connection, start flow-control auto-tuning.
+	if c.highestReceived == 0 {
+		c.startNewAutoTuningEpoch(now)
+	}
 	c.highestReceived += increment
+
 	if c.checkFlowControlViolation() {
 		return &qerr.TransportError{
 			ErrorCode:    qerr.FlowControlError,
@@ -60,44 +57,47 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
 	return nil
 }
 
-func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) {
+func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) (hasWindowUpdate bool) {
 	c.mutex.Lock()
+	defer c.mutex.Unlock()
+
 	c.baseFlowController.addBytesRead(n)
-	shouldQueueWindowUpdate := c.hasWindowUpdate()
-	c.mutex.Unlock()
-	if shouldQueueWindowUpdate {
-		c.queueWindowUpdate()
-	}
+	return c.baseFlowController.hasWindowUpdate()
 }
 
-func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
+func (c *connectionFlowController) GetWindowUpdate(now time.Time) protocol.ByteCount {
 	c.mutex.Lock()
+	defer c.mutex.Unlock()
+
 	oldWindowSize := c.receiveWindowSize
-	offset := c.baseFlowController.getWindowUpdate()
-	if oldWindowSize < c.receiveWindowSize {
+	offset := c.baseFlowController.getWindowUpdate(now)
+	if c.logger.Debug() && oldWindowSize < c.receiveWindowSize {
 		c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
 	}
-	c.mutex.Unlock()
 	return offset
 }
 
 // EnsureMinimumWindowSize sets a minimum window size
 // it should make sure that the connection-level window is increased when a stream-level window grows
-func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
+func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount, now time.Time) {
 	c.mutex.Lock()
-	if inc > c.receiveWindowSize {
-		c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
-		newSize := min(inc, c.maxReceiveWindowSize)
-		if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
-			c.receiveWindowSize = newSize
+	defer c.mutex.Unlock()
+
+	if inc <= c.receiveWindowSize {
+		return
+	}
+	newSize := min(inc, c.maxReceiveWindowSize)
+	if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
+		c.receiveWindowSize = newSize
+		if c.logger.Debug() {
+			c.logger.Debugf("Increasing receive flow control window for the connection to %d, in response to stream flow control window increase", newSize)
 		}
-		c.startNewAutoTuningEpoch(time.Now())
 	}
-	c.mutex.Unlock()
+	c.startNewAutoTuningEpoch(now)
 }
 
 // Reset rests the flow controller. This happens when 0-RTT is rejected.
-// All stream data is invalidated, it's if we had never opened a stream and never sent any data.
+// All stream data is invalidated, it's as if we had never opened a stream and never sent any data.
 // At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet.
 func (c *connectionFlowController) Reset() error {
 	c.mutex.Lock()
@@ -108,5 +108,6 @@ func (c *connectionFlowController) Reset() error {
 	}
 	c.bytesSent = 0
 	c.lastBlockedAt = 0
+	c.sendWindow = 0
 	return nil
 }

+ 16 - 11
vendor/github.com/Psiphon-Labs/quic-go/internal/flowcontrol/interface.go

@@ -1,42 +1,47 @@
 package flowcontrol
 
-import "github.com/Psiphon-Labs/quic-go/internal/protocol"
+import (
+	"time"
+
+	"github.com/Psiphon-Labs/quic-go/internal/protocol"
+)
 
 type flowController interface {
 	// for sending
 	SendWindowSize() protocol.ByteCount
-	UpdateSendWindow(protocol.ByteCount)
+	UpdateSendWindow(protocol.ByteCount) (updated bool)
 	AddBytesSent(protocol.ByteCount)
 	// for receiving
-	AddBytesRead(protocol.ByteCount)
-	GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
-	IsNewlyBlocked() (bool, protocol.ByteCount)
+	GetWindowUpdate(time.Time) protocol.ByteCount // returns 0 if no update is necessary
 }
 
 // A StreamFlowController is a flow controller for a QUIC stream.
 type StreamFlowController interface {
 	flowController
-	// for receiving
-	// UpdateHighestReceived should be called when a new highest offset is received
+	AddBytesRead(protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool)
+	// UpdateHighestReceived is called when a new highest offset is received
 	// final has to be to true if this is the final offset of the stream,
 	// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
-	UpdateHighestReceived(offset protocol.ByteCount, final bool) error
-	// Abandon should be called when reading from the stream is aborted early,
+	UpdateHighestReceived(offset protocol.ByteCount, final bool, now time.Time) error
+	// Abandon is called when reading from the stream is aborted early,
 	// and there won't be any further calls to AddBytesRead.
 	Abandon()
+	IsNewlyBlocked() bool
 }
 
 // The ConnectionFlowController is the flow controller for the connection.
 type ConnectionFlowController interface {
 	flowController
+	AddBytesRead(protocol.ByteCount) (hasWindowUpdate bool)
 	Reset() error
+	IsNewlyBlocked() (bool, protocol.ByteCount)
 }
 
 type connectionFlowControllerI interface {
 	ConnectionFlowController
 	// The following two methods are not supposed to be called from outside this packet, but are needed internally
 	// for sending
-	EnsureMinimumWindowSize(protocol.ByteCount)
+	EnsureMinimumWindowSize(protocol.ByteCount, time.Time)
 	// for receiving
-	IncrementHighestReceived(protocol.ByteCount) error
+	IncrementHighestReceived(protocol.ByteCount, time.Time) error
 }

+ 29 - 24
vendor/github.com/Psiphon-Labs/quic-go/internal/flowcontrol/stream_flow_controller.go

@@ -2,6 +2,7 @@ package flowcontrol
 
 import (
 	"fmt"
+	"time"
 
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
 	"github.com/Psiphon-Labs/quic-go/internal/qerr"
@@ -13,8 +14,6 @@ type streamFlowController struct {
 
 	streamID protocol.StreamID
 
-	queueWindowUpdate func()
-
 	connection connectionFlowControllerI
 
 	receivedFinalOffset bool
@@ -29,14 +28,12 @@ func NewStreamFlowController(
 	receiveWindow protocol.ByteCount,
 	maxReceiveWindow protocol.ByteCount,
 	initialSendWindow protocol.ByteCount,
-	queueWindowUpdate func(protocol.StreamID),
 	rttStats *utils.RTTStats,
 	logger utils.Logger,
 ) StreamFlowController {
 	return &streamFlowController{
-		streamID:          streamID,
-		connection:        cfc.(connectionFlowControllerI),
-		queueWindowUpdate: func() { queueWindowUpdate(streamID) },
+		streamID:   streamID,
+		connection: cfc.(connectionFlowControllerI),
 		baseFlowController: baseFlowController{
 			rttStats:             rttStats,
 			receiveWindow:        receiveWindow,
@@ -49,7 +46,7 @@ func NewStreamFlowController(
 }
 
 // UpdateHighestReceived updates the highestReceived value, if the offset is higher.
-func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error {
+func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool, now time.Time) error {
 	// If the final offset for this stream is already known, check for consistency.
 	if c.receivedFinalOffset {
 		// If we receive another final offset, check that it's the same.
@@ -74,9 +71,8 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount,
 	if offset == c.highestReceived {
 		return nil
 	}
-	// A higher offset was received before.
-	// This can happen due to reordering.
-	if offset <= c.highestReceived {
+	// A higher offset was received before. This can happen due to reordering.
+	if offset < c.highestReceived {
 		if final {
 			return &qerr.TransportError{
 				ErrorCode:    qerr.FinalSizeError,
@@ -86,31 +82,35 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount,
 		return nil
 	}
 
+	// If this is the first frame received for this stream, start flow-control auto-tuning.
+	if c.highestReceived == 0 {
+		c.startNewAutoTuningEpoch(now)
+	}
 	increment := offset - c.highestReceived
 	c.highestReceived = offset
+
 	if c.checkFlowControlViolation() {
 		return &qerr.TransportError{
 			ErrorCode:    qerr.FlowControlError,
 			ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow),
 		}
 	}
-	return c.connection.IncrementHighestReceived(increment)
+	return c.connection.IncrementHighestReceived(increment, now)
 }
 
-func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
+func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) {
 	c.mutex.Lock()
 	c.baseFlowController.addBytesRead(n)
-	shouldQueueWindowUpdate := c.shouldQueueWindowUpdate()
+	hasStreamWindowUpdate = c.shouldQueueWindowUpdate()
 	c.mutex.Unlock()
-	if shouldQueueWindowUpdate {
-		c.queueWindowUpdate()
-	}
-	c.connection.AddBytesRead(n)
+	hasConnWindowUpdate = c.connection.AddBytesRead(n)
+	return
 }
 
 func (c *streamFlowController) Abandon() {
 	c.mutex.Lock()
 	unread := c.highestReceived - c.bytesRead
+	c.bytesRead = c.highestReceived
 	c.mutex.Unlock()
 	if unread > 0 {
 		c.connection.AddBytesRead(unread)
@@ -123,27 +123,32 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
 }
 
 func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
-	return min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
+	return min(c.baseFlowController.SendWindowSize(), c.connection.SendWindowSize())
+}
+
+func (c *streamFlowController) IsNewlyBlocked() bool {
+	blocked, _ := c.baseFlowController.IsNewlyBlocked()
+	return blocked
 }
 
 func (c *streamFlowController) shouldQueueWindowUpdate() bool {
 	return !c.receivedFinalOffset && c.hasWindowUpdate()
 }
 
-func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
+func (c *streamFlowController) GetWindowUpdate(now time.Time) protocol.ByteCount {
 	// If we already received the final offset for this stream, the peer won't need any additional flow control credit.
 	if c.receivedFinalOffset {
 		return 0
 	}
 
-	// Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
 	c.mutex.Lock()
+	defer c.mutex.Unlock()
+
 	oldWindowSize := c.receiveWindowSize
-	offset := c.baseFlowController.getWindowUpdate()
+	offset := c.baseFlowController.getWindowUpdate(now)
 	if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
-		c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
-		c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
+		c.logger.Debugf("Increasing receive flow control window for stream %d to %d", c.streamID, c.receiveWindowSize)
+		c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize)*protocol.ConnectionFlowControlMultiplier), now)
 	}
-	c.mutex.Unlock()
 	return offset
 }

+ 18 - 21
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/aead.go

@@ -1,13 +1,12 @@
 package handshake
 
 import (
-	"crypto/cipher"
 	"encoding/binary"
 
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
 )
 
-func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
+func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD {
 	keyLabel := hkdfLabelKeyV1
 	ivLabel := hkdfLabelIVV1
 	if v == protocol.Version2 {
@@ -20,28 +19,26 @@ func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumb
 }
 
 type longHeaderSealer struct {
-	aead            cipher.AEAD
+	aead            *xorNonceAEAD
 	headerProtector headerProtector
-
-	// use a single slice to avoid allocations
-	nonceBuf []byte
+	nonceBuf        [8]byte
 }
 
 var _ LongHeaderSealer = &longHeaderSealer{}
 
-func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
+func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer {
+	if aead.NonceSize() != 8 {
+		panic("unexpected nonce size")
+	}
 	return &longHeaderSealer{
 		aead:            aead,
 		headerProtector: headerProtector,
-		nonceBuf:        make([]byte, aead.NonceSize()),
 	}
 }
 
 func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
-	binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
-	// The AEAD we're using here will be the qtls.aeadAESGCM13.
-	// It uses the nonce provided here and XOR it with the IV.
-	return s.aead.Seal(dst, s.nonceBuf, src, ad)
+	binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn))
+	return s.aead.Seal(dst, s.nonceBuf[:], src, ad)
 }
 
 func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
@@ -53,21 +50,23 @@ func (s *longHeaderSealer) Overhead() int {
 }
 
 type longHeaderOpener struct {
-	aead            cipher.AEAD
+	aead            *xorNonceAEAD
 	headerProtector headerProtector
 	highestRcvdPN   protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
 
-	// use a single slice to avoid allocations
-	nonceBuf []byte
+	// use a single array to avoid allocations
+	nonceBuf [8]byte
 }
 
 var _ LongHeaderOpener = &longHeaderOpener{}
 
-func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
+func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener {
+	if aead.NonceSize() != 8 {
+		panic("unexpected nonce size")
+	}
 	return &longHeaderOpener{
 		aead:            aead,
 		headerProtector: headerProtector,
-		nonceBuf:        make([]byte, aead.NonceSize()),
 	}
 }
 
@@ -76,10 +75,8 @@ func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wire
 }
 
 func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
-	binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
-	// The AEAD we're using here will be the qtls.aeadAESGCM13.
-	// It uses the nonce provided here and XOR it with the IV.
-	dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
+	binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn))
+	dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad)
 	if err == nil {
 		o.highestRcvdPN = max(o.highestRcvdPN, pn)
 	} else {

+ 4 - 5
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/cipher_suite.go

@@ -4,9 +4,8 @@ import (
 	"crypto"
 	"crypto/aes"
 	"crypto/cipher"
-	"fmt"
-
 	tls "github.com/Psiphon-Labs/psiphon-tls"
+	"fmt"
 
 	"golang.org/x/crypto/chacha20poly1305"
 )
@@ -19,7 +18,7 @@ type cipherSuite struct {
 	ID     uint16
 	Hash   crypto.Hash
 	KeyLen int
-	AEAD   func(key, nonceMask []byte) cipher.AEAD
+	AEAD   func(key, nonceMask []byte) *xorNonceAEAD
 }
 
 func (s cipherSuite) IVLen() int { return aeadNonceLength }
@@ -37,7 +36,7 @@ func getCipherSuite(id uint16) *cipherSuite {
 	}
 }
 
-func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
+func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD {
 	if len(nonceMask) != aeadNonceLength {
 		panic("tls: internal error: wrong nonce length")
 	}
@@ -55,7 +54,7 @@ func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
 	return ret
 }
 
-func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD {
+func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD {
 	if len(nonceMask) != aeadNonceLength {
 		panic("tls: internal error: wrong nonce length")
 	}

+ 28 - 57
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/crypto_setup.go

@@ -1,7 +1,6 @@
 package handshake
 
 import (
-	"bytes"
 	"context"
 	"errors"
 	"fmt"
@@ -34,7 +33,7 @@ type cryptoSetup struct {
 
 	events []Event
 
-	version protocol.VersionNumber
+	version protocol.Version
 
 	ourParams  *wire.TransportParameters
 	peerParams *wire.TransportParameters
@@ -83,7 +82,7 @@ func NewCryptoSetupClient(
 	rttStats *utils.RTTStats,
 	tracer *logging.ConnectionTracer,
 	logger utils.Logger,
-	version protocol.VersionNumber,
+	version protocol.Version,
 ) CryptoSetup {
 
 	// [Psiphon]
@@ -133,7 +132,7 @@ func NewCryptoSetupServer(
 	rttStats *utils.RTTStats,
 	tracer *logging.ConnectionTracer,
 	logger utils.Logger,
-	version protocol.VersionNumber,
+	version protocol.Version,
 ) CryptoSetup {
 	cs := newCryptoSetup(
 		connID,
@@ -146,44 +145,12 @@ func NewCryptoSetupServer(
 	)
 	cs.allow0RTT = allow0RTT
 
-	quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
-	qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
-	addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
-
-	cs.tlsConf = quicConf.TLSConfig
-	cs.conn = tls.QUICServer(quicConf)
-
+	tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket)
+	cs.tlsConf = tlsConf
+	cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf})
 	return cs
 }
 
-// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
-// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
-// that allows the caller to get the local and the remote address.
-func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
-	if conf.GetConfigForClient != nil {
-		gcfc := conf.GetConfigForClient
-		conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
-			info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
-			c, err := gcfc(info)
-			if c != nil {
-				c = c.Clone()
-				// This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted.
-				c.MinVersion = tls.VersionTLS13
-				// We're returning a tls.Config here, so we need to apply this recursively.
-				addConnToClientHelloInfo(c, localAddr, remoteAddr)
-			}
-			return c, err
-		}
-	}
-	if conf.GetCertificate != nil {
-		gc := conf.GetCertificate
-		conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
-			info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
-			return gc(info)
-		}
-	}
-}
-
 func newCryptoSetup(
 	connID protocol.ConnectionID,
 	tp *wire.TransportParameters,
@@ -191,7 +158,7 @@ func newCryptoSetup(
 	tracer *logging.ConnectionTracer,
 	logger utils.Logger,
 	perspective protocol.Perspective,
-	version protocol.VersionNumber,
+	version protocol.Version,
 ) *cryptoSetup {
 	initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
 	if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
@@ -226,8 +193,8 @@ func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
 	return h.aead.SetLargestAcked(pn)
 }
 
-func (h *cryptoSetup) StartHandshake() error {
-	err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version))
+func (h *cryptoSetup) StartHandshake(ctx context.Context) error {
+	err := h.conn.Start(context.WithValue(ctx, QUICVersionContextKey, h.version))
 	if err != nil {
 		return wrapError(err)
 	}
@@ -284,14 +251,17 @@ func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLev
 }
 
 func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
+	//nolint:exhaustive
+	// Go 1.23 added new 0-RTT events, see https://github.com/quic-go/quic-go/issues/4272.
+	// We will start using these events when dropping support for Go 1.22.
 	switch ev.Kind {
 	case tls.QUICNoEvent:
 		return true, nil
 	case tls.QUICSetReadSecret:
-		h.SetReadKey(ev.Level, ev.Suite, ev.Data)
+		h.setReadKey(ev.Level, ev.Suite, ev.Data)
 		return false, nil
 	case tls.QUICSetWriteSecret:
-		h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
+		h.setWriteKey(ev.Level, ev.Suite, ev.Data)
 		return false, nil
 	case tls.QUICTransportParameters:
 		return false, h.handleTransportParameters(ev.Data)
@@ -308,7 +278,10 @@ func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
 		h.handshakeComplete()
 		return false, nil
 	default:
-		return false, fmt.Errorf("unexpected event: %d", ev.Kind)
+		// Unknown events should be ignored.
+		// crypto/tls will ensure that this is safe to do.
+		// See the discussion following https://github.com/golang/go/issues/68124#issuecomment-2187042510 for details.
+		return false, nil
 	}
 }
 
@@ -360,25 +333,26 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (a
 	return false
 }
 
-func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
-	r := bytes.NewReader(data)
-	ver, err := quicvarint.Read(r)
+func decodeDataFromSessionState(b []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
+	ver, l, err := quicvarint.Parse(b)
 	if err != nil {
 		return 0, nil, err
 	}
+	b = b[l:]
 	if ver != clientSessionStateRevision {
 		return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
 	}
-	rttEncoded, err := quicvarint.Read(r)
+	rttEncoded, l, err := quicvarint.Parse(b)
 	if err != nil {
 		return 0, nil, err
 	}
+	b = b[l:]
 	rtt := time.Duration(rttEncoded) * time.Microsecond
 	if !earlyData {
 		return rtt, nil, nil
 	}
 	var tp wire.TransportParameters
-	if err := tp.UnmarshalFromSessionTicket(r); err != nil {
+	if err := tp.UnmarshalFromSessionTicket(b); err != nil {
 		return 0, nil, err
 	}
 	return rtt, &tp, nil
@@ -398,9 +372,7 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
 // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
 // It is only valid for the server.
 func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
-	if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
-		EarlyData: h.allow0RTT,
-	}); err != nil {
+	if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil {
 		// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
 		// We can't check h.tlsConfig here, since the actual config might have been obtained from
 		// the GetConfigForClient callback.
@@ -461,7 +433,7 @@ func (h *cryptoSetup) rejected0RTT() {
 	}
 }
 
-func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
+func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
 	suite := getCipherSuite(suiteID)
 	//nolint:exhaustive // The TLS stack doesn't export Initial keys.
 	switch el {
@@ -500,7 +472,7 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
 	}
 }
 
-func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
+func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
 	suite := getCipherSuite(suiteID)
 	//nolint:exhaustive // The TLS stack doesn't export Initial keys.
 	switch el {
@@ -683,8 +655,7 @@ func (h *cryptoSetup) TLSConnectionMetrics() tls.ConnectionMetrics {
 
 
 func wrapError(err error) error {
-	// alert 80 is an internal error
-	if alertErr := tls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
+	if alertErr := tls.AlertError(0); errors.As(err, &alertErr) {
 		return qerr.NewLocalCryptoError(uint8(alertErr), err)
 	}
 	return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}

+ 5 - 7
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/header_protector.go

@@ -3,11 +3,10 @@ package handshake
 import (
 	"crypto/aes"
 	"crypto/cipher"
+	tls "github.com/Psiphon-Labs/psiphon-tls"
 	"encoding/binary"
 	"fmt"
 
-	tls "github.com/Psiphon-Labs/psiphon-tls"
-
 	"golang.org/x/crypto/chacha20"
 
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
@@ -18,14 +17,14 @@ type headerProtector interface {
 	DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
 }
 
-func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
+func hkdfHeaderProtectionLabel(v protocol.Version) string {
 	if v == protocol.Version2 {
 		return "quicv2 hp"
 	}
 	return "quic hp"
 }
 
-func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
+func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector {
 	hkdfLabel := hkdfHeaderProtectionLabel(v)
 	switch suite.ID {
 	case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
@@ -38,7 +37,7 @@ func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader b
 }
 
 type aesHeaderProtector struct {
-	mask         []byte
+	mask         [16]byte // AES always has a 16 byte block size
 	block        cipher.Block
 	isLongHeader bool
 }
@@ -53,7 +52,6 @@ func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeade
 	}
 	return &aesHeaderProtector{
 		block:        block,
-		mask:         make([]byte, block.BlockSize()),
 		isLongHeader: isLongHeader,
 	}
 }
@@ -70,7 +68,7 @@ func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []by
 	if len(sample) != len(p.mask) {
 		panic("invalid sample size")
 	}
-	p.block.Encrypt(p.mask, sample)
+	p.block.Encrypt(p.mask[:], sample)
 	if p.isLongHeader {
 		*firstByte ^= p.mask[0] & 0xf
 	} else {

+ 1 - 1
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/hkdf.go

@@ -7,7 +7,7 @@ import (
 	"golang.org/x/crypto/hkdf"
 )
 
-// hkdfExpandLabel HKDF expands a label.
+// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1.
 // Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
 // hkdfExpandLabel in the standard library.
 func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {

+ 4 - 5
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/initial_aead.go

@@ -2,7 +2,6 @@ package handshake
 
 import (
 	"crypto"
-
 	tls "github.com/Psiphon-Labs/psiphon-tls"
 
 	"golang.org/x/crypto/hkdf"
@@ -22,7 +21,7 @@ const (
 	hkdfLabelIVV2  = "quicv2 iv"
 )
 
-func getSalt(v protocol.VersionNumber) []byte {
+func getSalt(v protocol.Version) []byte {
 	if v == protocol.Version2 {
 		return quicSaltV2
 	}
@@ -32,7 +31,7 @@ func getSalt(v protocol.VersionNumber) []byte {
 var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
 
 // NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
-func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
+func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) {
 	clientSecret, serverSecret := computeSecrets(connID, v)
 	var mySecret, otherSecret []byte
 	if pers == protocol.PerspectiveClient {
@@ -52,14 +51,14 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p
 		newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
 }
 
-func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) {
+func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) {
 	initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
 	clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
 	serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
 	return
 }
 
-func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) {
+func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) {
 	keyLabel := hkdfLabelKeyV1
 	ivLabel := hkdfLabelIVV1
 	if v == protocol.Version2 {

+ 26 - 3
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/interface.go

@@ -1,12 +1,12 @@
 package handshake
 
 import (
+	"context"
+	tls "github.com/Psiphon-Labs/psiphon-tls"
 	"errors"
 	"io"
 	"time"
 
-	tls "github.com/Psiphon-Labs/psiphon-tls"
-
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
 	"github.com/Psiphon-Labs/quic-go/internal/wire"
 )
@@ -83,6 +83,29 @@ const (
 	EventHandshakeComplete
 )
 
+func (k EventKind) String() string {
+	switch k {
+	case EventNoEvent:
+		return "EventNoEvent"
+	case EventWriteInitialData:
+		return "EventWriteInitialData"
+	case EventWriteHandshakeData:
+		return "EventWriteHandshakeData"
+	case EventReceivedReadKeys:
+		return "EventReceivedReadKeys"
+	case EventDiscard0RTTKeys:
+		return "EventDiscard0RTTKeys"
+	case EventReceivedTransportParameters:
+		return "EventReceivedTransportParameters"
+	case EventRestoredTransportParameters:
+		return "EventRestoredTransportParameters"
+	case EventHandshakeComplete:
+		return "EventHandshakeComplete"
+	default:
+		return "Unknown EventKind"
+	}
+}
+
 // Event is a handshake event.
 type Event struct {
 	Kind                EventKind
@@ -92,7 +115,7 @@ type Event struct {
 
 // CryptoSetup handles the handshake and protecting / unprotecting packets
 type CryptoSetup interface {
-	StartHandshake() error
+	StartHandshake(context.Context) error
 	io.Closer
 	ChangeConnectionID(protocol.ConnectionID)
 	GetSessionTicket() ([]byte, error)

+ 9 - 6
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/retry.go

@@ -10,16 +10,13 @@ import (
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
 )
 
+// Instead of using an init function, the AEADs are created lazily.
+// For more details see https://github.com/quic-go/quic-go/issues/4894.
 var (
 	retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000)
 	retryAEADv2 cipher.AEAD // used for QUIC v2 (RFC 9369)
 )
 
-func init() {
-	retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
-	retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92})
-}
-
 func initAEAD(key [16]byte) cipher.AEAD {
 	aes, err := aes.NewCipher(key[:])
 	if err != nil {
@@ -40,7 +37,7 @@ var (
 )
 
 // GetRetryIntegrityTag calculates the integrity tag on a Retry packet
-func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte {
+func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte {
 	retryMutex.Lock()
 	defer retryMutex.Unlock()
 
@@ -52,8 +49,14 @@ func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, ve
 	var tag [16]byte
 	var sealed []byte
 	if version == protocol.Version2 {
+		if retryAEADv2 == nil {
+			retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92})
+		}
 		sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes())
 	} else {
+		if retryAEADv1 == nil {
+			retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
+		}
 		sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes())
 	}
 	if len(sealed) != 16 {

+ 6 - 6
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/session_ticket.go

@@ -1,7 +1,6 @@
 package handshake
 
 import (
-	"bytes"
 	"errors"
 	"fmt"
 	"time"
@@ -28,25 +27,26 @@ func (t *sessionTicket) Marshal() []byte {
 }
 
 func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
-	r := bytes.NewReader(b)
-	rev, err := quicvarint.Read(r)
+	rev, l, err := quicvarint.Parse(b)
 	if err != nil {
 		return errors.New("failed to read session ticket revision")
 	}
+	b = b[l:]
 	if rev != sessionTicketRevision {
 		return fmt.Errorf("unknown session ticket revision: %d", rev)
 	}
-	rtt, err := quicvarint.Read(r)
+	rtt, l, err := quicvarint.Parse(b)
 	if err != nil {
 		return errors.New("failed to read RTT")
 	}
+	b = b[l:]
 	if using0RTT {
 		var tp wire.TransportParameters
-		if err := tp.UnmarshalFromSessionTicket(r); err != nil {
+		if err := tp.UnmarshalFromSessionTicket(b); err != nil {
 			return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
 		}
 		t.Parameters = &tp
-	} else if r.Len() > 0 {
+	} else if len(b) > 0 {
 		return fmt.Errorf("the session ticket has more bytes than expected")
 	}
 	t.RTT = time.Duration(rtt) * time.Microsecond

+ 1 - 1
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/token_generator.go

@@ -46,7 +46,7 @@ type TokenGenerator struct {
 
 // NewTokenGenerator initializes a new TokenGenerator
 func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator {
-	return &TokenGenerator{tokenProtector: newTokenProtector(key)}
+	return &TokenGenerator{tokenProtector: *newTokenProtector(key)}
 }
 
 // NewRetryToken generates a new token for a Retry for a given source address

+ 6 - 14
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/token_protector.go

@@ -14,28 +14,20 @@ import (
 // TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens.
 type TokenProtectorKey [32]byte
 
-// TokenProtector is used to create and verify a token
-type tokenProtector interface {
-	// NewToken creates a new token
-	NewToken([]byte) ([]byte, error)
-	// DecodeToken decodes a token
-	DecodeToken([]byte) ([]byte, error)
-}
-
 const tokenNonceSize = 32
 
 // tokenProtector is used to create and verify a token
-type tokenProtectorImpl struct {
+type tokenProtector struct {
 	key TokenProtectorKey
 }
 
 // newTokenProtector creates a source for source address tokens
-func newTokenProtector(key TokenProtectorKey) tokenProtector {
-	return &tokenProtectorImpl{key: key}
+func newTokenProtector(key TokenProtectorKey) *tokenProtector {
+	return &tokenProtector{key: key}
 }
 
 // NewToken encodes data into a new token.
-func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
+func (s *tokenProtector) NewToken(data []byte) ([]byte, error) {
 	var nonce [tokenNonceSize]byte
 	if _, err := rand.Read(nonce[:]); err != nil {
 		return nil, err
@@ -48,7 +40,7 @@ func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
 }
 
 // DecodeToken decodes a token.
-func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
+func (s *tokenProtector) DecodeToken(p []byte) ([]byte, error) {
 	if len(p) < tokenNonceSize {
 		return nil, fmt.Errorf("token too short: %d", len(p))
 	}
@@ -60,7 +52,7 @@ func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
 	return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
 }
 
-func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
+func (s *tokenProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
 	h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source"))
 	key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
 	if _, err := io.ReadFull(h, key); err != nil {

+ 4 - 5
vendor/github.com/Psiphon-Labs/quic-go/internal/handshake/updatable_aead.go

@@ -3,12 +3,11 @@ package handshake
 import (
 	"crypto"
 	"crypto/cipher"
+	tls "github.com/Psiphon-Labs/psiphon-tls"
 	"encoding/binary"
 	"fmt"
 	"time"
 
-	tls "github.com/Psiphon-Labs/psiphon-tls"
-
 	"github.com/Psiphon-Labs/quic-go/internal/protocol"
 	"github.com/Psiphon-Labs/quic-go/internal/qerr"
 	"github.com/Psiphon-Labs/quic-go/internal/utils"
@@ -60,7 +59,7 @@ type updatableAEAD struct {
 
 	tracer  *logging.ConnectionTracer
 	logger  utils.Logger
-	version protocol.VersionNumber
+	version protocol.Version
 
 	// use a single slice to avoid allocations
 	nonceBuf []byte
@@ -71,7 +70,7 @@ var (
 	_ ShortHeaderSealer = &updatableAEAD{}
 )
 
-func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD {
+func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD {
 	return &updatableAEAD{
 		firstPacketNumber:       protocol.InvalidPacketNumber,
 		largestAcked:            protocol.InvalidPacketNumber,
@@ -134,7 +133,7 @@ func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
 
 // SetWriteKey sets the write key.
 // For the client, this function is called after SetReadKey.
-// For the server, this function is called before SetWriteKey.
+// For the server, this function is called before SetReadKey.
 func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
 	a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
 	a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)

+ 0 - 50
vendor/github.com/Psiphon-Labs/quic-go/internal/logutils/frame.go

@@ -1,50 +0,0 @@
-package logutils
-
-import (
-	"github.com/Psiphon-Labs/quic-go/internal/protocol"
-	"github.com/Psiphon-Labs/quic-go/internal/wire"
-	"github.com/Psiphon-Labs/quic-go/logging"
-)
-
-// ConvertFrame converts a wire.Frame into a logging.Frame.
-// This makes it possible for external packages to access the frames.
-// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
-func ConvertFrame(frame wire.Frame) logging.Frame {
-	switch f := frame.(type) {
-	case *wire.AckFrame:
-		// We use a pool for ACK frames.
-		// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
-		return ConvertAckFrame(f)
-	case *wire.CryptoFrame:
-		return &logging.CryptoFrame{
-			Offset: f.Offset,
-			Length: protocol.ByteCount(len(f.Data)),
-		}
-	case *wire.StreamFrame:
-		return &logging.StreamFrame{
-			StreamID: f.StreamID,
-			Offset:   f.Offset,
-			Length:   f.DataLen(),
-			Fin:      f.Fin,
-		}
-	case *wire.DatagramFrame:
-		return &logging.DatagramFrame{
-			Length: logging.ByteCount(len(f.Data)),
-		}
-	default:
-		return logging.Frame(frame)
-	}
-}
-
-func ConvertAckFrame(f *wire.AckFrame) *logging.AckFrame {
-	ranges := make([]wire.AckRange, 0, len(f.AckRanges))
-	ranges = append(ranges, f.AckRanges...)
-	ack := &logging.AckFrame{
-		AckRanges: ranges,
-		DelayTime: f.DelayTime,
-		ECNCE:     f.ECNCE,
-		ECT0:      f.ECT0,
-		ECT1:      f.ECT1,
-	}
-	return ack
-}

+ 15 - 1
vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/connection_id.go

@@ -57,12 +57,26 @@ func ParseConnectionID(b []byte) ConnectionID {
 	return c
 }
 
-// [Psiphon]
+// [Psiphon] SECTION BEGIN
+
+// // GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
+// // It uses a length randomly chosen between 8 and 20 bytes.
+// func GenerateConnectionIDForInitial() (ConnectionID, error) {
+// 	r := make([]byte, 1)
+// 	if _, err := rand.Read(r); err != nil {
+// 		return ConnectionID{}, err
+// 	}
+// 	l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
+// 	return GenerateConnectionID(l)
+// }
+
 // GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
 func GenerateConnectionIDForInitial() (ConnectionID, error) {
 	return GenerateConnectionID(MinConnectionIDLenInitial)
 }
 
+// [Psiphon] SECTION END
+
 // ReadConnectionID reads a connection ID of length len from the given io.Reader.
 // It returns io.EOF if there are not enough bytes to read.
 func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {

+ 23 - 45
vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/packet_number.go

@@ -21,58 +21,36 @@ const (
 	PacketNumberLen4 PacketNumberLen = 4
 )
 
-// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
-func DecodePacketNumber(
-	packetNumberLength PacketNumberLen,
-	lastPacketNumber PacketNumber,
-	wirePacketNumber PacketNumber,
-) PacketNumber {
-	var epochDelta PacketNumber
-	switch packetNumberLength {
-	case PacketNumberLen1:
-		epochDelta = PacketNumber(1) << 8
-	case PacketNumberLen2:
-		epochDelta = PacketNumber(1) << 16
-	case PacketNumberLen3:
-		epochDelta = PacketNumber(1) << 24
-	case PacketNumberLen4:
-		epochDelta = PacketNumber(1) << 32
+// DecodePacketNumber calculates the packet number based its length and the last seen packet number
+// This function is taken from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3.
+func DecodePacketNumber(length PacketNumberLen, largest PacketNumber, truncated PacketNumber) PacketNumber {
+	expected := largest + 1
+	win := PacketNumber(1 << (length * 8))
+	hwin := win / 2
+	mask := win - 1
+	candidate := (expected & ^mask) | truncated
+	if candidate <= expected-hwin && candidate < 1<<62-win {
+		return candidate + win
 	}
-	epoch := lastPacketNumber & ^(epochDelta - 1)
-	var prevEpochBegin PacketNumber
-	if epoch > epochDelta {
-		prevEpochBegin = epoch - epochDelta
+	if candidate > expected+hwin && candidate >= win {
+		return candidate - win
 	}
-	nextEpochBegin := epoch + epochDelta
-	return closestTo(
-		lastPacketNumber+1,
-		epoch+wirePacketNumber,
-		closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
-	)
+	return candidate
 }
 
-func closestTo(target, a, b PacketNumber) PacketNumber {
-	if delta(target, a) < delta(target, b) {
-		return a
-	}
-	return b
-}
-
-func delta(a, b PacketNumber) PacketNumber {
-	if a < b {
-		return b - a
-	}
-	return a - b
-}
-
-// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
+// PacketNumberLengthForHeader gets the length of the packet number for the public header
 // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
-func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
-	diff := uint64(packetNumber - leastUnacked)
-	if diff < (1 << (16 - 1)) {
+func PacketNumberLengthForHeader(pn, largestAcked PacketNumber) PacketNumberLen {
+	var numUnacked PacketNumber
+	if largestAcked == InvalidPacketNumber {
+		numUnacked = pn + 1
+	} else {
+		numUnacked = pn - largestAcked
+	}
+	if numUnacked < 1<<(16-1) {
 		return PacketNumberLen2
 	}
-	if diff < (1 << (24 - 1)) {
+	if numUnacked < 1<<(24-1) {
 		return PacketNumberLen3
 	}
 	return PacketNumberLen4

+ 7 - 12
vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/params.go

@@ -3,16 +3,13 @@ package protocol
 import "time"
 
 // DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use.
-const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB
+const DesiredReceiveBufferSize = (1 << 20) * 7 // 7 MB
 
 // DesiredSendBufferSize is the kernel UDP send buffer size that we'd like to use.
-const DesiredSendBufferSize = (1 << 20) * 2 // 2 MB
+const DesiredSendBufferSize = (1 << 20) * 7 // 7 MB
 
-// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
-const InitialPacketSizeIPv4 = 1252
-
-// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
-const InitialPacketSizeIPv6 = 1232
+// InitialPacketSize is the initial (before Path MTU discovery) maximum packet size used.
+const InitialPacketSize = 1280
 
 // MaxCongestionWindowPackets is the maximum congestion window in packet.
 const MaxCongestionWindowPackets = 10000
@@ -105,10 +102,6 @@ const DefaultIdleTimeout = 30 * time.Second
 // DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion.
 const DefaultHandshakeIdleTimeout = 5 * time.Second
 
-// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive.
-// It should be shorter than the time that NATs clear their mapping.
-const MaxKeepAliveInterval = 20 * time.Second
-
 // RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE.
 // after this time all information about the old connection will be deleted
 const RetiredConnectionIDDeleteTimeout = 5 * time.Second
@@ -139,9 +132,11 @@ const MaxNumAckRanges = 32
 // Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth.
 const MinPacingDelay = time.Millisecond
 
-// [Psiphon]
 // DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
 // if no other value is configured.
+// const DefaultConnectionIDLength = 4
+
+// [Psiphon]
 const DefaultConnectionIDLength = 8
 
 // MaxActiveConnectionIDs is the number of connection IDs that we're storing.

+ 2 - 2
vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/perspective.go

@@ -17,9 +17,9 @@ func (p Perspective) Opposite() Perspective {
 func (p Perspective) String() string {
 	switch p {
 	case PerspectiveServer:
-		return "Server"
+		return "server"
 	case PerspectiveClient:
-		return "Client"
+		return "client"
 	default:
 		return "invalid perspective"
 	}

+ 34 - 25
vendor/github.com/Psiphon-Labs/quic-go/internal/protocol/version.go

@@ -1,14 +1,17 @@
 package protocol
 
 import (
-	"crypto/rand"
 	"encoding/binary"
 	"fmt"
 	"math"
+	"sync"
+	"time"
+
+	"golang.org/x/exp/rand"
 )
 
-// VersionNumber is a version number as int
-type VersionNumber uint32
+// Version is a version number as int
+type Version uint32
 
 // gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
 const (
@@ -18,22 +21,22 @@ const (
 
 // The version numbers, making grepping easier
 const (
-	VersionUnknown VersionNumber = math.MaxUint32
-	versionDraft29 VersionNumber = 0xff00001d // draft-29 used to be a widely deployed version
-	Version1       VersionNumber = 0x1
-	Version2       VersionNumber = 0x6b3343cf
+	VersionUnknown Version = math.MaxUint32
+	versionDraft29 Version = 0xff00001d // draft-29 used to be a widely deployed version
+	Version1       Version = 0x1
+	Version2       Version = 0x6b3343cf
 )
 
 // SupportedVersions lists the versions that the server supports
 // must be in sorted descending order
-var SupportedVersions = []VersionNumber{Version1, Version2}
+var SupportedVersions = []Version{Version1, Version2}
 
 // IsValidVersion says if the version is known to quic-go
-func IsValidVersion(v VersionNumber) bool {
+func IsValidVersion(v Version) bool {
 	return v == Version1 || IsSupportedVersion(SupportedVersions, v)
 }
 
-func (vn VersionNumber) String() string {
+func (vn Version) String() string {
 	//nolint:exhaustive
 	switch vn {
 	case VersionUnknown:
@@ -52,16 +55,16 @@ func (vn VersionNumber) String() string {
 	}
 }
 
-func (vn VersionNumber) isGQUIC() bool {
+func (vn Version) isGQUIC() bool {
 	return vn > gquicVersion0 && vn <= maxGquicVersion
 }
 
-func (vn VersionNumber) toGQUICVersion() int {
+func (vn Version) toGQUICVersion() int {
 	return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
 }
 
 // IsSupportedVersion returns true if the server supports this version
-func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
+func IsSupportedVersion(supported []Version, v Version) bool {
 	for _, t := range supported {
 		if t == v {
 			return true
@@ -74,7 +77,7 @@ func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
 // ours is a slice of versions that we support, sorted by our preference (descending)
 // theirs is a slice of versions offered by the peer. The order does not matter.
 // The bool returned indicates if a matching version was found.
-func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) {
+func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) {
 	for _, ourVer := range ours {
 		for _, theirVer := range theirs {
 			if ourVer == theirVer {
@@ -85,19 +88,25 @@ func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool)
 	return 0, false
 }
 
-// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a)
-func generateReservedVersion() VersionNumber {
-	b := make([]byte, 4)
-	_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
-	return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa)
+var (
+	versionNegotiationMx   sync.Mutex
+	versionNegotiationRand = rand.New(rand.NewSource(uint64(time.Now().UnixNano())))
+)
+
+// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a)
+func generateReservedVersion() Version {
+	var b [4]byte
+	_, _ = versionNegotiationRand.Read(b[:]) // ignore the error here. Failure to read random data doesn't break anything
+	return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa)
 }
 
-// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position
-func GetGreasedVersions(supported []VersionNumber) []VersionNumber {
-	b := make([]byte, 1)
-	_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
-	randPos := int(b[0]) % (len(supported) + 1)
-	greased := make([]VersionNumber, len(supported)+1)
+// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position.
+// It doesn't modify the supported slice.
+func GetGreasedVersions(supported []Version) []Version {
+	versionNegotiationMx.Lock()
+	defer versionNegotiationMx.Unlock()
+	randPos := rand.Intn(len(supported) + 1)
+	greased := make([]Version, len(supported)+1)
 	copy(greased, supported[:randPos])
 	greased[randPos] = generateReservedVersion()
 	copy(greased[randPos+1:], supported[randPos:])

+ 1 - 2
vendor/github.com/Psiphon-Labs/quic-go/internal/qerr/error_codes.go

@@ -1,9 +1,8 @@
 package qerr
 
 import (
-	"fmt"
-
 	tls "github.com/Psiphon-Labs/psiphon-tls"
+	"fmt"
 )
 
 // TransportErrorCode is a QUIC transport error.

+ 25 - 30
vendor/github.com/Psiphon-Labs/quic-go/internal/qerr/errors.go

@@ -48,21 +48,16 @@ func (e *TransportError) Error() string {
 	return str + ": " + msg
 }
 
-func (e *TransportError) Is(target error) bool {
-	return target == net.ErrClosed
-}
+func (e *TransportError) Unwrap() []error { return []error{net.ErrClosed, e.error} }
 
-func (e *TransportError) Unwrap() error {
-	return e.error
+func (e *TransportError) Is(target error) bool {
+	t, ok := target.(*TransportError)
+	return ok && e.ErrorCode == t.ErrorCode && e.FrameType == t.FrameType && e.Remote == t.Remote
 }
 
 // An ApplicationErrorCode is an application-defined error code.
 type ApplicationErrorCode uint64
 
-func (e *ApplicationError) Is(target error) bool {
-	return target == net.ErrClosed
-}
-
 // A StreamErrorCode is an error code used to cancel streams.
 type StreamErrorCode uint64
 
@@ -81,53 +76,53 @@ func (e *ApplicationError) Error() string {
 	return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage)
 }
 
+func (e *ApplicationError) Unwrap() error { return net.ErrClosed }
+
+func (e *ApplicationError) Is(target error) bool {
+	t, ok := target.(*ApplicationError)
+	return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
+}
+
 type IdleTimeoutError struct{}
 
 var _ error = &IdleTimeoutError{}
 
-func (e *IdleTimeoutError) Timeout() bool        { return true }
-func (e *IdleTimeoutError) Temporary() bool      { return false }
-func (e *IdleTimeoutError) Error() string        { return "timeout: no recent network activity" }
-func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed }
+func (e *IdleTimeoutError) Timeout() bool   { return true }
+func (e *IdleTimeoutError) Temporary() bool { return false }
+func (e *IdleTimeoutError) Error() string   { return "timeout: no recent network activity" }
+func (e *IdleTimeoutError) Unwrap() error   { return net.ErrClosed }
 
 type HandshakeTimeoutError struct{}
 
 var _ error = &HandshakeTimeoutError{}
 
-func (e *HandshakeTimeoutError) Timeout() bool        { return true }
-func (e *HandshakeTimeoutError) Temporary() bool      { return false }
-func (e *HandshakeTimeoutError) Error() string        { return "timeout: handshake did not complete in time" }
-func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed }
+func (e *HandshakeTimeoutError) Timeout() bool   { return true }
+func (e *HandshakeTimeoutError) Temporary() bool { return false }
+func (e *HandshakeTimeoutError) Error() string   { return "timeout: handshake did not complete in time" }
+func (e *HandshakeTimeoutError) Unwrap() error   { return net.ErrClosed }
 
 // A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
 type VersionNegotiationError struct {
-	Ours   []protocol.VersionNumber
-	Theirs []protocol.VersionNumber
+	Ours   []protocol.Version
+	Theirs []protocol.Version
 }
 
 func (e *VersionNegotiationError) Error() string {
 	return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs)
 }
 
-func (e *VersionNegotiationError) Is(target error) bool {
-	return target == net.ErrClosed
-}
+func (e *VersionNegotiationError) Unwrap() error { return net.ErrClosed }
 
 // A StatelessResetError occurs when we receive a stateless reset.
-type StatelessResetError struct {
-	Token protocol.StatelessResetToken
-}
+type StatelessResetError struct{}
 
 var _ net.Error = &StatelessResetError{}
 
 func (e *StatelessResetError) Error() string {
-	return fmt.Sprintf("received a stateless reset with token %x", e.Token)
-}
-
-func (e *StatelessResetError) Is(target error) bool {
-	return target == net.ErrClosed
+	return "received a stateless reset"
 }
 
+func (e *StatelessResetError) Unwrap() error   { return net.ErrClosed }
 func (e *StatelessResetError) Timeout() bool   { return false }
 func (e *StatelessResetError) Temporary() bool { return true }
 

+ 0 - 12
vendor/github.com/Psiphon-Labs/quic-go/internal/qtls/cipher_suite.go

@@ -1,24 +1,12 @@
 package qtls
 
 import (
-	"crypto"
-	"crypto/cipher"
 	"fmt"
 	"unsafe"
 
 	tls "github.com/Psiphon-Labs/psiphon-tls"
 )
 
-type cipherSuiteTLS13 struct {
-	ID     uint16
-	KeyLen int
-	AEAD   func(key, fixedNonce []byte) cipher.AEAD
-	Hash   crypto.Hash
-}
-
-//go:linkname cipherSuiteTLS13ByID github.com/Psiphon-Labs/psiphon-tls.cipherSuiteTLS13ByID
-func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
-
 //go:linkname cipherSuitesTLS13 github.com/Psiphon-Labs/psiphon-tls.cipherSuitesTLS13
 var cipherSuitesTLS13 []unsafe.Pointer
 

+ 10 - 2
vendor/github.com/Psiphon-Labs/quic-go/internal/qtls/client_session_cache.go

@@ -2,9 +2,11 @@ package qtls
 
 import (
 	tls "github.com/Psiphon-Labs/psiphon-tls"
+	"sync"
 )
 
 type clientSessionCache struct {
+	mx      sync.Mutex
 	getData func(earlyData bool) []byte
 	setData func(data []byte, earlyData bool) (allowEarlyData bool)
 	wrapped tls.ClientSessionCache
@@ -12,7 +14,10 @@ type clientSessionCache struct {
 
 var _ tls.ClientSessionCache = &clientSessionCache{}
 
-func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
+func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
+	c.mx.Lock()
+	defer c.mx.Unlock()
+
 	if cs == nil {
 		c.wrapped.Put(key, nil)
 		return
@@ -32,7 +37,10 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
 	c.wrapped.Put(key, newCS)
 }
 
-func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
+func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
+	c.mx.Lock()
+	defer c.mx.Unlock()
+
 	cs, ok := c.wrapped.Get(key)
 	if !ok || cs == nil {
 		return cs, ok

Some files were not shown because too many files changed in this diff