Просмотр исходного кода

Drastically simplify UDP implementation (under 200 lines)

- Remove overly complex udpConn methods and udpConnID struct
- Simplify to single map keyed by source only
- Remove HandleConnection UDP check (UDP handled separately)
- Inline packet writing into udpWriter
- Remove redundant helper methods
- Total new code: 141 lines (vs 264 before)

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 месяцев назад
Родитель
Сommit
543d2ffcaf
1 измененных файлов с 103 добавлено и 223 удалено
  1. 103 223
      proxy/tun/handler.go

+ 103 - 223
proxy/tun/handler.go

@@ -28,76 +28,25 @@ import (
 	"gvisor.dev/gvisor/pkg/tcpip/stack"
 )
 
-// udpConnID represents a UDP connection identifier
-type udpConnID struct {
-	src  net.Destination
-	dest net.Destination
-}
-
-// udpConn represents a UDP connection for packet handling
 type udpConn struct {
-	lastActivityTime int64 // in seconds
-	reader           buf.Reader
-	writer           buf.Writer
-	output           func([]byte, net.Destination) (int, error)
-	remote           net.Addr
-	local            net.Addr
-	done             *done.Instance
-	inactive         bool
-	cancel           context.CancelFunc
-}
-
-func (c *udpConn) updateActivity() {
-	atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
-}
-
-func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
-	mb, err := c.reader.ReadMultiBuffer()
-	if err != nil {
-		return nil, err
-	}
-	c.updateActivity()
-	return mb, nil
-}
-
-func (c *udpConn) Write(data []byte) (int, error) {
-	n, err := c.output(data, net.Destination{})
-	if err == nil {
-		c.updateActivity()
-	}
-	return n, err
+	lastActive int64
+	reader     buf.Reader
+	writer     buf.Writer
+	done       *done.Instance
+	cancel     context.CancelFunc
 }
 
-func (c *udpConn) Close() error {
-	if c.cancel != nil {
-		c.cancel()
-	}
-	common.Must(c.done.Close())
-	common.Must(common.Close(c.writer))
-	return nil
-}
-
-func (c *udpConn) RemoteAddr() net.Addr              { return c.remote }
-func (c *udpConn) LocalAddr() net.Addr               { return c.local }
-func (c *udpConn) Read([]byte) (int, error)         { return 0, errors.New("not supported") }
-func (*udpConn) SetDeadline(time.Time) error        { return nil }
-func (*udpConn) SetReadDeadline(time.Time) error    { return nil }
-func (*udpConn) SetWriteDeadline(time.Time) error   { return nil }
-
 // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
 type Handler struct {
-	sync.RWMutex
-	
+	sync.Mutex
 	ctx           context.Context
 	config        *Config
 	stack         Stack
 	policyManager policy.Manager
 	dispatcher    routing.Dispatcher
 	cone          bool
-	
-	// UDP connection management
-	udpConns   map[udpConnID]*udpConn
-	udpChecker *task.Periodic
+	udpConns      map[net.Destination]*udpConn
+	udpChecker    *task.Periodic
 }
 
 // ConnectionHandler interface with the only method that stack is going to push new connections to
