Просмотр исходного кода

Further simplify code structure

- Consolidate udpConn methods into single-line implementations
- Remove setInactive method, use direct field access
- Simplify writeUDPPacket variable declarations
- Extract handleUDPConn as separate method for clarity
- Reduce cleanupUDPConns redundancy
- Inline struct initializations in HandleUDPPacket

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 месяцев назад
Родитель
Сommit
551c17d991
1 измененных файлов с 50 добавлено и 112 удалено
  1. 50 112
      proxy/tun/handler.go

+ 50 - 112
proxy/tun/handler.go

@@ -47,15 +47,10 @@ type udpConn struct {
 	cancel           context.CancelFunc
 }
 
-func (c *udpConn) setInactive() {
-	c.inactive = true
-}
-
 func (c *udpConn) updateActivity() {
 	atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
 }
 
-// ReadMultiBuffer implements buf.Reader
 func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	mb, err := c.reader.ReadMultiBuffer()
 	if err != nil {
@@ -65,14 +60,7 @@ func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	return mb, nil
 }
 
-func (c *udpConn) Read(buf []byte) (int, error) {
-	return 0, errors.New("Read not supported, use ReadMultiBuffer instead")
-}
-
-// Write implements io.Writer
 func (c *udpConn) Write(data []byte) (int, error) {
-	// Extract destination from the first buffer if available
-	// For now, write with empty destination (will be filled by output function)
 	n, err := c.output(data, net.Destination{})
 	if err == nil {
 		c.updateActivity()
@@ -89,25 +77,12 @@ func (c *udpConn) Close() error {
 	return nil
 }
 
-func (c *udpConn) RemoteAddr() net.Addr {
-	return c.remote
-}
-
-func (c *udpConn) LocalAddr() net.Addr {
-	return c.local
-}
-
-func (*udpConn) SetDeadline(time.Time) error {
-	return nil
-}
-
-func (*udpConn) SetReadDeadline(time.Time) error {
-	return nil
-}
-
-func (*udpConn) SetWriteDeadline(time.Time) error {
-	return nil
-}
+func (c *udpConn) RemoteAddr() net.Addr              { return c.remote }
+func (c *udpConn) LocalAddr() net.Addr               { return c.local }
+func (c *udpConn) Read([]byte) (int, error)         { return 0, errors.New("not supported") }
+func (*udpConn) SetDeadline(time.Time) error        { return nil }
+func (*udpConn) SetReadDeadline(time.Time) error    { return nil }
+func (*udpConn) SetWriteDeadline(time.Time) error   { return nil }
 
 // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
 type Handler struct {
@@ -189,21 +164,19 @@ func (t *Handler) removeUDPConn(id udpConnID) {
 
 // cleanupUDPConns removes inactive UDP connections
 func (t *Handler) cleanupUDPConns() error {
-	nowSec := time.Now().Unix()
 	t.Lock()
 	defer t.Unlock()
 	
 	if len(t.udpConns) == 0 {
-		return errors.New("UDP connection cleanup stopped: no active connections remaining")
+		return errors.New("no active connections")
 	}
 	
+	nowSec := time.Now().Unix()
 	for id, conn := range t.udpConns {
-		if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 { // 5 minutes
-			if !conn.inactive {
-				conn.setInactive()
-				conn.Close()
-				delete(t.udpConns, id)
-			}
+		if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 && !conn.inactive {
+			conn.inactive = true
+			conn.Close()
+			delete(t.udpConns, id)
 		}
 	}
 	
@@ -212,48 +185,34 @@ func (t *Handler) cleanupUDPConns() error {
 
 // writeUDPPacket writes a UDP packet back to the gVisor stack with custom source address
 func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source net.Destination) (int, error) {
-	// Create a route from dest (our local) to source (remote)
-	var netProto tcpip.NetworkProtocolNumber
-	if dest.Address.Family().IsIPv4() {
-		netProto = header.IPv4ProtocolNumber
-	} else {
+	netProto := header.IPv4ProtocolNumber
+	if !dest.Address.Family().IsIPv4() {
 		netProto = header.IPv6ProtocolNumber
 	}
 	
-	route, err := ipStack.FindRoute(
-		defaultNIC,
-		tcpip.AddrFromSlice(dest.Address.IP()),
-		tcpip.AddrFromSlice(source.Address.IP()),
-		netProto,
-		false,
-	)
+	route, err := ipStack.FindRoute(defaultNIC, tcpip.AddrFromSlice(dest.Address.IP()), tcpip.AddrFromSlice(source.Address.IP()), netProto, false)
 	if err != nil {
 		return 0, errors.New("failed to find route: " + err.String())
 	}
 	defer route.Release()
 	
-	// Create packet buffer with UDP payload
 	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
 		ReserveHeaderBytes: header.UDPMinimumSize,
 		Payload:            buffer.MakeWithData(data),
 	})
 	defer pkt.DecRef()
 	
-	// Build UDP header
-	udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
 	length := uint16(pkt.Size())
+	udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
 	udpHeader.Encode(&header.UDPFields{
 		SrcPort: uint16(dest.Port),
 		DstPort: uint16(source.Port),
 		Length:  length,
 	})
 	
-	// Calculate checksum
 	xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, length)
-	xsum = checksum.Checksum(data, xsum)
-	udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
+	udpHeader.SetChecksum(^udpHeader.CalculateChecksum(checksum.Checksum(data, xsum)))
 	
-	// Write packet through route
 	if err := route.WritePacket(stack.NetworkHeaderParams{
 		Protocol: header.UDPProtocolNumber,
 		TTL:      64,
@@ -267,80 +226,59 @@ func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source
 
 // HandleUDPPacket processes a raw UDP packet from gVisor
 func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
-	// Extract packet information
-	source := net.UDPDestination(
-		net.IPAddress(id.RemoteAddress.AsSlice()),
-		net.Port(id.RemotePort),
-	)
-	dest := net.UDPDestination(
-		net.IPAddress(id.LocalAddress.AsSlice()),
-		net.Port(id.LocalPort),
-	)
+	source := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
+	dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
 	
-	// Extract UDP payload
 	data := pkt.Data().AsRange().ToSlice()
 	if len(data) == 0 {
 		return
 	}
 	
-	// Get or create connection for this source
 	conn, existing := t.getUDPConn(source, dest, ipStack)
 	
-	// Create buffer and set UDP destination
 	b := buf.New()
 	b.Write(data)
 	b.UDP = &dest
-	
-	// Write to connection pipe
 	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
 	
 	if !existing {
-		// Start checker for cleanup (only once)
 		t.Lock()
 		if t.udpChecker != nil && len(t.udpConns) == 1 {
 			common.Must(t.udpChecker.Start())
 		}
 		t.Unlock()
 		
-		// Start handling this connection
-		go func() {
-			connID := udpConnID{
-				src: source,
-			}
-			if !t.cone {
-				connID.dest = dest
-			}
-			
-			ctx, cancel := context.WithCancel(t.ctx)
-			conn.cancel = cancel
-			sid := session.NewID()
-			ctx = c.ContextWithID(ctx, sid)
-			
-			inbound := session.Inbound{}
-			inbound.Name = "tun"
-			inbound.Source = source
-			inbound.User = &protocol.MemoryUser{
-				Level: t.config.UserLevel,
-			}
-			
-			ctx = session.ContextWithInbound(ctx, &inbound)
-			ctx = session.SubContextFromMuxInbound(ctx)
-			
-			link := &transport.Link{
-				Reader: conn.reader,
-				Writer: buf.NewWriter(conn),
-			}
-			
-			if err := t.dispatcher.DispatchLink(ctx, dest, link); err != nil {
-				errors.LogError(ctx, errors.New("UDP connection ended").Base(err))
-			}
-			
-			conn.Close()
-			if !conn.inactive {
-				conn.setInactive()
-				t.removeUDPConn(connID)
-			}
-		}()
+		go t.handleUDPConn(conn, source, dest)
+	}
+}
+
+func (t *Handler) handleUDPConn(conn *udpConn, source, dest net.Destination) {
+	connID := udpConnID{src: source}
+	if !t.cone {
+		connID.dest = dest
+	}
+	
+	ctx, cancel := context.WithCancel(t.ctx)
+	conn.cancel = cancel
+	ctx = c.ContextWithID(ctx, session.NewID())
+	ctx = session.ContextWithInbound(ctx, &session.Inbound{
+		Name:   "tun",
+		Source: source,
+		User:   &protocol.MemoryUser{Level: t.config.UserLevel},
+	})
+	ctx = session.SubContextFromMuxInbound(ctx)
+	
+	if err := t.dispatcher.DispatchLink(ctx, dest, &transport.Link{
+		Reader: conn.reader,
+		Writer: buf.NewWriter(conn),
+	}); err != nil {
+		errors.LogError(ctx, errors.New("UDP connection ended").Base(err))
+	}
+	
+	conn.Close()
+	if !conn.inactive {
+		conn.inactive = true
+		t.removeUDPConn(connID)
 	}
 }