flight4handler.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "context"
  6. "crypto/rand"
  7. "crypto/x509"
  8. "github.com/pion/dtls/v2/internal/ciphersuite"
  9. "github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
  10. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  11. "github.com/pion/dtls/v2/pkg/crypto/prf"
  12. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  13. "github.com/pion/dtls/v2/pkg/protocol"
  14. "github.com/pion/dtls/v2/pkg/protocol/alert"
  15. "github.com/pion/dtls/v2/pkg/protocol/extension"
  16. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  17. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  18. inproxy_dtls "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/inproxy/dtls"
  19. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  20. )
  21. func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
  22. seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  23. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true},
  24. handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
  25. handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true},
  26. )
  27. if !ok {
  28. // No valid message received. Keep reading
  29. return 0, nil, nil
  30. }
  31. // Validate type
  32. var clientKeyExchange *handshake.MessageClientKeyExchange
  33. if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok {
  34. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  35. }
  36. if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert {
  37. state.PeerCertificates = h.Certificate
  38. // If the client offer its certificate, just disable session resumption.
  39. // Otherwise, we have to store the certificate identitfication and expire time.
  40. // And we have to check whether this certificate expired, revoked or changed.
  41. //
  42. // https://curl.se/docs/CVE-2016-5419.html
  43. state.SessionID = nil
  44. }
  45. if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify {
  46. if state.PeerCertificates == nil {
  47. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate
  48. }
  49. plainText := cache.pullAndMerge(
  50. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  51. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  52. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
  53. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  54. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
  55. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  56. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
  57. handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
  58. )
  59. // Verify that the pair of hash algorithm and signiture is listed.
  60. var validSignatureScheme bool
  61. for _, ss := range cfg.localSignatureSchemes {
  62. if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
  63. validSignatureScheme = true
  64. break
  65. }
  66. }
  67. if !validSignatureScheme {
  68. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
  69. }
  70. if err := verifyCertificateVerify(plainText, h.HashAlgorithm, h.Signature, state.PeerCertificates); err != nil {
  71. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  72. }
  73. var chains [][]*x509.Certificate
  74. var err error
  75. var verified bool
  76. if cfg.clientAuth >= VerifyClientCertIfGiven {
  77. if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil {
  78. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  79. }
  80. verified = true
  81. }
  82. if cfg.verifyPeerCertificate != nil {
  83. if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
  84. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  85. }
  86. }
  87. state.peerCertificatesVerified = verified
  88. } else if state.PeerCertificates != nil {
  89. // A certificate was received, but we haven't seen a CertificateVerify
  90. // keep reading until we receive one
  91. return 0, nil, nil
  92. }
  93. if !state.cipherSuite.IsInitialized() {
  94. serverRandom := state.localRandom.MarshalFixed()
  95. clientRandom := state.remoteRandom.MarshalFixed()
  96. var err error
  97. var preMasterSecret []byte
  98. if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey {
  99. var psk []byte
  100. if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil {
  101. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  102. }
  103. state.IdentityHint = clientKeyExchange.IdentityHint
  104. switch state.cipherSuite.KeyExchangeAlgorithm() {
  105. case CipherSuiteKeyExchangeAlgorithmPsk:
  106. preMasterSecret = prf.PSKPreMasterSecret(psk)
  107. case (CipherSuiteKeyExchangeAlgorithmPsk | CipherSuiteKeyExchangeAlgorithmEcdhe):
  108. if preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil {
  109. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  110. }
  111. default:
  112. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidCipherSuite
  113. }
  114. } else {
  115. preMasterSecret, err = prf.PreMasterSecret(clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
  116. if err != nil {
  117. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
  118. }
  119. }
  120. if state.extendedMasterSecret {
  121. var sessionHash []byte
  122. sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch)
  123. if err != nil {
  124. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  125. }
  126. state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
  127. if err != nil {
  128. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  129. }
  130. } else {
  131. state.masterSecret, err = prf.MasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
  132. if err != nil {
  133. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  134. }
  135. }
  136. if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil {
  137. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  138. }
  139. cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
  140. }
  141. if len(state.SessionID) > 0 {
  142. s := Session{
  143. ID: state.SessionID,
  144. Secret: state.masterSecret,
  145. }
  146. cfg.log.Tracef("[handshake] save new session: %x", s.ID)
  147. if err := cfg.sessionStore.Set(state.SessionID, s); err != nil {
  148. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  149. }
  150. }
  151. // Now, encrypted packets can be handled
  152. if err := c.handleQueuedPackets(ctx); err != nil {
  153. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  154. }
  155. seq, msgs, ok = cache.fullPullMap(seq, state.cipherSuite,
  156. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
  157. )
  158. if !ok {
  159. // No valid message received. Keep reading
  160. return 0, nil, nil
  161. }
  162. state.handshakeRecvSequence = seq
  163. if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
  164. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  165. }
  166. if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous {
  167. if cfg.verifyConnection != nil {
  168. if err := cfg.verifyConnection(state.clone()); err != nil {
  169. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  170. }
  171. }
  172. return flight6, nil, nil
  173. }
  174. switch cfg.clientAuth {
  175. case RequireAnyClientCert:
  176. if state.PeerCertificates == nil {
  177. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
  178. }
  179. case VerifyClientCertIfGiven:
  180. if state.PeerCertificates != nil && !state.peerCertificatesVerified {
  181. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
  182. }
  183. case RequireAndVerifyClientCert:
  184. if state.PeerCertificates == nil {
  185. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
  186. }
  187. if !state.peerCertificatesVerified {
  188. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
  189. }
  190. case NoClientCert, RequestClientCert:
  191. // go to flight6
  192. }
  193. if cfg.verifyConnection != nil {
  194. if err := cfg.verifyConnection(state.clone()); err != nil {
  195. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  196. }
  197. }
  198. return flight6, nil, nil
  199. }
  200. func flight4Generate(ctx context.Context, c flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  201. extensions := []extension.Extension{&extension.RenegotiationInfo{
  202. RenegotiatedConnection: 0,
  203. }}
  204. if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
  205. cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
  206. extensions = append(extensions, &extension.UseExtendedMasterSecret{
  207. Supported: true,
  208. })
  209. }
  210. if state.getSRTPProtectionProfile() != 0 {
  211. extensions = append(extensions, &extension.UseSRTP{
  212. ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
  213. })
  214. }
  215. if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
  216. extensions = append(extensions, &extension.SupportedPointFormats{
  217. PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
  218. })
  219. }
  220. selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols)
  221. if err != nil {
  222. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err
  223. }
  224. if selectedProto != "" {
  225. extensions = append(extensions, &extension.ALPN{
  226. ProtocolNameList: []string{selectedProto},
  227. })
  228. state.NegotiatedProtocol = selectedProto
  229. }
  230. var pkts []*packet
  231. cipherSuiteID := uint16(state.cipherSuite.ID())
  232. if cfg.sessionStore != nil {
  233. state.SessionID = make([]byte, sessionLength)
  234. if _, err := rand.Read(state.SessionID); err != nil {
  235. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  236. }
  237. }
  238. // [Psiphon]
  239. // Randomize ServerHello
  240. seed, err := inproxy_dtls.GetDTLSSeed(ctx)
  241. if err != nil {
  242. return nil, nil, err
  243. }
  244. if seed != nil {
  245. PRNG := prng.NewPRNGWithSeed(seed)
  246. PRNG.Shuffle(len(extensions), func(i, j int) {
  247. extensions[i], extensions[j] = extensions[j], extensions[i]
  248. })
  249. }
  250. pkts = append(pkts, &packet{
  251. record: &recordlayer.RecordLayer{
  252. Header: recordlayer.Header{
  253. Version: protocol.Version1_2,
  254. },
  255. Content: &handshake.Handshake{
  256. Message: &handshake.MessageServerHello{
  257. Version: protocol.Version1_2,
  258. Random: state.localRandom,
  259. SessionID: state.SessionID,
  260. CipherSuiteID: &cipherSuiteID,
  261. CompressionMethod: defaultCompressionMethods()[0],
  262. Extensions: extensions,
  263. },
  264. },
  265. },
  266. })
  267. switch {
  268. case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate:
  269. certificate, err := cfg.getCertificate(&ClientHelloInfo{
  270. ServerName: state.serverName,
  271. CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()},
  272. // [Psiphon]
  273. // Conjure DTLS support, from: https://github.com/mingyech/dtls/commit/a56eccc1
  274. RandomBytes: state.remoteRandom.RandomBytes,
  275. })
  276. if err != nil {
  277. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
  278. }
  279. pkts = append(pkts, &packet{
  280. record: &recordlayer.RecordLayer{
  281. Header: recordlayer.Header{
  282. Version: protocol.Version1_2,
  283. },
  284. Content: &handshake.Handshake{
  285. Message: &handshake.MessageCertificate{
  286. Certificate: certificate.Certificate,
  287. },
  288. },
  289. },
  290. })
  291. serverRandom := state.localRandom.MarshalFixed()
  292. clientRandom := state.remoteRandom.MarshalFixed()
  293. // Find compatible signature scheme
  294. signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey)
  295. if err != nil {
  296. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
  297. }
  298. signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash)
  299. if err != nil {
  300. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  301. }
  302. state.localKeySignature = signature
  303. pkts = append(pkts, &packet{
  304. record: &recordlayer.RecordLayer{
  305. Header: recordlayer.Header{
  306. Version: protocol.Version1_2,
  307. },
  308. Content: &handshake.Handshake{
  309. Message: &handshake.MessageServerKeyExchange{
  310. EllipticCurveType: elliptic.CurveTypeNamedCurve,
  311. NamedCurve: state.namedCurve,
  312. PublicKey: state.localKeypair.PublicKey,
  313. HashAlgorithm: signatureHashAlgo.Hash,
  314. SignatureAlgorithm: signatureHashAlgo.Signature,
  315. Signature: state.localKeySignature,
  316. },
  317. },
  318. },
  319. })
  320. if cfg.clientAuth > NoClientCert {
  321. // An empty list of certificateAuthorities signals to
  322. // the client that it may send any certificate in response
  323. // to our request. When we know the CAs we trust, then
  324. // we can send them down, so that the client can choose
  325. // an appropriate certificate to give to us.
  326. var certificateAuthorities [][]byte
  327. if cfg.clientCAs != nil {
  328. // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool and it's ok if certificate authorities is empty.
  329. certificateAuthorities = cfg.clientCAs.Subjects()
  330. }
  331. pkts = append(pkts, &packet{
  332. record: &recordlayer.RecordLayer{
  333. Header: recordlayer.Header{
  334. Version: protocol.Version1_2,
  335. },
  336. Content: &handshake.Handshake{
  337. Message: &handshake.MessageCertificateRequest{
  338. CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign},
  339. SignatureHashAlgorithms: cfg.localSignatureSchemes,
  340. CertificateAuthoritiesNames: certificateAuthorities,
  341. },
  342. },
  343. },
  344. })
  345. }
  346. case cfg.localPSKIdentityHint != nil || state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe):
  347. // To help the client in selecting which identity to use, the server
  348. // can provide a "PSK identity hint" in the ServerKeyExchange message.
  349. // If no hint is provided and cipher suite doesn't use elliptic curve,
  350. // the ServerKeyExchange message is omitted.
  351. //
  352. // https://tools.ietf.org/html/rfc4279#section-2
  353. srvExchange := &handshake.MessageServerKeyExchange{
  354. IdentityHint: cfg.localPSKIdentityHint,
  355. }
  356. if state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe) {
  357. srvExchange.EllipticCurveType = elliptic.CurveTypeNamedCurve
  358. srvExchange.NamedCurve = state.namedCurve
  359. srvExchange.PublicKey = state.localKeypair.PublicKey
  360. }
  361. pkts = append(pkts, &packet{
  362. record: &recordlayer.RecordLayer{
  363. Header: recordlayer.Header{
  364. Version: protocol.Version1_2,
  365. },
  366. Content: &handshake.Handshake{
  367. Message: srvExchange,
  368. },
  369. },
  370. })
  371. }
  372. pkts = append(pkts, &packet{
  373. record: &recordlayer.RecordLayer{
  374. Header: recordlayer.Header{
  375. Version: protocol.Version1_2,
  376. },
  377. Content: &handshake.Handshake{
  378. Message: &handshake.MessageServerHelloDone{},
  379. },
  380. },
  381. })
  382. return pkts, nil, nil
  383. }