Просмотр исходного кода

Move gVisor-specific UDP code to stack_gvisor.go

- Move HandleUDPPacket function to stack_gvisor.go
- Move udpWriter struct and WriteMultiBuffer method to stack_gvisor.go
- Keep handler.go clean of gVisor implementation details
- Add necessary imports (buffer, checksum, pipe, session, etc.) to stack_gvisor.go
- Maintains identical functionality with better code organization

Co-authored-by: Fangliding <[email protected]>
copilot-swe-agent[bot] 5 месяцев назад
Родитель
Сommit
71fb2af26e
2 измененных файлов с 154 добавлено и 147 удалено
  1. 0 147
      proxy/tun/handler.go
  2. 154 0
      proxy/tun/stack_gvisor.go

+ 0 - 147
proxy/tun/handler.go

@@ -20,12 +20,6 @@ import (
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport/internet/stat"
-	"github.com/xtls/xray-core/transport/pipe"
-	"gvisor.dev/gvisor/pkg/buffer"
-	"gvisor.dev/gvisor/pkg/tcpip"
-	"gvisor.dev/gvisor/pkg/tcpip/checksum"
-	"gvisor.dev/gvisor/pkg/tcpip/header"
-	"gvisor.dev/gvisor/pkg/tcpip/stack"
 )
 
 type udpConn struct {
@@ -78,148 +72,7 @@ func (t *Handler) cleanupUDP() error {
 	return nil
 }
 
-func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
-	src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
-	dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
-	data := pkt.Data().AsRange().ToSlice()
-	if len(data) == 0 {
-		return
-	}
-
-	t.Lock()
-	conn, found := t.udpConns[src]
-	if !found {
-		reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
-		conn = &udpConn{reader: reader, writer: writer, done: done.New()}
-		t.udpConns[src] = conn
-		if t.udpChecker != nil && len(t.udpConns) == 1 {
-			common.Must(t.udpChecker.Start())
-		}
-		t.Unlock()
-
-		go func() {
-			ctx, cancel := context.WithCancel(t.ctx)
-			conn.cancel = cancel
-			defer func() {
-				cancel()
-				t.Lock()
-				delete(t.udpConns, src)
-				t.Unlock()
-				common.Must(conn.done.Close())
-				common.Must(common.Close(conn.writer))
-			}()
-
-			inbound := &session.Inbound{
-				Name:          "tun",
-				Source:        src,
-				CanSpliceCopy: 1,
-				User:          &protocol.MemoryUser{Level: t.config.UserLevel},
-			}
-			ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound)
-			ctx = session.SubContextFromMuxInbound(ctx)
-			link := &transport.Link{
-				Reader: &buf.TimeoutWrapperReader{Reader: conn.reader},
-				Writer: &udpWriter{stack: ipStack, src: dest, dest: src},
-			}
-			t.dispatcher.DispatchLink(ctx, dest, link)
-		}()
-	} else {
-		conn.lastActive.Store(time.Now().Unix())
-		t.Unlock()
-	}
-
-	b := buf.New()
-	b.Write(data)
-	b.UDP = &dest
-	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
-}
-
-type udpWriter struct {
-	stack *stack.Stack
-	src   net.Destination
-	dest  net.Destination
-}
-
-func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
-	for _, b := range mb {
-		// Use b.UDP as source if available, otherwise use w.dest
-		srcAddr := w.dest
-		if b.UDP != nil {
-			srcAddr = *b.UDP
-		}
-
-		// Validate address family matches
-		if srcAddr.Address.Family() != w.src.Address.Family() {
-			errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family())
-			b.Release()
-			continue
-		}
-
-		payload := b.Bytes()
-		udpLen := header.UDPMinimumSize + len(payload)
-		srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP())
-		dstIP := tcpip.AddrFromSlice(w.src.Address.IP())
-
-		// Build packet with appropriate IP header size
-		isIPv4 := w.src.Address.Family().IsIPv4()
-		ipHdrSize := header.IPv6MinimumSize
-		netProto := header.IPv6ProtocolNumber
-		if isIPv4 {
-			ipHdrSize = header.IPv4MinimumSize
-			netProto = header.IPv4ProtocolNumber
-		}
-
-		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
-			ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
-			Payload:            buffer.MakeWithData(payload),
-		})
-
-		// Build UDP header
-		udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
-		udpHdr.Encode(&header.UDPFields{
-			SrcPort: uint16(srcAddr.Port),
-			DstPort: uint16(w.src.Port),
-			Length:  uint16(udpLen),
-		})
 
