flight5handler.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. package dtls
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto"
  6. "crypto/x509"
  7. "github.com/pion/dtls/v2/pkg/crypto/prf"
  8. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  9. "github.com/pion/dtls/v2/pkg/protocol"
  10. "github.com/pion/dtls/v2/pkg/protocol/alert"
  11. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  12. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  13. )
  14. func flight5Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  15. _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  16. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
  17. )
  18. if !ok {
  19. // No valid message received. Keep reading
  20. return 0, nil, nil
  21. }
  22. var finished *handshake.MessageFinished
  23. if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
  24. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  25. }
  26. plainText := cache.pullAndMerge(
  27. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  28. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  29. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
  30. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  31. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
  32. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  33. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
  34. handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
  35. handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
  36. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
  37. )
  38. expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
  39. if err != nil {
  40. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  41. }
  42. if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
  43. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
  44. }
  45. if len(state.SessionID) > 0 {
  46. s := Session{
  47. ID: state.SessionID,
  48. Secret: state.masterSecret,
  49. }
  50. cfg.log.Tracef("[handshake] save new session: %x", s.ID)
  51. if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil {
  52. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  53. }
  54. }
  55. return flight5, nil, nil
  56. }
  57. func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit
  58. var privateKey crypto.PrivateKey
  59. var pkts []*packet
  60. if state.remoteRequestedCertificate {
  61. _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite,
  62. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false})
  63. if !ok {
  64. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired
  65. }
  66. reqInfo := CertificateRequestInfo{}
  67. if r, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
  68. reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames
  69. } else {
  70. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired
  71. }
  72. certificate, err := cfg.getClientCertificate(&reqInfo)
  73. if err != nil {
  74. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
  75. }
  76. if certificate == nil {
  77. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain
  78. }
  79. if certificate.Certificate != nil {
  80. privateKey = certificate.PrivateKey
  81. }
  82. pkts = append(pkts,
  83. &packet{
  84. record: &recordlayer.RecordLayer{
  85. Header: recordlayer.Header{
  86. Version: protocol.Version1_2,
  87. },
  88. Content: &handshake.Handshake{
  89. Message: &handshake.MessageCertificate{
  90. Certificate: certificate.Certificate,
  91. },
  92. },
  93. },
  94. })
  95. }
  96. clientKeyExchange := &handshake.MessageClientKeyExchange{}
  97. if cfg.localPSKCallback == nil {
  98. clientKeyExchange.PublicKey = state.localKeypair.PublicKey
  99. } else {
  100. clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint
  101. }
  102. if state != nil && state.localKeypair != nil && len(state.localKeypair.PublicKey) > 0 {
  103. clientKeyExchange.PublicKey = state.localKeypair.PublicKey
  104. }
  105. pkts = append(pkts,
  106. &packet{
  107. record: &recordlayer.RecordLayer{
  108. Header: recordlayer.Header{
  109. Version: protocol.Version1_2,
  110. },
  111. Content: &handshake.Handshake{
  112. Message: clientKeyExchange,
  113. },
  114. },
  115. })
  116. serverKeyExchangeData := cache.pullAndMerge(
  117. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  118. )
  119. serverKeyExchange := &handshake.MessageServerKeyExchange{}
  120. // handshakeMessageServerKeyExchange is optional for PSK
  121. if len(serverKeyExchangeData) == 0 {
  122. alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{})
  123. if err != nil {
  124. return nil, alertPtr, err
  125. }
  126. } else {
  127. rawHandshake := &handshake.Handshake{
  128. KeyExchangeAlgorithm: state.cipherSuite.KeyExchangeAlgorithm(),
  129. }
  130. err := rawHandshake.Unmarshal(serverKeyExchangeData)
  131. if err != nil {
  132. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err
  133. }
  134. switch h := rawHandshake.Message.(type) {
  135. case *handshake.MessageServerKeyExchange:
  136. serverKeyExchange = h
  137. default:
  138. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType
  139. }
  140. }
  141. // Append not-yet-sent packets
  142. merged := []byte{}
  143. seqPred := uint16(state.handshakeSendSequence)
  144. for _, p := range pkts {
  145. h, ok := p.record.Content.(*handshake.Handshake)
  146. if !ok {
  147. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
  148. }
  149. h.Header.MessageSequence = seqPred
  150. seqPred++
  151. raw, err := h.Marshal()
  152. if err != nil {
  153. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  154. }
  155. merged = append(merged, raw...)
  156. }
  157. if alertPtr, err := initalizeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil {
  158. return nil, alertPtr, err
  159. }
  160. // If the client has sent a certificate with signing ability, a digitally-signed
  161. // CertificateVerify message is sent to explicitly verify possession of the
  162. // private key in the certificate.
  163. if state.remoteRequestedCertificate && privateKey != nil {
  164. plainText := append(cache.pullAndMerge(
  165. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  166. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  167. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
  168. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  169. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
  170. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  171. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
  172. handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
  173. ), merged...)
  174. // Find compatible signature scheme
  175. signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, privateKey)
  176. if err != nil {
  177. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
  178. }
  179. certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash)
  180. if err != nil {
  181. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  182. }
  183. state.localCertificatesVerify = certVerify
  184. p := &packet{
  185. record: &recordlayer.RecordLayer{
  186. Header: recordlayer.Header{
  187. Version: protocol.Version1_2,
  188. },
  189. Content: &handshake.Handshake{
  190. Message: &handshake.MessageCertificateVerify{
  191. HashAlgorithm: signatureHashAlgo.Hash,
  192. SignatureAlgorithm: signatureHashAlgo.Signature,
  193. Signature: state.localCertificatesVerify,
  194. },
  195. },
  196. },
  197. }
  198. pkts = append(pkts, p)
  199. h, ok := p.record.Content.(*handshake.Handshake)
  200. if !ok {
  201. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
  202. }
  203. h.Header.MessageSequence = seqPred
  204. // seqPred++ // this is the last use of seqPred
  205. raw, err := h.Marshal()
  206. if err != nil {
  207. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  208. }
  209. merged = append(merged, raw...)
  210. }
  211. pkts = append(pkts,
  212. &packet{
  213. record: &recordlayer.RecordLayer{
  214. Header: recordlayer.Header{
  215. Version: protocol.Version1_2,
  216. },
  217. Content: &protocol.ChangeCipherSpec{},
  218. },
  219. })
  220. if len(state.localVerifyData) == 0 {
  221. plainText := cache.pullAndMerge(
  222. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  223. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  224. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
  225. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  226. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
  227. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  228. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
  229. handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
  230. handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
  231. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
  232. )
  233. var err error
  234. state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc())
  235. if err != nil {
  236. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  237. }
  238. }
  239. pkts = append(pkts,
  240. &packet{
  241. record: &recordlayer.RecordLayer{
  242. Header: recordlayer.Header{
  243. Version: protocol.Version1_2,
  244. Epoch: 1,
  245. },
  246. Content: &handshake.Handshake{
  247. Message: &handshake.MessageFinished{
  248. VerifyData: state.localVerifyData,
  249. },
  250. },
  251. },
  252. shouldEncrypt: true,
  253. resetLocalSequenceNumber: true,
  254. })
  255. return pkts, nil, nil
  256. }
  257. func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit
  258. if state.cipherSuite.IsInitialized() {
  259. return nil, nil //nolint
  260. }
  261. clientRandom := state.localRandom.MarshalFixed()
  262. serverRandom := state.remoteRandom.MarshalFixed()
  263. var err error
  264. if state.extendedMasterSecret {
  265. var sessionHash []byte
  266. sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText)
  267. if err != nil {
  268. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  269. }
  270. state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
  271. if err != nil {
  272. return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
  273. }
  274. } else {
  275. state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
  276. if err != nil {
  277. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  278. }
  279. }
  280. if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
  281. // Verify that the pair of hash algorithm and signiture is listed.
  282. var validSignatureScheme bool
  283. for _, ss := range cfg.localSignatureSchemes {
  284. if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
  285. validSignatureScheme = true
  286. break
  287. }
  288. }
  289. if !validSignatureScheme {
  290. return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
  291. }
  292. expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve)
  293. if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil {
  294. return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  295. }
  296. var chains [][]*x509.Certificate
  297. if !cfg.insecureSkipVerify {
  298. if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil {
  299. return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  300. }
  301. }
  302. if cfg.verifyPeerCertificate != nil {
  303. if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
  304. return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  305. }
  306. }
  307. }
  308. if cfg.verifyConnection != nil {
  309. if err = cfg.verifyConnection(state.clone()); err != nil {
  310. return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
  311. }
  312. }
  313. if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil {
  314. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  315. }
  316. cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
  317. return nil, nil //nolint
  318. }