fragment_buffer.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "github.com/pion/dtls/v2/pkg/protocol"
  6. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  7. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  8. )
  9. // 2 megabytes
  10. const fragmentBufferMaxSize = 2000000
  11. type fragment struct {
  12. recordLayerHeader recordlayer.Header
  13. handshakeHeader handshake.Header
  14. data []byte
  15. }
  16. type fragmentBuffer struct {
  17. // map of MessageSequenceNumbers that hold slices of fragments
  18. cache map[uint16][]*fragment
  19. currentMessageSequenceNumber uint16
  20. }
  21. func newFragmentBuffer() *fragmentBuffer {
  22. return &fragmentBuffer{cache: map[uint16][]*fragment{}}
  23. }
  24. // current total size of buffer
  25. func (f *fragmentBuffer) size() int {
  26. size := 0
  27. for i := range f.cache {
  28. for j := range f.cache[i] {
  29. size += len(f.cache[i][j].data)
  30. }
  31. }
  32. return size
  33. }
  34. // Attempts to push a DTLS packet to the fragmentBuffer
  35. // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
  36. // when an error returns it is fatal, and the DTLS connection should be stopped
  37. func (f *fragmentBuffer) push(buf []byte) (bool, error) {
  38. if f.size()+len(buf) >= fragmentBufferMaxSize {
  39. return false, errFragmentBufferOverflow
  40. }
  41. frag := new(fragment)
  42. if err := frag.recordLayerHeader.Unmarshal(buf); err != nil {
  43. return false, err
  44. }
  45. // fragment isn't a handshake, we don't need to handle it
  46. if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
  47. return false, nil
  48. }
  49. for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) {
  50. if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
  51. return false, err
  52. }
  53. if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok {
  54. f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{}
  55. }
  56. // end index should be the length of handshake header but if the handshake
  57. // was fragmented, we should keep them all
  58. end := int(handshake.HeaderLength + frag.handshakeHeader.Length)
  59. if size := len(buf); end > size {
  60. end = size
  61. }
  62. // Discard all headers, when rebuilding the packet we will re-build
  63. frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...)
  64. f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag)
  65. buf = buf[end:]
  66. }
  67. return true, nil
  68. }
  69. func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
  70. frags, ok := f.cache[f.currentMessageSequenceNumber]
  71. if !ok {
  72. return nil, 0
  73. }
  74. // Go doesn't support recursive lambdas
  75. var appendMessage func(targetOffset uint32) bool
  76. rawMessage := []byte{}
  77. appendMessage = func(targetOffset uint32) bool {
  78. for _, f := range frags {
  79. if f.handshakeHeader.FragmentOffset == targetOffset {
  80. fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength)
  81. if fragmentEnd != f.handshakeHeader.Length && f.handshakeHeader.FragmentLength != 0 {
  82. if !appendMessage(fragmentEnd) {
  83. return false
  84. }
  85. }
  86. rawMessage = append(f.data, rawMessage...)
  87. return true
  88. }
  89. }
  90. return false
  91. }
  92. // Recursively collect up
  93. if !appendMessage(0) {
  94. return nil, 0
  95. }
  96. firstHeader := frags[0].handshakeHeader
  97. firstHeader.FragmentOffset = 0
  98. firstHeader.FragmentLength = firstHeader.Length
  99. rawHeader, err := firstHeader.Marshal()
  100. if err != nil {
  101. return nil, 0
  102. }
  103. messageEpoch := frags[0].recordLayerHeader.Epoch
  104. delete(f.cache, f.currentMessageSequenceNumber)
  105. f.currentMessageSequenceNumber++
  106. return append(rawHeader, rawMessage...), messageEpoch
  107. }