|
|
@@ -28,76 +28,25 @@ import (
|
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
|
)
|
|
|
|
|
|
-// udpConnID represents a UDP connection identifier
|
|
|
-type udpConnID struct {
|
|
|
- src net.Destination
|
|
|
- dest net.Destination
|
|
|
-}
|
|
|
-
|
|
|
-// udpConn represents a UDP connection for packet handling
|
|
|
type udpConn struct {
|
|
|
- lastActivityTime int64 // in seconds
|
|
|
- reader buf.Reader
|
|
|
- writer buf.Writer
|
|
|
- output func([]byte, net.Destination) (int, error)
|
|
|
- remote net.Addr
|
|
|
- local net.Addr
|
|
|
- done *done.Instance
|
|
|
- inactive bool
|
|
|
- cancel context.CancelFunc
|
|
|
-}
|
|
|
-
|
|
|
-func (c *udpConn) updateActivity() {
|
|
|
- atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
|
|
|
-}
|
|
|
-
|
|
|
-func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
|
|
- mb, err := c.reader.ReadMultiBuffer()
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- c.updateActivity()
|
|
|
- return mb, nil
|
|
|
-}
|
|
|
-
|
|
|
-func (c *udpConn) Write(data []byte) (int, error) {
|
|
|
- n, err := c.output(data, net.Destination{})
|
|
|
- if err == nil {
|
|
|
- c.updateActivity()
|
|
|
- }
|
|
|
- return n, err
|
|
|
+ lastActive int64
|
|
|
+ reader buf.Reader
|
|
|
+ writer buf.Writer
|
|
|
+ done *done.Instance
|
|
|
+ cancel context.CancelFunc
|
|
|
}
|
|
|
|
|
|
-func (c *udpConn) Close() error {
|
|
|
- if c.cancel != nil {
|
|
|
- c.cancel()
|
|
|
- }
|
|
|
- common.Must(c.done.Close())
|
|
|
- common.Must(common.Close(c.writer))
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-func (c *udpConn) RemoteAddr() net.Addr { return c.remote }
|
|
|
-func (c *udpConn) LocalAddr() net.Addr { return c.local }
|
|
|
-func (c *udpConn) Read([]byte) (int, error) { return 0, errors.New("not supported") }
|
|
|
-func (*udpConn) SetDeadline(time.Time) error { return nil }
|
|
|
-func (*udpConn) SetReadDeadline(time.Time) error { return nil }
|
|
|
-func (*udpConn) SetWriteDeadline(time.Time) error { return nil }
|
|
|
-
|
|
|
// Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
|
|
|
type Handler struct {
|
|
|
- sync.RWMutex
|
|
|
-
|
|
|
+ sync.Mutex
|
|
|
ctx context.Context
|
|
|
config *Config
|
|
|
stack Stack
|
|
|
policyManager policy.Manager
|
|
|
dispatcher routing.Dispatcher
|
|
|
cone bool
|
|
|
-
|
|
|
- // UDP connection management
|
|
|
- udpConns map[udpConnID]*udpConn
|
|
|
- udpChecker *task.Periodic
|
|
|
+ udpConns map[net.Destination]*udpConn
|
|
|
+ udpChecker *task.Periodic
|
|
|
}
|
|
|
|
|
|
// ConnectionHandler interface with the only method that stack is going to push new connections to
|
|
|
@@ -109,177 +58,123 @@ type ConnectionHandler interface {
|
|
|
var _ ConnectionHandler = (*Handler)(nil)
|
|
|
|
|
|
func (t *Handler) policy() policy.Session {
|
|
|
- p := t.policyManager.ForLevel(t.config.UserLevel)
|
|
|
- return p
|
|
|
-}
|
|
|
-
|
|
|
-// getUDPConn gets or creates a UDP connection for the given source and destination
|
|
|
-func (t *Handler) getUDPConn(source, dest net.Destination, ipStack *stack.Stack) (*udpConn, bool) {
|
|
|
- t.Lock()
|
|
|
- defer t.Unlock()
|
|
|
-
|
|
|
- id := udpConnID{
|
|
|
- src: source,
|
|
|
- }
|
|
|
- if !t.cone {
|
|
|
- id.dest = dest
|
|
|
- }
|
|
|
-
|
|
|
- if conn, found := t.udpConns[id]; found && !conn.done.Done() {
|
|
|
- conn.updateActivity()
|
|
|
- return conn, true
|
|
|
- }
|
|
|
-
|
|
|
- pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
|
|
|
- conn := &udpConn{
|
|
|
- reader: pReader,
|
|
|
- writer: pWriter,
|
|
|
- output: func(data []byte, returnDest net.Destination) (int, error) {
|
|
|
- // Write UDP packet back to the stack with proper source address
|
|
|
- return t.writeUDPPacket(ipStack, data, returnDest, source)
|
|
|
- },
|
|
|
- remote: &net.UDPAddr{
|
|
|
- IP: source.Address.IP(),
|
|
|
- Port: int(source.Port),
|
|
|
- },
|
|
|
- local: &net.UDPAddr{
|
|
|
- IP: dest.Address.IP(),
|
|
|
- Port: int(dest.Port),
|
|
|
- },
|
|
|
- done: done.New(),
|
|
|
- }
|
|
|
-
|
|
|
- t.udpConns[id] = conn
|
|
|
-
|
|
|
- conn.updateActivity()
|
|
|
- return conn, false
|
|
|
-}
|
|
|
-
|
|
|
-// removeUDPConn removes a UDP connection
|
|
|
-func (t *Handler) removeUDPConn(id udpConnID) {
|
|
|
- t.Lock()
|
|
|
- delete(t.udpConns, id)
|
|
|
- t.Unlock()
|
|
|
+ return t.policyManager.ForLevel(t.config.UserLevel)
|
|
|
}
|
|
|
|
|
|
-// cleanupUDPConns removes inactive UDP connections
|
|
|
-func (t *Handler) cleanupUDPConns() error {
|
|
|
+func (t *Handler) cleanupUDP() error {
|
|
|
t.Lock()
|
|
|
defer t.Unlock()
|
|
|
-
|
|
|
if len(t.udpConns) == 0 {
|
|
|
- return errors.New("no active connections")
|
|
|
+ return errors.New("no connections")
|
|
|
}
|
|
|
-
|
|
|
- nowSec := time.Now().Unix()
|
|
|
- for id, conn := range t.udpConns {
|
|
|
- if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 && !conn.inactive {
|
|
|
- conn.inactive = true
|
|
|
- conn.Close()
|
|
|
- delete(t.udpConns, id)
|
|
|
+ now := time.Now().Unix()
|
|
|
+ for src, conn := range t.udpConns {
|
|
|
+ if now-atomic.LoadInt64(&conn.lastActive) > 300 {
|
|
|
+ conn.cancel()
|
|
|
+ common.Must(conn.done.Close())
|
|
|
+ common.Must(common.Close(conn.writer))
|
|
|
+ delete(t.udpConns, src)
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-// writeUDPPacket writes a UDP packet back to the gVisor stack with custom source address
|
|
|
-func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source net.Destination) (int, error) {
|
|
|
- netProto := header.IPv4ProtocolNumber
|
|
|
- if !dest.Address.Family().IsIPv4() {
|
|
|
- netProto = header.IPv6ProtocolNumber
|
|
|
- }
|
|
|
-
|
|
|
- route, err := ipStack.FindRoute(defaultNIC, tcpip.AddrFromSlice(dest.Address.IP()), tcpip.AddrFromSlice(source.Address.IP()), netProto, false)
|
|
|
- if err != nil {
|
|
|
- return 0, errors.New("failed to find route: " + err.String())
|
|
|
- }
|
|
|
- defer route.Release()
|
|
|
-
|
|
|
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
|
- ReserveHeaderBytes: header.UDPMinimumSize,
|
|
|
- Payload: buffer.MakeWithData(data),
|
|
|
- })
|
|
|
- defer pkt.DecRef()
|
|
|
-
|
|
|
- length := uint16(pkt.Size())
|
|
|
- udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
|
|
- udpHeader.Encode(&header.UDPFields{
|
|
|
- SrcPort: uint16(dest.Port),
|
|
|
- DstPort: uint16(source.Port),
|
|
|
- Length: length,
|
|
|
- })
|
|
|
-
|
|
|
- xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, length)
|
|
|
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(checksum.Checksum(data, xsum)))
|
|
|
-
|
|
|
- if err := route.WritePacket(stack.NetworkHeaderParams{
|
|
|
- Protocol: header.UDPProtocolNumber,
|
|
|
- TTL: 64,
|
|
|
- TOS: 0,
|
|
|
- }, pkt); err != nil {
|
|
|
- return 0, errors.New("failed to write packet: " + err.String())
|
|
|
- }
|
|
|
-
|
|
|
- return len(data), nil
|
|
|
-}
|
|
|
-
|
|
|
-// HandleUDPPacket processes a raw UDP packet from gVisor
|
|
|
func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
|
|
|
- source := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
|
|
|
+ src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
|
|
|
dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
|
|
|
-
|
|
|
data := pkt.Data().AsRange().ToSlice()
|
|
|
if len(data) == 0 {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- conn, existing := t.getUDPConn(source, dest, ipStack)
|
|
|
-
|
|
|
- b := buf.New()
|
|
|
- b.Write(data)
|
|
|
- b.UDP = &dest
|
|
|
- conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
|
|
|
-
|
|
|
- if !existing {
|
|
|
- t.Lock()
|
|
|
+ t.Lock()
|
|
|
+ conn, found := t.udpConns[src]
|
|
|
+ if !found {
|
|
|
+ reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
|
|
|
+ conn = &udpConn{reader: reader, writer: writer, done: done.New()}
|
|
|
+ t.udpConns[src] = conn
|
|
|
if t.udpChecker != nil && len(t.udpConns) == 1 {
|
|
|
common.Must(t.udpChecker.Start())
|
|
|
}
|
|
|
t.Unlock()
|
|
|
|
|
|
- go t.handleUDPConn(conn, source, dest)
|
|
|
+ go func() {
|
|
|
+ ctx, cancel := context.WithCancel(t.ctx)
|
|
|
+ conn.cancel = cancel
|
|
|
+ defer func() {
|
|
|
+ cancel()
|
|
|
+ t.Lock()
|
|
|
+ delete(t.udpConns, src)
|
|
|
+ t.Unlock()
|
|
|
+ common.Must(conn.done.Close())
|
|
|
+ common.Must(common.Close(conn.writer))
|
|
|
+ }()
|
|
|
+
|
|
|
+ ctx = c.ContextWithID(ctx, session.NewID())
|
|
|
+ ctx = session.ContextWithInbound(ctx, &session.Inbound{
|
|
|
+ Name: "tun", Source: src,
|
|
|
+ User: &protocol.MemoryUser{Level: t.config.UserLevel},
|
|
|
+ })
|
|
|
+
|
|
|
+ t.dispatcher.DispatchLink(ctx, dest, &transport.Link{
|
|
|
+ Reader: conn.reader,
|
|
|
+ Writer: &udpWriter{stack: ipStack, src: dest, dest: src},
|
|
|
+ })
|
|
|
+ }()
|
|
|
+ } else {
|
|
|
+ atomic.StoreInt64(&conn.lastActive, time.Now().Unix())
|
|
|
+ t.Unlock()
|
|
|
}
|
|
|
+
|
|
|
+ b := buf.New()
|
|
|
+ b.Write(data)
|
|
|
+ b.UDP = &dest
|
|
|
+ conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
|
|
|
}
|
|
|
|
|
|
-func (t *Handler) handleUDPConn(conn *udpConn, source, dest net.Destination) {
|
|
|
- connID := udpConnID{src: source}
|
|
|
- if !t.cone {
|
|
|
- connID.dest = dest
|
|
|
- }
|
|
|
-
|
|
|
- ctx, cancel := context.WithCancel(t.ctx)
|
|
|
- conn.cancel = cancel
|
|
|
- ctx = c.ContextWithID(ctx, session.NewID())
|
|
|
- ctx = session.ContextWithInbound(ctx, &session.Inbound{
|
|
|
- Name: "tun",
|
|
|
- Source: source,
|
|
|
- User: &protocol.MemoryUser{Level: t.config.UserLevel},
|
|
|
- })
|
|
|
- ctx = session.SubContextFromMuxInbound(ctx)
|
|
|
-
|
|
|
- if err := t.dispatcher.DispatchLink(ctx, dest, &transport.Link{
|
|
|
- Reader: conn.reader,
|
|
|
- Writer: buf.NewWriter(conn),
|
|
|
- }); err != nil {
|
|
|
- errors.LogError(ctx, errors.New("UDP connection ended").Base(err))
|
|
|
- }
|
|
|
-
|
|
|
- conn.Close()
|
|
|
- if !conn.inactive {
|
|
|
- conn.inactive = true
|
|
|
- t.removeUDPConn(connID)
|
|
|
+type udpWriter struct {
|
|
|
+ stack *stack.Stack
|
|
|
+ src net.Destination
|
|
|
+ dest net.Destination
|
|
|
+}
|
|
|
+
|
|
|
+func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
|
|
+ for _, b := range mb {
|
|
|
+ if b.UDP != nil {
|
|
|
+ w.src = *b.UDP
|
|
|
+ }
|
|
|
+
|
|
|
+ netProto := header.IPv4ProtocolNumber
|
|
|
+ if !w.src.Address.Family().IsIPv4() {
|
|
|
+ netProto = header.IPv6ProtocolNumber
|
|
|
+ }
|
|
|
+
|
|
|
+ route, err := w.stack.FindRoute(defaultNIC, tcpip.AddrFromSlice(w.src.Address.IP()), tcpip.AddrFromSlice(w.dest.Address.IP()), netProto, false)
|
|
|
+ if err != nil {
|
|
|
+ b.Release()
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
|
+ ReserveHeaderBytes: header.UDPMinimumSize,
|
|
|
+ Payload: buffer.MakeWithData(b.Bytes()),
|
|
|
+ })
|
|
|
+
|
|
|
+ length := uint16(pkt.Size())
|
|
|
+ udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
|
|
+ udpHeader.Encode(&header.UDPFields{
|
|
|
+ SrcPort: uint16(w.src.Port),
|
|
|
+ DstPort: uint16(w.dest.Port),
|
|
|
+ Length: length,
|
|
|
+ })
|
|
|
+ udpHeader.SetChecksum(^udpHeader.CalculateChecksum(checksum.Checksum(b.Bytes(), route.PseudoHeaderChecksum(header.UDPProtocolNumber, length))))
|
|
|
+
|
|
|
+ route.WritePacket(stack.NetworkHeaderParams{Protocol: header.UDPProtocolNumber, TTL: 64}, pkt)
|
|
|
+ pkt.DecRef()
|
|
|
+ route.Release()
|
|
|
+ b.Release()
|
|
|
}
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
// Init the Handler instance with necessary parameters
|
|
|
@@ -290,13 +185,8 @@ func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routin
|
|
|
t.policyManager = pm
|
|
|
t.dispatcher = dispatcher
|
|
|
t.cone = ctx.Value("cone").(bool)
|
|
|
-
|
|
|
- // Initialize UDP connection manager
|
|
|
- t.udpConns = make(map[udpConnID]*udpConn)
|
|
|
- t.udpChecker = &task.Periodic{
|
|
|
- Interval: time.Minute,
|
|
|
- Execute: t.cleanupUDPConns,
|
|
|
- }
|
|
|
+ t.udpConns = make(map[net.Destination]*udpConn)
|
|
|
+ t.udpChecker = &task.Periodic{Interval: time.Minute, Execute: t.cleanupUDP}
|
|
|
|
|
|
tunName := t.config.Name
|
|
|
tunOptions := TunOptions{
|
|
|
@@ -357,20 +247,10 @@ func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) {
|
|
|
ctx = session.ContextWithInbound(ctx, &inbound)
|
|
|
ctx = session.SubContextFromMuxInbound(ctx)
|
|
|
|
|
|
- var link *transport.Link
|
|
|
- if destination.Network == net.Network_UDP {
|
|
|
- // For UDP, use PacketReader to preserve packet boundaries
|
|
|
- link = &transport.Link{
|
|
|
- Reader: buf.NewPacketReader(conn),
|
|
|
- Writer: buf.NewWriter(conn),
|
|
|
- }
|
|
|
- } else {
|
|
|
- link = &transport.Link{
|
|
|
- Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},
|
|
|
- Writer: buf.NewWriter(conn),
|
|
|
- }
|
|
|
+ link := &transport.Link{
|
|
|
+ Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},
|
|
|
+ Writer: buf.NewWriter(conn),
|
|
|
}
|
|
|
-
|
|
|
if err := t.dispatcher.DispatchLink(ctx, destination, link); err != nil {
|
|
|
errors.LogError(ctx, errors.New("connection closed").Base(err))
|
|
|
return
|