| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- package marionette
- import (
- "expvar"
- "fmt"
- "io"
- "math/rand"
- "os"
- "path/filepath"
- "strconv"
- "sync"
- "time"
- "go.uber.org/zap"
- )
- const (
- // StreamCloseTimeout is the amount of time before an idle read-closed or
- // write-closed stream is reaped by a monitoring goroutine.
- StreamCloseTimeout = 5 * time.Second
- )
- // evStreams is a global expvar variable for tracking open streams.
- var evStreams = expvar.NewInt("streams")
- // StreamSet represents a multiplexer for a set of streams for a connection.
- type StreamSet struct {
- mu sync.RWMutex
- streams map[int]*Stream // streams by id
- streamIDs []int // cached list of all stream ids
- wnotify chan struct{} // notification of write changes
- // Close management
- closing chan struct{}
- once sync.Once
- wg sync.WaitGroup
- // Callback executed when a new stream is created.
- OnNewStream func(*Stream)
- // Directory for storing stream traces.
- TracePath string
- }
- // NewStreamSet returns a new instance of StreamSet.
- func NewStreamSet() *StreamSet {
- ss := &StreamSet{
- streams: make(map[int]*Stream),
- closing: make(chan struct{}),
- wnotify: make(chan struct{}),
- }
- return ss
- }
- // Close closes all streams in the set.
- func (ss *StreamSet) Close() (err error) {
- ss.mu.Lock()
- for _, stream := range ss.streams {
- if e := stream.CloseWrite(); e != nil && err == nil {
- err = e
- } else if e := stream.CloseRead(); e != nil && err == nil {
- err = e
- }
- }
- ss.mu.Unlock()
- ss.once.Do(func() { close(ss.closing) })
- ss.wg.Wait()
- return err
- }
- // monitorStream checks a stream until its read & write channels are closed
- // and then removes the stream from the set.
- func (ss *StreamSet) monitorStream(stream *Stream) {
- readCloseNotify := stream.ReadCloseNotify()
- writeCloseNotifiedNotify := stream.WriteCloseNotifiedNotify()
- var timeout <-chan time.Time
- LOOP:
- for {
- // Wait until stream closed state is changed or the set is closed.
- select {
- case <-ss.closing:
- break LOOP
- case <-timeout:
- break LOOP
- case <-readCloseNotify:
- readCloseNotify = nil
- timeout = time.After(StreamCloseTimeout)
- case <-writeCloseNotifiedNotify:
- writeCloseNotifiedNotify = nil
- timeout = time.After(StreamCloseTimeout)
- }
- // If stream is completely closed then remove from the set.
- if stream.ReadWriteCloseNotified() {
- break
- }
- }
- // Ensure both sides are closed.
- stream.CloseRead()
- stream.CloseWrite()
- ss.mu.Lock()
- ss.remove(stream)
- ss.mu.Unlock()
- }
- // Stream returns a stream by id.
- func (ss *StreamSet) Stream(id int) *Stream {
- ss.mu.Lock()
- defer ss.mu.Unlock()
- return ss.streams[id]
- }
- // Streams returns a list of streams.
- func (ss *StreamSet) Streams() []*Stream {
- ss.mu.Lock()
- defer ss.mu.Unlock()
- streams := make([]*Stream, 0, len(ss.streams))
- for _, stream := range ss.streams {
- streams = append(streams, stream)
- }
- return streams
- }
- // Create returns a new stream with a random stream id.
- func (ss *StreamSet) Create() *Stream {
- ss.mu.Lock()
- defer ss.mu.Unlock()
- return ss.create(0)
- }
- func (ss *StreamSet) create(id int) *Stream {
- if id == 0 {
- id = int(rand.Int31() + 1)
- }
- stream := NewStream(id)
- // Create a per-stream log if trace path is specified.
- if ss.TracePath != "" {
- path := filepath.Join(ss.TracePath, strconv.Itoa(id))
- if err := os.MkdirAll(ss.TracePath, 0777); err != nil {
- Logger.Warn("cannot create trace directory", zap.Error(err))
- } else if w, err := os.Create(path); err != nil {
- Logger.Warn("cannot create trace file", zap.Error(err))
- } else {
- fmt.Fprintf(w, "# STREAM %d\n\n", id)
- stream.TraceWriter = ×tampWriter{Writer: w}
- }
- stream.TraceWriter.Write([]byte("[create]"))
- }
- // Add stream to set.
- ss.streams[stream.id] = stream
- ss.streamIDs = append(ss.streamIDs, stream.id)
- // Add to global counter.
- evStreams.Add(1)
- // Monitor each stream closing in a separate goroutine.
- ss.wg.Add(1)
- go func() { defer ss.wg.Done(); ss.monitorStream(stream) }()
- // Monitor read/write changes in a separate goroutine per stream.
- ss.wg.Add(1)
- go func() { defer ss.wg.Done(); ss.handleStream(stream) }()
- // Execute callback, if exists.
- if ss.OnNewStream != nil {
- ss.OnNewStream(stream)
- }
- return stream
- }
- // remove removes stream from the set and decrements open stream count.
- // This must be called under lock.
- func (ss *StreamSet) remove(stream *Stream) {
- streamID := stream.ID()
- evStreams.Add(-1)
- if stream.TraceWriter != nil {
- stream.TraceWriter.Write([]byte("[remove]"))
- if traceWriter, ok := stream.TraceWriter.(io.Closer); ok {
- traceWriter.Close()
- }
- }
- delete(ss.streams, streamID)
- for i, id := range ss.streamIDs {
- if id == streamID {
- ss.streamIDs = append(ss.streamIDs[:i], ss.streamIDs[i+1:]...)
- }
- }
- }
- // Enqueue pushes a cell onto a stream's read queue.
- // If the stream doesn't exist then it is created.
- func (ss *StreamSet) Enqueue(cell *Cell) error {
- ss.mu.Lock()
- defer ss.mu.Unlock()
- // Ignore empty cells.
- if cell.StreamID == 0 {
- return nil
- }
- // Create or find stream and enqueue cell.
- stream := ss.streams[cell.StreamID]
- if stream == nil {
- stream = ss.create(cell.StreamID)
- }
- return stream.Enqueue(cell)
- }
- // Dequeue returns a cell containing data for a random stream's write buffer.
- func (ss *StreamSet) Dequeue(n int) *Cell {
- ss.mu.Lock()
- defer ss.mu.Unlock()
- // Choose a random stream with data.
- var stream *Stream
- for _, i := range rand.Perm(len(ss.streamIDs)) {
- s := ss.streams[ss.streamIDs[i]]
- if s.WriteBufferLen() > 0 || s.WriteClosed() {
- stream = s
- break
- }
- }
- // If there is no stream with data then send an empty
- if stream == nil {
- return nil
- }
- // Generate cell from stream.
- return stream.Dequeue(n)
- }
- // WriteNotify returns a channel that receives a notification when a new write is available.
- func (ss *StreamSet) WriteNotify() <-chan struct{} {
- ss.mu.RLock()
- defer ss.mu.RUnlock()
- return ss.wnotify
- }
- // notifyWrite closes previous write notification channel and creates a new one.
- // This provides a broadcast to all interested parties.
- func (ss *StreamSet) notifyWrite() {
- ss.mu.Lock()
- close(ss.wnotify)
- ss.wnotify = make(chan struct{})
- ss.mu.Unlock()
- }
- // handleStream continually monitors write changes for stream.
- func (ss *StreamSet) handleStream(stream *Stream) {
- notify := stream.WriteNotify()
- ss.notifyWrite()
- for {
- select {
- case <-notify:
- notify = stream.WriteNotify()
- ss.notifyWrite()
- case <-stream.WriteCloseNotify():
- ss.notifyWrite()
- return
- }
- }
- }
- // timestampWriter wraps a writer and prepends a timestamp & appends a newline to every write.
- type timestampWriter struct {
- Writer io.Writer
- }
- func (w *timestampWriter) Write(p []byte) (n int, err error) {
- return fmt.Fprintf(w.Writer, "%s %s\n", time.Now().UTC().Format("2006-01-02T15:04:05.000Z"), p)
- }
|