Эх сурвалжийг харах

WireGuard: Implement UDP FullCone NAT (#5833)

Fixes https://github.com/XTLS/Xray-core/issues/5601

---------

Co-authored-by: RPRX <[email protected]>
LjhAUMEM 2 сар өмнө
parent
commit
67a71adad1

+ 1 - 1
common/xudp/xudp.go

@@ -53,7 +53,7 @@ func GetGlobalID(ctx context.Context) (globalID [8]byte) {
 		return
 	}
 	if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
-		(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun") {
+		(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun" || inbound.Name == "wireguard") {
 		h := blake3.New(8, BaseKey)
 		h.Write([]byte(inbound.Source.String()))
 		copy(globalID[:], h.Sum(nil))

+ 1 - 1
go.mod

@@ -26,7 +26,7 @@ require (
 	golang.org/x/sync v0.20.0
 	golang.org/x/sys v0.42.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
-	golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
+	golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
 	google.golang.org/grpc v1.79.3
 	google.golang.org/protobuf v1.36.11
 	gvisor.dev/gvisor v0.0.0-20260122175437-89a5d21be8f0

+ 2 - 0
go.sum

@@ -131,6 +131,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
+golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
+golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
 gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
 gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
 google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=

+ 1 - 1
infra/conf/wireguard.go

@@ -130,7 +130,7 @@ func ParseWireGuardKey(str string) (string, error) {
 		return "", errors.New("key must not be empty")
 	}
 
-	if len(str)%2 == 0 {
+	if len(str) == 64 {
 		_, err = hex.DecodeString(str)
 		if err == nil {
 			return str, nil

+ 36 - 0
proxy/wireguard/client.go

@@ -227,6 +227,11 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		}
 		defer conn.Close()
 
+		conn = &udpConnClient{
+			Conn: conn,
+			dest: destination,
+		}
+
 		requestFunc = func() error {
 			defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
 			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
@@ -336,3 +341,34 @@ func (h *Handler) createIPCRequest() string {
 
 	return request.String()[:request.Len()]
 }
+
+type udpConnClient struct {
+	net.Conn
+	dest net.Destination
+}
+
+func (c *udpConnClient) ReadMultiBuffer() (buf.MultiBuffer, error) {
+	b := buf.New()
+	b.Resize(0, buf.Size)
+	n, addr, err := c.Conn.(net.PacketConn).ReadFrom(b.Bytes())
+	if err != nil {
+		b.Release()
+		return nil, err
+	}
+	if addr == nil { // should never hit
+		addr = c.dest.RawNetAddr()
+	}
+	b.Resize(0, int32(n))
+
+	b.UDP = &net.Destination{
+		Address: net.IPAddress(addr.(*net.UDPAddr).IP),
+		Port:    net.Port(addr.(*net.UDPAddr).Port),
+		Network: net.Network_UDP,
+	}
+
+	return buf.MultiBuffer{b}, nil
+}
+
+func (c *udpConnClient) Write(p []byte) (int, error) {
+	return c.Conn.(net.PacketConn).WriteTo(p, c.dest.RawNetAddr())
+}

+ 18 - 21
proxy/wireguard/gvisortun/tun.go

@@ -31,6 +31,7 @@ type netTun struct {
 	ep             *channel.Endpoint
 	stack          *stack.Stack
 	events         chan tun.Event
+	notifyHandle   *channel.NotificationHandle
 	incomingPacket chan *buffer.View
 	mtu            int
 	hasV4, hasV6   bool
@@ -48,12 +49,17 @@ func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (t
 	dev := &netTun{
 		ep:             channel.New(1024, uint32(mtu), ""),
 		stack:          stack.New(opts),
-		events:         make(chan tun.Event, 1),
+		events:         make(chan tun.Event, 10),
 		incomingPacket: make(chan *buffer.View),
 		mtu:            mtu,
 	}
-	dev.ep.AddNotify(dev)
-	tcpipErr := dev.stack.CreateNIC(1, dev.ep)
+	sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+	tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+	if tcpipErr != nil {
+		return nil, nil, dev.stack, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+	}
+	dev.notifyHandle = dev.ep.AddNotify(dev)
+	tcpipErr = dev.stack.CreateNIC(1, dev.ep)
 	if tcpipErr != nil {
 		return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
 	}
@@ -90,20 +96,10 @@ func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (t
 		dev.stack.SetSpoofing(1, true)
 	}
 
-	opt := tcpip.CongestionControlOption("cubic")
-	if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
-		return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
-	}
-
 	dev.events <- tun.EventUp
 	return dev, (*Net)(dev), dev.stack, nil
 }
 
-// BatchSize implements tun.Device
-func (tun *netTun) BatchSize() int {
-	return 1
-}
-
 // Name implements tun.Device
 func (tun *netTun) Name() (string, error) {
 	return "go", nil
@@ -120,7 +116,6 @@ func (tun *netTun) Events() <-chan tun.Event {
 }
 
 // Read implements tun.Device
-
 func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
 	view, ok := <-tun.incomingPacket
 	if !ok {
@@ -169,20 +164,16 @@ func (tun *netTun) WriteNotify() {
 	tun.incomingPacket <- view
 }
 
-// Flush  implements tun.Device
-func (tun *netTun) Flush() error {
-	return nil
-}
-
 // Close implements tun.Device
 func (tun *netTun) Close() error {
 	tun.closeOnce.Do(func() {
 		tun.stack.RemoveNIC(1)
+		tun.stack.Close()
+		tun.ep.RemoveNotify(tun.notifyHandle)
+		tun.ep.Close()
 
 		close(tun.events)
 
-		tun.ep.Close()
-
 		close(tun.incomingPacket)
 	})
 	return nil
@@ -193,6 +184,11 @@ func (tun *netTun) MTU() (int, error) {
 	return tun.mtu, nil
 }
 
+// BatchSize implements tun.Device
+func (tun *netTun) BatchSize() int {
+	return 1
+}
+
 func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
 	var protoNumber tcpip.NetworkProtocolNumber
 	if endpoint.Addr().Is4() {
@@ -224,6 +220,7 @@ func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, er
 		var addr tcpip.FullAddress
 		addr, pn = convertToFullAddr(raddr)
 		rfa = &addr
+		rfa = nil // do not ep connect
 	}
 	return gonet.DialUDP(net.stack, lfa, rfa, pn)
 }

+ 13 - 39
proxy/wireguard/server.go

@@ -5,19 +5,17 @@ import (
 	goerrors "errors"
 	"io"
 
-	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
 	c "github.com/xtls/xray-core/common/ctx"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/log"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal"
-	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/dns"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/routing"
+	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport/internet/stat"
 )
 
@@ -31,10 +29,10 @@ type Server struct {
 }
 
 type routingInfo struct {
-	ctx         context.Context
-	dispatcher  routing.Dispatcher
-	inboundTag  *session.Inbound
-	contentTag  *session.Content
+	ctx        context.Context
+	dispatcher routing.Dispatcher
+	inboundTag *session.Inbound
+	contentTag *session.Content
 }
 
 func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
@@ -124,7 +122,6 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 		errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
 		return
 	}
-	defer conn.Close()
 
 	ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
 	sid := session.NewID()
@@ -146,9 +143,6 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 	}
 	ctx = session.SubContextFromMuxInbound(ctx)
 
-	plcy := s.policyManager.ForLevel(0)
-	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
-
 	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
 		From:   nullDestination,
 		To:     dest,
@@ -156,35 +150,15 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 		Reason: "",
 	})
 
-	link, err := s.info.dispatcher.Dispatch(ctx, dest)
-	if err != nil {
-		errors.LogErrorInner(ctx, err, "dispatch connection")
-	}
-	defer cancel()
-
-	requestDone := func() error {
-		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
-		if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
-			return errors.New("failed to transport all TCP request").Base(err)
-		}
-
-		return nil
-	}
-
-	responseDone := func() error {
-		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
-		if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
-			return errors.New("failed to transport all TCP response").Base(err)
-		}
+	err := s.info.dispatcher.DispatchLink(ctx, dest, &transport.Link{
+		Reader: buf.NewReader(conn),
+		Writer: buf.NewWriter(conn),
+	})
 
-		return nil
+	if err != nil {
+		errors.LogInfoInner(ctx, err, "connection ends")
 	}
 
-	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
-	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
-		common.Interrupt(link.Reader)
-		common.Interrupt(link.Writer)
-		errors.LogDebugInner(ctx, err, "connection ends")
-		return
-	}
+	cancel()
+	conn.Close()
 }

+ 218 - 36
proxy/wireguard/tun.go

@@ -3,6 +3,7 @@ package wireguard
 import (
 	"context"
 	"fmt"
+	"io"
 	"net/netip"
 	"runtime"
 	"strconv"
@@ -10,12 +11,17 @@ import (
 	"sync"
 	"time"
 
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/log"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
+	"gvisor.dev/gvisor/pkg/buffer"
 	"gvisor.dev/gvisor/pkg/tcpip"
 	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/checksum"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
 	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
 	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
 	"gvisor.dev/gvisor/pkg/waiter"
@@ -138,7 +144,7 @@ func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, erro
 
 func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
 	out := &gvisorNet{}
-	tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
+	tun, n, gstack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
 	if err != nil {
 		return nil, err
 	}
@@ -147,60 +153,236 @@ func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
 		// handler is only used for promiscuous mode
 		// capture all packets and send to handler
 
-		tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
+		tcpForwarder := tcp.NewForwarder(gstack, 0, 65535, func(r *tcp.ForwarderRequest) {
 			go func(r *tcp.ForwarderRequest) {
-				var (
-					wq waiter.Queue
-					id = r.ID()
-				)
+				var wq waiter.Queue
+				var id = r.ID()
 
-				// Perform a TCP three-way handshake.
 				ep, err := r.CreateEndpoint(&wq)
 				if err != nil {
 					errors.LogError(context.Background(), err.String())
 					r.Complete(true)
 					return
 				}
-				r.Complete(false)
-				defer ep.Close()
 
-				// enable tcp keep-alive to prevent hanging connections
-				ep.SocketOptions().SetKeepAlive(true)
+				options := ep.SocketOptions()
+				options.SetKeepAlive(false)
+				options.SetReuseAddress(true)
+				options.SetReusePort(true)
 
-				// local address is actually destination
 				handler(net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
+
+				ep.Close()
+				r.Complete(false)
 			}(r)
 		})
-		stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
-
-		udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) bool {
-			go func(r *udp.ForwarderRequest) {
-				var (
-					wq waiter.Queue
-					id = r.ID()
-				)
-
-				ep, err := r.CreateEndpoint(&wq)
-				if err != nil {
-					errors.LogError(context.Background(), err.String())
-					return
-				}
-				defer ep.Close()
+		gstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
 
-				// prevents hanging connections and ensure timely release
-				ep.SocketOptions().SetLinger(tcpip.LingerOption{
-					Enabled: true,
-					Timeout: 15 * time.Second,
-				})
-
-				handler(net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewUDPConn(&wq, ep))
-			}(r)
+		manager := &udpManager{
+			stack:   gstack,
+			handler: handler,
+			m:       make(map[string]*udpConn),
+		}
 
+		gstack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+			data := pkt.Clone().Data().AsRange().ToSlice()
+			// if len(data) == 0 {
+			// 	return false
+			// }
+			src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
+			dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
+			manager.feed(src, dst, data)
 			return true
 		})
-		stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
 	}
 
 	out.tun, out.net = tun, n
 	return out, nil
 }
+
+type udpManager struct {
+	stack   *stack.Stack
+	handler func(dest net.Destination, conn net.Conn)
+	m       map[string]*udpConn
+	mutex   sync.RWMutex
+}
+
+func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) {
+	m.mutex.RLock()
+	uc, ok := m.m[src.NetAddr()]
+	if ok {
+		select {
+		case uc.ch <- data:
+		default:
+		}
+		m.mutex.RUnlock()
+		return
+	}
+	m.mutex.RUnlock()
+
+	m.mutex.Lock()
+	defer m.mutex.Unlock()
+
+	uc, ok = m.m[src.NetAddr()]
+	if !ok {
+		uc = &udpConn{
+			ch:  make(chan []byte, 1024),
+			src: src,
+			dst: dst,
+		}
+		uc.writeFunc = m.writeRawUDPPacket
+		uc.closeFunc = func() {
+			m.mutex.Lock()
+			m.close(uc)
+			m.mutex.Unlock()
+		}
+		m.m[src.NetAddr()] = uc
+		go m.handler(dst, uc)
+	}
+
+	select {
+	case uc.ch <- data:
+	default:
+	}
+}
+
+func (m *udpManager) close(uc *udpConn) {
+	if !uc.closed {
+		uc.closed = true
+		close(uc.ch)
+		delete(m.m, uc.src.NetAddr())
+	}
+}
+
+func (m *udpManager) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error {
+	udpLen := header.UDPMinimumSize + len(payload)
+	srcIP := tcpip.AddrFromSlice(src.Address.IP())
+	dstIP := tcpip.AddrFromSlice(dst.Address.IP())
+
+	// build packet with appropriate IP header size
+	isIPv4 := dst.Address.Family().IsIPv4()
+	ipHdrSize := header.IPv6MinimumSize
+	ipProtocol := header.IPv6ProtocolNumber
+	if isIPv4 {
+		ipHdrSize = header.IPv4MinimumSize
+		ipProtocol = header.IPv4ProtocolNumber
+	}
+
+	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+		ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
+		Payload:            buffer.MakeWithData(payload),
+	})
+	defer pkt.DecRef()
+
+	// Build UDP header
+	udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+	udpHdr.Encode(&header.UDPFields{
+		SrcPort: uint16(src.Port),
+		DstPort: uint16(dst.Port),
+		Length:  uint16(udpLen),
+	})
+
+	// Calculate and set UDP checksum
+	xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
+	udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
+
+	// Build IP header
+	if isIPv4 {
+		ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+		ipHdr.Encode(&header.IPv4Fields{
+			TotalLength: uint16(header.IPv4MinimumSize + udpLen),
+			TTL:         64,
+			Protocol:    uint8(header.UDPProtocolNumber),
+			SrcAddr:     srcIP,
+			DstAddr:     dstIP,
+		})
+		ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
+	} else {
+		ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+		ipHdr.Encode(&header.IPv6Fields{
+			PayloadLength:     uint16(udpLen),
+			TransportProtocol: header.UDPProtocolNumber,
+			HopLimit:          64,
+			SrcAddr:           srcIP,
+			DstAddr:           dstIP,
+		})
+	}
+
+	// dispatch the packet
+	err := m.stack.WriteRawPacket(1, ipProtocol, buffer.MakeWithView(pkt.ToView()))
+	if err != nil {
+		return errors.New("failed to write raw udp packet back to stack err ", err)
+	}
+
+	return nil
+}
+
+type udpConn struct {
+	ch        chan []byte
+	src       net.Destination
+	dst       net.Destination
+	writeFunc func(payload []byte, src net.Destination, dst net.Destination) error
+	closeFunc func()
+	closed    bool
+}
+
+func (c *udpConn) Read(p []byte) (int, error) {
+	b, ok := <-c.ch
+	if !ok {
+		return 0, io.EOF
+	}
+	n := copy(p, b)
+	if n != len(b) {
+		return 0, io.ErrShortBuffer
+	}
+	return n, nil
+}
+
+func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
+	for i, b := range mb {
+		dst := c.dst
+		if b.UDP != nil {
+			dst = *b.UDP
+		}
+		err := c.writeFunc(b.Bytes(), dst, c.src)
+		if err != nil {
+			buf.ReleaseMulti(mb[i:])
+			return err
+		}
+		b.Release()
+	}
+	return nil
+}
+
+func (c *udpConn) Write(p []byte) (int, error) {
+	err := c.writeFunc(p, c.dst, c.src)
+	if err != nil {
+		return 0, err
+	}
+	return len(p), nil
+}
+
+func (c *udpConn) Close() error {
+	c.closeFunc()
+	return nil
+}
+
+func (c *udpConn) LocalAddr() net.Addr {
+	return c.src.RawNetAddr() // fake
+}
+
+func (c *udpConn) RemoteAddr() net.Addr {
+	return c.src.RawNetAddr() // src
+}
+
+func (c *udpConn) SetDeadline(t time.Time) error {
+	return nil
+}
+
+func (c *udpConn) SetReadDeadline(t time.Time) error {
+	return nil
+}
+
+func (c *udpConn) SetWriteDeadline(t time.Time) error {
+	return nil
+}

