flight4bhandler.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "bytes"
  6. "context"
  7. "github.com/pion/dtls/v2/pkg/crypto/prf"
  8. "github.com/pion/dtls/v2/pkg/protocol"
  9. "github.com/pion/dtls/v2/pkg/protocol/alert"
  10. "github.com/pion/dtls/v2/pkg/protocol/extension"
  11. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  12. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  13. )
  14. func flight4bParse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  15. _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  16. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
  17. )
  18. if !ok {
  19. // No valid message received. Keep reading
  20. return 0, nil, nil
  21. }
  22. var finished *handshake.MessageFinished
  23. if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
  24. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  25. }
  26. plainText := cache.pullAndMerge(
  27. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  28. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  29. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
  30. )
  31. expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc())
  32. if err != nil {
  33. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  34. }
  35. if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
  36. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
  37. }
  38. // Other party may re-transmit the last flight. Keep state to be flight4b.
  39. return flight4b, nil, nil
  40. }
  41. func flight4bGenerate(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  42. var pkts []*packet
  43. extensions := []extension.Extension{&extension.RenegotiationInfo{
  44. RenegotiatedConnection: 0,
  45. }}
  46. if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
  47. cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
  48. extensions = append(extensions, &extension.UseExtendedMasterSecret{
  49. Supported: true,
  50. })
  51. }
  52. if state.getSRTPProtectionProfile() != 0 {
  53. extensions = append(extensions, &extension.UseSRTP{
  54. ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
  55. })
  56. }
  57. selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols)
  58. if err != nil {
  59. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err
  60. }
  61. if selectedProto != "" {
  62. extensions = append(extensions, &extension.ALPN{
  63. ProtocolNameList: []string{selectedProto},
  64. })
  65. state.NegotiatedProtocol = selectedProto
  66. }
  67. cipherSuiteID := uint16(state.cipherSuite.ID())
  68. serverHello := &handshake.Handshake{
  69. Message: &handshake.MessageServerHello{
  70. Version: protocol.Version1_2,
  71. Random: state.localRandom,
  72. SessionID: state.SessionID,
  73. CipherSuiteID: &cipherSuiteID,
  74. CompressionMethod: defaultCompressionMethods()[0],
  75. Extensions: extensions,
  76. },
  77. }
  78. serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence)
  79. if len(state.localVerifyData) == 0 {
  80. plainText := cache.pullAndMerge(
  81. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  82. )
  83. raw, err := serverHello.Marshal()
  84. if err != nil {
  85. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  86. }
  87. plainText = append(plainText, raw...)
  88. state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
  89. if err != nil {
  90. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  91. }
  92. }
  93. pkts = append(pkts,
  94. &packet{
  95. record: &recordlayer.RecordLayer{
  96. Header: recordlayer.Header{
  97. Version: protocol.Version1_2,
  98. },
  99. Content: serverHello,
  100. },
  101. },
  102. &packet{
  103. record: &recordlayer.RecordLayer{
  104. Header: recordlayer.Header{
  105. Version: protocol.Version1_2,
  106. },
  107. Content: &protocol.ChangeCipherSpec{},
  108. },
  109. },
  110. &packet{
  111. record: &recordlayer.RecordLayer{
  112. Header: recordlayer.Header{
  113. Version: protocol.Version1_2,
  114. Epoch: 1,
  115. },
  116. Content: &handshake.Handshake{
  117. Message: &handshake.MessageFinished{
  118. VerifyData: state.localVerifyData,
  119. },
  120. },
  121. },
  122. shouldEncrypt: true,
  123. resetLocalSequenceNumber: true,
  124. },
  125. )
  126. return pkts, nil, nil
  127. }