Procházet zdrojové kódy

Fix: iterator must filter out server entries that don't match specified region/protocol

Rod Hynes před 10 roky
rodič
revize
104e93fd28
2 změnil soubory, kde provedl 49 přidání a 25 odebrání
  1. 6 0
      psiphon/controller_test.go
  2. 43 25
      psiphon/dataStore_alt.go

+ 6 - 0
psiphon/controller_test.go

@@ -99,6 +99,12 @@ func controllerRun(t *testing.T, protocol string) {
 				}
 				}
 			case "ListeningHttpProxyPort":
 			case "ListeningHttpProxyPort":
 				httpProxyPort = int(payload["port"].(float64))
 				httpProxyPort = int(payload["port"].(float64))
+			case "ConnectingServer":
+				serverProtocol := payload["protocol"]
+				if serverProtocol != protocol {
+					t.Errorf("wrong protocol selected: %s", serverProtocol)
+					t.FailNow()
+				}
 			}
 			}
 		}))
 		}))
 
 

+ 43 - 25
psiphon/dataStore_alt.go

@@ -129,8 +129,15 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 		return ContextError(errors.New("invalid server entry"))
 		return ContextError(errors.New("invalid server entry"))
 	}
 	}
 
 
-	serverEntryExists := false
+	// BoltDB implementation note:
+	// For simplicity, we don't maintain indexes on server entry
+	// region or supported protocols. Instead, we perform full-bucket
+	// scans with a filter. With a small enough database (thousands or
+	// even tens of thousand of server entries) and common enough
+	// values (e.g., many servers support all protocols), performance
+	// is expected to be acceptable.
 
 
+	serverEntryExists := false
 	err = singleton.db.Update(func(tx *bolt.Tx) error {
 	err = singleton.db.Update(func(tx *bolt.Tx) error {
 
 
 		serverEntries := tx.Bucket([]byte(serverEntriesBucket))
 		serverEntries := tx.Bucket([]byte(serverEntriesBucket))
@@ -442,33 +449,44 @@ func (iterator *ServerEntryIterator) Next() (serverEntry *ServerEntry, err error
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
-	if iterator.serverEntryIndex >= len(iterator.serverEntryIds) {
-		// There is no next item
-		return nil, nil
-	}
+	// There are no region/protocol indexes for the server entries bucket.
+	// Loop until we have the next server entry that matches the iterator
+	// filter requirements.
+	for {
+		if iterator.serverEntryIndex >= len(iterator.serverEntryIds) {
+			// There is no next item
+			return nil, nil
+		}
 
 
-	serverEntryId := iterator.serverEntryIds[iterator.serverEntryIndex]
-	iterator.serverEntryIndex += 1
+		serverEntryId := iterator.serverEntryIds[iterator.serverEntryIndex]
+		iterator.serverEntryIndex += 1
 
 
-	var data []byte
-	err = singleton.db.View(func(tx *bolt.Tx) error {
-		bucket := tx.Bucket([]byte(serverEntriesBucket))
-		data = bucket.Get([]byte(serverEntryId))
-		return nil
-	})
-	if err != nil {
-		return nil, ContextError(err)
-	}
+		var data []byte
+		err = singleton.db.View(func(tx *bolt.Tx) error {
+			bucket := tx.Bucket([]byte(serverEntriesBucket))
+			data = bucket.Get([]byte(serverEntryId))
+			return nil
+		})
+		if err != nil {
+			return nil, ContextError(err)
+		}
 
 
-	if data == nil {
-		return nil, ContextError(
-			fmt.Errorf("Unexpected missing server entry: %s", serverEntryId))
-	}
+		if data == nil {
+			return nil, ContextError(
+				fmt.Errorf("Unexpected missing server entry: %s", serverEntryId))
+		}
 
 
-	serverEntry = new(ServerEntry)
-	err = json.Unmarshal(data, serverEntry)
-	if err != nil {
-		return nil, ContextError(err)
+		serverEntry = new(ServerEntry)
+		err = json.Unmarshal(data, serverEntry)
+		if err != nil {
+			return nil, ContextError(err)
+		}
+
+		if (iterator.region == "" || serverEntry.Region == iterator.region) &&
+			(iterator.protocol == "" || serverEntrySupportsProtocol(serverEntry, iterator.protocol)) {
+
+			break
+		}
 	}
 	}
 
 
 	return MakeCompatibleServerEntry(serverEntry), nil
 	return MakeCompatibleServerEntry(serverEntry), nil
@@ -519,7 +537,7 @@ func CountServerEntries(region, protocol string) int {
 	count := 0
 	count := 0
 	err := scanServerEntries(func(serverEntry *ServerEntry) {
 	err := scanServerEntries(func(serverEntry *ServerEntry) {
 		if (region == "" || serverEntry.Region == region) &&
 		if (region == "" || serverEntry.Region == region) &&
-			serverEntrySupportsProtocol(serverEntry, protocol) {
+			(protocol == "" || serverEntrySupportsProtocol(serverEntry, protocol)) {
 			count += 1
 			count += 1
 		}
 		}
 	})
 	})