Răsfoiți Sursa

WireGuard inbound: Fix multi-peer; Fix potential routing issue (#5843)

Fixes https://github.com/XTLS/Xray-core/pull/5554

Fixes https://github.com/XTLS/Xray-core/issues/4760
LjhAUMEM 2 luni în urmă
părinte
comite
8aacdbd71b

+ 2 - 2
common/log/logger.go

@@ -36,7 +36,7 @@ type serverityLogger struct {
 func NewLogger(logWriterCreator WriterCreator) Handler {
 	return &generalLogger{
 		creator: logWriterCreator,
-		buffer:  make(chan Message, 16),
+		buffer:  make(chan Message, 128),
 		access:  semaphore.New(1),
 		done:    done.New(),
 	}
@@ -46,7 +46,7 @@ func ReplaceWithSeverityLogger(serverity Severity) {
 	w := CreateStdoutLogWriter()
 	g := &generalLogger{
 		creator: w,
-		buffer:  make(chan Message, 16),
+		buffer:  make(chan Message, 128),
 		access:  semaphore.New(1),
 		done:    done.New(),
 	}

+ 40 - 45
proxy/wireguard/bind.go

@@ -2,27 +2,23 @@ package wireguard
 
 import (
 	"context"
-	"errors"
+	gonet "net"
 	"net/netip"
+	"runtime"
 	"strconv"
-	"sync"
 
 	"golang.zx2c4.com/wireguard/conn"
+	"golang.zx2c4.com/wireguard/device"
 
+	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/features/dns"
 	"github.com/xtls/xray-core/transport/internet"
 )
 
 type netReadInfo struct {
-	// status
-	waiter sync.WaitGroup
-	// param
-	buff []byte
-	// result
-	bytes    int
+	buff     []byte
 	endpoint conn.Endpoint
-	err      error
 }
 
 // reduce duplicated code
@@ -32,6 +28,7 @@ type netBind struct {
 
 	workers   int
 	readQueue chan *netReadInfo
+	closedCh  chan struct{}
 }
 
 // SetMark implements conn.Bind
@@ -79,27 +76,23 @@ func (bind *netBind) BatchSize() int {
 
 // Open implements conn.Bind
 func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
-	bind.readQueue = make(chan *netReadInfo)
+	bind.closedCh = make(chan struct{})
+	errors.LogDebug(context.Background(), "bind opened")
 
 	fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
-		defer func() {
-			if r := recover(); r != nil {
-				n = 0
-				err = errors.New("channel closed")
-			}
-		}()
-
-		r, ok := <-bind.readQueue
-		if !ok {
-			return 0, errors.New("channel closed")
+		select {
+		case r := <-bind.readQueue:
+			sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint
+			return 1, nil
+		case <-bind.closedCh:
+			errors.LogDebug(context.Background(), "recv func closed")
+			return 0, gonet.ErrClosed
 		}
-
-		copy(bufs[0], r.buff[:r.bytes])
-		sizes[0], eps[0] = r.bytes, r.endpoint
-		r.waiter.Done()
-		return 1, r.err
 	}
 	workers := bind.workers
+	if workers <= 0 {
+		workers = runtime.NumCPU()
+	}
 	if workers <= 0 {
 		workers = 1
 	}
@@ -113,8 +106,9 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 
 // Close implements conn.Bind
 func (bind *netBind) Close() error {
-	if bind.readQueue != nil {
-		close(bind.readQueue)
+	errors.LogDebug(context.Background(), "bind closed")
+	if bind.closedCh != nil {
+		close(bind.closedCh)
 	}
 	return nil
 }
@@ -134,35 +128,35 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
 	}
 	endpoint.conn = c
 
-	go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
-		defer func() {
-			_ = recover() // handle send on closed channel
-		}()
+	go func() {
 		for {
-			buff := make([]byte, 1700)
-			i, err := c.Read(buff)
+			buff := make([]byte, device.MaxMessageSize)
+			n, err := c.Read(buff)
+
+			if err != nil {
+				endpoint.conn = nil
+				c.Close()
+				return
+			}
 
-			if i > 3 {
+			if n > 3 {
 				buff[1] = 0
 				buff[2] = 0
 				buff[3] = 0
 			}
 
-			r := &netReadInfo{
-				buff:     buff,
-				bytes:    i,
+			select {
+			case bind.readQueue <- &netReadInfo{
+				buff:     buff[:n],
 				endpoint: endpoint,
-				err:      err,
-			}
-			r.waiter.Add(1)
-			readQueue <- r
-			r.waiter.Wait()
-			if err != nil {
+			}:
+			case <-bind.closedCh:
 				endpoint.conn = nil
+				c.Close()
 				return
 			}
 		}
-	}(bind.readQueue, endpoint)
+	}()
 
 	return nil
 }
@@ -206,7 +200,8 @@ func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
 	}
 
 	if nend.conn == nil {
-		return errors.New("connection not open yet")
+		errors.LogDebug(context.Background(), nend.dst.NetAddr(), " send on closed peer")
+		return errors.New("peer closed")
 	}
 
 	for _, buff := range buff {

+ 2 - 1
proxy/wireguard/client.go

@@ -121,7 +121,8 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
 				IPv4Enable: h.hasIPv4,
 				IPv6Enable: h.hasIPv6,
 			},
-			workers: int(h.conf.NumWorkers),
+			workers:   int(h.conf.NumWorkers),
+			readQueue: make(chan *netReadInfo),
 		},
 		ctx:      ctx,
 		dialer:   dialer,

+ 23 - 15
proxy/wireguard/server.go

@@ -2,8 +2,6 @@ package wireguard
 
 import (
 	"context"
-	goerrors "errors"
-	"io"
 
 	"github.com/xtls/xray-core/common/buf"
 	c "github.com/xtls/xray-core/common/ctx"
@@ -51,6 +49,8 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
 					IPv4Enable: hasIPv4,
 					IPv6Enable: hasIPv6,
 				},
+				workers:   int(conf.NumWorkers),
+				readQueue: make(chan *netReadInfo),
 			},
 		},
 		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
@@ -93,25 +93,31 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 
 	reader := buf.NewPacketReader(conn)
 	for {
-		mpayload, err := reader.ReadMultiBuffer()
+		mb, err := reader.ReadMultiBuffer()
 		if err != nil {
+			nep.conn = nil
+			buf.ReleaseMulti(mb)
 			return err
 		}
 
-		for _, payload := range mpayload {
-			v, ok := <-s.bindServer.readQueue
-			if !ok {
-				return nil
+		for i, b := range mb {
+			buff := b.Bytes()
+
+			if b.Len() > 3 {
+				buff[1] = 0
+				buff[2] = 0
+				buff[3] = 0
 			}
-			i, err := payload.Read(v.buff)
 
-			v.bytes = i
-			v.endpoint = nep
-			v.err = err
-			v.waiter.Done()
-			if err != nil && goerrors.Is(err, io.EOF) {
+			select {
+			case s.bindServer.readQueue <- &netReadInfo{
+				buff:     buff,
+				endpoint: nep,
+			}:
+			case <-s.bindServer.closedCh:
 				nep.conn = nil
-				return nil
+				buf.ReleaseMulti(mb[i:])
+				return errors.New("bind closed")
 			}
 		}
 	}
@@ -138,9 +144,11 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
 	// Currently we have no way to link to the original source address
 	inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
 	ctx = session.ContextWithInbound(ctx, &inbound)
+	content := new(session.Content)
 	if s.info.contentTag != nil {
-		ctx = session.ContextWithContent(ctx, s.info.contentTag)
+		content.SniffingRequest = s.info.contentTag.SniffingRequest
 	}
+	ctx = session.ContextWithContent(ctx, content)
 	ctx = session.SubContextFromMuxInbound(ctx)
 
 	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{

+ 0 - 17
proxy/wireguard/wireguard.go

@@ -8,25 +8,8 @@ import (
 	"strings"
 
 	"github.com/xtls/xray-core/common"
-	"github.com/xtls/xray-core/common/log"
-	"golang.zx2c4.com/wireguard/device"
 )
 
-var wgLogger = &device.Logger{
-	Verbosef: func(format string, args ...any) {
-		log.Record(&log.GeneralMessage{
-			Severity: log.Severity_Debug,
-			Content:  fmt.Sprintf(format, args...),
-		})
-	},
-	Errorf: func(format string, args ...any) {
-		log.Record(&log.GeneralMessage{
-			Severity: log.Severity_Error,
-			Content:  fmt.Sprintf(format, args...),
-		})
-	},
-}
-
 func init() {
 	common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
 		deviceConfig := config.(*DeviceConfig)