conn.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package marionette
  2. import (
  3. "io"
  4. "net"
  5. "strings"
  6. "sync"
  7. )
  8. // BufferedConn wraps a net.Conn and continually reads from it into a buffer.
  9. //
  10. // The buffer is inspectable and seekable by the caller. This provides buffering
  11. // until a complete cell can be decoded from the connection. The buffer is sized
  12. // based on the max cell size and does not support cells that exceed that size.
  13. type BufferedConn struct {
  14. net.Conn
  15. // Current buffer & last error, protected for concurrent use.
  16. mu sync.RWMutex
  17. buf []byte
  18. err error
  19. // Close management.
  20. closing chan struct{}
  21. once sync.Once
  22. // Channels used to notify caller when the connection & buffer have changed.
  23. seekNotify chan struct{} // sent when seeking forward
  24. writeNotify chan struct{} // sent when data has been written to the buffer.
  25. }
  26. // NewBufferedConn returns a new BufferConn wrapping conn, sized to bufferSize.
  27. func NewBufferedConn(conn net.Conn, bufferSize int) *BufferedConn {
  28. c := &BufferedConn{
  29. Conn: conn,
  30. buf: make([]byte, 0, bufferSize*2),
  31. closing: make(chan struct{}, 0),
  32. seekNotify: make(chan struct{}, 1),
  33. writeNotify: make(chan struct{}, 1),
  34. }
  35. go c.monitor()
  36. return c
  37. }
  38. // Close closes the connection.
  39. func (conn *BufferedConn) Close() error {
  40. conn.once.Do(func() { close(conn.closing) })
  41. return conn.Conn.Close()
  42. }
  43. // Append adds b to the end of the buffer, under lock.
  44. func (conn *BufferedConn) Append(b []byte) {
  45. conn.mu.Lock()
  46. defer conn.mu.Unlock()
  47. copy(conn.buf[len(conn.buf):len(conn.buf)+len(b)], b)
  48. conn.buf = conn.buf[:len(conn.buf)+len(b)]
  49. }
  50. // Read is unavailable for BufferedConn.
  51. func (conn *BufferedConn) Read(p []byte) (int, error) {
  52. panic("BufferedConn.Read(): unavailable, use Peek/Seek")
  53. }
  54. // Peek returns the first n bytes of the read buffer.
  55. // If n is -1 then returns any available data after attempting a read.
  56. func (conn *BufferedConn) Peek(n int, blocking bool) ([]byte, error) {
  57. for {
  58. // Read buffer & error from monitor under read lock.
  59. conn.mu.RLock()
  60. buf, err := conn.buf, conn.err
  61. conn.mu.RUnlock()
  62. // Return any data that exists in the buffer.
  63. switch n {
  64. case -1:
  65. if len(buf) > 0 {
  66. return buf, nil
  67. } else if err != nil {
  68. return nil, err
  69. }
  70. default:
  71. if n <= len(buf) {
  72. return buf[:n], nil
  73. } else if isEOFError(err) {
  74. return buf, io.EOF
  75. } else if err != nil {
  76. return buf, err
  77. }
  78. }
  79. // Exit immediately if we are not blocking.
  80. if !blocking {
  81. return buf, err
  82. }
  83. // Wait for a new write or error from the monitor.
  84. <-conn.writeNotify
  85. }
  86. }
  87. // Seek moves the buffer forward a given number of bytes.
  88. // This implementation only supports io.SeekCurrent.
  89. func (conn *BufferedConn) Seek(offset int64, whence int) (int64, error) {
  90. assert(whence == io.SeekCurrent)
  91. conn.mu.Lock()
  92. defer conn.mu.Unlock()
  93. assert(offset <= int64(len(conn.buf)))
  94. b := conn.buf[offset:]
  95. conn.buf = conn.buf[:len(b)]
  96. copy(conn.buf, b)
  97. conn.notifySeek()
  98. return 0, nil
  99. }
  100. // monitor runs in a separate goroutine and continually reads to the buffer.
  101. func (conn *BufferedConn) monitor() {
  102. conn.mu.RLock()
  103. buf := make([]byte, cap(conn.buf))
  104. conn.mu.RUnlock()
  105. for {
  106. // Ensure connection is not closed.
  107. select {
  108. case <-conn.closing:
  109. return
  110. default:
  111. }
  112. // Determine remaining space on buffer.
  113. // If no capacity remains then wait for seek or connection close.
  114. conn.mu.RLock()
  115. capacity := cap(conn.buf) - len(conn.buf)
  116. conn.mu.RUnlock()
  117. if capacity == 0 {
  118. select {
  119. case <-conn.closing:
  120. return
  121. case <-conn.seekNotify:
  122. continue
  123. }
  124. }
  125. // Attempt to read next bytes from connection.
  126. n, err := conn.Conn.Read(buf[:capacity])
  127. // Append bytes to connection buffer.
  128. if n > 0 {
  129. conn.Append(buf[:n])
  130. conn.notifyWrite()
  131. }
  132. // If an error occurred then save on connection and exit.
  133. if err != nil && !isTemporaryError(err) {
  134. conn.mu.Lock()
  135. conn.err = err
  136. conn.mu.Unlock()
  137. conn.notifyWrite()
  138. return
  139. }
  140. }
  141. }
  142. // notifySeek performs a non-blocking send to the seekNotify channel.
  143. func (conn *BufferedConn) notifySeek() {
  144. select {
  145. case conn.seekNotify <- struct{}{}:
  146. default:
  147. }
  148. }
  149. // notifyWrite performs a non-blocking send to the seekWrite channel.
  150. func (conn *BufferedConn) notifyWrite() {
  151. select {
  152. case conn.writeNotify <- struct{}{}:
  153. default:
  154. }
  155. }
  156. // isTimeoutError returns true if the error is a timeout error.
  157. func isTimeoutError(err error) bool {
  158. if err == nil {
  159. return false
  160. } else if err, ok := err.(interface{ Timeout() bool }); ok && err.Timeout() {
  161. return true
  162. }
  163. return false
  164. }
  165. // isTemporaryError returns true if the error is a temporary error.
  166. func isTemporaryError(err error) bool {
  167. if err == nil {
  168. return false
  169. } else if err, ok := err.(interface{ Temporary() bool }); ok && err.Temporary() {
  170. return true
  171. }
  172. return false
  173. }
  174. // isEOFError returns true if error represents a closed connection.
  175. func isEOFError(err error) bool {
  176. return err != nil && strings.Contains(err.Error(), "connection reset by peer")
  177. }