resolver_test.go 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925
  1. /*
  2. * Copyright (c) 2022, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package resolver
  20. import (
  21. "context"
  22. "fmt"
  23. "net"
  24. "reflect"
  25. "sync/atomic"
  26. "testing"
  27. "time"
  28. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
  31. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  32. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
  33. "github.com/miekg/dns"
  34. )
  35. func TestMakeResolveParameters(t *testing.T) {
  36. err := runTestMakeResolveParameters()
  37. if err != nil {
  38. t.Fatalf(errors.Trace(err).Error())
  39. }
  40. }
  41. func TestResolver(t *testing.T) {
  42. err := runTestResolver()
  43. if err != nil {
  44. t.Fatalf(errors.Trace(err).Error())
  45. }
  46. }
  47. func TestPublicDNSServers(t *testing.T) {
  48. IPs, metrics, err := runTestPublicDNSServers()
  49. if err != nil {
  50. t.Fatalf(errors.Trace(err).Error())
  51. }
  52. t.Logf("IPs: %v", IPs)
  53. t.Logf("Metrics: %v", metrics)
  54. }
  55. func runTestMakeResolveParameters() error {
  56. frontingProviderID := "frontingProvider"
  57. frontingDialDomain := exampleDomain
  58. alternateDNSServer := "172.16.0.1"
  59. alternateDNSServerWithPort := net.JoinHostPort(alternateDNSServer, resolverDNSPort)
  60. preferredAlternateDNSServer := "172.16.0.2"
  61. preferredAlternateDNSServerWithPort := net.JoinHostPort(preferredAlternateDNSServer, resolverDNSPort)
  62. transformName := "exampleTransform"
  63. paramValues := map[string]interface{}{
  64. "DNSResolverAttemptsPerServer": 2,
  65. "DNSResolverAttemptsPerPreferredServer": 1,
  66. "DNSResolverPreresolvedIPAddressProbability": 1.0,
  67. "DNSResolverPreresolvedIPAddressCIDRs": parameters.LabeledCIDRs{frontingProviderID: []string{exampleIPv4CIDR}},
  68. "DNSResolverAlternateServers": []string{alternateDNSServer},
  69. "DNSResolverPreferredAlternateServers": []string{preferredAlternateDNSServer},
  70. "DNSResolverPreferAlternateServerProbability": 1.0,
  71. "DNSResolverProtocolTransformProbability": 1.0,
  72. "DNSResolverProtocolTransformSpecs": transforms.Specs{transformName: exampleTransform},
  73. "DNSResolverProtocolTransformScopedSpecNames": transforms.ScopedSpecNames{preferredAlternateDNSServer: []string{transformName}},
  74. "DNSResolverQNameRandomizeCasingProbability": 1.0,
  75. "DNSResolverQNameMustMatchProbability": 1.0,
  76. "DNSResolverIncludeEDNS0Probability": 1.0,
  77. }
  78. params, err := parameters.NewParameters(nil)
  79. if err != nil {
  80. return errors.Trace(err)
  81. }
  82. _, err = params.Set("", 0, paramValues)
  83. if err != nil {
  84. return errors.Trace(err)
  85. }
  86. resolver := NewResolver(&NetworkConfig{}, "")
  87. defer resolver.Stop()
  88. resolverParams, err := resolver.MakeResolveParameters(
  89. params.Get(), frontingProviderID, frontingDialDomain)
  90. if err != nil {
  91. return errors.Trace(err)
  92. }
  93. // Test: PreresolvedIPAddress
  94. CIDRContainsIP := func(CIDR, IP string) bool {
  95. _, IPNet, _ := net.ParseCIDR(CIDR)
  96. return IPNet.Contains(net.ParseIP(IP))
  97. }
  98. if resolverParams.AttemptsPerServer != 2 ||
  99. resolverParams.AttemptsPerPreferredServer != 1 ||
  100. resolverParams.RequestTimeout != 5*time.Second ||
  101. resolverParams.AwaitTimeout != 10*time.Millisecond ||
  102. !CIDRContainsIP(exampleIPv4CIDR, resolverParams.PreresolvedIPAddress) ||
  103. resolverParams.PreresolvedDomain != frontingDialDomain {
  104. return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
  105. }
  106. // Test: additional generateIPAddressFromCIDR cases
  107. for i := 0; i < 10000; i++ {
  108. for _, CIDR := range []string{exampleIPv4CIDR, exampleIPv6CIDR} {
  109. IP, err := generateIPAddressFromCIDR(CIDR)
  110. if err != nil {
  111. return errors.Trace(err)
  112. }
  113. if !CIDRContainsIP(CIDR, IP.String()) || common.IsBogon(IP) {
  114. return errors.Tracef(
  115. "invalid generated IP address %v for CIDR %v", IP, CIDR)
  116. }
  117. }
  118. }
  119. // Test: Preferred/Transform/RandomQNameCasing/QNameMustMatch/EDNS(0)
  120. paramValues["DNSResolverPreresolvedIPAddressProbability"] = 0.0
  121. _, err = params.Set("", 0, paramValues)
  122. if err != nil {
  123. return errors.Trace(err)
  124. }
  125. resolverParams, err = resolver.MakeResolveParameters(
  126. params.Get(), frontingProviderID, frontingDialDomain)
  127. if err != nil {
  128. return errors.Trace(err)
  129. }
  130. if resolverParams.AttemptsPerServer != 2 ||
  131. resolverParams.AttemptsPerPreferredServer != 1 ||
  132. resolverParams.RequestTimeout != 5*time.Second ||
  133. resolverParams.AwaitTimeout != 10*time.Millisecond ||
  134. resolverParams.PreresolvedIPAddress != "" ||
  135. resolverParams.PreresolvedDomain != "" ||
  136. resolverParams.AlternateDNSServer != preferredAlternateDNSServerWithPort ||
  137. resolverParams.PreferAlternateDNSServer != true ||
  138. resolverParams.ProtocolTransformName != transformName ||
  139. resolverParams.ProtocolTransformSpec == nil ||
  140. resolverParams.RandomQNameCasingSeed == nil ||
  141. resolverParams.ResponseQNameMustMatch != true ||
  142. resolverParams.IncludeEDNS0 != true {
  143. return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
  144. }
  145. // Test: No Preferred/Transform/EDNS(0)
  146. paramValues["DNSResolverPreferAlternateServerProbability"] = 0.0
  147. paramValues["DNSResolverProtocolTransformProbability"] = 0.0
  148. paramValues["DNSResolverQNameRandomizeCasingProbability"] = 0.0
  149. paramValues["DNSResolverQNameMustMatchProbability"] = 0.0
  150. paramValues["DNSResolverIncludeEDNS0Probability"] = 0.0
  151. _, err = params.Set("", 0, paramValues)
  152. if err != nil {
  153. return errors.Trace(err)
  154. }
  155. resolverParams, err = resolver.MakeResolveParameters(
  156. params.Get(), frontingProviderID, frontingDialDomain)
  157. if err != nil {
  158. return errors.Trace(err)
  159. }
  160. if resolverParams.AttemptsPerServer != 2 ||
  161. resolverParams.AttemptsPerPreferredServer != 1 ||
  162. resolverParams.RequestTimeout != 5*time.Second ||
  163. resolverParams.AwaitTimeout != 10*time.Millisecond ||
  164. resolverParams.PreresolvedIPAddress != "" ||
  165. resolverParams.PreresolvedDomain != "" ||
  166. resolverParams.AlternateDNSServer != alternateDNSServerWithPort ||
  167. resolverParams.PreferAlternateDNSServer != false ||
  168. resolverParams.ProtocolTransformName != "" ||
  169. resolverParams.ProtocolTransformSpec != nil ||
  170. resolverParams.RandomQNameCasingSeed != nil ||
  171. resolverParams.ResponseQNameMustMatch != false ||
  172. resolverParams.IncludeEDNS0 != false {
  173. return errors.Tracef("unexpected resolver parameters: %+v", resolverParams)
  174. }
  175. return nil
  176. }
  177. func runTestResolver() error {
  178. // noResponseServer will not respond to requests
  179. noResponseServer, err := newTestDNSServer(false, false, false, false)
  180. if err != nil {
  181. return errors.Trace(err)
  182. }
  183. defer noResponseServer.stop()
  184. // invalidIPServer will respond with an invalid IP
  185. invalidIPServer, err := newTestDNSServer(true, false, false, false)
  186. if err != nil {
  187. return errors.Trace(err)
  188. }
  189. defer invalidIPServer.stop()
  190. // okServer will respond to correct requests (expected domain) with the
  191. // correct response (expected IPv4 or IPv6 address)
  192. okServer, err := newTestDNSServer(true, true, false, false)
  193. if err != nil {
  194. return errors.Trace(err)
  195. }
  196. defer okServer.stop()
  197. // alternateOkServer behaves like okServer; getRequestCount is used to
  198. // confirm that the alternate server was indeed used
  199. alternateOkServer, err := newTestDNSServer(true, true, false, false)
  200. if err != nil {
  201. return errors.Trace(err)
  202. }
  203. defer alternateOkServer.stop()
  204. // transformOkServer behaves like okServer but only responds if the
  205. // transform was applied; other servers do not respond if the transform
  206. // is applied
  207. transformOkServer, err := newTestDNSServer(true, true, true, false)
  208. if err != nil {
  209. return errors.Trace(err)
  210. }
  211. defer transformOkServer.stop()
  212. randomQNameCasingOkServer, err := newTestDNSServer(true, true, false, true)
  213. if err != nil {
  214. return errors.Trace(err)
  215. }
  216. defer randomQNameCasingOkServer.stop()
  217. servers := []string{noResponseServer.getAddr(), invalidIPServer.getAddr(), okServer.getAddr()}
  218. networkConfig := &NetworkConfig{
  219. GetDNSServers: func() []string { return servers },
  220. LogWarning: func(err error) { fmt.Printf("LogWarning: %v\n", err) },
  221. }
  222. networkID := "networkID-1"
  223. resolver := NewResolver(networkConfig, networkID)
  224. defer resolver.Stop()
  225. params := &ResolveParameters{
  226. AttemptsPerServer: 1,
  227. AttemptsPerPreferredServer: 1,
  228. RequestTimeout: 250 * time.Millisecond,
  229. AwaitTimeout: 250 * time.Millisecond,
  230. IncludeEDNS0: true,
  231. }
  232. checkResult := func(IPs []net.IP) error {
  233. var IPv4, IPv6 net.IP
  234. for _, IP := range IPs {
  235. if IP.To4() != nil {
  236. IPv4 = IP
  237. } else {
  238. IPv6 = IP
  239. }
  240. }
  241. if IPv4 == nil {
  242. return errors.TraceNew("missing IPv4 response")
  243. }
  244. if IPv4.String() != exampleIPv4 {
  245. return errors.TraceNew("unexpected IPv4 response")
  246. }
  247. if resolver.hasIPv6Route {
  248. if IPv6 == nil {
  249. return errors.TraceNew("missing IPv6 response")
  250. }
  251. if IPv6.String() != exampleIPv6 {
  252. return errors.TraceNew("unexpected IPv6 response")
  253. }
  254. }
  255. return nil
  256. }
  257. ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
  258. defer cancelFunc()
  259. // Test: should retry until okServer responds
  260. IPs, err := resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  261. if err != nil {
  262. return errors.Trace(err)
  263. }
  264. err = checkResult(IPs)
  265. if err != nil {
  266. return errors.Trace(err)
  267. }
  268. if resolver.metrics.resolves != 1 ||
  269. resolver.metrics.cacheHits != 0 ||
  270. resolver.metrics.requestsIPv4 != 3 || resolver.metrics.responsesIPv4 != 1 ||
  271. (resolver.hasIPv6Route && (resolver.metrics.requestsIPv6 != 3 || resolver.metrics.responsesIPv6 != 1)) {
  272. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  273. }
  274. // Test: cached response
  275. beforeMetrics := resolver.metrics
  276. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  277. if err != nil {
  278. return errors.Trace(err)
  279. }
  280. err = checkResult(IPs)
  281. if err != nil {
  282. return errors.Trace(err)
  283. }
  284. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  285. resolver.metrics.cacheHits != beforeMetrics.cacheHits+1 ||
  286. resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
  287. resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
  288. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  289. }
  290. // Test: PreresolvedIPAddress
  291. beforeMetrics = resolver.metrics
  292. params.PreresolvedIPAddress = exampleIPv4
  293. params.PreresolvedDomain = exampleDomain
  294. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  295. if err != nil {
  296. return errors.Trace(err)
  297. }
  298. if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
  299. return errors.TraceNew("unexpected preresolved response")
  300. }
  301. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  302. resolver.metrics.cacheHits != beforeMetrics.cacheHits ||
  303. resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
  304. resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
  305. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  306. }
  307. params.PreresolvedIPAddress = ""
  308. // Test: PreresolvedIPAddress set for different domain
  309. beforeMetrics = resolver.metrics
  310. params.PreresolvedIPAddress = exampleIPv4
  311. params.PreresolvedDomain = "not.example.com"
  312. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  313. if err != nil {
  314. return errors.Trace(err)
  315. }
  316. err = checkResult(IPs)
  317. if err != nil {
  318. return errors.Trace(err)
  319. }
  320. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  321. resolver.metrics.cacheHits != beforeMetrics.cacheHits+1 ||
  322. resolver.metrics.requestsIPv4 != beforeMetrics.requestsIPv4 ||
  323. resolver.metrics.requestsIPv6 != beforeMetrics.requestsIPv6 {
  324. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  325. }
  326. params.PreresolvedIPAddress = ""
  327. params.PreresolvedDomain = ""
  328. // Test: change network ID, which must clear cache
  329. beforeMetrics = resolver.metrics
  330. networkID = "networkID-2"
  331. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  332. if err != nil {
  333. return errors.Trace(err)
  334. }
  335. err = checkResult(IPs)
  336. if err != nil {
  337. return errors.Trace(err)
  338. }
  339. if resolver.metrics.resolves != beforeMetrics.resolves+1 ||
  340. resolver.metrics.cacheHits != beforeMetrics.cacheHits {
  341. return errors.Tracef("unexpected metrics: %+v (%+v)", resolver.metrics, beforeMetrics)
  342. }
  343. // Test: PreferAlternateDNSServer
  344. if alternateOkServer.getRequestCount() != 0 {
  345. return errors.TraceNew("unexpected alternate server request count")
  346. }
  347. resolver.cache.Flush()
  348. params.AlternateDNSServer = alternateOkServer.getAddr()
  349. params.PreferAlternateDNSServer = true
  350. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  351. if err != nil {
  352. return errors.Trace(err)
  353. }
  354. err = checkResult(IPs)
  355. if err != nil {
  356. return errors.Trace(err)
  357. }
  358. if alternateOkServer.getRequestCount() < 1 {
  359. return errors.TraceNew("unexpected alternate server request count")
  360. }
  361. params.AlternateDNSServer = ""
  362. params.PreferAlternateDNSServer = false
  363. // Test: PreferAlternateDNSServer with failed attempt (exercise maxAttempts prefer case)
  364. resolver.cache.Flush()
  365. params.AlternateDNSServer = invalidIPServer.getAddr()
  366. params.PreferAlternateDNSServer = true
  367. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  368. if err != nil {
  369. return errors.Trace(err)
  370. }
  371. err = checkResult(IPs)
  372. if err != nil {
  373. return errors.Trace(err)
  374. }
  375. params.AlternateDNSServer = ""
  376. params.PreferAlternateDNSServer = false
  377. // Test: fall over to AlternateDNSServer when no system servers
  378. beforeCount := alternateOkServer.getRequestCount()
  379. previousGetDNSServers := networkConfig.GetDNSServers
  380. networkConfig.GetDNSServers = func() []string { return nil }
  381. // Force system servers update
  382. networkID = "networkID-3"
  383. resolver.cache.Flush()
  384. params.AlternateDNSServer = alternateOkServer.getAddr()
  385. params.PreferAlternateDNSServer = false
  386. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  387. if err != nil {
  388. return errors.Trace(err)
  389. }
  390. err = checkResult(IPs)
  391. if err != nil {
  392. return errors.Trace(err)
  393. }
  394. if alternateOkServer.getRequestCount() <= beforeCount {
  395. return errors.TraceNew("unexpected alterate server request count")
  396. }
  397. // Test: use default, standard resolver when no servers
  398. resolver.cache.Flush()
  399. params.AlternateDNSServer = ""
  400. params.PreferAlternateDNSServer = false
  401. if len(resolver.systemServers) != 0 {
  402. return errors.TraceNew("unexpected server count")
  403. }
  404. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleRealDomain)
  405. if err != nil {
  406. return errors.Trace(err)
  407. }
  408. if len(IPs) == 0 {
  409. return errors.TraceNew("unexpected response")
  410. }
  411. // Test: ResolveAddress
  412. networkConfig.GetDNSServers = previousGetDNSServers
  413. // Force system servers update
  414. networkID = "networkID-4"
  415. domainAddress := net.JoinHostPort(exampleDomain, "443")
  416. address, err := resolver.ResolveAddress(ctx, networkID, params, "", domainAddress)
  417. if err != nil {
  418. return errors.Trace(err)
  419. }
  420. host, port, err := net.SplitHostPort(address)
  421. if err != nil {
  422. return errors.Trace(err)
  423. }
  424. IP := net.ParseIP(host)
  425. if IP == nil || (host != exampleIPv4 && host != exampleIPv6) || port != "443" {
  426. return errors.TraceNew("unexpected response")
  427. }
  428. // Test: protocol transform
  429. if transformOkServer.getRequestCount() != 0 {
  430. return errors.TraceNew("unexpected transform server request count")
  431. }
  432. resolver.cache.Flush()
  433. params.AttemptsPerServer = 0
  434. params.AlternateDNSServer = transformOkServer.getAddr()
  435. params.PreferAlternateDNSServer = true
  436. seed, err := prng.NewSeed()
  437. if err != nil {
  438. return errors.Trace(err)
  439. }
  440. params.ProtocolTransformName = "exampleTransform"
  441. params.ProtocolTransformSpec = exampleTransform
  442. params.ProtocolTransformSeed = seed
  443. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  444. if err != nil {
  445. return errors.Trace(err)
  446. }
  447. err = checkResult(IPs)
  448. if err != nil {
  449. return errors.Trace(err)
  450. }
  451. if transformOkServer.getRequestCount() < 1 {
  452. return errors.TraceNew("unexpected transform server request count")
  453. }
  454. params.AttemptsPerServer = 1
  455. params.AlternateDNSServer = ""
  456. params.PreferAlternateDNSServer = false
  457. params.ProtocolTransformName = ""
  458. params.ProtocolTransformSpec = nil
  459. params.ProtocolTransformSeed = nil
  460. // Test: random QName (hostname) casing
  461. //
  462. // Note: there's a (1/2)^N chance that the QName (hostname) with randomized
  463. // casing has the same casing as the input QName, where N is the number of
  464. // Unicode letters in the QName. In such an event these tests will either
  465. // give a false positive or false negative depending on the subtest.
  466. if randomQNameCasingOkServer.getRequestCount() != 0 {
  467. return errors.TraceNew("unexpected random QName casing server request count")
  468. }
  469. resolver.cache.Flush()
  470. params.AttemptsPerServer = 0
  471. params.AttemptsPerPreferredServer = 1
  472. params.AlternateDNSServer = randomQNameCasingOkServer.getAddr()
  473. params.PreferAlternateDNSServer = true
  474. params.RandomQNameCasingSeed = seed
  475. _, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  476. if err != nil {
  477. return errors.Trace(err)
  478. }
  479. resolver.cache.Flush()
  480. params.ResponseQNameMustMatch = true
  481. _, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  482. if err == nil {
  483. return errors.TraceNew("expected QName mismatch")
  484. }
  485. resolver.cache.Flush()
  486. params.AlternateDNSServer = okServer.getAddr()
  487. _, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  488. if err == nil {
  489. return errors.TraceNew("expected server to not support random QName casing")
  490. }
  491. err = checkResult(IPs)
  492. if err != nil {
  493. return errors.Trace(err)
  494. }
  495. if randomQNameCasingOkServer.getRequestCount() < 1 {
  496. return errors.TraceNew("unexpected random QName casing server request count")
  497. }
  498. params.AttemptsPerServer = 1
  499. params.AlternateDNSServer = ""
  500. params.PreferAlternateDNSServer = false
  501. params.RandomQNameCasingSeed = nil
  502. // Test: EDNS(0)
  503. resolver.cache.Flush()
  504. params.IncludeEDNS0 = true
  505. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  506. if err != nil {
  507. return errors.Trace(err)
  508. }
  509. err = checkResult(IPs)
  510. if err != nil {
  511. return errors.Trace(err)
  512. }
  513. params.IncludeEDNS0 = false
  514. // Test: input IP address
  515. beforeMetrics = resolver.metrics
  516. resolver.cache.Flush()
  517. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleIPv4)
  518. if err != nil {
  519. return errors.Trace(err)
  520. }
  521. if len(IPs) != 1 || IPs[0].String() != exampleIPv4 {
  522. return errors.TraceNew("unexpected IPv4 response")
  523. }
  524. if resolver.metrics.resolves != beforeMetrics.resolves {
  525. return errors.Tracef("unexpected metrics: %+v", resolver.metrics)
  526. }
  527. // Test: DNS cache extension
  528. resolver.cache.Flush()
  529. networkConfig.CacheExtensionInitialTTL = (exampleTTLSeconds * 2) * time.Second
  530. networkConfig.CacheExtensionVerifiedTTL = 2 * time.Hour
  531. now := time.Now()
  532. IPs, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  533. if err != nil {
  534. return errors.Trace(err)
  535. }
  536. entry, expiry, ok := resolver.cache.GetWithExpiration(exampleDomain)
  537. if !ok ||
  538. !reflect.DeepEqual(entry, IPs) ||
  539. expiry.Before(now.Add(networkConfig.CacheExtensionInitialTTL)) ||
  540. expiry.After(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
  541. return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
  542. }
  543. resolver.VerifyCacheExtension(exampleDomain)
  544. entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain)
  545. if !ok ||
  546. !reflect.DeepEqual(entry, IPs) ||
  547. expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
  548. return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
  549. }
  550. // Set cache flush condition, which should be ignored
  551. networkID = "networkID-5"
  552. resolver.updateNetworkState(networkID)
  553. entry, expiry, ok = resolver.cache.GetWithExpiration(exampleDomain)
  554. if !ok ||
  555. !reflect.DeepEqual(entry, IPs) ||
  556. expiry.Before(now.Add(networkConfig.CacheExtensionVerifiedTTL)) {
  557. return errors.TraceNew("unexpected CacheExtensionInitialTTL state")
  558. }
  559. // Test: cancel context
  560. resolver.cache.Flush()
  561. cancelFunc()
  562. _, err = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  563. if err == nil {
  564. return errors.TraceNew("unexpected success")
  565. }
  566. // Test: cancel context while resolving
  567. // This test exercises the additional answers and await cases in
  568. // ResolveIP. The test is timing dependent, and so imperfect, but this
  569. // configuration can reproduce panics in those cases before bugs were
  570. // fixed, where DNS responses need to be received just as the context is
  571. // cancelled.
  572. networkConfig.GetDNSServers = func() []string { return []string{okServer.getAddr()} }
  573. networkID = "networkID-6"
  574. for i := 0; i < 500; i++ {
  575. resolver.cache.Flush()
  576. ctx, cancelFunc := context.WithTimeout(
  577. context.Background(), time.Duration((i%10+1)*20)*time.Microsecond)
  578. defer cancelFunc()
  579. _, _ = resolver.ResolveIP(ctx, networkID, params, exampleDomain)
  580. }
  581. return nil
  582. }
  583. func runTestPublicDNSServers() ([]net.IP, string, error) {
  584. networkConfig := &NetworkConfig{
  585. GetDNSServers: getPublicDNSServers,
  586. }
  587. networkID := "networkID-1"
  588. resolver := NewResolver(networkConfig, networkID)
  589. defer resolver.Stop()
  590. params := &ResolveParameters{
  591. AttemptsPerServer: 1,
  592. RequestTimeout: 5 * time.Second,
  593. AwaitTimeout: 1 * time.Second,
  594. IncludeEDNS0: true,
  595. }
  596. IPs, err := resolver.ResolveIP(
  597. context.Background(), networkID, params, exampleRealDomain)
  598. if err != nil {
  599. return nil, "", errors.Trace(err)
  600. }
  601. gotIPv4 := false
  602. gotIPv6 := false
  603. for _, IP := range IPs {
  604. if IP.To4() != nil {
  605. gotIPv4 = true
  606. } else {
  607. gotIPv6 = true
  608. }
  609. }
  610. if !gotIPv4 {
  611. return nil, "", errors.TraceNew("missing IPv4 response")
  612. }
  613. if !gotIPv6 && resolver.hasIPv6Route {
  614. return nil, "", errors.TraceNew("missing IPv6 response")
  615. }
  616. return IPs, resolver.GetMetrics(), nil
  617. }
  618. func getPublicDNSServers() []string {
  619. servers := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9"}
  620. shuffledServers := make([]string, len(servers))
  621. for i, j := range prng.Perm(len(servers)) {
  622. shuffledServers[i] = servers[j]
  623. }
  624. return shuffledServers
  625. }
  626. var exampleDomain = fmt.Sprintf("%s.example.com", prng.Base64String(32))
  627. const (
  628. exampleRealDomain = "example.com"
  629. exampleIPv4 = "93.184.216.34"
  630. exampleIPv4CIDR = "93.184.216.0/24"
  631. exampleIPv6 = "2606:2800:220:1:248:1893:25c8:1946"
  632. exampleIPv6CIDR = "2606:2800:220::/48"
  633. exampleTTLSeconds = 60
  634. )
  635. // Set the reserved Z flag
  636. var exampleTransform = transforms.Spec{[2]string{"^([a-f0-9]{4})0100", "\\$\\{1\\}0140"}}
  637. type testDNSServer struct {
  638. respond bool
  639. validResponse bool
  640. expectTransform bool
  641. expectRandomQNameCasing bool
  642. addr string
  643. requestCount int32
  644. server *dns.Server
  645. }
  646. func newTestDNSServer(respond, validResponse, expectTransform, expectRandomQNameCasing bool) (*testDNSServer, error) {
  647. udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
  648. if err != nil {
  649. return nil, errors.Trace(err)
  650. }
  651. udpConn, err := net.ListenUDP("udp", udpAddr)
  652. if err != nil {
  653. return nil, errors.Trace(err)
  654. }
  655. s := &testDNSServer{
  656. respond: respond,
  657. validResponse: validResponse,
  658. expectTransform: expectTransform,
  659. expectRandomQNameCasing: expectRandomQNameCasing,
  660. addr: udpConn.LocalAddr().String(),
  661. }
  662. server := &dns.Server{
  663. PacketConn: udpConn,
  664. Handler: s,
  665. }
  666. s.server = server
  667. go server.ActivateAndServe()
  668. return s, nil
  669. }
  670. func (s *testDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
  671. atomic.AddInt32(&s.requestCount, 1)
  672. if !s.respond {
  673. return
  674. }
  675. // Check the reserved Z flag
  676. if s.expectTransform != r.MsgHdr.Zero {
  677. return
  678. }
  679. if len(r.Question) != 1 ||
  680. (!s.expectRandomQNameCasing &&
  681. r.Question[0].Name != dns.Fqdn(exampleDomain)) {
  682. return
  683. }
  684. m := new(dns.Msg)
  685. m.SetReply(r)
  686. m.Answer = make([]dns.RR, 1)
  687. if r.Question[0].Qtype == dns.TypeA {
  688. IP := net.ParseIP(exampleIPv4)
  689. if !s.validResponse {
  690. IP = net.ParseIP("127.0.0.1")
  691. }
  692. m.Answer[0] = &dns.A{
  693. Hdr: dns.RR_Header{
  694. Name: r.Question[0].Name,
  695. Rrtype: dns.TypeA,
  696. Class: dns.ClassINET,
  697. Ttl: exampleTTLSeconds},
  698. A: IP,
  699. }
  700. } else {
  701. IP := net.ParseIP(exampleIPv6)
  702. if !s.validResponse {
  703. IP = net.ParseIP("::1")
  704. }
  705. m.Answer[0] = &dns.AAAA{
  706. Hdr: dns.RR_Header{
  707. Name: r.Question[0].Name,
  708. Rrtype: dns.TypeAAAA,
  709. Class: dns.ClassINET,
  710. Ttl: exampleTTLSeconds},
  711. AAAA: IP,
  712. }
  713. }
  714. if s.expectRandomQNameCasing {
  715. // Simulate a server that does not preserve the casing of the QName.
  716. m.Question[0].Name = dns.Fqdn(exampleDomain)
  717. }
  718. w.WriteMsg(m)
  719. }
  720. func (s *testDNSServer) getAddr() string {
  721. return s.addr
  722. }
  723. func (s *testDNSServer) getRequestCount() int {
  724. return int(atomic.LoadInt32(&s.requestCount))
  725. }
  726. func (s *testDNSServer) stop() {
  727. s.server.PacketConn.Close()
  728. s.server.Shutdown()
  729. }