writer.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. /*
  2. Copyright 2025 Psiphon Inc.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package udsipc
  14. import (
  15. "context"
  16. "encoding/binary"
  17. "errors"
  18. "fmt"
  19. "net"
  20. "sync"
  21. "sync/atomic"
  22. "time"
  23. )
  24. var (
  25. ErrBackpressure = errors.New("backpressure detected")
  26. ErrNoConsumer = errors.New("no consumer")
  27. ErrBufferFull = errors.New("send buffer full")
  28. ErrNotConnected = errors.New("not connected")
  29. ErrInvalidTimeout = errors.New("timeout must be positive")
  30. ErrInvalidBufferSize = errors.New("invalid buffer size")
  31. )
  32. // Pre-allocated joined errors for hot path error conditions to reduce allocations.
  33. var (
  34. errNoConsumerNotConnected = errors.Join(ErrNoConsumer, ErrNotConnected)
  35. )
  36. // lengthPrefixPool pools byte slices for length prefix decoding to reduce allocations.
  37. // nolint: gochecknoglobals // Pools are package-global for efficiency.
  38. var lengthPrefixPool = sync.Pool{
  39. New: func() any {
  40. b := make([]byte, 0, binary.MaxVarintLen64)
  41. return &b
  42. },
  43. }
  44. // vectoredBufferPool pools net.Buffers slices to reduce allocations.
  45. // nolint: gochecknoglobals // Pools are package-global for efficiency.
  46. var vectoredBufferPool = sync.Pool{
  47. New: func() any {
  48. buffers := make(net.Buffers, 2) //nolint: mnd // We always only need 2 (length, data).
  49. return &buffers
  50. },
  51. }
  52. // Writer writes varint length prefixed byte slices
  53. // to a Unix Domain Socket (UDS) with a small internal buffer,
  54. // backpressure detection, and lost consumer detection.
  55. // If the consumer is unavailable for long enough that the buffer
  56. // fills, new messages will be discarded (instead of blocking).
  57. // nolint: govet
  58. type Writer struct {
  59. onError ErrorCallback
  60. send chan []byte
  61. conn net.Conn
  62. socketPath string
  63. shutdownStart chan struct{} // Signals running→stopping transition.
  64. shutdownComplete chan struct{} // Signals stopping→stopped gracefully.
  65. shutdownForced chan struct{} // Signals stopping→stopped forcefully.
  66. sentCount uint64 // Successfully sent to consumer.
  67. droppedCount uint64 // Dropped due to queue full.
  68. failedCount uint64 // Failed due to connection issues.
  69. writeTimeout time.Duration
  70. dialTimeout time.Duration
  71. maxBackoff time.Duration
  72. wg sync.WaitGroup
  73. closeOnce sync.Once
  74. writeBufferSize uint32 // Size of kernel write buffer (SO_SNDBUF).
  75. }
  76. // NewWriter creates a pointer to a newly initialized Writer.
  77. func NewWriter(socketPath string, opts ...WriterOption) (*Writer, error) {
  78. if socketPath == "" {
  79. return nil, fmt.Errorf("%w: empty path", ErrInvalidSocketPath)
  80. }
  81. if len(socketPath) > MaxSocketPathLength() {
  82. return nil, fmt.Errorf("%w: socket path too long: %s", ErrInvalidSocketPath, socketPath)
  83. }
  84. // nolint: mnd // Default values.
  85. w := &Writer{
  86. writeTimeout: time.Second,
  87. dialTimeout: time.Second,
  88. maxBackoff: 10 * time.Second,
  89. socketPath: socketPath,
  90. writeBufferSize: 256 * 1024, // 256KB.
  91. send: make(chan []byte, 10_000),
  92. shutdownStart: make(chan struct{}),
  93. shutdownComplete: make(chan struct{}),
  94. shutdownForced: make(chan struct{}),
  95. }
  96. for _, opt := range opts {
  97. if err := opt(w); err != nil {
  98. return nil, fmt.Errorf("error applying option: %w", err)
  99. }
  100. }
  101. return w, nil
  102. }
  103. // WriteMessage queues a message for sending, dropping messages and returning
  104. // ErrBufferFull when the queue is full (instead of blocking).
  105. // Callers MUST NOT modify the data slice after calling WriteMessage. The slice
  106. // will be retained for potential retries on write failure. If the caller needs
  107. // to reuse or modify the slice, they must pass a copy.
  108. func (w *Writer) WriteMessage(data []byte) error {
  109. if len(data) < 1 {
  110. return nil
  111. }
  112. select {
  113. case w.send <- data:
  114. // Queued successfully.
  115. default:
  116. // Queue full - message dropped.
  117. atomic.AddUint64(&w.droppedCount, 1)
  118. return ErrBufferFull
  119. }
  120. return nil
  121. }
  122. // GetMetrics returns current counter values and queue depth.
  123. func (w *Writer) GetMetrics() (uint64, uint64, uint64, int) {
  124. return atomic.LoadUint64(&w.sentCount),
  125. atomic.LoadUint64(&w.droppedCount),
  126. atomic.LoadUint64(&w.failedCount),
  127. len(w.send)
  128. }
  129. // GetSocketPath returns the socket path being used.
  130. func (w *Writer) GetSocketPath() string {
  131. return w.socketPath
  132. }
  133. // Start begins the sender loop.
  134. func (w *Writer) Start() {
  135. w.wg.Add(1)
  136. go w.run()
  137. }
  138. // Stop attempts to shut down gracefully until it either finishes
  139. // draining all writes, or the passed context is cancelled or expires.
  140. // Subsequent calls return nil.
  141. func (w *Writer) Stop(ctx context.Context) error {
  142. var err error
  143. w.closeOnce.Do(func() {
  144. close(w.shutdownStart) // Signal run() to begin shutdown
  145. // Wait for either graceful completion or context cancellation/expiration
  146. select {
  147. case <-w.shutdownComplete: // Clean shutdown - all buffered messages drained
  148. case <-ctx.Done(): // Forced shutdown - context cancelled or expired
  149. close(w.shutdownForced) // Force run() to exit drain phase immediately
  150. err = fmt.Errorf("graceful shutdown timeout, forcing unclean shutdown: %w", ctx.Err())
  151. }
  152. // Always wait for goroutine cleanup regardless of how we exited the select
  153. w.wg.Wait()
  154. // Close connection after goroutine cleanup
  155. if w.conn != nil {
  156. if closeErr := w.conn.Close(); closeErr != nil && err == nil {
  157. err = fmt.Errorf("failed to close connection: %w", closeErr)
  158. }
  159. }
  160. })
  161. return err
  162. }
  163. // writeLengthPrefixedData writes length-prefixed data to the socket.
  164. func (w *Writer) writeLengthPrefixedData(data []byte) error {
  165. if w.conn == nil {
  166. return errNoConsumerNotConnected
  167. }
  168. lengthPrefixBuf, _ := lengthPrefixPool.Get().(*[]byte)
  169. *lengthPrefixBuf = (*lengthPrefixBuf)[:0] // Clear previous data.
  170. *lengthPrefixBuf = (*lengthPrefixBuf)[:binary.MaxVarintLen64] // Ensure sufficient length for PutUvarint.
  171. defer lengthPrefixPool.Put(lengthPrefixBuf)
  172. lengthPrefixSize := binary.PutUvarint(*lengthPrefixBuf, uint64(len(data)))
  173. // Use vectored I/O to write prefix + data in single syscall.
  174. buffersPtr, _ := vectoredBufferPool.Get().(*net.Buffers)
  175. defer vectoredBufferPool.Put(buffersPtr)
  176. buffers := *buffersPtr
  177. buffers[0] = (*lengthPrefixBuf)[:lengthPrefixSize]
  178. buffers[1] = data
  179. deadline := time.Now().Add(w.writeTimeout)
  180. if err := w.conn.SetWriteDeadline(deadline); err != nil {
  181. return errors.Join(ErrNoConsumer, err)
  182. }
  183. if _, err := buffers.WriteTo(w.conn); err != nil {
  184. return w.classifyWriteError(err)
  185. }
  186. return nil
  187. }
  188. // classifyWriteError categorizes write errors.
  189. // Timeouts while writing return a backpressure error.
  190. // All other errors are classified as having no consumer.
  191. func (w *Writer) classifyWriteError(err error) error {
  192. var netErr net.Error
  193. if errors.As(err, &netErr) && netErr.Timeout() {
  194. return errors.Join(ErrBackpressure, err)
  195. }
  196. return errors.Join(ErrNoConsumer, err)
  197. }
  198. // run is the main sender loop.
  199. func (w *Writer) run() {
  200. defer w.wg.Done()
  201. // Signal graceful shutdown completion
  202. defer close(w.shutdownComplete)
  203. // Phase 1: Normal operations.
  204. retryMsg := w.processMessages()
  205. // Phase 2: Graceful drain of remaining buffered messages.
  206. w.drainQueuedWrites(retryMsg)
  207. }
  208. // processMessages handles normal operation including connection management and message processing.
  209. // Returns any pending retry message that should be attempted during drain phase.
  210. // nolint: gocognit
  211. func (w *Writer) processMessages() []byte {
  212. backoff := time.Second
  213. var retryMsgOnReconnect []byte
  214. for {
  215. // Make sure we're connected.
  216. if w.conn == nil { // nolint: nestif
  217. if err := w.connect(); err != nil {
  218. if w.onError != nil {
  219. w.onError(err, "failed to connect")
  220. }
  221. select {
  222. case <-time.After(backoff):
  223. backoff = min(backoff*2, w.maxBackoff) //nolint: mnd // Exponential backoff.
  224. continue
  225. case <-w.shutdownStart:
  226. return retryMsgOnReconnect // Move to draining buffered writes phase.
  227. }
  228. }
  229. // Reset the timeout to 1 second, which could be larger than
  230. // the expected minimum, but strikes a balance between fast
  231. // reconnections and hammering a dead endpoint repeatedly.
  232. backoff = time.Second
  233. // If we've previously failed to write a message, it will be stored
  234. // in retryMsgOnReconnect and a write should be immediately attempted
  235. // with this message upon successful reconnect. Subsequent failures
  236. // should continue to trigger reconnections (since failing to
  237. // reconnect repeatedly will eventually hit the maximum backoff time
  238. // and result in a different error pathway.
  239. if retryMsgOnReconnect != nil {
  240. if err := w.sendRetryMessage(retryMsgOnReconnect, "write failure after reconnect"); err != nil {
  241. w.closeConn()
  242. continue
  243. }
  244. retryMsgOnReconnect = nil
  245. }
  246. }
  247. // Process messages.
  248. select {
  249. case data := <-w.send:
  250. if err := w.writeLengthPrefixedData(data); err != nil {
  251. atomic.AddUint64(&w.failedCount, 1)
  252. if w.onError != nil {
  253. w.onError(err, "write failure")
  254. }
  255. w.closeConn()
  256. // Buffer the failed message for retry on reconnect.
  257. // Note: We rely on the WriteMessage API contract that callers
  258. // do not modify the slice after passing it to WriteMessage.
  259. retryMsgOnReconnect = data
  260. } else {
  261. atomic.AddUint64(&w.sentCount, 1)
  262. }
  263. case <-w.shutdownStart:
  264. return retryMsgOnReconnect // Move to draining buffered writes phase.
  265. }
  266. }
  267. }
  268. // sendRetryMessage attempts to send a buffered retry message, updating metrics accordingly.
  269. // Returns error if write failed. Caller is responsible for connection management.
  270. func (w *Writer) sendRetryMessage(data []byte, context string) error {
  271. if err := w.writeLengthPrefixedData(data); err != nil {
  272. atomic.AddUint64(&w.failedCount, 1)
  273. if w.onError != nil {
  274. w.onError(err, context)
  275. }
  276. return err
  277. }
  278. atomic.AddUint64(&w.sentCount, 1)
  279. return nil
  280. }
  281. // drainQueuedWrites handles graceful shutdown by draining remaining buffered messages.
  282. func (w *Writer) drainQueuedWrites(retryMsgOnReconnect []byte) {
  283. // If there's a pending retry message from normal operation, attempt to send it first.
  284. if retryMsgOnReconnect != nil {
  285. if err := w.sendRetryMessage(
  286. retryMsgOnReconnect, "write failure during drain (retry message)",
  287. ); err != nil {
  288. w.closeConn()
  289. }
  290. }
  291. for {
  292. select {
  293. case data := <-w.send:
  294. // Continue processing buffered messages during drain.
  295. if err := w.writeLengthPrefixedData(data); err != nil {
  296. atomic.AddUint64(&w.failedCount, 1)
  297. if w.onError != nil {
  298. w.onError(err, "write failure during drain")
  299. }
  300. w.closeConn()
  301. } else {
  302. atomic.AddUint64(&w.sentCount, 1)
  303. }
  304. case <-w.shutdownForced:
  305. // Forced shutdown - exit immediately without draining more.
  306. return
  307. default:
  308. // No more messages to drain - clean shutdown complete.
  309. if len(w.send) == 0 {
  310. return
  311. }
  312. // While there is a small risk this code could create a short busy loop condition
  313. // in the case where data is in the buffered channel but not yet available to be
  314. // selected, no explicit sleep or yield is needed since in Go 1.14+ the scheduler
  315. // can preempt busy loops itself when needed.
  316. }
  317. }
  318. }
  319. // connect establishes connection to the socket.
  320. func (w *Writer) connect() error {
  321. conn, err := net.DialTimeout("unix", w.socketPath, w.dialTimeout) //nolint: noctx
  322. if err != nil {
  323. return fmt.Errorf("failed to dial socket: %s: %w", w.socketPath, err)
  324. }
  325. if w.writeBufferSize > 0 {
  326. if unixConn, ok := conn.(*net.UnixConn); ok {
  327. // Increase write buffer to reduce kernel allocation overhead.
  328. // Don't fail connection for buffer optimization errors,
  329. // this could happen in restricted environments.
  330. _ = unixConn.SetWriteBuffer(int(w.writeBufferSize))
  331. }
  332. }
  333. w.conn = conn
  334. return nil
  335. }
  336. // closeConn closes current connection.
  337. func (w *Writer) closeConn() {
  338. if w.conn != nil {
  339. _ = w.conn.Close()
  340. w.conn = nil
  341. }
  342. }