send_stream.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. package quic
  2. import (
  3. "context"
  4. "fmt"
  5. "sync"
  6. "time"
  7. "github.com/Psiphon-Labs/quic-go/internal/ackhandler"
  8. "github.com/Psiphon-Labs/quic-go/internal/flowcontrol"
  9. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  10. "github.com/Psiphon-Labs/quic-go/internal/qerr"
  11. "github.com/Psiphon-Labs/quic-go/internal/utils"
  12. "github.com/Psiphon-Labs/quic-go/internal/wire"
  13. )
  14. type sendStreamI interface {
  15. SendStream
  16. handleStopSendingFrame(*wire.StopSendingFrame)
  17. hasData() bool
  18. popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool)
  19. closeForShutdown(error)
  20. updateSendWindow(protocol.ByteCount)
  21. }
  22. type sendStream struct {
  23. mutex sync.Mutex
  24. numOutstandingFrames int64
  25. retransmissionQueue []*wire.StreamFrame
  26. ctx context.Context
  27. ctxCancel context.CancelCauseFunc
  28. streamID protocol.StreamID
  29. sender streamSender
  30. writeOffset protocol.ByteCount
  31. cancelWriteErr error
  32. closeForShutdownErr error
  33. finishedWriting bool // set once Close() is called
  34. finSent bool // set when a STREAM_FRAME with FIN bit has been sent
  35. completed bool // set when this stream has been reported to the streamSender as completed
  36. dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
  37. nextFrame *wire.StreamFrame
  38. writeChan chan struct{}
  39. writeOnce chan struct{}
  40. deadline time.Time
  41. flowController flowcontrol.StreamFlowController
  42. }
  43. var (
  44. _ SendStream = &sendStream{}
  45. _ sendStreamI = &sendStream{}
  46. )
  47. func newSendStream(
  48. streamID protocol.StreamID,
  49. sender streamSender,
  50. flowController flowcontrol.StreamFlowController,
  51. ) *sendStream {
  52. s := &sendStream{
  53. streamID: streamID,
  54. sender: sender,
  55. flowController: flowController,
  56. writeChan: make(chan struct{}, 1),
  57. writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
  58. }
  59. s.ctx, s.ctxCancel = context.WithCancelCause(context.Background())
  60. return s
  61. }
  62. func (s *sendStream) StreamID() protocol.StreamID {
  63. return s.streamID // same for receiveStream and sendStream
  64. }
  65. func (s *sendStream) Write(p []byte) (int, error) {
  66. // Concurrent use of Write is not permitted (and doesn't make any sense),
  67. // but sometimes people do it anyway.
  68. // Make sure that we only execute one call at any given time to avoid hard to debug failures.
  69. s.writeOnce <- struct{}{}
  70. defer func() { <-s.writeOnce }()
  71. s.mutex.Lock()
  72. defer s.mutex.Unlock()
  73. if s.finishedWriting {
  74. return 0, fmt.Errorf("write on closed stream %d", s.streamID)
  75. }
  76. if s.cancelWriteErr != nil {
  77. return 0, s.cancelWriteErr
  78. }
  79. if s.closeForShutdownErr != nil {
  80. return 0, s.closeForShutdownErr
  81. }
  82. if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
  83. return 0, errDeadline
  84. }
  85. if len(p) == 0 {
  86. return 0, nil
  87. }
  88. s.dataForWriting = p
  89. var (
  90. deadlineTimer *utils.Timer
  91. bytesWritten int
  92. notifiedSender bool
  93. )
  94. for {
  95. var copied bool
  96. var deadline time.Time
  97. // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
  98. // which can then be popped the next time we assemble a packet.
  99. // This allows us to return Write() when all data but x bytes have been sent out.
  100. // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
  101. // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
  102. if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
  103. if s.nextFrame == nil {
  104. f := wire.GetStreamFrame()
  105. f.Offset = s.writeOffset
  106. f.StreamID = s.streamID
  107. f.DataLenPresent = true
  108. f.Data = f.Data[:len(s.dataForWriting)]
  109. copy(f.Data, s.dataForWriting)
  110. s.nextFrame = f
  111. } else {
  112. l := len(s.nextFrame.Data)
  113. s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
  114. copy(s.nextFrame.Data[l:], s.dataForWriting)
  115. }
  116. s.dataForWriting = nil
  117. bytesWritten = len(p)
  118. copied = true
  119. } else {
  120. bytesWritten = len(p) - len(s.dataForWriting)
  121. deadline = s.deadline
  122. if !deadline.IsZero() {
  123. if !time.Now().Before(deadline) {
  124. s.dataForWriting = nil
  125. return bytesWritten, errDeadline
  126. }
  127. if deadlineTimer == nil {
  128. deadlineTimer = utils.NewTimer()
  129. defer deadlineTimer.Stop()
  130. }
  131. deadlineTimer.Reset(deadline)
  132. }
  133. if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
  134. break
  135. }
  136. }
  137. s.mutex.Unlock()
  138. if !notifiedSender {
  139. s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
  140. notifiedSender = true
  141. }
  142. if copied {
  143. s.mutex.Lock()
  144. break
  145. }
  146. if deadline.IsZero() {
  147. <-s.writeChan
  148. } else {
  149. select {
  150. case <-s.writeChan:
  151. case <-deadlineTimer.Chan():
  152. deadlineTimer.SetRead()
  153. }
  154. }
  155. s.mutex.Lock()
  156. }
  157. if bytesWritten == len(p) {
  158. return bytesWritten, nil
  159. }
  160. if s.closeForShutdownErr != nil {
  161. return bytesWritten, s.closeForShutdownErr
  162. } else if s.cancelWriteErr != nil {
  163. return bytesWritten, s.cancelWriteErr
  164. }
  165. return bytesWritten, nil
  166. }
  167. func (s *sendStream) canBufferStreamFrame() bool {
  168. var l protocol.ByteCount
  169. if s.nextFrame != nil {
  170. l = s.nextFrame.DataLen()
  171. }
  172. return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
  173. }
  174. // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
  175. // maxBytes is the maximum length this frame (including frame header) will have.
  176. func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) {
  177. s.mutex.Lock()
  178. f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
  179. if f != nil {
  180. s.numOutstandingFrames++
  181. }
  182. s.mutex.Unlock()
  183. if f == nil {
  184. return ackhandler.StreamFrame{}, false, hasMoreData
  185. }
  186. return ackhandler.StreamFrame{
  187. Frame: f,
  188. Handler: (*sendStreamAckHandler)(s),
  189. }, true, hasMoreData
  190. }
  191. func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) {
  192. if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
  193. return nil, false
  194. }
  195. if len(s.retransmissionQueue) > 0 {
  196. f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v)
  197. if f != nil || hasMoreRetransmissions {
  198. if f == nil {
  199. return nil, true
  200. }
  201. // We always claim that we have more data to send.
  202. // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
  203. return f, true
  204. }
  205. }
  206. if len(s.dataForWriting) == 0 && s.nextFrame == nil {
  207. if s.finishedWriting && !s.finSent {
  208. s.finSent = true
  209. return &wire.StreamFrame{
  210. StreamID: s.streamID,
  211. Offset: s.writeOffset,
  212. DataLenPresent: true,
  213. Fin: true,
  214. }, false
  215. }
  216. return nil, false
  217. }
  218. sendWindow := s.flowController.SendWindowSize()
  219. if sendWindow == 0 {
  220. if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
  221. s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
  222. StreamID: s.streamID,
  223. MaximumStreamData: offset,
  224. })
  225. return nil, false
  226. }
  227. return nil, true
  228. }
  229. f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v)
  230. if dataLen := f.DataLen(); dataLen > 0 {
  231. s.writeOffset += f.DataLen()
  232. s.flowController.AddBytesSent(f.DataLen())
  233. }
  234. f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
  235. if f.Fin {
  236. s.finSent = true
  237. }
  238. return f, hasMoreData
  239. }
  240. func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool) {
  241. if s.nextFrame != nil {
  242. nextFrame := s.nextFrame
  243. s.nextFrame = nil
  244. maxDataLen := min(sendWindow, nextFrame.MaxDataLen(maxBytes, v))
  245. if nextFrame.DataLen() > maxDataLen {
  246. s.nextFrame = wire.GetStreamFrame()
  247. s.nextFrame.StreamID = s.streamID
  248. s.nextFrame.Offset = s.writeOffset + maxDataLen
  249. s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
  250. s.nextFrame.DataLenPresent = true
  251. copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
  252. nextFrame.Data = nextFrame.Data[:maxDataLen]
  253. } else {
  254. s.signalWrite()
  255. }
  256. return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
  257. }
  258. f := wire.GetStreamFrame()
  259. f.Fin = false
  260. f.StreamID = s.streamID
  261. f.Offset = s.writeOffset
  262. f.DataLenPresent = true
  263. f.Data = f.Data[:0]
  264. hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow, v)
  265. if len(f.Data) == 0 && !f.Fin {
  266. f.PutBack()
  267. return nil, hasMoreData
  268. }
  269. return f, hasMoreData
  270. }
  271. func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) bool {
  272. maxDataLen := f.MaxDataLen(maxBytes, v)
  273. if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
  274. return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
  275. }
  276. s.getDataForWriting(f, min(maxDataLen, sendWindow))
  277. return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
  278. }
  279. func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more retransmissions */) {
  280. f := s.retransmissionQueue[0]
  281. newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
  282. if needsSplit {
  283. return newFrame, true
  284. }
  285. s.retransmissionQueue = s.retransmissionQueue[1:]
  286. return f, len(s.retransmissionQueue) > 0
  287. }
  288. func (s *sendStream) hasData() bool {
  289. s.mutex.Lock()
  290. hasData := len(s.dataForWriting) > 0
  291. s.mutex.Unlock()
  292. return hasData
  293. }
  294. func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
  295. if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
  296. f.Data = f.Data[:len(s.dataForWriting)]
  297. copy(f.Data, s.dataForWriting)
  298. s.dataForWriting = nil
  299. s.signalWrite()
  300. return
  301. }
  302. f.Data = f.Data[:maxBytes]
  303. copy(f.Data, s.dataForWriting)
  304. s.dataForWriting = s.dataForWriting[maxBytes:]
  305. if s.canBufferStreamFrame() {
  306. s.signalWrite()
  307. }
  308. }
  309. func (s *sendStream) isNewlyCompleted() bool {
  310. completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
  311. if completed && !s.completed {
  312. s.completed = true
  313. return true
  314. }
  315. return false
  316. }
  317. func (s *sendStream) Close() error {
  318. s.mutex.Lock()
  319. if s.closeForShutdownErr != nil {
  320. s.mutex.Unlock()
  321. return nil
  322. }
  323. if s.cancelWriteErr != nil {
  324. s.mutex.Unlock()
  325. return fmt.Errorf("close called for canceled stream %d", s.streamID)
  326. }
  327. s.ctxCancel(nil)
  328. s.finishedWriting = true
  329. s.mutex.Unlock()
  330. s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
  331. return nil
  332. }
  333. func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
  334. s.cancelWriteImpl(errorCode, false)
  335. }
  336. // must be called after locking the mutex
  337. func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
  338. s.mutex.Lock()
  339. if s.cancelWriteErr != nil {
  340. s.mutex.Unlock()
  341. return
  342. }
  343. s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
  344. s.ctxCancel(s.cancelWriteErr)
  345. s.numOutstandingFrames = 0
  346. s.retransmissionQueue = nil
  347. newlyCompleted := s.isNewlyCompleted()
  348. s.mutex.Unlock()
  349. s.signalWrite()
  350. s.sender.queueControlFrame(&wire.ResetStreamFrame{
  351. StreamID: s.streamID,
  352. FinalSize: s.writeOffset,
  353. ErrorCode: errorCode,
  354. })
  355. if newlyCompleted {
  356. s.sender.onStreamCompleted(s.streamID)
  357. }
  358. }
  359. func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
  360. s.mutex.Lock()
  361. hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
  362. s.mutex.Unlock()
  363. s.flowController.UpdateSendWindow(limit)
  364. if hasStreamData {
  365. s.sender.onHasStreamData(s.streamID)
  366. }
  367. }
  368. func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
  369. s.cancelWriteImpl(frame.ErrorCode, true)
  370. }
  371. func (s *sendStream) Context() context.Context {
  372. return s.ctx
  373. }
  374. func (s *sendStream) SetWriteDeadline(t time.Time) error {
  375. s.mutex.Lock()
  376. s.deadline = t
  377. s.mutex.Unlock()
  378. s.signalWrite()
  379. return nil
  380. }
  381. // CloseForShutdown closes a stream abruptly.
  382. // It makes Write unblock (and return the error) immediately.
  383. // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
  384. func (s *sendStream) closeForShutdown(err error) {
  385. s.mutex.Lock()
  386. s.ctxCancel(err)
  387. s.closeForShutdownErr = err
  388. s.mutex.Unlock()
  389. s.signalWrite()
  390. }
  391. // signalWrite performs a non-blocking send on the writeChan
  392. func (s *sendStream) signalWrite() {
  393. select {
  394. case s.writeChan <- struct{}{}:
  395. default:
  396. }
  397. }
  398. type sendStreamAckHandler sendStream
  399. var _ ackhandler.FrameHandler = &sendStreamAckHandler{}
  400. func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
  401. sf := f.(*wire.StreamFrame)
  402. sf.PutBack()
  403. s.mutex.Lock()
  404. if s.cancelWriteErr != nil {
  405. s.mutex.Unlock()
  406. return
  407. }
  408. s.numOutstandingFrames--
  409. if s.numOutstandingFrames < 0 {
  410. panic("numOutStandingFrames negative")
  411. }
  412. newlyCompleted := (*sendStream)(s).isNewlyCompleted()
  413. s.mutex.Unlock()
  414. if newlyCompleted {
  415. s.sender.onStreamCompleted(s.streamID)
  416. }
  417. }
  418. func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
  419. sf := f.(*wire.StreamFrame)
  420. s.mutex.Lock()
  421. if s.cancelWriteErr != nil {
  422. s.mutex.Unlock()
  423. return
  424. }
  425. sf.DataLenPresent = true
  426. s.retransmissionQueue = append(s.retransmissionQueue, sf)
  427. s.numOutstandingFrames--
  428. if s.numOutstandingFrames < 0 {
  429. panic("numOutStandingFrames negative")
  430. }
  431. s.mutex.Unlock()
  432. s.sender.onHasStreamData(s.streamID)
  433. }