resolver_test.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844
  1. /*
  2. * Copyright (c) 2022, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package resolver
  20. import (
  21. "context"
  22. "fmt"
  23. "net"
  24. "reflect"
  25. "sync/atomic"
  26. "testing"
  27. "time"
  28. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
  31. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  32. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
  33. "github.com/miekg/dns"
  34. )
  35. func TestMakeResolveParameters(t *testing.T) {
  36. err := runTestMakeResolveParameters()
  37. if err != nil {
  38. t.Fatalf(errors.Trace(err).Error())
  39. }
  40. }
  41. func TestResolver(t *testing.T) {
  42. err := runTestResolver()
  43. if err != nil {
  44. t.Fatalf(errors.Trace(err).Error())
  45. }
  46. }
  47. func TestPublicDNSServers(t *testing.T) {
  48. IPs, metrics, err := runTestPublicDNSServers()
  49. if err != nil {
  50. t.Fatalf(errors.Trace(err).Error())
  51. }
  52. t.Logf("IPs: %v", IPs)
  53. t.Logf("Metrics: %v", metrics)
  54. }
  55. func runTestMakeResolveParameters() error {
  56. frontingProviderID := "frontingProvider"
  57. frontingDialDomain := exampleDomain
  58. alternateDNSServer := "172.16.0.1"
  59. alternateDNSServerWithPort := net.JoinHostPort(alternateDNSServer, resolverDNSPort)
  60. preferredAlternateDNSServer := "172.16.0.2"
  61. preferredAlternateDNSServerWithPort := net.JoinHostPort(preferredAlternateDNSServer, resolverDNSPort)
  62. transformName := "exampleTransform"
  63. paramValues := map[string]interface{}{
  64. "DNSResolverAttemptsPerServer": 2,
  65. "DNSResolverAttemptsPerPreferredServer": 1,
  66. "DNSResolverPreresolvedIPAddressProbability": 1.0,
  67. "DNSResolverPreresolvedIPAddressCIDRs": parameters.LabeledCIDRs{frontingProviderID: []string{exampleIPv4CIDR}},
  68. "DNSResolverAlternateServers": []string{alternateDNSServer},
  69. "DNSResolverPreferredAlternateServers": []string{preferredAlternateDNSServer},
  70. "DNSResolverPreferAlternateServerProbability": 1.0,
  71. "DNSResolverProtocolTransformProbability": 1.0,
  72. "DNSResolverProtocolTransformSpecs": transforms.Specs{transformName: exampleTransform},
  73. "DNSResolverProtocolTransformScopedSpecNames": transforms.ScopedSpecNames{preferredAlternateDNSServer: []string{transformName}},
  74. "DNSResolverIncludeEDNS0Probability": 1.0,
  75. }
  76. params, err := parameters.NewParameters(nil)
  77. if err != nil {
  78. return errors.Trace(err)
  79. }
  80. _, err = params.Set("", false, paramValues)
  81. if err != nil {
  82. return errors.Trace(err)
  83. }
  84. resolver := NewResolver(&NetworkConfig{}, "")
  85. defer resolver.Stop()
  86. resolverParams, err := resolver.MakeResolveParameters(
  87. params.Get(), frontingProviderID, frontingDialDomain)
  88. if err != nil {
  89. return errors.Trace(err)
  90. }
  91. // Test: PreresolvedIPAddress
  92. CIDRContainsIP := func(CIDR, IP string) bool {
  93. _, IPNet, _ := net.ParseCIDR(CIDR)
  94. return IPNet.Contains(net.ParseIP(IP))
  95. }
  96. if resolverParams.AttemptsPerServer != 2 ||
  97. resolverParams.AttemptsPerPreferredServer != 1 ||
  98. resolverParams.RequestTimeout != 5*time.Second ||
  99. resolverParams.AwaitTimeout != 10*time.Millisecond ||
  100. !CIDRContainsIP(exampleIPv4CIDR, resolverParams.PreresolvedIPAddress) ||
  101. resolverParams.PreresolvedDomain != frontingDialDomain {
  102. return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
  103. }
  104. // Test: additional generateIPAddressFromCIDR cases
  105. for i := 0; i < 10000; i++ {
  106. for _, CIDR := range []string{exampleIPv4CIDR, exampleIPv6CIDR} {
  107. IP, err := generateIPAddressFromCIDR(CIDR)
  108. if err != nil {
  109. return errors.Trace(err)
  110. }
  111. if !CIDRContainsIP(CIDR, IP.String()) || common.IsBogon(IP) {
  112. return errors.Tracef(
  113. "invalid generated IP address %v for CIDR %v", IP, CIDR)
  114. }
  115. }
  116. }
  117. // Test: Preferred/Transform/EDNS(0)
  118. paramValues["DNSResolverPreresolvedIPAddressProbability"] = 0.0
  119. _, err = params.Set("", false, paramValues)
  120. if err != nil {
  121. return errors.Trace(err)
  122. }
  123. resolverParams, err = resolver.MakeResolveParameters(
  124. params.Get(), frontingProviderID, frontingDialDomain)
  125. if err != nil {
  126. return errors.Trace(err)
  127. }
  128. if resolverParams.AttemptsPerServer != 2 ||
  129. resolverParams.AttemptsPerPreferredServer != 1 ||
  130. resolverParams.RequestTimeout != 5*time.Second ||
  131. resolverParams.AwaitTimeout != 10*time.Millisecond ||
  132. resolverParams.PreresolvedIPAddress != "" ||
  133. resolverParams.PreresolvedDomain != "" ||
  134. resolverParams.AlternateDNSServer != preferredAlternateDNSServerWithPort ||
  135. resolverParams.PreferAlternateDNSServer != true ||
  136. resolverParams.ProtocolTransformName != transformName ||
  137. resolverParams.ProtocolTransformSpec == nil ||
  138. resolverParams.IncludeEDNS0 != true {
  139. return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
  140. }
  141. // Test: No Preferred/Transform/EDNS(0)
  142. paramValues["DNSResolverPreferAlternateServerProbability"] = 0.0
  143. paramValues["DNSResolverProtocolTransformProbability"] = 0.0
  144. paramValues["DNSResolverIncludeEDNS0Probability"] = 0.0
  145. _, err = params.Set("", false, paramValues)
  146. if err != nil {
  147. return errors.Trace(err)
  148. }
  149. resolverParams, err = resolver.MakeResolveParameters(
  150. params.Get(), frontingProviderID, frontingDialDomain)
  151. if err != nil {
  152. return errors.Trace(err)
  153. }
  154. if resolverParams.AttemptsPerServer != 2 ||
  155. resolverParams.AttemptsPerPreferredServer != 1 ||
  156. resolverParams.RequestTimeout != 5*time.Second ||
  157. resolverParams.AwaitTimeout != 10*time.Millisecond ||
  158. resolverParams.PreresolvedIPAddress != "" ||
  159. resolverParams.PreresolvedDomain != "" ||
  160. resolverParams.AlternateDNSServer != alternateDNSServerWithPort ||
  161. resolverParams.PreferAlternateDNSServer != false ||
  162. resolverParams.ProtocolTransformName != "" ||
  163. resolverParams.ProtocolTransformSpec != nil ||
  164. resolverParams.IncludeEDNS0 != false {
  165. return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
  166. }
  167. return nil
  168. }
  169. func runTestResolver() error {
  170. // noResponseServer will not respond to requests
  171. noResponseServer, err := newTestDNSServer(false, false, false)
  172. if err != nil {
  173. return errors.Trace(err)
  174. }
  175. defer noResponseServer.stop()
  176. // invalidIPServer will respond with an invalid IP
  177. invalidIPServer, err := newTestDNSServer(true, false, false)
  178. if err != nil {
  179. return errors.Trace(err)
  180. }
  181. defer invalidIPServer.stop()
  182. // okServer will respond to correct requests (expected domain) with the
  183. // correct response (expected IPv4 or IPv6 address)
  184. okServer, err := newTestDNSServer(true, true, false)
  185. if err != nil {
  186. return errors.Trace(err)
  187. }
  188. defer okServer.stop()
  189. // alternateOkServer behaves like okServer; getRequestCount is used to
  190. // confirm that the alternate server was indeed used
  191. alternateOkServer, err := newTestDNSServer(true, true, false)
  192. if err != nil {
  193. return errors.Trace(err)
  194. }
  195. defer alternateOkServer.stop()
  196. // transformOkServer behaves like okServer but only responds if the
  197. // transform was applied; other servers do not respond if the transform
  198. // is applied
  199. transformOkServer, err := newTestDNSServer(true, true, true)
  200. if err != nil {
  201. return errors.Trace(err)
  202. }
  203. defer transformOkServer.stop()
  204. servers := []string{noResponseServer.getAddr(), invalidIPServer.getAddr(), okServer.getAddr()}
  205. networkConfig := &NetworkConfig{
  206. GetDNSServers: func() []string { return servers },
  207. LogWarning: func(err error) { fmt.Printf("LogWarning: %v\n", err) },
  208. }
  209. networkID := "networkID-1"
  210. resolver := NewResolver(networkConfig, networkID)
  211. defer resolver.Stop()
  212. params := &ResolveParameters{
  213. AttemptsPerServer: 1,
  214. AttemptsPerPreferredServer: 1,
  215. RequestTimeout: 250 * time.Millisecond,
  216. AwaitTimeout: 250 * time.Millisecond,
  217. IncludeEDNS0: true,
  218. }
  219. checkResult := func(IPs []net.IP) error {
  220. var IPv4, IPv6 net.IP
  221. for _, IP := range IPs {
  222. if IP.To4() != nil {
  223. IPv4 = IP
  224. } else {
  225. IPv6 = IP
  226. }
  227. }
  228. if IPv4 == nil {
  229. return errors.TraceNew("missing IPv4 response")
  230. }
  231. if IPv4.String() != exampleIPv4 {
  232. return errors.TraceNew("unexpected IPv4 response")
  233. }
  234. if resolver.hasIPv6Route {
  235. if IPv6 == nil {
  236. return errors.TraceNew("missing IPv6 response")
  237. }
  238. if IPv6.String() != exampleIPv6 {
  239. return errors.TraceNew("unexpected IPv6 response")
  240. }
  241. }
  242. return nil
  243. }
  244. ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
  245. defer cancelFunc()
  246. // Test: should retry until okServer responds
  247. IPs, err := resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  248. if err != nil {
  249. return errors.Trace(err)
  250. }
  251. err = checkResult(IPs)
  252. if err != nil {
  253. return errors.Trace(err)
  254. }
  255. if resolver.metrics.resolves != 1 ||
  256. resolver.metrics.cacheHits != 0 ||
  257. resolver.metrics.requestsIPv4 != 3 || resolver.metrics.responsesIPv4 != 1 ||
  258. (resolver.hasIPv6Route && (resolver.metrics.requestsIPv6 != 3 || resolver.metrics.responsesIPv6 != 1)) {
  259. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  260. }
  261. // Test: cached response
  262. beforeMetrics := resolver.metrics
  263. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  264. if err != nil {
  265. return errors.Trace(err)
  266. }
  267. err = checkResult(IPs)
  268. if err != nil {
  269. return errors.Trace(err)
  270. }
  271. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  272. resolver.metrics.cacheHits != beforeMetrics.cacheHits+1 ||
  273. resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
  274. resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
  275. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  276. }
  277. // Test: PreresolvedIPAddress
  278. beforeMetrics = resolver.metrics
  279. params.PreresolvedIPAddress = exampleIPv4
  280. params.PreresolvedDomain = exampleDomain
  281. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  282. if err != nil {
  283. return errors.Trace(err)
  284. }
  285. if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
  286. return errors.TraceNew("unexpected preresolved response")
  287. }
  288. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  289. resolver.metrics.cacheHits != beforeMetrics.cacheHits ||
  290. resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
  291. resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
  292. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  293. }
  294. params.PreresolvedIPAddress = ""
  295. // Test: PreresolvedIPAddress set for different domain
  296. beforeMetrics = resolver.metrics
  297. params.PreresolvedIPAddress = exampleIPv4
  298. params.PreresolvedDomain = "not.example.com"
  299. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  300. if err != nil {
  301. return errors.Trace(err)
  302. }
  303. err = checkResult(IPs)
  304. if err != nil {
  305. return errors.Trace(err)
  306. }
  307. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  308. resolver.metrics.cacheHits != beforeMetrics.cacheHits+1 ||
  309. resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
  310. resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
  311. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  312. }
  313. params.PreresolvedIPAddress = ""
  314. params.PreresolvedDomain = ""
  315. // Test: change network ID, which must clear cache
  316. beforeMetrics = resolver.metrics
  317. networkID = "networkID-2"
  318. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  319. if err != nil {
  320. return errors.Trace(err)
  321. }
  322. err = checkResult(IPs)
  323. if err != nil {
  324. return errors.Trace(err)
  325. }
  326. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  327. resolver.metrics.cacheHits != beforeMetrics.cacheHits {
  328. return errors.Tracef("unexpected metrics: %+v (%+v)", resolver.metrics, beforeMetrics)
  329. }
  330. // Test: PreferAlternateDNSServer
  331. if alternateOkServer.getRequestCount() != 0 {
  332. return errors.TraceNew("unexpected alternate server request count")
  333. }
  334. resolver.cache.Flush()
  335. params.AlternateDNSServer = alternateOkServer.getAddr()
  336. params.PreferAlternateDNSServer = true
  337. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  338. if err != nil {
  339. return errors.Trace(err)
  340. }
  341. err = checkResult(IPs)
  342. if err != nil {
  343. return errors.Trace(err)
  344. }
  345. if alternateOkServer.getRequestCount() < 1 {
  346. return errors.TraceNew("unexpected alternate server request count")
  347. }
  348. params.AlternateDNSServer = ""
  349. params.PreferAlternateDNSServer = false
  350. // Test: PreferAlternateDNSServer with failed attempt (exercise maxAttempts prefer case)
  351. resolver.cache.Flush()
  352. params.AlternateDNSServer = invalidIPServer.getAddr()
  353. params.PreferAlternateDNSServer = true
  354. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  355. if err != nil {
  356. return errors.Trace(err)
  357. }
  358. err = checkResult(IPs)
  359. if err != nil {
  360. return errors.Trace(err)
  361. }
  362. params.AlternateDNSServer = ""
  363. params.PreferAlternateDNSServer = false
  364. // Test: fall over to AlternateDNSServer when no system servers
  365. beforeCount := alternateOkServer.getRequestCount()
  366. previousGetDNSServers := networkConfig.GetDNSServers
  367. networkConfig.GetDNSServers = func() []string { return nil }
  368. // Force system servers update
  369. networkID = "networkID-3"
  370. resolver.cache.Flush()
  371. params.AlternateDNSServer = alternateOkServer.getAddr()
  372. params.PreferAlternateDNSServer = false
  373. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  374. if err != nil {
  375. return errors.Trace(err)
  376. }
  377. err = checkResult(IPs)
  378. if err != nil {
  379. return errors.Trace(err)
  380. }
  381. if alternateOkServer.getRequestCount() <= beforeCount {
  382. return errors.TraceNew("unexpected alterate server request count")
  383. }
  384. // Test: use default, standard resolver when no servers
  385. resolver.cache.Flush()
  386. params.AlternateDNSServer = ""
  387. params.PreferAlternateDNSServer = false
  388. if len(resolver.systemServers) != 0 {
  389. return errors.TraceNew("unexpected server count")
  390. }
  391. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  392. if err != nil {
  393. return errors.Trace(err)
  394. }
  395. if len(IPs) == 0 {
  396. return errors.TraceNew("unexpected response")
  397. }
  398. // Test: ResolveAddress
  399. networkConfig.GetDNSServers = previousGetDNSServers
  400. // Force system servers update
  401. networkID = "networkID-4"
  402. domainAddress := net.JoinHostPort(exampleDomain, "443")
  403. address, err := resolver.ResolveAddress(ctx, networkID, params, domainAddress)
  404. if err != nil {
  405. return errors.Trace(err)
  406. }
  407. host, port, err := net.SplitHostPort(address)
  408. if err != nil {
  409. return errors.Trace(err)
  410. }
  411. IP := net.ParseIP(host)
  412. if IP == nil || (host != exampleIPv4 && host != exampleIPv6) || port != "443" {
  413. return errors.TraceNew("unexpected response")
  414. }
  415. // Test: protocol transform
  416. if transformOkServer.getRequestCount() != 0 {
  417. return errors.TraceNew("unexpected transform server request count")
  418. }
  419. resolver.cache.Flush()
  420. params.AlternateDNSServer = transformOkServer.getAddr()
  421. params.PreferAlternateDNSServer = true
  422. seed, err := prng.NewSeed()
  423. if err != nil {
  424. return errors.Trace(err)
  425. }
  426. params.ProtocolTransformName = "exampleTransform"
  427. params.ProtocolTransformSpec = exampleTransform
  428. params.ProtocolTransformSeed = seed
  429. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  430. if err != nil {
  431. return errors.Trace(err)
  432. }
  433. err = checkResult(IPs)
  434. if err != nil {
  435. return errors.Trace(err)
  436. }
  437. if transformOkServer.getRequestCount() < 1 {
  438. return errors.TraceNew("unexpected transform server request count")
  439. }
  440. params.AlternateDNSServer = ""
  441. params.PreferAlternateDNSServer = false
  442. params.ProtocolTransformName = ""
  443. params.ProtocolTransformSpec = nil
  444. params.ProtocolTransformSeed = nil
  445. // Test: EDNS(0)
  446. resolver.cache.Flush()
  447. params.IncludeEDNS0 = true
  448. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  449. if err != nil {
  450. return errors.Trace(err)
  451. }
  452. err = checkResult(IPs)
  453. if err != nil {
  454. return errors.Trace(err)
  455. }
  456. params.IncludeEDNS0 = false
  457. // Test: input IP address
  458. beforeMetrics = resolver.metrics
  459. resolver.cache.Flush()
  460. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleIPv4)
  461. if err != nil {
  462. return errors.Trace(err)
  463. }
  464. if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
  465. return errors.TraceNew("unexpected IPv4 response")
  466. }
  467. if resolver.metrics.resolves != beforeMetrics.resolves {
  468. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  469. }
  470. // Test: DNS cache extension
  471. resolver.cache.Flush()
  472. networkConfig.CacheExtensionInitialTTL = (exampleTTLSeconds * 2) * time.Second
  473. networkConfig.CacheExtensionVerifiedTTL = 2 * time.Hour
  474. now := time.Now()
  475. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  476. if err != nil {
  477. return errors.Trace(err)
  478. }
  479. entry, expiry, ok := resolver.cache.GetWithExpiration(exampleDomain)
  480. if !ok ||
  481. !reflect.DeepEqual(entry, IPs) ||
  482. expiry.Before(now.Add(networkConfig.CacheExtensionInitialTTL)) ||
  483. expiry.After(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
  484. return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
  485. }
  486. resolver.VerifyCacheExtension(exampleDomain)
  487. entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain)
  488. if !ok ||
  489. !reflect.DeepEqual(entry, IPs) ||
  490. expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
  491. return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
  492. }
  493. // Set cache flush condition, which should be ignored
  494. networkID = "networkID-5"
  495. resolver.updateNetworkState(networkID)
  496. entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain)
  497. if !ok ||
  498. !reflect.DeepEqual(entry, IPs) ||
  499. expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
  500. return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
  501. }
  502. // Test: cancel context
  503. resolver.cache.Flush()
  504. cancelFunc()
  505. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  506. if err == nil {
  507. return errors.TraceNew("unexpected success")
  508. }
  509. // Test: cancel context while resolving
  510. // This test exercises the additional answers and await cases in
  511. // ResolveIP. The test is timing dependent, and so imperfect, but this
  512. // configuration can reproduce panics in those cases before bugs were
  513. // fixed, where DNS responses need to be received just as the context is
  514. // cancelled.
  515. networkConfig.GetDNSServers = func() []string { return []string{okServer.getAddr()} }
  516. networkID = "networkID-6"
  517. for i := 0; i < 500; i++ {
  518. resolver.cache.Flush()
  519. ctx, cancelFunc := context.WithTimeout(
  520. context.Background(), time.Duration((i%10+1)*20)*time.Microsecond)
  521. defer cancelFunc()
  522. _, _ = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  523. }
  524. return nil
  525. }
  526. func runTestPublicDNSServers() ([]net.IP, string, error) {
  527. networkConfig := &NetworkConfig{
  528. GetDNSServers: getPublicDNSServers,
  529. }
  530. networkID := "networkID-1"
  531. resolver := NewResolver(networkConfig, networkID)
  532. defer resolver.Stop()
  533. params := &ResolveParameters{
  534. AttemptsPerServer: 1,
  535. RequestTimeout: 5 * time.Second,
  536. AwaitTimeout: 1 * time.Second,
  537. IncludeEDNS0: true,
  538. }
  539. IPs, err := resolver.ResolveIP(
  540. context.Background(), networkID, params, exampleDomain)
  541. if err != nil {
  542. return nil, "", errors.Trace(err)
  543. }
  544. gotIPv4 := false
  545. gotIPv6 := false
  546. for _, IP := range IPs {
  547. if IP.To4() != nil {
  548. gotIPv4 = true
  549. } else {
  550. gotIPv6 = true
  551. }
  552. }
  553. if !gotIPv4 {
  554. return nil, "", errors.TraceNew("missing IPv4 response")
  555. }
  556. if !gotIPv6 && resolver.hasIPv6Route {
  557. return nil, "", errors.TraceNew("missing IPv6 response")
  558. }
  559. return IPs, resolver.GetMetrics(), nil
  560. }
  561. func getPublicDNSServers() []string {
  562. servers := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9"}
  563. shuffledServers := make([]string, len(servers))
  564. for i, j := range prng.Perm(len(servers)) {
  565. shuffledServers[i] = servers[j]
  566. }
  567. return shuffledServers
  568. }
  569. const (
  570. exampleDomain = "example.com"
  571. exampleIPv4 = "93.184.216.34"
  572. exampleIPv4CIDR = "93.184.216.0/24"
  573. exampleIPv6 = "2606:2800:220:1:248:1893:25c8:1946"
  574. exampleIPv6CIDR = "2606:2800:220::/48"
  575. exampleTTLSeconds = 60
  576. )
  577. // Set the reserved Z flag
  578. var exampleTransform = transforms.Spec{[2]string{"^([a-f0-9]{4})0100", "\\$\\{1\\}0140"}}
  579. type testDNSServer struct {
  580. respond bool
  581. validResponse bool
  582. expectTransform bool
  583. addr string
  584. requestCount int32
  585. server *dns.Server
  586. }
  587. func newTestDNSServer(respond, validResponse, expectTransform bool) (*testDNSServer, error) {
  588. udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
  589. if err != nil {
  590. return nil, errors.Trace(err)
  591. }
  592. udpConn, err := net.ListenUDP("udp", udpAddr)
  593. if err != nil {
  594. return nil, errors.Trace(err)
  595. }
  596. s := &testDNSServer{
  597. respond: respond,
  598. validResponse: validResponse,
  599. expectTransform: expectTransform,
  600. addr: udpConn.LocalAddr().String(),
  601. }
  602. server := &dns.Server{
  603. PacketConn: udpConn,
  604. Handler: s,
  605. }
  606. s.server = server
  607. go server.ActivateAndServe()
  608. return s, nil
  609. }
  610. func (s *testDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
  611. atomic.AddInt32(&s.requestCount, 1)
  612. if !s.respond {
  613. return
  614. }
  615. // Check the reserved Z flag
  616. if s.expectTransform != r.MsgHdr.Zero {
  617. return
  618. }
  619. if len(r.Question) != 1 || r.Question[0].Name != dns.Fqdn(exampleDomain) {
  620. return
  621. }
  622. m := new(dns.Msg)
  623. m.SetReply(r)
  624. m.Answer = make([]dns.RR, 1)
  625. if r.Question[0].Qtype == dns.TypeA {
  626. IP := net.ParseIP(exampleIPv4)
  627. if !s.validResponse {
  628. IP = net.ParseIP("127.0.0.1")
  629. }
  630. m.Answer[0] = &dns.A{
  631. Hdr: dns.RR_Header{
  632. Name: r.Question[0].Name,
  633. Rrtype: dns.TypeA,
  634. Class: dns.ClassINET,
  635. Ttl: exampleTTLSeconds},
  636. A: IP,
  637. }
  638. } else {
  639. IP := net.ParseIP(exampleIPv6)
  640. if !s.validResponse {
  641. IP = net.ParseIP("::1")
  642. }
  643. m.Answer[0] = &dns.AAAA{
  644. Hdr: dns.RR_Header{
  645. Name: r.Question[0].Name,
  646. Rrtype: dns.TypeAAAA,
  647. Class: dns.ClassINET,
  648. Ttl: exampleTTLSeconds},
  649. AAAA: IP,
  650. }
  651. }
  652. w.WriteMsg(m)
  653. }
  654. func (s *testDNSServer) getAddr() string {
  655. return s.addr
  656. }
  657. func (s *testDNSServer) getRequestCount() int {
  658. return int(atomic.LoadInt32(&s.requestCount))
  659. }
  660. func (s *testDNSServer) stop() {
  661. s.server.PacketConn.Close()
  662. s.server.Shutdown()
  663. }