Explorar o código

Refine `PrioritizedDomain`, should fix https://github.com/XTLS/Xray-core/issues/638

hmol233 %!s(int64=5) %!d(string=hai) anos
pai
achega
1ced7985d5
Modificáronse 2 ficheiros con 52 adicións e 72 borrados
  1. 17 33
      app/dns/dns.go
  2. 35 39
      app/dns/nameserver.go

+ 17 - 33
app/dns/dns.go

@@ -12,7 +12,6 @@ import (
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/matcher/geoip"
 	"github.com/xtls/xray-core/common/matcher/geoip"
-	"github.com/xtls/xray-core/common/matcher/str"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/features"
 	"github.com/xtls/xray-core/features"
@@ -29,8 +28,6 @@ type DNS struct {
 	hosts           *StaticHosts
 	hosts           *StaticHosts
 	clients         []*Client
 	clients         []*Client
 	ctx             context.Context
 	ctx             context.Context
-	domainMatcher   str.IndexMatcher
-	matcherInfos    []DomainMatcherInfo
 }
 }
 
 
 // DomainMatcherInfo contains information attached to index returned by Server.domainMatcher
 // DomainMatcherInfo contains information attached to index returned by Server.domainMatcher
@@ -89,9 +86,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 		domainRuleCount += len(ns.PrioritizedDomain)
 		domainRuleCount += len(ns.PrioritizedDomain)
 	}
 	}
 
 
-	// MatcherInfos is ensured to cover the maximum index domainMatcher could return, where matcher's index starts from 1
-	matcherInfos := make([]DomainMatcherInfo, domainRuleCount+1)
-	domainMatcher := &str.MatcherGroup{}
 	geoipContainer := geoip.GeoIPMatcherContainer{}
 	geoipContainer := geoip.GeoIPMatcherContainer{}
 
 
 	for _, endpoint := range config.NameServers {
 	for _, endpoint := range config.NameServers {
@@ -104,22 +98,13 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 	}
 	}
 
 
 	for _, ns := range config.NameServer {
 	for _, ns := range config.NameServer {
-		clientIdx := len(clients)
-		updateDomain := func(domainRule str.Matcher, originalRuleIdx int, matcherInfos []DomainMatcherInfo) error {
-			midx := domainMatcher.Add(domainRule)
-			matcherInfos[midx] = DomainMatcherInfo{
-				clientIdx:     uint16(clientIdx),
-				domainRuleIdx: uint16(originalRuleIdx),
-			}
-			return nil
-		}
 
 
 		myClientIP := clientIP
 		myClientIP := clientIP
 		switch len(ns.ClientIp) {
 		switch len(ns.ClientIp) {
 		case net.IPv4len, net.IPv6len:
 		case net.IPv4len, net.IPv6len:
-			myClientIP = net.IP(ns.ClientIp)
+			myClientIP = ns.ClientIp
 		}
 		}
-		client, err := NewClient(ctx, ns, myClientIP, geoipContainer, &matcherInfos, updateDomain)
+		client, err := NewClient(ctx, ns, myClientIP, geoipContainer)
 		if err != nil {
 		if err != nil {
 			return nil, newError("failed to create client").Base(err)
 			return nil, newError("failed to create client").Base(err)
 		}
 		}
@@ -137,8 +122,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 		ipOption:        ipOption,
 		ipOption:        ipOption,
 		clients:         clients,
 		clients:         clients,
 		ctx:             ctx,
 		ctx:             ctx,
-		domainMatcher:   domainMatcher,
-		matcherInfos:    matcherInfos,
 		cacheStrategy:   config.CacheStrategy,
 		cacheStrategy:   config.CacheStrategy,
 		disableFallback: config.DisableFallback,
 		disableFallback: config.DisableFallback,
 	}, nil
 	}, nil
@@ -268,21 +251,22 @@ func (s *DNS) sortClients(domain string, option *dns.IPOption) []*Client {
 	}()
 	}()
 
 
 	// Priority domain matching
 	// Priority domain matching
