Преглед изворни кода

Proxy: Tun: FullCone NAT: extract udp connection handler into separate file
Split handler/gvisor/udp connection handler as much as possible, reducing cross dependencies between handler and gVisor version of stack

Owersun пре 5 месеци
родитељ
комит
882fbd9a0c
3 измењених фајлова са 243 додато и 192 уклоњено
  1. 0 36
      proxy/tun/handler.go
  2. 27 156
      proxy/tun/stack_gvisor.go
  3. 216 0
      proxy/tun/udp_connection.go

+ 0 - 36
proxy/tun/handler.go

@@ -2,9 +2,6 @@ package tun
 
 
 import (
 import (
 	"context"
 	"context"
-	"sync"
-	"sync/atomic"
-	"time"
 
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/buf"
@@ -13,8 +10,6 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal/done"
-	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/features/routing"
@@ -22,24 +17,13 @@ import (
 	"github.com/xtls/xray-core/transport/internet/stat"
 	"github.com/xtls/xray-core/transport/internet/stat"
 )
 )
 
 
-type udpConn struct {
-	lastActive atomic.Int64
-	reader     buf.Reader
-	writer     buf.Writer
-	done       *done.Instance
-	cancel     context.CancelFunc
-}
-
 // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
 // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
 type Handler struct {
 type Handler struct {
-	sync.Mutex
 	ctx           context.Context
 	ctx           context.Context
 	config        *Config
 	config        *Config
 	stack         Stack
 	stack         Stack
 	policyManager policy.Manager
 	policyManager policy.Manager
 	dispatcher    routing.Dispatcher
 	dispatcher    routing.Dispatcher
-	udpConns      map[net.Destination]*udpConn
-	udpChecker    *task.Periodic
 }
 }
 
 
 // ConnectionHandler interface with the only method that stack is going to push new connections to
 // ConnectionHandler interface with the only method that stack is going to push new connections to
@@ -54,24 +38,6 @@ func (t *Handler) policy() policy.Session {
 	return t.policyManager.ForLevel(t.config.UserLevel)
 	return t.policyManager.ForLevel(t.config.UserLevel)
 }
 }
 
 
-func (t *Handler) cleanupUDP() error {
-	t.Lock()
-	defer t.Unlock()
-	if len(t.udpConns) == 0 {
-		return errors.New("no connections")
-	}
-	now := time.Now().Unix()
-	for src, conn := range t.udpConns {
-		if now-conn.lastActive.Load() > 300 {
-			conn.cancel()
-			common.Must(conn.done.Close())
-			common.Must(common.Close(conn.writer))
-			delete(t.udpConns, src)
-		}
-	}
-	return nil
-}
-
 // Init the Handler instance with necessary parameters
 // Init the Handler instance with necessary parameters
 func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error {
 func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error {
 	var err error
 	var err error
@@ -79,8 +45,6 @@ func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routin
 	t.ctx = core.ToBackgroundDetachedContext(ctx)
 	t.ctx = core.ToBackgroundDetachedContext(ctx)
 	t.policyManager = pm
 	t.policyManager = pm
 	t.dispatcher = dispatcher
 	t.dispatcher = dispatcher
-	t.udpConns = make(map[net.Destination]*udpConn)
-	t.udpChecker = &task.Periodic{Interval: time.Minute, Execute: t.cleanupUDP}
 
 
 	tunName := t.config.Name
 	tunName := t.config.Name
 	tunOptions := TunOptions{
 	tunOptions := TunOptions{

+ 27 - 156
proxy/tun/stack_gvisor.go

@@ -4,20 +4,11 @@ import (
 	"context"
 	"context"
 	"time"
 	"time"
 
 
-	"github.com/xtls/xray-core/common"
-	"github.com/xtls/xray-core/common/buf"
-	c "github.com/xtls/xray-core/common/ctx"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
-	"github.com/xtls/xray-core/common/protocol"
-	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal/done"
-	"github.com/xtls/xray-core/transport"
-	"github.com/xtls/xray-core/transport/pipe"
 	"gvisor.dev/gvisor/pkg/buffer"
 	"gvisor.dev/gvisor/pkg/buffer"
 	"gvisor.dev/gvisor/pkg/tcpip"
 	"gvisor.dev/gvisor/pkg/tcpip"
 	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
 	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
-	"gvisor.dev/gvisor/pkg/tcpip/checksum"
 	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -110,10 +101,34 @@ func (t *stackGVisor) Start() error {
 	})
 	})
 	ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
 	ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
 
 
-	// Use custom UDP packet handler instead of forwarder for FullCone NAT
+	// Use custom UDP packet handler, instead of strict gVisor forwarder, for FullCone NAT support
+	udpForwarder := newUdpConnectionHandler(t.ctx, t.handler, func(p []byte) {
+		// extract network protocol from the packet
+		var networkProtocol tcpip.NetworkProtocolNumber
+		switch header.IPVersion(p) {
+		case header.IPv4Version:
+			networkProtocol = header.IPv4ProtocolNumber
+		case header.IPv6Version:
+			networkProtocol = header.IPv6ProtocolNumber
+		default:
+			// discard packet with unknown network version
+			return
+		}
+
+		ipStack.WriteRawPacket(defaultNIC, networkProtocol, buffer.MakeWithData(p))
+	})
 	ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
 	ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
-		t.handler.HandleUDPPacket(id, pkt, ipStack)
-		return true
+		data := pkt.Data().AsRange().ToSlice()
+		if len(data) == 0 {
+			return false
+		}
+		// source/destination of the packet we process as incoming, on gVisor side are Remote/Local
+		// in other terms, src is the side behind tun, dst is the side behind gVisor
+		// this function handle packets passing from the tun to the gVisor, therefore the src/dst assignement
+		src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
+		dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
+
+		return udpForwarder.HandlePacket(src, dst, data)
 	})
 	})
 
 
 	t.stack = ipStack
 	t.stack = ipStack
