Explorar el Código

migration errors are no longer fatal; remove legacy data store file on successful migration; removed all references to target server entries in the migration code; do not export any methods or types from the migration code; placed call to migrateEntries inside of singleton.init.Do after singleton.db is assigned; removed unused global migratableServerEntries; reversed the order of the data store file existence checks and added explanatory comments

Michael Goldberger hace 10 años
padre
commit
a24febe34b
Se han modificado 3 ficheros con 53 adiciones y 91 borrados
  1. 1 2
      ConsoleClient/psiphonClient.go
  2. 10 5
      psiphon/dataStore_alt.go
  3. 42 84
      psiphon/migrateDataStore.go

+ 1 - 2
ConsoleClient/psiphonClient.go

@@ -28,8 +28,7 @@ import (
 	"runtime/pprof"
 	"sync"
 
-	// TODO: Put this back to the real github URL (can't go get from a branch, so this seemed reasonable for now)
-	"../psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 )
 
 func main() {

+ 10 - 5
psiphon/dataStore_alt.go

@@ -72,7 +72,7 @@ func InitDataStore(config *Config) (err error) {
 	var migratableServerEntries []*ServerEntry
 
 	singleton.init.Do(func() {
-		migratableServerEntries, err = PrepareMigrationEntries(config)
+		migratableServerEntries, err = prepareMigrationEntries(config)
 		if err != nil {
 			return
 		}
@@ -109,11 +109,16 @@ func InitDataStore(config *Config) (err error) {
 		}
 
 		singleton.db = db
-	})
 
-	if len(migratableServerEntries) > 0 {
-		err = MigrateEntries(migratableServerEntries)
-	}
+		// The migrateServerEntries function requires the data store is
+		// initialized prior to execution so that migrated entries can be stored
+		if len(migratableServerEntries) > 0 {
+			migrationFailures := migrateEntries(migratableServerEntries, filepath.Join(config.DataStoreDirectory, LEGACY_DATA_STORE_FILENAME))
+			if migrationFailures != nil {
+				NoticeAlert("initDataStore failed to migrate legacy server entries: %s", migrationFailures)
+			}
+		}
+	})
 
 	return err
 }

+ 42 - 84
psiphon/migrateDataStore.go

@@ -24,11 +24,9 @@ package psiphon
 import (
 	"database/sql"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"os"
 	"path/filepath"
-	"strings"
 
 	_ "github.com/Psiphon-Inc/go-sqlite3"
 )
