Преглед изворни кода

TUN inbound: Make udp_fullcone pure side effect free udp connection (#5526)

https://github.com/XTLS/Xray-core/pull/5526#issue-3804306341

---------

Co-authored-by: RPRX <[email protected]>
Owersun пре 5 месеци
родитељ
комит
7726fbece0
3 измењених фајлова са 156 додато и 199 уклоњено
  1. 3 2
      proxy/tun/handler.go
  2. 65 15
      proxy/tun/stack_gvisor.go
  3. 88 182
      proxy/tun/udp_fullcone.go

+ 3 - 2
proxy/tun/handler.go

@@ -105,11 +105,12 @@ func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) {
 	sid := session.NewID()
 	ctx := c.ContextWithID(t.ctx, sid)
 
+	source := net.DestinationFromAddr(conn.RemoteAddr())
 	inbound := session.Inbound{
 		Name:          "tun",
 		Tag:           t.tag,
 		CanSpliceCopy: 3,
-		Source:        net.DestinationFromAddr(conn.RemoteAddr()),
+		Source:        source,
 		User: &protocol.MemoryUser{
 			Level: t.config.UserLevel,
 		},
@@ -127,7 +128,7 @@ func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) {
 		Status: log.AccessAccepted,
 		Reason: "",
 	})
-	errors.LogInfo(ctx, "processing TCP from ", conn.RemoteAddr(), " to ", destination)
+	errors.LogInfo(ctx, "processing from ", source, " to ", destination)
 
 	link := &transport.Link{
 		Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},

+ 65 - 15
proxy/tun/stack_gvisor.go

