فهرست منبع

MUX: Prevent goroutine leak (#5110)

patterniha 9 ماه پیش
والد
کامیت
9f5dcb1591
7فایلهای تغییر یافته به همراه109 افزوده شده و 33 حذف شده
  1. 22 2
      app/reverse/bridge.go
  2. 9 1
      app/reverse/portal.go
  3. 12 10
      common/mux/client.go
  4. 50 12
      common/mux/server.go
  5. 11 3
      common/mux/session.go
  6. 2 2
      common/mux/session_test.go
  7. 3 3
      proxy/proxy.go

+ 22 - 2
app/reverse/bridge.go

@@ -9,6 +9,7 @@ import (
 	"github.com/xtls/xray-core/common/mux"
 	"github.com/xtls/xray-core/common/mux"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/signal"
 	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport"
@@ -53,6 +54,9 @@ func (b *Bridge) cleanup() {
 		if w.IsActive() {
 		if w.IsActive() {
 			activeWorkers = append(activeWorkers, w)
 			activeWorkers = append(activeWorkers, w)
 		}
 		}
+		if w.Closed() {
+			w.Timer.SetTimeout(0)
+		}
 	}
 	}
 
 
 	if len(activeWorkers) != len(b.workers) {
 	if len(activeWorkers) != len(b.workers) {
@@ -98,6 +102,7 @@ type BridgeWorker struct {
 	Worker     *mux.ServerWorker
 	Worker     *mux.ServerWorker
 	Dispatcher routing.Dispatcher
 	Dispatcher routing.Dispatcher
 	State      Control_State
 	State      Control_State
+	Timer      *signal.ActivityTimer
 }
 }
 
 
 func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) {
 func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) {
@@ -125,6 +130,10 @@ func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWo
 	}
 	}
 	w.Worker = worker
 	w.Worker = worker
 
 
+	terminate := func() {
+		worker.Close()
+	}
+	w.Timer = signal.CancelAfterInactivity(ctx, terminate, 60*time.Second)
 	return w, nil
 	return w, nil
 }
 }
 
 
@@ -144,6 +153,10 @@ func (w *BridgeWorker) IsActive() bool {
 	return w.State == Control_ACTIVE && !w.Worker.Closed()
 	return w.State == Control_ACTIVE && !w.Worker.Closed()
 }
 }
 
 
+func (w *BridgeWorker) Closed() bool {
+	return w.Worker.Closed()
+}
+
 func (w *BridgeWorker) Connections() uint32 {
 func (w *BridgeWorker) Connections() uint32 {
 	return w.Worker.ActiveConnections()
 	return w.Worker.ActiveConnections()
 }
 }
