| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- // Package mux multiplexes packets on a single socket (RFC7983)
- package mux
- import (
- "errors"
- "io"
- "net"
- "sync"
- "github.com/pion/ice/v2"
- "github.com/pion/logging"
- "github.com/pion/transport/v2/packetio"
- )
- // The maximum amount of data that can be buffered before returning errors.
- const maxBufferSize = 1000 * 1000 // 1MB
- // Config collects the arguments to mux.Mux construction into
- // a single structure
- type Config struct {
- Conn net.Conn
- BufferSize int
- LoggerFactory logging.LoggerFactory
- }
- // Mux allows multiplexing
- type Mux struct {
- lock sync.RWMutex
- nextConn net.Conn
- endpoints map[*Endpoint]MatchFunc
- bufferSize int
- closedCh chan struct{}
- log logging.LeveledLogger
- }
- // NewMux creates a new Mux
- func NewMux(config Config) *Mux {
- m := &Mux{
- nextConn: config.Conn,
- endpoints: make(map[*Endpoint]MatchFunc),
- bufferSize: config.BufferSize,
- closedCh: make(chan struct{}),
- log: config.LoggerFactory.NewLogger("mux"),
- }
- go m.readLoop()
- return m
- }
- // NewEndpoint creates a new Endpoint
- func (m *Mux) NewEndpoint(f MatchFunc) *Endpoint {
- e := &Endpoint{
- mux: m,
- buffer: packetio.NewBuffer(),
- }
- // Set a maximum size of the buffer in bytes.
- e.buffer.SetLimitSize(maxBufferSize)
- m.lock.Lock()
- m.endpoints[e] = f
- m.lock.Unlock()
- return e
- }
- // RemoveEndpoint removes an endpoint from the Mux
- func (m *Mux) RemoveEndpoint(e *Endpoint) {
- m.lock.Lock()
- defer m.lock.Unlock()
- delete(m.endpoints, e)
- }
- // Close closes the Mux and all associated Endpoints.
- func (m *Mux) Close() error {
- m.lock.Lock()
- for e := range m.endpoints {
- if err := e.close(); err != nil {
- m.lock.Unlock()
- return err
- }
- delete(m.endpoints, e)
- }
- m.lock.Unlock()
- err := m.nextConn.Close()
- if err != nil {
- return err
- }
- // Wait for readLoop to end
- <-m.closedCh
- return nil
- }
- func (m *Mux) readLoop() {
- defer func() {
- close(m.closedCh)
- }()
- buf := make([]byte, m.bufferSize)
- for {
- n, err := m.nextConn.Read(buf)
- switch {
- case errors.Is(err, io.EOF), errors.Is(err, ice.ErrClosed):
- return
- case errors.Is(err, io.ErrShortBuffer), errors.Is(err, packetio.ErrTimeout):
- m.log.Errorf("mux: failed to read from packetio.Buffer %s", err.Error())
- continue
- case err != nil:
- m.log.Errorf("mux: ending readLoop packetio.Buffer error %s", err.Error())
- return
- }
- if err = m.dispatch(buf[:n]); err != nil {
- m.log.Errorf("mux: ending readLoop dispatch error %s", err.Error())
- return
- }
- }
- }
- func (m *Mux) dispatch(buf []byte) error {
- var endpoint *Endpoint
- m.lock.Lock()
- for e, f := range m.endpoints {
- if f(buf) {
- endpoint = e
- break
- }
- }
- m.lock.Unlock()
- if endpoint == nil {
- if len(buf) > 0 {
- m.log.Warnf("Warning: mux: no endpoint for packet starting with %d", buf[0])
- } else {
- m.log.Warnf("Warning: mux: no endpoint for zero length packet")
- }
- return nil
- }
- _, err := endpoint.buffer.Write(buf)
- // Expected when bytes are received faster than the endpoint can process them (#2152, #2180)
- if errors.Is(err, packetio.ErrFull) {
- m.log.Infof("mux: endpoint buffer is full, dropping packet")
- return nil
- }
- return err
- }
|