@@ -36,11 +34,11 @@ import (
 var legacyDb *sql.DB
 var migratableServerEntries []*ServerEntry
 
-func PrepareMigrationEntries(config *Config) ([]*ServerEntry, error) {
-	if _, err := os.Stat(filepath.Join(config.DataStoreDirectory, LEGACY_DATA_STORE_FILENAME)); err == nil {
-		if _, err := os.Stat(filepath.Join(config.DataStoreDirectory, DATA_STORE_FILENAME)); os.IsNotExist(err) {
-			NoticeInfo("sqlite DB found, boltdb not found; preparing datastore migration")
-
+func prepareMigrationEntries(config *Config) ([]*ServerEntry, error) {
+	// If DATA_STORE_FILENAME does not exist on disk
+	if _, err := os.Stat(filepath.Join(config.DataStoreDirectory, DATA_STORE_FILENAME)); os.IsNotExist(err) {
+		// If LEGACY_DATA_STORE_FILENAME exists on disk
+		if _, err := os.Stat(filepath.Join(config.DataStoreDirectory, LEGACY_DATA_STORE_FILENAME)); err == nil {
 			legacyDb, err = sql.Open("sqlite3", fmt.Sprintf("file:%s?cache=private&mode=rwc", filepath.Join(config.DataStoreDirectory, LEGACY_DATA_STORE_FILENAME)))
 			if err != nil {
 				return migratableServerEntries, err
@@ -52,7 +50,7 @@ func PrepareMigrationEntries(config *Config) ([]*ServerEntry, error) {
 				return migratableServerEntries, err
 			}
 
-			iterator, err := NewLegacyServerEntryIterator(config)
+			iterator, err := newlegacyServerEntryIterator(config)
 			if err != nil {
 				return migratableServerEntries, err
 			}
@@ -61,61 +59,60 @@ func PrepareMigrationEntries(config *Config) ([]*ServerEntry, error) {
 			for {
 				serverEntry, err := iterator.Next()
 				if err != nil {
-					err = fmt.Errorf("Failed to iterate legacy server entries: %s", err)
+					err = fmt.Errorf("failed to iterate legacy server entries: %s", err)
 					break
 				}
 				if serverEntry == nil {
 					break
 				}
 
-				NoticeInfo("Server entry (%s) prepped for migration", serverEntry.IpAddress)
 				migratableServerEntries = append(migratableServerEntries, serverEntry)
 			}
-			NoticeInfo("All entries prepped")
+			NoticeInfo("%d server entries prepared for data store migration", len(migratableServerEntries))
 		}
 	}
 
 	return migratableServerEntries, nil
 }
 
-func MigrateEntries(serverEntries []*ServerEntry) error {
+func migrateEntries(serverEntries []*ServerEntry, legacyDataStoreFilename string) error {
+	checkInitDataStore()
+
 	err := StoreServerEntries(serverEntries, false)
 	if err != nil {
 		return err
 	}
+	NoticeInfo("%d server entries successfully migrated to new data store", len(serverEntries))
+
+	err = os.Remove(legacyDataStoreFilename)
+	if err != nil {
+		NoticeInfo("failed to delete legacy data store file '%s': %s", legacyDataStoreFilename, err)
+	}
+
 	return nil
 }
 
 // This code is copied from the dataStore.go code used to operate the legacy
-// SQLite datastore. The word "Legacy" was added to all of the method names to avoid
+// SQLite datastore. The word "legacy" was added to all of the method names to avoid
 // namespace conflicts with the methods used to operate the BoltDB datastore
 
-// LegacyServerEntryIterator is used to iterate over
+// legacyServerEntryIterator is used to iterate over
 // stored server entries in rank order.
-type LegacyServerEntryIterator struct {
-	region                      string
-	protocol                    string
-	shuffleHeadLength           int
-	transaction                 *sql.Tx
-	cursor                      *sql.Rows
-	isTargetServerEntryIterator bool
-	hasNextTargetServerEntry    bool
-	targetServerEntry           *ServerEntry
+type legacyServerEntryIterator struct {
+	region            string
+	protocol          string
+	shuffleHeadLength int
+	transaction       *sql.Tx
+	cursor            *sql.Rows
 }
 
-// NewLegacyServerEntryIterator creates a new NewLegacyServerEntryIterator
-func NewLegacyServerEntryIterator(config *Config) (iterator *LegacyServerEntryIterator, err error) {
+// newLegacyServerEntryIterator creates a new legacyServerEntryIterator
+func newlegacyServerEntryIterator(config *Config) (iterator *legacyServerEntryIterator, err error) {
 
-	// When configured, this target server entry is the only candidate
-	if config.TargetServerEntry != "" {
-		return newLegacyTargetServerEntryIterator(config)
-	}
-
-	iterator = &LegacyServerEntryIterator{
-		region:                      config.EgressRegion,
-		protocol:                    config.TunnelProtocol,
-		shuffleHeadLength:           config.TunnelPoolSize,
-		isTargetServerEntryIterator: false,
+	iterator = &legacyServerEntryIterator{
+		region:            config.EgressRegion,
+		protocol:          config.TunnelProtocol,
+		shuffleHeadLength: config.TunnelPoolSize,
 	}
 	err = iterator.Reset()
 	if err != nil {
@@ -124,33 +121,8 @@ func NewLegacyServerEntryIterator(config *Config) (iterator *LegacyServerEntryIt
 	return iterator, nil
 }
 
-// newLegacyTargetServerEntryIterator is a helper for initializing the LegacyTargetServerEntry case
-func newLegacyTargetServerEntryIterator(config *Config) (iterator *LegacyServerEntryIterator, err error) {
-	serverEntry, err := DecodeServerEntry(config.TargetServerEntry)
-	if err != nil {
-		return nil, err
-	}
-	if config.EgressRegion != "" && serverEntry.Region != config.EgressRegion {
-		return nil, errors.New("TargetServerEntry does not support EgressRegion")
-	}
-	if config.TunnelProtocol != "" {
-		// Note: same capability/protocol mapping as in StoreServerEntry
-		requiredCapability := strings.TrimSuffix(config.TunnelProtocol, "-OSSH")
-		if !Contains(serverEntry.Capabilities, requiredCapability) {
-			return nil, errors.New("TargetServerEntry does not support TunnelProtocol")
-		}
-	}
-	iterator = &LegacyServerEntryIterator{
-		isTargetServerEntryIterator: true,
-		hasNextTargetServerEntry:    true,
-		targetServerEntry:           serverEntry,
-	}
-	NoticeInfo("using TargetServerEntry: %s", serverEntry.IpAddress)
-	return iterator, nil
-}
-
-// Close cleans up resources associated with a LegacyServerEntryIterator.
-func (iterator *LegacyServerEntryIterator) Close() {
+// Close cleans up resources associated with a legacyServerEntryIterator.
+func (iterator *legacyServerEntryIterator) Close() {
 	if iterator.cursor != nil {
 		iterator.cursor.Close()
 	}
@@ -161,23 +133,15 @@ func (iterator *LegacyServerEntryIterator) Close() {
 	iterator.transaction = nil
 }
 
-// Next returns the next server entry, by rank, for a LegacyServerEntryIterator.
+// Next returns the next server entry, by rank, for a legacyServerEntryIterator.
 // Returns nil with no error when there is no next item.
-func (iterator *LegacyServerEntryIterator) Next() (serverEntry *ServerEntry, err error) {
+func (iterator *legacyServerEntryIterator) Next() (serverEntry *ServerEntry, err error) {
 	defer func() {
 		if err != nil {
 			iterator.Close()
 		}
 	}()
 
-	if iterator.isTargetServerEntryIterator {
-		if iterator.hasNextTargetServerEntry {
-			iterator.hasNextTargetServerEntry = false
-			return MakeCompatibleServerEntry(iterator.targetServerEntry), nil
-		}
-		return nil, nil
-	}
-
 	if !iterator.cursor.Next() {
 		err = iterator.cursor.Err()
 		if err != nil {
@@ -201,17 +165,12 @@ func (iterator *LegacyServerEntryIterator) Next() (serverEntry *ServerEntry, err
 	return MakeCompatibleServerEntry(serverEntry), nil
 }
 
-// Reset a NewLegacyServerEntryIterator to the start of its cycle. The next
+// Reset a NewlegacyServerEntryIterator to the start of its cycle. The next
 // call to Next will return the first server entry.
-func (iterator *LegacyServerEntryIterator) Reset() error {
+func (iterator *legacyServerEntryIterator) Reset() error {
 	iterator.Close()
 
-	if iterator.isTargetServerEntryIterator {
-		iterator.hasNextTargetServerEntry = true
-		return nil
-	}
-
-	count := CountLegacyServerEntries(iterator.region, iterator.protocol)
+	count := countLegacyServerEntries(iterator.region, iterator.protocol)
 	NoticeCandidateServers(iterator.region, iterator.protocol, count)
 
 	transaction, err := legacyDb.Begin()
@@ -289,16 +248,15 @@ func makeServerEntryWhereClause(
 	return whereClause, whereParams
 }
 
-// CountLegacyServerEntries returns a count of stored servers for the
-// specified region and protocol.
-func CountLegacyServerEntries(region, protocol string) int {
+// countLegacyServerEntries returns a count of stored servers for the specified region and protocol.
+func countLegacyServerEntries(region, protocol string) int {
 	var count int
 	whereClause, whereParams := makeServerEntryWhereClause(region, protocol, nil)
 	query := "select count(*) from serverEntry" + whereClause
 	err := legacyDb.QueryRow(query, whereParams...).Scan(&count)
 
 	if err != nil {
-		NoticeAlert("CountLegacyServerEntries failed: %s", err)
+		NoticeAlert("countLegacyServerEntries failed: %s", err)
 		return 0
 	}