소스 검색

Geodat: Reduce peak memory usage (#5581)

Fixes https://github.com/XTLS/Xray-core/commit/5f7474120f523ad1e36174481e0b16c3446cc29c
Meow 4 달 전
부모
커밋
9a04eecaf9
7개의 변경된 파일117개의 추가작업 그리고 94개의 파일을 삭제
  1. 4 1
      app/dns/hosts.go
  2. 9 1
      app/dns/nameserver.go
  3. 2 1
      app/router/condition.go
  4. 2 1
      app/router/condition_geoip.go
  5. 9 0
      app/router/config.go
  6. 4 0
      common/platform/filesystem/file.go
  7. 87 90
      infra/conf/router.go

+ 4 - 1
app/dns/hosts.go

@@ -2,6 +2,7 @@ package dns
 
 
 import (
 import (
 	"context"
 	"context"
+	"runtime"
 	"strconv"
 	"strconv"
 
 
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/errors"
@@ -24,7 +25,9 @@ func NewStaticHosts(hosts []*Config_HostMapping) (*StaticHosts, error) {
 		matchers: g,
 		matchers: g,
 	}
 	}
 
 
-	for _, mapping := range hosts {
+	defer runtime.GC()
+	for i, mapping := range hosts {
+		hosts[i] = nil
 		matcher, err := toStrMatcher(mapping.Type, mapping.Domain)
 		matcher, err := toStrMatcher(mapping.Type, mapping.Domain)
 		if err != nil {
 		if err != nil {
 			errors.LogErrorInner(context.Background(), err, "failed to create domain matcher, ignore domain rule [type: ", mapping.Type, ", domain: ", mapping.Domain, "]")
 			errors.LogErrorInner(context.Background(), err, "failed to create domain matcher, ignore domain rule [type: ", mapping.Type, ", domain: ", mapping.Domain, "]")

+ 9 - 1
app/dns/nameserver.go

@@ -3,6 +3,7 @@ package dns
 import (
 import (
 	"context"
 	"context"
 	"net/url"
 	"net/url"
+	"runtime"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
@@ -131,7 +132,8 @@ func NewClient(
 		var rules []string
 		var rules []string
 		ruleCurr := 0
 		ruleCurr := 0
 		ruleIter := 0
 		ruleIter := 0
-		for _, domain := range ns.PrioritizedDomain {
+		for i, domain := range ns.PrioritizedDomain {
+			ns.PrioritizedDomain[i] = nil
 			domainRule, err := toStrMatcher(domain.Type, domain.Domain)
 			domainRule, err := toStrMatcher(domain.Type, domain.Domain)
 			if err != nil {
 			if err != nil {
 				errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]")
 				errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]")
@@ -154,6 +156,8 @@ func NewClient(
 			}
 			}
 			updateDomainRule(domainRule, originalRuleIdx, *matcherInfos)
 			updateDomainRule(domainRule, originalRuleIdx, *matcherInfos)
 		}
 		}
+		ns.PrioritizedDomain = nil
+		runtime.GC()
 
 
 		// Establish expected IPs
 		// Establish expected IPs
 		var expectedMatcher router.GeoIPMatcher
 		var expectedMatcher router.GeoIPMatcher
@@ -162,6 +166,8 @@ func NewClient(
 			if err != nil {
 			if err != nil {
 				return errors.New("failed to create expected ip matcher").Base(err).AtWarning()
 				return errors.New("failed to create expected ip matcher").Base(err).AtWarning()
 			}
 			}
+			ns.ExpectedGeoip = nil
+			runtime.GC()
 		}
 		}
 
 
 		// Establish unexpected IPs
 		// Establish unexpected IPs
@@ -171,6 +177,8 @@ func NewClient(
 			if err != nil {
 			if err != nil {
 				return errors.New("failed to create unexpected ip matcher").Base(err).AtWarning()
 				return errors.New("failed to create unexpected ip matcher").Base(err).AtWarning()
 			}
 			}
+			ns.UnexpectedGeoip = nil
+			runtime.GC()
 		}
 		}
 
 
 		if len(clientIP) > 0 {
 		if len(clientIP) > 0 {

+ 2 - 1
app/router/condition.go

@@ -57,7 +57,8 @@ type DomainMatcher struct {
 
 
 func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) {
 func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) {
 	g := strmatcher.NewMphMatcherGroup()
 	g := strmatcher.NewMphMatcherGroup()
-	for _, d := range domains {
+	for i, d := range domains {
+		domains[i] = nil
 		matcherType, f := matcherTypeMap[d.Type]
 		matcherType, f := matcherTypeMap[d.Type]
 		if !f {
 		if !f {
 			errors.LogError(context.Background(), "ignore unsupported domain type ", d.Type, " of rule ", d.Value)
 			errors.LogError(context.Background(), "ignore unsupported domain type ", d.Type, " of rule ", d.Value)

+ 2 - 1
app/router/condition_geoip.go

@@ -822,7 +822,8 @@ func (f *GeoIPSetFactory) Create(cidrGroups ...[]*CIDR) (*GeoIPSet, error) {
 	var ipv4Builder, ipv6Builder netipx.IPSetBuilder
 	var ipv4Builder, ipv6Builder netipx.IPSetBuilder
 
 
 	for _, cidrGroup := range cidrGroups {
 	for _, cidrGroup := range cidrGroups {
-		for _, cidrEntry := range cidrGroup {
+		for i, cidrEntry := range cidrGroup {
+			cidrGroup[i] = nil
 			ipBytes := cidrEntry.GetIp()
 			ipBytes := cidrEntry.GetIp()
 			prefixLen := int(cidrEntry.GetPrefix())
 			prefixLen := int(cidrEntry.GetPrefix())
 
 

+ 9 - 0
app/router/config.go

@@ -3,6 +3,7 @@ package router
 import (
 import (
 	"context"
 	"context"
 	"regexp"
 	"regexp"
+	"runtime"
 	"strings"
 	"strings"
 
 
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/errors"
@@ -78,6 +79,8 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 			return nil, err
 			return nil, err
 		}
 		}
 		conds.Add(cond)
 		conds.Add(cond)
+		rr.Geoip = nil
+		runtime.GC()
 	}
 	}
 
 
 	if len(rr.SourceGeoip) > 0 {
 	if len(rr.SourceGeoip) > 0 {
@@ -86,6 +89,8 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 			return nil, err
 			return nil, err
 		}
 		}
 		conds.Add(cond)
 		conds.Add(cond)
+		rr.SourceGeoip = nil
+		runtime.GC()
 	}
 	}
 
 
 	if len(rr.LocalGeoip) > 0 {
 	if len(rr.LocalGeoip) > 0 {
@@ -95,6 +100,8 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 		}
 		}
 		conds.Add(cond)
 		conds.Add(cond)
 		errors.LogWarning(context.Background(), "Due to some limitations, in UDP connections, localIP is always equal to listen interface IP, so \"localIP\" rule condition does not work properly on UDP inbound connections that listen on all interfaces")
 		errors.LogWarning(context.Background(), "Due to some limitations, in UDP connections, localIP is always equal to listen interface IP, so \"localIP\" rule condition does not work properly on UDP inbound connections that listen on all interfaces")
+		rr.LocalGeoip = nil
+		runtime.GC()
 	}
 	}
 
 
 	if len(rr.Domain) > 0 {
 	if len(rr.Domain) > 0 {
@@ -104,6 +111,8 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 		}
 		}
 		errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
 		errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
 		conds.Add(matcher)
 		conds.Add(matcher)
+		rr.Domain = nil
+		runtime.GC()
 	}
 	}
 
 
 	if len(rr.Process) > 0 {
 	if len(rr.Process) > 0 {

+ 4 - 0
common/platform/filesystem/file.go

@@ -29,6 +29,10 @@ func ReadAsset(file string) ([]byte, error) {
 	return ReadFile(platform.GetAssetLocation(file))
 	return ReadFile(platform.GetAssetLocation(file))
 }
 }
 
 
+func OpenAsset(file string) (io.ReadCloser, error) {
+	return NewFileReader(platform.GetAssetLocation(file))
+}
+
 func ReadCert(file string) ([]byte, error) {
 func ReadCert(file string) ([]byte, error) {
 	if filepath.IsAbs(file) {
 	if filepath.IsAbs(file) {
 		return ReadFile(file)
 		return ReadFile(file)

+ 87 - 90
infra/conf/router.go

@@ -1,7 +1,10 @@
 package conf
 package conf
 
 
 import (
 import (
+	"bufio"
+	"bytes"
 	"encoding/json"
 	"encoding/json"
+	"io"
 	"runtime"
 	"runtime"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -102,7 +105,7 @@ func (c *RouterConfig) Build() (*router.Config, error) {
 	}
 	}
 
 
 	for _, rawRule := range rawRuleList {
 	for _, rawRule := range rawRuleList {
-		rule, err := ParseRule(rawRule)
+		rule, err := parseRule(rawRule)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -125,7 +128,7 @@ type RouterRule struct {
 	BalancerTag string `json:"balancerTag"`
 	BalancerTag string `json:"balancerTag"`
 }
 }
 
 
-func ParseIP(s string) (*router.CIDR, error) {
+func parseIP(s string) (*router.CIDR, error) {
 	var addr, mask string
 	var addr, mask string
 	i := strings.Index(s, "/")
 	i := strings.Index(s, "/")
 	if i < 0 {
 	if i < 0 {
@@ -173,125 +176,119 @@ func ParseIP(s string) (*router.CIDR, error) {
 	}
 	}
 }
 }
 
 
-func loadGeoIP(code string) ([]*router.CIDR, error) {
-	return loadIP("geoip.dat", code)
-}
-
-var (
-	FileCache = make(map[string][]byte)
-	IPCache   = make(map[string]*router.GeoIP)
-	SiteCache = make(map[string]*router.GeoSite)
-)
-
-func loadFile(file string) ([]byte, error) {
-	if FileCache[file] == nil {
-		bs, err := filesystem.ReadAsset(file)
-		if err != nil {
-			return nil, errors.New("failed to open file: ", file).Base(err)
-		}
-		if len(bs) == 0 {
-			return nil, errors.New("empty file: ", file)
-		}
-		// Do not cache file, may save RAM when there
-		// are many files, but consume CPU each time.
-		return bs, nil
-		FileCache[file] = bs
+func loadFile(file, code string) ([]byte, error) {
+	runtime.GC()
+	r, err := filesystem.OpenAsset(file)
+	defer r.Close()
+	if err != nil {
+		return nil, errors.New("failed to open file: ", file).Base(err)
 	}
 	}
-	return FileCache[file], nil
+	bs := find(r, []byte(code))
+	if bs == nil {
+		return nil, errors.New("code not found in ", file, ": ", code)
+	}
+	return bs, nil
 }
 }
 
 
 func loadIP(file, code string) ([]*router.CIDR, error) {
 func loadIP(file, code string) ([]*router.CIDR, error) {
-	index := file + ":" + code
-	if IPCache[index] == nil {
-		bs, err := loadFile(file)
-		if err != nil {
-			return nil, errors.New("failed to load file: ", file).Base(err)
-		}
-		bs = find(bs, []byte(code))
-		if bs == nil {
-			return nil, errors.New("code not found in ", file, ": ", code)
-		}
-		var geoip router.GeoIP
-		if err := proto.Unmarshal(bs, &geoip); err != nil {
-			return nil, errors.New("error unmarshal IP in ", file, ": ", code).Base(err)
-		}
-		defer runtime.GC()     // or debug.FreeOSMemory()
-		return geoip.Cidr, nil // do not cache geoip
-		IPCache[index] = &geoip
+	bs, err := loadFile(file, code)
+	if err != nil {
+		return nil, err
 	}
 	}
-	return IPCache[index].Cidr, nil
+	var geoip router.GeoIP
+	if err := proto.Unmarshal(bs, &geoip); err != nil {
+		return nil, errors.New("error unmarshal IP in ", file, ": ", code).Base(err)
+	}
+	defer runtime.GC() // or debug.FreeOSMemory()
+	return geoip.Cidr, nil
 }
 }
 
 
 func loadSite(file, code string) ([]*router.Domain, error) {
 func loadSite(file, code string) ([]*router.Domain, error) {
-	index := file + ":" + code
-	if SiteCache[index] == nil {
-		bs, err := loadFile(file)
-		if err != nil {
-			return nil, errors.New("failed to load file: ", file).Base(err)
-		}
-		bs = find(bs, []byte(code))
-		if bs == nil {
-			return nil, errors.New("list not found in ", file, ": ", code)
-		}
-		var geosite router.GeoSite
-		if err := proto.Unmarshal(bs, &geosite); err != nil {
-			return nil, errors.New("error unmarshal Site in ", file, ": ", code).Base(err)
-		}
-		defer runtime.GC()         // or debug.FreeOSMemory()
-		return geosite.Domain, nil // do not cache geosite
-		SiteCache[index] = &geosite
+	bs, err := loadFile(file, code)
+	if err != nil {
+		return nil, err
+	}
+	var geosite router.GeoSite
+	if err := proto.Unmarshal(bs, &geosite); err != nil {
+		return nil, errors.New("error unmarshal Site in ", file, ": ", code).Base(err)
 	}
 	}
-	return SiteCache[index].Domain, nil
+	defer runtime.GC() // or debug.FreeOSMemory()
+	return geosite.Domain, nil
 }
 }
 
 
-func DecodeVarint(buf []byte) (x uint64, n int) {
+func decodeVarint(r *bufio.Reader) (uint64, error) {
+	var x uint64
 	for shift := uint(0); shift < 64; shift += 7 {
 	for shift := uint(0); shift < 64; shift += 7 {
-		if n >= len(buf) {
-			return 0, 0
+		b, err := r.ReadByte()
+		if err != nil {
+			return 0, err
 		}
 		}
-		b := uint64(buf[n])
-		n++
-		x |= (b & 0x7F) << shift
+		x |= (uint64(b) & 0x7F) << shift
 		if (b & 0x80) == 0 {
 		if (b & 0x80) == 0 {
-			return x, n
+			return x, nil
 		}
 		}
 	}
 	}
-
 	// The number is too large to represent in a 64-bit value.
 	// The number is too large to represent in a 64-bit value.
-	return 0, 0
+	return 0, errors.New("varint overflow")
 }
 }
 
 
-func find(data, code []byte) []byte {
+func find(r io.Reader, code []byte) []byte {
 	codeL := len(code)
 	codeL := len(code)
 	if codeL == 0 {
 	if codeL == 0 {
 		return nil
 		return nil
 	}
 	}
+
+	br := bufio.NewReaderSize(r, 64*1024)
+	need := 2 + codeL
+	prefixBuf := make([]byte, need)
+
 	for {
 	for {
-		dataL := len(data)
-		if dataL < 2 {
+		if _, err := br.ReadByte(); err != nil {
+			return nil
+		}
+
+		x, err := decodeVarint(br)
+		if err != nil {
 			return nil
 			return nil
 		}
 		}
-		x, y := DecodeVarint(data[1:])
-		if x == 0 && y == 0 {
+		bodyL := int(x)
+		if bodyL <= 0 {
 			return nil
 			return nil
 		}
 		}
-		headL, bodyL := 1+y, int(x)
-		dataL -= headL
-		if dataL < bodyL {
+
+		prefixL := bodyL
+		if prefixL > need {
+			prefixL = need
+		}
+		prefix := prefixBuf[:prefixL]
+		if _, err := io.ReadFull(br, prefix); err != nil {
 			return nil
 			return nil
 		}
 		}
-		data = data[headL:]
-		if int(data[1]) == codeL {
-			for i := 0; i < codeL && data[2+i] == code[i]; i++ {
-				if i+1 == codeL {
-					return data[:bodyL]
+
+		match := false
+		if bodyL >= need {
+			if int(prefix[1]) == codeL && bytes.Equal(prefix[2:need], code) {
+				match = true
+			}
+		}
+
+		remain := bodyL - prefixL
+		if match {
+			out := make([]byte, bodyL)
+			copy(out, prefix)
+			if remain > 0 {
+				if _, err := io.ReadFull(br, out[prefixL:]); err != nil {
+					return nil
 				}
 				}
 			}
 			}
+			return out
 		}
 		}
-		if dataL == bodyL {
-			return nil
+
+		if remain > 0 {
+			if _, err := br.Discard(remain); err != nil {
+				return nil
+			}
 		}
 		}
-		data = data[bodyL:]
 	}
 	}
 }
 }
 
 
@@ -447,7 +444,7 @@ func ToCidrList(ips StringList) ([]*router.GeoIP, error) {
 			if len(country) == 0 {
 			if len(country) == 0 {
 				return nil, errors.New("empty country name in rule")
 				return nil, errors.New("empty country name in rule")
 			}
 			}
-			geoip, err := loadGeoIP(strings.ToUpper(country))
+			geoip, err := loadIP("geoip.dat", strings.ToUpper(country))
 			if err != nil {
 			if err != nil {
 				return nil, errors.New("failed to load GeoIP: ", country).Base(err)
 				return nil, errors.New("failed to load GeoIP: ", country).Base(err)
 			}
 			}
@@ -501,7 +498,7 @@ func ToCidrList(ips StringList) ([]*router.GeoIP, error) {
 			continue
 			continue
 		}
 		}
 
 
-		ipRule, err := ParseIP(ip)
+		ipRule, err := parseIP(ip)
 		if err != nil {
 		if err != nil {
 			return nil, errors.New("invalid IP: ", ip).Base(err)
 			return nil, errors.New("invalid IP: ", ip).Base(err)
 		}
 		}
@@ -655,7 +652,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
 	return rule, nil
 	return rule, nil
 }
 }
 
 
-func ParseRule(msg json.RawMessage) (*router.RoutingRule, error) {
+func parseRule(msg json.RawMessage) (*router.RoutingRule, error) {
 	rawRule := new(RouterRule)
 	rawRule := new(RouterRule)
 	err := json.Unmarshal(msg, rawRule)
 	err := json.Unmarshal(msg, rawRule)
 	if err != nil {
 	if err != nil {