-	for _, match := range s.domainMatcher.Match(domain) {
-		info := s.matcherInfos[match]
-		client := s.clients[info.clientIdx]
-		domainRule := client.domains[info.domainRuleIdx]
-		if !canQueryOnClient(option, client) {
-			newError("skipping the client " + client.Name()).AtDebug().WriteToLog()
-			continue
-		}
-		domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", domainRule, info.clientIdx))
-		if clientUsed[info.clientIdx] {
-			continue
+	for clientIdx, client := range s.clients {
+		if ids := client.domainMatcher.Match(domain); len(ids) > 0 {
+			if !canQueryOnClient(option, client) {
+				newError("skipping the client " + client.Name()).AtDebug().WriteToLog()
+				continue
+			}
+			for _, id := range ids {
+				rule := client.findRule(id)
+				domainRules = append(domainRules, fmt.Sprintf("%s(DNS idx:%d)", rule, clientIdx))
+			}
+			if clientUsed[clientIdx] {
+				continue
+			}
+			clients = append(clients, client)
+			clientNames = append(clientNames, client.Name())
 		}
 		}
-		clientUsed[info.clientIdx] = true
-		clients = append(clients, client)
-		clientNames = append(clientNames, client.Name())
 	}
 	}
 
 
 	if !s.disableFallback {
 	if !s.disableFallback {

+ 35 - 39
app/dns/nameserver.go

@@ -25,11 +25,23 @@ type Server interface {
 
 
 // Client is the interface for DNS client.
 // Client is the interface for DNS client.
 type Client struct {
 type Client struct {
-	server       Server
-	clientIP     net.IP
-	skipFallback bool
-	domains      []string
-	expectIPs    []*geoip.GeoIPMatcher
+	server        Server
+	clientIP      net.IP
+	skipFallback  bool
+	expectIPs     []*geoip.GeoIPMatcher
+	domainMatcher str.MatcherGroup
+	originRules   []*NameServer_OriginalRule
+}
+
+func (c Client) findRule(idx uint32) string {
+	for _, r := range c.originRules {
+		if idx <= r.Size {
+			return r.Rule
+		}
+		idx -= r.Size
+	}
+
+	return "unknown rule"
 }
 }
 
 
 var errExpectedIPNonMatch = errors.New("expectIPs not match")
 var errExpectedIPNonMatch = errors.New("expectIPs not match")
@@ -64,7 +76,7 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher) (Server, err
 }
 }
 
 
 // NewClient creates a DNS client managing a name server with client IP, domain rules and expected IPs.
 // NewClient creates a DNS client managing a name server with client IP, domain rules and expected IPs.
-func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container geoip.GeoIPMatcherContainer, matcherInfos *[]DomainMatcherInfo, updateDomainRule func(str.Matcher, int, []DomainMatcherInfo) error) (*Client, error) {
+func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container geoip.GeoIPMatcherContainer) (*Client, error) {
 	client := &Client{}
 	client := &Client{}
 
 
 	err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
 	err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
@@ -79,55 +91,38 @@ func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container g
 			ns.PrioritizedDomain = append(ns.PrioritizedDomain, localTLDsAndDotlessDomains...)
 			ns.PrioritizedDomain = append(ns.PrioritizedDomain, localTLDsAndDotlessDomains...)
 			ns.OriginalRules = append(ns.OriginalRules, localTLDsAndDotlessDomainsRule)
 			ns.OriginalRules = append(ns.OriginalRules, localTLDsAndDotlessDomainsRule)
 			// The following lines is a solution to avoid core panics(rule index out of range) when setting `localhost` DNS client in config.
 			// The following lines is a solution to avoid core panics(rule index out of range) when setting `localhost` DNS client in config.
-			// Because the `localhost` DNS client will apend len(localTLDsAndDotlessDomains) rules into matcherInfos to match `geosite:private` default rule.
+			// Because the `localhost` DNS client will append len(localTLDsAndDotlessDomains) rules into matcherInfos to match `geosite:private` default rule.
 			// But `matcherInfos` has no enough length to add rules, which leads to core panics (rule index out of range).
 			// But `matcherInfos` has no enough length to add rules, which leads to core panics (rule index out of range).
 			// To avoid this, the length of `matcherInfos` must be equal to the expected, so manually append it with Golang default zero value first for later modification.
 			// To avoid this, the length of `matcherInfos` must be equal to the expected, so manually append it with Golang default zero value first for later modification.
-			for i := 0; i < len(localTLDsAndDotlessDomains); i++ {
-				*matcherInfos = append(*matcherInfos, DomainMatcherInfo{
-					clientIdx:     uint16(0),
-					domainRuleIdx: uint16(0),
-				})
-			}
+			// ;)
+			/*
+				for i := 0; i < len(localTLDsAndDotlessDomains); i++ {
+					*matcherInfos = append(*matcherInfos, DomainMatcherInfo{
+						clientIdx:     uint16(0),
+						domainRuleIdx: uint16(0),
+					})
+				}
+			*/
 		}
 		}
 
 
 		// Establish domain rules
 		// Establish domain rules
