Răsfoiți Sursa

Remove redundant stats in mux and bridge dispatcher (#5466)

Fixes https://github.com/XTLS/Xray-core/issues/5446
yuhan6665 5 luni în urmă
părinte
comite
a54e1f2be4

+ 7 - 6
app/dispatcher/default.go

@@ -196,7 +196,7 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *tran
 	return inboundLink, outboundLink
 }
 
-func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link) *transport.Link {
+func WrapLink(ctx context.Context, policyManager policy.Manager, statsManager stats.Manager, link *transport.Link) *transport.Link {
 	sessionInbound := session.InboundFromContext(ctx)
 	var user *protocol.MemoryUser
 	if sessionInbound != nil {
@@ -206,16 +206,16 @@ func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link)
 	link.Reader = &buf.TimeoutWrapperReader{Reader: link.Reader}
 
 	if user != nil && len(user.Email) > 0 {
-		p := d.policy.ForLevel(user.Level)
+		p := policyManager.ForLevel(user.Level)
 		if p.Stats.UserUplink {
 			name := "user>>>" + user.Email + ">>>traffic>>>uplink"
-			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
+			if c, _ := stats.GetOrRegisterCounter(statsManager, name); c != nil {
 				link.Reader.(*buf.TimeoutWrapperReader).Counter = c
 			}
 		}
 		if p.Stats.UserDownlink {
 			name := "user>>>" + user.Email + ">>>traffic>>>downlink"
-			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
+			if c, _ := stats.GetOrRegisterCounter(statsManager, name); c != nil {
 				link.Writer = &SizeStatWriter{
 					Counter: c,
 					Writer:  link.Writer,
@@ -224,7 +224,7 @@ func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link)
 		}
 		if p.Stats.UserOnline {
 			name := "user>>>" + user.Email + ">>>online"
-			if om, _ := stats.GetOrRegisterOnlineMap(d.stats, name); om != nil {
+			if om, _ := stats.GetOrRegisterOnlineMap(statsManager, name); om != nil {
 				sessionInbounds := session.InboundFromContext(ctx)
 				userIP := sessionInbounds.Source.Address.String()
 				om.AddIP(userIP)
@@ -357,7 +357,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 		content = new(session.Content)
 		ctx = session.ContextWithContent(ctx, content)
 	}
-	outbound = d.WrapLink(ctx, outbound)
+	outbound = WrapLink(ctx, d.policy, d.stats, outbound)
 	sniffingRequest := content.SniffingRequest
 	if !sniffingRequest.Enabled {
 		d.routedDispatch(ctx, outbound, destination)
@@ -449,6 +449,7 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw
 	}
 	return contentResult, contentErr
 }
+
 func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
 	outbounds := session.OutboundsFromContext(ctx)
 	ob := outbounds[len(outbounds)-1]

+ 0 - 4
app/reverse/bridge.go

@@ -229,10 +229,6 @@ func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, l
 		}
 		return w.Dispatcher.DispatchLink(ctx, dest, link)
 	}
-
-	if d, ok := w.Dispatcher.(routing.WrapLinkDispatcher); ok {
-		link = d.WrapLink(ctx, link)
-	}
 	w.handleInternalConn(link)
 
 	return nil

+ 0 - 3
common/mux/server.go

@@ -63,9 +63,6 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t
 	if dest.Address != muxCoolAddress {
 		return s.dispatcher.DispatchLink(ctx, dest, link)
 	}
-	if d, ok := s.dispatcher.(routing.WrapLinkDispatcher); ok {
-		link = d.WrapLink(ctx, link)
-	}
 	worker, err := NewServerWorker(ctx, s.dispatcher, link)
 	if err != nil {
 		return err

+ 0 - 6
features/routing/dispatcher.go

@@ -26,9 +26,3 @@ type Dispatcher interface {
 func DispatcherType() interface{} {
 	return (*Dispatcher)(nil)
 }
-
-// Just for type assertion
-type WrapLinkDispatcher interface {
-	Dispatcher
-	WrapLink(ctx context.Context, link *transport.Link) *transport.Link
-}

+ 9 - 12
proxy/vless/inbound/inbound.go

@@ -12,6 +12,7 @@ import (
 	"time"
 	"unsafe"
 
+	"github.com/xtls/xray-core/app/dispatcher"
 	"github.com/xtls/xray-core/app/reverse"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
@@ -31,6 +32,7 @@ import (
 	"github.com/xtls/xray-core/features/outbound"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/routing"
+	"github.com/xtls/xray-core/features/stats"
 	"github.com/xtls/xray-core/proxy"
 	"github.com/xtls/xray-core/proxy/vless"
 	"github.com/xtls/xray-core/proxy/vless/encoding"
@@ -72,10 +74,11 @@ func init() {
 type Handler struct {
 	inboundHandlerManager  feature_inbound.Manager
 	policyManager          policy.Manager
+	stats                  stats.Manager
 	validator              vless.Validator
 	decryption             *encryption.ServerInstance
 	outboundHandlerManager outbound.Manager
-	wrapLink               func(ctx context.Context, link *transport.Link) *transport.Link
+	defaultDispatcher      routing.Dispatcher
 	ctx                    context.Context
 	fallbacks              map[string]map[string]map[string]*Fallback // or nil
 	// regexps               map[string]*regexp.Regexp       // or nil
@@ -84,16 +87,13 @@ type Handler struct {
 // New creates a new VLess inbound handler.
 func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) {
 	v := core.MustFromContext(ctx)
-	var wrapLinkFunc func(ctx context.Context, link *transport.Link) *transport.Link
-	if dispatcher, ok := v.GetFeature(routing.DispatcherType()).(routing.WrapLinkDispatcher); ok {
-		wrapLinkFunc = dispatcher.WrapLink
-	}
 	handler := &Handler{
 		inboundHandlerManager:  v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
 		policyManager:          v.GetFeature(policy.ManagerType()).(policy.Manager),
+		stats:                  v.GetFeature(stats.ManagerType()).(stats.Manager),
 		validator:              validator,
 		outboundHandlerManager: v.GetFeature(outbound.ManagerType()).(outbound.Manager),
-		wrapLink:               wrapLinkFunc,
+		defaultDispatcher:      v.GetFeature(routing.DispatcherType()).(routing.Dispatcher),
 		ctx:                    ctx,
 	}
 
@@ -264,7 +264,7 @@ func (*Handler) Network() []net.Network {
 }
 
 // Process implements proxy.Inbound.Process().
-func (h *Handler) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
+func (h *Handler) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatch routing.Dispatcher) error {
 	iConn := stat.TryUnwrapStatsConn(connection)
 
 	if h.decryption != nil {
@@ -623,13 +623,10 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 		if err != nil {
 			return err
 		}
-		if h.wrapLink == nil {
-			return errors.New("VLESS reverse must have a dispatcher that implemented routing.WrapLinkDispatcher")
-		}
-		return r.NewMux(ctx, h.wrapLink(ctx, &transport.Link{Reader: clientReader, Writer: clientWriter}))
+		return r.NewMux(ctx, dispatcher.WrapLink(ctx, h.policyManager, h.stats, &transport.Link{Reader: clientReader, Writer: clientWriter}))
 	}
 
-	if err := dispatcher.DispatchLink(ctx, request.Destination(), &transport.Link{
+	if err := dispatch.DispatchLink(ctx, request.Destination(), &transport.Link{
 		Reader: clientReader,
 		Writer: clientWriter},
 	); err != nil {