Jelajahi Sumber

Fix WireGuard multi-peer issue by changing to push-based architecture

In the original implementation, all peer connections shared a single
readQueue channel where ReceiveFunc workers would request buffers.
This caused responses from different peers to be incorrectly associated
with the wrong endpoint when multiple peers were active.

The fix changes to a push-based architecture where:
- Each endpoint's read goroutine pushes responses directly to a shared
  responseRecv channel along with the correct endpoint identity
- The ReceiveFunc workers simply receive from this channel
- This ensures each response is correctly associated with its source endpoint

Fixes issue #4507: multi wg peers outbound only one established

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 bulan lalu
induk
melakukan
e3606dea4c
2 mengubah file dengan 41 tambahan dan 38 penghapusan
  1. 30 30
      proxy/wireguard/bind.go
  2. 11 8
      proxy/wireguard/server.go

+ 30 - 30
proxy/wireguard/bind.go

@@ -5,7 +5,6 @@ import (
 	"errors"
 	"errors"
 	"net/netip"
 	"net/netip"
 	"strconv"
 	"strconv"
-	"sync"
 
 
 	"golang.zx2c4.com/wireguard/conn"
 	"golang.zx2c4.com/wireguard/conn"
 
 
@@ -14,12 +13,10 @@ import (
 	"github.com/xtls/xray-core/transport/internet"
 	"github.com/xtls/xray-core/transport/internet"
 )
 )
 
 
+// netReadInfo holds the result of a read operation from a specific endpoint
 type netReadInfo struct {
 type netReadInfo struct {
-	// status
-	waiter sync.WaitGroup
-	// param
-	buff []byte
 	// result
 	// result
+	buff     []byte
 	bytes    int
 	bytes    int
 	endpoint conn.Endpoint
 	endpoint conn.Endpoint
 	err      error
 	err      error
@@ -30,8 +27,8 @@ type netBind struct {
 	dns       dns.Client
 	dns       dns.Client
 	dnsOption dns.IPOption
 	dnsOption dns.IPOption
 
 
-	workers   int
-	readQueue chan *netReadInfo
+	workers      int
+	responseRecv chan *netReadInfo // responses from all endpoints flow through here
 }
 }
 
 
 // SetMark implements conn.Bind
 // SetMark implements conn.Bind
@@ -79,7 +76,7 @@ func (bind *netBind) BatchSize() int {
 
 
 // Open implements conn.Bind
 // Open implements conn.Bind
 func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
-	bind.readQueue = make(chan *netReadInfo)
+	bind.responseRecv = make(chan *netReadInfo)
 
 
 	fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
 	fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
 		defer func() {
 		defer func() {
@@ -89,13 +86,14 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 			}
 			}
 		}()
 		}()
 
 
-		r := &netReadInfo{
-			buff: bufs[0],
+		r, ok := <-bind.responseRecv
+		if !ok {
+			return 0, errors.New("channel closed")
 		}
 		}
-		r.waiter.Add(1)
-		bind.readQueue <- r
-		r.waiter.Wait() // wait read goroutine done, or we will miss the result
-		sizes[0], eps[0] = r.bytes, r.endpoint
+
+		copy(bufs[0], r.buff[:r.bytes])
+		sizes[0] = r.bytes
+		eps[0] = r.endpoint
 		return 1, r.err
 		return 1, r.err
 	}
 	}
 	workers := bind.workers
 	workers := bind.workers
@@ -112,8 +110,8 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 
 
 // Close implements conn.Bind
 // Close implements conn.Bind
 func (bind *netBind) Close() error {
 func (bind *netBind) Close() error {
-	if bind.readQueue != nil {
-		close(bind.readQueue)
+	if bind.responseRecv != nil {
+		close(bind.responseRecv)
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -133,30 +131,32 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
 	}
 	}
 	endpoint.conn = c
 	endpoint.conn = c
 
 
-	go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
+	go func(responseRecv chan<- *netReadInfo, endpoint *netEndpoint, c net.Conn) {
+		defer func() {
+			_ = recover() // gracefully handle send on closed channel
+		}()
 		for {
 		for {
-			v, ok := <-readQueue
-			if !ok {
-				return
-			}
-			i, err := c.Read(v.buff)
+			buff := make([]byte, 1700) // max MTU for WireGuard
+			i, err := c.Read(buff)
 
 
 			if i > 3 {
 			if i > 3 {
-				v.buff[1] = 0
-				v.buff[2] = 0
-				v.buff[3] = 0
+				buff[1] = 0
+				buff[2] = 0
+				buff[3] = 0
 			}
 			}
 
 
-			v.bytes = i
-			v.endpoint = endpoint
-			v.err = err
-			v.waiter.Done()
+			responseRecv <- &netReadInfo{
+				buff:     buff,
+				bytes:    i,
+				endpoint: endpoint,
+				err:      err,
+			}
 			if err != nil {
 			if err != nil {
 				endpoint.conn = nil
 				endpoint.conn = nil
 				return
 				return
 			}
 			}
 		}
 		}
-	}(bind.readQueue, endpoint)
+	}(bind.responseRecv, endpoint, c)
 
 
 	return nil
 	return nil
 }
 }

+ 11 - 8
proxy/wireguard/server.go

@@ -101,16 +101,19 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 		}
 		}
 
 
 		for _, payload := range mpayload {
 		for _, payload := range mpayload {
-			v, ok := <-s.bindServer.readQueue
-			if !ok {
+			data := make([]byte, payload.Len())
+			n, err := payload.Read(data)
+
+			select {
+			case s.bindServer.responseRecv <- &netReadInfo{
+				buff:     data,
+				bytes:    n,
+				endpoint: nep,
+				err:      err,
+			}:
+			case <-ctx.Done():
 				return nil
 				return nil
 			}
 			}
-			i, err := payload.Read(v.buff)
-
-			v.bytes = i
-			v.endpoint = nep
-			v.err = err
-			v.waiter.Done()
 			if err != nil && goerrors.Is(err, io.EOF) {
 			if err != nil && goerrors.Is(err, io.EOF) {
 				nep.conn = nil
 				nep.conn = nil
 				return nil
 				return nil