flight5handler.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "bytes"
  6. "context"
  7. "crypto"
  8. "crypto/x509"
  9. "github.com/pion/dtls/v2/pkg/crypto/prf"
  10. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  11. "github.com/pion/dtls/v2/pkg/protocol"
  12. "github.com/pion/dtls/v2/pkg/protocol/alert"
  13. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  14. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  15. )
  16. func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  17. _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  18. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
  19. )
  20. if !ok {
  21. // No valid message received. Keep reading
  22. return 0, nil, nil
  23. }
  24. var finished *handshake.MessageFinished
  25. if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
  26. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  27. }
  28. plainText := cache.pullAndMerge(
  29. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  30. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  31. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
  32. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  33. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
  34. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  35. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
  36. handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
  37. handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
  38. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
  39. )
  40. expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
  41. if err != nil {
  42. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  43. }
  44. if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
  45. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
  46. }
  47. if len(state.SessionID) > 0 {
  48. s := Session{
  49. ID: state.SessionID,
  50. Secret: state.masterSecret,
  51. }
  52. cfg.log.Tracef("[handshake] save new session: %x", s.ID)
  53. if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil {
  54. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  55. }
  56. }
  57. return flight5, nil, nil
  58. }
  59. func flight5Generate(_ context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit
  60. var privateKey crypto.PrivateKey
  61. var pkts []*packet
  62. if state.remoteRequestedCertificate {
  63. _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite,
  64. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false})
  65. if !ok {
  66. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired
  67. }
  68. reqInfo := CertificateRequestInfo{}
  69. if r, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
  70. reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames
  71. } else {
  72. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired
  73. }
  74. certificate, err := cfg.getClientCertificate(&reqInfo)
  75. if err != nil {
  76. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
  77. }
  78. if certificate == nil {
  79. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain
  80. }
  81. if certificate.Certificate != nil {
  82. privateKey = certificate.PrivateKey
  83. }
  84. pkts = append(pkts,
  85. &packet{
  86. record: &recordlayer.RecordLayer{
  87. Header: recordlayer.Header{
  88. Version: protocol.Version1_2,
  89. },
  90. Content: &handshake.Handshake{
  91. Message: &handshake.MessageCertificate{
  92. Certificate: certificate.Certificate,
  93. },
  94. },
  95. },
  96. })
  97. }
  98. clientKeyExchange := &handshake.MessageClientKeyExchange{}
  99. if cfg.localPSKCallback == nil {
  100. clientKeyExchange.PublicKey = state.localKeypair.PublicKey
  101. } else {
  102. clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint
  103. }
  104. if state != nil && state.localKeypair != nil && len(state.localKeypair.PublicKey) > 0 {
  105. clientKeyExchange.PublicKey = state.localKeypair.PublicKey
  106. }
  107. pkts = append(pkts,
  108. &packet{
  109. record: &recordlayer.RecordLayer{
  110. Header: recordlayer.Header{
  111. Version: protocol.Version1_2,
  112. },
  113. Content: &handshake.Handshake{
  114. Message: clientKeyExchange,
  115. },
  116. },
  117. })
  118. serverKeyExchangeData := cache.pullAndMerge(
  119. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  120. )
  121. serverKeyExchange := &handshake.MessageServerKeyExchange{}
  122. // handshakeMessageServerKeyExchange is optional for PSK
  123. if len(serverKeyExchangeData) == 0 {
  124. alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{})
  125. if err != nil {
  126. return nil, alertPtr, err
  127. }
  128. } else {
  129. rawHandshake := &handshake.Handshake{
  130. KeyExchangeAlgorithm: state.cipherSuite.KeyExchangeAlgorithm(),
  131. }
  132. err := rawHandshake.Unmarshal(serverKeyExchangeData)
  133. if err != nil {
  134. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err
  135. }
  136. switch h := rawHandshake.Message.(type) {
  137. case *handshake.MessageServerKeyExchange:
  138. serverKeyExchange = h
  139. default:
  140. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType
  141. }
  142. }
  143. // Append not-yet-sent packets
  144. merged := []byte{}
  145. seqPred := uint16(state.handshakeSendSequence)
  146. for _, p := range pkts {
  147. h, ok := p.record.Content.(*handshake.Handshake)
  148. if !ok {
  149. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
  150. }
  151. h.Header.MessageSequence = seqPred
  152. seqPred++
  153. raw, err := h.Marshal()
  154. if err != nil {
  155. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  156. }
  157. merged = append(merged, raw...)
  158. }
  159. if alertPtr, err := initalizeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil {
  160. return nil, alertPtr, err
  161. }
  162. // If the client has sent a certificate with signing ability, a digitally-signed
  163. // CertificateVerify message is sent to explicitly verify possession of the
  164. // private key in the certificate.
  165. if state.remoteRequestedCertificate && privateKey != nil {
  166. plainText := append(cache.pullAndMerge(
  167. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  168. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  169. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
  170. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  171. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
  172. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  173. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
  174. handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
  175. ), merged...)
  176. // Find compatible signature scheme
  177. signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, privateKey)
  178. if err != nil {
  179. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
  180. }
  181. certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash)
  182. if err != nil {
  183. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  184. }
  185. state.localCertificatesVerify = certVerify
  186. p := &packet{
  187. record: &recordlayer.RecordLayer{
  188. Header: recordlayer.Header{
  189. Version: protocol.Version1_2,
  190. },
  191. Content: &handshake.Handshake{
  192. Message: &handshake.MessageCertificateVerify{
  193. HashAlgorithm: signatureHashAlgo.Hash,
  194. SignatureAlgorithm: signatureHashAlgo.Signature,
  195. Signature: state.localCertificatesVerify,
  196. },
  197. },
  198. },
  199. }
  200. pkts = append(pkts, p)
  201. h, ok := p.record.Content.(*handshake.Handshake)
  202. if !ok {
  203. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
  204. }
  205. h.Header.MessageSequence = seqPred
  206. // seqPred++ // this is the last use of seqPred
  207. raw, err := h.Marshal()
  208. if err != nil {
  209. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  210. }
  211. merged = append(merged, raw...)
  212. }
  213. pkts = append(pkts,
  214. &packet{
  215. record: &recordlayer.RecordLayer{
  216. Header: recordlayer.Header{
  217. Version: protocol.Version1_2,
  218. },
  219. Content: &protocol.ChangeCipherSpec{},
  220. },
  221. })
  222. if len(state.localVerifyData) == 0 {
  223. plainText := cache.pullAndMerge(
  224. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  225. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  226. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
  227. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  228. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
  229. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  230. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
  231. handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
  232. handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
  233. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
  234. )
  235. var err error
  236. state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc())
  237. if err != nil {
  238. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  239. }
  240. }
  241. pkts = append(pkts,
  242. &packet{
  243. record: &recordlayer.RecordLayer{
  244. Header: recordlayer.Header{
  245. Version: protocol.Version1_2,
  246. Epoch: 1,
  247. },
  248. Content: &handshake.Handshake{
  249. Message: &handshake.MessageFinished{
  250. VerifyData: state.localVerifyData,
  251. },
  252. },
  253. },
  254. shouldEncrypt: true,
  255. resetLocalSequenceNumber: true,
  256. })
  257. return pkts, nil, nil
  258. }
  259. func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit
  260. if state.cipherSuite.IsInitialized() {
  261. return nil, nil //nolint
  262. }
  263. clientRandom := state.localRandom.MarshalFixed()
  264. serverRandom := state.remoteRandom.MarshalFixed()
  265. var err error
  266. if state.extendedMasterSecret {
  267. var sessionHash []byte
  268. sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText)
  269. if err != nil {
  270. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  271. }
  272. state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
  273. if err != nil {
  274. return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
  275. }
  276. } else {
  277. state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
  278. if err != nil {
  279. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  280. }
  281. }
  282. if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
  283. // Verify that the pair of hash algorithm and signiture is listed.
  284. var validSignatureScheme bool
  285. for _, ss := range cfg.localSignatureSchemes {
  286. if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
  287. validSignatureScheme = true
  288. break
  289. }
  290. }
  291. if !validSignatureScheme {
  292. return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
  293. }
  294. expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve)
  295. if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil {
  296. return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  297. }
  298. var chains [][]*x509.Certificate
  299. if !cfg.insecureSkipVerify {
  300. if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil {
  301. return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  302. }
  303. }
  304. if cfg.verifyPeerCertificate != nil {
  305. if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
  306. return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  307. }
  308. }
  309. }
  310. if cfg.verifyConnection != nil {
  311. if err = cfg.verifyConnection(state.clone()); err != nil {
  312. return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  313. }
  314. }
  315. if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil {
  316. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  317. }
  318. cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
  319. return nil, nil //nolint
  320. }