@@ -193,147 +208,3 @@ func createStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
 
 
 	return gStack, nil
 	return gStack, nil
 }
 }
-
-// HandleUDPPacket handles incoming UDP packets for FullCone NAT
-func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
-	src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
-	dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
-	data := pkt.Data().AsRange().ToSlice()
-	if len(data) == 0 {
-		return
-	}
-
-	t.Lock()
-	conn, found := t.udpConns[src]
-	if !found {
-		reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
-		conn = &udpConn{reader: reader, writer: writer, done: done.New()}
-		t.udpConns[src] = conn
-		if t.udpChecker != nil && len(t.udpConns) == 1 {
-			common.Must(t.udpChecker.Start())
-		}
-		t.Unlock()
-
-		go func() {
-			ctx, cancel := context.WithCancel(t.ctx)
-			conn.cancel = cancel
-			defer func() {
-				cancel()
-				t.Lock()
-				delete(t.udpConns, src)
-				t.Unlock()
-				common.Must(conn.done.Close())
-				common.Must(common.Close(conn.writer))
-			}()
-
-			inbound := &session.Inbound{
-				Name:          "tun",
-				Source:        src,
-				CanSpliceCopy: 1,
-				User:          &protocol.MemoryUser{Level: t.config.UserLevel},
-			}
-			ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound)
-			ctx = session.SubContextFromMuxInbound(ctx)
-			link := &transport.Link{
-				Reader: &buf.TimeoutWrapperReader{Reader: conn.reader},
-				Writer: &udpWriter{stack: ipStack, src: dest, dest: src},
-			}
-			t.dispatcher.DispatchLink(ctx, dest, link)
-		}()
-	} else {
-		conn.lastActive.Store(time.Now().Unix())
-		t.Unlock()
-	}
-
-	b := buf.New()
-	b.Write(data)
-	b.UDP = &dest
-	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
-}
-
-type udpWriter struct {
-	stack *stack.Stack
-	src   net.Destination
-	dest  net.Destination
-}
-
-func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
-	for _, b := range mb {
-		// Use b.UDP as source if available, otherwise use w.dest
-		srcAddr := w.dest
-		if b.UDP != nil {
-			srcAddr = *b.UDP
-		}
-
-		// Validate address family matches
-		if srcAddr.Address.Family() != w.src.Address.Family() {
-			errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family())
-			b.Release()
-			continue
-		}
-
-		payload := b.Bytes()
-		udpLen := header.UDPMinimumSize + len(payload)
-		srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP())
-		dstIP := tcpip.AddrFromSlice(w.src.Address.IP())
-
-		// Build packet with appropriate IP header size
-		isIPv4 := w.src.Address.Family().IsIPv4()
-		ipHdrSize := header.IPv6MinimumSize
-		netProto := header.IPv6ProtocolNumber
-		if isIPv4 {
-			ipHdrSize = header.IPv4MinimumSize
-			netProto = header.IPv4ProtocolNumber
-		}
-
-		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
-			ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
-			Payload:            buffer.MakeWithData(payload),
-		})
-
-		// Build UDP header
-		udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
-		udpHdr.Encode(&header.UDPFields{
-			SrcPort: uint16(srcAddr.Port),
-			DstPort: uint16(w.src.Port),
-			Length:  uint16(udpLen),
-		})
-
-		// Calculate and set UDP checksum
-		xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
-		udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
-
-		// Build IP header
-		if isIPv4 {
-			ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
-			ipHdr.Encode(&header.IPv4Fields{
-				TotalLength: uint16(header.IPv4MinimumSize + udpLen),
-				TTL:         64,
-				Protocol:    uint8(header.UDPProtocolNumber),
-				SrcAddr:     srcIP,
-				DstAddr:     dstIP,
-			})
-			ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
-		} else {
-			ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
-			ipHdr.Encode(&header.IPv6Fields{
-				PayloadLength:     uint16(udpLen),
-				TransportProtocol: header.UDPProtocolNumber,
-				HopLimit:          64,
-				SrcAddr:           srcIP,
-				DstAddr:           dstIP,
-			})
-		}
-
-		// Write raw packet to network stack
-		views := pkt.AsSlices()
-		var data []byte
-		for _, view := range views {
-			data = append(data, view...)
-		}
-		w.stack.WriteRawPacket(defaultNIC, netProto, buffer.MakeWithData(data))
-		pkt.DecRef()
-		b.Release()
-	}
-	return nil
-}

