Selaa lähdekoodia

Implement custom UDP packet handler with 2-tuple routing

- Replace UDP Forwarder with custom HandlePacket function
- Implement UDP connection management grouped by source 2-tuple
- Use gVisor header builders to construct return packets with custom source addresses
- Add 5-minute idle timeout for UDP connections
- Support FullCone NAT by aggregating packets from same source

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 kuukautta sitten
vanhempi
sitoutus
d720cc4ad5
2 muutettua tiedostoa jossa 358 lisäystä ja 25 poistoa
  1. 354 0
      proxy/tun/handler.go
  2. 4 25
      proxy/tun/stack_gvisor.go

+ 354 - 0
proxy/tun/handler.go

@@ -2,6 +2,9 @@ package tun
 
 import (
 	"context"
+	"sync"
+	"sync/atomic"
+	"time"
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
@@ -10,21 +13,116 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/signal/done"
+	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport/internet/stat"
+	"github.com/xtls/xray-core/transport/pipe"
+	"gvisor.dev/gvisor/pkg/buffer"
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/checksum"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
 )
 
+// udpConnID represents a UDP connection identifier
+type udpConnID struct {
+	src  net.Destination
+	dest net.Destination
+}
+
+// udpConn represents a UDP connection for packet handling
+type udpConn struct {
+	lastActivityTime int64 // in seconds
+	reader           buf.Reader
+	writer           buf.Writer
+	output           func([]byte, net.Destination) (int, error)
+	remote           net.Addr
+	local            net.Addr
+	done             *done.Instance
+	inactive         bool
+	cancel           context.CancelFunc
+}
+
+func (c *udpConn) setInactive() {
+	c.inactive = true
+}
+
+func (c *udpConn) updateActivity() {
+	atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
+}
+
+// ReadMultiBuffer implements buf.Reader
+func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
+	mb, err := c.reader.ReadMultiBuffer()
+	if err != nil {
+		return nil, err
+	}
+	c.updateActivity()
+	return mb, nil
+}
+
+func (c *udpConn) Read(buf []byte) (int, error) {
+	panic("not implemented")
+}
+
+// Write implements io.Writer
+func (c *udpConn) Write(data []byte) (int, error) {
+	// Extract destination from the first buffer if available
+	// For now, write with empty destination (will be filled by output function)
+	n, err := c.output(data, net.Destination{})
+	if err == nil {
+		c.updateActivity()
+	}
+	return n, err
+}
+
+func (c *udpConn) Close() error {
+	if c.cancel != nil {
+		c.cancel()
+	}
+	common.Must(c.done.Close())
+	common.Must(common.Close(c.writer))
+	return nil
+}
+
+func (c *udpConn) RemoteAddr() net.Addr {
+	return c.remote
+}
+
+func (c *udpConn) LocalAddr() net.Addr {
+	return c.local
+}
+
+func (*udpConn) SetDeadline(time.Time) error {
+	return nil
+}
+
+func (*udpConn) SetReadDeadline(time.Time) error {
+	return nil
+}
+
+func (*udpConn) SetWriteDeadline(time.Time) error {
+	return nil
+}
+
 // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
 type Handler struct {
+	sync.RWMutex
+	
 	ctx           context.Context
 	config        *Config
 	stack         Stack
 	policyManager policy.Manager
 	dispatcher    routing.Dispatcher
 	cone          bool
+	
+	// UDP connection management
+	udpConns   map[udpConnID]*udpConn
+	udpChecker *task.Periodic
 }
 
 // ConnectionHandler interface with the only method that stack is going to push new connections to
@@ -40,6 +138,255 @@ func (t *Handler) policy() policy.Session {
 	return p
 }
 