+ 29 - 22
transport/internet/hysteria/hub.go

@@ -100,32 +100,39 @@ func (m *udpSessionManagerServer) run() {
 func (m *udpSessionManagerServer) feed(id uint32, d []byte) {
 	m.mutex.RLock()
 	udpConn, ok := m.m[id]
+	if ok {
+		select {
+		case udpConn.ch <- d:
+		default:
+		}
+		m.mutex.RUnlock()
+		return
+	}
 	m.mutex.RUnlock()
 
+	m.mutex.Lock()
+	defer m.mutex.Unlock()
+
+	udpConn, ok = m.m[id]
 	if !ok {
-		m.mutex.Lock()
-		udpConn, ok = m.m[id]
-		if !ok {
-			udpConn = &InterUdpConn{
-				conn:   m.conn,
-				local:  m.conn.LocalAddr(),
-				remote: m.conn.RemoteAddr(),
-
-				id:   id,
-				ch:   make(chan []byte, udpMessageChanSize),
-				last: time.Now(),
-
-				user: m.user,
-			}
-			udpConn.closeFunc = func() {
-				m.mutex.Lock()
-				defer m.mutex.Unlock()
-				m.close(udpConn)
-			}
-			m.m[id] = udpConn
-			m.addConn(udpConn)
+		udpConn = &InterUdpConn{
+			conn:   m.conn,
+			local:  m.conn.LocalAddr(),
+			remote: m.conn.RemoteAddr(),
+
+			id:   id,
+			ch:   make(chan []byte, udpMessageChanSize),
+			last: time.Now(),
+
+			user: m.user,
+		}
+		udpConn.closeFunc = func() {
+			m.mutex.Lock()
+			m.close(udpConn)
+			m.mutex.Unlock()
 		}
-		m.mutex.Unlock()
+		m.m[id] = udpConn
+		m.addConn(udpConn)
 	}
 
 	select {