| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package dtls
- import (
- "context"
- "crypto/rand"
- "github.com/pion/dtls/v2/pkg/crypto/elliptic"
- "github.com/pion/dtls/v2/pkg/protocol"
- "github.com/pion/dtls/v2/pkg/protocol/alert"
- "github.com/pion/dtls/v2/pkg/protocol/extension"
- "github.com/pion/dtls/v2/pkg/protocol/handshake"
- )
- func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
- seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite,
- handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
- )
- if !ok {
- // No valid message received. Keep reading
- return 0, nil, nil
- }
- state.handshakeRecvSequence = seq
- var clientHello *handshake.MessageClientHello
- // Validate type
- if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
- }
- if !clientHello.Version.Equal(protocol.Version1_2) {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
- }
- state.remoteRandom = clientHello.Random
- cipherSuites := []CipherSuite{}
- for _, id := range clientHello.CipherSuiteIDs {
- if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil {
- cipherSuites = append(cipherSuites, c)
- }
- }
- if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
- }
- for _, val := range clientHello.Extensions {
- switch e := val.(type) {
- case *extension.SupportedEllipticCurves:
- if len(e.EllipticCurves) == 0 {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves
- }
- state.namedCurve = e.EllipticCurves[0]
- case *extension.UseSRTP:
- profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
- if !ok {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
- }
- state.setSRTPProtectionProfile(profile)
- case *extension.UseExtendedMasterSecret:
- if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
- state.extendedMasterSecret = true
- }
- case *extension.ServerName:
- state.serverName = e.ServerName // remote server name
- case *extension.ALPN:
- state.peerSupportedProtocols = e.ProtocolNameList
- }
- }
- if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS
- }
- if state.localKeypair == nil {
- var err error
- state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve)
- if err != nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
- }
- }
- nextFlight := flight2
- if cfg.insecureSkipHelloVerify {
- nextFlight = flight4
- }
- return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight)
- }
- func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, next flightVal) (flightVal, *alert.Alert, error) {
- if len(sessionID) > 0 && cfg.sessionStore != nil {
- if s, err := cfg.sessionStore.Get(sessionID); err != nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- } else if s.ID != nil {
- cfg.log.Tracef("[handshake] resume session: %x", sessionID)
- state.SessionID = sessionID
- state.masterSecret = s.Secret
- if err := state.initCipherSuite(); err != nil {
- return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
- }
- clientRandom := state.localRandom.MarshalFixed()
- cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
- return flight4b, nil, nil
- }
- }
- return next, nil, nil
- }
- func flight0Generate(_ context.Context, _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
- // Initialize
- if !cfg.insecureSkipHelloVerify {
- state.cookie = make([]byte, cookieLength)
- if _, err := rand.Read(state.cookie); err != nil {
- return nil, nil, err
- }
- }
- var zeroEpoch uint16
- state.localEpoch.Store(zeroEpoch)
- state.remoteEpoch.Store(zeroEpoch)
- state.namedCurve = defaultNamedCurve
- if err := state.localRandom.Populate(); err != nil {
- return nil, nil, err
- }
- return nil, nil, nil
- }
|