flight1handler.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "context"
  6. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  7. "github.com/pion/dtls/v2/pkg/protocol"
  8. "github.com/pion/dtls/v2/pkg/protocol/alert"
  9. "github.com/pion/dtls/v2/pkg/protocol/extension"
  10. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  11. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  12. )
  13. func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  14. // HelloVerifyRequest can be skipped by the server,
  15. // so allow ServerHello during flight1 also
  16. seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  17. handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
  18. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true},
  19. )
  20. if !ok {
  21. // No valid message received. Keep reading
  22. return 0, nil, nil
  23. }
  24. if _, ok := msgs[handshake.TypeServerHello]; ok {
  25. // Flight1 and flight2 were skipped.
  26. // Parse as flight3.
  27. return flight3Parse(ctx, c, state, cache, cfg)
  28. }
  29. if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok {
  30. // DTLS 1.2 clients must not assume that the server will use the protocol version
  31. // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
  32. if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
  33. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  34. }
  35. state.cookie = append([]byte{}, h.Cookie...)
  36. state.handshakeRecvSequence = seq
  37. return flight3, nil, nil
  38. }
  39. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  40. }
  41. func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  42. var zeroEpoch uint16
  43. state.localEpoch.Store(zeroEpoch)
  44. state.remoteEpoch.Store(zeroEpoch)
  45. state.namedCurve = defaultNamedCurve
  46. state.cookie = nil
  47. if err := state.localRandom.Populate(); err != nil {
  48. return nil, nil, err
  49. }
  50. if state.isClient && cfg.customClientHelloRandom != nil {
  51. state.localRandom.RandomBytes = cfg.customClientHelloRandom()
  52. }
  53. extensions := []extension.Extension{
  54. &extension.SupportedSignatureAlgorithms{
  55. SignatureHashAlgorithms: cfg.localSignatureSchemes,
  56. },
  57. &extension.RenegotiationInfo{
  58. RenegotiatedConnection: 0,
  59. },
  60. }
  61. var setEllipticCurveCryptographyClientHelloExtensions bool
  62. for _, c := range cfg.localCipherSuites {
  63. if c.ECC() {
  64. setEllipticCurveCryptographyClientHelloExtensions = true
  65. break
  66. }
  67. }
  68. if setEllipticCurveCryptographyClientHelloExtensions {
  69. extensions = append(extensions, []extension.Extension{
  70. &extension.SupportedEllipticCurves{
  71. EllipticCurves: cfg.ellipticCurves,
  72. },
  73. &extension.SupportedPointFormats{
  74. PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
  75. },
  76. }...)
  77. }
  78. if len(cfg.localSRTPProtectionProfiles) > 0 {
  79. extensions = append(extensions, &extension.UseSRTP{
  80. ProtectionProfiles: cfg.localSRTPProtectionProfiles,
  81. })
  82. }
  83. if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
  84. cfg.extendedMasterSecret == RequireExtendedMasterSecret {
  85. extensions = append(extensions, &extension.UseExtendedMasterSecret{
  86. Supported: true,
  87. })
  88. }
  89. if len(cfg.serverName) > 0 {
  90. extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
  91. }
  92. if len(cfg.supportedProtocols) > 0 {
  93. extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
  94. }
  95. if cfg.sessionStore != nil {
  96. cfg.log.Tracef("[handshake] try to resume session")
  97. if s, err := cfg.sessionStore.Get(c.sessionKey()); err != nil {
  98. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  99. } else if s.ID != nil {
  100. cfg.log.Tracef("[handshake] get saved session: %x", s.ID)
  101. state.SessionID = s.ID
  102. state.masterSecret = s.Secret
  103. }
  104. }
  105. return []*packet{
  106. {
  107. record: &recordlayer.RecordLayer{
  108. Header: recordlayer.Header{
  109. Version: protocol.Version1_2,
  110. },
  111. Content: &handshake.Handshake{
  112. Message: &handshake.MessageClientHello{
  113. Version: protocol.Version1_2,
  114. SessionID: state.SessionID,
  115. Cookie: state.cookie,
  116. Random: state.localRandom,
  117. CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
  118. CompressionMethods: defaultCompressionMethods(),
  119. Extensions: extensions,
  120. },
  121. },
  122. },
  123. },
  124. }, nil, nil
  125. }