stream.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. package marionette
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sort"
  9. "sync"
  10. "time"
  11. "go.uber.org/zap"
  12. )
  13. var (
  14. // ErrStreamClosed is returned enqueuing cells or writing data to a closed stream.
  15. // Dequeuing cells and reading data will be available until pending data is exhausted.
  16. ErrStreamClosed = errors.New("marionette: stream closed")
  17. // ErrWriteTooLarge is returned when a Write() is larger than the buffer.
  18. ErrWriteTooLarge = errors.New("marionette: write too large")
  19. )
  20. // Ensure type implements interface.
  21. var _ net.Conn = &Stream{}
  22. // Stream represents a readable and writable connection for plaintext data.
  23. // Data is injected into the stream using cells which provide ordering and payload data.
  24. // Implements the net.Conn interface.
  25. type Stream struct {
  26. mu sync.RWMutex
  27. id int
  28. rseq int
  29. wseq int
  30. // Read-side close management.
  31. ronce sync.Once
  32. readClosed bool
  33. readClosing chan struct{}
  34. // Write-side close management.
  35. wonce sync.Once
  36. writeClosed bool
  37. writeClosing chan struct{}
  38. // Notification when write-side has been closed.
  39. writeCloseNotified bool
  40. writeCloseNotifiedNotify chan struct{}
  41. // Local & remote addresses for net.Conn implementation.
  42. localAddr net.Addr
  43. remoteAddr net.Addr
  44. // Read & write buffer queues & notification.
  45. rbuf, wbuf []byte // buffer pending processing
  46. rqueue []*Cell // cells processed from read buffer
  47. rnotify chan struct{} // notification when read buffer changed
  48. wnotify chan struct{} // notification when write buffer changed
  49. modTime time.Time // last change to read or write
  50. onWrite func() // callback when a new write buffer changes
  51. // Stream verbosely logs to trace writer when set.
  52. TraceWriter io.Writer
  53. }
  54. // NewStream returns a new stream with a givenZ
  55. func NewStream(id int) *Stream {
  56. return &Stream{
  57. id: id,
  58. rbuf: make([]byte, 0, MaxCellLength),
  59. wbuf: make([]byte, 0, MaxCellLength),
  60. readClosing: make(chan struct{}),
  61. writeClosing: make(chan struct{}),
  62. rnotify: make(chan struct{}),
  63. wnotify: make(chan struct{}),
  64. modTime: time.Now(),
  65. writeCloseNotifiedNotify: make(chan struct{}),
  66. }
  67. }
  68. // ID returns the stream id.
  69. func (s *Stream) ID() int { return s.id }
  70. // ModTime returns the last time a cell was added or removed from the stream.
  71. func (s *Stream) ModTime() time.Time {
  72. s.mu.RLock()
  73. defer s.mu.RUnlock()
  74. return s.modTime
  75. }
  76. // ReadNotify returns a channel that receives a notification when a new read is available.
  77. func (s *Stream) ReadNotify() <-chan struct{} {
  78. s.mu.RLock()
  79. defer s.mu.RUnlock()
  80. return s.rnotify
  81. }
  82. func (s *Stream) notifyRead() {
  83. if s.TraceWriter != nil {
  84. fmt.Fprintf(s.TraceWriter, "[notifyRead]")
  85. }
  86. close(s.rnotify)
  87. s.rnotify = make(chan struct{})
  88. }
  89. // Read reads n bytes from the stream.
  90. func (s *Stream) Read(b []byte) (n int, err error) {
  91. if s.TraceWriter != nil {
  92. s.TraceWriter.Write([]byte("[Read]"))
  93. }
  94. for {
  95. // Attempt to read from the buffer. Exit if bytes read or error.
  96. s.mu.Lock()
  97. if n, err = s.read(b); n != 0 || err != nil {
  98. s.mu.Unlock()
  99. return n, err
  100. } else if n == 0 && len(s.rqueue) == 0 && s.readClosed {
  101. s.rbuf = nil
  102. s.mu.Unlock()
  103. return 0, io.EOF
  104. }
  105. notify := s.rnotify
  106. s.processReadQueue()
  107. s.mu.Unlock()
  108. // Wait for notification of new read buffer bytes.
  109. select {
  110. case <-s.readClosing:
  111. case <-notify:
  112. }
  113. }
  114. }
  115. // read reads available bytes from read buffer to b.
  116. func (s *Stream) read(b []byte) (n int, err error) {
  117. if len(s.rbuf) == 0 {
  118. return 0, nil
  119. }
  120. // Copy bytes to caller.
  121. n = len(b)
  122. if n > len(s.rbuf) {
  123. n = len(s.rbuf)
  124. }
  125. copy(b, s.rbuf)
  126. // Remove bytes from buffer.
  127. copy(s.rbuf, s.rbuf[n:])
  128. s.rbuf = s.rbuf[:len(s.rbuf)-n]
  129. return n, nil
  130. }
  131. // ReadBufferLen returns the number of bytes in the read buffer.
  132. func (s *Stream) ReadBufferLen() int {
  133. s.mu.RLock()
  134. defer s.mu.RUnlock()
  135. return len(s.rbuf)
  136. }
  137. // Write appends b to the write buffer. This method will continue to try until
  138. // the entire byte slice is written atomically to the buffer.
  139. func (s *Stream) Write(b []byte) (n int, err error) {
  140. if s.TraceWriter != nil {
  141. fmt.Fprintf(s.TraceWriter, "[Write] len=%d", len(b))
  142. }
  143. for {
  144. // Attempt to write to write buffer.
  145. // If no room available then wait for write buffer to change and try again.
  146. s.mu.Lock()
  147. if s.writeClosed {
  148. s.mu.Unlock()
  149. return 0, ErrStreamClosed
  150. } else if n, err = s.write(b); n != 0 || err != nil {
  151. s.notifyWrite()
  152. s.mu.Unlock()
  153. return n, err
  154. }
  155. notify := s.wnotify
  156. s.mu.Unlock()
  157. // Wait for a change in the write buffer.
  158. select {
  159. case <-s.writeClosing:
  160. case <-notify:
  161. }
  162. }
  163. }
  164. // write atomically writes b to the write buffer.
  165. // Returns ErrWriteTooLarge if b is larger than write buffer capacity.
  166. // Returns n=0 and no error if there is not enough space to write all of b.
  167. func (s *Stream) write(b []byte) (n int, err error) {
  168. if len(b) > cap(s.wbuf) {
  169. return 0, ErrWriteTooLarge
  170. } else if len(b) > cap(s.wbuf)-len(s.wbuf) {
  171. return 0, nil // not enough space
  172. }
  173. // Copy bytes to the end of the write buffer.
  174. s.wbuf = s.wbuf[:len(s.wbuf)+len(b)]
  175. copy(s.wbuf[len(s.wbuf)-len(b):], b)
  176. return len(b), nil
  177. }
  178. // WriteNotify returns a channel that receives a notification when a new write is available.
  179. func (s *Stream) WriteNotify() <-chan struct{} {
  180. s.mu.RLock()
  181. defer s.mu.RUnlock()
  182. return s.wnotify
  183. }
  184. // notifyWrite closes previous write notify channel and creates a new one.
  185. // This provides a broadcast for all interested parties.
  186. func (s *Stream) notifyWrite() {
  187. if s.TraceWriter != nil {
  188. fmt.Fprintf(s.TraceWriter, "[notifyWrite]")
  189. }
  190. close(s.wnotify)
  191. s.wnotify = make(chan struct{})
  192. }
  193. // WriteBufferLen returns the number of bytes in the write buffer.
  194. func (s *Stream) WriteBufferLen() int {
  195. s.mu.RLock()
  196. defer s.mu.RUnlock()
  197. return len(s.wbuf)
  198. }
  199. // Enqueue pushes a cell's payload on to the stream if it is the next sequence.
  200. // Out of sequence cells are added to the queue and are read after earlier cells.
  201. func (s *Stream) Enqueue(cell *Cell) error {
  202. s.mu.Lock()
  203. defer s.mu.Unlock()
  204. if s.TraceWriter != nil {
  205. fmt.Fprintf(s.TraceWriter, "[Enqueue] seq=%d rseq=%d", cell.SequenceID, s.rseq)
  206. }
  207. // If sequence is a duplicate then ignore it.
  208. if cell.SequenceID < s.rseq {
  209. s.logger().Info("duplicate cell sequence",
  210. zap.Int("local", s.rseq),
  211. zap.Int("remote", cell.SequenceID))
  212. return nil // duplicate cell
  213. }
  214. // Add to queue & sort.
  215. s.rqueue = append(s.rqueue, cell)
  216. sort.Slice(s.rqueue, func(i, j int) bool { return s.rqueue[i].Compare(s.rqueue[j]) == -1 })
  217. // Process read queue to convert cells in the queue to bytes on the read buffer.
  218. s.processReadQueue()
  219. s.modTime = time.Now()
  220. return nil
  221. }
  222. // processReadQueue deserializes cells in the read queue and writes the bytes to
  223. // the read buffer. Queue processing stops when the next cell does not match the
  224. // next expected sequence or if there is not enough room left in the read buffer.
  225. func (s *Stream) processReadQueue() {
  226. // Read all consecutive cells onto the buffer.
  227. var notify bool
  228. for len(s.rqueue) > 0 {
  229. cell := s.rqueue[0]
  230. if cell.SequenceID != s.rseq {
  231. break // out-of-order
  232. } else if len(cell.Payload) > cap(s.rbuf)-len(s.rbuf) {
  233. break // not enough space on buffer
  234. }
  235. // Extend buffer and copy cell payload.
  236. s.rbuf = s.rbuf[:len(s.rbuf)+len(cell.Payload)]
  237. copy(s.rbuf[len(s.rbuf)-len(cell.Payload):], cell.Payload)
  238. notify = true
  239. // Shift cell off queue and increment sequence.
  240. s.rqueue[0] = nil
  241. s.rqueue = s.rqueue[1:]
  242. s.rseq++
  243. // If this is the end of the stream then close out reads.
  244. if cell.Type == CellTypeEOS {
  245. if s.TraceWriter != nil {
  246. fmt.Fprintf(s.TraceWriter, "[eos:recv] seq=%d rseq=%d qlen=%d rbuf=%d", cell.SequenceID, s.rseq, len(s.rqueue), len(s.rbuf))
  247. }
  248. s.rqueue = nil
  249. s.closeRead()
  250. }
  251. }
  252. // Notify of read buffer change.
  253. if notify {
  254. s.notifyRead()
  255. }
  256. }
  257. // Dequeue reads n bytes from the write buffer and encodes it as a cell.
  258. func (s *Stream) Dequeue(n int) *Cell {
  259. s.mu.Lock()
  260. defer s.mu.Unlock()
  261. if s.TraceWriter != nil {
  262. fmt.Fprintf(s.TraceWriter, "[Dequeue] n=%d", n)
  263. }
  264. // Exit immediately if stream has already notified that its writes are closed.
  265. if s.writeCloseNotified {
  266. return nil
  267. }
  268. // Determine the amount of data to read.
  269. if n == 0 {
  270. n = len(s.wbuf) + CellHeaderSize
  271. } else if n > MaxCellLength {
  272. n = MaxCellLength
  273. }
  274. // Determine next sequence.
  275. sequenceID := s.wseq
  276. s.wseq++
  277. s.modTime = time.Now()
  278. // End stream if there's no more data and it's marked as closed.
  279. if len(s.wbuf) == 0 && s.writeClosed {
  280. if s.TraceWriter != nil {
  281. fmt.Fprintf(s.TraceWriter, "[eos:send] seq=%d", sequenceID)
  282. }
  283. s.writeCloseNotified = true
  284. close(s.writeCloseNotifiedNotify)
  285. return NewCell(s.id, sequenceID, n, CellTypeEOS)
  286. }
  287. // Build cell.
  288. cell := NewCell(s.id, sequenceID, n, CellTypeNormal)
  289. // Determine payload size.
  290. payloadN := n - CellHeaderSize
  291. if payloadN > len(s.wbuf) {
  292. payloadN = len(s.wbuf)
  293. }
  294. // Copy buffer to payload
  295. if payloadN > 0 {
  296. cell.Payload = make([]byte, payloadN)
  297. copy(cell.Payload, s.wbuf[:payloadN])
  298. // Remove payload bytes from buffer.
  299. remaining := len(s.wbuf) - payloadN
  300. copy(s.wbuf[:remaining], s.wbuf[payloadN:len(s.wbuf)])
  301. s.wbuf = s.wbuf[:remaining]
  302. // Send notification that write buffer has changed.
  303. s.notifyWrite()
  304. }
  305. return cell
  306. }
  307. // Close marks the stream as closed for writes. The server will close the read side.
  308. func (s *Stream) Close() error {
  309. return s.CloseWrite()
  310. }
  311. // CloseWrite marks the stream as closed for writes.
  312. func (s *Stream) CloseWrite() error {
  313. s.mu.Lock()
  314. defer s.mu.Unlock()
  315. s.closeWrite()
  316. return nil
  317. }
  318. func (s *Stream) closeWrite() {
  319. s.writeClosed = true
  320. s.wonce.Do(func() { close(s.writeClosing) })
  321. s.notifyWrite()
  322. }
  323. // CloseRead marks the stream as closed for reads.
  324. func (s *Stream) CloseRead() error {
  325. s.mu.Lock()
  326. defer s.mu.Unlock()
  327. s.closeRead()
  328. return nil
  329. }
  330. func (s *Stream) closeRead() {
  331. s.readClosed = true
  332. s.ronce.Do(func() { close(s.readClosing) })
  333. }
  334. // Closed returns true if the stream has been closed.
  335. func (s *Stream) Closed() bool {
  336. s.mu.RLock()
  337. defer s.mu.RUnlock()
  338. return s.readClosed && s.writeClosed
  339. }
  340. // ReadClosed returns true if the stream has been closed for reads.
  341. func (s *Stream) ReadClosed() bool {
  342. s.mu.RLock()
  343. defer s.mu.RUnlock()
  344. return s.readClosed
  345. }
  346. // ReadCloseNotify returns a channel that sends when the stream has been closed for writing.
  347. func (s *Stream) ReadCloseNotify() <-chan struct{} { return s.readClosing }
  348. // WriteClosed returns true if the stream has been requested to be closed for writes.
  349. func (s *Stream) WriteClosed() bool {
  350. s.mu.RLock()
  351. defer s.mu.RUnlock()
  352. return s.writeClosed
  353. }
  354. // WriteCloseNotify returns a channel that sends when the stream has been closed for writing.
  355. func (s *Stream) WriteCloseNotify() <-chan struct{} { return s.writeClosing }
  356. // WriteCloseNotified returns true if the stream has notified the peer connection of the end of stream.
  357. func (s *Stream) WriteCloseNotified() bool {
  358. s.mu.RLock()
  359. defer s.mu.RUnlock()
  360. return s.writeCloseNotified
  361. }
  362. func (s *Stream) WriteCloseNotifiedNotify() <-chan struct{} { return s.writeCloseNotifiedNotify }
  363. // ReadWriteCloseNotified returns true if the stream is closed for read and write and has been notified.
  364. func (s *Stream) ReadWriteCloseNotified() bool {
  365. s.mu.RLock()
  366. defer s.mu.RUnlock()
  367. return s.readClosed && s.writeCloseNotified
  368. }
  369. // LocalAddr returns the local address. Implements net.Conn.
  370. func (c *Stream) LocalAddr() net.Addr { return c.localAddr }
  371. // RemoteAddr returns the remote address. Implements net.Conn.
  372. func (c *Stream) RemoteAddr() net.Addr { return c.remoteAddr }
  373. // SetDeadline is a no-op. Implements net.Conn.
  374. func (c *Stream) SetDeadline(t time.Time) error { return nil }
  375. // SetReadDeadline is a no-op. Implements net.Conn.
  376. func (c *Stream) SetReadDeadline(t time.Time) error { return nil }
  377. // SetWriteDeadline is a no-op. Implements net.Conn.
  378. func (c *Stream) SetWriteDeadline(t time.Time) error { return nil }
  379. func (s *Stream) logger() *zap.Logger {
  380. return Logger.With(zap.Int("stream_id", s.id))
  381. }
  382. // streamExpVar is a wrapper for stream to generate expvar data.
  383. type streamExpVar Stream
  384. // String returns JSON representation of the expvar data.
  385. func (s *streamExpVar) String() string {
  386. s.mu.RLock()
  387. defer s.mu.RUnlock()
  388. buf, _ := json.Marshal(streamExpVarJSON{
  389. Rseq: s.rseq,
  390. Wseq: s.wseq,
  391. Rbuf: len(s.rbuf),
  392. Wbuf: len(s.wbuf),
  393. Rqueue: len(s.rqueue),
  394. })
  395. return string(buf)
  396. }
  397. // streamExpVarJSON is the JSON representation of a stream in expvar.
  398. type streamExpVarJSON struct {
  399. Rseq int `json:"rseq"`
  400. Wseq int `json:"wseq"`
  401. Rbuf int `json:"rbuf"`
  402. Wbuf int `json:"wbuf"`
  403. Rqueue int `json:"rqueue"`
  404. }