-		var rules []string
-		ruleCurr := 0
-		ruleIter := 0
+		var domainMatcher = str.MatcherGroup{}
 		for _, domain := range ns.PrioritizedDomain {
 		for _, domain := range ns.PrioritizedDomain {
 			domainRule, err := toStrMatcher(domain.Type, domain.Value)
 			domainRule, err := toStrMatcher(domain.Type, domain.Value)
 			if err != nil {
 			if err != nil {
 				return newError("failed to create prioritized domain").Base(err).AtWarning()
 				return newError("failed to create prioritized domain").Base(err).AtWarning()
 			}
 			}
-			originalRuleIdx := ruleCurr
-			if ruleCurr < len(ns.OriginalRules) {
-				rule := ns.OriginalRules[ruleCurr]
-				if ruleCurr >= len(rules) {
-					rules = append(rules, rule.Rule)
-				}
-				ruleIter++
-				if ruleIter >= int(rule.Size) {
-					ruleIter = 0
-					ruleCurr++
-				}
-			} else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests)
-				rules = append(rules, domainRule.String())
-				ruleCurr++
-			}
-			err = updateDomainRule(domainRule, originalRuleIdx, *matcherInfos)
-			if err != nil {
-				return newError("failed to create prioritized domain").Base(err).AtWarning()
-			}
+			domainMatcher.Add(domainRule)
 		}
 		}
 
 
 		// Establish expected IPs
 		// Establish expected IPs
-		var matchers []*geoip.GeoIPMatcher
+		var ipMatchers []*geoip.GeoIPMatcher
 		for _, geoip := range ns.Geoip {
 		for _, geoip := range ns.Geoip {
 			matcher, err := container.Add(geoip)
 			matcher, err := container.Add(geoip)
 			if err != nil {
 			if err != nil {
 				return newError("failed to create ip matcher").Base(err).AtWarning()
 				return newError("failed to create ip matcher").Base(err).AtWarning()
 			}
 			}
-			matchers = append(matchers, matcher)
+			ipMatchers = append(ipMatchers, matcher)
 		}
 		}
 
 
 		if len(clientIP) > 0 {
 		if len(clientIP) > 0 {
@@ -141,8 +136,9 @@ func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container g
 
 
 		client.server = server
 		client.server = server
 		client.clientIP = clientIP
 		client.clientIP = clientIP
-		client.domains = rules
-		client.expectIPs = matchers
+		client.expectIPs = ipMatchers
+		client.originRules = ns.OriginalRules
+		client.domainMatcher = domainMatcher
 		return nil
 		return nil
 	})
 	})
 	return client, err
 	return client, err