resolver_test.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817
  1. /*
  2. * Copyright (c) 2022, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package resolver
  20. import (
  21. "context"
  22. "fmt"
  23. "net"
  24. "reflect"
  25. "sync/atomic"
  26. "testing"
  27. "time"
  28. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
  31. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  32. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
  33. "github.com/miekg/dns"
  34. )
  35. func TestMakeResolveParameters(t *testing.T) {
  36. err := runTestMakeResolveParameters()
  37. if err != nil {
  38. t.Fatalf(errors.Trace(err).Error())
  39. }
  40. }
  41. func TestResolver(t *testing.T) {
  42. err := runTestResolver()
  43. if err != nil {
  44. t.Fatalf(errors.Trace(err).Error())
  45. }
  46. }
  47. func TestPublicDNSServers(t *testing.T) {
  48. IPs, metrics, err := runTestPublicDNSServers()
  49. if err != nil {
  50. t.Fatalf(errors.Trace(err).Error())
  51. }
  52. t.Logf("IPs: %v", IPs)
  53. t.Logf("Metrics: %v", metrics)
  54. }
  55. func runTestMakeResolveParameters() error {
  56. frontingProviderID := "frontingProvider"
  57. alternateDNSServer := "172.16.0.1"
  58. alternateDNSServerWithPort := net.JoinHostPort(alternateDNSServer, resolverDNSPort)
  59. preferredAlternateDNSServer := "172.16.0.2"
  60. preferredAlternateDNSServerWithPort := net.JoinHostPort(preferredAlternateDNSServer, resolverDNSPort)
  61. transformName := "exampleTransform"
  62. paramValues := map[string]interface{}{
  63. "DNSResolverAttemptsPerServer": 2,
  64. "DNSResolverAttemptsPerPreferredServer": 1,
  65. "DNSResolverPreresolvedIPAddressProbability": 1.0,
  66. "DNSResolverPreresolvedIPAddressCIDRs": parameters.LabeledCIDRs{frontingProviderID: []string{exampleIPv4CIDR}},
  67. "DNSResolverAlternateServers": []string{alternateDNSServer},
  68. "DNSResolverPreferredAlternateServers": []string{preferredAlternateDNSServer},
  69. "DNSResolverPreferAlternateServerProbability": 1.0,
  70. "DNSResolverProtocolTransformProbability": 1.0,
  71. "DNSResolverProtocolTransformSpecs": transforms.Specs{transformName: exampleTransform},
  72. "DNSResolverProtocolTransformScopedSpecNames": transforms.ScopedSpecNames{preferredAlternateDNSServer: []string{transformName}},
  73. "DNSResolverIncludeEDNS0Probability": 1.0,
  74. }
  75. params, err := parameters.NewParameters(nil)
  76. if err != nil {
  77. return errors.Trace(err)
  78. }
  79. _, err = params.Set("", false, paramValues)
  80. if err != nil {
  81. return errors.Trace(err)
  82. }
  83. resolver := NewResolver(&NetworkConfig{}, "")
  84. defer resolver.Stop()
  85. resolverParams, err := resolver.MakeResolveParameters(
  86. params.Get(), frontingProviderID)
  87. if err != nil {
  88. return errors.Trace(err)
  89. }
  90. // Test: PreresolvedIPAddress
  91. CIDRContainsIP := func(CIDR, IP string) bool {
  92. _, IPNet, _ := net.ParseCIDR(CIDR)
  93. return IPNet.Contains(net.ParseIP(IP))
  94. }
  95. if resolverParams.AttemptsPerServer != 2 ||
  96. resolverParams.AttemptsPerPreferredServer != 1 ||
  97. resolverParams.RequestTimeout != 5*time.Second ||
  98. resolverParams.AwaitTimeout != 10*time.Millisecond ||
  99. !CIDRContainsIP(exampleIPv4CIDR, resolverParams.PreresolvedIPAddress) ||
  100. resolverParams.AlternateDNSServer != "" ||
  101. resolverParams.PreferAlternateDNSServer != false ||
  102. resolverParams.ProtocolTransformName != "" ||
  103. resolverParams.ProtocolTransformSpec != nil ||
  104. resolverParams.IncludeEDNS0 != false {
  105. return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
  106. }
  107. // Test: additional generateIPAddressFromCIDR cases
  108. for i := 0; i < 10000; i++ {
  109. for _, CIDR := range []string{exampleIPv4CIDR, exampleIPv6CIDR} {
  110. IP, err := generateIPAddressFromCIDR(CIDR)
  111. if err != nil {
  112. return errors.Trace(err)
  113. }
  114. if !CIDRContainsIP(CIDR, IP.String()) || common.IsBogon(IP) {
  115. return errors.Tracef(
  116. "invalid generated IP address %v for CIDR %v", IP, CIDR)
  117. }
  118. }
  119. }
  120. // Test: Preferred/Transform/EDNS(0)
  121. paramValues["DNSResolverPreresolvedIPAddressProbability"] = 0.0
  122. _, err = params.Set("", false, paramValues)
  123. if err != nil {
  124. return errors.Trace(err)
  125. }
  126. resolverParams, err = resolver.MakeResolveParameters(
  127. params.Get(), frontingProviderID)
  128. if err != nil {
  129. return errors.Trace(err)
  130. }
  131. if resolverParams.AttemptsPerServer != 2 ||
  132. resolverParams.AttemptsPerPreferredServer != 1 ||
  133. resolverParams.RequestTimeout != 5*time.Second ||
  134. resolverParams.AwaitTimeout != 10*time.Millisecond ||
  135. resolverParams.PreresolvedIPAddress != "" ||
  136. resolverParams.AlternateDNSServer != preferredAlternateDNSServerWithPort ||
  137. resolverParams.PreferAlternateDNSServer != true ||
  138. resolverParams.ProtocolTransformName != transformName ||
  139. resolverParams.ProtocolTransformSpec == nil ||
  140. resolverParams.IncludeEDNS0 != true {
  141. return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
  142. }
  143. // Test: No Preferred/Transform/EDNS(0)
  144. paramValues["DNSResolverPreferAlternateServerProbability"] = 0.0
  145. paramValues["DNSResolverProtocolTransformProbability"] = 0.0
  146. paramValues["DNSResolverIncludeEDNS0Probability"] = 0.0
  147. _, err = params.Set("", false, paramValues)
  148. if err != nil {
  149. return errors.Trace(err)
  150. }
  151. resolverParams, err = resolver.MakeResolveParameters(
  152. params.Get(), frontingProviderID)
  153. if err != nil {
  154. return errors.Trace(err)
  155. }
  156. if resolverParams.AttemptsPerServer != 2 ||
  157. resolverParams.AttemptsPerPreferredServer != 1 ||
  158. resolverParams.RequestTimeout != 5*time.Second ||
  159. resolverParams.AwaitTimeout != 10*time.Millisecond ||
  160. resolverParams.PreresolvedIPAddress != "" ||
  161. resolverParams.AlternateDNSServer != alternateDNSServerWithPort ||
  162. resolverParams.PreferAlternateDNSServer != false ||
  163. resolverParams.ProtocolTransformName != "" ||
  164. resolverParams.ProtocolTransformSpec != nil ||
  165. resolverParams.IncludeEDNS0 != false {
  166. return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
  167. }
  168. return nil
  169. }
  170. func runTestResolver() error {
  171. // noResponseServer will not respond to requests
  172. noResponseServer, err := newTestDNSServer(false, false, false)
  173. if err != nil {
  174. return errors.Trace(err)
  175. }
  176. defer noResponseServer.stop()
  177. // invalidIPServer will respond with an invalid IP
  178. invalidIPServer, err := newTestDNSServer(true, false, false)
  179. if err != nil {
  180. return errors.Trace(err)
  181. }
  182. defer invalidIPServer.stop()
  183. // okServer will respond to correct requests (expected domain) with the
  184. // correct response (expected IPv4 or IPv6 address)
  185. okServer, err := newTestDNSServer(true, true, false)
  186. if err != nil {
  187. return errors.Trace(err)
  188. }
  189. defer okServer.stop()
  190. // alternateOkServer behaves like okServer; getRequestCount is used to
  191. // confirm that the alternate server was indeed used
  192. alternateOkServer, err := newTestDNSServer(true, true, false)
  193. if err != nil {
  194. return errors.Trace(err)
  195. }
  196. defer alternateOkServer.stop()
  197. // transformOkServer behaves like okServer but only responds if the
  198. // transform was applied; other servers do not respond if the transform
  199. // is applied
  200. transformOkServer, err := newTestDNSServer(true, true, true)
  201. if err != nil {
  202. return errors.Trace(err)
  203. }
  204. defer transformOkServer.stop()
  205. servers := []string{noResponseServer.getAddr(), invalidIPServer.getAddr(), okServer.getAddr()}
  206. networkConfig := &NetworkConfig{
  207. GetDNSServers: func() []string { return servers },
  208. LogWarning: func(err error) { fmt.Printf("LogWarning: %v\n", err) },
  209. }
  210. networkID := "networkID-1"
  211. resolver := NewResolver(networkConfig, networkID)
  212. defer resolver.Stop()
  213. params := &ResolveParameters{
  214. AttemptsPerServer: 1,
  215. AttemptsPerPreferredServer: 1,
  216. RequestTimeout: 250 * time.Millisecond,
  217. AwaitTimeout: 250 * time.Millisecond,
  218. IncludeEDNS0: true,
  219. }
  220. checkResult := func(IPs []net.IP) error {
  221. var IPv4, IPv6 net.IP
  222. for _, IP := range IPs {
  223. if IP.To4() != nil {
  224. IPv4 = IP
  225. } else {
  226. IPv6 = IP
  227. }
  228. }
  229. if IPv4 == nil {
  230. return errors.TraceNew("missing IPv4 response")
  231. }
  232. if IPv4.String() != exampleIPv4 {
  233. return errors.TraceNew("unexpected IPv4 response")
  234. }
  235. if resolver.hasIPv6Route {
  236. if IPv6 == nil {
  237. return errors.TraceNew("missing IPv6 response")
  238. }
  239. if IPv6.String() != exampleIPv6 {
  240. return errors.TraceNew("unexpected IPv6 response")
  241. }
  242. }
  243. return nil
  244. }
  245. ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
  246. defer cancelFunc()
  247. // Test: should retry until okServer responds
  248. IPs, err := resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  249. if err != nil {
  250. return errors.Trace(err)
  251. }
  252. err = checkResult(IPs)
  253. if err != nil {
  254. return errors.Trace(err)
  255. }
  256. if resolver.metrics.resolves != 1 ||
  257. resolver.metrics.cacheHits != 0 ||
  258. resolver.metrics.requestsIPv4 != 3 || resolver.metrics.responsesIPv4 != 1 ||
  259. (resolver.hasIPv6Route && (resolver.metrics.requestsIPv6 != 3 || resolver.metrics.responsesIPv6 != 1)) {
  260. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  261. }
  262. // Test: cached response
  263. beforeMetrics := resolver.metrics
  264. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  265. if err != nil {
  266. return errors.Trace(err)
  267. }
  268. err = checkResult(IPs)
  269. if err != nil {
  270. return errors.Trace(err)
  271. }
  272. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  273. resolver.metrics.cacheHits != beforeMetrics.cacheHits+1 ||
  274. resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
  275. resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
  276. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  277. }
  278. // Test: PreresolvedIPAddress
  279. beforeMetrics = resolver.metrics
  280. params.PreresolvedIPAddress = exampleIPv4
  281. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  282. if err != nil {
  283. return errors.Trace(err)
  284. }
  285. if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
  286. return errors.TraceNew("unexpected preresolved response")
  287. }
  288. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  289. resolver.metrics.cacheHits != beforeMetrics.cacheHits ||
  290. resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
  291. resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
  292. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  293. }
  294. params.PreresolvedIPAddress = ""
  295. // Test: change network ID, which must clear cache
  296. beforeMetrics = resolver.metrics
  297. networkID = "networkID-2"
  298. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  299. if err != nil {
  300. return errors.Trace(err)
  301. }
  302. err = checkResult(IPs)
  303. if err != nil {
  304. return errors.Trace(err)
  305. }
  306. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  307. resolver.metrics.cacheHits != beforeMetrics.cacheHits {
  308. return errors.Tracef("unexpected metrics: %+v (%+v)", resolver.metrics, beforeMetrics)
  309. }
  310. // Test: PreferAlternateDNSServer
  311. if alternateOkServer.getRequestCount() != 0 {
  312. return errors.TraceNew("unexpected alternate server request count")
  313. }
  314. resolver.cache.Flush()
  315. params.AlternateDNSServer = alternateOkServer.getAddr()
  316. params.PreferAlternateDNSServer = true
  317. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  318. if err != nil {
  319. return errors.Trace(err)
  320. }
  321. err = checkResult(IPs)
  322. if err != nil {
  323. return errors.Trace(err)
  324. }
  325. if alternateOkServer.getRequestCount() < 1 {
  326. return errors.TraceNew("unexpected alternate server request count")
  327. }
  328. params.AlternateDNSServer = ""
  329. params.PreferAlternateDNSServer = false
  330. // Test: PreferAlternateDNSServer with failed attempt (exercise maxAttempts prefer case)
  331. resolver.cache.Flush()
  332. params.AlternateDNSServer = invalidIPServer.getAddr()
  333. params.PreferAlternateDNSServer = true
  334. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  335. if err != nil {
  336. return errors.Trace(err)
  337. }
  338. err = checkResult(IPs)
  339. if err != nil {
  340. return errors.Trace(err)
  341. }
  342. params.AlternateDNSServer = ""
  343. params.PreferAlternateDNSServer = false
  344. // Test: fall over to AlternateDNSServer when no system servers
  345. beforeCount := alternateOkServer.getRequestCount()
  346. previousGetDNSServers := networkConfig.GetDNSServers
  347. networkConfig.GetDNSServers = func() []string { return nil }
  348. // Force system servers update
  349. networkID = "networkID-3"
  350. resolver.cache.Flush()
  351. params.AlternateDNSServer = alternateOkServer.getAddr()
  352. params.PreferAlternateDNSServer = false
  353. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  354. if err != nil {
  355. return errors.Trace(err)
  356. }
  357. err = checkResult(IPs)
  358. if err != nil {
  359. return errors.Trace(err)
  360. }
  361. if alternateOkServer.getRequestCount() <= beforeCount {
  362. return errors.TraceNew("unexpected alterate server request count")
  363. }
  364. // Test: use default, standard resolver when no servers
  365. resolver.cache.Flush()
  366. params.AlternateDNSServer = ""
  367. params.PreferAlternateDNSServer = false
  368. if len(resolver.systemServers) != 0 {
  369. return errors.TraceNew("unexpected server count")
  370. }
  371. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  372. if err != nil {
  373. return errors.Trace(err)
  374. }
  375. if len(IPs) == 0 {
  376. return errors.TraceNew("unexpected response")
  377. }
  378. // Test: ResolveAddress
  379. networkConfig.GetDNSServers = previousGetDNSServers
  380. // Force system servers update
  381. networkID = "networkID-4"
  382. domainAddress := net.JoinHostPort(exampleDomain, "443")
  383. address, err := resolver.ResolveAddress(ctx, networkID, params, domainAddress)
  384. if err != nil {
  385. return errors.Trace(err)
  386. }
  387. host, port, err := net.SplitHostPort(address)
  388. if err != nil {
  389. return errors.Trace(err)
  390. }
  391. IP := net.ParseIP(host)
  392. if IP == nil || (host != exampleIPv4 && host != exampleIPv6) || port != "443" {
  393. return errors.TraceNew("unexpected response")
  394. }
  395. // Test: protocol transform
  396. if transformOkServer.getRequestCount() != 0 {
  397. return errors.TraceNew("unexpected transform server request count")
  398. }
  399. resolver.cache.Flush()
  400. params.AlternateDNSServer = transformOkServer.getAddr()
  401. params.PreferAlternateDNSServer = true
  402. seed, err := prng.NewSeed()
  403. if err != nil {
  404. return errors.Trace(err)
  405. }
  406. params.ProtocolTransformName = "exampleTransform"
  407. params.ProtocolTransformSpec = exampleTransform
  408. params.ProtocolTransformSeed = seed
  409. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  410. if err != nil {
  411. return errors.Trace(err)
  412. }
  413. err = checkResult(IPs)
  414. if err != nil {
  415. return errors.Trace(err)
  416. }
  417. if transformOkServer.getRequestCount() < 1 {
  418. return errors.TraceNew("unexpected transform server request count")
  419. }
  420. params.AlternateDNSServer = ""
  421. params.PreferAlternateDNSServer = false
  422. params.ProtocolTransformName = ""
  423. params.ProtocolTransformSpec = nil
  424. params.ProtocolTransformSeed = nil
  425. // Test: EDNS(0)
  426. resolver.cache.Flush()
  427. params.IncludeEDNS0 = true
  428. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  429. if err != nil {
  430. return errors.Trace(err)
  431. }
  432. err = checkResult(IPs)
  433. if err != nil {
  434. return errors.Trace(err)
  435. }
  436. params.IncludeEDNS0 = false
  437. // Test: input IP address
  438. beforeMetrics = resolver.metrics
  439. resolver.cache.Flush()
  440. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleIPv4)
  441. if err != nil {
  442. return errors.Trace(err)
  443. }
  444. if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
  445. return errors.TraceNew("unexpected IPv4 response")
  446. }
  447. if resolver.metrics.resolves != beforeMetrics.resolves {
  448. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  449. }
  450. // Test: DNS cache extension
  451. resolver.cache.Flush()
  452. networkConfig.CacheExtensionInitialTTL = (exampleTTLSeconds * 2) * time.Second
  453. networkConfig.CacheExtensionVerifiedTTL = 2 * time.Hour
  454. now := time.Now()
  455. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  456. if err != nil {
  457. return errors.Trace(err)
  458. }
  459. entry, expiry, ok := resolver.cache.GetWithExpiration(exampleDomain)
  460. if !ok ||
  461. !reflect.DeepEqual(entry, IPs) ||
  462. expiry.Before(now.Add(networkConfig.CacheExtensionInitialTTL)) ||
  463. expiry.After(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
  464. return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
  465. }
  466. resolver.VerifyCacheExtension(exampleDomain)
  467. entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain)
  468. if !ok ||
  469. !reflect.DeepEqual(entry, IPs) ||
  470. expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
  471. return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
  472. }
  473. // Set cache flush condition, which should be ignored
  474. networkID = "networkID-5"
  475. resolver.updateNetworkState(networkID)
  476. entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain)
  477. if !ok ||
  478. !reflect.DeepEqual(entry, IPs) ||
  479. expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
  480. return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
  481. }
  482. // Test: cancel context
  483. resolver.cache.Flush()
  484. cancelFunc()
  485. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  486. if err == nil {
  487. return errors.TraceNew("unexpected success")
  488. }
  489. // Test: cancel context while resolving
  490. // This test exercises the additional answers and await cases in
  491. // ResolveIP. The test is timing dependent, and so imperfect, but this
  492. // configuration can reproduce panics in those cases before bugs were
  493. // fixed, where DNS responses need to be received just as the context is
  494. // cancelled.
  495. networkConfig.GetDNSServers = func() []string { return []string{okServer.getAddr()} }
  496. networkID = "networkID-6"
  497. for i := 0; i < 500; i++ {
  498. resolver.cache.Flush()
  499. ctx, cancelFunc := context.WithTimeout(
  500. context.Background(), time.Duration((i%10+1)*20)*time.Microsecond)
  501. defer cancelFunc()
  502. _, _ = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  503. }
  504. return nil
  505. }
  506. func runTestPublicDNSServers() ([]net.IP, string, error) {
  507. networkConfig := &NetworkConfig{
  508. GetDNSServers: getPublicDNSServers,
  509. }
  510. networkID := "networkID-1"
  511. resolver := NewResolver(networkConfig, networkID)
  512. defer resolver.Stop()
  513. params := &ResolveParameters{
  514. AttemptsPerServer: 1,
  515. RequestTimeout: 5 * time.Second,
  516. AwaitTimeout: 1 * time.Second,
  517. IncludeEDNS0: true,
  518. }
  519. IPs, err := resolver.ResolveIP(
  520. context.Background(), networkID, params, exampleDomain)
  521. if err != nil {
  522. return nil, "", errors.Trace(err)
  523. }
  524. gotIPv4 := false
  525. gotIPv6 := false
  526. for _, IP := range IPs {
  527. if IP.To4() != nil {
  528. gotIPv4 = true
  529. } else {
  530. gotIPv6 = true
  531. }
  532. }
  533. if !gotIPv4 {
  534. return nil, "", errors.TraceNew("missing IPv4 response")
  535. }
  536. if !gotIPv6 && resolver.hasIPv6Route {
  537. return nil, "", errors.TraceNew("missing IPv6 response")
  538. }
  539. return IPs, resolver.GetMetrics(), nil
  540. }
  541. func getPublicDNSServers() []string {
  542. servers := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9"}
  543. shuffledServers := make([]string, len(servers))
  544. for i, j := range prng.Perm(len(servers)) {
  545. shuffledServers[i] = servers[j]
  546. }
  547. return shuffledServers
  548. }
  549. const (
  550. exampleDomain = "example.com"
  551. exampleIPv4 = "93.184.216.34"
  552. exampleIPv4CIDR = "93.184.216.0/24"
  553. exampleIPv6 = "2606:2800:220:1:248:1893:25c8:1946"
  554. exampleIPv6CIDR = "2606:2800:220::/48"
  555. exampleTTLSeconds = 60
  556. )
  557. // Set the reserved Z flag
  558. var exampleTransform = transforms.Spec{[2]string{"^([a-f0-9]{4})0100", "\\$\\{1\\}0140"}}
  559. type testDNSServer struct {
  560. respond bool
  561. validResponse bool
  562. expectTransform bool
  563. addr string
  564. requestCount int32
  565. server *dns.Server
  566. }
  567. func newTestDNSServer(respond, validResponse, expectTransform bool) (*testDNSServer, error) {
  568. udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
  569. if err != nil {
  570. return nil, errors.Trace(err)
  571. }
  572. udpConn, err := net.ListenUDP("udp", udpAddr)
  573. if err != nil {
  574. return nil, errors.Trace(err)
  575. }
  576. s := &testDNSServer{
  577. respond: respond,
  578. validResponse: validResponse,
  579. expectTransform: expectTransform,
  580. addr: udpConn.LocalAddr().String(),
  581. }
  582. server := &dns.Server{
  583. PacketConn: udpConn,
  584. Handler: s,
  585. }
  586. s.server = server
  587. go server.ActivateAndServe()
  588. return s, nil
  589. }
  590. func (s *testDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
  591. atomic.AddInt32(&s.requestCount, 1)
  592. if !s.respond {
  593. return
  594. }
  595. // Check the reserved Z flag
  596. if s.expectTransform != r.MsgHdr.Zero {
  597. return
  598. }
  599. if len(r.Question) != 1 || r.Question[0].Name != dns.Fqdn(exampleDomain) {
  600. return
  601. }
  602. m := new(dns.Msg)
  603. m.SetReply(r)
  604. m.Answer = make([]dns.RR, 1)
  605. if r.Question[0].Qtype == dns.TypeA {
  606. IP := net.ParseIP(exampleIPv4)
  607. if !s.validResponse {
  608. IP = net.ParseIP("127.0.0.1")
  609. }
  610. m.Answer[0] = &dns.A{
  611. Hdr: dns.RR_Header{
  612. Name: r.Question[0].Name,
  613. Rrtype: dns.TypeA,
  614. Class: dns.ClassINET,
  615. Ttl: exampleTTLSeconds},
  616. A: IP,
  617. }
  618. } else {
  619. IP := net.ParseIP(exampleIPv6)
  620. if !s.validResponse {
  621. IP = net.ParseIP("::1")
  622. }
  623. m.Answer[0] = &dns.AAAA{
  624. Hdr: dns.RR_Header{
  625. Name: r.Question[0].Name,
  626. Rrtype: dns.TypeAAAA,
  627. Class: dns.ClassINET,
  628. Ttl: exampleTTLSeconds},
  629. AAAA: IP,
  630. }
  631. }
  632. w.WriteMsg(m)
  633. }
  634. func (s *testDNSServer) getAddr() string {
  635. return s.addr
  636. }
  637. func (s *testDNSServer) getRequestCount() int {
  638. return int(atomic.LoadInt32(&s.requestCount))
  639. }
  640. func (s *testDNSServer) stop() {
  641. s.server.PacketConn.Close()
  642. s.server.Shutdown()
  643. }