Просмотр исходного кода

DNS hosts: Support returning RCode (#4681)

j2rong4cn 10 месяцев назад
Родитель
Сommit
923b5d7229
4 измененных файлов с 76 добавлено и 18 удалено
  1. 6 1
      app/dns/dns.go
  2. 35 12
      app/dns/hosts.go
  3. 17 5
      app/dns/hosts_test.go
  4. 18 0
      features/dns/client.go

+ 6 - 1
app/dns/dns.go

@@ -204,7 +204,12 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er
 	}
 
 	// Static host lookup
-	switch addrs := s.hosts.Lookup(domain, option); {
+	switch addrs, err := s.hosts.Lookup(domain, option); {
+	case err != nil:
+		if go_errors.Is(err, dns.ErrEmptyResponse) {
+			return nil, 0, dns.ErrEmptyResponse
+		}
+		return nil, 0, errors.New("returning nil for domain ", domain).Base(err)
 	case addrs == nil: // Domain not recorded in static host
 		break
 	case len(addrs) == 0: // Domain recorded, but no valid IP returned (e.g. IPv4 address with only IPv6 enabled)

+ 35 - 12
app/dns/hosts.go

@@ -2,6 +2,8 @@ package dns
 
 import (
 	"context"
+	"strconv"
+
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/strmatcher"
@@ -31,7 +33,15 @@ func NewStaticHosts(hosts []*Config_HostMapping) (*StaticHosts, error) {
 		ips := make([]net.Address, 0, len(mapping.Ip)+1)
 		switch {
 		case len(mapping.ProxiedDomain) > 0:
-			ips = append(ips, net.DomainAddress(mapping.ProxiedDomain))
+			if mapping.ProxiedDomain[0] == '#' {
+				rcode, err := strconv.Atoi(mapping.ProxiedDomain[1:])
+				if err != nil {
+					return nil, err
+				}
+				ips = append(ips, dns.RCodeError(rcode))
+			} else {
+				ips = append(ips, net.DomainAddress(mapping.ProxiedDomain))
+			}
 		case len(mapping.Ip) > 0:
 			for _, ip := range mapping.Ip {
 				addr := net.IPAddress(ip)
@@ -58,38 +68,51 @@ func filterIP(ips []net.Address, option dns.IPOption) []net.Address {
 	return filtered
 }
 
-func (h *StaticHosts) lookupInternal(domain string) []net.Address {
+func (h *StaticHosts) lookupInternal(domain string) ([]net.Address, error) {
 	ips := make([]net.Address, 0)
 	found := false
 	for _, id := range h.matchers.Match(domain) {
+		for _, v := range h.ips[id] {
+			if err, ok := v.(dns.RCodeError); ok {
+				if uint16(err) == 0 {
+					return nil, dns.ErrEmptyResponse
+				}
+				return nil, err
+			}
+		}
 		ips = append(ips, h.ips[id]...)
 		found = true
 	}
 	if !found {
-		return nil
+		return nil, nil
 	}
-	return ips
+	return ips, nil
 }
 
-func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address {
-	switch addrs := h.lookupInternal(domain); {
+func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ([]net.Address, error) {
+	switch addrs, err := h.lookupInternal(domain); {
+	case err != nil:
+		return nil, err
 	case len(addrs) == 0: // Not recorded in static hosts, return nil
-		return addrs
+		return addrs, nil
 	case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Try to unwrap domain
 		errors.LogDebug(context.Background(), "found replaced domain: ", domain, " -> ", addrs[0].Domain(), ". Try to unwrap it")
 		if maxDepth > 0 {
-			unwrapped := h.lookup(addrs[0].Domain(), option, maxDepth-1)
+			unwrapped, err := h.lookup(addrs[0].Domain(), option, maxDepth-1)
+			if err != nil {
+				return nil, err
+			}
 			if unwrapped != nil {
-				return unwrapped
+				return unwrapped, nil
 			}
 		}
-		return addrs
+		return addrs, nil
 	default: // IP record found, return a non-nil IP array
-		return filterIP(addrs, option)
+		return filterIP(addrs, option), nil
 	}
 }
 
 // Lookup returns IP addresses or proxied domain for the given domain, if exists in this StaticHosts.
-func (h *StaticHosts) Lookup(domain string, option dns.IPOption) []net.Address {
+func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) {
 	return h.lookup(domain, option, 5)
 }

+ 17 - 5
app/dns/hosts_test.go

@@ -12,6 +12,11 @@ import (
 
 func TestStaticHosts(t *testing.T) {
 	pb := []*Config_HostMapping{
+		{
+			Type:          DomainMatchingType_Subdomain,
+			Domain:        "lan",
+			ProxiedDomain: "#3",
+		},
 		{
 			Type:   DomainMatchingType_Full,
 			Domain: "example.com",
@@ -54,7 +59,14 @@ func TestStaticHosts(t *testing.T) {
 	common.Must(err)
 
 	{
-		ips := hosts.Lookup("example.com", dns.IPOption{
+		_, err := hosts.Lookup("example.com.lan", dns.IPOption{})
+		if dns.RCodeFromError(err) != 3 {
+			t.Error(err)
+		}
+	}
+
+	{
+		ips, _ := hosts.Lookup("example.com", dns.IPOption{
 			IPv4Enable: true,
 			IPv6Enable: true,
 		})
@@ -67,7 +79,7 @@ func TestStaticHosts(t *testing.T) {
 	}
 
 	{
-		domain := hosts.Lookup("proxy.xray.com", dns.IPOption{
+		domain, _ := hosts.Lookup("proxy.xray.com", dns.IPOption{
 			IPv4Enable: true,
 			IPv6Enable: false,
 		})
@@ -80,7 +92,7 @@ func TestStaticHosts(t *testing.T) {
 	}
 
 	{
-		domain := hosts.Lookup("proxy2.xray.com", dns.IPOption{
+		domain, _ := hosts.Lookup("proxy2.xray.com", dns.IPOption{
 			IPv4Enable: true,
 			IPv6Enable: false,
 		})
@@ -93,7 +105,7 @@ func TestStaticHosts(t *testing.T) {
 	}
 
 	{
-		ips := hosts.Lookup("www.example.cn", dns.IPOption{
+		ips, _ := hosts.Lookup("www.example.cn", dns.IPOption{
 			IPv4Enable: true,
 			IPv6Enable: true,
 		})
@@ -106,7 +118,7 @@ func TestStaticHosts(t *testing.T) {
 	}
 
 	{
-		ips := hosts.Lookup("baidu.com", dns.IPOption{
+		ips, _ := hosts.Lookup("baidu.com", dns.IPOption{
 			IPv4Enable: false,
 			IPv6Enable: true,
 		})

+ 18 - 0
features/dns/client.go

@@ -42,6 +42,24 @@ func (e RCodeError) Error() string {
 	return serial.Concat("rcode: ", uint16(e))
 }
 
+func (RCodeError) IP() net.IP {
+	panic("Calling IP() on a RCodeError.")
+}
+
+func (RCodeError) Domain() string {
+	panic("Calling Domain() on a RCodeError.")
+}
+
+func (RCodeError) Family() net.AddressFamily {
+	panic("Calling Family() on a RCodeError.")
+}
+
+func (e RCodeError) String() string {
+	return e.Error()
+}
+
+var _ net.Address = (*RCodeError)(nil)
+
 func RCodeFromError(err error) uint16 {
 	if err == nil {
 		return 0