stream_set.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. package marionette
  2. import (
  3. "expvar"
  4. "fmt"
  5. "io"
  6. "math/rand"
  7. "os"
  8. "path/filepath"
  9. "strconv"
  10. "sync"
  11. "time"
  12. "go.uber.org/zap"
  13. )
  14. const (
  15. // StreamCloseTimeout is the amount of time before an idle read-closed or
  16. // write-closed stream is reaped by a monitoring goroutine.
  17. StreamCloseTimeout = 5 * time.Second
  18. )
  19. // evStreams is a global expvar variable for tracking open streams.
  20. var evStreams = expvar.NewInt("streams")
  21. // StreamSet represents a multiplexer for a set of streams for a connection.
  22. type StreamSet struct {
  23. mu sync.RWMutex
  24. streams map[int]*Stream // streams by id
  25. streamIDs []int // cached list of all stream ids
  26. wnotify chan struct{} // notification of write changes
  27. // Close management
  28. closing chan struct{}
  29. once sync.Once
  30. wg sync.WaitGroup
  31. // Callback executed when a new stream is created.
  32. OnNewStream func(*Stream)
  33. // Directory for storing stream traces.
  34. TracePath string
  35. }
  36. // NewStreamSet returns a new instance of StreamSet.
  37. func NewStreamSet() *StreamSet {
  38. ss := &StreamSet{
  39. streams: make(map[int]*Stream),
  40. closing: make(chan struct{}),
  41. wnotify: make(chan struct{}),
  42. }
  43. return ss
  44. }
  45. // Close closes all streams in the set.
  46. func (ss *StreamSet) Close() (err error) {
  47. ss.mu.Lock()
  48. for _, stream := range ss.streams {
  49. if e := stream.CloseWrite(); e != nil && err == nil {
  50. err = e
  51. } else if e := stream.CloseRead(); e != nil && err == nil {
  52. err = e
  53. }
  54. }
  55. ss.mu.Unlock()
  56. ss.once.Do(func() { close(ss.closing) })
  57. ss.wg.Wait()
  58. return err
  59. }
  60. // monitorStream checks a stream until its read & write channels are closed
  61. // and then removes the stream from the set.
  62. func (ss *StreamSet) monitorStream(stream *Stream) {
  63. readCloseNotify := stream.ReadCloseNotify()
  64. writeCloseNotifiedNotify := stream.WriteCloseNotifiedNotify()
  65. var timeout <-chan time.Time
  66. LOOP:
  67. for {
  68. // Wait until stream closed state is changed or the set is closed.
  69. select {
  70. case <-ss.closing:
  71. break LOOP
  72. case <-timeout:
  73. break LOOP
  74. case <-readCloseNotify:
  75. readCloseNotify = nil
  76. timeout = time.After(StreamCloseTimeout)
  77. case <-writeCloseNotifiedNotify:
  78. writeCloseNotifiedNotify = nil
  79. timeout = time.After(StreamCloseTimeout)
  80. }
  81. // If stream is completely closed then remove from the set.
  82. if stream.ReadWriteCloseNotified() {
  83. break
  84. }
  85. }
  86. // Ensure both sides are closed.
  87. stream.CloseRead()
  88. stream.CloseWrite()
  89. ss.mu.Lock()
  90. ss.remove(stream)
  91. ss.mu.Unlock()
  92. }
  93. // Stream returns a stream by id.
  94. func (ss *StreamSet) Stream(id int) *Stream {
  95. ss.mu.Lock()
  96. defer ss.mu.Unlock()
  97. return ss.streams[id]
  98. }
  99. // Streams returns a list of streams.
  100. func (ss *StreamSet) Streams() []*Stream {
  101. ss.mu.Lock()
  102. defer ss.mu.Unlock()
  103. streams := make([]*Stream, 0, len(ss.streams))
  104. for _, stream := range ss.streams {
  105. streams = append(streams, stream)
  106. }
  107. return streams
  108. }
  109. // Create returns a new stream with a random stream id.
  110. func (ss *StreamSet) Create() *Stream {
  111. ss.mu.Lock()
  112. defer ss.mu.Unlock()
  113. return ss.create(0)
  114. }
  115. func (ss *StreamSet) create(id int) *Stream {
  116. if id == 0 {
  117. id = int(rand.Int31() + 1)
  118. }
  119. stream := NewStream(id)
  120. // Create a per-stream log if trace path is specified.
  121. if ss.TracePath != "" {
  122. path := filepath.Join(ss.TracePath, strconv.Itoa(id))
  123. if err := os.MkdirAll(ss.TracePath, 0777); err != nil {
  124. Logger.Warn("cannot create trace directory", zap.Error(err))
  125. } else if w, err := os.Create(path); err != nil {
  126. Logger.Warn("cannot create trace file", zap.Error(err))
  127. } else {
  128. fmt.Fprintf(w, "# STREAM %d\n\n", id)
  129. stream.TraceWriter = &timestampWriter{Writer: w}
  130. }
  131. stream.TraceWriter.Write([]byte("[create]"))
  132. }
  133. // Add stream to set.
  134. ss.streams[stream.id] = stream
  135. ss.streamIDs = append(ss.streamIDs, stream.id)
  136. // Add to global counter.
  137. evStreams.Add(1)
  138. // Monitor each stream closing in a separate goroutine.
  139. ss.wg.Add(1)
  140. go func() { defer ss.wg.Done(); ss.monitorStream(stream) }()
  141. // Monitor read/write changes in a separate goroutine per stream.
  142. ss.wg.Add(1)
  143. go func() { defer ss.wg.Done(); ss.handleStream(stream) }()
  144. // Execute callback, if exists.
  145. if ss.OnNewStream != nil {
  146. ss.OnNewStream(stream)
  147. }
  148. return stream
  149. }
  150. // remove removes stream from the set and decrements open stream count.
  151. // This must be called under lock.
  152. func (ss *StreamSet) remove(stream *Stream) {
  153. streamID := stream.ID()
  154. evStreams.Add(-1)
  155. if stream.TraceWriter != nil {
  156. stream.TraceWriter.Write([]byte("[remove]"))
  157. if traceWriter, ok := stream.TraceWriter.(io.Closer); ok {
  158. traceWriter.Close()
  159. }
  160. }
  161. delete(ss.streams, streamID)
  162. for i, id := range ss.streamIDs {
  163. if id == streamID {
  164. ss.streamIDs = append(ss.streamIDs[:i], ss.streamIDs[i+1:]...)
  165. }
  166. }
  167. }
  168. // Enqueue pushes a cell onto a stream's read queue.
  169. // If the stream doesn't exist then it is created.
  170. func (ss *StreamSet) Enqueue(cell *Cell) error {
  171. ss.mu.Lock()
  172. defer ss.mu.Unlock()
  173. // Ignore empty cells.
  174. if cell.StreamID == 0 {
  175. return nil
  176. }
  177. // Create or find stream and enqueue cell.
  178. stream := ss.streams[cell.StreamID]
  179. if stream == nil {
  180. stream = ss.create(cell.StreamID)
  181. }
  182. return stream.Enqueue(cell)
  183. }
  184. // Dequeue returns a cell containing data for a random stream's write buffer.
  185. func (ss *StreamSet) Dequeue(n int) *Cell {
  186. ss.mu.Lock()
  187. defer ss.mu.Unlock()
  188. // Choose a random stream with data.
  189. var stream *Stream
  190. for _, i := range rand.Perm(len(ss.streamIDs)) {
  191. s := ss.streams[ss.streamIDs[i]]
  192. if s.WriteBufferLen() > 0 || s.WriteClosed() {
  193. stream = s
  194. break
  195. }
  196. }
  197. // If there is no stream with data then send an empty
  198. if stream == nil {
  199. return nil
  200. }
  201. // Generate cell from stream.
  202. return stream.Dequeue(n)
  203. }
  204. // WriteNotify returns a channel that receives a notification when a new write is available.
  205. func (ss *StreamSet) WriteNotify() <-chan struct{} {
  206. ss.mu.RLock()
  207. defer ss.mu.RUnlock()
  208. return ss.wnotify
  209. }
  210. // notifyWrite closes previous write notification channel and creates a new one.
  211. // This provides a broadcast to all interested parties.
  212. func (ss *StreamSet) notifyWrite() {
  213. ss.mu.Lock()
  214. close(ss.wnotify)
  215. ss.wnotify = make(chan struct{})
  216. ss.mu.Unlock()
  217. }
  218. // handleStream continually monitors write changes for stream.
  219. func (ss *StreamSet) handleStream(stream *Stream) {
  220. notify := stream.WriteNotify()
  221. ss.notifyWrite()
  222. for {
  223. select {
  224. case <-notify:
  225. notify = stream.WriteNotify()
  226. ss.notifyWrite()
  227. case <-stream.WriteCloseNotify():
  228. ss.notifyWrite()
  229. return
  230. }
  231. }
  232. }
  233. // timestampWriter wraps a writer and prepends a timestamp & appends a newline to every write.
  234. type timestampWriter struct {
  235. Writer io.Writer
  236. }
  237. func (w *timestampWriter) Write(p []byte) (n int, err error) {
  238. return fmt.Fprintf(w.Writer, "%s %s\n", time.Now().UTC().Format("2006-01-02T15:04:05.000Z"), p)
  239. }