+ 216 - 0
proxy/tun/udp_connection.go

@@ -0,0 +1,216 @@
+package tun
+
+import (
+	"context"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
+	c "github.com/xtls/xray-core/common/ctx"
+	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/signal/done"
+	"github.com/xtls/xray-core/common/task"
+	"github.com/xtls/xray-core/transport"
+	"github.com/xtls/xray-core/transport/pipe"
+	"gvisor.dev/gvisor/pkg/buffer"
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/checksum"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// udp connection abstraction
+type udpConn struct {
+	lastActive atomic.Int64
+	reader     buf.Reader
+	writer     buf.Writer
+	done       *done.Instance
+	cancel     context.CancelFunc
+}
+
+// sub-handler specifically for udp connections under main handler
+type udpConnectionHandler struct {
+	sync.Mutex
+	ctx         context.Context
+	handler     *Handler
+	udpConns    map[net.Destination]*udpConn
+	udpChecker  *task.Periodic
+	writePacket func(p []byte)
+}
+
+func newUdpConnectionHandler(ctx context.Context, h *Handler, writePacket func(p []byte)) *udpConnectionHandler {
+	handler := &udpConnectionHandler{
+		ctx:         ctx,
+		handler:     h,
+		udpConns:    make(map[net.Destination]*udpConn),
+		writePacket: writePacket,
+	}
+
+	handler.udpChecker = &task.Periodic{Interval: time.Minute, Execute: handler.cleanupUDP}
+	handler.udpChecker.Start()
+
+	return handler
+}
+
+func (u *udpConnectionHandler) cleanupUDP() error {
+	u.Lock()
+	defer u.Unlock()
+	if len(u.udpConns) == 0 {
+		return errors.New("no connections")
+	}
+	now := time.Now().Unix()
+	for src, conn := range u.udpConns {
+		if now-conn.lastActive.Load() > 300 {
+			conn.cancel()
+			common.Must(conn.done.Close())
+			common.Must(common.Close(conn.writer))
+			delete(u.udpConns, src)
+		}
+	}
+	return nil
+}
+
+// HandlePacket handles UDP packets coming from tun, to forward to the dispatcher
+// this custom handler support FullCone NAT of returning packets, binding connection only by the source port
+func (u *udpConnectionHandler) HandlePacket(src net.Destination, dst net.Destination, data []byte) bool {
+	u.Lock()
+	conn, found := u.udpConns[src]
+	if !found {
+		reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
+		conn = &udpConn{reader: reader, writer: writer, done: done.New()}
+		u.udpConns[src] = conn
+		u.Unlock()
+
+		go func() {
+			ctx, cancel := context.WithCancel(u.ctx)
+			conn.cancel = cancel
+			defer func() {
+				cancel()
+				u.Lock()
+				delete(u.udpConns, src)
+				u.Unlock()
+				common.Must(conn.done.Close())
+				common.Must(common.Close(conn.writer))
+			}()
+
+			inbound := &session.Inbound{
+				Name:          "tun",
+				Source:        src,
+				CanSpliceCopy: 1,
+				User:          &protocol.MemoryUser{Level: u.handler.config.UserLevel},
+			}
+			ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound)
+			ctx = session.SubContextFromMuxInbound(ctx)
+			link := &transport.Link{
+				Reader: &buf.TimeoutWrapperReader{Reader: conn.reader},
+				// reverse source and destination, indicating the packets to write are going in the other
+				// direction (written back to tun) and should have reversed addressing
+				Writer: &udpWriter{handler: u, src: dst, dst: src},
+			}
+			_ = u.handler.dispatcher.DispatchLink(ctx, dst, link)
+		}()
+	} else {
+		conn.lastActive.Store(time.Now().Unix())
+		u.Unlock()
+	}
+
+	b := buf.New()
+	b.Write(data)
+	b.UDP = &dst
+	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
+
+	return true
+}
+
+type udpWriter struct {
+	handler *udpConnectionHandler
+	// address in the side of stack, where packet will be coming from
+	src net.Destination
+	// address on the side of tun, where packet will be destined to
+	dst net.Destination
+}
+
+func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
+	for _, b := range mb {
+		// use captured in the dispatched packet source address b.UDP as source, if available,
+		// otherwise use captured in the writer source w.src
+		srcAddr := w.src
+		if b.UDP != nil {
+			srcAddr = *b.UDP
+		}
+
+		// validate address family matches
+		if srcAddr.Address.Family() != w.src.Address.Family() {
+			errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family())
+			b.Release()
+			continue
+		}
+
+		payload := b.Bytes()
+		udpLen := header.UDPMinimumSize + len(payload)
+		srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP())
+		dstIP := tcpip.AddrFromSlice(w.dst.Address.IP())
+
+		// build packet with appropriate IP header size
+		isIPv4 := srcAddr.Address.Family().IsIPv4()
+		ipHdrSize := header.IPv6MinimumSize
+		if isIPv4 {
+			ipHdrSize = header.IPv4MinimumSize
+		}
+
+		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+			ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
+			Payload:            buffer.MakeWithData(payload),
+		})
+
+		// Build UDP header
+		udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+		udpHdr.Encode(&header.UDPFields{
+			SrcPort: uint16(srcAddr.Port),
+			DstPort: uint16(w.dst.Port),
+			Length:  uint16(udpLen),
+		})
+
+		// Calculate and set UDP checksum
+		xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
+		udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
+
+		// Build IP header
+		if isIPv4 {
+			ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+			ipHdr.Encode(&header.IPv4Fields{
+				TotalLength: uint16(header.IPv4MinimumSize + udpLen),
+				TTL:         64,
+				Protocol:    uint8(header.UDPProtocolNumber),
+				SrcAddr:     srcIP,
+				DstAddr:     dstIP,
+			})
+			ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
+		} else {
+			ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+			ipHdr.Encode(&header.IPv6Fields{
+				PayloadLength:     uint16(udpLen),
+				TransportProtocol: header.UDPProtocolNumber,
+				HopLimit:          64,
+				SrcAddr:           srcIP,
+				DstAddr:           dstIP,
+			})
+		}
+
+		// Write raw packet to network stack
+		views := pkt.AsSlices()
+		var data []byte
+		for _, view := range views {
+			data = append(data, view...)
+		}
+		w.handler.writePacket(data)
+		pkt.DecRef()
+		b.Release()
+	}
+	return nil
+}