Browse Source

Merge pull request #351 from rod-hynes/master

Concurrent OSL schemes
Rod Hynes 9 years ago
parent
commit
39af973c9b
4 changed files with 216 additions and 105 deletions
  1. 136 102
      psiphon/common/osl/osl.go
  2. 19 2
      psiphon/common/osl/osl_test.go
  3. 23 0
      psiphon/dataStore.go
  4. 38 1
      psiphon/server/server_test.go

+ 136 - 102
psiphon/common/osl/osl.go

@@ -97,7 +97,7 @@ type Scheme struct {
 
 	// SeedSpecs is the set of different client network activity patterns
 	// that will result in issuing SLOKs. For a given time period, a distinct
-	// SLOK is issued for each SeedLevel in each SeedSpec.
+	// SLOK is issued for each SeedSpec.
 	// Duplicate subnets may appear in multiple SeedSpecs.
 	SeedSpecs []*SeedSpec
 
@@ -185,18 +185,25 @@ type KeySplit struct {
 	Threshold int
 }
 
-// ClientSeedState tracks the progress of a client towards seeding SLOKs.
+// ClientSeedState tracks the progress of a client towards seeding SLOKs
+// across all schemes the client qualifies for.
 type ClientSeedState struct {
-	scheme               *Scheme
 	propagationChannelID string
 	signalIssueSLOKs     chan struct{}
-	progress             []*TrafficValues
-	progressSLOKTime     int64
+	seedProgress         []*ClientSeedProgress
 	mutex                sync.Mutex
 	issuedSLOKs          map[string]*SLOK
 	payloadSLOKs         []*SLOK
 }
 
+// ClientSeedProgress tracks client progress towards seeding SLOKs for
+// a particular scheme.
+type ClientSeedProgress struct {
+	scheme           *Scheme
+	trafficProgress  []*TrafficValues
+	progressSLOKTime int64
+}
+
 // ClientSeedPortForward map a client port forward, which is relaying
 // traffic to a specific upstream address, to all seed state progress
 // counters for SeedSpecs with subnets containing the upstream address.
@@ -204,8 +211,16 @@ type ClientSeedState struct {
 // and duration count towards the progress of these SeedSpecs and
 // associated SLOKs.
 type ClientSeedPortForward struct {
-	state           *ClientSeedState
-	progressIndexes []int
+	state              *ClientSeedState
+	progressReferences []progressReference
+}
+
+// progressReference points to a particular ClientSeedProgress and
+// TrafficValues for to update with traffic events for a
+// ClientSeedPortForward.
+type progressReference struct {
+	seedProgressIndex    int
+	trafficProgressIndex int
 }
 
 // slokReference uniquely identifies a SLOK by specifying all the fields
@@ -347,8 +362,16 @@ func (config *Config) NewClientSeedState(
 	config.ReloadableFile.RLock()
 	defer config.ReloadableFile.RUnlock()
 
+	state := &ClientSeedState{
+		propagationChannelID: propagationChannelID,
+		signalIssueSLOKs:     signalIssueSLOKs,
+		issuedSLOKs:          make(map[string]*SLOK),
+		payloadSLOKs:         nil,
+	}
+
 	for _, scheme := range config.Schemes {
-		// Only the first matching scheme is selected.
+
+		// All matching schemes are selected.
 		// Note: this implementation assumes a few simple schemes. For more
 		// schemes with many propagation channel IDs or region filters, use
 		// maps for more efficient lookup.
@@ -359,24 +382,22 @@ func (config *Config) NewClientSeedState(
 			// Empty progress is initialized up front for all seed specs. Once
 			// created, the progress structure is read-only (the slice, not the
 			// TrafficValue fields); this permits lock-free operation.
-			progress := make([]*TrafficValues, len(scheme.SeedSpecs))
+			trafficProgress := make([]*TrafficValues, len(scheme.SeedSpecs))
 			for index := 0; index < len(scheme.SeedSpecs); index++ {
-				progress[index] = &TrafficValues{}
+				trafficProgress[index] = &TrafficValues{}
 			}
 
-			return &ClientSeedState{
-				scheme:               scheme,
-				propagationChannelID: propagationChannelID,
-				signalIssueSLOKs:     signalIssueSLOKs,
-				progressSLOKTime:     getSLOKTime(scheme.SeedPeriodNanoseconds),
-				progress:             progress,
-				issuedSLOKs:          make(map[string]*SLOK),
-				payloadSLOKs:         nil,
+			seedProgress := &ClientSeedProgress{
+				scheme:           scheme,
+				progressSLOKTime: getSLOKTime(scheme.SeedPeriodNanoseconds),
+				trafficProgress:  trafficProgress,
 			}
+
+			state.seedProgress = append(state.seedProgress, seedProgress)
 		}
 	}
 
-	return &ClientSeedState{}
+	return state
 }
 
 // NewClientSeedPortForwardState creates a new client port forward
@@ -393,34 +414,41 @@ func (state *ClientSeedState) NewClientSeedPortForward(
 	// Concurrency: access to ClientSeedState is unsynchronized
 	// but references only read-only fields.
 
-	if state.scheme == nil {
+	if len(state.seedProgress) == 0 {
 		return nil
 	}
 
-	var progressIndexes []int
+	var progressReferences []progressReference
 
 	// Determine which seed spec subnets contain upstreamIPAddress
 	// and point to the progress for each. When progress is reported,
 	// it is added directly to all of these TrafficValues instances.
 	// Assumes state.progress entries correspond 1-to-1 with
 	// state.scheme.subnetLookups.
-	// Note: this implementation assumes a small number of seed specs.
-	// For larger numbers, instead of N SubnetLookups, create a single
-	// SubnetLookup which returns, for a given IP address, all matching
-	// subnets and associated seed specs.
-	for index, subnetLookup := range state.scheme.subnetLookups {
-		if subnetLookup.ContainsIPAddress(upstreamIPAddress) {
-			progressIndexes = append(progressIndexes, index)
+	// Note: this implementation assumes a small number of schemes and
+	// seed specs. For larger numbers, instead of N SubnetLookups, create
+	// a single SubnetLookup which returns, for a given IP address, all
+	// matching subnets and associated seed specs.
+	for seedProgressIndex, seedProgress := range state.seedProgress {
+		for trafficProgressIndex, subnetLookup := range seedProgress.scheme.subnetLookups {
+			if subnetLookup.ContainsIPAddress(upstreamIPAddress) {
+				progressReferences = append(
+					progressReferences,
+					progressReference{
+						seedProgressIndex:    seedProgressIndex,
+						trafficProgressIndex: trafficProgressIndex,
+					})
+			}
 		}
 	}
 
-	if progressIndexes == nil {
+	if progressReferences == nil {
 		return nil
 	}
 
 	return &ClientSeedPortForward{
-		state:           state,
-		progressIndexes: progressIndexes,
+		state:              state,
+		progressReferences: progressReferences,
 	}
 }
 
@@ -449,47 +477,50 @@ func (portForward *ClientSeedPortForward) UpdateProgress(
 	// to read-only fields, atomic, or channels, except in the case of a time
 	// period rollover, in which case a mutex is acquired.
 
-	slokTime := getSLOKTime(portForward.state.scheme.SeedPeriodNanoseconds)
-
-	// If the SLOK time period has changed since progress was last recorded,
-	// call issueSLOKs which will issue any SLOKs for that past time period
-	// and then clear all progress. Progress will then be recorded for the
-	// current time period.
-	// As it acquires the state mutex, issueSLOKs may stall other port
-	// forwards for this client. The delay is minimized by SLOK caching,
-	// which avoids redundant crypto operations.
-	if slokTime != atomic.LoadInt64(&portForward.state.progressSLOKTime) {
-		portForward.state.mutex.Lock()
-		portForward.state.issueSLOKs()
-		portForward.state.mutex.Unlock()
-
-		// Call to issueSLOKs may have issued new SLOKs. Note that
-		// this will only happen if the time period rolls over with
-		// sufficient progress pending while the signalIssueSLOKs
-		// receiver did not call IssueSLOKs soon enough.
-		portForward.state.sendIssueSLOKsSignal()
-	}
+	for _, progressReference := range portForward.progressReferences {
+
+		seedProgress := portForward.state.seedProgress[progressReference.seedProgressIndex]
+		trafficProgress := seedProgress.trafficProgress[progressReference.trafficProgressIndex]
+
+		slokTime := getSLOKTime(seedProgress.scheme.SeedPeriodNanoseconds)
+
+		// If the SLOK time period has changed since progress was last recorded,
+		// call issueSLOKs which will issue any SLOKs for that past time period
+		// and then clear all progress. Progress will then be recorded for the
+		// current time period.
+		// As it acquires the state mutex, issueSLOKs may stall other port
+		// forwards for this client. The delay is minimized by SLOK caching,
+		// which avoids redundant crypto operations.
+		if slokTime != atomic.LoadInt64(&seedProgress.progressSLOKTime) {
+			portForward.state.mutex.Lock()
+			portForward.state.issueSLOKs()
+			portForward.state.mutex.Unlock()
+
+			// Call to issueSLOKs may have issued new SLOKs. Note that
+			// this will only happen if the time period rolls over with
+			// sufficient progress pending while the signalIssueSLOKs
+			// receiver did not call IssueSLOKs soon enough.
+			portForward.state.sendIssueSLOKsSignal()
+		}
 
-	// Add directly to the permanent TrafficValues progress accumulators
-	// for the state's seed specs. Concurrently, other port forwards may
-	// be adding to the same accumulators. Also concurrently, another
-	// goroutine may be invoking issueSLOKs, which zeros all the accumulators.
-	// As a consequence, progress may be dropped at the exact time of
-	// time period rollover.
-	for _, progressIndex := range portForward.progressIndexes {
+		// Add directly to the permanent TrafficValues progress accumulators
+		// for the state's seed specs. Concurrently, other port forwards may
+		// be adding to the same accumulators. Also concurrently, another
+		// goroutine may be invoking issueSLOKs, which zeros all the accumulators.
+		// As a consequence, progress may be dropped at the exact time of
+		// time period rollover.
 
-		seedSpec := portForward.state.scheme.SeedSpecs[progressIndex]
-		progress := portForward.state.progress[progressIndex]
+		seedSpec := seedProgress.scheme.SeedSpecs[progressReference.trafficProgressIndex]
 
-		alreadyExceedsTargets := progress.exceeds(&seedSpec.Targets)
+		alreadyExceedsTargets := trafficProgress.exceeds(&seedSpec.Targets)
 
-		atomic.AddInt64(&progress.BytesRead, bytesRead)
-		atomic.AddInt64(&progress.BytesWritten, bytesWritten)
-		atomic.AddInt64(&progress.PortForwardDurationNanoseconds, durationNanoseconds)
+		atomic.AddInt64(&trafficProgress.BytesRead, bytesRead)
+		atomic.AddInt64(&trafficProgress.BytesWritten, bytesWritten)
+		atomic.AddInt64(&trafficProgress.PortForwardDurationNanoseconds, durationNanoseconds)
 
 		// With the target newly met for a SeedSpec, a new
 		// SLOK *may* be issued.
-		if !alreadyExceedsTargets && progress.exceeds(&seedSpec.Targets) {
+		if !alreadyExceedsTargets && trafficProgress.exceeds(&seedSpec.Targets) {
 			portForward.state.sendIssueSLOKsSignal()
 		}
 	}
@@ -514,54 +545,57 @@ func (state *ClientSeedState) issueSLOKs() {
 
 	// Concurrency: the caller must lock state.mutex.
 
-	if state.scheme == nil {
+	if len(state.seedProgress) == 0 {
 		return
 	}
 
-	progressSLOKTime := time.Unix(0, state.progressSLOKTime)
+	for _, seedProgress := range state.seedProgress {
 
-	for index, progress := range state.progress {
+		progressSLOKTime := time.Unix(0, seedProgress.progressSLOKTime)
 
-		seedSpec := state.scheme.SeedSpecs[index]
+		for index, trafficProgress := range seedProgress.trafficProgress {
 
-		if progress.exceeds(&seedSpec.Targets) {
+			seedSpec := seedProgress.scheme.SeedSpecs[index]
 
-			ref := &slokReference{
-				PropagationChannelID: state.propagationChannelID,
-				SeedSpecID:           string(seedSpec.ID),
-				Time:                 progressSLOKTime,
-			}
+			if trafficProgress.exceeds(&seedSpec.Targets) {
 
-			state.scheme.derivedSLOKCacheMutex.RLock()
-			slok, ok := state.scheme.derivedSLOKCache[*ref]
-			state.scheme.derivedSLOKCacheMutex.RUnlock()
-			if !ok {
-				slok = state.scheme.deriveSLOK(ref)
-				state.scheme.derivedSLOKCacheMutex.Lock()
-				state.scheme.derivedSLOKCache[*ref] = slok
-				state.scheme.derivedSLOKCacheMutex.Unlock()
-			}
+				ref := &slokReference{
+					PropagationChannelID: state.propagationChannelID,
+					SeedSpecID:           string(seedSpec.ID),
+					Time:                 progressSLOKTime,
+				}
 
-			// Previously issued SLOKs are not re-added to
-			// the payload.
-			if state.issuedSLOKs[string(slok.ID)] == nil {
-				state.issuedSLOKs[string(slok.ID)] = slok
-				state.payloadSLOKs = append(state.payloadSLOKs, slok)
+				seedProgress.scheme.derivedSLOKCacheMutex.RLock()
+				slok, ok := seedProgress.scheme.derivedSLOKCache[*ref]
+				seedProgress.scheme.derivedSLOKCacheMutex.RUnlock()
+				if !ok {
+					slok = seedProgress.scheme.deriveSLOK(ref)
+					seedProgress.scheme.derivedSLOKCacheMutex.Lock()
+					seedProgress.scheme.derivedSLOKCache[*ref] = slok
+					seedProgress.scheme.derivedSLOKCacheMutex.Unlock()
+				}
+
+				// Previously issued SLOKs are not re-added to
+				// the payload.
+				if state.issuedSLOKs[string(slok.ID)] == nil {
+					state.issuedSLOKs[string(slok.ID)] = slok
+					state.payloadSLOKs = append(state.payloadSLOKs, slok)
+				}
 			}
 		}
-	}
 
-	slokTime := getSLOKTime(state.scheme.SeedPeriodNanoseconds)
-
-	if slokTime != atomic.LoadInt64(&state.progressSLOKTime) {
-		atomic.StoreInt64(&state.progressSLOKTime, slokTime)
-		// The progress map structure is not reset or modifed; instead
-		// the mapped accumulator values are zeroed. Concurrently, port
-		// forward relay goroutines continue to add to these accumulators.
-		for _, progress := range state.progress {
-			atomic.StoreInt64(&progress.BytesRead, 0)
-			atomic.StoreInt64(&progress.BytesWritten, 0)
-			atomic.StoreInt64(&progress.PortForwardDurationNanoseconds, 0)
+		slokTime := getSLOKTime(seedProgress.scheme.SeedPeriodNanoseconds)
+
+		if slokTime != atomic.LoadInt64(&seedProgress.progressSLOKTime) {
+			atomic.StoreInt64(&seedProgress.progressSLOKTime, slokTime)
+			// The progress map structure is not reset or modifed; instead
+			// the mapped accumulator values are zeroed. Concurrently, port
+			// forward relay goroutines continue to add to these accumulators.
+			for _, trafficProgress := range seedProgress.trafficProgress {
+				atomic.StoreInt64(&trafficProgress.BytesRead, 0)
+				atomic.StoreInt64(&trafficProgress.BytesWritten, 0)
+				atomic.StoreInt64(&trafficProgress.PortForwardDurationNanoseconds, 0)
+			}
 		}
 	}
 }
@@ -580,7 +614,7 @@ func (state *ClientSeedState) GetSeedPayload() *SeedPayload {
 	state.mutex.Lock()
 	defer state.mutex.Unlock()
 
-	if state.scheme == nil {
+	if len(state.seedProgress) == 0 {
 		return &SeedPayload{}
 	}
 

+ 19 - 2
psiphon/common/osl/osl_test.go

@@ -40,7 +40,7 @@ func TestOSL(t *testing.T) {
 
       "Regions" : ["US", "CA"],
 
-      "PropagationChannelIDs" : ["2995DB0C968C59C4F23E87988D9C0D41", "E742C25A6D8BA8C17F37E725FA628569"],
+      "PropagationChannelIDs" : ["2995DB0C968C59C4F23E87988D9C0D41", "E742C25A6D8BA8C17F37E725FA628569", "B4A780E67695595FA486E9B900EA7335"],
 
       "MasterKey" : "wFuSbqU/pJ/35vRmoM8T9ys1PgDa8uzJps1Y+FNKa5U=",
 
@@ -100,7 +100,7 @@ func TestOSL(t *testing.T) {
 
       "Regions" : ["US", "CA"],
 
-      "PropagationChannelIDs" : ["36F1CF2DF1250BF0C7BA0629CE3DC657"],
+      "PropagationChannelIDs" : ["36F1CF2DF1250BF0C7BA0629CE3DC657", "B4A780E67695595FA486E9B900EA7335"],
 
       "MasterKey" : "fcyQy8JSxLXHt/Iom9Qj9wMnSjrsccTiiSPEsJicet4=",
 
@@ -296,6 +296,23 @@ func TestOSL(t *testing.T) {
 		}
 	})
 
+	t.Run("concurrent schemes", func(t *testing.T) {
+
+		rolloverToNextSLOKTime()
+
+		clientSeedState := config.NewClientSeedState("US", "B4A780E67695595FA486E9B900EA7335", nil)
+
+		clientSeedPortForward := clientSeedState.NewClientSeedPortForward(net.ParseIP("192.168.0.1"))
+
+		clientSeedPortForward.UpdateProgress(5, 5, 5)
+
+		clientSeedPortForward.UpdateProgress(5, 5, 5)
+
+		if len(clientSeedState.GetSeedPayload().SLOKs) != 5 {
+			t.Fatalf("expected 5 SLOKs, got %d", len(clientSeedState.GetSeedPayload().SLOKs))
+		}
+	})
+
 	signingPublicKey, signingPrivateKey, err := common.GenerateAuthenticatedDataPackageKeys()
 	if err != nil {
 		t.Fatalf("GenerateAuthenticatedDataPackageKeys failed: %s", err)

+ 23 - 0
psiphon/dataStore.go

@@ -1101,6 +1101,29 @@ func resetAllPersistentStatsToUnreported() error {
 	return nil
 }
 
+// CountSLOKs returns the total number of SLOK records.
+func CountSLOKs() int {
+	checkInitDataStore()
+
+	count := 0
+
+	err := singleton.db.View(func(tx *bolt.Tx) error {
+		bucket := tx.Bucket([]byte(slokBucket))
+		cursor := bucket.Cursor()
+		for key, _ := cursor.First(); key != nil; key, _ = cursor.Next() {
+			count++
+		}
+		return nil
+	})
+
+	if err != nil {
+		NoticeAlert("CountSLOKs failed: %s", err)
+		return 0
+	}
+
+	return count
+}
+
 // DeleteSLOKs deletes all SLOK records.
 func DeleteSLOKs() error {
 	checkInitDataStore()

+ 38 - 1
psiphon/server/server_test.go

@@ -414,6 +414,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	if err != nil {
 		t.Fatalf("error initializing client datastore: %s", err)
 	}
+	psiphon.DeleteSLOKs()
 
 	controller, err := psiphon.NewController(clientConfig)
 	if err != nil {
@@ -544,8 +545,14 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	// Test: await SLOK payload
 
 	if !runConfig.denyTrafficRules {
+
 		time.Sleep(1 * time.Second)
 		waitOnNotification(t, slokSeeded, timeoutSignal, "SLOK seeded timeout exceeded")
+
+		numSLOKs := psiphon.CountSLOKs()
+		if numSLOKs != expectedNumSLOKs {
+			t.Fatalf("unexpected number of SLOKs: %d", numSLOKs)
+		}
 	}
 }
 
@@ -871,6 +878,8 @@ func paveTrafficRulesFile(
 	}
 }
 
+var expectedNumSLOKs = 3
+
 func paveOSLConfigFile(t *testing.T, oslConfigFilename string) string {
 
 	oslConfigJSONFormat := `
@@ -911,6 +920,32 @@ func paveOSLConfigFile(t *testing.T, oslConfigFilename string) string {
               "Threshold": 2
             }
           ]
+        },
+        {
+          "Epoch" : "%s",
+          "Regions" : [],
+          "PropagationChannelIDs" : ["%s"],
+          "MasterKey" : "HDc/mvd7e+lKDJD0fMpJW66YJ/VW4iqDRjeclEsMnro=",
+          "SeedSpecs" : [
+            {
+              "ID" : "/M0vsT0IjzmI0MvTI9IYe8OVyeQGeaPZN2xGxfLw/UQ=",
+              "UpstreamSubnets" : ["0.0.0.0/0"],
+              "Targets" :
+              {
+                  "BytesRead" : 1,
+                  "BytesWritten" : 1,
+                  "PortForwardDurationNanoseconds" : 1
+              }
+            }
+          ],
+          "SeedSpecThreshold" : 1,
+          "SeedPeriodNanoseconds" : 10000000000,
+          "SeedPeriodKeySplits": [
+            {
+              "Total": 1,
+              "Threshold": 1
+            }
+          ]
         }
       ]
     }
@@ -923,7 +958,9 @@ func paveOSLConfigFile(t *testing.T, oslConfigFilename string) string {
 	epochStr := epoch.Format(time.RFC3339Nano)
 
 	oslConfigJSON := fmt.Sprintf(
-		oslConfigJSONFormat, epochStr, propagationChannelID)
+		oslConfigJSONFormat,
+		epochStr, propagationChannelID,
+		epochStr, propagationChannelID)
 
 	err := ioutil.WriteFile(oslConfigFilename, []byte(oslConfigJSON), 0600)
 	if err != nil {