send_stream.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. package quic
  2. import (
  3. "context"
  4. "fmt"
  5. "sync"
  6. "time"
  7. "github.com/Psiphon-Labs/quic-go/internal/ackhandler"
  8. "github.com/Psiphon-Labs/quic-go/internal/flowcontrol"
  9. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  10. "github.com/Psiphon-Labs/quic-go/internal/qerr"
  11. "github.com/Psiphon-Labs/quic-go/internal/utils"
  12. "github.com/Psiphon-Labs/quic-go/internal/wire"
  13. )
  14. type sendStreamI interface {
  15. SendStream
  16. handleStopSendingFrame(*wire.StopSendingFrame)
  17. hasData() bool
  18. popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool)
  19. closeForShutdown(error)
  20. updateSendWindow(protocol.ByteCount)
  21. }
  22. type sendStream struct {
  23. mutex sync.Mutex
  24. numOutstandingFrames int64
  25. retransmissionQueue []*wire.StreamFrame
  26. ctx context.Context
  27. ctxCancel context.CancelFunc
  28. streamID protocol.StreamID
  29. sender streamSender
  30. writeOffset protocol.ByteCount
  31. cancelWriteErr error
  32. closeForShutdownErr error
  33. closedForShutdown bool // set when CloseForShutdown() is called
  34. finishedWriting bool // set once Close() is called
  35. canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
  36. finSent bool // set when a STREAM_FRAME with FIN bit has been sent
  37. completed bool // set when this stream has been reported to the streamSender as completed
  38. dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
  39. nextFrame *wire.StreamFrame
  40. writeChan chan struct{}
  41. deadline time.Time
  42. flowController flowcontrol.StreamFlowController
  43. version protocol.VersionNumber
  44. }
  45. var (
  46. _ SendStream = &sendStream{}
  47. _ sendStreamI = &sendStream{}
  48. )
  49. func newSendStream(
  50. streamID protocol.StreamID,
  51. sender streamSender,
  52. flowController flowcontrol.StreamFlowController,
  53. version protocol.VersionNumber,
  54. ) *sendStream {
  55. s := &sendStream{
  56. streamID: streamID,
  57. sender: sender,
  58. flowController: flowController,
  59. writeChan: make(chan struct{}, 1),
  60. version: version,
  61. }
  62. s.ctx, s.ctxCancel = context.WithCancel(context.Background())
  63. return s
  64. }
  65. func (s *sendStream) StreamID() protocol.StreamID {
  66. return s.streamID // same for receiveStream and sendStream
  67. }
  68. func (s *sendStream) Write(p []byte) (int, error) {
  69. s.mutex.Lock()
  70. defer s.mutex.Unlock()
  71. if s.finishedWriting {
  72. return 0, fmt.Errorf("write on closed stream %d", s.streamID)
  73. }
  74. if s.canceledWrite {
  75. return 0, s.cancelWriteErr
  76. }
  77. if s.closeForShutdownErr != nil {
  78. return 0, s.closeForShutdownErr
  79. }
  80. if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
  81. return 0, errDeadline
  82. }
  83. if len(p) == 0 {
  84. return 0, nil
  85. }
  86. s.dataForWriting = p
  87. var (
  88. deadlineTimer *utils.Timer
  89. bytesWritten int
  90. notifiedSender bool
  91. )
  92. for {
  93. var copied bool
  94. var deadline time.Time
  95. // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
  96. // which can the be popped the next time we assemble a packet.
  97. // This allows us to return Write() when all data but x bytes have been sent out.
  98. // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
  99. // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
  100. if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
  101. if s.nextFrame == nil {
  102. f := wire.GetStreamFrame()
  103. f.Offset = s.writeOffset
  104. f.StreamID = s.streamID
  105. f.DataLenPresent = true
  106. f.Data = f.Data[:len(s.dataForWriting)]
  107. copy(f.Data, s.dataForWriting)
  108. s.nextFrame = f
  109. } else {
  110. l := len(s.nextFrame.Data)
  111. s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
  112. copy(s.nextFrame.Data[l:], s.dataForWriting)
  113. }
  114. s.dataForWriting = nil
  115. bytesWritten = len(p)
  116. copied = true
  117. } else {
  118. bytesWritten = len(p) - len(s.dataForWriting)
  119. deadline = s.deadline
  120. if !deadline.IsZero() {
  121. if !time.Now().Before(deadline) {
  122. s.dataForWriting = nil
  123. return bytesWritten, errDeadline
  124. }
  125. if deadlineTimer == nil {
  126. deadlineTimer = utils.NewTimer()
  127. defer deadlineTimer.Stop()
  128. }
  129. deadlineTimer.Reset(deadline)
  130. }
  131. if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
  132. break
  133. }
  134. }
  135. s.mutex.Unlock()
  136. if !notifiedSender {
  137. s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
  138. notifiedSender = true
  139. }
  140. if copied {
  141. s.mutex.Lock()
  142. break
  143. }
  144. if deadline.IsZero() {
  145. <-s.writeChan
  146. } else {
  147. select {
  148. case <-s.writeChan:
  149. case <-deadlineTimer.Chan():
  150. deadlineTimer.SetRead()
  151. }
  152. }
  153. s.mutex.Lock()
  154. }
  155. if bytesWritten == len(p) {
  156. return bytesWritten, nil
  157. }
  158. if s.closeForShutdownErr != nil {
  159. return bytesWritten, s.closeForShutdownErr
  160. } else if s.cancelWriteErr != nil {
  161. return bytesWritten, s.cancelWriteErr
  162. }
  163. return bytesWritten, nil
  164. }
  165. func (s *sendStream) canBufferStreamFrame() bool {
  166. var l protocol.ByteCount
  167. if s.nextFrame != nil {
  168. l = s.nextFrame.DataLen()
  169. }
  170. return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
  171. }
  172. // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
  173. // maxBytes is the maximum length this frame (including frame header) will have.
  174. func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) {
  175. s.mutex.Lock()
  176. f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes)
  177. if f != nil {
  178. s.numOutstandingFrames++
  179. }
  180. s.mutex.Unlock()
  181. if f == nil {
  182. return nil, hasMoreData
  183. }
  184. return &ackhandler.Frame{Frame: f, OnLost: s.queueRetransmission, OnAcked: s.frameAcked}, hasMoreData
  185. }
  186. func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) {
  187. if s.canceledWrite || s.closeForShutdownErr != nil {
  188. return nil, false
  189. }
  190. if len(s.retransmissionQueue) > 0 {
  191. f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes)
  192. if f != nil || hasMoreRetransmissions {
  193. if f == nil {
  194. return nil, true
  195. }
  196. // We always claim that we have more data to send.
  197. // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
  198. return f, true
  199. }
  200. }
  201. if len(s.dataForWriting) == 0 && s.nextFrame == nil {
  202. if s.finishedWriting && !s.finSent {
  203. s.finSent = true
  204. return &wire.StreamFrame{
  205. StreamID: s.streamID,
  206. Offset: s.writeOffset,
  207. DataLenPresent: true,
  208. Fin: true,
  209. }, false
  210. }
  211. return nil, false
  212. }
  213. sendWindow := s.flowController.SendWindowSize()
  214. if sendWindow == 0 {
  215. if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
  216. s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
  217. StreamID: s.streamID,
  218. MaximumStreamData: offset,
  219. })
  220. return nil, false
  221. }
  222. return nil, true
  223. }
  224. f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow)
  225. if dataLen := f.DataLen(); dataLen > 0 {
  226. s.writeOffset += f.DataLen()
  227. s.flowController.AddBytesSent(f.DataLen())
  228. }
  229. f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
  230. if f.Fin {
  231. s.finSent = true
  232. }
  233. return f, hasMoreData
  234. }
  235. func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) {
  236. if s.nextFrame != nil {
  237. nextFrame := s.nextFrame
  238. s.nextFrame = nil
  239. maxDataLen := utils.MinByteCount(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version))
  240. if nextFrame.DataLen() > maxDataLen {
  241. s.nextFrame = wire.GetStreamFrame()
  242. s.nextFrame.StreamID = s.streamID
  243. s.nextFrame.Offset = s.writeOffset + maxDataLen
  244. s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
  245. s.nextFrame.DataLenPresent = true
  246. copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
  247. nextFrame.Data = nextFrame.Data[:maxDataLen]
  248. } else {
  249. s.signalWrite()
  250. }
  251. return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
  252. }
  253. f := wire.GetStreamFrame()
  254. f.Fin = false
  255. f.StreamID = s.streamID
  256. f.Offset = s.writeOffset
  257. f.DataLenPresent = true
  258. f.Data = f.Data[:0]
  259. hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow)
  260. if len(f.Data) == 0 && !f.Fin {
  261. f.PutBack()
  262. return nil, hasMoreData
  263. }
  264. return f, hasMoreData
  265. }
  266. func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool {
  267. maxDataLen := f.MaxDataLen(maxBytes, s.version)
  268. if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
  269. return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
  270. }
  271. s.getDataForWriting(f, utils.MinByteCount(maxDataLen, sendWindow))
  272. return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
  273. }
  274. func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) {
  275. f := s.retransmissionQueue[0]
  276. newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, s.version)
  277. if needsSplit {
  278. return newFrame, true
  279. }
  280. s.retransmissionQueue = s.retransmissionQueue[1:]
  281. return f, len(s.retransmissionQueue) > 0
  282. }
  283. func (s *sendStream) hasData() bool {
  284. s.mutex.Lock()
  285. hasData := len(s.dataForWriting) > 0
  286. s.mutex.Unlock()
  287. return hasData
  288. }
  289. func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
  290. if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
  291. f.Data = f.Data[:len(s.dataForWriting)]
  292. copy(f.Data, s.dataForWriting)
  293. s.dataForWriting = nil
  294. s.signalWrite()
  295. return
  296. }
  297. f.Data = f.Data[:maxBytes]
  298. copy(f.Data, s.dataForWriting)
  299. s.dataForWriting = s.dataForWriting[maxBytes:]
  300. if s.canBufferStreamFrame() {
  301. s.signalWrite()
  302. }
  303. }
  304. func (s *sendStream) frameAcked(f wire.Frame) {
  305. f.(*wire.StreamFrame).PutBack()
  306. s.mutex.Lock()
  307. if s.canceledWrite {
  308. s.mutex.Unlock()
  309. return
  310. }
  311. s.numOutstandingFrames--
  312. if s.numOutstandingFrames < 0 {
  313. panic("numOutStandingFrames negative")
  314. }
  315. newlyCompleted := s.isNewlyCompleted()
  316. s.mutex.Unlock()
  317. if newlyCompleted {
  318. s.sender.onStreamCompleted(s.streamID)
  319. }
  320. }
  321. func (s *sendStream) isNewlyCompleted() bool {
  322. completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
  323. if completed && !s.completed {
  324. s.completed = true
  325. return true
  326. }
  327. return false
  328. }
  329. func (s *sendStream) queueRetransmission(f wire.Frame) {
  330. sf := f.(*wire.StreamFrame)
  331. sf.DataLenPresent = true
  332. s.mutex.Lock()
  333. if s.canceledWrite {
  334. s.mutex.Unlock()
  335. return
  336. }
  337. s.retransmissionQueue = append(s.retransmissionQueue, sf)
  338. s.numOutstandingFrames--
  339. if s.numOutstandingFrames < 0 {
  340. panic("numOutStandingFrames negative")
  341. }
  342. s.mutex.Unlock()
  343. s.sender.onHasStreamData(s.streamID)
  344. }
  345. func (s *sendStream) Close() error {
  346. s.mutex.Lock()
  347. if s.closedForShutdown {
  348. s.mutex.Unlock()
  349. return nil
  350. }
  351. if s.canceledWrite {
  352. s.mutex.Unlock()
  353. return fmt.Errorf("close called for canceled stream %d", s.streamID)
  354. }
  355. s.ctxCancel()
  356. s.finishedWriting = true
  357. s.mutex.Unlock()
  358. s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
  359. return nil
  360. }
  361. func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
  362. s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
  363. }
  364. // must be called after locking the mutex
  365. func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) {
  366. s.mutex.Lock()
  367. if s.canceledWrite {
  368. s.mutex.Unlock()
  369. return
  370. }
  371. s.ctxCancel()
  372. s.canceledWrite = true
  373. s.cancelWriteErr = writeErr
  374. s.numOutstandingFrames = 0
  375. s.retransmissionQueue = nil
  376. newlyCompleted := s.isNewlyCompleted()
  377. s.mutex.Unlock()
  378. s.signalWrite()
  379. s.sender.queueControlFrame(&wire.ResetStreamFrame{
  380. StreamID: s.streamID,
  381. FinalSize: s.writeOffset,
  382. ErrorCode: errorCode,
  383. })
  384. if newlyCompleted {
  385. s.sender.onStreamCompleted(s.streamID)
  386. }
  387. }
  388. func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
  389. s.mutex.Lock()
  390. hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
  391. s.mutex.Unlock()
  392. s.flowController.UpdateSendWindow(limit)
  393. if hasStreamData {
  394. s.sender.onHasStreamData(s.streamID)
  395. }
  396. }
  397. func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
  398. s.cancelWriteImpl(frame.ErrorCode, &StreamError{
  399. StreamID: s.streamID,
  400. ErrorCode: frame.ErrorCode,
  401. })
  402. }
  403. func (s *sendStream) Context() context.Context {
  404. return s.ctx
  405. }
  406. func (s *sendStream) SetWriteDeadline(t time.Time) error {
  407. s.mutex.Lock()
  408. s.deadline = t
  409. s.mutex.Unlock()
  410. s.signalWrite()
  411. return nil
  412. }
  413. // CloseForShutdown closes a stream abruptly.
  414. // It makes Write unblock (and return the error) immediately.
  415. // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
  416. func (s *sendStream) closeForShutdown(err error) {
  417. s.mutex.Lock()
  418. s.ctxCancel()
  419. s.closedForShutdown = true
  420. s.closeForShutdownErr = err
  421. s.mutex.Unlock()
  422. s.signalWrite()
  423. }
  424. // signalWrite performs a non-blocking send on the writeChan
  425. func (s *sendStream) signalWrite() {
  426. select {
  427. case s.writeChan <- struct{}{}:
  428. default:
  429. }
  430. }