reader.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. /*
  2. Copyright 2025 Psiphon Inc.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package udsipc
  14. import (
  15. "bufio"
  16. "context"
  17. "encoding/binary"
  18. "errors"
  19. "fmt"
  20. "io"
  21. "net"
  22. "sync"
  23. "sync/atomic"
  24. "syscall"
  25. "time"
  26. )
  27. // Use a buffer pool for message allocation to reduce GC pressure.
  28. const maxPooledMessageSize = 4096
  29. // messageBuffer wraps a fixed-size array to enable pooling without heap allocation.
  30. type messageBuffer struct {
  31. data [maxPooledMessageSize]byte
  32. }
  33. // messageBufferPool pools messageBuffer instances to reduce allocations.
  34. // nolint: gochecknoglobals // Pools are package-global for efficiency.
  35. var messageBufferPool = sync.Pool{
  36. New: func() any {
  37. return &messageBuffer{}
  38. },
  39. }
  40. // bufioReaderPool pools bufio.Reader instances to reduce allocations.
  41. // nolint: gochecknoglobals // Pools are package-global for efficiency.
  42. var bufioReaderPool = sync.Pool{
  43. New: func() any {
  44. return bufio.NewReader(nil)
  45. },
  46. }
  47. var (
  48. ErrInvalidLengthPrefix = errors.New("invalid length prefix")
  49. ErrConnectionClosed = errors.New("connection closed")
  50. ErrHandlerFailed = errors.New("handler failed")
  51. ErrHandlerNil = errors.New("handler cannot be nil")
  52. ErrMaxAcceptErrorsTooLarge = errors.New("maxAcceptErrors must be <= 63 to prevent overflow")
  53. ErrInvalidSocketPath = errors.New("invalid socket path")
  54. )
  55. // MessageHandler implementations process received messages.
  56. // MessageHandler's MUST NOT retain references to the passed slice.
  57. // If a MessageHandler needs to retain the data from this slice, it MUST copy it.
  58. // This restriction is because the passed slice is retrieved from a buffer pool prior to
  59. // being passed to the handler and returned to the pool for reuse when the handler returns.
  60. type MessageHandler func(data []byte) error
  61. // Reader receives length-prefixed messages via Unix domain socket.
  62. // nolint: govet
  63. type Reader struct {
  64. handler MessageHandler
  65. onError ErrorCallback
  66. systemd *SystemdManager
  67. listener net.Listener
  68. socketPath string
  69. shutdownStart chan struct{} // Signals running→stopping transition.
  70. shutdownForced chan struct{} // Signals stopping→stopped forcefully.
  71. maxMessageSize uint64
  72. receivedCount uint64 // Successfully processed messages.
  73. connectionCount uint64 // Total connections accepted.
  74. errorCount uint64 // Handler or protocol errors.
  75. inactivityTimeout time.Duration
  76. wg sync.WaitGroup
  77. closeOnce sync.Once
  78. readBufferSize uint32 // Size of kernel read buffer (SO_RCVBUF).
  79. maxAcceptErrors int
  80. }
  81. // NewReader creates a new reader with optional systemd integration.
  82. // nolint: gocognit
  83. func NewReader(handler MessageHandler, fallbackSocketPath string, opts ...ReaderOption) (*Reader, error) {
  84. if handler == nil {
  85. return nil, ErrHandlerNil
  86. }
  87. if fallbackSocketPath == "" {
  88. return nil, fmt.Errorf("%w: empty path", ErrInvalidSocketPath)
  89. }
  90. // nolint: mnd // Default values.
  91. r := &Reader{
  92. handler: handler,
  93. maxMessageSize: 10 << 20, // 10MB.
  94. inactivityTimeout: 10 * time.Second,
  95. maxAcceptErrors: 10,
  96. readBufferSize: 256 << 10, // 256KB.
  97. shutdownStart: make(chan struct{}),
  98. shutdownForced: make(chan struct{}),
  99. }
  100. for _, opt := range opts {
  101. if err := opt(r); err != nil {
  102. return nil, fmt.Errorf("failed to apply option: %w", err)
  103. }
  104. }
  105. systemd, err := NewSystemdManager()
  106. if err != nil {
  107. return nil, fmt.Errorf("failed to set up systemd manager: %w", err)
  108. }
  109. r.systemd = systemd
  110. r.socketPath = ResolveSocketPath(systemd, fallbackSocketPath)
  111. if len(r.socketPath) > MaxSocketPathLength() {
  112. return nil, fmt.Errorf("%w: socket path too long: %s", ErrInvalidSocketPath, r.socketPath)
  113. }
  114. // Try to get systemd-provided listener first, falling back to creating one directly.
  115. r.listener = systemd.GetSystemdListener()
  116. if r.listener == nil {
  117. if err = EnsureSocketDir(r.socketPath); err != nil {
  118. return nil, fmt.Errorf("failed to create socket directory: %w", err)
  119. }
  120. if err = CleanupSocket(r.socketPath); err != nil {
  121. return nil, fmt.Errorf("failed to clean up previous socket: %w", err)
  122. }
  123. r.listener, err = net.Listen("unix", r.socketPath) // nolint: noctx
  124. if err != nil {
  125. return nil, fmt.Errorf("failed to listen on socket: %w", err)
  126. }
  127. }
  128. if r.readBufferSize > 0 {
  129. if unixListener, ok := r.listener.(*net.UnixListener); ok {
  130. // Set read buffer on the listening socket.
  131. if file, err := unixListener.File(); err == nil { //nolint: govet // Safely shadowed error.
  132. defer file.Close()
  133. fd := int(file.Fd())
  134. // Use syscall to set SO_RCVBUF on the listening socket.
  135. //
  136. // As per: https://www.man7.org/linux/man-pages/man7/unix.7.html,
  137. // setting SO_RCVBUF has no effect on streaming UDS sockets on Linux.
  138. // > The SO_SNDBUF socket option does have an effect for UNIX domain
  139. // > sockets, but the SO_RCVBUF option does not. For datagram sockets,
  140. // > the SO_SNDBUF value imposes an upper limit on the size of outgoing
  141. // > datagrams.
  142. //
  143. // As per: https://man.freebsd.org/cgi/man.cgi?setsockopt(2),
  144. // setting SO_RCVBUF does set the buffer size for input on BSD.
  145. // An assumption is made that other BSDs (and derivatives like Darwin)
  146. // will have the same behavior as FreeBSD.
  147. // > SO_SNDBUF and SO_RCVBUF are options to adjust the normal buffer sizes
  148. // > allocated for output and input buffers, respectively. The buffer size
  149. // > may be increased for high-volume connections, or may be decreased to
  150. // > limit the possible backlog of incoming data. The system places an ab-
  151. // > solute maximum on these values, which is accessible through the
  152. // > sysctl(3) MIB variable "kern.ipc.maxsockbuf".
  153. //
  154. // This syscall safely no-ops on Linux sockets, so no platform
  155. // detection logic or conditional calling is necessary.
  156. _ = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, int(r.readBufferSize))
  157. }
  158. }
  159. }
  160. return r, nil
  161. }
  162. // GetMetrics returns current counter values and connection info.
  163. func (r *Reader) GetMetrics() (uint64, uint64, uint64) {
  164. return atomic.LoadUint64(&r.receivedCount),
  165. atomic.LoadUint64(&r.connectionCount),
  166. atomic.LoadUint64(&r.errorCount)
  167. }
  168. // Start begins listening for connections.
  169. func (r *Reader) Start() error {
  170. if r.systemd.IsSystemd() {
  171. if err := r.systemd.NotifyReady(); err != nil {
  172. return fmt.Errorf("failed to notify systemd ready socket: %w", err)
  173. }
  174. }
  175. r.wg.Add(1)
  176. go r.run()
  177. return nil
  178. }
  179. // Stop shuts down the reader gracefully, allowing in-flight messages to complete
  180. // until the provided context is cancelled or expires. Subsequent calls return nil.
  181. func (r *Reader) Stop(ctx context.Context) error {
  182. var err error
  183. r.closeOnce.Do(func() {
  184. close(r.shutdownStart)
  185. // Unix domain socket Accept() doesn't seem to respect SetDeadline.
  186. // Force the blocked Accept() to return by connecting to ourselves.
  187. if r.listener != nil {
  188. go func() {
  189. //nolint: mnd // Brief delay to ensure r.shutdownStart channel is processed first.
  190. time.Sleep(10 * time.Millisecond)
  191. if conn, dialErr := net.Dial("unix", r.socketPath); dialErr == nil { // nolint: noctx
  192. _ = conn.Close()
  193. }
  194. }()
  195. }
  196. // Monitor context and abort drain if context is cancelled or expires.
  197. stopComplete := make(chan struct{})
  198. go func() {
  199. select {
  200. case <-ctx.Done():
  201. // Context cancelled or expired - force immediate shutdown.
  202. close(r.shutdownForced)
  203. case <-stopComplete:
  204. // Clean shutdown completed before context cancellation/expiration.
  205. }
  206. }()
  207. // Wait for all goroutines to finish before closing the listener.
  208. // This prevents a race condition where SetDeadline() is called
  209. // on an invalid file descriptor (as warned in os.File.Fd docs).
  210. r.wg.Wait()
  211. // Signal context monitor that we're done.
  212. close(stopComplete)
  213. if r.systemd.IsSystemd() {
  214. // r.systemd.Close will close the listener internally.
  215. // The file lifecycle of systemd managed sockets is handled by
  216. // systemd itself, so we don't have to remove the socket file.
  217. if systemdErr := r.systemd.Close(); systemdErr != nil {
  218. err = errors.Join(err, systemdErr)
  219. }
  220. } else {
  221. if r.listener != nil {
  222. err = r.listener.Close()
  223. }
  224. if cleanupErr := CleanupSocket(r.socketPath); cleanupErr != nil {
  225. err = errors.Join(err, cleanupErr)
  226. }
  227. }
  228. })
  229. return err
  230. }
  231. // run is the main accept loop.
  232. func (r *Reader) run() {
  233. defer r.wg.Done()
  234. consecutiveErrors := 0
  235. for {
  236. conn, err := r.listener.Accept()
  237. if err != nil {
  238. select {
  239. case <-r.shutdownStart:
  240. return
  241. default:
  242. consecutiveErrors++
  243. if consecutiveErrors > r.maxAcceptErrors {
  244. if r.onError != nil {
  245. r.onError(err, "too many consecutive failures in accept loop")
  246. }
  247. return
  248. }
  249. // nolint: mnd // Fixed 100ms sleep to prevent busy looping on Accept errors
  250. time.Sleep(100 * time.Millisecond)
  251. }
  252. continue
  253. }
  254. // Reset error count on successful accept.
  255. consecutiveErrors = 0
  256. atomic.AddUint64(&r.connectionCount, 1)
  257. // Check for shutdown after successful accept as well.
  258. select {
  259. case <-r.shutdownStart:
  260. _ = conn.Close()
  261. return
  262. default:
  263. }
  264. r.wg.Add(1)
  265. go r.handleConnection(conn)
  266. }
  267. }
  268. // handleConnection processes length-prefixed messages from a connection.
  269. // nolint: gocognit,funlen
  270. func (r *Reader) handleConnection(conn net.Conn) {
  271. defer r.wg.Done()
  272. defer conn.Close() // nolint: errcheck // Nothing to do with this error.
  273. if r.readBufferSize > 0 {
  274. if unixConn, ok := conn.(*net.UnixConn); ok {
  275. // Optimize read buffer for this connection.
  276. _ = unixConn.SetReadBuffer(int(r.readBufferSize))
  277. }
  278. }
  279. // Get pooled bufio.Reader and reset it for this connection.
  280. reader, _ := bufioReaderPool.Get().(*bufio.Reader)
  281. reader.Reset(conn)
  282. defer bufioReaderPool.Put(reader)
  283. draining := false
  284. for {
  285. select {
  286. case <-r.shutdownStart:
  287. draining = true
  288. case <-r.shutdownForced:
  289. // Forced shutdown - exit immediately without processing further messages.
  290. // IMPORTANT: This cannot interrupt an already-executing handler. If the handler
  291. // is blocking (e.g., in time.Sleep or blocking I/O), this goroutine will wait
  292. // for it to complete before returning. An open connection that continues to
  293. // write messages while the handler is blocked will cause this goroutine to
  294. // remain blocked indefinitely until the handler completes or the connection
  295. // closes. To prevent this, ensure handlers are responsive and don't block
  296. // indefinitely, or ensure clients close connections promptly on shutdown.
  297. return
  298. default:
  299. }
  300. // Set read deadline based on shutdown state.
  301. // Potential errors are ignored because there is nothing to do with them.
  302. var deadline time.Time
  303. if !draining {
  304. // Normal operation - set inactivity timeout to close idle connections.
  305. deadline = time.Now().Add(r.inactivityTimeout)
  306. } else {
  307. // Draining - add a short inactivity timeout to allow continued reading of
  308. // data from the socket while draining (using time.Now() is too fast).
  309. // nolint: mnd
  310. deadline = time.Now().Add(time.Millisecond)
  311. }
  312. _ = conn.SetReadDeadline(deadline)
  313. length, err := binary.ReadUvarint(reader)
  314. if err != nil {
  315. if errors.Is(err, io.EOF) {
  316. // Client closed the connection.
  317. return
  318. }
  319. var netErr net.Error
  320. if errors.As(err, &netErr) && netErr.Timeout() {
  321. // Close idle connections after inactivity timeout.
  322. return
  323. }
  324. // Neither the client closing the connection nor closing idle
  325. // connections after they reach timeout should incremenet the
  326. // error count, so only do it once those checks have happened.
  327. atomic.AddUint64(&r.errorCount, 1)
  328. if r.onError != nil {
  329. r.onError(err, "failed to read the length prefix")
  330. }
  331. return
  332. }
  333. if length > r.maxMessageSize {
  334. atomic.AddUint64(&r.errorCount, 1)
  335. if r.onError != nil {
  336. r.onError(fmt.Errorf("invalid message size: %d: exceeds limit: %d", length, r.maxMessageSize), "message too large")
  337. }
  338. return
  339. }
  340. if length < 1 {
  341. // Empty messages aren't sent by this package's writer, but
  342. // if we receive one, there is no need to treat it as an error.
  343. // There is nothing else to be done with this message.
  344. continue
  345. }
  346. // Read message bytes with known length.
  347. var message []byte
  348. var msgBuf *messageBuffer
  349. if length <= maxPooledMessageSize {
  350. // Use pooled buffer for small messages.
  351. msgBuf, _ = messageBufferPool.Get().(*messageBuffer)
  352. message = msgBuf.data[:length]
  353. } else {
  354. // Fall back to heap allocation for large messages.
  355. message = make([]byte, length)
  356. }
  357. if _, err := io.ReadFull(reader, message); err != nil { //nolint: govet // Safely shadowed error.
  358. if msgBuf != nil {
  359. messageBufferPool.Put(msgBuf)
  360. }
  361. atomic.AddUint64(&r.errorCount, 1)
  362. if r.onError != nil {
  363. r.onError(err, "failed to read the complete message")
  364. }
  365. return
  366. }
  367. if err := r.handler(message); err != nil { //nolint: govet // Safely shadowed error.
  368. if msgBuf != nil {
  369. messageBufferPool.Put(msgBuf)
  370. }
  371. atomic.AddUint64(&r.errorCount, 1)
  372. if r.onError != nil {
  373. r.onError(err, "handler failed to process a message")
  374. }
  375. // Don't close connection for handler errors on individual messages.
  376. continue
  377. }
  378. // Return buffer to pool after successful processing.
  379. if msgBuf != nil {
  380. messageBufferPool.Put(msgBuf)
  381. }
  382. atomic.AddUint64(&r.receivedCount, 1)
  383. }
  384. }