Ver código fonte

Dokodemo: Recycle inactive fakeudp connections

风扇滑翔翼 7 meses atrás
pai
commit
26b246fa92
1 arquivos alterados com 66 adições e 19 exclusões
  1. 66 19
      proxy/dokodemo/dokodemo.go

+ 66 - 19
proxy/dokodemo/dokodemo.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"strconv"
 	"strings"
+	"time"
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/buf"
@@ -12,6 +13,7 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/common/utils"
 	"github.com/xtls/xray-core/core"
 	"github.com/xtls/xray-core/features/policy"
 	"github.com/xtls/xray-core/features/routing"
@@ -176,7 +178,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 			if err != nil {
 				return err
 			}
-			writer = NewPacketWriter(pConn, &dest, mark, back)
+			writer = NewPacketWriter(ctx, pConn, &dest, mark, back)
 			defer writer.(*PacketWriter).Close() // close fake UDP conns
 		}
 	}
@@ -190,22 +192,35 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 	return nil // Unlike Dispatch(), DispatchLink() will not return until the outbound finishes Process()
 }
 
-func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer {
+func NewPacketWriter(ctx context.Context, conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer {
+	ctx, cancel := context.WithCancel(ctx)
 	writer := &PacketWriter{
-		conn:  conn,
-		conns: make(map[net.Destination]net.PacketConn),
-		mark:  mark,
-		back:  back,
+		ctx:    ctx,
+		cancel: cancel,
+		conn:   conn,
+		conns:  utils.NewTypedSyncMap[net.Destination, *PacketConnTimeWrapper](),
+		inactiveConns: make([]*PacketConnTimeWrapper, 0),
+		mark:   mark,
+		back:   back,
 	}
-	writer.conns[*d] = conn
+	timedconn := &PacketConnTimeWrapper{
+		PacketConn:   conn,
+		lastUsedTime: time.Now(),
+		isMainConn:   true,
+	}
+	writer.conns.Store(*d, timedconn)
+	go writer.cleanInactiveConns()
 	return writer
 }
 
 type PacketWriter struct {
-	conn  net.PacketConn
-	conns map[net.Destination]net.PacketConn
-	mark  int
-	back  *net.UDPAddr
+	ctx    context.Context
+	cancel context.CancelFunc
+	conn   net.PacketConn
+	conns  *utils.TypedSyncMap[net.Destination, *PacketConnTimeWrapper]
+	inactiveConns []*PacketConnTimeWrapper
+	mark   int
+	back   *net.UDPAddr
 }
 
 func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
@@ -217,26 +232,30 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		}
 		var err error
 		if b.UDP != nil && b.UDP.Address.Family().IsIP() {
-			conn := w.conns[*b.UDP]
+			conn, _ := w.conns.Load(*b.UDP)
 			if conn == nil {
-				conn, err = FakeUDP(
+				fakeudpconn, err := FakeUDP(
 					&net.UDPAddr{
 						IP:   b.UDP.Address.IP(),
 						Port: int(b.UDP.Port),
 					},
 					w.mark,
 				)
+				conn = &PacketConnTimeWrapper{
+					PacketConn:   fakeudpconn,
+					lastUsedTime: time.Now(),
+				}
 				if err != nil {
 					errors.LogInfo(context.Background(), err.Error())
 					b.Release()
 					continue
 				}
-				w.conns[*b.UDP] = conn
+				w.conns.Store(*b.UDP, conn)
 			}
 			_, err = conn.WriteTo(b.Bytes(), w.back)
 			if err != nil {
 				errors.LogInfo(context.Background(), err.Error())
-				w.conns[*b.UDP] = nil
+				w.conns.Delete(*b.UDP)
 				conn.Close()
 			}
 			b.Release()
@@ -253,10 +272,38 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 }
 
 func (w *PacketWriter) Close() error {
-	for _, conn := range w.conns {
-		if conn != nil {
-			conn.Close()
+	w.cancel()
+	w.conns.Range(func(key net.Destination, conn *PacketConnTimeWrapper) bool {
+		common.CloseIfExists(conn)
+		return true
+	})
+	return nil
+}
+
+func (w *PacketWriter) cleanInactiveConns() {
+	ticker := time.NewTicker(60 * time.Second)
+	defer ticker.Stop()
+	select {
+	case <-ticker.C:
+		if len(w.inactiveConns) > 0 {
+			for _, conn := range w.inactiveConns {
+				common.CloseIfExists(conn)
+			}
 		}
+		w.conns.Range(func(key net.Destination, conn *PacketConnTimeWrapper) bool {
+			if conn != nil && !conn.isMainConn && time.Since(conn.lastUsedTime) > 120*time.Second {
+				w.conns.Delete(key)
+			}
+			w.inactiveConns = append(w.inactiveConns, conn)
+			return true
+		})
+	case <-w.ctx.Done():
+		return
 	}
-	return nil
+}
+
+type PacketConnTimeWrapper struct {
+	net.PacketConn
+	lastUsedTime time.Time
+	isMainConn   bool
 }