Просмотр исходного кода

Simplify UDP packet writing using Route.WritePacket

Replace manual IP header construction with gVisor's Route API:
- Use Stack.FindRoute() to create proper route
- Use Route.WritePacket() with NetworkHeaderParams
- Let gVisor handle IP header construction
- Simpler and more maintainable code

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 месяцев назад
Родитель
Сommit
8e2d358564
1 измененных файлов с 38 добавлено и 76 удалено
  1. 38 76
      proxy/tun/handler.go

+ 38 - 76
proxy/tun/handler.go

@@ -212,91 +212,53 @@ func (t *Handler) cleanupUDPConns() error {
 
 // writeUDPPacket writes a UDP packet back to the gVisor stack with custom source address
 func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source net.Destination) (int, error) {
-	// Build UDP+IP packet with proper headers using gVisor's header builders
-	
-	// Determine IP version
-	var ipHdrLen, udpHdrLen int
-	isIPv4 := dest.Address.Family().IsIPv4()
-	
-	if isIPv4 {
-		ipHdrLen = header.IPv4MinimumSize
+	// Create a route from dest (our local) to source (remote)
+	var netProto tcpip.NetworkProtocolNumber
+	if dest.Address.Family().IsIPv4() {
+		netProto = header.IPv4ProtocolNumber
 	} else {
-		ipHdrLen = header.IPv6MinimumSize
+		netProto = header.IPv6ProtocolNumber
 	}
-	udpHdrLen = header.UDPMinimumSize
 	
-	totalLen := ipHdrLen + udpHdrLen + len(data)
-	packet := make([]byte, totalLen)
+	route, err := ipStack.FindRoute(
+		defaultNIC,
+		tcpip.AddrFromSlice(dest.Address.IP()),
+		tcpip.AddrFromSlice(source.Address.IP()),
+		netProto,
+		false,
+	)
+	if err != nil {
+		return 0, errors.New("failed to find route: " + err.String())
+	}
+	defer route.Release()
+	
+	// Create packet buffer with UDP payload
+	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+		ReserveHeaderBytes: header.UDPMinimumSize,
+		Payload:            buffer.MakeWithData(data),
+	})
+	defer pkt.DecRef()
 	
 	// Build UDP header
-	udpHeader := header.UDP(packet[ipHdrLen:])
+	udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+	length := uint16(pkt.Size())
 	udpHeader.Encode(&header.UDPFields{
-		SrcPort: uint16(dest.Port),  // Source is the original destination
-		DstPort: uint16(source.Port), // Destination is the original source
-		Length:  uint16(udpHdrLen + len(data)),
+		SrcPort: uint16(dest.Port),
+		DstPort: uint16(source.Port),
+		Length:  length,
 	})
 	
-	// Copy payload
-	copy(packet[ipHdrLen+udpHdrLen:], data)
-	
-	// Build IP header and calculate checksums
-	if isIPv4 {
-		ipv4Header := header.IPv4(packet)
-		ipv4Header.Encode(&header.IPv4Fields{
-			TOS:         0,
-			TotalLength: uint16(totalLen),
-			ID:          0,
-			Flags:       0,
-			FragmentOffset: 0,
-			TTL:         64,
-			Protocol:    uint8(header.UDPProtocolNumber),
-			SrcAddr:     tcpip.AddrFromSlice(dest.Address.IP()),
-			DstAddr:     tcpip.AddrFromSlice(source.Address.IP()),
-		})
-		ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
-		
-		// Calculate UDP checksum
-		xsum := header.PseudoHeaderChecksum(
-			header.UDPProtocolNumber,
-			tcpip.AddrFromSlice(dest.Address.IP()),
-			tcpip.AddrFromSlice(source.Address.IP()),
-			uint16(udpHdrLen+len(data)),
-		)
-		xsum = checksum.Checksum(data, xsum)
-		udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
-	} else {
-		ipv6Header := header.IPv6(packet)
-		ipv6Header.Encode(&header.IPv6Fields{
-			TrafficClass:  0,
-			FlowLabel:     0,
-			PayloadLength: uint16(udpHdrLen + len(data)),
-			TransportProtocol: header.UDPProtocolNumber,
-			HopLimit:      64,
-			SrcAddr:       tcpip.AddrFromSlice(dest.Address.IP()),
-			DstAddr:       tcpip.AddrFromSlice(source.Address.IP()),
-		})
-		
-		// Calculate UDP checksum for IPv6
-		xsum := header.PseudoHeaderChecksum(
-			header.UDPProtocolNumber,
-			tcpip.AddrFromSlice(dest.Address.IP()),
-			tcpip.AddrFromSlice(source.Address.IP()),
-			uint16(udpHdrLen+len(data)),
-		)
-		xsum = checksum.Checksum(data, xsum)
-		udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
-	}
-	
-	// Write packet to stack
-	var proto tcpip.NetworkProtocolNumber
-	if isIPv4 {
-		proto = header.IPv4ProtocolNumber
-	} else {
-		proto = header.IPv6ProtocolNumber
-	}
+	// Calculate checksum
+	xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, length)
+	xsum = checksum.Checksum(data, xsum)
+	udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
 	
-	buf := buffer.MakeWithData(packet)
-	if err := ipStack.WriteRawPacket(defaultNIC, proto, buf); err != nil {
+	// Write packet through route
+	if err := route.WritePacket(stack.NetworkHeaderParams{
+		Protocol: header.UDPProtocolNumber,
+		TTL:      64,
+		TOS:      0,
+	}, pkt); err != nil {
 		return 0, errors.New("failed to write packet: " + err.String())
 	}