upload_queue.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. package splithttp
  2. // upload_queue is a specialized priorityqueue + channel to reorder generic
  3. // packets by a sequence number
  4. import (
  5. "container/heap"
  6. "io"
  7. "sync"
  8. "github.com/xtls/xray-core/common/errors"
  9. )
  10. type Packet struct {
  11. Payload []byte
  12. Seq uint64
  13. }
  14. type uploadQueue struct {
  15. pushedPackets chan Packet
  16. writeCloseMutex sync.Mutex
  17. heap uploadHeap
  18. nextSeq uint64
  19. closed bool
  20. maxPackets int
  21. }
  22. func NewUploadQueue(maxPackets int) *uploadQueue {
  23. return &uploadQueue{
  24. pushedPackets: make(chan Packet, maxPackets),
  25. heap: uploadHeap{},
  26. nextSeq: 0,
  27. closed: false,
  28. maxPackets: maxPackets,
  29. }
  30. }
  31. func (h *uploadQueue) Push(p Packet) error {
  32. h.writeCloseMutex.Lock()
  33. defer h.writeCloseMutex.Unlock()
  34. if h.closed {
  35. return errors.New("splithttp packet queue closed")
  36. }
  37. h.pushedPackets <- p
  38. return nil
  39. }
  40. func (h *uploadQueue) Close() error {
  41. h.writeCloseMutex.Lock()
  42. defer h.writeCloseMutex.Unlock()
  43. if !h.closed {
  44. h.closed = true
  45. close(h.pushedPackets)
  46. }
  47. return nil
  48. }
  49. func (h *uploadQueue) Read(b []byte) (int, error) {
  50. if h.closed {
  51. return 0, io.EOF
  52. }
  53. if len(h.heap) == 0 {
  54. packet, more := <-h.pushedPackets
  55. if !more {
  56. return 0, io.EOF
  57. }
  58. heap.Push(&h.heap, packet)
  59. }
  60. for len(h.heap) > 0 {
  61. packet := heap.Pop(&h.heap).(Packet)
  62. n := 0
  63. if packet.Seq == h.nextSeq {
  64. copy(b, packet.Payload)
  65. n = min(len(b), len(packet.Payload))
  66. if n < len(packet.Payload) {
  67. // partial read
  68. packet.Payload = packet.Payload[n:]
  69. heap.Push(&h.heap, packet)
  70. } else {
  71. h.nextSeq = packet.Seq + 1
  72. }
  73. return n, nil
  74. }
  75. // misordered packet
  76. if packet.Seq > h.nextSeq {
  77. if len(h.heap) > h.maxPackets {
  78. // the "reassembly buffer" is too large, and we want to
  79. // constrain memory usage somehow. let's tear down the
  80. // connection, and hope the application retries.
  81. return 0, errors.New("packet queue is too large")
  82. }
  83. heap.Push(&h.heap, packet)
  84. packet2, more := <-h.pushedPackets
  85. if !more {
  86. return 0, io.EOF
  87. }
  88. heap.Push(&h.heap, packet2)
  89. }
  90. }
  91. return 0, nil
  92. }
  93. // heap code directly taken from https://pkg.go.dev/container/heap
  94. type uploadHeap []Packet
  95. func (h uploadHeap) Len() int { return len(h) }
  96. func (h uploadHeap) Less(i, j int) bool { return h[i].Seq < h[j].Seq }
  97. func (h uploadHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  98. func (h *uploadHeap) Push(x any) {
  99. // Push and Pop use pointer receivers because they modify the slice's length,
  100. // not just its contents.
  101. *h = append(*h, x.(Packet))
  102. }
  103. func (h *uploadHeap) Pop() any {
  104. old := *h
  105. n := len(old)
  106. x := old[n-1]
  107. *h = old[0 : n-1]
  108. return x
  109. }