@@ -109,177 +58,123 @@ type ConnectionHandler interface {
 var _ ConnectionHandler = (*Handler)(nil)
 
 func (t *Handler) policy() policy.Session {
-	p := t.policyManager.ForLevel(t.config.UserLevel)
-	return p
-}
-
-// getUDPConn gets or creates a UDP connection for the given source and destination
-func (t *Handler) getUDPConn(source, dest net.Destination, ipStack *stack.Stack) (*udpConn, bool) {
-	t.Lock()
-	defer t.Unlock()
-	
-	id := udpConnID{
-		src: source,
-	}
-	if !t.cone {
-		id.dest = dest
-	}
-	
-	if conn, found := t.udpConns[id]; found && !conn.done.Done() {
-		conn.updateActivity()
-		return conn, true
-	}
-	
-	pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
-	conn := &udpConn{
-		reader: pReader,
-		writer: pWriter,
-		output: func(data []byte, returnDest net.Destination) (int, error) {
-			// Write UDP packet back to the stack with proper source address
-			return t.writeUDPPacket(ipStack, data, returnDest, source)
-		},
-		remote: &net.UDPAddr{
-			IP:   source.Address.IP(),
-			Port: int(source.Port),
-		},
-		local: &net.UDPAddr{
-			IP:   dest.Address.IP(),
-			Port: int(dest.Port),
-		},
-		done: done.New(),
-	}
-	
-	t.udpConns[id] = conn
-	
-	conn.updateActivity()
-	return conn, false
-}
-
-// removeUDPConn removes a UDP connection
-func (t *Handler) removeUDPConn(id udpConnID) {
-	t.Lock()
-	delete(t.udpConns, id)
-	t.Unlock()
+	return t.policyManager.ForLevel(t.config.UserLevel)
 }
 
-// cleanupUDPConns removes inactive UDP connections
-func (t *Handler) cleanupUDPConns() error {
+func (t *Handler) cleanupUDP() error {
 	t.Lock()
 	defer t.Unlock()
-	
 	if len(t.udpConns) == 0 {
-		return errors.New("no active connections")
+		return errors.New("no connections")
 	}
-	
-	nowSec := time.Now().Unix()
-	for id, conn := range t.udpConns {
-		if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 && !conn.inactive {
-			conn.inactive = true
-			conn.Close()
-			delete(t.udpConns, id)
+	now := time.Now().Unix()
+	for src, conn := range t.udpConns {
+		if now-atomic.LoadInt64(&conn.lastActive) > 300 {
+			conn.cancel()
+			common.Must(conn.done.Close())
+			common.Must(common.Close(conn.writer))
+			delete(t.udpConns, src)
 		}
 	}
-	
 	return nil
 }
 
-// writeUDPPacket writes a UDP packet back to the gVisor stack with custom source address
-func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source net.Destination) (int, error) {
-	netProto := header.IPv4ProtocolNumber
-	if !dest.Address.Family().IsIPv4() {
-		netProto = header.IPv6ProtocolNumber
-	}
-	
-	route, err := ipStack.FindRoute(defaultNIC, tcpip.AddrFromSlice(dest.Address.IP()), tcpip.AddrFromSlice(source.Address.IP()), netProto, false)
-	if err != nil {
-		return 0, errors.New("failed to find route: " + err.String())
-	}
-	defer route.Release()
-	
-	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
-		ReserveHeaderBytes: header.UDPMinimumSize,
-		Payload:            buffer.MakeWithData(data),
-	})
-	defer pkt.DecRef()
-	
-	length := uint16(pkt.Size())
-	udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
-	udpHeader.Encode(&header.UDPFields{
-		SrcPort: uint16(dest.Port),
-		DstPort: uint16(source.Port),
-		Length:  length,
-	})
-	
-	xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, length)
-	udpHeader.SetChecksum(^udpHeader.CalculateChecksum(checksum.Checksum(data, xsum)))
-	
-	if err := route.WritePacket(stack.NetworkHeaderParams{
-		Protocol: header.UDPProtocolNumber,
-		TTL:      64,
-		TOS:      0,
-	}, pkt); err != nil {
-		return 0, errors.New("failed to write packet: " + err.String())
-	}
-	
-	return len(data), nil
-}
-
-// HandleUDPPacket processes a raw UDP packet from gVisor
 func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
-	source := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
+	src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
 	dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
-	
 	data := pkt.Data().AsRange().ToSlice()
 	if len(data) == 0 {
 		return
 	}
 	
-	conn, existing := t.getUDPConn(source, dest, ipStack)
-	
-	b := buf.New()
-	b.Write(data)
-	b.UDP = &dest
-	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
-	
-	if !existing {
-		t.Lock()
+	t.Lock()
+	conn, found := t.udpConns[src]
+	if !found {
+		reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
+		conn = &udpConn{reader: reader, writer: writer, done: done.New()}
+		t.udpConns[src] = conn
 		if t.udpChecker != nil && len(t.udpConns) == 1 {
 			common.Must(t.udpChecker.Start())
 		}
 		t.Unlock()
 		
-		go t.handleUDPConn(conn, source, dest)
+		go func() {
+			ctx, cancel := context.WithCancel(t.ctx)
+			conn.cancel = cancel
+			defer func() {
+				cancel()
+				t.Lock()
+				delete(t.udpConns, src)
+				t.Unlock()
+				common.Must(conn.done.Close())
+				common.Must(common.Close(conn.writer))
+			}()
+			
+			ctx = c.ContextWithID(ctx, session.NewID())
+			ctx = session.ContextWithInbound(ctx, &session.Inbound{
+				Name: "tun", Source: src,
+				User: &protocol.MemoryUser{Level: t.config.UserLevel},
+			})
+			
+			t.dispatcher.DispatchLink(ctx, dest, &transport.Link{
+				Reader: conn.reader,
+				Writer: &udpWriter{stack: ipStack, src: dest, dest: src},
+			})
+		}()
+	} else {
+		atomic.StoreInt64(&conn.lastActive, time.Now().Unix())
+		t.Unlock()
 	}
+	
+	b := buf.New()
+	b.Write(data)
+	b.UDP = &dest
+	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
 }
 
