瀏覽代碼

DNS: Fix parse domain and geoip (#5499)

Fixes https://github.com/XTLS/Xray-core/pull/5488#issuecomment-3712856715
Hossin Asaadi 5 月之前
父節點
當前提交
961c352127
共有 5 個文件被更改,包括 117 次插入22 次删除
  1. 95 0
      app/dns/dns.go
  2. 4 4
      app/dns/dns_test.go
  3. 15 0
      app/dns/nameserver.go
  4. 2 2
      app/router/config.go
  5. 1 16
      infra/conf/dns.go

+ 95 - 0
app/dns/dns.go

@@ -12,12 +12,15 @@ import (
 	"sync"
 	"time"
 
+	router "github.com/xtls/xray-core/app/router"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/platform/filesystem"
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/common/strmatcher"
 	"github.com/xtls/xray-core/features/dns"
+	"google.golang.org/protobuf/proto"
 )
 
 // DNS is a DNS rely server.
@@ -97,6 +100,25 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
 	}
 
 	for _, ns := range config.NameServer {
+		if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
+			err := parseDomains(ns)
+			if err != nil {
+				return nil, errors.New("failed to parse dns domain rules: ").Base(err)
+			}
+
+			expectedGeoip, err := router.GetGeoIPList(ns.ExpectedGeoip)
+			if err != nil {
+				return nil, errors.New("failed to parse dns expectIPs rules: ").Base(err)
+			}
+			ns.ExpectedGeoip = expectedGeoip
+
+			unexpectedGeoip, err := router.GetGeoIPList(ns.UnexpectedGeoip)
+			if err != nil {
+				return nil, errors.New("failed to parse dns unexpectedGeoip rules: ").Base(err)
+			}
+			ns.UnexpectedGeoip = unexpectedGeoip
+
+		}
 		domainRuleCount += len(ns.PrioritizedDomain)
 	}
 
@@ -580,3 +602,76 @@ func detectGUIPlatform() bool {
 	}
 	return false
 }
+
+func parseDomains(ns *NameServer) error {
+	pureDomains := []*router.Domain{}
+
+	// convert to pure domain
+	for _, pd := range ns.PrioritizedDomain {
+		pureDomains = append(pureDomains, &router.Domain{
+			Type:  router.Domain_Type(pd.Type),
+			Value: pd.Domain,
+		})
+	}
+
+	domainList := []*router.Domain{}
+	for _, domain := range pureDomains {
+		val := strings.Split(domain.Value, "_")
+		if len(val) >= 2 {
+
+			fileName := val[0]
+			code := val[1]
+
+			bs, err := filesystem.ReadAsset(fileName)
+			if err != nil {
+				return errors.New("failed to load file: ", fileName).Base(err)
+			}
+			bs = filesystem.Find(bs, []byte(code))
+			var geosite router.GeoSite
+
+			if err := proto.Unmarshal(bs, &geosite); err != nil {
+				return errors.New("failed Unmarshal :").Base(err)
+			}
+
+			// parse attr
+			if len(val) == 3 {
+				siteWithAttr := strings.Split(val[2], ",")
+				attrs := router.ParseAttrs(siteWithAttr)
+				if !attrs.IsEmpty() {
+					filteredDomains := make([]*router.Domain, 0, len(pureDomains))
+					for _, domain := range geosite.Domain {
+						if attrs.Match(domain) {
+							filteredDomains = append(filteredDomains, domain)
+						}
+					}
+					geosite.Domain = filteredDomains
+				}
+
+			}
+
+			domainList = append(domainList, geosite.Domain...)
+
+			// update ns.OriginalRules Size
+			ruleTag := strings.Join(val, ":")
+			for i, oRule := range ns.OriginalRules {
+				if oRule.Rule == strings.ToLower(ruleTag) {
+					ns.OriginalRules[i].Size = uint32(len(geosite.Domain))
+				}
+			}
+
+		} else {
+			domainList = append(domainList, domain)
+		}
+	}
+
+	// convert back to NameServer_PriorityDomain
+	ns.PrioritizedDomain = []*NameServer_PriorityDomain{}
+	for _, pd := range domainList {
+		ns.PrioritizedDomain = append(ns.PrioritizedDomain, &NameServer_PriorityDomain{
+			Type:   ToDomainMatchingType(pd.Type),
+			Domain: pd.Value,
+		})
+	}
+
+	return nil
+}

