Преглед на файлове

Routing: Reduce peak memory usage (#5488)

https://github.com/XTLS/Xray-core/pull/5488#issuecomment-3711430369

For https://github.com/XTLS/Xray-core/issues/4422
Hossin Asaadi преди 5 месеца
родител
ревизия
c715154309
променени са 4 файла, в които са добавени 235 реда и са изтрити 115 реда
  1. 43 1
      app/router/condition.go
  2. 99 2
      app/router/config.go
  3. 52 0
      common/platform/filesystem/asset_tools.go
  4. 41 112
      infra/conf/router.go

+ 43 - 1
app/router/condition.go

@@ -307,6 +307,48 @@ func (m *AttributeMatcher) Apply(ctx routing.Context) bool {
 	return m.Match(attributes)
 }
 
+// Geo attribute
+type GeoAttributeMatcher interface {
+	Match(*Domain) bool
+}
+
+type GeoBooleanMatcher string
+
+func (m GeoBooleanMatcher) Match(domain *Domain) bool {
+	for _, attr := range domain.Attribute {
+		if attr.Key == string(m) {
+			return true
+		}
+	}
+	return false
+}
+
+type GeoAttributeList struct {
+	Matcher []GeoAttributeMatcher
+}
+
+func (al *GeoAttributeList) Match(domain *Domain) bool {
+	for _, matcher := range al.Matcher {
+		if !matcher.Match(domain) {
+			return false
+		}
+	}
+	return true
+}
+
+func (al *GeoAttributeList) IsEmpty() bool {
+	return len(al.Matcher) == 0
+}
+
+func ParseAttrs(attrs []string) *GeoAttributeList {
+	al := new(GeoAttributeList)
+	for _, attr := range attrs {
+		lc := strings.ToLower(attr)
+		al.Matcher = append(al.Matcher, GeoBooleanMatcher(lc))
+	}
+	return al
+}
+
 type ProcessNameMatcher struct {
 	names []string
 }
@@ -343,4 +385,4 @@ func (m *ProcessNameMatcher) Apply(ctx routing.Context) bool {
 		}
 	}
 	return false
-}
+}

+ 99 - 2
app/router/config.go

@@ -3,11 +3,14 @@ package router
 import (
 	"context"
 	"regexp"
+	"runtime"
 	"strings"
 
 	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/common/platform/filesystem"
 	"github.com/xtls/xray-core/features/outbound"
 	"github.com/xtls/xray-core/features/routing"
+	"google.golang.org/protobuf/proto"
 )
 
 type Rule struct {
@@ -73,7 +76,15 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	}
 
 	if len(rr.Geoip) > 0 {
-		cond, err := NewIPMatcher(rr.Geoip, MatcherAsType_Target)
+		geoip := rr.Geoip
+		if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
+			var err error
+			geoip, err = getGeoIPList(rr.Geoip)
+			if err != nil {
+				return nil, errors.New("failed to build geoip from mmap").Base(err)
+			}
+		}
+		cond, err := NewIPMatcher(geoip, MatcherAsType_Target)
 		if err != nil {
 			return nil, err
 		}
@@ -98,7 +109,16 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	}
 
 	if len(rr.Domain) > 0 {
-		matcher, err := NewMphMatcherGroup(rr.Domain)
+		domains := rr.Domain
+		if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
+			var err error
+			domains, err = getDomainList(rr.Domain)
+			if err != nil {
+				return nil, errors.New("failed to build domains from mmap").Base(err)
+			}
+		}
+
+		matcher, err := NewMphMatcherGroup(domains)
 		if err != nil {
 			return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
 		}
@@ -167,3 +187,80 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch
 		return nil, errors.New("unrecognized balancer type")
 	}
 }
