Przeglądaj źródła

Implement relay server entry caching

Rod Hynes 6 miesięcy temu
rodzic
commit
a4abb2ebf5
2 zmienionych plików z 326 dodań i 55 usunięć
  1. 80 4
      psiphon/common/dsl/dsl_test.go
  2. 246 51
      psiphon/common/dsl/relay.go

+ 80 - 4
psiphon/common/dsl/dsl_test.go

@@ -61,6 +61,7 @@ type testConfig struct {
 	repeatBeforeTTL    bool
 	isConnected        bool
 	expectFailure      bool
+	cacheServerEntries bool
 }
 
 func TestDSLs(t *testing.T) {
@@ -108,6 +109,13 @@ func TestDSLs(t *testing.T) {
 
 			isConnected: true,
 		},
+		{
+			name: "cache server entries",
+
+			interruptDownloads: true,
+			enableRetries:      true,
+			cacheServerEntries: true,
+		},
 	}
 
 	for _, testConfig := range tests {
@@ -161,12 +169,33 @@ func testDSLs(testConfig *testConfig) error {
 
 	// Initialize relay
 
+	expectValidMetric := false
+	metricsValidator := func(metric string, fields common.LogFields) bool { return false }
+	if testConfig.cacheServerEntries {
+		expectValidMetric = true
+		metricsValidator = func(metric string, fields common.LogFields) bool {
+			return metric == "dsl" &&
+				fields["dsl_event"].(string) == "get-server-entries"
+		}
+	}
+
+	relayLogger := newTestLoggerWithMetricValidator("relay", metricsValidator)
+
 	relayConfig := &RelayConfig{
-		Logger:                      newTestLoggerWithComponent("relay"),
+		Logger:                      relayLogger,
 		CACertificates:              []*x509.Certificate{tlsConfig.CACertificate},
 		HostCertificate:             tlsConfig.relayCertificate,
 		DynamicServerListServiceURL: backend.getAddress(),
 		HostID:                      testHostID,
+
+		APIParameterValidator: func(params common.APIParameters) error { return nil },
+
+		APIParameterLogFieldFormatter: func(
+			_ string, _ common.GeoIPData, params common.APIParameters) common.LogFields {
+			logFields := common.LogFields{}
+			logFields.Add(common.LogFields(params))
+			return logFields
+		},
 	}
 
 	relay, err := NewRelay(relayConfig)
@@ -174,6 +203,10 @@ func testDSLs(testConfig *testConfig) error {
 		return errors.Trace(err)
 	}
 
+	if !testConfig.cacheServerEntries {
+		relay.SetCacheParameters(0, 0)
+	}
+
 	// Initialize client fetcher
 
 	// Set transfer targets that will exercise various scenarios, including
@@ -375,6 +408,11 @@ func testDSLs(testConfig *testConfig) error {
 		}
 	}
 
+	err = relayLogger.CheckMetrics(expectValidMetric)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
 	return nil
 }
 
@@ -1107,8 +1145,11 @@ func initializeTLSConfiguration() (*tlsConfig, error) {
 }
 
 type testLogger struct {
-	component     string
-	logLevelDebug int32
+	component        string
+	metricValidator  func(string, common.LogFields) bool
+	hasValidMetric   int32
+	hasInvalidMetric int32
+	logLevelDebug    int32
 }
 
 func newTestLogger() *testLogger {
@@ -1117,13 +1158,26 @@ func newTestLogger() *testLogger {
 	}
 }
 
