handshake_cache.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "sync"
  6. "github.com/pion/dtls/v2/pkg/crypto/prf"
  7. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  8. )
  9. type handshakeCacheItem struct {
  10. typ handshake.Type
  11. isClient bool
  12. epoch uint16
  13. messageSequence uint16
  14. data []byte
  15. }
  16. type handshakeCachePullRule struct {
  17. typ handshake.Type
  18. epoch uint16
  19. isClient bool
  20. optional bool
  21. }
  22. type handshakeCache struct {
  23. cache []*handshakeCacheItem
  24. mu sync.Mutex
  25. }
  26. func newHandshakeCache() *handshakeCache {
  27. return &handshakeCache{}
  28. }
  29. func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) {
  30. h.mu.Lock()
  31. defer h.mu.Unlock()
  32. h.cache = append(h.cache, &handshakeCacheItem{
  33. data: append([]byte{}, data...),
  34. epoch: epoch,
  35. messageSequence: messageSequence,
  36. typ: typ,
  37. isClient: isClient,
  38. })
  39. }
  40. // returns a list handshakes that match the requested rules
  41. // the list will contain null entries for rules that can't be satisfied
  42. // multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies)
  43. func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem {
  44. h.mu.Lock()
  45. defer h.mu.Unlock()
  46. out := make([]*handshakeCacheItem, len(rules))
  47. for i, r := range rules {
  48. for _, c := range h.cache {
  49. if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch {
  50. switch {
  51. case out[i] == nil:
  52. out[i] = c
  53. case out[i].messageSequence < c.messageSequence:
  54. out[i] = c
  55. }
  56. }
  57. }
  58. }
  59. return out
  60. }
  61. // fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map.
  62. func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rules ...handshakeCachePullRule) (int, map[handshake.Type]handshake.Message, bool) {
  63. h.mu.Lock()
  64. defer h.mu.Unlock()
  65. ci := make(map[handshake.Type]*handshakeCacheItem)
  66. for _, r := range rules {
  67. var item *handshakeCacheItem
  68. for _, c := range h.cache {
  69. if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch {
  70. switch {
  71. case item == nil:
  72. item = c
  73. case item.messageSequence < c.messageSequence:
  74. item = c
  75. }
  76. }
  77. }
  78. if !r.optional && item == nil {
  79. // Missing mandatory message.
  80. return startSeq, nil, false
  81. }
  82. ci[r.typ] = item
  83. }
  84. out := make(map[handshake.Type]handshake.Message)
  85. seq := startSeq
  86. for _, r := range rules {
  87. t := r.typ
  88. i := ci[t]
  89. if i == nil {
  90. continue
  91. }
  92. var keyExchangeAlgorithm CipherSuiteKeyExchangeAlgorithm
  93. if cipherSuite != nil {
  94. keyExchangeAlgorithm = cipherSuite.KeyExchangeAlgorithm()
  95. }
  96. rawHandshake := &handshake.Handshake{
  97. KeyExchangeAlgorithm: keyExchangeAlgorithm,
  98. }
  99. if err := rawHandshake.Unmarshal(i.data); err != nil {
  100. return startSeq, nil, false
  101. }
  102. if uint16(seq) != rawHandshake.Header.MessageSequence {
  103. // There is a gap. Some messages are not arrived.
  104. return startSeq, nil, false
  105. }
  106. seq++
  107. out[t] = rawHandshake.Message
  108. }
  109. return seq, out, true
  110. }
  111. // pullAndMerge calls pull and then merges the results, ignoring any null entries
  112. func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte {
  113. merged := []byte{}
  114. for _, p := range h.pull(rules...) {
  115. if p != nil {
  116. merged = append(merged, p.data...)
  117. }
  118. }
  119. return merged
  120. }
  121. // sessionHash returns the session hash for Extended Master Secret support
  122. // https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4
  123. func (h *handshakeCache) sessionHash(hf prf.HashFunc, epoch uint16, additional ...[]byte) ([]byte, error) {
  124. merged := []byte{}
  125. // Order defined by https://tools.ietf.org/html/rfc5246#section-7.3
  126. handshakeBuffer := h.pull(
  127. handshakeCachePullRule{handshake.TypeClientHello, epoch, true, false},
  128. handshakeCachePullRule{handshake.TypeServerHello, epoch, false, false},
  129. handshakeCachePullRule{handshake.TypeCertificate, epoch, false, false},
  130. handshakeCachePullRule{handshake.TypeServerKeyExchange, epoch, false, false},
  131. handshakeCachePullRule{handshake.TypeCertificateRequest, epoch, false, false},
  132. handshakeCachePullRule{handshake.TypeServerHelloDone, epoch, false, false},
  133. handshakeCachePullRule{handshake.TypeCertificate, epoch, true, false},
  134. handshakeCachePullRule{handshake.TypeClientKeyExchange, epoch, true, false},
  135. )
  136. for _, p := range handshakeBuffer {
  137. if p == nil {
  138. continue
  139. }
  140. merged = append(merged, p.data...)
  141. }
  142. for _, a := range additional {
  143. merged = append(merged, a...)
  144. }
  145. hash := hf()
  146. if _, err := hash.Write(merged); err != nil {
  147. return []byte{}, err
  148. }
  149. return hash.Sum(nil), nil
  150. }