Просмотр исходного кода

Minimize code changes: keep existing structure, only change data flow direction

Changed from pull-based (ReceiveFunc sends request, goroutine fills) to
push-based (goroutine reads and pushes, ReceiveFunc receives) while
keeping the same channel and data structures.

The core fix: each endpoint's goroutine now owns reading from its own
connection and pushes data with correct endpoint identity, instead of
competing for shared read requests.

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 месяцев назад
Родитель
Сommit
f2d5e3357d
2 измененных файлов с 30 добавлено и 33 удалено
  1. 20 25
      proxy/wireguard/bind.go
  2. 10 8
      proxy/wireguard/server.go

+ 20 - 25
proxy/wireguard/bind.go

@@ -14,18 +14,12 @@ import (
 	"github.com/xtls/xray-core/transport/internet"
 )
 
-const udpBufferSize = 1700 // max MTU for WireGuard
-
-var bufferPool = sync.Pool{
-	New: func() any {
-		return make([]byte, udpBufferSize)
-	},
-}
-
-// netReadInfo holds the result of a read operation from a specific endpoint
 type netReadInfo struct {
+	// status
+	waiter sync.WaitGroup
+	// param
+	buff []byte
 	// result
-	buff     []byte
 	bytes    int
 	endpoint conn.Endpoint
 	err      error
@@ -36,8 +30,8 @@ type netBind struct {
 	dns       dns.Client
 	dnsOption dns.IPOption
 
-	workers      int
-	responseRecv chan *netReadInfo // responses from all endpoints flow through here
+	workers   int
+	readQueue chan *netReadInfo
 }
 
 // SetMark implements conn.Bind
@@ -85,7 +79,7 @@ func (bind *netBind) BatchSize() int {
 
 // Open implements conn.Bind
 func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
-	bind.responseRecv = make(chan *netReadInfo)
+	bind.readQueue = make(chan *netReadInfo)
 
 	fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
 		defer func() {
@@ -95,16 +89,14 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 			}
 		}()
 
-		r, ok := <-bind.responseRecv
+		r, ok := <-bind.readQueue
 		if !ok {
 			return 0, errors.New("channel closed")
 		}
 
 		copy(bufs[0], r.buff[:r.bytes])
-		sizes[0] = r.bytes
-		eps[0] = r.endpoint
-		// Return buffer to pool
-		bufferPool.Put(r.buff)
+		sizes[0], eps[0] = r.bytes, r.endpoint
+		r.waiter.Done()
 		return 1, r.err
 	}
 	workers := bind.workers
@@ -121,8 +113,8 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 
 // Close implements conn.Bind
 func (bind *netBind) Close() error {
-	if bind.responseRecv != nil {
-		close(bind.responseRecv)
+	if bind.readQueue != nil {
+		close(bind.readQueue)
 	}
 	return nil
 }
@@ -142,12 +134,12 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
 	}
 	endpoint.conn = c
 
-	go func(responseRecv chan<- *netReadInfo, endpoint *netEndpoint, c net.Conn) {
+	go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
 		defer func() {
-			_ = recover() // gracefully handle send on closed channel
+			_ = recover() // handle send on closed channel
 		}()
 		for {
-			buff := bufferPool.Get().([]byte)
+			buff := make([]byte, 1700)
 			i, err := c.Read(buff)
 
 			if i > 3 {
@@ -156,18 +148,21 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
 				buff[3] = 0
 			}
 
-			responseRecv <- &netReadInfo{
+			r := &netReadInfo{
 				buff:     buff,
 				bytes:    i,
 				endpoint: endpoint,
 				err:      err,
 			}
+			r.waiter.Add(1)
+			readQueue <- r
+			r.waiter.Wait()
 			if err != nil {
 				endpoint.conn = nil
 				return
 			}
 		}
-	}(bind.responseRecv, endpoint, c)
+	}(bind.readQueue, endpoint)
 
 	return nil
 }

+ 10 - 8
proxy/wireguard/server.go

@@ -101,18 +101,20 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 		}
 
 		for _, payload := range mpayload {
-			data := bufferPool.Get().([]byte)
-			n, err := payload.Read(data)
+			buff := make([]byte, payload.Len())
+			i, err := payload.Read(buff)
 
-			select {
-			case s.bindServer.responseRecv <- &netReadInfo{
-				buff:     data,
-				bytes:    n,
+			r := &netReadInfo{
+				buff:     buff,
+				bytes:    i,
 				endpoint: nep,
 				err:      err,
-			}:
+			}
+			r.waiter.Add(1)
+			select {
+			case s.bindServer.readQueue <- r:
+				r.waiter.Wait()
 			case <-ctx.Done():
-				bufferPool.Put(data) // Return buffer if not sent
 				return nil
 			}
 			if err != nil && goerrors.Is(err, io.EOF) {