Просмотр исходного кода

Simplify packet building by reducing code duplication

- Extract common UDP header building logic
- Consolidate IP address conversion (srcIP, dstIP)
- Move netProto determination earlier to reduce duplication
- Remove redundant IPv4Fields zero values (TOS, ID, Flags, etc.)
- Remove redundant IPv6Fields zero values (TrafficClass, FlowLabel)
- Reduce from ~95 lines to ~60 lines in WriteMultiBuffer
- Maintain identical functionality with cleaner code structure

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 месяцев назад
Родитель
Сommit
8b9d631692
1 измененных файлов с 34 добавлено и 67 удалено
  1. 34 67
      proxy/tun/handler.go

+ 34 - 67
proxy/tun/handler.go

@@ -157,91 +157,58 @@ func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 
 		payload := b.Bytes()
 		udpLen := header.UDPMinimumSize + len(payload)
+		srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP())
+		dstIP := tcpip.AddrFromSlice(w.src.Address.IP())
+
+		// Build packet with appropriate IP header size
+		isIPv4 := w.src.Address.Family().IsIPv4()
+		ipHdrSize := header.IPv6MinimumSize
+		netProto := header.IPv6ProtocolNumber
+		if isIPv4 {
+			ipHdrSize = header.IPv4MinimumSize
+			netProto = header.IPv4ProtocolNumber
+		}
 
-		// Build complete packet with IP and UDP headers
-		var pkt *stack.PacketBuffer
-		if w.src.Address.Family().IsIPv4() {
-			// IPv4 packet
-			pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
-				ReserveHeaderBytes: header.IPv4MinimumSize + header.UDPMinimumSize,
-				Payload:            buffer.MakeWithData(payload),
-			})
-			
-			// Build UDP header
-			udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
-			udpHdr.Encode(&header.UDPFields{
-				SrcPort: uint16(srcAddr.Port),
-				DstPort: uint16(w.src.Port),
-				Length:  uint16(udpLen),
-			})
-			
-			// Calculate UDP checksum
-			xsum := header.PseudoHeaderChecksum(
-				header.UDPProtocolNumber,
-				tcpip.AddrFromSlice(srcAddr.Address.IP()),
-				tcpip.AddrFromSlice(w.src.Address.IP()),
-				uint16(udpLen),
-			)
-			udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
-			
-			// Build IPv4 header
+		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+			ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
+			Payload:            buffer.MakeWithData(payload),
+		})
+
+		// Build UDP header
+		udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+		udpHdr.Encode(&header.UDPFields{
+			SrcPort: uint16(srcAddr.Port),
+			DstPort: uint16(w.src.Port),
+			Length:  uint16(udpLen),
+		})
+
+		// Calculate and set UDP checksum
+		xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
+		udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
+
+		// Build IP header
+		if isIPv4 {
 			ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
 			ipHdr.Encode(&header.IPv4Fields{
-				TOS:         0,
 				TotalLength: uint16(header.IPv4MinimumSize + udpLen),
-				ID:          0,
-				Flags:       0,
-				FragmentOffset: 0,
 				TTL:         64,
 				Protocol:    uint8(header.UDPProtocolNumber),
-				SrcAddr:     tcpip.AddrFromSlice(srcAddr.Address.IP()),
-				DstAddr:     tcpip.AddrFromSlice(w.src.Address.IP()),
+				SrcAddr:     srcIP,
+				DstAddr:     dstIP,
 			})
 			ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
 		} else {
-			// IPv6 packet
-			pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
-				ReserveHeaderBytes: header.IPv6MinimumSize + header.UDPMinimumSize,
-				Payload:            buffer.MakeWithData(payload),
-			})
-			
-			// Build UDP header
-			udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
-			udpHdr.Encode(&header.UDPFields{
-				SrcPort: uint16(srcAddr.Port),
-				DstPort: uint16(w.src.Port),
-				Length:  uint16(udpLen),
-			})
-			
-			// Calculate UDP checksum
-			xsum := header.PseudoHeaderChecksum(
-				header.UDPProtocolNumber,
-				tcpip.AddrFromSlice(srcAddr.Address.IP()),
-				tcpip.AddrFromSlice(w.src.Address.IP()),
-				uint16(udpLen),
-			)
-			udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
-			
-			// Build IPv6 header
 			ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
 			ipHdr.Encode(&header.IPv6Fields{
-				TrafficClass:      0,
-				FlowLabel:         0,
 				PayloadLength:     uint16(udpLen),
 				TransportProtocol: header.UDPProtocolNumber,
 				HopLimit:          64,
-				SrcAddr:           tcpip.AddrFromSlice(srcAddr.Address.IP()),
-				DstAddr:           tcpip.AddrFromSlice(w.src.Address.IP()),
+				SrcAddr:           srcIP,
+				DstAddr:           dstIP,
 			})
 		}
 
 		// Write raw packet to network stack
-		netProto := header.IPv4ProtocolNumber
-		if !w.src.Address.Family().IsIPv4() {
-			netProto = header.IPv6ProtocolNumber
-		}
-		
-		// Get packet data and write as raw packet
 		views := pkt.AsSlices()
 		var data []byte
 		for _, view := range views {