flight4bhandler.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. package dtls
  2. import (
  3. "bytes"
  4. "context"
  5. "github.com/pion/dtls/v2/pkg/crypto/prf"
  6. "github.com/pion/dtls/v2/pkg/protocol"
  7. "github.com/pion/dtls/v2/pkg/protocol/alert"
  8. "github.com/pion/dtls/v2/pkg/protocol/extension"
  9. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  10. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  11. )
  12. func flight4bParse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  13. _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  14. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
  15. )
  16. if !ok {
  17. // No valid message received. Keep reading
  18. return 0, nil, nil
  19. }
  20. var finished *handshake.MessageFinished
  21. if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
  22. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  23. }
  24. plainText := cache.pullAndMerge(
  25. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  26. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  27. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
  28. )
  29. expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc())
  30. if err != nil {
  31. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  32. }
  33. if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
  34. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
  35. }
  36. // Other party may re-transmit the last flight. Keep state to be flight4b.
  37. return flight4b, nil, nil
  38. }
  39. func flight4bGenerate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  40. var pkts []*packet
  41. extensions := []extension.Extension{&extension.RenegotiationInfo{
  42. RenegotiatedConnection: 0,
  43. }}
  44. if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
  45. cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
  46. extensions = append(extensions, &extension.UseExtendedMasterSecret{
  47. Supported: true,
  48. })
  49. }
  50. if state.srtpProtectionProfile != 0 {
  51. extensions = append(extensions, &extension.UseSRTP{
  52. ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
  53. })
  54. }
  55. selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols)
  56. if err != nil {
  57. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err
  58. }
  59. if selectedProto != "" {
  60. extensions = append(extensions, &extension.ALPN{
  61. ProtocolNameList: []string{selectedProto},
  62. })
  63. state.NegotiatedProtocol = selectedProto
  64. }
  65. cipherSuiteID := uint16(state.cipherSuite.ID())
  66. serverHello := &handshake.Handshake{
  67. Message: &handshake.MessageServerHello{
  68. Version: protocol.Version1_2,
  69. Random: state.localRandom,
  70. SessionID: state.SessionID,
  71. CipherSuiteID: &cipherSuiteID,
  72. CompressionMethod: defaultCompressionMethods()[0],
  73. Extensions: extensions,
  74. },
  75. }
  76. serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence)
  77. if len(state.localVerifyData) == 0 {
  78. plainText := cache.pullAndMerge(
  79. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  80. )
  81. raw, err := serverHello.Marshal()
  82. if err != nil {
  83. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  84. }
  85. plainText = append(plainText, raw...)
  86. state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
  87. if err != nil {
  88. return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  89. }
  90. }
  91. pkts = append(pkts,
  92. &packet{
  93. record: &recordlayer.RecordLayer{
  94. Header: recordlayer.Header{
  95. Version: protocol.Version1_2,
  96. },
  97. Content: serverHello,
  98. },
  99. },
  100. &packet{
  101. record: &recordlayer.RecordLayer{
  102. Header: recordlayer.Header{
  103. Version: protocol.Version1_2,
  104. },
  105. Content: &protocol.ChangeCipherSpec{},
  106. },
  107. },
  108. &packet{
  109. record: &recordlayer.RecordLayer{
  110. Header: recordlayer.Header{
  111. Version: protocol.Version1_2,
  112. Epoch: 1,
  113. },
  114. Content: &handshake.Handshake{
  115. Message: &handshake.MessageFinished{
  116. VerifyData: state.localVerifyData,
  117. },
  118. },
  119. },
  120. shouldEncrypt: true,
  121. resetLocalSequenceNumber: true,
  122. },
  123. )
  124. return pkts, nil, nil
  125. }