فهرست منبع

Add mutex protection for WireGuard routing info

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 ماه پیش
والد
کامیت
06dee57f8a
1فایلهای تغییر یافته به همراه17 افزوده شده و 8 حذف شده
  1. 17 8
      proxy/wireguard/server.go

+ 17 - 8
proxy/wireguard/server.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	goerrors "errors"
 	"io"
+	"sync"
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
@@ -26,6 +27,7 @@ var nullDestination = net.TCPDestination(net.AnyIP, 0)
 type Server struct {
 	bindServer *netBindServer
 
+	infoMu        sync.RWMutex
 	info          routingInfo
 	policyManager policy.Manager
 }
@@ -78,12 +80,14 @@ func (*Server) Network() []net.Network {
 
 // Process implements proxy.Inbound.
 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
+	s.infoMu.Lock()
 	s.info = routingInfo{
 		ctx:        ctx,
 		dispatcher: dispatcher,
 		inboundTag: session.InboundFromContext(ctx),
 		contentTag: session.ContentFromContext(ctx),
 	}
+	s.infoMu.Unlock()
 
 	ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
 	if err != nil {
@@ -120,18 +124,23 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 }
 
 func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
-	if s.info.dispatcher == nil {
-		errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
+	// Make a thread-safe copy of routing info
+	s.infoMu.RLock()
+	info := s.info
+	s.infoMu.RUnlock()
+
+	if info.dispatcher == nil {
+		errors.LogError(info.ctx, "unexpected: dispatcher == nil")
 		return
 	}
 	defer conn.Close()
 
-	ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
+	ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(info.ctx))
 	sid := session.NewID()
 	ctx = c.ContextWithID(ctx, sid)
 	inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
-	if s.info.inboundTag != nil {
-		inbound = *s.info.inboundTag
+	if info.inboundTag != nil {
+		inbound = *info.inboundTag
 	}
 	inbound.Name = "wireguard"
 	inbound.CanSpliceCopy = 3
@@ -141,8 +150,8 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 	// Currently we have no way to link to the original source address
 	inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
 	ctx = session.ContextWithInbound(ctx, &inbound)
-	if s.info.contentTag != nil {
-		ctx = session.ContextWithContent(ctx, s.info.contentTag)
+	if info.contentTag != nil {
+		ctx = session.ContextWithContent(ctx, info.contentTag)
 	}
 	ctx = session.SubContextFromMuxInbound(ctx)
 
@@ -156,7 +165,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 		Reason: "",
 	})
 
-	link, err := s.info.dispatcher.Dispatch(ctx, dest)
+	link, err := info.dispatcher.Dispatch(ctx, dest)
 	if err != nil {
 		errors.LogErrorInner(ctx, err, "dispatch connection")
 	}