flight0handler.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "context"
  6. "crypto/rand"
  7. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  8. "github.com/pion/dtls/v2/pkg/protocol"
  9. "github.com/pion/dtls/v2/pkg/protocol/alert"
  10. "github.com/pion/dtls/v2/pkg/protocol/extension"
  11. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  12. )
  13. func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  14. seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite,
  15. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  16. )
  17. if !ok {
  18. // No valid message received. Keep reading
  19. return 0, nil, nil
  20. }
  21. state.handshakeRecvSequence = seq
  22. var clientHello *handshake.MessageClientHello
  23. // Validate type
  24. if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
  25. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  26. }
  27. if !clientHello.Version.Equal(protocol.Version1_2) {
  28. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  29. }
  30. state.remoteRandom = clientHello.Random
  31. cipherSuites := []CipherSuite{}
  32. for _, id := range clientHello.CipherSuiteIDs {
  33. if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil {
  34. cipherSuites = append(cipherSuites, c)
  35. }
  36. }
  37. if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok {
  38. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
  39. }
  40. for _, val := range clientHello.Extensions {
  41. switch e := val.(type) {
  42. case *extension.SupportedEllipticCurves:
  43. if len(e.EllipticCurves) == 0 {
  44. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves
  45. }
  46. state.namedCurve = e.EllipticCurves[0]
  47. case *extension.UseSRTP:
  48. profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
  49. if !ok {
  50. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
  51. }
  52. state.setSRTPProtectionProfile(profile)
  53. case *extension.UseExtendedMasterSecret:
  54. if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
  55. state.extendedMasterSecret = true
  56. }
  57. case *extension.ServerName:
  58. state.serverName = e.ServerName // remote server name
  59. case *extension.ALPN:
  60. state.peerSupportedProtocols = e.ProtocolNameList
  61. }
  62. }
  63. if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
  64. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS
  65. }
  66. if state.localKeypair == nil {
  67. var err error
  68. state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve)
  69. if err != nil {
  70. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
  71. }
  72. }
  73. nextFlight := flight2
  74. if cfg.insecureSkipHelloVerify {
  75. nextFlight = flight4
  76. }
  77. return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight)
  78. }
  79. func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, next flightVal) (flightVal, *alert.Alert, error) {
  80. if len(sessionID) > 0 && cfg.sessionStore != nil {
  81. if s, err := cfg.sessionStore.Get(sessionID); err != nil {
  82. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  83. } else if s.ID != nil {
  84. cfg.log.Tracef("[handshake] resume session: %x", sessionID)
  85. state.SessionID = sessionID
  86. state.masterSecret = s.Secret
  87. if err := state.initCipherSuite(); err != nil {
  88. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  89. }
  90. clientRandom := state.localRandom.MarshalFixed()
  91. cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
  92. return flight4b, nil, nil
  93. }
  94. }
  95. return next, nil, nil
  96. }
  97. func flight0Generate(_ context.Context, _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  98. // Initialize
  99. if !cfg.insecureSkipHelloVerify {
  100. state.cookie = make([]byte, cookieLength)
  101. if _, err := rand.Read(state.cookie); err != nil {
  102. return nil, nil, err
  103. }
  104. }
  105. var zeroEpoch uint16
  106. state.localEpoch.Store(zeroEpoch)
  107. state.remoteEpoch.Store(zeroEpoch)
  108. state.namedCurve = defaultNamedCurve
  109. if err := state.localRandom.Populate(); err != nil {
  110. return nil, nil, err
  111. }
  112. return nil, nil, nil
  113. }