send_stream.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. package quic
  2. import (
  3. "context"
  4. "fmt"
  5. "sync"
  6. "time"
  7. "github.com/Psiphon-Labs/quic-go/internal/ackhandler"
  8. "github.com/Psiphon-Labs/quic-go/internal/flowcontrol"
  9. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  10. "github.com/Psiphon-Labs/quic-go/internal/qerr"
  11. "github.com/Psiphon-Labs/quic-go/internal/utils"
  12. "github.com/Psiphon-Labs/quic-go/internal/wire"
  13. )
  14. type sendStreamI interface {
  15. SendStream
  16. handleStopSendingFrame(*wire.StopSendingFrame)
  17. hasData() bool
  18. popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool)
  19. closeForShutdown(error)
  20. updateSendWindow(protocol.ByteCount)
  21. }
  22. type sendStream struct {
  23. mutex sync.Mutex
  24. numOutstandingFrames int64
  25. retransmissionQueue []*wire.StreamFrame
  26. ctx context.Context
  27. ctxCancel context.CancelFunc
  28. streamID protocol.StreamID
  29. sender streamSender
  30. writeOffset protocol.ByteCount
  31. cancelWriteErr error
  32. closeForShutdownErr error
  33. closedForShutdown bool // set when CloseForShutdown() is called
  34. finishedWriting bool // set once Close() is called
  35. canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
  36. finSent bool // set when a STREAM_FRAME with FIN bit has been sent
  37. completed bool // set when this stream has been reported to the streamSender as completed
  38. dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
  39. nextFrame *wire.StreamFrame
  40. writeChan chan struct{}
  41. writeOnce chan struct{}
  42. deadline time.Time
  43. flowController flowcontrol.StreamFlowController
  44. version protocol.VersionNumber
  45. }
  46. var (
  47. _ SendStream = &sendStream{}
  48. _ sendStreamI = &sendStream{}
  49. )
  50. func newSendStream(
  51. streamID protocol.StreamID,
  52. sender streamSender,
  53. flowController flowcontrol.StreamFlowController,
  54. version protocol.VersionNumber,
  55. ) *sendStream {
  56. s := &sendStream{
  57. streamID: streamID,
  58. sender: sender,
  59. flowController: flowController,
  60. writeChan: make(chan struct{}, 1),
  61. writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
  62. version: version,
  63. }
  64. s.ctx, s.ctxCancel = context.WithCancel(context.Background())
  65. return s
  66. }
  67. func (s *sendStream) StreamID() protocol.StreamID {
  68. return s.streamID // same for receiveStream and sendStream
  69. }
  70. func (s *sendStream) Write(p []byte) (int, error) {
  71. // Concurrent use of Write is not permitted (and doesn't make any sense),
  72. // but sometimes people do it anyway.
  73. // Make sure that we only execute one call at any given time to avoid hard to debug failures.
  74. s.writeOnce <- struct{}{}
  75. defer func() { <-s.writeOnce }()
  76. s.mutex.Lock()
  77. defer s.mutex.Unlock()
  78. if s.finishedWriting {
  79. return 0, fmt.Errorf("write on closed stream %d", s.streamID)
  80. }
  81. if s.canceledWrite {
  82. return 0, s.cancelWriteErr
  83. }
  84. if s.closeForShutdownErr != nil {
  85. return 0, s.closeForShutdownErr
  86. }
  87. if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
  88. return 0, errDeadline
  89. }
  90. if len(p) == 0 {
  91. return 0, nil
  92. }
  93. s.dataForWriting = p
  94. var (
  95. deadlineTimer *utils.Timer
  96. bytesWritten int
  97. notifiedSender bool
  98. )
  99. for {
  100. var copied bool
  101. var deadline time.Time
  102. // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
  103. // which can the be popped the next time we assemble a packet.
  104. // This allows us to return Write() when all data but x bytes have been sent out.
  105. // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
  106. // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
  107. if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
  108. if s.nextFrame == nil {
  109. f := wire.GetStreamFrame()
  110. f.Offset = s.writeOffset
  111. f.StreamID = s.streamID
  112. f.DataLenPresent = true
  113. f.Data = f.Data[:len(s.dataForWriting)]
  114. copy(f.Data, s.dataForWriting)
  115. s.nextFrame = f
  116. } else {
  117. l := len(s.nextFrame.Data)
  118. s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
  119. copy(s.nextFrame.Data[l:], s.dataForWriting)
  120. }
  121. s.dataForWriting = nil
  122. bytesWritten = len(p)
  123. copied = true
  124. } else {
  125. bytesWritten = len(p) - len(s.dataForWriting)
  126. deadline = s.deadline
  127. if !deadline.IsZero() {
  128. if !time.Now().Before(deadline) {
  129. s.dataForWriting = nil
  130. return bytesWritten, errDeadline
  131. }
  132. if deadlineTimer == nil {
  133. deadlineTimer = utils.NewTimer()
  134. defer deadlineTimer.Stop()
  135. }
  136. deadlineTimer.Reset(deadline)
  137. }
  138. if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
  139. break
  140. }
  141. }
  142. s.mutex.Unlock()
  143. if !notifiedSender {
  144. s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
  145. notifiedSender = true
  146. }
  147. if copied {
  148. s.mutex.Lock()
  149. break
  150. }
  151. if deadline.IsZero() {
  152. <-s.writeChan
  153. } else {
  154. select {
  155. case <-s.writeChan:
  156. case <-deadlineTimer.Chan():
  157. deadlineTimer.SetRead()
  158. }
  159. }
  160. s.mutex.Lock()
  161. }
  162. if bytesWritten == len(p) {
  163. return bytesWritten, nil
  164. }
  165. if s.closeForShutdownErr != nil {
  166. return bytesWritten, s.closeForShutdownErr
  167. } else if s.cancelWriteErr != nil {
  168. return bytesWritten, s.cancelWriteErr
  169. }
  170. return bytesWritten, nil
  171. }
  172. func (s *sendStream) canBufferStreamFrame() bool {
  173. var l protocol.ByteCount
  174. if s.nextFrame != nil {
  175. l = s.nextFrame.DataLen()
  176. }
  177. return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
  178. }
  179. // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
  180. // maxBytes is the maximum length this frame (including frame header) will have.
  181. func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) {
  182. s.mutex.Lock()
  183. f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes)
  184. if f != nil {
  185. s.numOutstandingFrames++
  186. }
  187. s.mutex.Unlock()
  188. if f == nil {
  189. return nil, hasMoreData
  190. }
  191. return &ackhandler.Frame{Frame: f, OnLost: s.queueRetransmission, OnAcked: s.frameAcked}, hasMoreData
  192. }
  193. func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
  194. if s.canceledWrite || s.closeForShutdownErr != nil {
  195. return nil, false
  196. }
  197. if len(s.retransmissionQueue) > 0 {
  198. f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes)
  199. if f != nil || hasMoreRetransmissions {
  200. if f == nil {
  201. return nil, true
  202. }
  203. // We always claim that we have more data to send.
  204. // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
  205. return f, true
  206. }
  207. }
  208. if len(s.dataForWriting) == 0 && s.nextFrame == nil {
  209. if s.finishedWriting && !s.finSent {
  210. s.finSent = true
  211. return &wire.StreamFrame{
  212. StreamID: s.streamID,
  213. Offset: s.writeOffset,
  214. DataLenPresent: true,
  215. Fin: true,
  216. }, false
  217. }
  218. return nil, false
  219. }
  220. sendWindow := s.flowController.SendWindowSize()
  221. if sendWindow == 0 {
  222. if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
  223. s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
  224. StreamID: s.streamID,
  225. MaximumStreamData: offset,
  226. })
  227. return nil, false
  228. }
  229. return nil, true
  230. }
  231. f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow)
  232. if dataLen := f.DataLen(); dataLen > 0 {
  233. s.writeOffset += f.DataLen()
  234. s.flowController.AddBytesSent(f.DataLen())
  235. }
  236. f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
  237. if f.Fin {
  238. s.finSent = true
  239. }
  240. return f, hasMoreData
  241. }
  242. func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) {
  243. if s.nextFrame != nil {
  244. nextFrame := s.nextFrame
  245. s.nextFrame = nil
  246. maxDataLen := utils.Min(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version))
  247. if nextFrame.DataLen() > maxDataLen {
  248. s.nextFrame = wire.GetStreamFrame()
  249. s.nextFrame.StreamID = s.streamID
  250. s.nextFrame.Offset = s.writeOffset + maxDataLen
  251. s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
  252. s.nextFrame.DataLenPresent = true
  253. copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
  254. nextFrame.Data = nextFrame.Data[:maxDataLen]
  255. } else {
  256. s.signalWrite()
  257. }
  258. return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
  259. }
  260. f := wire.GetStreamFrame()
  261. f.Fin = false
  262. f.StreamID = s.streamID
  263. f.Offset = s.writeOffset
  264. f.DataLenPresent = true
  265. f.Data = f.Data[:0]
  266. hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow)
  267. if len(f.Data) == 0 && !f.Fin {
  268. f.PutBack()
  269. return nil, hasMoreData
  270. }
  271. return f, hasMoreData
  272. }
  273. func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool {
  274. maxDataLen := f.MaxDataLen(maxBytes, s.version)
  275. if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
  276. return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
  277. }
  278. s.getDataForWriting(f, utils.Min(maxDataLen, sendWindow))
  279. return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
  280. }
  281. func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) {
  282. f := s.retransmissionQueue[0]
  283. newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, s.version)
  284. if needsSplit {
  285. return newFrame, true
  286. }
  287. s.retransmissionQueue = s.retransmissionQueue[1:]
  288. return f, len(s.retransmissionQueue) > 0
  289. }
  290. func (s *sendStream) hasData() bool {
  291. s.mutex.Lock()
  292. hasData := len(s.dataForWriting) > 0
  293. s.mutex.Unlock()
  294. return hasData
  295. }
  296. func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
  297. if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
  298. f.Data = f.Data[:len(s.dataForWriting)]
  299. copy(f.Data, s.dataForWriting)
  300. s.dataForWriting = nil
  301. s.signalWrite()
  302. return
  303. }
  304. f.Data = f.Data[:maxBytes]
  305. copy(f.Data, s.dataForWriting)
  306. s.dataForWriting = s.dataForWriting[maxBytes:]
  307. if s.canBufferStreamFrame() {
  308. s.signalWrite()
  309. }
  310. }
  311. func (s *sendStream) frameAcked(f wire.Frame) {
  312. f.(*wire.StreamFrame).PutBack()
  313. s.mutex.Lock()
  314. if s.canceledWrite {
  315. s.mutex.Unlock()
  316. return
  317. }
  318. s.numOutstandingFrames--
  319. if s.numOutstandingFrames < 0 {
  320. panic("numOutStandingFrames negative")
  321. }
  322. newlyCompleted := s.isNewlyCompleted()
  323. s.mutex.Unlock()
  324. if newlyCompleted {
  325. s.sender.onStreamCompleted(s.streamID)
  326. }
  327. }
  328. func (s *sendStream) isNewlyCompleted() bool {
  329. completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
  330. if completed && !s.completed {
  331. s.completed = true
  332. return true
  333. }
  334. return false
  335. }
  336. func (s *sendStream) queueRetransmission(f wire.Frame) {
  337. sf := f.(*wire.StreamFrame)
  338. sf.DataLenPresent = true
  339. s.mutex.Lock()
  340. if s.canceledWrite {
  341. s.mutex.Unlock()
  342. return
  343. }
  344. s.retransmissionQueue = append(s.retransmissionQueue, sf)
  345. s.numOutstandingFrames--
  346. if s.numOutstandingFrames < 0 {
  347. panic("numOutStandingFrames negative")
  348. }
  349. s.mutex.Unlock()
  350. s.sender.onHasStreamData(s.streamID)
  351. }
  352. func (s *sendStream) Close() error {
  353. s.mutex.Lock()
  354. if s.closedForShutdown {
  355. s.mutex.Unlock()
  356. return nil
  357. }
  358. if s.canceledWrite {
  359. s.mutex.Unlock()
  360. return fmt.Errorf("close called for canceled stream %d", s.streamID)
  361. }
  362. s.ctxCancel()
  363. s.finishedWriting = true
  364. s.mutex.Unlock()
  365. s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
  366. return nil
  367. }
  368. func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
  369. s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
  370. }
  371. // must be called after locking the mutex
  372. func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) {
  373. s.mutex.Lock()
  374. if s.canceledWrite {
  375. s.mutex.Unlock()
  376. return
  377. }
  378. s.ctxCancel()
  379. s.canceledWrite = true
  380. s.cancelWriteErr = writeErr
  381. s.numOutstandingFrames = 0
  382. s.retransmissionQueue = nil
  383. newlyCompleted := s.isNewlyCompleted()
  384. s.mutex.Unlock()
  385. s.signalWrite()
  386. s.sender.queueControlFrame(&wire.ResetStreamFrame{
  387. StreamID: s.streamID,
  388. FinalSize: s.writeOffset,
  389. ErrorCode: errorCode,
  390. })
  391. if newlyCompleted {
  392. s.sender.onStreamCompleted(s.streamID)
  393. }
  394. }
  395. func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
  396. s.mutex.Lock()
  397. hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
  398. s.mutex.Unlock()
  399. s.flowController.UpdateSendWindow(limit)
  400. if hasStreamData {
  401. s.sender.onHasStreamData(s.streamID)
  402. }
  403. }
  404. func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
  405. s.cancelWriteImpl(frame.ErrorCode, &StreamError{
  406. StreamID: s.streamID,
  407. ErrorCode: frame.ErrorCode,
  408. })
  409. }
  410. func (s *sendStream) Context() context.Context {
  411. return s.ctx
  412. }
  413. func (s *sendStream) SetWriteDeadline(t time.Time) error {
  414. s.mutex.Lock()
  415. s.deadline = t
  416. s.mutex.Unlock()
  417. s.signalWrite()
  418. return nil
  419. }
  420. // CloseForShutdown closes a stream abruptly.
  421. // It makes Write unblock (and return the error) immediately.
  422. // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
  423. func (s *sendStream) closeForShutdown(err error) {
  424. s.mutex.Lock()
  425. s.ctxCancel()
  426. s.closedForShutdown = true
  427. s.closeForShutdownErr = err
  428. s.mutex.Unlock()
  429. s.signalWrite()
  430. }
  431. // signalWrite performs a non-blocking send on the writeChan
  432. func (s *sendStream) signalWrite() {
  433. select {
  434. case s.writeChan <- struct{}{}:
  435. default:
  436. }
  437. }