Просмотр исходного кода

Use b.UDP as source in return packets with proper routing

- Use b.UDP (actual response origin) as source address when available
- Build route from srcAddr (b.UDP or w.dest) to w.src (original client)
- Set UDP header with srcAddr.Port as source, w.src.Port as destination
- Validates address family matches w.src instead of w.dest
- Properly implements NAT with custom source addresses

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 месяцев назад
Родитель
Сommit
06ac8c1799
1 измененных файлов с 13 добавлено и 6 удалено
  1. 13 6
      proxy/tun/handler.go

+ 13 - 6
proxy/tun/handler.go

@@ -142,9 +142,15 @@ type udpWriter struct {
 
 func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 	for _, b := range mb {
-		// Validate return packet address family matches expected destination
-		if b.UDP != nil && b.UDP.Address.Family() != w.dest.Address.Family() {
-			errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.dest.Address.Family(), ", got ", b.UDP.Address.Family())
+		// Use b.UDP as source if available, otherwise use w.dest
+		srcAddr := w.dest
+		if b.UDP != nil {
+			srcAddr = *b.UDP
+		}
+
+		// Validate address family matches
+		if srcAddr.Address.Family() != w.src.Address.Family() {
+			errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family())
 			b.Release()
 			continue
 		}
@@ -154,10 +160,11 @@ func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 			netProto = header.IPv6ProtocolNumber
 		}
 
+		// Build route from actual response source to original client
 		route, err := w.stack.FindRoute(
 			defaultNIC,
+			tcpip.AddrFromSlice(srcAddr.Address.IP()),
 			tcpip.AddrFromSlice(w.src.Address.IP()),
-			tcpip.AddrFromSlice(w.dest.Address.IP()),
 			netProto,
 			false,
 		)
@@ -172,8 +179,8 @@ func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		})
 		udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
 		udp.Encode(&header.UDPFields{
-			SrcPort: uint16(w.src.Port),
-			DstPort: uint16(w.dest.Port),
+			SrcPort: uint16(srcAddr.Port),
+			DstPort: uint16(w.src.Port),
 			Length:  uint16(pkt.Size()),
 		})
 		xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, uint16(pkt.Size()))