+// getUDPConn gets or creates a UDP connection for the given source and destination
+func (t *Handler) getUDPConn(source, dest net.Destination, ipStack *stack.Stack) (*udpConn, bool) {
+	t.Lock()
+	defer t.Unlock()
+	
+	id := udpConnID{
+		src: source,
+	}
+	if !t.cone {
+		id.dest = dest
+	}
+	
+	if conn, found := t.udpConns[id]; found && !conn.done.Done() {
+		conn.updateActivity()
+		return conn, true
+	}
+	
+	pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
+	conn := &udpConn{
+		reader: pReader,
+		writer: pWriter,
+		output: func(data []byte, returnDest net.Destination) (int, error) {
+			// Write UDP packet back to the stack with proper source address
+			return t.writeUDPPacket(ipStack, data, returnDest, source)
+		},
+		remote: &net.UDPAddr{
+			IP:   source.Address.IP(),
+			Port: int(source.Port),
+		},
+		local: &net.UDPAddr{
+			IP:   dest.Address.IP(),
+			Port: int(dest.Port),
+		},
+		done: done.New(),
+	}
+	
+	if t.udpConns == nil {
+		t.udpConns = make(map[udpConnID]*udpConn)
+	}
+	t.udpConns[id] = conn
+	
+	conn.updateActivity()
+	return conn, false
+}
+
+// removeUDPConn removes a UDP connection
+func (t *Handler) removeUDPConn(id udpConnID) {
+	t.Lock()
+	delete(t.udpConns, id)
+	t.Unlock()
+}
+
+// cleanupUDPConns removes inactive UDP connections
+func (t *Handler) cleanupUDPConns() error {
+	nowSec := time.Now().Unix()
+	t.Lock()
+	defer t.Unlock()
+	
+	if len(t.udpConns) == 0 {
+		return errors.New("no more connections. stopping...")
+	}
+	
+	for id, conn := range t.udpConns {
+		if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 { // 5 minutes
+			if !conn.inactive {
+				conn.setInactive()
+				conn.Close()
+				delete(t.udpConns, id)
+			}
+		}
+	}
+	
+	if len(t.udpConns) == 0 {
+		return errors.New("no more connections. stopping...")
+	}
+	
+	return nil
+}
+
+// writeUDPPacket writes a UDP packet back to the gVisor stack with custom source address
+func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source net.Destination) (int, error) {
+	// Build UDP+IP packet with proper headers using gVisor's header builders
+	
+	// Determine IP version
+	var ipHdrLen, udpHdrLen int
+	isIPv4 := dest.Address.Family().IsIPv4()
+	
+	if isIPv4 {
+		ipHdrLen = header.IPv4MinimumSize
+	} else {
+		ipHdrLen = header.IPv6MinimumSize
+	}
+	udpHdrLen = header.UDPMinimumSize
+	
+	totalLen := ipHdrLen + udpHdrLen + len(data)
+	packet := make([]byte, totalLen)
+	
+	// Build UDP header
+	udpHeader := header.UDP(packet[ipHdrLen:])
+	udpHeader.Encode(&header.UDPFields{
+		SrcPort: uint16(dest.Port),  // Source is the original destination
+		DstPort: uint16(source.Port), // Destination is the original source
+		Length:  uint16(udpHdrLen + len(data)),
+	})
+	
+	// Copy payload
+	copy(packet[ipHdrLen+udpHdrLen:], data)
+	
+	// Build IP header and calculate checksums
+	if isIPv4 {
+		ipv4Header := header.IPv4(packet)
+		ipv4Header.Encode(&header.IPv4Fields{
+			TOS:         0,
+			TotalLength: uint16(totalLen),
+			ID:          0,
+			Flags:       0,
+			FragmentOffset: 0,
+			TTL:         64,
+			Protocol:    uint8(header.UDPProtocolNumber),
+			SrcAddr:     tcpip.AddrFromSlice(dest.Address.IP()),
+			DstAddr:     tcpip.AddrFromSlice(source.Address.IP()),
+		})
+		ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
+		
+		// Calculate UDP checksum
+		xsum := header.PseudoHeaderChecksum(
+			header.UDPProtocolNumber,
+			tcpip.AddrFromSlice(dest.Address.IP()),
+			tcpip.AddrFromSlice(source.Address.IP()),
+			uint16(udpHdrLen+len(data)),
+		)
+		xsum = checksum.Checksum(data, xsum)
+		udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
+	} else {
+		ipv6Header := header.IPv6(packet)
+		ipv6Header.Encode(&header.IPv6Fields{
+			TrafficClass:  0,
+			FlowLabel:     0,
+			PayloadLength: uint16(udpHdrLen + len(data)),
+			TransportProtocol: header.UDPProtocolNumber,
+			HopLimit:      64,
+			SrcAddr:       tcpip.AddrFromSlice(dest.Address.IP()),
+			DstAddr:       tcpip.AddrFromSlice(source.Address.IP()),
+		})
+		
+		// Calculate UDP checksum for IPv6
+		xsum := header.PseudoHeaderChecksum(
+			header.UDPProtocolNumber,
+			tcpip.AddrFromSlice(dest.Address.IP()),
+			tcpip.AddrFromSlice(source.Address.IP()),
+			uint16(udpHdrLen+len(data)),
+		)
+		xsum = checksum.Checksum(data, xsum)
+		udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
+	}
+	
+	// Write packet to stack
+	var proto tcpip.NetworkProtocolNumber
+	if isIPv4 {
+		proto = header.IPv4ProtocolNumber
+	} else {
+		proto = header.IPv6ProtocolNumber
+	}
+	
+	buf := buffer.MakeWithData(packet)
+	if err := ipStack.WriteRawPacket(defaultNIC, proto, buf); err != nil {
+		return 0, errors.New("failed to write packet: " + err.String())
+	}
+	
+	return len(data), nil
+}
+
+// HandleUDPPacket processes a raw UDP packet from gVisor
+func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
+	// Extract packet information
+	source := net.UDPDestination(
+		net.IPAddress(id.RemoteAddress.AsSlice()),
+		net.Port(id.RemotePort),
+	)
+	dest := net.UDPDestination(
+		net.IPAddress(id.LocalAddress.AsSlice()),
+		net.Port(id.LocalPort),
+	)
+	
+	// Extract UDP payload
+	data := pkt.Data().AsRange().ToSlice()
+	if len(data) == 0 {
+		return
+	}
+	
+	// Get or create connection for this source
+	conn, existing := t.getUDPConn(source, dest, ipStack)
+	
+	// Create buffer and set UDP destination
+	b := buf.New()
+	b.Write(data)
+	b.UDP = &dest
+	
+	// Write to connection pipe
+	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
+	
+	if !existing {
+		// Start checker for cleanup
+		if t.udpChecker != nil {
+			common.Must(t.udpChecker.Start())
+		}
+		
+		// Start handling this connection
+		go func() {
+			connID := udpConnID{
+				src: source,
+			}
+			if !t.cone {
+				connID.dest = dest
+			}
+			
+			ctx, cancel := context.WithCancel(t.ctx)
+			conn.cancel = cancel
+			sid := session.NewID()
+			ctx = c.ContextWithID(ctx, sid)
+			
+			inbound := session.Inbound{}
+			inbound.Name = "tun"
+			inbound.Source = source
+			inbound.User = &protocol.MemoryUser{
+				Level: t.config.UserLevel,
+			}
+			
+			ctx = session.ContextWithInbound(ctx, &inbound)
+			ctx = session.SubContextFromMuxInbound(ctx)
+			
+			link := &transport.Link{
+				Reader: conn.reader,
+				Writer: buf.NewWriter(conn),
+			}
+			
+			if err := t.dispatcher.DispatchLink(ctx, dest, link); err != nil {
+				errors.LogError(ctx, errors.New("UDP connection ended").Base(err))
+			}
+			
+			conn.Close()
+			if !conn.inactive {
+				conn.setInactive()
+				t.removeUDPConn(connID)
+			}
+		}()
+	}
+}
+
 // Init the Handler instance with necessary parameters
 func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error {
 	var err error
@@ -48,6 +395,13 @@ func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routin
 	t.policyManager = pm
 	t.dispatcher = dispatcher
 	t.cone = ctx.Value("cone").(bool)
+	
+	// Initialize UDP connection manager
+	t.udpConns = make(map[udpConnID]*udpConn)
+	t.udpChecker = &task.Periodic{
+		Interval: time.Minute,
+		Execute:  t.cleanupUDPConns,
+	}
 
 	tunName := t.config.Name
 	tunOptions := TunOptions{

+ 4 - 25
proxy/tun/stack_gvisor.go

@@ -100,32 +100,11 @@ func (t *stackGVisor) Start() error {
 	})
 	ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
 
-	udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) {
-		go func(r *udp.ForwarderRequest) {
-			var wq waiter.Queue
-			var id = r.ID()
-
-			ep, err := r.CreateEndpoint(&wq)
-			if err != nil {
-				errors.LogError(t.ctx, err.String())
-				return
-			}
-
-			options := ep.SocketOptions()
-			options.SetReuseAddress(true)
-			options.SetReusePort(true)
-
-			t.handler.HandleConnection(
-				gonet.NewUDPConn(&wq, ep),
-				// local address on the gVisor side is connection destination
-				net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)),
-			)
-
-			// close the socket
-			ep.Close()
-		}(r)
+	// Use custom UDP packet handler instead of forwarder for FullCone NAT
+	ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+		t.handler.HandleUDPPacket(id, pkt, ipStack)
+		return true
 	})
-	ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
 
 	t.stack = ipStack
 	t.endpoint = linkEndpoint