flight0handler.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package dtls
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  6. "github.com/pion/dtls/v2/pkg/protocol"
  7. "github.com/pion/dtls/v2/pkg/protocol/alert"
  8. "github.com/pion/dtls/v2/pkg/protocol/extension"
  9. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  10. )
  11. func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  12. seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite,
  13. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  14. )
  15. if !ok {
  16. // No valid message received. Keep reading
  17. return 0, nil, nil
  18. }
  19. state.handshakeRecvSequence = seq
  20. var clientHello *handshake.MessageClientHello
  21. // Validate type
  22. if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
  23. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  24. }
  25. if !clientHello.Version.Equal(protocol.Version1_2) {
  26. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  27. }
  28. state.remoteRandom = clientHello.Random
  29. cipherSuites := []CipherSuite{}
  30. for _, id := range clientHello.CipherSuiteIDs {
  31. if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil {
  32. cipherSuites = append(cipherSuites, c)
  33. }
  34. }
  35. if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok {
  36. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
  37. }
  38. for _, val := range clientHello.Extensions {
  39. switch e := val.(type) {
  40. case *extension.SupportedEllipticCurves:
  41. if len(e.EllipticCurves) == 0 {
  42. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves
  43. }
  44. state.namedCurve = e.EllipticCurves[0]
  45. case *extension.UseSRTP:
  46. profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
  47. if !ok {
  48. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile
  49. }
  50. state.srtpProtectionProfile = profile
  51. case *extension.UseExtendedMasterSecret:
  52. if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
  53. state.extendedMasterSecret = true
  54. }
  55. case *extension.ServerName:
  56. state.serverName = e.ServerName // remote server name
  57. case *extension.ALPN:
  58. state.peerSupportedProtocols = e.ProtocolNameList
  59. }
  60. }
  61. if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
  62. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS
  63. }
  64. if state.localKeypair == nil {
  65. var err error
  66. state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve)
  67. if err != nil {
  68. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
  69. }
  70. }
  71. nextFlight := flight2
  72. if cfg.insecureSkipHelloVerify {
  73. nextFlight = flight4
  74. }
  75. return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight)
  76. }
  77. func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, next flightVal) (flightVal, *alert.Alert, error) {
  78. if len(sessionID) > 0 && cfg.sessionStore != nil {
  79. if s, err := cfg.sessionStore.Get(sessionID); err != nil {
  80. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  81. } else if s.ID != nil {
  82. cfg.log.Tracef("[handshake] resume session: %x", sessionID)
  83. state.SessionID = sessionID
  84. state.masterSecret = s.Secret
  85. if err := state.initCipherSuite(); err != nil {
  86. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  87. }
  88. clientRandom := state.localRandom.MarshalFixed()
  89. cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
  90. return flight4b, nil, nil
  91. }
  92. }
  93. return next, nil, nil
  94. }
  95. func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  96. // Initialize
  97. if !cfg.insecureSkipHelloVerify {
  98. state.cookie = make([]byte, cookieLength)
  99. if _, err := rand.Read(state.cookie); err != nil {
  100. return nil, nil, err
  101. }
  102. }
  103. var zeroEpoch uint16
  104. state.localEpoch.Store(zeroEpoch)
  105. state.remoteEpoch.Store(zeroEpoch)
  106. state.namedCurve = defaultNamedCurve
  107. if err := state.localRandom.Populate(); err != nil {
  108. return nil, nil, err
  109. }
  110. return nil, nil, nil
  111. }