+
+func getGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
+	geoipList := []*GeoIP{}
+	for _, ip := range ips {
+		if ip.CountryCode != "" {
+			val := strings.Split(ip.CountryCode, "_")
+			fileName := "geoip.dat"
+			if len(val) == 2 {
+				fileName = strings.ToLower(val[0])
+			}
+			bs, err := filesystem.ReadAsset(fileName)
+			if err != nil {
+				return nil, errors.New("failed to load file: ", fileName).Base(err)
+			}
+			bs = filesystem.Find(bs, []byte(ip.CountryCode))
+
+			var geoip GeoIP
+
+			if err := proto.Unmarshal(bs, &geoip); err != nil {
+				return nil, errors.New("failed Unmarshal :").Base(err)
+			}
+			geoipList = append(geoipList, &geoip)
+
+		} else {
+			geoipList = append(geoipList, ip)
+		}
+	}
+	return geoipList, nil
+
+}
+
+func getDomainList(domains []*Domain) ([]*Domain, error) {
+	domainList := []*Domain{}
+	for _, domain := range domains {
+		val := strings.Split(domain.Value, "_")
+
+		if len(val) >= 2 {
+
+			fileName := val[0]
+			code := val[1]
+
+			bs, err := filesystem.ReadAsset(fileName)
+			if err != nil {
+				return nil, errors.New("failed to load file: ", fileName).Base(err)
+			}
+			bs = filesystem.Find(bs, []byte(code))
+			var geosite GeoSite
+
+			if err := proto.Unmarshal(bs, &geosite); err != nil {
+				return nil, errors.New("failed Unmarshal :").Base(err)
+			}
+
+			// parse attr
+			if len(val) == 3 {
+				siteWithAttr := strings.Split(val[2], ",")
+				attrs := ParseAttrs(siteWithAttr)
+
+				if !attrs.IsEmpty() {
+					filteredDomains := make([]*Domain, 0, len(domains))
+					for _, domain := range geosite.Domain {
+						if attrs.Match(domain) {
+							filteredDomains = append(filteredDomains, domain)
+						}
+					}
+					geosite.Domain = filteredDomains
+				}
+
+			}
+
+			domainList = append(domainList, geosite.Domain...)
+
+		} else {
+			domainList = append(domainList, domain)
+		}
+	}
+	return domainList, nil
+}

+ 52 - 0
common/platform/filesystem/asset_tools.go

@@ -0,0 +1,52 @@
+package filesystem
+
+func DecodeVarint(buf []byte) (x uint64, n int) {
+	for shift := uint(0); shift < 64; shift += 7 {
+		if n >= len(buf) {
+			return 0, 0
+		}
+		b := uint64(buf[n])
+		n++
+		x |= (b & 0x7F) << shift
+		if (b & 0x80) == 0 {
+			return x, n
+		}
+	}
+
+	// The number is too large to represent in a 64-bit value.
+	return 0, 0
+}
+
+func Find(data, code []byte) []byte {
+	codeL := len(code)
+	if codeL == 0 {
+		return nil
+	}
+	for {
+		dataL := len(data)
+		if dataL < 2 {
+			return nil
+		}
+		x, y := DecodeVarint(data[1:])
+		if x == 0 && y == 0 {
+			return nil
+		}
+		headL, bodyL := 1+y, int(x)
+		dataL -= headL
+		if dataL < bodyL {
+			return nil
+		}
+		data = data[headL:]
+		if int(data[1]) == codeL {
+			for i := 0; i < codeL && data[2+i] == code[i]; i++ {
+				if i+1 == codeL {
+					return data[:bodyL]
+				}
+			}
+		}
+		if dataL == bodyL {
+			return nil
+		}
+		data = data[bodyL:]
+	}
+}

+ 41 - 112
infra/conf/router.go

