dsl_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. /*
  2. * Copyright (c) 2025, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package dsl
  20. import (
  21. "bytes"
  22. "context"
  23. "encoding/base64"
  24. "encoding/hex"
  25. "io/ioutil"
  26. "os"
  27. "runtime/debug"
  28. "sync"
  29. "sync/atomic"
  30. "testing"
  31. "time"
  32. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  33. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  34. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
  35. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  36. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/internal/testutils"
  37. )
  38. type testConfig struct {
  39. name string
  40. alreadyDiscovered bool
  41. requireOSLKeys bool
  42. interruptDownloads bool
  43. enableRetries bool
  44. repeatBeforeTTL bool
  45. isTunneled bool
  46. expectFailure bool
  47. cacheServerEntries bool
  48. }
  49. func TestDSLs(t *testing.T) {
  50. tests := []*testConfig{
  51. {
  52. name: "undiscovered server entries",
  53. },
  54. {
  55. name: "require OSL keys",
  56. requireOSLKeys: true,
  57. },
  58. {
  59. name: "interruptions without retry",
  60. interruptDownloads: true,
  61. expectFailure: true,
  62. },
  63. {
  64. name: "interruptions with retry",
  65. interruptDownloads: true,
  66. enableRetries: true,
  67. },
  68. {
  69. name: "require OSL keys with interruptions",
  70. requireOSLKeys: true,
  71. interruptDownloads: true,
  72. enableRetries: true,
  73. },
  74. {
  75. name: "repeat before TTL",
  76. repeatBeforeTTL: true,
  77. },
  78. {
  79. name: "previously discovered server entries",
  80. alreadyDiscovered: true,
  81. },
  82. {
  83. name: "first request is-tunneled",
  84. isTunneled: true,
  85. },
  86. {
  87. name: "cache server entries",
  88. interruptDownloads: true,
  89. enableRetries: true,
  90. cacheServerEntries: true,
  91. },
  92. }
  93. for _, testConfig := range tests {
  94. t.Run(testConfig.name, func(t *testing.T) {
  95. err := testDSLs(testConfig)
  96. if err != nil && !testConfig.expectFailure {
  97. t.Fatal(err.Error())
  98. }
  99. })
  100. }
  101. }
  102. var (
  103. testClientIP = "192.168.0.1"
  104. testClientGeoIPData = common.GeoIPData{
  105. Country: "Country",
  106. City: "City",
  107. ISP: "ISP",
  108. ASN: "ASN",
  109. ASO: "ASO",
  110. }
  111. testHostID = "host_id"
  112. )
  113. func testDSLs(testConfig *testConfig) error {
  114. testDataDirName, err := ioutil.TempDir("", "psiphon-dsl-test")
  115. if err != nil {
  116. return errors.Trace(err)
  117. }
  118. defer os.RemoveAll(testDataDirName)
  119. // Initialize OSLs
  120. var backendOSLPaveData1 []*osl.PaveData
  121. var backendOSLPaveData2 []*osl.PaveData
  122. var clientSLOKs []*osl.SLOK
  123. if testConfig.requireOSLKeys {
  124. var err error
  125. backendOSLPaveData1, backendOSLPaveData2, clientSLOKs, err =
  126. testutils.InitializeTestOSLPaveData()
  127. if err != nil {
  128. return errors.Trace(err)
  129. }
  130. }
  131. // Initialize backend
  132. tlsConfig, err := testutils.NewTestDSLTLSConfig()
  133. if err != nil {
  134. return errors.Trace(err)
  135. }
  136. backend, err := testutils.NewTestDSLBackend(
  137. NewBackendTestShim(),
  138. tlsConfig,
  139. testClientIP, &testClientGeoIPData, testHostID,
  140. backendOSLPaveData1)
  141. if err != nil {
  142. return errors.Trace(err)
  143. }
  144. err = backend.Start()
  145. if err != nil {
  146. return errors.Trace(err)
  147. }
  148. defer backend.Stop()
  149. // Initialize relay
  150. expectValidMetric := false
  151. metricsValidator := func(metric string, fields common.LogFields) bool { return false }
  152. if testConfig.cacheServerEntries {
  153. expectValidMetric = true
  154. metricsValidator = func(metric string, fields common.LogFields) bool {
  155. return metric == "dsl_relay_get_server_entries"
  156. }
  157. }
  158. relayLogger := testutils.NewTestLoggerWithMetricValidator("relay", metricsValidator)
  159. relayCACertificatesFilename,
  160. relayHostCertificateFilename,
  161. relayHostKeyFilename,
  162. err := tlsConfig.WriteRelayFiles(testDataDirName)
  163. if err != nil {
  164. return errors.Trace(err)
  165. }
  166. relayGetServiceAddress := func(_ common.GeoIPData) (string, error) {
  167. return backend.GetAddress(), nil
  168. }
  169. relayConfig := &RelayConfig{
  170. Logger: relayLogger,
  171. CACertificatesFilename: relayCACertificatesFilename,
  172. HostCertificateFilename: relayHostCertificateFilename,
  173. HostKeyFilename: relayHostKeyFilename,
  174. GetServiceAddress: relayGetServiceAddress,
  175. HostID: testHostID,
  176. APIParameterValidator: func(params common.APIParameters) error { return nil },
  177. APIParameterLogFieldFormatter: func(
  178. _ string, _ common.GeoIPData, params common.APIParameters) common.LogFields {
  179. logFields := common.LogFields{}
  180. logFields.Add(common.LogFields(params))
  181. return logFields
  182. },
  183. }
  184. relay, err := NewRelay(relayConfig)
  185. if err != nil {
  186. return errors.Trace(err)
  187. }
  188. if !testConfig.cacheServerEntries {
  189. relay.SetCacheParameters(0, 0)
  190. }
  191. // Initialize client fetcher
  192. // Set transfer targets that will exercise various scenarios, including
  193. // requiring request size backoff (e.g. see Fetcher.doGetServerEntriesRequest)
  194. // to succeed.
  195. discoverCount := 128
  196. getCount := 64
  197. oslCount := 1
  198. interruptLimit := 0
  199. if testConfig.interruptDownloads {
  200. interruptLimit = 8192
  201. }
  202. retryCount := 0
  203. if testConfig.enableRetries {
  204. retryCount = 20
  205. }
  206. isTunneled := testConfig.isTunneled
  207. if isTunneled {
  208. discoverCount = 1
  209. }
  210. if backend.GetServerEntryCount(isTunneled) != 128 {
  211. return errors.TraceNew("unexpected server entry count")
  212. }
  213. dslClient := newDSLClient(clientSLOKs)
  214. clientRelayRoundTripper := func(
  215. ctx context.Context,
  216. requestPayload []byte) ([]byte, error) {
  217. // Normally, the Fetcher.RoundTripper would add a circumvention,
  218. // blocking resistant first hop. For this test, it's just a stub that
  219. // directly invokes the relay.
  220. responsePayload, err := relay.HandleRequest(
  221. ctx,
  222. nil,
  223. testClientIP,
  224. testClientGeoIPData,
  225. isTunneled,
  226. requestPayload)
  227. if err != nil {
  228. return GetRelayGenericErrorResponse(), errors.Trace(err)
  229. }
  230. // Simulate interruption of large response.
  231. if interruptLimit > 0 && len(responsePayload) > interruptLimit {
  232. return nil, errors.TraceNew("interrupted")
  233. }
  234. return responsePayload, nil
  235. }
  236. // TODO: exercise BaseAPIParameters?
  237. var unexpectedServerEntrySource atomic.Int32
  238. var unexpectedServerEntryPrioritizeDial atomic.Int32
  239. datastoreHasServerEntryWithCheck := func(
  240. tag ServerEntryTag,
  241. version int,
  242. prioritizeDial bool) bool {
  243. _, expectedPrioritizeDial, err := backend.GetServerEntryProperties(tag.String())
  244. if err != nil || prioritizeDial != expectedPrioritizeDial {
  245. unexpectedServerEntryPrioritizeDial.Store(1)
  246. }
  247. return dslClient.DatastoreHasServerEntry(tag, version)
  248. }
  249. datastoreStoreServerEntryWithCheck := func(
  250. packedServerEntryFields protocol.PackedServerEntryFields,
  251. source string,
  252. prioritizeDial bool) error {
  253. serverEntryFields, _ := protocol.DecodePackedServerEntryFields(packedServerEntryFields)
  254. tag := serverEntryFields.GetTag()
  255. expectedSource, expectedPrioritizeDial, err := backend.GetServerEntryProperties(tag)
  256. if err != nil || prioritizeDial != expectedPrioritizeDial {
  257. unexpectedServerEntryPrioritizeDial.Store(1)
  258. }
  259. if err != nil || source != expectedSource {
  260. unexpectedServerEntrySource.Store(1)
  261. }
  262. return errors.Trace(
  263. dslClient.DatastoreStoreServerEntry(packedServerEntryFields, source))
  264. }
  265. fetcherConfig := &FetcherConfig{
  266. Logger: testutils.NewTestLoggerWithComponent("fetcher"),
  267. RoundTripper: clientRelayRoundTripper,
  268. DatastoreGetLastFetchTime: dslClient.DatastoreGetLastFetchTime,
  269. DatastoreSetLastFetchTime: dslClient.DatastoreSetLastFetchTime,
  270. DatastoreGetLastActiveOSLsTime: dslClient.DatastoreGetLastActiveOSLsTime,
  271. DatastoreSetLastActiveOSLsTime: dslClient.DatastoreSetLastActiveOSLsTime,
  272. DatastoreHasServerEntry: datastoreHasServerEntryWithCheck,
  273. DatastoreStoreServerEntry: datastoreStoreServerEntryWithCheck,
  274. DatastoreKnownOSLIDs: dslClient.DatastoreKnownOSLIDs,
  275. DatastoreGetOSLState: dslClient.DatastoreGetOSLState,
  276. DatastoreStoreOSLState: dslClient.DatastoreStoreOSLState,
  277. DatastoreDeleteOSLState: dslClient.DatastoreDeleteOSLState,
  278. DatastoreSLOKLookup: dslClient.DatastoreSLOKLookup,
  279. RequestTimeout: 1 * time.Second,
  280. RequestRetryCount: retryCount,
  281. RequestRetryDelay: 1 * time.Millisecond,
  282. RequestRetryDelayJitter: 0.1,
  283. FetchTTL: 1 * time.Hour,
  284. DiscoverServerEntriesMinCount: discoverCount,
  285. DiscoverServerEntriesMaxCount: discoverCount,
  286. GetServerEntriesMinCount: getCount,
  287. GetServerEntriesMaxCount: getCount,
  288. GetLastActiveOSLsTTL: 1 * time.Hour,
  289. GetOSLFileSpecsMinCount: oslCount,
  290. GetOSLFileSpecsMaxCount: oslCount,
  291. DoGarbageCollection: debug.FreeOSMemory,
  292. }
  293. fetcher, err := NewFetcher(fetcherConfig)
  294. if err != nil {
  295. return errors.Trace(err)
  296. }
  297. // Fetch server entries
  298. ctx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
  299. defer cancelFunc()
  300. err = fetcher.Run(ctx)
  301. if testConfig.expectFailure && err == nil {
  302. err = errors.TraceNew("unexpected success")
  303. }
  304. if err != nil {
  305. return errors.Trace(err)
  306. }
  307. if testConfig.repeatBeforeTTL {
  308. // Invoke fetch again with before the last discover time TTL expires.
  309. // The always-failing round tripper will be hit if an unexpected
  310. // request is sent.
  311. fetcherConfig.RoundTripper = func(
  312. context.Context,
  313. []byte) ([]byte, error) {
  314. return nil, errors.TraceNew("round trip not permitted")
  315. }
  316. err = fetcher.Run(ctx)
  317. if err != nil {
  318. return errors.Trace(err)
  319. }
  320. }
  321. if testConfig.alreadyDiscovered && testConfig.isTunneled {
  322. return errors.TraceNew("invalid test configuration")
  323. }
  324. if testConfig.alreadyDiscovered {
  325. // Fetch again after resetting the last discover time TTL. A
  326. // DiscoverServerEntries request will be sent, but all tags should be
  327. // known, and no GetServerEntries requests should be sent or any
  328. // server entries stores, as will be checked via
  329. // dslClient.serverEntryStoreCount.
  330. dslClient.lastFetchTime = time.Time{}
  331. dslClient.lastActiveOSLsTime = time.Time{}
  332. err = fetcher.Run(ctx)
  333. if err != nil {
  334. return errors.Trace(err)
  335. }
  336. }
  337. if testConfig.isTunneled {
  338. if dslClient.serverEntryStoreCount != 1 {
  339. return errors.Tracef(
  340. "unexpected server entry store count: %d", dslClient.serverEntryStoreCount)
  341. }
  342. // If the first request was isTunneled, only one server entry will
  343. // have been fetched. Do another full fetch, and the following
  344. // dslClient.serverEntryStoreCount check will demonstrate that all
  345. // remaining server entries were downloaded and stored.
  346. dslClient.lastFetchTime = time.Time{}
  347. discoverCount = 128
  348. fetcherConfig.DiscoverServerEntriesMinCount = discoverCount
  349. fetcherConfig.DiscoverServerEntriesMaxCount = discoverCount
  350. err = fetcher.Run(ctx)
  351. if err != nil {
  352. return errors.Trace(err)
  353. }
  354. }
  355. // TODO: check "updated" and "known" counters in "DSL: fetched server
  356. // entries" logs.
  357. if dslClient.serverEntryStoreCount != backend.GetServerEntryCount(isTunneled) {
  358. return errors.Tracef(
  359. "unexpected server entry store count: %d", dslClient.serverEntryStoreCount)
  360. }
  361. if testConfig.requireOSLKeys {
  362. // Rotate to the next OSL period and clear all server entries. The
  363. // fetcher will download the new, unknown OSL and reassemble the key,
  364. // or else no server entries will be downloaded. Check that the
  365. // fetcher cleans up the old, no longer active OSL state via
  366. // dslClient.deleteOSLStateCount.
  367. dslClient.lastFetchTime = time.Time{}
  368. dslClient.lastActiveOSLsTime = time.Time{}
  369. dslClient.serverEntries = make(map[string]protocol.ServerEntryFields)
  370. backend.SetOSLPaveData(backendOSLPaveData2)
  371. err = fetcher.Run(ctx)
  372. if err != nil {
  373. return errors.Trace(err)
  374. }
  375. if dslClient.serverEntryStoreCount != backend.GetServerEntryCount(isTunneled) {
  376. return errors.Tracef(
  377. "unexpected server entry store count: %d", dslClient.serverEntryStoreCount)
  378. }
  379. if dslClient.deleteOSLStateCount < 1 {
  380. return errors.Tracef(
  381. "unexpected delete OSL state count: %d", dslClient.deleteOSLStateCount)
  382. }
  383. }
  384. err = relayLogger.CheckMetrics(expectValidMetric)
  385. if err != nil {
  386. return errors.Trace(err)
  387. }
  388. if unexpectedServerEntrySource.Load() != 0 {
  389. return errors.TraceNew("unexpected server entry source")
  390. }
  391. if unexpectedServerEntryPrioritizeDial.Load() != 0 {
  392. return errors.TraceNew("unexpected server entry prioritize dial")
  393. }
  394. return nil
  395. }
  396. type dslClient struct {
  397. mutex sync.Mutex
  398. lastFetchTime time.Time
  399. lastActiveOSLsTime time.Time
  400. serverEntries map[string]protocol.ServerEntryFields
  401. serverEntryStoreCount int
  402. oslStates map[string][]byte
  403. deleteOSLStateCount int
  404. SLOKs []*osl.SLOK
  405. }
  406. func newDSLClient(SLOKs []*osl.SLOK) *dslClient {
  407. return &dslClient{
  408. serverEntries: make(map[string]protocol.ServerEntryFields),
  409. oslStates: make(map[string][]byte),
  410. SLOKs: SLOKs,
  411. }
  412. }
  413. func (c *dslClient) DatastoreGetLastFetchTime() (time.Time, error) {
  414. c.mutex.Lock()
  415. defer c.mutex.Unlock()
  416. return c.lastFetchTime, nil
  417. }
  418. func (c *dslClient) DatastoreSetLastFetchTime(time time.Time) error {
  419. c.mutex.Lock()
  420. defer c.mutex.Unlock()
  421. c.lastFetchTime = time
  422. return nil
  423. }
  424. func (c *dslClient) DatastoreGetLastActiveOSLsTime() (time.Time, error) {
  425. c.mutex.Lock()
  426. defer c.mutex.Unlock()
  427. return c.lastActiveOSLsTime, nil
  428. }
  429. func (c *dslClient) DatastoreSetLastActiveOSLsTime(time time.Time) error {
  430. c.mutex.Lock()
  431. defer c.mutex.Unlock()
  432. c.lastActiveOSLsTime = time
  433. return nil
  434. }
  435. func (c *dslClient) DatastoreHasServerEntry(tag ServerEntryTag, version int) bool {
  436. c.mutex.Lock()
  437. defer c.mutex.Unlock()
  438. _, ok := c.serverEntries[base64.StdEncoding.EncodeToString(tag)]
  439. return ok
  440. }
  441. func (c *dslClient) DatastoreStoreServerEntry(
  442. packedServerEntryFields protocol.PackedServerEntryFields, source string) error {
  443. c.mutex.Lock()
  444. defer c.mutex.Unlock()
  445. c.serverEntryStoreCount += 1
  446. serverEntryFields, err := protocol.DecodePackedServerEntryFields(packedServerEntryFields)
  447. if err != nil {
  448. return errors.Trace(err)
  449. }
  450. serverEntryFields.SetLocalSource(source)
  451. serverEntryFields.SetLocalTimestamp(
  452. common.TruncateTimestampToHour(common.GetCurrentTimestamp()))
  453. c.serverEntries[serverEntryFields.GetTag()] = serverEntryFields
  454. return nil
  455. }
  456. func (c *dslClient) DatastoreKnownOSLIDs() ([]OSLID, error) {
  457. c.mutex.Lock()
  458. defer c.mutex.Unlock()
  459. var IDs []OSLID
  460. for IDStr := range c.oslStates {
  461. ID, _ := hex.DecodeString(IDStr)
  462. IDs = append(IDs, ID)
  463. }
  464. return IDs, nil
  465. }
  466. func (c *dslClient) DatastoreGetOSLState(ID OSLID) ([]byte, error) {
  467. c.mutex.Lock()
  468. defer c.mutex.Unlock()
  469. state, ok := c.oslStates[hex.EncodeToString(ID)]
  470. if !ok {
  471. return nil, nil
  472. }
  473. return state, nil
  474. }
  475. func (c *dslClient) DatastoreStoreOSLState(ID OSLID, state []byte) error {
  476. c.mutex.Lock()
  477. defer c.mutex.Unlock()
  478. c.oslStates[hex.EncodeToString(ID)] = state
  479. return nil
  480. }
  481. func (c *dslClient) DatastoreDeleteOSLState(ID OSLID) error {
  482. c.mutex.Lock()
  483. defer c.mutex.Unlock()
  484. c.deleteOSLStateCount += 1
  485. delete(c.oslStates, hex.EncodeToString(ID))
  486. return nil
  487. }
  488. func (c *dslClient) DatastoreSLOKLookup(SLOKID []byte) []byte {
  489. c.mutex.Lock()
  490. defer c.mutex.Unlock()
  491. for _, slok := range c.SLOKs {
  492. if bytes.Equal(slok.ID, SLOKID) {
  493. return slok.Key
  494. }
  495. }
  496. return nil
  497. }
  498. func (c *dslClient) DatastoreFatalError(err error) {
  499. panic(err.Error())
  500. }