@@ -153,13 +166,20 @@ func (w *BridgeWorker) handleInternalConn(link *transport.Link) {
 	for {
 	for {
 		mb, err := reader.ReadMultiBuffer()
 		mb, err := reader.ReadMultiBuffer()
 		if err != nil {
 		if err != nil {
-			break
+			if w.Closed() {
+				w.Timer.SetTimeout(0)
+			} else {
+				w.Timer.SetTimeout(24 * time.Hour)
+			}
+			return
 		}
 		}
+		w.Timer.Update()
 		for _, b := range mb {
 		for _, b := range mb {
 			var ctl Control
 			var ctl Control
 			if err := proto.Unmarshal(b.Bytes(), &ctl); err != nil {
 			if err := proto.Unmarshal(b.Bytes(), &ctl); err != nil {
 				errors.LogInfoInner(context.Background(), err, "failed to parse proto message")
 				errors.LogInfoInner(context.Background(), err, "failed to parse proto message")
-				break
+				w.Timer.SetTimeout(0)
+				return
 			}
 			}
 			if ctl.State != w.State {
 			if ctl.State != w.State {
 				w.State = ctl.State
 				w.State = ctl.State

+ 9 - 1
app/reverse/portal.go

@@ -12,6 +12,7 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/serial"
 	"github.com/xtls/xray-core/common/serial"
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/signal"
 	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/features/outbound"
 	"github.com/xtls/xray-core/features/outbound"
 	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport"
@@ -159,6 +160,8 @@ func (p *StaticMuxPicker) cleanup() error {
 	for _, w := range p.workers {
 	for _, w := range p.workers {
 		if !w.Closed() {
 		if !w.Closed() {
 			activeWorkers = append(activeWorkers, w)
 			activeWorkers = append(activeWorkers, w)
+		} else {
+			w.timer.SetTimeout(0)
 		}
 		}
 	}
 	}
 
 
@@ -225,6 +228,7 @@ type PortalWorker struct {
 	reader   buf.Reader
 	reader   buf.Reader
 	draining bool
 	draining bool
 	counter  uint32
 	counter  uint32
+	timer    *signal.ActivityTimer
 }
 }
 
 
 func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
 func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
@@ -244,10 +248,14 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
 	if !f {
 	if !f {
 		return nil, errors.New("unable to dispatch control connection")
 		return nil, errors.New("unable to dispatch control connection")
 	}
 	}
+	terminate := func() {
+		client.Close()
+	}
 	w := &PortalWorker{
 	w := &PortalWorker{
 		client: client,
 		client: client,
 		reader: downlinkReader,
 		reader: downlinkReader,
 		writer: uplinkWriter,
 		writer: uplinkWriter,
+		timer:  signal.CancelAfterInactivity(ctx, terminate, 24*time.Hour), // // prevent leak
 	}
 	}
 	w.control = &task.Periodic{
 	w.control = &task.Periodic{
 		Execute:  w.heartbeat,
 		Execute:  w.heartbeat,
@@ -274,7 +282,6 @@ func (w *PortalWorker) heartbeat() error {
 		msg.State = Control_DRAIN
 		msg.State = Control_DRAIN
 
 
 		defer func() {
 		defer func() {
-			w.client.GetTimer().Reset(time.Second * 16)
 			common.Close(w.writer)
 			common.Close(w.writer)
 			common.Interrupt(w.reader)
 			common.Interrupt(w.reader)
 			w.writer = nil
 			w.writer = nil
@@ -286,6 +293,7 @@ func (w *PortalWorker) heartbeat() error {
 		b, err := proto.Marshal(msg)
 		b, err := proto.Marshal(msg)
 		common.Must(err)
 		common.Must(err)
 		mb := buf.MergeBytes(nil, b)
 		mb := buf.MergeBytes(nil, b)
+		w.timer.Update()
 		return w.writer.WriteMultiBuffer(mb)
 		return w.writer.WriteMultiBuffer(mb)
 	}
 	}
 	return nil
 	return nil

+ 12 - 10
common/mux/client.go

@@ -219,14 +219,16 @@ func (m *ClientWorker) WaitClosed() <-chan struct{} {
 	return m.done.Wait()
 	return m.done.Wait()
 }
 }
 
 
-func (m *ClientWorker) GetTimer() *time.Ticker {
-	return m.timer
+func (m *ClientWorker) Close() error {
+	return m.done.Close()
 }
 }
 
 
 func (m *ClientWorker) monitor() {
 func (m *ClientWorker) monitor() {
 	defer m.timer.Stop()
 	defer m.timer.Stop()
 
 
 	for {
 	for {
+		checkSize := m.sessionManager.Size()
+		checkCount := m.sessionManager.Count()
 		select {
 		select {
 		case <-m.done.Wait():
 		case <-m.done.Wait():
 			m.sessionManager.Close()
 			m.sessionManager.Close()
@@ -234,8 +236,7 @@ func (m *ClientWorker) monitor() {
 			common.Interrupt(m.link.Reader)
 			common.Interrupt(m.link.Reader)
 			return
 			return
 		case <-m.timer.C:
 		case <-m.timer.C:
-			size := m.sessionManager.Size()
-			if size == 0 && m.sessionManager.CloseIfNoSession() {
+			if m.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
 				common.Must(m.done.Close())
 				common.Must(m.done.Close())
 			}
 			}
 		}
 		}
@@ -255,7 +256,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
 	return nil
 	return nil
 }
 }
 
 
-func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
+func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
 	outbounds := session.OutboundsFromContext(ctx)
 	outbounds := session.OutboundsFromContext(ctx)
 	ob := outbounds[len(outbounds)-1]
 	ob := outbounds[len(outbounds)-1]
 	transferType := protocol.TransferTypeStream
 	transferType := protocol.TransferTypeStream
@@ -266,7 +267,6 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.
 	writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
 	writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
 	defer s.Close(false)
 	defer s.Close(false)
 	defer writer.Close()
 	defer writer.Close()
-	defer timer.Reset(time.Second * 16)
 
 
 	errors.LogInfo(ctx, "dispatching request to ", ob.Target)
 	errors.LogInfo(ctx, "dispatching request to ", ob.Target)
 	if err := writeFirstPayload(s.input, writer); err != nil {
 	if err := writeFirstPayload(s.input, writer); err != nil {
@@ -316,10 +316,12 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
 	}
 	}
 	s.input = link.Reader
 	s.input = link.Reader
 	s.output = link.Writer
 	s.output = link.Writer
-	if _, ok := link.Reader.(*pipe.Reader); ok {
-		go fetchInput(ctx, s, m.link.Writer, m.timer)
-	} else {
-		fetchInput(ctx, s, m.link.Writer, m.timer)
+	go fetchInput(ctx, s, m.link.Writer)
+	if _, ok := link.Reader.(*pipe.Reader); !ok {
+		select {
+		case <-ctx.Done():
+		case <-s.done.Wait():
+		}
 	}
 	}
 	return true
 	return true
 }
 }

+ 50 - 12
common/mux/server.go

@@ -3,6 +3,7 @@ package mux
 import (
 import (
 	"context"
 	"context"
 	"io"
 	"io"
+	"time"
 
 
 	"github.com/xtls/xray-core/app/dispatcher"
 	"github.com/xtls/xray-core/app/dispatcher"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common"
@@ -12,6 +13,7 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/signal/done"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport"
@@ -63,8 +65,15 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t
 		return s.dispatcher.DispatchLink(ctx, dest, link)
 		return s.dispatcher.DispatchLink(ctx, dest, link)
 	}
 	}
 	link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
 	link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
-	_, err := NewServerWorker(ctx, s.dispatcher, link)
-	return err
+	worker, err := NewServerWorker(ctx, s.dispatcher, link)
+	if err != nil {
+		return err
+	}
+	select {
+	case <-ctx.Done():
+	case <-worker.done.Wait():
+	}
+	return nil
 }
 }
 
 
 // Start implements common.Runnable.
 // Start implements common.Runnable.
@@ -81,6 +90,8 @@ type ServerWorker struct {
 	dispatcher     routing.Dispatcher
 	dispatcher     routing.Dispatcher
 	link           *transport.Link
 	link           *transport.Link
 	sessionManager *SessionManager
 	sessionManager *SessionManager
+	done           *done.Instance
+	timer          *time.Ticker
 }
 }
 
 
 func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) {
 func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) {
@@ -88,15 +99,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.
 		dispatcher:     d,
 		dispatcher:     d,
 		link:           link,
 		link:           link,
 		sessionManager: NewSessionManager(),
 		sessionManager: NewSessionManager(),
+		done:           done.New(),
+		timer:          time.NewTicker(60 * time.Second),
 	}
 	}
 	if inbound := session.InboundFromContext(ctx); inbound != nil {
 	if inbound := session.InboundFromContext(ctx); inbound != nil {
 		inbound.CanSpliceCopy = 3
 		inbound.CanSpliceCopy = 3
 	}
 	}
-	if _, ok := link.Reader.(*pipe.Reader); ok {
-		go worker.run(ctx)
-	} else {
-		worker.run(ctx)
-	}
+	go worker.run(ctx)
+	go worker.monitor()
 	return worker, nil
 	return worker, nil
 }
 }
 
 
@@ -111,12 +121,40 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
 	s.Close(false)
 	s.Close(false)
 }
 }
 
 
+func (w *ServerWorker) monitor() {
+	defer w.timer.Stop()
+
+	for {
+		checkSize := w.sessionManager.Size()
+		checkCount := w.sessionManager.Count()
+		select {
+		case <-w.done.Wait():
+			w.sessionManager.Close()
+			common.Interrupt(w.link.Writer)
+			common.Interrupt(w.link.Reader)
+			return
+		case <-w.timer.C:
+			if w.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
+				common.Must(w.done.Close())
+			}
+		}
+	}
+}
+
 func (w *ServerWorker) ActiveConnections() uint32 {
 func (w *ServerWorker) ActiveConnections() uint32 {
 	return uint32(w.sessionManager.Size())
 	return uint32(w.sessionManager.Size())
 }
 }
 
 
 func (w *ServerWorker) Closed() bool {
 func (w *ServerWorker) Closed() bool {
-	return w.sessionManager.Closed()
+	return w.done.Done()
+}
+
+func (w *ServerWorker) WaitClosed() <-chan struct{} {
+	return w.done.Wait()
+}
+
+func (w *ServerWorker) Close() error {
+	return w.done.Close()
 }
 }
 
 
 func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
 func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