@@ -9,6 +9,7 @@ import (
 	"gvisor.dev/gvisor/pkg/buffer"
 	"gvisor.dev/gvisor/pkg/tcpip"
 	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/checksum"
 	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -102,21 +103,7 @@ func (t *stackGVisor) Start() error {
 	ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
 
 	// Use custom UDP packet handler, instead of strict gVisor forwarder, for FullCone NAT support
-	udpForwarder := newUdpConnectionHandler(t.ctx, t.handler, func(p []byte) {
-		// extract network protocol from the packet
-		var networkProtocol tcpip.NetworkProtocolNumber
-		switch header.IPVersion(p) {
-		case header.IPv4Version:
-			networkProtocol = header.IPv4ProtocolNumber
-		case header.IPv6Version:
-			networkProtocol = header.IPv6ProtocolNumber
-		default:
-			// discard packet with unknown network version
-			return
-		}
-
-		ipStack.WriteRawPacket(defaultNIC, networkProtocol, buffer.MakeWithData(p))
-	})
+	udpForwarder := newUdpConnectionHandler(t.handler.HandleConnection, t.writeRawUDPPacket)
 	ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
 		data := pkt.Data().AsRange().ToSlice()
 		if len(data) == 0 {
@@ -137,6 +124,69 @@ func (t *stackGVisor) Start() error {
 	return nil
 }
 
+func (t *stackGVisor) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error {
+	udpLen := header.UDPMinimumSize + len(payload)
+	srcIP := tcpip.AddrFromSlice(src.Address.IP())
+	dstIP := tcpip.AddrFromSlice(dst.Address.IP())
+
+	// build packet with appropriate IP header size
+	isIPv4 := dst.Address.Family().IsIPv4()
+	ipHdrSize := header.IPv6MinimumSize
+	ipProtocol := header.IPv6ProtocolNumber
+	if isIPv4 {
+		ipHdrSize = header.IPv4MinimumSize
+		ipProtocol = header.IPv4ProtocolNumber
+	}
+
+	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+		ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
+		Payload:            buffer.MakeWithData(payload),
+	})
+	defer pkt.DecRef()
+
+	// Build UDP header
+	udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+	udpHdr.Encode(&header.UDPFields{
+		SrcPort: uint16(src.Port),
+		DstPort: uint16(dst.Port),
+		Length:  uint16(udpLen),
+	})
+
+	// Calculate and set UDP checksum
+	xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
+	udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
+
+	// Build IP header
+	if isIPv4 {
+		ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+		ipHdr.Encode(&header.IPv4Fields{
+			TotalLength: uint16(header.IPv4MinimumSize + udpLen),
+			TTL:         64,
+			Protocol:    uint8(header.UDPProtocolNumber),
+			SrcAddr:     srcIP,
+			DstAddr:     dstIP,
+		})
+		ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
+	} else {
+		ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+		ipHdr.Encode(&header.IPv6Fields{
+			PayloadLength:     uint16(udpLen),
+			TransportProtocol: header.UDPProtocolNumber,
+			HopLimit:          64,
+			SrcAddr:           srcIP,
+			DstAddr:           dstIP,
+		})
+	}
+
+	// dispatch the packet
+	err := t.stack.WriteRawPacket(defaultNIC, ipProtocol, buffer.MakeWithView(pkt.ToView()))
+	if err != nil {
+		return errors.New("failed to write raw udp packet back to stack", err)
+	}
+
+	return nil
+}
+
 // Close is called by Handler to shut down the stack
 func (t *stackGVisor) Close() error {
 	if t.stack == nil {

+ 88 - 182
proxy/tun/udp_fullcone.go

@@ -1,228 +1,134 @@
 package tun
 
 import (
-	"context"
+	"io"
 	"sync"
-	"sync/atomic"
-	"time"
 
-	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
-	c "github.com/xtls/xray-core/common/ctx"
-	"github.com/xtls/xray-core/common/errors"
-	"github.com/xtls/xray-core/common/log"
 	"github.com/xtls/xray-core/common/net"
-	"github.com/xtls/xray-core/common/protocol"
-	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal/done"
-	"github.com/xtls/xray-core/common/task"
-	"github.com/xtls/xray-core/transport"
-	"github.com/xtls/xray-core/transport/pipe"
-	"gvisor.dev/gvisor/pkg/buffer"
-	"gvisor.dev/gvisor/pkg/tcpip"
-	"gvisor.dev/gvisor/pkg/tcpip/checksum"
-	"gvisor.dev/gvisor/pkg/tcpip/header"
-	"gvisor.dev/gvisor/pkg/tcpip/stack"
 )
 
-// udp connection abstraction
-type udpConn struct {
-	lastActive atomic.Int64
-	reader     buf.Reader
-	writer     buf.Writer
-	done       *done.Instance
-	cancel     context.CancelFunc
-}
-
 // sub-handler specifically for udp connections under main handler
 type udpConnectionHandler struct {
 	sync.Mutex
-	ctx         context.Context
-	handler     *Handler
-	udpConns    map[net.Destination]*udpConn
-	udpChecker  *task.Periodic
-	writePacket func(p []byte)
+
+	udpConns map[net.Destination]*udpConn
+
+	handleConnection func(conn net.Conn, dest net.Destination)
+	writePacket      func(data []byte, src net.Destination, dst net.Destination) error
 }
 
-func newUdpConnectionHandler(ctx context.Context, h *Handler, writePacket func(p []byte)) *udpConnectionHandler {
+func newUdpConnectionHandler(handleConnection func(conn net.Conn, dest net.Destination), writePacket func(data []byte, src net.Destination, dst net.Destination) error) *udpConnectionHandler {
 	handler := &udpConnectionHandler{
-		ctx:         ctx,
-		handler:     h,
-		udpConns:    make(map[net.Destination]*udpConn),
-		writePacket: writePacket,
+		udpConns:         make(map[net.Destination]*udpConn),
+		handleConnection: handleConnection,
+		writePacket:      writePacket,
 	}
 
-	handler.udpChecker = &task.Periodic{Interval: time.Minute, Execute: handler.cleanupUDP}
-	handler.udpChecker.Start()
-
 	return handler
 }
 
-func (u *udpConnectionHandler) cleanupUDP() error {
-	u.Lock()
-	defer u.Unlock()
-	if len(u.udpConns) == 0 {
-		return errors.New("no connections")
-	}
-	now := time.Now().Unix()
-	for src, conn := range u.udpConns {
-		if now-conn.lastActive.Load() > 300 {
-			conn.cancel()
-			common.Must(conn.done.Close())
-			common.Must(common.Close(conn.writer))
-			delete(u.udpConns, src)
-		}
-	}
-	return nil
-}
-
 // HandlePacket handles UDP packets coming from tun, to forward to the dispatcher
-// this custom handler support FullCone NAT of returning packets, binding connection only by the source port
+// this custom handler support FullCone NAT of returning packets, binding connection only by the source addr:port
 func (u *udpConnectionHandler) HandlePacket(src net.Destination, dst net.Destination, data []byte) bool {
 	u.Lock()
 	conn, found := u.udpConns[src]
 	if !found {
-		reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
-		conn = &udpConn{reader: reader, writer: writer, done: done.New()}
+		egress := make(chan []byte, 16)
+		conn = &udpConn{handler: u, egress: egress, src: src, dst: dst}
 		u.udpConns[src] = conn
-		u.Unlock()
-
-		go func() {
-			ctx, cancel := context.WithCancel(u.ctx)
-			conn.cancel = cancel
-			defer func() {
-				cancel()
-				u.Lock()
-				delete(u.udpConns, src)
-				u.Unlock()
-				common.Must(conn.done.Close())
-				common.Must(common.Close(conn.writer))
-			}()
-
-			inbound := &session.Inbound{
-				Name:          "tun",
-				Tag:           u.handler.tag,
-				Source:        src,
-				CanSpliceCopy: 3,
-				User:          &protocol.MemoryUser{Level: u.handler.config.UserLevel},
-			}
-			ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound)
-			ctx = session.ContextWithContent(ctx, &session.Content{
-				SniffingRequest: u.handler.sniffingRequest,
-			})
-			ctx = session.SubContextFromMuxInbound(ctx)
-			ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
-				From:   src,
-				To:     dst,
-				Status: log.AccessAccepted,
-				Reason: "",
-			})
-			errors.LogInfo(ctx, "processing UDP from ", src, " to ", dst)
-			link := &transport.Link{
-				Reader: &buf.TimeoutWrapperReader{Reader: conn.reader},
-				// reverse source and destination, indicating the packets to write are going in the other
-				// direction (written back to tun) and should have reversed addressing
-				Writer: &udpWriter{handler: u, src: dst, dst: src},
-			}
-			_ = u.handler.dispatcher.DispatchLink(ctx, dst, link)
-		}()
-	} else {
-		conn.lastActive.Store(time.Now().Unix())
-		u.Unlock()
+
+		go u.handleConnection(conn, dst)
 	}
+	u.Unlock()
 
-	b := buf.New()
-	b.Write(data)
-	b.UDP = &dst
-	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
+	// send packet data to the egress channel, if it has buffer, or discard
+	select {
+	case conn.egress <- data:
+	default:
+	}
 
 	return true
 }
 
-type udpWriter struct {
+func (u *udpConnectionHandler) connectionFinished(src net.Destination) {
+	u.Lock()
+	conn, found := u.udpConns[src]
+	if found {
+		delete(u.udpConns, src)
+		close(conn.egress)
+	}
+	u.Unlock()
+}
+
+// udp connection abstraction
+type udpConn struct {
+	net.Conn
+	buf.Writer
+
 	handler *udpConnectionHandler
-	// address in the side of stack, where packet will be coming from
-	src net.Destination
-	// address on the side of tun, where packet will be destined to
-	dst net.Destination
+
+	egress chan []byte
+	src    net.Destination
+	dst    net.Destination
 }
 
-func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
-	for _, b := range mb {
-		// use captured in the dispatched packet source address b.UDP as source, if available,
-		// otherwise use captured in the writer source w.src
-		srcAddr := w.src
-		if b.UDP != nil {
-			srcAddr = *b.UDP
-		}
+// Read packets from the connection
+func (c *udpConn) Read(p []byte) (int, error) {
+	data, ok := <-c.egress
+	if !ok {
+		return 0, io.EOF
+	}
 
-		// validate address family matches
-		if srcAddr.Address.Family() != w.src.Address.Family() {
-			errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family())
-			b.Release()
-			continue
-		}
+	n := copy(p, data)
+	return n, nil
+}
+
+// Write returning packets back
+func (c *udpConn) Write(p []byte) (int, error) {
+	// sending packets back mean sending payload with source/destination reversed
+	err := c.handler.writePacket(p, c.dst, c.src)
+	if err != nil {
+		return 0, nil
+	}
+
+	return len(p), nil
+}
 
-		payload := b.Bytes()
-		udpLen := header.UDPMinimumSize + len(payload)
-		srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP())
-		dstIP := tcpip.AddrFromSlice(w.dst.Address.IP())
+func (c *udpConn) Close() error {
+	c.handler.connectionFinished(c.src)
 
-		// build packet with appropriate IP header size
-		isIPv4 := srcAddr.Address.Family().IsIPv4()
-		ipHdrSize := header.IPv6MinimumSize
-		if isIPv4 {
-			ipHdrSize = header.IPv4MinimumSize
+	return nil
+}
+
+func (c *udpConn) LocalAddr() net.Addr {
+	return &net.UDPAddr{IP: c.dst.Address.IP(), Port: int(c.dst.Port.Value())}
+}
+
+func (c *udpConn) RemoteAddr() net.Addr {
+	return &net.UDPAddr{IP: c.src.Address.IP(), Port: int(c.src.Port.Value())}
+}
+
+// Write returning packets back
+func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
+	for _, b := range mb {
+		dst := c.dst
+		if b.UDP != nil {
+			dst = *b.UDP
 		}
 
-		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
-			ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
-			Payload:            buffer.MakeWithData(payload),
-		})
-
-		// Build UDP header
-		udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
-		udpHdr.Encode(&header.UDPFields{
-			SrcPort: uint16(srcAddr.Port),
-			DstPort: uint16(w.dst.Port),
-			Length:  uint16(udpLen),
-		})
-
-		// Calculate and set UDP checksum
-		xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
-		udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
-
-		// Build IP header
-		if isIPv4 {
-			ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
-			ipHdr.Encode(&header.IPv4Fields{
-				TotalLength: uint16(header.IPv4MinimumSize + udpLen),
-				TTL:         64,
-				Protocol:    uint8(header.UDPProtocolNumber),
-				SrcAddr:     srcIP,
-				DstAddr:     dstIP,
-			})
-			ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
-		} else {
-			ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
-			ipHdr.Encode(&header.IPv6Fields{
-				PayloadLength:     uint16(udpLen),
-				TransportProtocol: header.UDPProtocolNumber,
-				HopLimit:          64,
-				SrcAddr:           srcIP,
-				DstAddr:           dstIP,
-			})
+		// validate address family matches between buffer packet and the connection
+		if dst.Address.Family() != c.dst.Address.Family() {
+			continue
 		}
 
-		// Write raw packet to network stack
-		views := pkt.AsSlices()
-		var data []byte
-		for _, view := range views {
-			data = append(data, view...)
+		// sending packets back mean sending payload with source/destination reversed
+		err := c.handler.writePacket(b.Bytes(), dst, c.src)
+		if err != nil {
+			// udp doesn't guarantee delivery, so in any failure we just continue to the next packet
+			continue
 		}
-		w.handler.writePacket(data)
-		pkt.DecRef()
-		b.Release()
 	}
+
 	return nil
 }