-func newTestLoggerWithComponent(component string) *testLogger {
+func newTestLoggerWithComponent(
+	component string) *testLogger {
+
 	return &testLogger{
 		component:     component,
 		logLevelDebug: 0,
 	}
 }
 
+func newTestLoggerWithMetricValidator(
+	component string,
+	metricValidator func(string, common.LogFields) bool) *testLogger {
+
+	return &testLogger{
+		component:       component,
+		metricValidator: metricValidator,
+		logLevelDebug:   0,
+	}
+}
+
 func (logger *testLogger) WithTrace() common.LogTrace {
 	return &testLoggerTrace{
 		logger: logger,
@@ -1140,6 +1194,17 @@ func (logger *testLogger) WithTraceFields(fields common.LogFields) common.LogTra
 }
 
 func (logger *testLogger) LogMetric(metric string, fields common.LogFields) {
+
+	if logger.metricValidator != nil {
+		if logger.metricValidator(metric, fields) {
+			atomic.StoreInt32(&logger.hasValidMetric, 1)
+		} else {
+			atomic.StoreInt32(&logger.hasInvalidMetric, 1)
+		}
+		// Don't print log.
+		return
+	}
+
 	jsonFields, _ := json.Marshal(fields)
 	var component string
 	if len(logger.component) > 0 {
@@ -1153,6 +1218,17 @@ func (logger *testLogger) LogMetric(metric string, fields common.LogFields) {
 		string(jsonFields))
 }
 
+func (logger *testLogger) CheckMetrics(expectValidMetric bool) error {
+
+	if expectValidMetric && atomic.LoadInt32(&logger.hasValidMetric) != 1 {
+		return errors.TraceNew("missing valid metric")
+	}
+	if atomic.LoadInt32(&logger.hasInvalidMetric) == 1 {
+		return errors.TraceNew("has invalid metric")
+	}
+	return nil
+}
+
 func (logger *testLogger) IsLogLevelDebug() bool {
 	return atomic.LoadInt32(&logger.logLevelDebug) == 1
 }

+ 246 - 51
psiphon/common/dsl/relay.go

@@ -24,6 +24,7 @@ import (
 	"context"
 	"crypto/tls"
 	"crypto/x509"
+	"encoding/base64"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -58,12 +59,19 @@ const (
 type RelayConfig struct {
 	Logger common.Logger
 
-	CACertificates  []*x509.Certificate
+	CACertificates []*x509.Certificate
+
 	HostCertificate *tls.Certificate
 
 	DynamicServerListServiceURL string
 
 	HostID string
+
+	// APIParameterValidator is a callback that validates base API metrics.
+	APIParameterValidator common.APIParameterValidator
+
+	// APIParameterValidator is a callback that formats base API metrics.
+	APIParameterLogFieldFormatter common.APIParameterLogFieldFormatter
 }
 
 // Relay is an intermediary between a DSL client and the DSL backend which
@@ -82,13 +90,13 @@ type Relay struct {
 	tlsConfig     *tls.Config
 	errorResponse []byte
 
-	mutex                      sync.Mutex
-	httpClient                 *http.Client
-	requestTimeout             time.Duration
-	requestRetryCount          int
-	serverEntryCache           *lrucache.Cache
-	serverEntryCacheDefaultTTL time.Duration
-	serverEntryCacheMaxSize    int
+	mutex                   sync.Mutex
+	httpClient              *http.Client
+	requestTimeout          time.Duration
+	requestRetryCount       int
+	serverEntryCache        *lrucache.Cache
+	serverEntryCacheTTL     time.Duration
+	serverEntryCacheMaxSize int
 }
 
 // NewRelay creates a new Relay.
@@ -184,27 +192,34 @@ func (r *Relay) SetRequestParameters(
 // entry caching. When the parameters change, any existing cache is flushed
 // and replaced.
 func (r *Relay) SetCacheParameters(
-	defaultTTL time.Duration,
+	TTL time.Duration,
 	maxSize int) {
 
 	r.mutex.Lock()
 	defer r.mutex.Unlock()
 
 	if r.serverEntryCache == nil ||
-		r.serverEntryCacheDefaultTTL != defaultTTL ||
+		r.serverEntryCacheTTL != TTL ||
 		r.serverEntryCacheMaxSize != maxSize {
 
 		if r.serverEntryCache != nil {
 			r.serverEntryCache.Flush()
 		}
 
-		r.serverEntryCacheDefaultTTL = defaultTTL
+		r.serverEntryCacheTTL = TTL
 		r.serverEntryCacheMaxSize = maxSize
 
-		r.serverEntryCache = lrucache.NewWithLRU(
-			r.serverEntryCacheDefaultTTL,
-			1*time.Minute,
-			r.serverEntryCacheMaxSize)
+		if r.serverEntryCacheTTL > 0 {
+
+			r.serverEntryCache = lrucache.NewWithLRU(
+				r.serverEntryCacheTTL,
+				1*time.Minute,
+				r.serverEntryCacheMaxSize)
+
+		} else {
+
+			r.serverEntryCache = nil
+		}
 	}
 }
 
@@ -273,30 +288,65 @@ func (r *Relay) handleRequest(
 			"unknown request type %d", relayedRequest.RequestType)
 	}
 
-	// TODO: implement transparent server entry caching.
+	// Transparent caching:
 	//
 	// For requestTypeGetServerEntries, peek at the RelayedResponse.Response
 	// and extract server entries and add to the local cache, keyed by server
-	// entry tag. When the server entry has a specific TTL, use that as the
-	// cache TTL, otherwise using serverEntryCacheDefaultTTL.
+	// entry tag.
 	//
 	// Peek at RelayedRequest.Request, and if all requested server entries are
 	// in the cache, serve the request entirely from the local cache.
-	// Consider also modifying requests to only fetch server entries that are
-	// not cached.
 	//
-	// Also handle for changes to server entry version.
-
-	requestCtx := ctx
-	if requestTimeout > 0 {
-		var requestCancelFunc context.CancelFunc
-		requestCtx, requestCancelFunc = context.WithTimeout(ctx, requestTimeout)
-		defer requestCancelFunc()
+	// The backend DSL may enforce a limited time interval in which certain
+	// server entries can be discovered. This cache doesn't bypass this,
+	// since DiscoveryServerEntries isn't cached and always passed through to
+	// the DSL backend. Clients must discover the large, random server entry
+	// tags via DiscoveryServerEntries within the designated time interval;
+	// then clients may download the server entries via GetServerEntries at
+	// any time, and this may be cached.
+	//
+	// Limitation: this cache ignores server entry version and may serve a
+	// version that's older that the latest within the cache TTL.
+	//
+	// - Server entry version changes are assumed to be rare.
+	//
+	// - The cache will be updated with a new version as soon as
+	//   cacheGetServerEntriesResponse sees it.
+	//
+	// - Use a reasonable TTL such as 24h; cache entry TTLs aren't extended on
+	//   hits, so any old version will eventually be removed.
+	//
+	// - A more complicated scheme is possible: also peek at
+	//   DiscoverServerEntriesResponses and, for each tag/version pair, if
+	//   the tag is in the cache and the cached entry is an old version,
+	//   delete from the cache. This would require unpacking each server entry.
+
+	var response []byte
+	cachedResponse := false
+
+	if relayedRequest.RequestType == requestTypeGetServerEntries {
+		var err error
+		response, err = r.getCachedGetServerEntriesResponse(
+			relayedRequest.Request, clientGeoIPData)
+		if err != nil {
+			r.config.Logger.WithTraceFields(common.LogFields{
+				"error": err.Error(),
+			}).Warning("DSL: serve cached response failed")
+			// Proceed with relaying request
+		}
+		cachedResponse = err == nil && response != nil
 	}
 
-	url := fmt.Sprintf("https://%s%s", r.config.DynamicServerListServiceURL, path)
+	for i := 0; !cachedResponse; i++ {
 
-	for i := 0; ; i++ {
+		requestCtx := ctx
+		if requestTimeout > 0 {
+			var requestCancelFunc context.CancelFunc
+			requestCtx, requestCancelFunc = context.WithTimeout(ctx, requestTimeout)
+			defer requestCancelFunc()
+		}
+
+		url := fmt.Sprintf("https://%s%s", r.config.DynamicServerListServiceURL, path)
 
 		httpRequest, err := http.NewRequestWithContext(
 			requestCtx, "POST", url, bytes.NewBuffer(relayedRequest.Request))
@@ -328,41 +378,186 @@ func (r *Relay) handleRequest(
 			err = errors.Tracef("unexpected response code: %d", httpResponse.StatusCode)
 		}
 
-		var response []byte
 		if err == nil {
 			response, err = io.ReadAll(httpResponse.Body)
 			httpResponse.Body.Close()
 		}
 
-		if err != nil {
-
-			r.config.Logger.WithTraceFields(common.LogFields{
-				"duration": duration.String(),
-				"error":    err.Error(),
-			}).Warning("DSL: service request attempt failed")
+		if err == nil {
 
-			// Retry on network errors.
-			if i < requestRetryCount && ctx.Err() == nil {
-				continue
+			if relayedRequest.RequestType == requestTypeGetServerEntries {
+				err := r.cacheGetServerEntriesResponse(
+					relayedRequest.Request, response)
+				if err != nil {
+					r.config.Logger.WithTraceFields(common.LogFields{
+						"error": err.Error(),
+					}).Warning("DSL: cache response failed")
+					// Proceed with relaying response
+				}
 			}
 
-			return nil, errors.Tracef("all attempts failed")
+			break
 		}
 
-		cborRelayedResponse, err := protocol.CBOREncoding.Marshal(
-			&RelayedResponse{
-				Response: response,
-			})
-		if err != nil {
-			return nil, errors.Trace(err)
+		r.config.Logger.WithTraceFields(common.LogFields{
+			"duration": duration.String(),
+			"error":    err.Error(),
+		}).Warning("DSL: service request attempt failed")
+
+		// Retry on network errors.
+		if i < requestRetryCount && ctx.Err() == nil {
+			continue
+		}
+
+		return nil, errors.Tracef("all attempts failed")
+	}
+
+	cborRelayedResponse, err := protocol.CBOREncoding.Marshal(
+		&RelayedResponse{
+			Response: response,
+		})
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	if len(cborRelayedResponse) > MaxRelayPayloadSize {
+		return nil, errors.Tracef(
+			"response size %d exceeds limit %d",
+			len(cborRelayedResponse), MaxRelayPayloadSize)
+	}
+
+	return cborRelayedResponse, nil
+}
+
+func (r *Relay) cacheGetServerEntriesResponse(
+	cborRequest []byte,
+	cborResponse []byte) error {
+
+	if r.serverEntryCacheTTL == 0 {
+		// Caching is disabled
+		return nil
+	}
+
+	var request GetServerEntriesRequest
+	err := cbor.Unmarshal(cborRequest, &request)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	var response GetServerEntriesResponse
+	err = cbor.Unmarshal(cborResponse, &response)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	if len(request.ServerEntryTags) != len(response.SourcedServerEntries) {
+		return errors.TraceNew("unexpected entry count mismatch")
+	}
+
+	for i, serverEntryTag := range request.ServerEntryTags {
+
+		if response.SourcedServerEntries[i] != nil {
+
+			// This will update any existing cached copy of the server entry for
+			// this tag, in case the server entry version is new. This also
+			// extends the cache TTL, since the server entry is fresh.
+
+			r.serverEntryCache.Set(
+				string(serverEntryTag),
+				response.SourcedServerEntries[i],
+				lrucache.DefaultExpiration)
+
+		} else {
+
+			// In this case, the DSL backend is indicating that the server
+			// entry for the requested tag no longer exists, perhaps due to
+			// server pruning since the DiscoverServerEntries request. This
+			// is an edge case since DiscoverServerEntries won't return
+			// invalid tags and so the "nil" value/state isn't cached.
+
+			r.serverEntryCache.Delete(string(serverEntryTag))
 		}
+	}
 
-		if len(cborRelayedResponse) > MaxRelayPayloadSize {
-			return nil, errors.Tracef(
-				"response size %d exceeds limit %d",
-				len(cborRelayedResponse), MaxRelayPayloadSize)
+	return nil
+}
+
+func (r *Relay) getCachedGetServerEntriesResponse(
+	cborRequest []byte,
+	clientGeoIPData common.GeoIPData) ([]byte, error) {
+
+	if r.serverEntryCacheTTL == 0 {
+		// Caching is disabled
+		return nil, nil
+	}
+
+	var request GetServerEntriesRequest
+	err := cbor.Unmarshal(cborRequest, &request)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// Since we anticipate that most server entries will be cached, allocate
+	// response slices optimistically.
+	//
+	// TODO: check for sufficient cache entries before allocating these
+	// response slices? Would doubling the cache lookups use less resources
+	// than unused allocations?
+
+	serverEntryTags := make([]string, len(request.ServerEntryTags))
+
+	var response GetServerEntriesResponse
+	response.SourcedServerEntries = make([]*SourcedServerEntry, len(request.ServerEntryTags))
+
+	for i, serverEntryTag := range request.ServerEntryTags {
+		cacheEntry, ok := r.serverEntryCache.Get(string(serverEntryTag))
+		if !ok {
+
+			// The request can't be served from the cache, as some server
+			// entry tags aren't present. Fall back to a full request to the
+			// DSL backend.
+			//
+			// As a potential future enhancement, consider partially serving
+			// from the cache, after making a DSL request for just the
+			// unknown server entries?
+			return nil, nil
 		}
 
-		return cborRelayedResponse, nil
+		// The cached entry's TTL is not extended on a hit.
+
+		// serverEntryTags are used for logging the request event when served
+		// from the cache. Use the same same string encoding as
+		// protocol.GenerateServerEntryTag.
+		serverEntryTags[i] = base64.StdEncoding.EncodeToString(serverEntryTag)
+
+		response.SourcedServerEntries[i] = cacheEntry.(*SourcedServerEntry)
 	}
+
+	cborResponse, err := protocol.CBOREncoding.Marshal(&response)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	// Log the request event. Since this request is server from the relay
+	// cache, the DSL backend will not see the request and log the event
+	// itself. This log should match the DSL log format and can be shipped to
+	// the same log aggregator.
+
+	baseParams, err := protocol.DecodePackedAPIParameters(request.BaseAPIParameters)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	err = r.config.APIParameterValidator(baseParams)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	logFields := r.config.APIParameterLogFieldFormatter("", clientGeoIPData, baseParams)
+	logFields["dsl_event"] = "get-server-entries"
+	logFields["host_id"] = r.config.HostID
+	logFields["server_entry_tags"] = serverEntryTags
+	r.config.Logger.LogMetric("dsl", logFields)
+
+	return cborResponse, nil
 }