hosts.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package dns
  2. import (
  3. "context"
  4. "runtime"
  5. "strconv"
  6. "github.com/xtls/xray-core/common/errors"
  7. "github.com/xtls/xray-core/common/net"
  8. "github.com/xtls/xray-core/common/strmatcher"
  9. "github.com/xtls/xray-core/features/dns"
  10. )
  11. // StaticHosts represents static domain-ip mapping in DNS server.
  12. type StaticHosts struct {
  13. ips [][]net.Address
  14. matchers strmatcher.IndexMatcher
  15. }
  16. // NewStaticHosts creates a new StaticHosts instance.
  17. func NewStaticHosts(hosts []*Config_HostMapping) (*StaticHosts, error) {
  18. g := new(strmatcher.MatcherGroup)
  19. sh := &StaticHosts{
  20. ips: make([][]net.Address, len(hosts)+16),
  21. matchers: g,
  22. }
  23. defer runtime.GC()
  24. for i, mapping := range hosts {
  25. hosts[i] = nil
  26. matcher, err := toStrMatcher(mapping.Type, mapping.Domain)
  27. if err != nil {
  28. errors.LogErrorInner(context.Background(), err, "failed to create domain matcher, ignore domain rule [type: ", mapping.Type, ", domain: ", mapping.Domain, "]")
  29. continue
  30. }
  31. id := g.Add(matcher)
  32. ips := make([]net.Address, 0, len(mapping.Ip)+1)
  33. switch {
  34. case len(mapping.ProxiedDomain) > 0:
  35. if mapping.ProxiedDomain[0] == '#' {
  36. rcode, err := strconv.Atoi(mapping.ProxiedDomain[1:])
  37. if err != nil {
  38. return nil, err
  39. }
  40. ips = append(ips, dns.RCodeError(rcode))
  41. } else {
  42. ips = append(ips, net.DomainAddress(mapping.ProxiedDomain))
  43. }
  44. case len(mapping.Ip) > 0:
  45. for _, ip := range mapping.Ip {
  46. addr := net.IPAddress(ip)
  47. if addr == nil {
  48. errors.LogError(context.Background(), "invalid IP address in static hosts: ", ip, ", ignore this ip for rule [type: ", mapping.Type, ", domain: ", mapping.Domain, "]")
  49. continue
  50. }
  51. ips = append(ips, addr)
  52. }
  53. if len(ips) == 0 {
  54. continue
  55. }
  56. }
  57. sh.ips[id] = ips
  58. }
  59. return sh, nil
  60. }
  61. func filterIP(ips []net.Address, option dns.IPOption) []net.Address {
  62. filtered := make([]net.Address, 0, len(ips))
  63. for _, ip := range ips {
  64. if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) {
  65. filtered = append(filtered, ip)
  66. }
  67. }
  68. return filtered
  69. }
  70. func (h *StaticHosts) lookupInternal(domain string) ([]net.Address, error) {
  71. ips := make([]net.Address, 0)
  72. found := false
  73. for _, id := range h.matchers.Match(domain) {
  74. for _, v := range h.ips[id] {
  75. if err, ok := v.(dns.RCodeError); ok {
  76. if uint16(err) == 0 {
  77. return nil, dns.ErrEmptyResponse
  78. }
  79. return nil, err
  80. }
  81. }
  82. ips = append(ips, h.ips[id]...)
  83. found = true
  84. }
  85. if !found {
  86. return nil, nil
  87. }
  88. return ips, nil
  89. }
  90. func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ([]net.Address, error) {
  91. switch addrs, err := h.lookupInternal(domain); {
  92. case err != nil:
  93. return nil, err
  94. case len(addrs) == 0: // Not recorded in static hosts, return nil
  95. return addrs, nil
  96. case len(addrs) == 1 && addrs[0].Family().IsDomain(): // Try to unwrap domain
  97. errors.LogDebug(context.Background(), "found replaced domain: ", domain, " -> ", addrs[0].Domain(), ". Try to unwrap it")
  98. if maxDepth > 0 {
  99. unwrapped, err := h.lookup(addrs[0].Domain(), option, maxDepth-1)
  100. if err != nil {
  101. return nil, err
  102. }
  103. if unwrapped != nil {
  104. return unwrapped, nil
  105. }
  106. }
  107. return addrs, nil
  108. default: // IP record found, return a non-nil IP array
  109. return filterIP(addrs, option), nil
  110. }
  111. }
  112. // Lookup returns IP addresses or proxied domain for the given domain, if exists in this StaticHosts.
  113. func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) {
  114. return h.lookup(domain, option, 5)
  115. }
  116. func NewStaticHostsFromCache(matcher strmatcher.IndexMatcher, hostIPs map[string][]string) (*StaticHosts, error) {
  117. sh := &StaticHosts{
  118. ips: make([][]net.Address, matcher.Size()+1),
  119. matchers: matcher,
  120. }
  121. order := hostIPs["_ORDER"]
  122. var offset uint32
  123. img, ok := matcher.(*strmatcher.IndexMatcherGroup)
  124. if !ok {
  125. // Single matcher (e.g. only manual or only one geosite)
  126. if len(order) > 0 {
  127. pattern := order[0]
  128. ips := parseIPs(hostIPs[pattern])
  129. for i := uint32(1); i <= matcher.Size(); i++ {
  130. sh.ips[i] = ips
  131. }
  132. }
  133. return sh, nil
  134. }
  135. for i, m := range img.Matchers {
  136. if i < len(order) {
  137. pattern := order[i]
  138. ips := parseIPs(hostIPs[pattern])
  139. for j := uint32(1); j <= m.Size(); j++ {
  140. sh.ips[offset+j] = ips
  141. }
  142. offset += m.Size()
  143. }
  144. }
  145. return sh, nil
  146. }
  147. func parseIPs(raw []string) []net.Address {
  148. addrs := make([]net.Address, 0, len(raw))
  149. for _, s := range raw {
  150. if len(s) > 1 && s[0] == '#' {
  151. rcode, _ := strconv.Atoi(s[1:])
  152. addrs = append(addrs, dns.RCodeError(rcode))
  153. } else {
  154. addrs = append(addrs, net.ParseAddress(s))
  155. }
  156. }
  157. return addrs
  158. }