-func (t *Handler) handleUDPConn(conn *udpConn, source, dest net.Destination) {
-	connID := udpConnID{src: source}
-	if !t.cone {
-		connID.dest = dest
-	}
-	
-	ctx, cancel := context.WithCancel(t.ctx)
-	conn.cancel = cancel
-	ctx = c.ContextWithID(ctx, session.NewID())
-	ctx = session.ContextWithInbound(ctx, &session.Inbound{
-		Name:   "tun",
-		Source: source,
-		User:   &protocol.MemoryUser{Level: t.config.UserLevel},
-	})
-	ctx = session.SubContextFromMuxInbound(ctx)
-	
-	if err := t.dispatcher.DispatchLink(ctx, dest, &transport.Link{
-		Reader: conn.reader,
-		Writer: buf.NewWriter(conn),
-	}); err != nil {
-		errors.LogError(ctx, errors.New("UDP connection ended").Base(err))
-	}
-	
-	conn.Close()
-	if !conn.inactive {
-		conn.inactive = true
-		t.removeUDPConn(connID)
+type udpWriter struct {
+	stack *stack.Stack
+	src   net.Destination
+	dest  net.Destination
+}
+
+func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
+	for _, b := range mb {
+		if b.UDP != nil {
+			w.src = *b.UDP
+		}
+		
+		netProto := header.IPv4ProtocolNumber
+		if !w.src.Address.Family().IsIPv4() {
+			netProto = header.IPv6ProtocolNumber
+		}
+		
+		route, err := w.stack.FindRoute(defaultNIC, tcpip.AddrFromSlice(w.src.Address.IP()), tcpip.AddrFromSlice(w.dest.Address.IP()), netProto, false)
+		if err != nil {
+			b.Release()
+			continue
+		}
+		
+		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+			ReserveHeaderBytes: header.UDPMinimumSize,
+			Payload:            buffer.MakeWithData(b.Bytes()),
+		})
+		
+		length := uint16(pkt.Size())
+		udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+		udpHeader.Encode(&header.UDPFields{
+			SrcPort: uint16(w.src.Port),
+			DstPort: uint16(w.dest.Port),
+			Length:  length,
+		})
+		udpHeader.SetChecksum(^udpHeader.CalculateChecksum(checksum.Checksum(b.Bytes(), route.PseudoHeaderChecksum(header.UDPProtocolNumber, length))))
+		
+		route.WritePacket(stack.NetworkHeaderParams{Protocol: header.UDPProtocolNumber, TTL: 64}, pkt)
+		pkt.DecRef()
+		route.Release()
+		b.Release()
 	}
+	return nil
 }
 
 // Init the Handler instance with necessary parameters
@@ -290,13 +185,8 @@ func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routin
 	t.policyManager = pm
 	t.dispatcher = dispatcher
 	t.cone = ctx.Value("cone").(bool)
-	
-	// Initialize UDP connection manager
-	t.udpConns = make(map[udpConnID]*udpConn)
-	t.udpChecker = &task.Periodic{
-		Interval: time.Minute,
-		Execute:  t.cleanupUDPConns,
-	}
+	t.udpConns = make(map[net.Destination]*udpConn)
+	t.udpChecker = &task.Periodic{Interval: time.Minute, Execute: t.cleanupUDP}
 
 	tunName := t.config.Name
 	tunOptions := TunOptions{
@@ -357,20 +247,10 @@ func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) {
 	ctx = session.ContextWithInbound(ctx, &inbound)
 	ctx = session.SubContextFromMuxInbound(ctx)
 
-	var link *transport.Link
-	if destination.Network == net.Network_UDP {
-		// For UDP, use PacketReader to preserve packet boundaries
-		link = &transport.Link{
-			Reader: buf.NewPacketReader(conn),
-			Writer: buf.NewWriter(conn),
-		}
-	} else {
-		link = &transport.Link{
-			Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},
-			Writer: buf.NewWriter(conn),
-		}
+	link := &transport.Link{
+		Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},
+		Writer: buf.NewWriter(conn),
 	}
-
 	if err := t.dispatcher.DispatchLink(ctx, destination, link); err != nil {
 		errors.LogError(ctx, errors.New("connection closed").Base(err))
 		return