Просмотр исходного кода

Fix WireGuard multi-peer issue by changing to push-based architecture

In the original implementation, all peer connections shared a single
readQueue channel where ReceiveFunc workers would request buffers.
This caused responses from different peers to be incorrectly associated
with the wrong endpoint when multiple peers were active.

The fix changes to a push-based architecture where:
- Each endpoint's read goroutine pushes responses directly to a shared
  responseRecv channel along with the correct endpoint identity
- The ReceiveFunc workers simply receive from this channel
- This ensures each response is correctly associated with its source endpoint

Fixes issue #4507: multi wg peers outbound only one established

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 месяцев назад
Родитель
Сommit
e3606dea4c
2 измененных файлов с 41 добавлено и 38 удалено
  1. 30 30
      proxy/wireguard/bind.go
  2. 11 8
      proxy/wireguard/server.go

+ 30 - 30
proxy/wireguard/bind.go

@@ -5,7 +5,6 @@ import (
 	"errors"
 	"net/netip"
 	"strconv"
-	"sync"
 
 	"golang.zx2c4.com/wireguard/conn"
 
@@ -14,12 +13,10 @@ import (
 	"github.com/xtls/xray-core/transport/internet"
 )
 
+// netReadInfo holds the result of a read operation from a specific endpoint
 type netReadInfo struct {
-	// status
-	waiter sync.WaitGroup
-	// param
-	buff []byte
 	// result
+	buff     []byte
 	bytes    int
 	endpoint conn.Endpoint
 	err      error
@@ -30,8 +27,8 @@ type netBind struct {
 	dns       dns.Client
 	dnsOption dns.IPOption
 
-	workers   int
-	readQueue chan *netReadInfo
+	workers      int
+	responseRecv chan *netReadInfo // responses from all endpoints flow through here
 }
 
 // SetMark implements conn.Bind
@@ -79,7 +76,7 @@ func (bind *netBind) BatchSize() int {
 
 // Open implements conn.Bind
 func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
-	bind.readQueue = make(chan *netReadInfo)
+	bind.responseRecv = make(chan *netReadInfo)
 
 	fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
 		defer func() {
@@ -89,13 +86,14 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 			}
 		}()
 
-		r := &netReadInfo{
-			buff: bufs[0],
+		r, ok := <-bind.responseRecv
+		if !ok {
+			return 0, errors.New("channel closed")
 		}
-		r.waiter.Add(1)
-		bind.readQueue <- r
-		r.waiter.Wait() // wait read goroutine done, or we will miss the result
-		sizes[0], eps[0] = r.bytes, r.endpoint
+
+		copy(bufs[0], r.buff[:r.bytes])
+		sizes[0] = r.bytes
+		eps[0] = r.endpoint
 		return 1, r.err
 	}
 	workers := bind.workers
@@ -112,8 +110,8 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 
 // Close implements conn.Bind
 func (bind *netBind) Close() error {
-	if bind.readQueue != nil {
-		close(bind.readQueue)
+	if bind.responseRecv != nil {
+		close(bind.responseRecv)
 	}
 	return nil
 }
@@ -133,30 +131,32 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
 	}
 	endpoint.conn = c
 
-	go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
+	go func(responseRecv chan<- *netReadInfo, endpoint *netEndpoint, c net.Conn) {
+		defer func() {
+			_ = recover() // gracefully handle send on closed channel
+		}()
 		for {
-			v, ok := <-readQueue
-			if !ok {
-				return
-			}
-			i, err := c.Read(v.buff)
+			buff := make([]byte, 1700) // max MTU for WireGuard
+			i, err := c.Read(buff)
 
 			if i > 3 {
-				v.buff[1] = 0
-				v.buff[2] = 0
-				v.buff[3] = 0
+				buff[1] = 0
+				buff[2] = 0
+				buff[3] = 0
 			}
 
-			v.bytes = i
-			v.endpoint = endpoint
-			v.err = err
-			v.waiter.Done()
+			responseRecv <- &netReadInfo{
+				buff:     buff,
+				bytes:    i,
+				endpoint: endpoint,
+				err:      err,
+			}
 			if err != nil {
 				endpoint.conn = nil
 				return
 			}
 		}
-	}(bind.readQueue, endpoint)
+	}(bind.responseRecv, endpoint, c)
 
 	return nil
 }

+ 11 - 8
proxy/wireguard/server.go

@@ -101,16 +101,19 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 		}
 
 		for _, payload := range mpayload {
-			v, ok := <-s.bindServer.readQueue
-			if !ok {
+			data := make([]byte, payload.Len())
+			n, err := payload.Read(data)
+
+			select {
+			case s.bindServer.responseRecv <- &netReadInfo{
+				buff:     data,
+				bytes:    n,
+				endpoint: nep,
+				err:      err,
+			}:
+			case <-ctx.Done():
 				return nil
 			}
-			i, err := payload.Read(v.buff)
-
-			v.bytes = i
-			v.endpoint = nep
-			v.err = err
-			v.waiter.Done()
 			if err != nil && goerrors.Is(err, io.EOF) {
 				nep.conn = nil
 				return nil