|
|
@@ -24,11 +24,9 @@ package psiphon
|
|
|
import (
|
|
|
"database/sql"
|
|
|
"encoding/json"
|
|
|
- "errors"
|
|
|
"fmt"
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
- "strings"
|
|
|
|
|
|
_ "github.com/Psiphon-Inc/go-sqlite3"
|
|
|
)
|
|
|
@@ -36,11 +34,11 @@ import (
|
|
|
var legacyDb *sql.DB
|
|
|
var migratableServerEntries []*ServerEntry
|
|
|
|
|
|
-func PrepareMigrationEntries(config *Config) ([]*ServerEntry, error) {
|
|
|
- if _, err := os.Stat(filepath.Join(config.DataStoreDirectory, LEGACY_DATA_STORE_FILENAME)); err == nil {
|
|
|
- if _, err := os.Stat(filepath.Join(config.DataStoreDirectory, DATA_STORE_FILENAME)); os.IsNotExist(err) {
|
|
|
- NoticeInfo("sqlite DB found, boltdb not found; preparing datastore migration")
|
|
|
-
|
|
|
+func prepareMigrationEntries(config *Config) ([]*ServerEntry, error) {
|
|
|
+ // If DATA_STORE_FILENAME does not exist on disk
|
|
|
+ if _, err := os.Stat(filepath.Join(config.DataStoreDirectory, DATA_STORE_FILENAME)); os.IsNotExist(err) {
|
|
|
+ // If LEGACY_DATA_STORE_FILENAME exists on disk
|
|
|
+ if _, err := os.Stat(filepath.Join(config.DataStoreDirectory, LEGACY_DATA_STORE_FILENAME)); err == nil {
|
|
|
legacyDb, err = sql.Open("sqlite3", fmt.Sprintf("file:%s?cache=private&mode=rwc", filepath.Join(config.DataStoreDirectory, LEGACY_DATA_STORE_FILENAME)))
|
|
|
if err != nil {
|
|
|
return migratableServerEntries, err
|
|
|
@@ -52,7 +50,7 @@ func PrepareMigrationEntries(config *Config) ([]*ServerEntry, error) {
|
|
|
return migratableServerEntries, err
|
|
|
}
|
|
|
|
|
|
- iterator, err := NewLegacyServerEntryIterator(config)
|
|
|
+ iterator, err := newlegacyServerEntryIterator(config)
|
|
|
if err != nil {
|
|
|
return migratableServerEntries, err
|
|
|
}
|
|
|
@@ -61,61 +59,60 @@ func PrepareMigrationEntries(config *Config) ([]*ServerEntry, error) {
|
|
|
for {
|
|
|
serverEntry, err := iterator.Next()
|
|
|
if err != nil {
|
|
|
- err = fmt.Errorf("Failed to iterate legacy server entries: %s", err)
|
|
|
+ err = fmt.Errorf("failed to iterate legacy server entries: %s", err)
|
|
|
break
|
|
|
}
|
|
|
if serverEntry == nil {
|
|
|
break
|
|
|
}
|
|
|
|
|
|
- NoticeInfo("Server entry (%s) prepped for migration", serverEntry.IpAddress)
|
|
|
migratableServerEntries = append(migratableServerEntries, serverEntry)
|
|
|
}
|
|
|
- NoticeInfo("All entries prepped")
|
|
|
+ NoticeInfo("%d server entries prepared for data store migration", len(migratableServerEntries))
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return migratableServerEntries, nil
|
|
|
}
|
|
|
|
|
|
-func MigrateEntries(serverEntries []*ServerEntry) error {
|
|
|
+func migrateEntries(serverEntries []*ServerEntry, legacyDataStoreFilename string) error {
|
|
|
+ checkInitDataStore()
|
|
|
+
|
|
|
err := StoreServerEntries(serverEntries, false)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+ NoticeInfo("%d server entries successfully migrated to new data store", len(serverEntries))
|
|
|
+
|
|
|
+ err = os.Remove(legacyDataStoreFilename)
|
|
|
+ if err != nil {
|
|
|
+ NoticeInfo("failed to delete legacy data store file '%s': %s", legacyDataStoreFilename, err)
|
|
|
+ }
|
|
|
+
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// This code is copied from the dataStore.go code used to operate the legacy
|
|
|
-// SQLite datastore. The word "Legacy" was added to all of the method names to avoid
|
|
|
+// SQLite datastore. The word "legacy" was added to all of the method names to avoid
|
|
|
// namespace conflicts with the methods used to operate the BoltDB datastore
|
|
|
|
|
|
-// LegacyServerEntryIterator is used to iterate over
|
|
|
+// legacyServerEntryIterator is used to iterate over
|
|
|
// stored server entries in rank order.
|
|
|
-type LegacyServerEntryIterator struct {
|
|
|
- region string
|
|
|
- protocol string
|
|
|
- shuffleHeadLength int
|
|
|
- transaction *sql.Tx
|
|
|
- cursor *sql.Rows
|
|
|
- isTargetServerEntryIterator bool
|
|
|
- hasNextTargetServerEntry bool
|
|
|
- targetServerEntry *ServerEntry
|
|
|
+type legacyServerEntryIterator struct {
|
|
|
+ region string
|
|
|
+ protocol string
|
|
|
+ shuffleHeadLength int
|
|
|
+ transaction *sql.Tx
|
|
|
+ cursor *sql.Rows
|
|
|
}
|
|
|
|
|
|
-// NewLegacyServerEntryIterator creates a new NewLegacyServerEntryIterator
|
|
|
-func NewLegacyServerEntryIterator(config *Config) (iterator *LegacyServerEntryIterator, err error) {
|
|
|
+// newLegacyServerEntryIterator creates a new legacyServerEntryIterator
|
|
|
+func newlegacyServerEntryIterator(config *Config) (iterator *legacyServerEntryIterator, err error) {
|
|
|
|
|
|
- // When configured, this target server entry is the only candidate
|
|
|
- if config.TargetServerEntry != "" {
|
|
|
- return newLegacyTargetServerEntryIterator(config)
|
|
|
- }
|
|
|
-
|
|
|
- iterator = &LegacyServerEntryIterator{
|
|
|
- region: config.EgressRegion,
|
|
|
- protocol: config.TunnelProtocol,
|
|
|
- shuffleHeadLength: config.TunnelPoolSize,
|
|
|
- isTargetServerEntryIterator: false,
|
|
|
+ iterator = &legacyServerEntryIterator{
|
|
|
+ region: config.EgressRegion,
|
|
|
+ protocol: config.TunnelProtocol,
|
|
|
+ shuffleHeadLength: config.TunnelPoolSize,
|
|
|
}
|
|
|
err = iterator.Reset()
|
|
|
if err != nil {
|
|
|
@@ -124,33 +121,8 @@ func NewLegacyServerEntryIterator(config *Config) (iterator *LegacyServerEntryIt
|
|
|
return iterator, nil
|
|
|
}
|
|
|
|
|
|
-// newLegacyTargetServerEntryIterator is a helper for initializing the LegacyTargetServerEntry case
|
|
|
-func newLegacyTargetServerEntryIterator(config *Config) (iterator *LegacyServerEntryIterator, err error) {
|
|
|
- serverEntry, err := DecodeServerEntry(config.TargetServerEntry)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- if config.EgressRegion != "" && serverEntry.Region != config.EgressRegion {
|
|
|
- return nil, errors.New("TargetServerEntry does not support EgressRegion")
|
|
|
- }
|
|
|
- if config.TunnelProtocol != "" {
|
|
|
- // Note: same capability/protocol mapping as in StoreServerEntry
|
|
|
- requiredCapability := strings.TrimSuffix(config.TunnelProtocol, "-OSSH")
|
|
|
- if !Contains(serverEntry.Capabilities, requiredCapability) {
|
|
|
- return nil, errors.New("TargetServerEntry does not support TunnelProtocol")
|
|
|
- }
|
|
|
- }
|
|
|
- iterator = &LegacyServerEntryIterator{
|
|
|
- isTargetServerEntryIterator: true,
|
|
|
- hasNextTargetServerEntry: true,
|
|
|
- targetServerEntry: serverEntry,
|
|
|
- }
|
|
|
- NoticeInfo("using TargetServerEntry: %s", serverEntry.IpAddress)
|
|
|
- return iterator, nil
|
|
|
-}
|
|
|
-
|
|
|
-// Close cleans up resources associated with a LegacyServerEntryIterator.
|
|
|
-func (iterator *LegacyServerEntryIterator) Close() {
|
|
|
+// Close cleans up resources associated with a legacyServerEntryIterator.
|
|
|
+func (iterator *legacyServerEntryIterator) Close() {
|
|
|
if iterator.cursor != nil {
|
|
|
iterator.cursor.Close()
|
|
|
}
|
|
|
@@ -161,23 +133,15 @@ func (iterator *LegacyServerEntryIterator) Close() {
|
|
|
iterator.transaction = nil
|
|
|
}
|
|
|
|
|
|
-// Next returns the next server entry, by rank, for a LegacyServerEntryIterator.
|
|
|
+// Next returns the next server entry, by rank, for a legacyServerEntryIterator.
|
|
|
// Returns nil with no error when there is no next item.
|
|
|
-func (iterator *LegacyServerEntryIterator) Next() (serverEntry *ServerEntry, err error) {
|
|
|
+func (iterator *legacyServerEntryIterator) Next() (serverEntry *ServerEntry, err error) {
|
|
|
defer func() {
|
|
|
if err != nil {
|
|
|
iterator.Close()
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
- if iterator.isTargetServerEntryIterator {
|
|
|
- if iterator.hasNextTargetServerEntry {
|
|
|
- iterator.hasNextTargetServerEntry = false
|
|
|
- return MakeCompatibleServerEntry(iterator.targetServerEntry), nil
|
|
|
- }
|
|
|
- return nil, nil
|
|
|
- }
|
|
|
-
|
|
|
if !iterator.cursor.Next() {
|
|
|
err = iterator.cursor.Err()
|
|
|
if err != nil {
|
|
|
@@ -201,17 +165,12 @@ func (iterator *LegacyServerEntryIterator) Next() (serverEntry *ServerEntry, err
|
|
|
return MakeCompatibleServerEntry(serverEntry), nil
|
|
|
}
|
|
|
|
|
|
-// Reset a NewLegacyServerEntryIterator to the start of its cycle. The next
|
|
|
+// Reset a NewlegacyServerEntryIterator to the start of its cycle. The next
|
|
|
// call to Next will return the first server entry.
|
|
|
-func (iterator *LegacyServerEntryIterator) Reset() error {
|
|
|
+func (iterator *legacyServerEntryIterator) Reset() error {
|
|
|
iterator.Close()
|
|
|
|
|
|
- if iterator.isTargetServerEntryIterator {
|
|
|
- iterator.hasNextTargetServerEntry = true
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- count := CountLegacyServerEntries(iterator.region, iterator.protocol)
|
|
|
+ count := countLegacyServerEntries(iterator.region, iterator.protocol)
|
|
|
NoticeCandidateServers(iterator.region, iterator.protocol, count)
|
|
|
|
|
|
transaction, err := legacyDb.Begin()
|
|
|
@@ -289,16 +248,15 @@ func makeServerEntryWhereClause(
|
|
|
return whereClause, whereParams
|
|
|
}
|
|
|
|
|
|
-// CountLegacyServerEntries returns a count of stored servers for the
|
|
|
-// specified region and protocol.
|
|
|
-func CountLegacyServerEntries(region, protocol string) int {
|
|
|
+// countLegacyServerEntries returns a count of stored servers for the specified region and protocol.
|
|
|
+func countLegacyServerEntries(region, protocol string) int {
|
|
|
var count int
|
|
|
whereClause, whereParams := makeServerEntryWhereClause(region, protocol, nil)
|
|
|
query := "select count(*) from serverEntry" + whereClause
|
|
|
err := legacyDb.QueryRow(query, whereParams...).Scan(&count)
|
|
|
|
|
|
if err != nil {
|
|
|
- NoticeAlert("CountLegacyServerEntries failed: %s", err)
|
|
|
+ NoticeAlert("countLegacyServerEntries failed: %s", err)
|
|
|
return 0
|
|
|
}
|
|
|
|