| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package dtls
- import (
- "bytes"
- "context"
- "errors"
- "github.com/pion/dtls/v2/internal/ciphersuite/types"
- "github.com/pion/dtls/v2/pkg/crypto/elliptic"
- "github.com/pion/dtls/v2/pkg/crypto/prf"
- "github.com/pion/dtls/v2/pkg/protocol"
- "github.com/pion/dtls/v2/pkg/protocol/alert"
- "github.com/pion/dtls/v2/pkg/protocol/extension"
- "github.com/pion/dtls/v2/pkg/protocol/handshake"
- "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
- )
- func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
- // Clients may receive multiple HelloVerifyRequest messages with different cookies.
- // Clients SHOULD handle this by sending a new ClientHello with a cookie in response
- // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
- seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
- handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
- )
- if ok {
- if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk {
- // DTLS 1.2 clients must not assume that the server will use the protocol version
- // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
- if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
- }
- state.cookie = append([]byte{}, h.Cookie...)
- state.handshakeRecvSequence = seq
- return flight3, nil, nil
- }
- }
- _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
- handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
- )
- if !ok {
- // Don't have enough messages. Keep reading
- return 0, nil, nil
- }
- if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk {
- if !h.Version.Equal(protocol.Version1_2) {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
- }
- for _, v := range h.Extensions {
- switch e := v.(type) {
- case *extension.UseSRTP:
- profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
- if !found {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
- }
- state.setSRTPProtectionProfile(profile)
- case *extension.UseExtendedMasterSecret:
- if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
- state.extendedMasterSecret = true
- }
- case *extension.ALPN:
- if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error?
- }
- state.NegotiatedProtocol = e.ProtocolNameList[0]
- }
- }
- if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
- }
- if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension
- }
- remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites)
- if remoteCipherSuite == nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
- }
- selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites)
- if !found {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
- }
- state.cipherSuite = selectedCipherSuite
- state.remoteRandom = h.Random
- cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String())
- if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) {
- return handleResumption(ctx, c, state, cache, cfg)
- }
- if len(state.SessionID) > 0 {
- cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID)
- if err := cfg.sessionStore.Del(state.SessionID); err != nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- }
- if cfg.sessionStore == nil {
- state.SessionID = []byte{}
- } else {
- state.SessionID = h.SessionID
- }
- state.masterSecret = []byte{}
- }
- if cfg.localPSKCallback != nil {
- seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
- handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true},
- handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
- )
- } else {
- seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
- handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true},
- handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
- handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true},
- handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
- )
- }
- if !ok {
- // Don't have enough messages. Keep reading
- return 0, nil, nil
- }
- state.handshakeRecvSequence = seq
- if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok {
- state.PeerCertificates = h.Certificate
- } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate
- }
- if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok {
- alertPtr, err := handleServerKeyExchange(c, state, cfg, h)
- if err != nil {
- return 0, alertPtr, err
- }
- }
- if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
- state.remoteRequestedCertificate = true
- }
- return flight5, nil, nil
- }
- func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
- if err := state.initCipherSuite(); err != nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- // Now, encrypted packets can be handled
- if err := c.handleQueuedPackets(ctx); err != nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
- handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
- )
- if !ok {
- // No valid message received. Keep reading
- return 0, nil, nil
- }
- var finished *handshake.MessageFinished
- if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
- }
- plainText := cache.pullAndMerge(
- handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
- handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
- )
- expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
- if err != nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
- }
- clientRandom := state.localRandom.MarshalFixed()
- cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
- return flight5b, nil, nil
- }
- func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
- var err error
- if state.cipherSuite == nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
- }
- if cfg.localPSKCallback != nil {
- var psk []byte
- if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- state.IdentityHint = h.IdentityHint
- switch state.cipherSuite.KeyExchangeAlgorithm() {
- case types.KeyExchangeAlgorithmPsk:
- state.preMasterSecret = prf.PSKPreMasterSecret(psk)
- case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk):
- if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
- if err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- default:
- return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
- }
- } else {
- if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil {
- return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- }
- return nil, nil //nolint:nilnil
- }
- func flight3Generate(_ context.Context, _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
- // [Psiphon]
- // With SetDTLSInsecureSkipHelloVerify set, this should never be called,
- // so handshake randomization is not implemented here.
- return nil, nil, errors.New("unexpected flight3Generate call")
- extensions := []extension.Extension{
- &extension.SupportedSignatureAlgorithms{
- SignatureHashAlgorithms: cfg.localSignatureSchemes,
- },
- &extension.RenegotiationInfo{
- RenegotiatedConnection: 0,
- },
- }
- if state.namedCurve != 0 {
- extensions = append(extensions, []extension.Extension{
- &extension.SupportedEllipticCurves{
- EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
- },
- &extension.SupportedPointFormats{
- PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
- },
- }...)
- }
- if len(cfg.localSRTPProtectionProfiles) > 0 {
- extensions = append(extensions, &extension.UseSRTP{
- ProtectionProfiles: cfg.localSRTPProtectionProfiles,
- })
- }
- if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
- cfg.extendedMasterSecret == RequireExtendedMasterSecret {
- extensions = append(extensions, &extension.UseExtendedMasterSecret{
- Supported: true,
- })
- }
- if len(cfg.serverName) > 0 {
- extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
- }
- if len(cfg.supportedProtocols) > 0 {
- extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
- }
- return []*packet{
- {
- record: &recordlayer.RecordLayer{
- Header: recordlayer.Header{
- Version: protocol.Version1_2,
- },
- Content: &handshake.Handshake{
- Message: &handshake.MessageClientHello{
- Version: protocol.Version1_2,
- SessionID: state.SessionID,
- Cookie: state.cookie,
- Random: state.localRandom,
- CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
- CompressionMethods: defaultCompressionMethods(),
- Extensions: extensions,
- },
- },
- },
- },
- }, nil, nil
- }
|