| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456 |
- /*
- Copyright 2025 Psiphon Inc.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package udsipc
- import (
- "bufio"
- "context"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "net"
- "sync"
- "sync/atomic"
- "syscall"
- "time"
- )
- // Use a buffer pool for message allocation to reduce GC pressure.
- const maxPooledMessageSize = 4096
- // messageBuffer wraps a fixed-size array to enable pooling without heap allocation.
- type messageBuffer struct {
- data [maxPooledMessageSize]byte
- }
- // messageBufferPool pools messageBuffer instances to reduce allocations.
- // nolint: gochecknoglobals // Pools are package-global for efficiency.
- var messageBufferPool = sync.Pool{
- New: func() any {
- return &messageBuffer{}
- },
- }
- // bufioReaderPool pools bufio.Reader instances to reduce allocations.
- // nolint: gochecknoglobals // Pools are package-global for efficiency.
- var bufioReaderPool = sync.Pool{
- New: func() any {
- return bufio.NewReader(nil)
- },
- }
- var (
- ErrInvalidLengthPrefix = errors.New("invalid length prefix")
- ErrConnectionClosed = errors.New("connection closed")
- ErrHandlerFailed = errors.New("handler failed")
- ErrHandlerNil = errors.New("handler cannot be nil")
- ErrMaxAcceptErrorsTooLarge = errors.New("maxAcceptErrors must be <= 63 to prevent overflow")
- ErrInvalidSocketPath = errors.New("invalid socket path")
- )
- // MessageHandler implementations process received messages.
- // MessageHandler's MUST NOT retain references to the passed slice.
- // If a MessageHandler needs to retain the data from this slice, it MUST copy it.
- // This restriction is because the passed slice is retrieved from a buffer pool prior to
- // being passed to the handler and returned to the pool for reuse when the handler returns.
- type MessageHandler func(data []byte) error
- // Reader receives length-prefixed messages via Unix domain socket.
- // nolint: govet
- type Reader struct {
- handler MessageHandler
- onError ErrorCallback
- systemd *SystemdManager
- listener net.Listener
- socketPath string
- shutdownStart chan struct{} // Signals running→stopping transition.
- shutdownForced chan struct{} // Signals stopping→stopped forcefully.
- maxMessageSize uint64
- receivedCount uint64 // Successfully processed messages.
- connectionCount uint64 // Total connections accepted.
- errorCount uint64 // Handler or protocol errors.
- inactivityTimeout time.Duration
- wg sync.WaitGroup
- closeOnce sync.Once
- readBufferSize uint32 // Size of kernel read buffer (SO_RCVBUF).
- maxAcceptErrors int
- }
- // NewReader creates a new reader with optional systemd integration.
- // nolint: gocognit
- func NewReader(handler MessageHandler, fallbackSocketPath string, opts ...ReaderOption) (*Reader, error) {
- if handler == nil {
- return nil, ErrHandlerNil
- }
- if fallbackSocketPath == "" {
- return nil, fmt.Errorf("%w: empty path", ErrInvalidSocketPath)
- }
- // nolint: mnd // Default values.
- r := &Reader{
- handler: handler,
- maxMessageSize: 10 << 20, // 10MB.
- inactivityTimeout: 10 * time.Second,
- maxAcceptErrors: 10,
- readBufferSize: 256 << 10, // 256KB.
- shutdownStart: make(chan struct{}),
- shutdownForced: make(chan struct{}),
- }
- for _, opt := range opts {
- if err := opt(r); err != nil {
- return nil, fmt.Errorf("failed to apply option: %w", err)
- }
- }
- systemd, err := NewSystemdManager()
- if err != nil {
- return nil, fmt.Errorf("failed to set up systemd manager: %w", err)
- }
- r.systemd = systemd
- r.socketPath = ResolveSocketPath(systemd, fallbackSocketPath)
- if len(r.socketPath) > MaxSocketPathLength() {
- return nil, fmt.Errorf("%w: socket path too long: %s", ErrInvalidSocketPath, r.socketPath)
- }
- // Try to get systemd-provided listener first, falling back to creating one directly.
- r.listener = systemd.GetSystemdListener()
- if r.listener == nil {
- if err = EnsureSocketDir(r.socketPath); err != nil {
- return nil, fmt.Errorf("failed to create socket directory: %w", err)
- }
- if err = CleanupSocket(r.socketPath); err != nil {
- return nil, fmt.Errorf("failed to clean up previous socket: %w", err)
- }
- r.listener, err = net.Listen("unix", r.socketPath) // nolint: noctx
- if err != nil {
- return nil, fmt.Errorf("failed to listen on socket: %w", err)
- }
- }
- if r.readBufferSize > 0 {
- if unixListener, ok := r.listener.(*net.UnixListener); ok {
- // Set read buffer on the listening socket.
- if file, err := unixListener.File(); err == nil { //nolint: govet // Safely shadowed error.
- defer file.Close()
- fd := int(file.Fd())
- // Use syscall to set SO_RCVBUF on the listening socket.
- //
- // As per: https://www.man7.org/linux/man-pages/man7/unix.7.html,
- // setting SO_RCVBUF has no effect on streaming UDS sockets on Linux.
- // > The SO_SNDBUF socket option does have an effect for UNIX domain
- // > sockets, but the SO_RCVBUF option does not. For datagram sockets,
- // > the SO_SNDBUF value imposes an upper limit on the size of outgoing
- // > datagrams.
- //
- // As per: https://man.freebsd.org/cgi/man.cgi?setsockopt(2),
- // setting SO_RCVBUF does set the buffer size for input on BSD.
- // An assumption is made that other BSDs (and derivatives like Darwin)
- // will have the same behavior as FreeBSD.
- // > SO_SNDBUF and SO_RCVBUF are options to adjust the normal buffer sizes
- // > allocated for output and input buffers, respectively. The buffer size
- // > may be increased for high-volume connections, or may be decreased to
- // > limit the possible backlog of incoming data. The system places an ab-
- // > solute maximum on these values, which is accessible through the
- // > sysctl(3) MIB variable "kern.ipc.maxsockbuf".
- //
- // This syscall safely no-ops on Linux sockets, so no platform
- // detection logic or conditional calling is necessary.
- _ = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, int(r.readBufferSize))
- }
- }
- }
- return r, nil
- }
- // GetMetrics returns current counter values and connection info.
- func (r *Reader) GetMetrics() (uint64, uint64, uint64) {
- return atomic.LoadUint64(&r.receivedCount),
- atomic.LoadUint64(&r.connectionCount),
- atomic.LoadUint64(&r.errorCount)
- }
- // Start begins listening for connections.
- func (r *Reader) Start() error {
- if r.systemd.IsSystemd() {
- if err := r.systemd.NotifyReady(); err != nil {
- return fmt.Errorf("failed to notify systemd ready socket: %w", err)
- }
- }
- r.wg.Add(1)
- go r.run()
- return nil
- }
- // Stop shuts down the reader gracefully, allowing in-flight messages to complete
- // until the provided context is cancelled or expires. Subsequent calls return nil.
- func (r *Reader) Stop(ctx context.Context) error {
- var err error
- r.closeOnce.Do(func() {
- close(r.shutdownStart)
- // Unix domain socket Accept() doesn't seem to respect SetDeadline.
- // Force the blocked Accept() to return by connecting to ourselves.
- if r.listener != nil {
- go func() {
- //nolint: mnd // Brief delay to ensure r.shutdownStart channel is processed first.
- time.Sleep(10 * time.Millisecond)
- if conn, dialErr := net.Dial("unix", r.socketPath); dialErr == nil { // nolint: noctx
- _ = conn.Close()
- }
- }()
- }
- // Monitor context and abort drain if context is cancelled or expires.
- stopComplete := make(chan struct{})
- go func() {
- select {
- case <-ctx.Done():
- // Context cancelled or expired - force immediate shutdown.
- close(r.shutdownForced)
- case <-stopComplete:
- // Clean shutdown completed before context cancellation/expiration.
- }
- }()
- // Wait for all goroutines to finish before closing the listener.
- // This prevents a race condition where SetDeadline() is called
- // on an invalid file descriptor (as warned in os.File.Fd docs).
- r.wg.Wait()
- // Signal context monitor that we're done.
- close(stopComplete)
- if r.systemd.IsSystemd() {
- // r.systemd.Close will close the listener internally.
- // The file lifecycle of systemd managed sockets is handled by
- // systemd itself, so we don't have to remove the socket file.
- if systemdErr := r.systemd.Close(); systemdErr != nil {
- err = errors.Join(err, systemdErr)
- }
- } else {
- if r.listener != nil {
- err = r.listener.Close()
- }
- if cleanupErr := CleanupSocket(r.socketPath); cleanupErr != nil {
- err = errors.Join(err, cleanupErr)
- }
- }
- })
- return err
- }
- // run is the main accept loop.
- func (r *Reader) run() {
- defer r.wg.Done()
- consecutiveErrors := 0
- for {
- conn, err := r.listener.Accept()
- if err != nil {
- select {
- case <-r.shutdownStart:
- return
- default:
- consecutiveErrors++
- if consecutiveErrors > r.maxAcceptErrors {
- if r.onError != nil {
- r.onError(err, "too many consecutive failures in accept loop")
- }
- return
- }
- // nolint: mnd // Fixed 100ms sleep to prevent busy looping on Accept errors
- time.Sleep(100 * time.Millisecond)
- }
- continue
- }
- // Reset error count on successful accept.
- consecutiveErrors = 0
- atomic.AddUint64(&r.connectionCount, 1)
- // Check for shutdown after successful accept as well.
- select {
- case <-r.shutdownStart:
- _ = conn.Close()
- return
- default:
- }
- r.wg.Add(1)
- go r.handleConnection(conn)
- }
- }
- // handleConnection processes length-prefixed messages from a connection.
- // nolint: gocognit,funlen
- func (r *Reader) handleConnection(conn net.Conn) {
- defer r.wg.Done()
- defer conn.Close() // nolint: errcheck // Nothing to do with this error.
- if r.readBufferSize > 0 {
- if unixConn, ok := conn.(*net.UnixConn); ok {
- // Optimize read buffer for this connection.
- _ = unixConn.SetReadBuffer(int(r.readBufferSize))
- }
- }
- // Get pooled bufio.Reader and reset it for this connection.
- reader, _ := bufioReaderPool.Get().(*bufio.Reader)
- reader.Reset(conn)
- defer bufioReaderPool.Put(reader)
- draining := false
- for {
- select {
- case <-r.shutdownStart:
- draining = true
- case <-r.shutdownForced:
- // Forced shutdown - exit immediately without processing further messages.
- // IMPORTANT: This cannot interrupt an already-executing handler. If the handler
- // is blocking (e.g., in time.Sleep or blocking I/O), this goroutine will wait
- // for it to complete before returning. An open connection that continues to
- // write messages while the handler is blocked will cause this goroutine to
- // remain blocked indefinitely until the handler completes or the connection
- // closes. To prevent this, ensure handlers are responsive and don't block
- // indefinitely, or ensure clients close connections promptly on shutdown.
- return
- default:
- }
- // Set read deadline based on shutdown state.
- // Potential errors are ignored because there is nothing to do with them.
- var deadline time.Time
- if !draining {
- // Normal operation - set inactivity timeout to close idle connections.
- deadline = time.Now().Add(r.inactivityTimeout)
- } else {
- // Draining - add a short inactivity timeout to allow continued reading of
- // data from the socket while draining (using time.Now() is too fast).
- // nolint: mnd
- deadline = time.Now().Add(time.Millisecond)
- }
- _ = conn.SetReadDeadline(deadline)
- length, err := binary.ReadUvarint(reader)
- if err != nil {
- if errors.Is(err, io.EOF) {
- // Client closed the connection.
- return
- }
- var netErr net.Error
- if errors.As(err, &netErr) && netErr.Timeout() {
- // Close idle connections after inactivity timeout.
- return
- }
- // Neither the client closing the connection nor closing idle
- // connections after they reach timeout should incremenet the
- // error count, so only do it once those checks have happened.
- atomic.AddUint64(&r.errorCount, 1)
- if r.onError != nil {
- r.onError(err, "failed to read the length prefix")
- }
- return
- }
- if length > r.maxMessageSize {
- atomic.AddUint64(&r.errorCount, 1)
- if r.onError != nil {
- r.onError(fmt.Errorf("invalid message size: %d: exceeds limit: %d", length, r.maxMessageSize), "message too large")
- }
- return
- }
- if length < 1 {
- // Empty messages aren't sent by this package's writer, but
- // if we receive one, there is no need to treat it as an error.
- // There is nothing else to be done with this message.
- continue
- }
- // Read message bytes with known length.
- var message []byte
- var msgBuf *messageBuffer
- if length <= maxPooledMessageSize {
- // Use pooled buffer for small messages.
- msgBuf, _ = messageBufferPool.Get().(*messageBuffer)
- message = msgBuf.data[:length]
- } else {
- // Fall back to heap allocation for large messages.
- message = make([]byte, length)
- }
- if _, err := io.ReadFull(reader, message); err != nil { //nolint: govet // Safely shadowed error.
- if msgBuf != nil {
- messageBufferPool.Put(msgBuf)
- }
- atomic.AddUint64(&r.errorCount, 1)
- if r.onError != nil {
- r.onError(err, "failed to read the complete message")
- }
- return
- }
- if err := r.handler(message); err != nil { //nolint: govet // Safely shadowed error.
- if msgBuf != nil {
- messageBufferPool.Put(msgBuf)
- }
- atomic.AddUint64(&r.errorCount, 1)
- if r.onError != nil {
- r.onError(err, "handler failed to process a message")
- }
- // Don't close connection for handler errors on individual messages.
- continue
- }
- // Return buffer to pool after successful processing.
- if msgBuf != nil {
- messageBufferPool.Put(msgBuf)
- }
- atomic.AddUint64(&r.receivedCount, 1)
- }
- }
|