| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package dtls
- import (
- "bytes"
- "context"
- "crypto"
- "crypto/x509"
- "github.com/pion/dtls/v2/pkg/crypto/prf"
- "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
- "github.com/pion/dtls/v2/pkg/protocol"
- "github.com/pion/dtls/v2/pkg/protocol/alert"
- "github.com/pion/dtls/v2/pkg/protocol/handshake"
- "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
- )
- func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
- _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
- handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
- )
- if !ok {
- // No valid message received. Keep reading
- return 0, nil, nil
- }
- var finished *handshake.MessageFinished
- if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
- }
- plainText := cache.pullAndMerge(
- handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
- )
- expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
- if err != nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
- }
- if len(state.SessionID) > 0 {
- s := Session{
- ID: state.SessionID,
- Secret: state.masterSecret,
- }
- cfg.log.Tracef("[handshake] save new session: %x", s.ID)
- if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- }
- return flight5, nil, nil
- }
- func flight5Generate(_ context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit
- var privateKey crypto.PrivateKey
- var pkts []*packet
- if state.remoteRequestedCertificate {
- _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite,
- handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false})
- if !ok {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired
- }
- reqInfo := CertificateRequestInfo{}
- if r, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
- reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames
- } else {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired
- }
- certificate, err := cfg.getClientCertificate(&reqInfo)
- if err != nil {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
- }
- if certificate == nil {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain
- }
- if certificate.Certificate != nil {
- privateKey = certificate.PrivateKey
- }
- pkts = append(pkts,
- &packet{
- record: &recordlayer.RecordLayer{
- Header: recordlayer.Header{
- Version: protocol.Version1_2,
- },
- Content: &handshake.Handshake{
- Message: &handshake.MessageCertificate{
- Certificate: certificate.Certificate,
- },
- },
- },
- })
- }
- clientKeyExchange := &handshake.MessageClientKeyExchange{}
- if cfg.localPSKCallback == nil {
- clientKeyExchange.PublicKey = state.localKeypair.PublicKey
- } else {
- clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint
- }
- if state != nil && state.localKeypair != nil && len(state.localKeypair.PublicKey) > 0 {
- clientKeyExchange.PublicKey = state.localKeypair.PublicKey
- }
- pkts = append(pkts,
- &packet{
- record: &recordlayer.RecordLayer{
- Header: recordlayer.Header{
- Version: protocol.Version1_2,
- },
- Content: &handshake.Handshake{
- Message: clientKeyExchange,
- },
- },
- })
- serverKeyExchangeData := cache.pullAndMerge(
- handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
- )
- serverKeyExchange := &handshake.MessageServerKeyExchange{}
- // handshakeMessageServerKeyExchange is optional for PSK
- if len(serverKeyExchangeData) == 0 {
- alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{})
- if err != nil {
- return nil, alertPtr, err
- }
- } else {
- rawHandshake := &handshake.Handshake{
- KeyExchangeAlgorithm: state.cipherSuite.KeyExchangeAlgorithm(),
- }
- err := rawHandshake.Unmarshal(serverKeyExchangeData)
- if err != nil {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err
- }
- switch h := rawHandshake.Message.(type) {
- case *handshake.MessageServerKeyExchange:
- serverKeyExchange = h
- default:
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType
- }
- }
- // Append not-yet-sent packets
- merged := []byte{}
- seqPred := uint16(state.handshakeSendSequence)
- for _, p := range pkts {
- h, ok := p.record.Content.(*handshake.Handshake)
- if !ok {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
- }
- h.Header.MessageSequence = seqPred
- seqPred++
- raw, err := h.Marshal()
- if err != nil {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- merged = append(merged, raw...)
- }
- if alertPtr, err := initalizeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil {
- return nil, alertPtr, err
- }
- // If the client has sent a certificate with signing ability, a digitally-signed
- // CertificateVerify message is sent to explicitly verify possession of the
- // private key in the certificate.
- if state.remoteRequestedCertificate && privateKey != nil {
- plainText := append(cache.pullAndMerge(
- handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
- ), merged...)
- // Find compatible signature scheme
- signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, privateKey)
- if err != nil {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
- }
- certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash)
- if err != nil {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- state.localCertificatesVerify = certVerify
- p := &packet{
- record: &recordlayer.RecordLayer{
- Header: recordlayer.Header{
- Version: protocol.Version1_2,
- },
- Content: &handshake.Handshake{
- Message: &handshake.MessageCertificateVerify{
- HashAlgorithm: signatureHashAlgo.Hash,
- SignatureAlgorithm: signatureHashAlgo.Signature,
- Signature: state.localCertificatesVerify,
- },
- },
- },
- }
- pkts = append(pkts, p)
- h, ok := p.record.Content.(*handshake.Handshake)
- if !ok {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType
- }
- h.Header.MessageSequence = seqPred
- // seqPred++ // this is the last use of seqPred
- raw, err := h.Marshal()
- if err != nil {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- merged = append(merged, raw...)
- }
- pkts = append(pkts,
- &packet{
- record: &recordlayer.RecordLayer{
- Header: recordlayer.Header{
- Version: protocol.Version1_2,
- },
- Content: &protocol.ChangeCipherSpec{},
- },
- })
- if len(state.localVerifyData) == 0 {
- plainText := cache.pullAndMerge(
- handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
- )
- var err error
- state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc())
- if err != nil {
- return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- }
- pkts = append(pkts,
- &packet{
- record: &recordlayer.RecordLayer{
- Header: recordlayer.Header{
- Version: protocol.Version1_2,
- Epoch: 1,
- },
- Content: &handshake.Handshake{
- Message: &handshake.MessageFinished{
- VerifyData: state.localVerifyData,
- },
- },
- },
- shouldEncrypt: true,
- resetLocalSequenceNumber: true,
- })
- return pkts, nil, nil
- }
- func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit
- if state.cipherSuite.IsInitialized() {
- return nil, nil //nolint
- }
- clientRandom := state.localRandom.MarshalFixed()
- serverRandom := state.remoteRandom.MarshalFixed()
- var err error
- if state.extendedMasterSecret {
- var sessionHash []byte
- sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText)
- if err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
- if err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
- }
- } else {
- state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
- if err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- }
- if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
- // Verify that the pair of hash algorithm and signiture is listed.
- var validSignatureScheme bool
- for _, ss := range cfg.localSignatureSchemes {
- if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
- validSignatureScheme = true
- break
- }
- }
- if !validSignatureScheme {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
- }
- expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve)
- if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
- }
- var chains [][]*x509.Certificate
- if !cfg.insecureSkipVerify {
- if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
- }
- }
- if cfg.verifyPeerCertificate != nil {
- if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
- }
- }
- }
- if cfg.verifyConnection != nil {
- if err = cfg.verifyConnection(state.clone()); err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
- }
- }
- if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
- return nil, nil //nolint
- }
|