relay.go 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. /*
  2. * Copyright (c) 2025, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package dsl
  20. import (
  21. "bytes"
  22. "context"
  23. "crypto/tls"
  24. "crypto/x509"
  25. "encoding/json"
  26. "fmt"
  27. "io"
  28. "net/http"
  29. "os"
  30. "sync"
  31. "time"
  32. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  33. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  34. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  35. lrucache "github.com/cognusion/go-cache-lru"
  36. "github.com/fxamacker/cbor/v2"
  37. )
  38. const (
  39. defaultMaxHttpConns = 100
  40. defaultMaxHttpIdleConns = 10
  41. defaultHttpIdleConnTimeout = 120 * time.Second
  42. defaultRequestTimeout = 30 * time.Second
  43. defaultRequestRetryCount = 1
  44. defaultServerEntryCacheTTL = 24 * time.Hour
  45. defaultServerEntryCacheMaxSize = 250000
  46. defaultOSLFileSpecCacheTTL = 24 * time.Hour
  47. defaultOSLFileSpecCacheMaxSize = 250000
  48. )
  49. // RelayConfig specifies the configuration for a Relay.
  50. //
  51. // The CACertificates and HostCertificate/Key parameters are used for mutually
  52. // authenticated TLS between the Relay and the DSL backend. The HostID value
  53. // is sent to the DSL backend for logging, and should be populated with the
  54. // HostID in psiphond.config.
  55. type RelayConfig struct {
  56. Logger common.Logger
  57. CACertificatesFilename string
  58. HostCertificateFilename string
  59. HostKeyFilename string
  60. GetServiceAddress func(
  61. clientGeoIPData common.GeoIPData) (string, error)
  62. HostID string
  63. // APIParameterValidator is a callback that validates base API metrics.
  64. APIParameterValidator common.APIParameterValidator
  65. // APIParameterValidator is a callback that formats base API metrics.
  66. APIParameterLogFieldFormatter common.APIParameterLogFieldFormatter
  67. }
  68. // Relay is an intermediary between a DSL client and the DSL backend which
  69. // provides circumvention and blocking resistance. Relays include in-proxy
  70. // brokers, and Psiphon servers. See the "Relay API layer" comment section is
  71. // in api.go for more details.
  72. //
  73. // The Relay maintains a pool of persistent HTTP connections for making
  74. // requests.
  75. //
  76. // The Relay supports transparent caching of server entries, where
  77. // GetServerEntriesRequest requests may be fully or partially served out of
  78. // the local cache.
  79. type Relay struct {
  80. config *RelayConfig
  81. caCertificatesFile common.ReloadableFile
  82. hostCertificateFile common.ReloadableFile
  83. hostKeyFile common.ReloadableFile
  84. mutex sync.Mutex
  85. tlsSessionCache tls.ClientSessionCache
  86. tlsConfig *tls.Config
  87. httpClient *http.Client
  88. requestTimeout time.Duration
  89. requestRetryCount int
  90. serverEntryCache *lrucache.Cache
  91. serverEntryCacheTTL time.Duration
  92. serverEntryCacheMaxSize int
  93. oslFileSpecCache *lrucache.Cache
  94. oslFileSpecCacheTTL time.Duration
  95. oslFileSpecCacheMaxSize int
  96. getServerEntriesBufferPool sync.Pool
  97. getOSLFileSpecsBufferPool sync.Pool
  98. }
  99. // NewRelay creates a new Relay.
  100. func NewRelay(config *RelayConfig) (*Relay, error) {
  101. relay := &Relay{
  102. config: config,
  103. caCertificatesFile: common.NewReloadableFile(config.CACertificatesFilename, false, nil),
  104. hostCertificateFile: common.NewReloadableFile(config.HostCertificateFilename, false, nil),
  105. hostKeyFile: common.NewReloadableFile(config.HostKeyFilename, false, nil),
  106. tlsSessionCache: tls.NewLRUClientSessionCache(0),
  107. }
  108. _, err := relay.Reload()
  109. if err != nil {
  110. return nil, errors.Trace(err)
  111. }
  112. relay.SetRequestParameters(
  113. defaultMaxHttpConns,
  114. defaultMaxHttpIdleConns,
  115. defaultHttpIdleConnTimeout,
  116. defaultRequestTimeout,
  117. defaultRequestRetryCount)
  118. relay.SetCacheParameters(
  119. defaultServerEntryCacheTTL,
  120. defaultServerEntryCacheMaxSize,
  121. defaultOSLFileSpecCacheTTL,
  122. defaultOSLFileSpecCacheMaxSize)
  123. relay.getServerEntriesBufferPool.New = func() any { return []*SourcedServerEntry{} }
  124. relay.getOSLFileSpecsBufferPool.New = func() any { return []OSLFileSpec{} }
  125. return relay, nil
  126. }
  127. // Reload reloads the TLS configuration when the file contents have changed.
  128. //
  129. // Reload implements the common.Reloader interface.
  130. func (r *Relay) Reload() (bool, error) {
  131. // The common.ReloadableFile.reloadAction callback not used; instead,
  132. // ReloadableFiles are used to check for changed file contents. When any
  133. // file has changed, all TLS configuration files are reloaded and the TLS
  134. // configuration is reinitialized.
  135. reloadedAny := false
  136. reloaded, err := r.caCertificatesFile.Reload()
  137. if err != nil {
  138. return false, errors.Trace(err)
  139. }
  140. reloadedAny = reloadedAny || reloaded
  141. reloaded, err = r.hostCertificateFile.Reload()
  142. if err != nil {
  143. return false, errors.Trace(err)
  144. }
  145. reloadedAny = reloadedAny || reloaded
  146. reloaded, err = r.hostKeyFile.Reload()
  147. if err != nil {
  148. return false, errors.Trace(err)
  149. }
  150. reloadedAny = reloadedAny || reloaded
  151. if !reloadedAny {
  152. return false, nil
  153. }
  154. caCertsPEM, err := os.ReadFile(r.config.CACertificatesFilename)
  155. if err != nil {
  156. return false, errors.Trace(err)
  157. }
  158. caCertificates := x509.NewCertPool()
  159. if !caCertificates.AppendCertsFromPEM(caCertsPEM) {
  160. return false, errors.TraceNew("AppendCertsFromPEM failed")
  161. }
  162. hostCertificate, err := tls.LoadX509KeyPair(
  163. r.config.HostCertificateFilename,
  164. r.config.HostKeyFilename)
  165. if err != nil {
  166. return false, errors.Trace(err)
  167. }
  168. r.mutex.Lock()
  169. defer r.mutex.Unlock()
  170. r.tlsSessionCache = tls.NewLRUClientSessionCache(0)
  171. r.tlsConfig = &tls.Config{
  172. RootCAs: caCertificates,
  173. Certificates: []tls.Certificate{hostCertificate},
  174. ClientSessionCache: r.tlsSessionCache,
  175. }
  176. if r.httpClient != nil {
  177. // Replace the http.Client if it exists. See the comment in
  178. // SetRequestParameters regarding in-flight requests and idle timeout
  179. // limitations.
  180. httpTransport := r.httpClient.Transport.(*http.Transport)
  181. r.httpClient = &http.Client{
  182. Transport: &http.Transport{
  183. TLSClientConfig: r.tlsConfig,
  184. ForceAttemptHTTP2: true,
  185. MaxConnsPerHost: httpTransport.MaxConnsPerHost,
  186. MaxIdleConns: httpTransport.MaxIdleConns,
  187. MaxIdleConnsPerHost: httpTransport.MaxIdleConnsPerHost,
  188. IdleConnTimeout: httpTransport.IdleConnTimeout,
  189. },
  190. }
  191. }
  192. return true, nil
  193. }
  194. // WillReload implements the common.Reloader interface.
  195. func (r *Relay) WillReload() bool {
  196. return true
  197. }
  198. // ReloadLogDescription implements the common.Reloader interface.
  199. func (r *Relay) ReloadLogDescription() string {
  200. return "DSL Relay TLS configuration"
  201. }
  202. // SetRequestParameters updates the HTTP request parameters used for upstream
  203. // requests.
  204. func (r *Relay) SetRequestParameters(
  205. maxHttpConns int,
  206. maxHttpIdleConns int,
  207. httpIdleConnTimeout time.Duration,
  208. requestTimeout time.Duration,
  209. requestRetryCount int) {
  210. r.mutex.Lock()
  211. defer r.mutex.Unlock()
  212. r.requestTimeout = requestTimeout
  213. r.requestRetryCount = requestRetryCount
  214. // The http.Client client is replaced when the net/http configuration has
  215. // changed. Any in-flight requests using the previous http.Client will
  216. // continue until complete and eventually the previous http.Client will
  217. // be garbage collected.
  218. //
  219. // TODO: don't retain the previous http.Client for as long as
  220. // http.Transport.IdleConnTimeout.
  221. var httpTransport *http.Transport
  222. if r.httpClient != nil {
  223. httpTransport = r.httpClient.Transport.(*http.Transport)
  224. }
  225. if r.httpClient == nil ||
  226. httpTransport.MaxConnsPerHost != maxHttpConns ||
  227. httpTransport.MaxIdleConns != maxHttpIdleConns ||
  228. httpTransport.IdleConnTimeout != httpIdleConnTimeout {
  229. r.httpClient = &http.Client{
  230. Transport: &http.Transport{
  231. TLSClientConfig: r.tlsConfig,
  232. ForceAttemptHTTP2: true,
  233. MaxConnsPerHost: maxHttpConns,
  234. MaxIdleConns: maxHttpIdleConns,
  235. MaxIdleConnsPerHost: maxHttpIdleConns,
  236. IdleConnTimeout: httpIdleConnTimeout,
  237. },
  238. }
  239. }
  240. }
  241. // SetCacheParameters updates the parameters used for transparent server
  242. // entry caching. When the parameters change, any existing cache is flushed
  243. // and replaced.
  244. func (r *Relay) SetCacheParameters(
  245. serverEntryCacheTTL time.Duration,
  246. serverEntryCacheMaxSize int,
  247. oslFileSpecCacheTTL time.Duration,
  248. oslFileSpecCacheMaxSize int) {
  249. r.mutex.Lock()
  250. defer r.mutex.Unlock()
  251. if r.serverEntryCache == nil ||
  252. r.serverEntryCacheTTL != serverEntryCacheTTL ||
  253. r.serverEntryCacheMaxSize != serverEntryCacheMaxSize {
  254. if r.serverEntryCache != nil {
  255. r.serverEntryCache.Flush()
  256. }
  257. r.serverEntryCacheTTL = serverEntryCacheTTL
  258. r.serverEntryCacheMaxSize = serverEntryCacheMaxSize
  259. if r.serverEntryCacheTTL > 0 {
  260. r.serverEntryCache = lrucache.NewWithLRU(
  261. r.serverEntryCacheTTL,
  262. 1*time.Minute,
  263. r.serverEntryCacheMaxSize)
  264. } else {
  265. r.serverEntryCache = nil
  266. }
  267. }
  268. if r.oslFileSpecCache == nil ||
  269. r.oslFileSpecCacheTTL != oslFileSpecCacheTTL ||
  270. r.oslFileSpecCacheMaxSize != oslFileSpecCacheMaxSize {
  271. if r.oslFileSpecCache != nil {
  272. r.oslFileSpecCache.Flush()
  273. }
  274. r.oslFileSpecCacheTTL = oslFileSpecCacheTTL
  275. r.oslFileSpecCacheMaxSize = oslFileSpecCacheMaxSize
  276. if r.oslFileSpecCacheTTL > 0 {
  277. r.oslFileSpecCache = lrucache.NewWithLRU(
  278. r.oslFileSpecCacheTTL,
  279. 1*time.Minute,
  280. r.oslFileSpecCacheMaxSize)
  281. } else {
  282. r.oslFileSpecCache = nil
  283. }
  284. }
  285. }
  286. // HandleRequest relays a DSL request.
  287. //
  288. // If an extendTimeout callback is specified, it will be called with the
  289. // expected maximum request timeout, including retries; this callback may be
  290. // used to customize the response timeout for a transport handler.
  291. //
  292. // Set isClientTunneled when the relay uses a connected Psiphon tunnel.
  293. //
  294. // In the case of an error, the caller must log the error and send
  295. // dsl.GenericErrorResponse to the client. This generic error response
  296. // ensures that the client receives a DSL response and doesn't consider the
  297. // DSL FetcherRoundTripper to have failed.
  298. func (r *Relay) HandleRequest(
  299. ctx context.Context,
  300. extendTimeout func(time.Duration),
  301. clientIP string,
  302. clientGeoIPData common.GeoIPData,
  303. isClientTunneled bool,
  304. cborRelayedRequest []byte) ([]byte, error) {
  305. r.mutex.Lock()
  306. httpClient := r.httpClient
  307. requestTimeout := r.requestTimeout
  308. requestRetryCount := r.requestRetryCount
  309. r.mutex.Unlock()
  310. if extendTimeout != nil {
  311. extendTimeout(requestTimeout * time.Duration(requestRetryCount))
  312. }
  313. if httpClient == nil {
  314. return nil, errors.TraceNew("missing http client")
  315. }
  316. if len(cborRelayedRequest) > MaxRelayPayloadSize {
  317. return nil, errors.Tracef(
  318. "request size %d exceeds limit %d",
  319. len(cborRelayedRequest), MaxRelayPayloadSize)
  320. }
  321. var relayedRequest *RelayedRequest
  322. err := cbor.Unmarshal(cborRelayedRequest, &relayedRequest)
  323. if err != nil {
  324. return nil, errors.Trace(err)
  325. }
  326. if relayedRequest.Version != requestVersion {
  327. return nil, errors.Tracef(
  328. "unexpected request version %d", relayedRequest.Version)
  329. }
  330. path, ok := requestTypeToHTTPPath[relayedRequest.RequestType]
  331. if !ok {
  332. return nil, errors.Tracef(
  333. "unknown request type %d", relayedRequest.RequestType)
  334. }
  335. // Transparent caching:
  336. //
  337. // For requestTypeGetServerEntries, peek at the RelayedResponse.Response
  338. // and extract server entries and add to the local cache, keyed by server
  339. // entry tag.
  340. //
  341. // Peek at RelayedRequest.Request, and if all requested server entries are
  342. // in the cache, serve the request entirely from the local cache.
  343. //
  344. // The backend DSL may enforce a limited time interval in which certain
  345. // server entries can be discovered. This cache doesn't bypass this,
  346. // since DiscoveryServerEntries isn't cached and always passed through to
  347. // the DSL backend. Clients must discover the large, random server entry
  348. // tags via DiscoveryServerEntries within the designated time interval;
  349. // then clients may download the server entries via GetServerEntries at
  350. // any time, and this may be cached.
  351. //
  352. // Limitation: this cache ignores server entry version and may serve a
  353. // version that's older that the latest within the cache TTL.
  354. //
  355. // - Server entry version changes are assumed to be rare.
  356. //
  357. // - The cache will be updated with a new version as soon as
  358. // cacheGetServerEntriesResponse sees it.
  359. //
  360. // - Use a reasonable TTL such as 24h; cache entry TTLs aren't extended on
  361. // hits, so any old version will eventually be removed.
  362. //
  363. // - A more complicated scheme is possible: also peek at
  364. // DiscoverServerEntriesResponses and, for each tag/version pair, if
  365. // the tag is in the cache and the cached entry is an old version,
  366. // delete from the cache. This would require unpacking each server entry.
  367. //
  368. // Similarly, for requestTypeGetOSLFileSpecs, peek at the
  369. // RelayedResponse.Response and extract OSL file specs and add to the
  370. // local cache, keyed by OSL ID; and peek at RelayedRequest.Request, and
  371. // if all requested OSL file specs are in the cache, serve the request
  372. // entirely from the local cache.
  373. var response []byte
  374. cachedResponse := false
  375. var serveCachedResponse func([]byte, common.GeoIPData) ([]byte, error)
  376. var updateCache func([]byte, []byte) error
  377. switch relayedRequest.RequestType {
  378. case requestTypeGetServerEntries:
  379. serveCachedResponse = r.getCachedGetServerEntriesResponse
  380. updateCache = r.cacheGetServerEntriesResponse
  381. case requestTypeGetOSLFileSpecs:
  382. serveCachedResponse = r.getCachedGetOSLFileSpecsResponse
  383. updateCache = r.cacheGetOSLFileSpecsResponse
  384. }
  385. if serveCachedResponse != nil {
  386. var err error
  387. response, err = serveCachedResponse(
  388. relayedRequest.Request, clientGeoIPData)
  389. if err != nil {
  390. r.config.Logger.WithTraceFields(common.LogFields{
  391. "error": err.Error(),
  392. }).Warning("DSL: serve cached response failed")
  393. // Proceed with relaying request, even if the failure was due to
  394. // an error in DecodePackedAPIParameters or APIParameterValidator.
  395. // This allows the DSL backend to make the authoritative decision
  396. // and also log all failure cases.
  397. }
  398. cachedResponse = err == nil && response != nil
  399. }
  400. for i := 0; !cachedResponse; i++ {
  401. requestCtx := ctx
  402. if requestTimeout > 0 {
  403. var requestCancelFunc context.CancelFunc
  404. requestCtx, requestCancelFunc = context.WithTimeout(ctx, requestTimeout)
  405. defer requestCancelFunc()
  406. }
  407. serviceAddress, err := r.config.GetServiceAddress(clientGeoIPData)
  408. if err != nil {
  409. return nil, errors.Trace(err)
  410. }
  411. url := fmt.Sprintf("https://%s%s", serviceAddress, path)
  412. httpRequest, err := http.NewRequestWithContext(
  413. requestCtx, "POST", url, bytes.NewBuffer(relayedRequest.Request))
  414. if err != nil {
  415. return nil, errors.Trace(err)
  416. }
  417. // Attach the client IP and GeoIPData. The raw IP may be used, by the
  418. // DSL backend, in server entry selection logic; the GeoIP data is
  419. // for stats, and may also be used in server entry selection logic.
  420. // Sending preresolved GeoIP data saves the DSL backend from needing
  421. // its own GeoIP resolver, and ensures, for a given client a
  422. // consistent GeoIP view between the Psiphon server and the DSL backend.
  423. jsonGeoIPData, err := json.Marshal(clientGeoIPData)
  424. if err != nil {
  425. return nil, errors.Trace(err)
  426. }
  427. httpRequest.Header.Set(PsiphonClientIPHeader, clientIP)
  428. httpRequest.Header.Set(PsiphonClientGeoIPDataHeader, string(jsonGeoIPData))
  429. if isClientTunneled {
  430. httpRequest.Header.Set(PsiphonClientTunneledHeader, "true")
  431. } else {
  432. httpRequest.Header.Set(PsiphonClientTunneledHeader, "false")
  433. }
  434. httpRequest.Header.Set(PsiphonHostIDHeader, r.config.HostID)
  435. startTime := time.Now()
  436. httpResponse, err := httpClient.Do(httpRequest)
  437. duration := time.Since(startTime)
  438. if err == nil && httpResponse.StatusCode != http.StatusOK {
  439. httpResponse.Body.Close()
  440. err = errors.Tracef("unexpected response code: %d", httpResponse.StatusCode)
  441. }
  442. if err == nil {
  443. response, err = io.ReadAll(httpResponse.Body)
  444. httpResponse.Body.Close()
  445. }
  446. if err == nil {
  447. if updateCache != nil {
  448. err := updateCache(
  449. relayedRequest.Request, response)
  450. if err != nil {
  451. r.config.Logger.WithTraceFields(common.LogFields{
  452. "error": err.Error(),
  453. }).Warning("DSL: update cache failed")
  454. // Proceed with relaying response
  455. }
  456. }
  457. break
  458. }
  459. r.config.Logger.WithTraceFields(common.LogFields{
  460. "duration": duration.String(),
  461. "error": err.Error(),
  462. }).Warning("DSL: service request attempt failed")
  463. // Retry on network errors.
  464. if i < requestRetryCount && ctx.Err() == nil {
  465. continue
  466. }
  467. return nil, errors.Tracef("all attempts failed")
  468. }
  469. // Compress GetServerEntriesResponse responses.
  470. //
  471. // The CBOR-encoded SourcedServerEntry/protocol.PackedServerEntryFields
  472. // items in GetServerEntriesResponse benefit from compression due to
  473. // repeating server entry values. Only this response is compressed, as
  474. // other responses almost completely consist of non-repeating random
  475. // values.
  476. //
  477. // Compression is only added at the relay->client hop, to avoid additonal
  478. // CPU load on the DSL backend, and avoid relays having to always
  479. // decompress the backend response in cacheGetServerEntriesResponse.
  480. compression := common.CompressionNone
  481. if relayedRequest.RequestType == requestTypeGetServerEntries {
  482. compression = common.CompressionZlib
  483. }
  484. compressedResponse, err := common.Compress(compression, response)
  485. if err != nil {
  486. return nil, errors.Trace(err)
  487. }
  488. cborRelayedResponse, err := protocol.CBOREncoding.Marshal(
  489. &RelayedResponse{
  490. Compression: compression,
  491. Response: compressedResponse,
  492. })
  493. if err != nil {
  494. return nil, errors.Trace(err)
  495. }
  496. if len(cborRelayedResponse) > MaxRelayPayloadSize {
  497. return nil, errors.Tracef(
  498. "response size %d exceeds limit %d",
  499. len(cborRelayedResponse), MaxRelayPayloadSize)
  500. }
  501. return cborRelayedResponse, nil
  502. }
  503. func (r *Relay) cacheGetServerEntriesResponse(
  504. cborRequest []byte,
  505. cborResponse []byte) error {
  506. r.mutex.Lock()
  507. cache := r.serverEntryCache
  508. r.mutex.Unlock()
  509. if cache == nil {
  510. // Caching is disabled
  511. return nil
  512. }
  513. var request GetServerEntriesRequest
  514. err := cbor.Unmarshal(cborRequest, &request)
  515. if err != nil {
  516. return errors.Trace(err)
  517. }
  518. var response GetServerEntriesResponse
  519. err = cbor.Unmarshal(cborResponse, &response)
  520. if err != nil {
  521. return errors.Trace(err)
  522. }
  523. if len(request.ServerEntryTags) != len(response.SourcedServerEntries) {
  524. return errors.TraceNew("unexpected entry count mismatch")
  525. }
  526. for i, serverEntryTag := range request.ServerEntryTags {
  527. if response.SourcedServerEntries[i] != nil {
  528. // This will update any existing cached copy of the server entry for
  529. // this tag, in case the server entry version is new. This also
  530. // extends the cache TTL, since the server entry is fresh.
  531. cache.Set(
  532. string(serverEntryTag),
  533. response.SourcedServerEntries[i],
  534. lrucache.DefaultExpiration)
  535. } else {
  536. // In this case, the DSL backend is indicating that the server
  537. // entry for the requested tag no longer exists, perhaps due to
  538. // server pruning since the DiscoverServerEntries request. This
  539. // is an edge case since DiscoverServerEntries won't return
  540. // invalid tags and so the "nil" value/state isn't cached.
  541. cache.Delete(string(serverEntryTag))
  542. }
  543. }
  544. return nil
  545. }
  546. func (r *Relay) getCachedGetServerEntriesResponse(
  547. cborRequest []byte,
  548. clientGeoIPData common.GeoIPData) ([]byte, error) {
  549. r.mutex.Lock()
  550. cache := r.serverEntryCache
  551. r.mutex.Unlock()
  552. if cache == nil {
  553. // Caching is disabled
  554. return nil, nil
  555. }
  556. var request GetServerEntriesRequest
  557. err := cbor.Unmarshal(cborRequest, &request)
  558. if err != nil {
  559. return nil, errors.Trace(err)
  560. }
  561. // Since we anticipate that most server entries will be cached, allocate
  562. // response slices optimistically. Use buffer pools to mitigate GC churn.
  563. //
  564. // TODO: check for sufficient cache entries before allocating these
  565. // response slices? Would doubling the cache lookups use less resources
  566. // than unused allocations?
  567. buffer := r.getServerEntriesBufferPool.Get().([]*SourcedServerEntry)
  568. size := len(request.ServerEntryTags)
  569. if cap(buffer) < size {
  570. buffer = make([]*SourcedServerEntry, size)
  571. } else {
  572. buffer = buffer[:size]
  573. }
  574. defer func() {
  575. clear(buffer)
  576. r.getServerEntriesBufferPool.Put(buffer)
  577. }()
  578. var response GetServerEntriesResponse
  579. response.SourcedServerEntries = buffer
  580. for i, serverEntryTag := range request.ServerEntryTags {
  581. cacheEntry, ok := cache.Get(string(serverEntryTag))
  582. if !ok {
  583. // The request can't be served from the cache, as some server
  584. // entry tags aren't present. Fall back to a full request to the
  585. // DSL backend.
  586. //
  587. // As a potential future enhancement, consider partially serving
  588. // from the cache, after making a DSL request for just the
  589. // unknown server entries?
  590. return nil, nil
  591. }
  592. // The cached entry's TTL is not extended on a hit.
  593. response.SourcedServerEntries[i] = cacheEntry.(*SourcedServerEntry)
  594. }
  595. cborResponse, err := protocol.CBOREncoding.Marshal(&response)
  596. if err != nil {
  597. return nil, errors.Trace(err)
  598. }
  599. // Log the request event. Since this request is served from the relay
  600. // cache, the DSL backend will not see the request and log the event
  601. // itself. This log should match the DSL log format and can be shipped to
  602. // the same log aggregator.
  603. baseParams, err := protocol.DecodePackedAPIParameters(request.BaseAPIParameters)
  604. if err != nil {
  605. return nil, errors.Trace(err)
  606. }
  607. err = r.config.APIParameterValidator(baseParams)
  608. if err != nil {
  609. return nil, errors.Trace(err)
  610. }
  611. logFields := r.config.APIParameterLogFieldFormatter("", clientGeoIPData, baseParams)
  612. logFields["server_entry_tag_count"] = len(response.SourcedServerEntries)
  613. r.config.Logger.LogMetric("dsl_relay_get_server_entries", logFields)
  614. return cborResponse, nil
  615. }
  616. func (r *Relay) cacheGetOSLFileSpecsResponse(
  617. cborRequest []byte,
  618. cborResponse []byte) error {
  619. r.mutex.Lock()
  620. cache := r.oslFileSpecCache
  621. r.mutex.Unlock()
  622. if cache == nil {
  623. // Caching is disabled
  624. return nil
  625. }
  626. var request GetOSLFileSpecsRequest
  627. err := cbor.Unmarshal(cborRequest, &request)
  628. if err != nil {
  629. return errors.Trace(err)
  630. }
  631. var response GetOSLFileSpecsResponse
  632. err = cbor.Unmarshal(cborResponse, &response)
  633. if err != nil {
  634. return errors.Trace(err)
  635. }
  636. if len(request.OSLIDs) != len(response.OSLFileSpecs) {
  637. return errors.TraceNew("unexpected spec count mismatch")
  638. }
  639. for i, oslID := range request.OSLIDs {
  640. if response.OSLFileSpecs[i] != nil {
  641. // This will extend the cache TTL for existing entries.
  642. cache.Set(
  643. string(oslID),
  644. response.OSLFileSpecs[i],
  645. lrucache.DefaultExpiration)
  646. } else {
  647. // In this case, the DSL backend is indicating that the OSL file
  648. // spec is not longer active or available for distribution.
  649. cache.Delete(string(oslID))
  650. }
  651. }
  652. return nil
  653. }
  654. func (r *Relay) getCachedGetOSLFileSpecsResponse(
  655. cborRequest []byte,
  656. clientGeoIPData common.GeoIPData) ([]byte, error) {
  657. r.mutex.Lock()
  658. cache := r.oslFileSpecCache
  659. r.mutex.Unlock()
  660. if cache == nil {
  661. // Caching is disabled
  662. return nil, nil
  663. }
  664. var request GetOSLFileSpecsRequest
  665. err := cbor.Unmarshal(cborRequest, &request)
  666. if err != nil {
  667. return nil, errors.Trace(err)
  668. }
  669. // This logic mirrors getCachedGetServerEntriesResponse. See the comments
  670. // in that function.
  671. buffer := r.getOSLFileSpecsBufferPool.Get().([]OSLFileSpec)
  672. size := len(request.OSLIDs)
  673. if cap(buffer) < size {
  674. buffer = make([]OSLFileSpec, size)
  675. } else {
  676. buffer = buffer[:size]
  677. }
  678. defer func() {
  679. clear(buffer)
  680. r.getOSLFileSpecsBufferPool.Put(buffer)
  681. }()
  682. var response GetOSLFileSpecsResponse
  683. response.OSLFileSpecs = buffer
  684. for i, oslID := range request.OSLIDs {
  685. cacheEntry, ok := cache.Get(string(oslID))
  686. if !ok {
  687. return nil, nil
  688. }
  689. response.OSLFileSpecs[i] = cacheEntry.(OSLFileSpec)
  690. }
  691. cborResponse, err := protocol.CBOREncoding.Marshal(&response)
  692. if err != nil {
  693. return nil, errors.Trace(err)
  694. }
  695. baseParams, err := protocol.DecodePackedAPIParameters(request.BaseAPIParameters)
  696. if err != nil {
  697. return nil, errors.Trace(err)
  698. }
  699. err = r.config.APIParameterValidator(baseParams)
  700. if err != nil {
  701. return nil, errors.Trace(err)
  702. }
  703. logFields := r.config.APIParameterLogFieldFormatter("", clientGeoIPData, baseParams)
  704. logFields["osl_id_count"] = len(response.OSLFileSpecs)
  705. r.config.Logger.LogMetric("dsl_relay_get_osl_file_specs", logFields)
  706. return cborResponse, nil
  707. }
  708. var relayGenericErrorResponse []byte
  709. func init() {
  710. // Pre-marshal a generic, non-revealing error code to return on any
  711. // upstream failure.
  712. cborErrorResponse, err := protocol.CBOREncoding.Marshal(
  713. &RelayedResponse{
  714. Error: 1,
  715. })
  716. if err != nil {
  717. panic(err.Error())
  718. }
  719. relayGenericErrorResponse = cborErrorResponse
  720. }
  721. func GetRelayGenericErrorResponse() []byte {
  722. return relayGenericErrorResponse
  723. }