@@ -317,11 +355,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead
 }
 }
 
 
 func (w *ServerWorker) run(ctx context.Context) {
 func (w *ServerWorker) run(ctx context.Context) {
-	reader := &buf.BufferedReader{Reader: w.link.Reader}
+	defer func() {
+		common.Must(w.done.Close())
+	}()
 
 
-	defer w.sessionManager.Close()
-	defer common.Interrupt(w.link.Reader)
-	defer common.Interrupt(w.link.Writer)
+	reader := &buf.BufferedReader{Reader: w.link.Reader}
 
 
 	for {
 	for {
 		select {
 		select {

+ 11 - 3
common/mux/session.go

@@ -12,6 +12,7 @@ import (
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/signal/done"
 	"github.com/xtls/xray-core/transport/pipe"
 	"github.com/xtls/xray-core/transport/pipe"
 )
 )
 
 
@@ -53,7 +54,7 @@ func (m *SessionManager) Count() int {
 func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
 func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
 	m.Lock()
 	m.Lock()
 	defer m.Unlock()
 	defer m.Unlock()
-	
+
 	MaxConcurrency := int(Strategy.MaxConcurrency)
 	MaxConcurrency := int(Strategy.MaxConcurrency)
 	MaxConnection := uint16(Strategy.MaxConnection)
 	MaxConnection := uint16(Strategy.MaxConnection)
 
 
@@ -65,6 +66,7 @@ func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
 	s := &Session{
 	s := &Session{
 		ID:     m.count,
 		ID:     m.count,
 		parent: m,
 		parent: m,
+		done:   done.New(),
 	}
 	}
 	m.sessions[s.ID] = s
 	m.sessions[s.ID] = s
 	return s
 	return s
@@ -115,7 +117,7 @@ func (m *SessionManager) Get(id uint16) (*Session, bool) {
 	return s, found
 	return s, found
 }
 }
 
 
-func (m *SessionManager) CloseIfNoSession() bool {
+func (m *SessionManager) CloseIfNoSessionAndIdle(checkSize int, checkCount int) bool {
 	m.Lock()
 	m.Lock()
 	defer m.Unlock()
 	defer m.Unlock()
 
 
@@ -123,11 +125,13 @@ func (m *SessionManager) CloseIfNoSession() bool {
 		return true
 		return true
 	}
 	}
 
 
-	if len(m.sessions) != 0 {
+	if len(m.sessions) != 0 || checkSize != 0 || checkCount != int(m.count) {
 		return false
 		return false
 	}
 	}
 
 
 	m.closed = true
 	m.closed = true
+
+	m.sessions = nil
 	return true
 	return true
 }
 }
 
 
@@ -157,6 +161,7 @@ type Session struct {
 	ID           uint16
 	ID           uint16
 	transferType protocol.TransferType
 	transferType protocol.TransferType
 	closed       bool
 	closed       bool
+	done         *done.Instance
 	XUDP         *XUDP
 	XUDP         *XUDP
 }
 }
 
 
@@ -171,6 +176,9 @@ func (s *Session) Close(locked bool) error {
 		return nil
 		return nil
 	}
 	}
 	s.closed = true
 	s.closed = true
+	if s.done != nil {
+		s.done.Close()
+	}
 	if s.XUDP == nil {
 	if s.XUDP == nil {
 		common.Interrupt(s.input)
 		common.Interrupt(s.input)
 		common.Close(s.output)
 		common.Close(s.output)

+ 2 - 2
common/mux/session_test.go

@@ -41,11 +41,11 @@ func TestSessionManagerClose(t *testing.T) {
 	m := NewSessionManager()
 	m := NewSessionManager()
 	s := m.Allocate(&ClientStrategy{})
 	s := m.Allocate(&ClientStrategy{})
 
 
-	if m.CloseIfNoSession() {
+	if m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
 		t.Error("able to close")
 		t.Error("able to close")
 	}
 	}
 	m.Remove(false, s.ID)
 	m.Remove(false, s.ID)
-	if !m.CloseIfNoSession() {
+	if !m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
 		t.Error("not able to close")
 		t.Error("not able to close")
 	}
 	}
 }
 }

+ 3 - 3
proxy/proxy.go

@@ -678,10 +678,10 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
 			errors.LogInfo(ctx, "CopyRawConn splice")
 			errors.LogInfo(ctx, "CopyRawConn splice")
 			statWriter, _ := writer.(*dispatcher.SizeStatWriter)
 			statWriter, _ := writer.(*dispatcher.SizeStatWriter)
 			//runtime.Gosched() // necessary
 			//runtime.Gosched() // necessary
-			time.Sleep(time.Millisecond)    // without this, there will be a rare ssl error for freedom splice
-			timer.SetTimeout(8 * time.Hour) // prevent leak, just in case
+			time.Sleep(time.Millisecond)     // without this, there will be a rare ssl error for freedom splice
+			timer.SetTimeout(24 * time.Hour) // prevent leak, just in case
 			if inTimer != nil {
 			if inTimer != nil {
-				inTimer.SetTimeout(8 * time.Hour)
+				inTimer.SetTimeout(24 * time.Hour)
 			}
 			}
 			w, err := tc.ReadFrom(readerConn)
 			w, err := tc.ReadFrom(readerConn)
 			if readCounter != nil {
 			if readCounter != nil {