+ 4 - 4
app/dns/dns_test.go

@@ -541,7 +541,7 @@ func TestIPMatch(t *testing.T) {
 						},
 						ExpectedGeoip: []*router.GeoIP{
 							{
-								CountryCode: "local",
+								// local
 								Cidr: []*router.CIDR{
 									{
 										// inner ip, will not match
@@ -565,7 +565,7 @@ func TestIPMatch(t *testing.T) {
 						},
 						ExpectedGeoip: []*router.GeoIP{
 							{
-								CountryCode: "test",
+								// test
 								Cidr: []*router.CIDR{
 									{
 										Ip:     []byte{8, 8, 8, 8},
@@ -574,7 +574,7 @@ func TestIPMatch(t *testing.T) {
 								},
 							},
 							{
-								CountryCode: "test",
+								// test
 								Cidr: []*router.CIDR{
 									{
 										Ip:     []byte{8, 8, 8, 4},
@@ -669,7 +669,7 @@ func TestLocalDomain(t *testing.T) {
 						},
 						ExpectedGeoip: []*router.GeoIP{
 							{ // Will match localhost, localhost-a and localhost-b,
-								CountryCode: "local",
+								// local
 								Cidr: []*router.CIDR{
 									{Ip: []byte{127, 0, 0, 2}, Prefix: 32},
 									{Ip: []byte{127, 0, 0, 3}, Prefix: 32},

+ 15 - 0
app/dns/nameserver.go

@@ -297,3 +297,18 @@ func ResolveIpOptionOverride(queryStrategy QueryStrategy, ipOption dns.IPOption)
 		return ipOption
 	}
 }
+
+func ToDomainMatchingType(t router.Domain_Type) DomainMatchingType {
+	switch t {
+	case router.Domain_Domain:
+		return DomainMatchingType_Subdomain
+	case router.Domain_Full:
+		return DomainMatchingType_Full
+	case router.Domain_Plain:
+		return DomainMatchingType_Keyword
+	case router.Domain_Regex:
+		return DomainMatchingType_Regex
+	default:
+		panic("unknown domain type")
+	}
+}

+ 2 - 2
app/router/config.go

@@ -79,7 +79,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 		geoip := rr.Geoip
 		if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
 			var err error
-			geoip, err = getGeoIPList(rr.Geoip)
+			geoip, err = GetGeoIPList(rr.Geoip)
 			if err != nil {
 				return nil, errors.New("failed to build geoip from mmap").Base(err)
 			}
@@ -188,7 +188,7 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch
 	}
 }
 
-func getGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
+func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
 	geoipList := []*GeoIP{}
 	for _, ip := range ips {
 		if ip.CountryCode != "" {

+ 1 - 16
infra/conf/dns.go

@@ -80,21 +80,6 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
 	return errors.New("failed to parse name server: ", string(data))
 }
 
-func toDomainMatchingType(t router.Domain_Type) dns.DomainMatchingType {
-	switch t {
-	case router.Domain_Domain:
-		return dns.DomainMatchingType_Subdomain
-	case router.Domain_Full:
-		return dns.DomainMatchingType_Full
-	case router.Domain_Plain:
-		return dns.DomainMatchingType_Keyword
-	case router.Domain_Regex:
-		return dns.DomainMatchingType_Regex
-	default:
-		panic("unknown domain type")
-	}
-}
-
 func (c *NameServerConfig) Build() (*dns.NameServer, error) {
 	if c.Address == nil {
 		return nil, errors.New("NameServer address is not specified.")
@@ -111,7 +96,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) {
 
 		for _, pd := range parsedDomain {
 			domains = append(domains, &dns.NameServer_PriorityDomain{
-				Type:   toDomainMatchingType(pd.Type),
+				Type:   dns.ToDomainMatchingType(pd.Type),
 				Domain: pd.Value,
 			})
 		}