-		// Calculate and set UDP checksum
-		xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
-		udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
-
-		// Build IP header
-		if isIPv4 {
-			ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
-			ipHdr.Encode(&header.IPv4Fields{
-				TotalLength: uint16(header.IPv4MinimumSize + udpLen),
-				TTL:         64,
-				Protocol:    uint8(header.UDPProtocolNumber),
-				SrcAddr:     srcIP,
-				DstAddr:     dstIP,
-			})
-			ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
-		} else {
-			ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
-			ipHdr.Encode(&header.IPv6Fields{
-				PayloadLength:     uint16(udpLen),
-				TransportProtocol: header.UDPProtocolNumber,
-				HopLimit:          64,
-				SrcAddr:           srcIP,
-				DstAddr:           dstIP,
-			})
-		}
-
-		// Write raw packet to network stack
-		views := pkt.AsSlices()
-		var data []byte
-		for _, view := range views {
-			data = append(data, view...)
-		}
-		w.stack.WriteRawPacket(defaultNIC, netProto, buffer.MakeWithData(data))
-		pkt.DecRef()
-		b.Release()
-	}
-	return nil
-}
 
 // Init the Handler instance with necessary parameters
 func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error {

+ 154 - 0
proxy/tun/stack_gvisor.go

@@ -4,10 +4,20 @@ import (
 	"context"
 	"time"
 
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
+	c "github.com/xtls/xray-core/common/ctx"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/signal/done"
+	"github.com/xtls/xray-core/transport"
+	"github.com/xtls/xray-core/transport/pipe"
+	"gvisor.dev/gvisor/pkg/buffer"
 	"gvisor.dev/gvisor/pkg/tcpip"
 	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/checksum"
 	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -183,3 +193,147 @@ func createStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
 
 	return gStack, nil
 }
+
+// HandleUDPPacket handles incoming UDP packets for FullCone NAT
+func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
+src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
+dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
+data := pkt.Data().AsRange().ToSlice()
+if len(data) == 0 {
+return
+}
+
+t.Lock()
+conn, found := t.udpConns[src]
+if !found {
+reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
+conn = &udpConn{reader: reader, writer: writer, done: done.New()}
+t.udpConns[src] = conn
+if t.udpChecker != nil && len(t.udpConns) == 1 {
+common.Must(t.udpChecker.Start())
+}
+t.Unlock()
+
+go func() {
+ctx, cancel := context.WithCancel(t.ctx)
+conn.cancel = cancel
+defer func() {
+cancel()
+t.Lock()
+delete(t.udpConns, src)
+t.Unlock()
+common.Must(conn.done.Close())
+common.Must(common.Close(conn.writer))
+}()
+
+inbound := &session.Inbound{
+Name:          "tun",
+Source:        src,
+CanSpliceCopy: 1,
+User:          &protocol.MemoryUser{Level: t.config.UserLevel},
+}
+ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound)
+ctx = session.SubContextFromMuxInbound(ctx)
+link := &transport.Link{
+Reader: &buf.TimeoutWrapperReader{Reader: conn.reader},
+Writer: &udpWriter{stack: ipStack, src: dest, dest: src},
+}
+t.dispatcher.DispatchLink(ctx, dest, link)
+}()
+} else {
+conn.lastActive.Store(time.Now().Unix())
+t.Unlock()
+}
+
+b := buf.New()
+b.Write(data)
+b.UDP = &dest
+conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
+}
+
+type udpWriter struct {
+stack *stack.Stack
+src   net.Destination
+dest  net.Destination
+}
+
+func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
+for _, b := range mb {
+// Use b.UDP as source if available, otherwise use w.dest
+srcAddr := w.dest
+if b.UDP != nil {
+srcAddr = *b.UDP
+}
+
+// Validate address family matches
+if srcAddr.Address.Family() != w.src.Address.Family() {
+errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family())
+b.Release()
+continue
+}
+
+payload := b.Bytes()
+udpLen := header.UDPMinimumSize + len(payload)
+srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP())
+dstIP := tcpip.AddrFromSlice(w.src.Address.IP())
+
+// Build packet with appropriate IP header size
+isIPv4 := w.src.Address.Family().IsIPv4()
+ipHdrSize := header.IPv6MinimumSize
+netProto := header.IPv6ProtocolNumber
+if isIPv4 {
+ipHdrSize = header.IPv4MinimumSize
+netProto = header.IPv4ProtocolNumber
+}
+
+pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
+Payload:            buffer.MakeWithData(payload),
+})
+
+// Build UDP header
+udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+udpHdr.Encode(&header.UDPFields{
+SrcPort: uint16(srcAddr.Port),
+DstPort: uint16(w.src.Port),
+Length:  uint16(udpLen),
+})
+
+// Calculate and set UDP checksum
+xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
+udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
+
+// Build IP header
+if isIPv4 {
+ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+ipHdr.Encode(&header.IPv4Fields{
+TotalLength: uint16(header.IPv4MinimumSize + udpLen),
+TTL:         64,
+Protocol:    uint8(header.UDPProtocolNumber),
+SrcAddr:     srcIP,
+DstAddr:     dstIP,
+})
+ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
+} else {
+ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ipHdr.Encode(&header.IPv6Fields{
+PayloadLength:     uint16(udpLen),
+TransportProtocol: header.UDPProtocolNumber,
+HopLimit:          64,
+SrcAddr:           srcIP,
+DstAddr:           dstIP,
+})
+}
+
+// Write raw packet to network stack
+views := pkt.AsSlices()
+var data []byte
+for _, view := range views {
+data = append(data, view...)
+}
+w.stack.WriteRawPacket(defaultNIC, netProto, buffer.MakeWithData(data))
+pkt.DecRef()
+b.Release()
+}
+return nil
+}