@@ -203,17 +203,23 @@ func loadFile(file string) ([]byte, error) {
 func loadIP(file, code string) ([]*router.CIDR, error) {
 	index := file + ":" + code
 	if IPCache[index] == nil {
-		bs, err := loadFile(file)
-		if err != nil {
-			return nil, errors.New("failed to load file: ", file).Base(err)
-		}
-		bs = find(bs, []byte(code))
-		if bs == nil {
-			return nil, errors.New("code not found in ", file, ": ", code)
-		}
 		var geoip router.GeoIP
-		if err := proto.Unmarshal(bs, &geoip); err != nil {
-			return nil, errors.New("error unmarshal IP in ", file, ": ", code).Base(err)
+
+		if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
+			// dont pass code becuase we have country code in top level router.GeoIP
+			geoip = router.GeoIP{Cidr: []*router.CIDR{}}
+		} else {
+			bs, err := loadFile(file)
+			if err != nil {
+				return nil, errors.New("failed to load file: ", file).Base(err)
+			}
+			bs = filesystem.Find(bs, []byte(code))
+			if bs == nil {
+				return nil, errors.New("code not found in ", file, ": ", code)
+			}
+			if err := proto.Unmarshal(bs, &geoip); err != nil {
+				return nil, errors.New("error unmarshal IP in ", file, ": ", code).Base(err)
+			}
 		}
 		defer runtime.GC()     // or debug.FreeOSMemory()
 		return geoip.Cidr, nil // do not cache geoip
@@ -225,115 +231,33 @@ func loadIP(file, code string) ([]*router.CIDR, error) {
 func loadSite(file, code string) ([]*router.Domain, error) {
 	index := file + ":" + code
 	if SiteCache[index] == nil {
-		bs, err := loadFile(file)
-		if err != nil {
-			return nil, errors.New("failed to load file: ", file).Base(err)
-		}
-		bs = find(bs, []byte(code))
-		if bs == nil {
-			return nil, errors.New("list not found in ", file, ": ", code)
-		}
 		var geosite router.GeoSite
-		if err := proto.Unmarshal(bs, &geosite); err != nil {
-			return nil, errors.New("error unmarshal Site in ", file, ": ", code).Base(err)
-		}
-		defer runtime.GC()         // or debug.FreeOSMemory()
-		return geosite.Domain, nil // do not cache geosite
-		SiteCache[index] = &geosite
-	}
-	return SiteCache[index].Domain, nil
-}
 
-func DecodeVarint(buf []byte) (x uint64, n int) {
-	for shift := uint(0); shift < 64; shift += 7 {
-		if n >= len(buf) {
-			return 0, 0
-		}
-		b := uint64(buf[n])
-		n++
-		x |= (b & 0x7F) << shift
-		if (b & 0x80) == 0 {
-			return x, n
-		}
-	}
+		if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
+			// pass file:code so can build optimized matcher later
+			domain := router.Domain{Value: file + "_" + code}
+			geosite = router.GeoSite{Domain: []*router.Domain{&domain}}
 
-	// The number is too large to represent in a 64-bit value.
-	return 0, 0
-}
+		} else {
 
-func find(data, code []byte) []byte {
-	codeL := len(code)
-	if codeL == 0 {
-		return nil
-	}
-	for {
-		dataL := len(data)
-		if dataL < 2 {
-			return nil
-		}
-		x, y := DecodeVarint(data[1:])
-		if x == 0 && y == 0 {
-			return nil
-		}
-		headL, bodyL := 1+y, int(x)
-		dataL -= headL
-		if dataL < bodyL {
-			return nil
-		}
-		data = data[headL:]
-		if int(data[1]) == codeL {
-			for i := 0; i < codeL && data[2+i] == code[i]; i++ {
-				if i+1 == codeL {
-					return data[:bodyL]
-				}
+			bs, err := loadFile(file)
+			if err != nil {
+				return nil, errors.New("failed to load file: ", file).Base(err)
+			}
+			bs = filesystem.Find(bs, []byte(code))
+			if bs == nil {
+				return nil, errors.New("list not found in ", file, ": ", code)
+			}
+			if err := proto.Unmarshal(bs, &geosite); err != nil {
+				return nil, errors.New("error unmarshal Site in ", file, ": ", code).Base(err)
 			}
 		}
-		if dataL == bodyL {
-			return nil
-		}
-		data = data[bodyL:]
-	}
-}
-
-type AttributeMatcher interface {
-	Match(*router.Domain) bool
-}
-
-type BooleanMatcher string
-
-func (m BooleanMatcher) Match(domain *router.Domain) bool {
-	for _, attr := range domain.Attribute {
-		if attr.Key == string(m) {
-			return true
-		}
-	}
-	return false
-}
-
-type AttributeList struct {
-	matcher []AttributeMatcher
-}
-
-func (al *AttributeList) Match(domain *router.Domain) bool {
-	for _, matcher := range al.matcher {
-		if !matcher.Match(domain) {
-			return false
-		}
-	}
-	return true
-}
 
-func (al *AttributeList) IsEmpty() bool {
-	return len(al.matcher) == 0
-}
-
-func parseAttrs(attrs []string) *AttributeList {
-	al := new(AttributeList)
-	for _, attr := range attrs {
-		lc := strings.ToLower(attr)
-		al.matcher = append(al.matcher, BooleanMatcher(lc))
+		defer runtime.GC()         // or debug.FreeOSMemory()
+		return geosite.Domain, nil // do not cache geosite
+		SiteCache[index] = &geosite
 	}
-	return al
+	return SiteCache[index].Domain, nil
 }
 
 func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, error) {
@@ -342,7 +266,7 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
 		return nil, errors.New("empty site")
 	}
 	country := strings.ToUpper(parts[0])
-	attrs := parseAttrs(parts[1:])
+	attrs := router.ParseAttrs(parts[1:])
 	domains, err := loadSite(file, country)
 	if err != nil {
 		return nil, err
@@ -352,6 +276,11 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
 		return domains, nil
 	}
 
+	if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
+		domains[0].Value = domains[0].Value + "_" + strings.Join(parts[1:], ",")
+		return domains, nil
+	}
+
 	filteredDomains := make([]*router.Domain, 0, len(domains))
 	for _, domain := range domains {
 		if attrs.Match(domain) {