flight1handler.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package dtls
  2. import (
  3. "context"
  4. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  5. "github.com/pion/dtls/v2/pkg/protocol"
  6. "github.com/pion/dtls/v2/pkg/protocol/alert"
  7. "github.com/pion/dtls/v2/pkg/protocol/extension"
  8. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  9. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  10. )
  11. func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  12. // HelloVerifyRequest can be skipped by the server,
  13. // so allow ServerHello during flight1 also
  14. seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  15. handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
  16. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true},
  17. )
  18. if !ok {
  19. // No valid message received. Keep reading
  20. return 0, nil, nil
  21. }
  22. if _, ok := msgs[handshake.TypeServerHello]; ok {
  23. // Flight1 and flight2 were skipped.
  24. // Parse as flight3.
  25. return flight3Parse(ctx, c, state, cache, cfg)
  26. }
  27. if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok {
  28. // DTLS 1.2 clients must not assume that the server will use the protocol version
  29. // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
  30. if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
  31. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  32. }
  33. state.cookie = append([]byte{}, h.Cookie...)
  34. state.handshakeRecvSequence = seq
  35. return flight3, nil, nil
  36. }
  37. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  38. }
  39. func flight1Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  40. var zeroEpoch uint16
  41. state.localEpoch.Store(zeroEpoch)
  42. state.remoteEpoch.Store(zeroEpoch)
  43. state.namedCurve = defaultNamedCurve
  44. state.cookie = nil
  45. if err := state.localRandom.Populate(); err != nil {
  46. return nil, nil, err
  47. }
  48. extensions := []extension.Extension{
  49. &extension.SupportedSignatureAlgorithms{
  50. SignatureHashAlgorithms: cfg.localSignatureSchemes,
  51. },
  52. &extension.RenegotiationInfo{
  53. RenegotiatedConnection: 0,
  54. },
  55. }
  56. var setEllipticCurveCryptographyClientHelloExtensions bool
  57. for _, c := range cfg.localCipherSuites {
  58. if c.ECC() {
  59. setEllipticCurveCryptographyClientHelloExtensions = true
  60. break
  61. }
  62. }
  63. if setEllipticCurveCryptographyClientHelloExtensions {
  64. extensions = append(extensions, []extension.Extension{
  65. &extension.SupportedEllipticCurves{
  66. EllipticCurves: cfg.ellipticCurves,
  67. },
  68. &extension.SupportedPointFormats{
  69. PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
  70. },
  71. }...)
  72. }
  73. if len(cfg.localSRTPProtectionProfiles) > 0 {
  74. extensions = append(extensions, &extension.UseSRTP{
  75. ProtectionProfiles: cfg.localSRTPProtectionProfiles,
  76. })
  77. }
  78. if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
  79. cfg.extendedMasterSecret == RequireExtendedMasterSecret {
  80. extensions = append(extensions, &extension.UseExtendedMasterSecret{
  81. Supported: true,
  82. })
  83. }
  84. if len(cfg.serverName) > 0 {
  85. extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
  86. }
  87. if len(cfg.supportedProtocols) > 0 {
  88. extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
  89. }
  90. if cfg.sessionStore != nil {
  91. cfg.log.Tracef("[handshake] try to resume session")
  92. if s, err := cfg.sessionStore.Get(c.sessionKey()); err != nil {
  93. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  94. } else if s.ID != nil {
  95. cfg.log.Tracef("[handshake] get saved session: %x", s.ID)
  96. state.SessionID = s.ID
  97. state.masterSecret = s.Secret
  98. }
  99. }
  100. return []*packet{
  101. {
  102. record: &recordlayer.RecordLayer{
  103. Header: recordlayer.Header{
  104. Version: protocol.Version1_2,
  105. },
  106. Content: &handshake.Handshake{
  107. Message: &handshake.MessageClientHello{
  108. Version: protocol.Version1_2,
  109. SessionID: state.SessionID,
  110. Cookie: state.cookie,
  111. Random: state.localRandom,
  112. CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
  113. CompressionMethods: defaultCompressionMethods(),
  114. Extensions: extensions,
  115. },
  116. },
  117. },
  118. },
  119. }, nil, nil
  120. }