فهرست منبع

Interrupt server entry imports when stopping

Fixes both remote and embedded server entry imports blocking stop. Any server
entries already imported when an import is interrupted are retained.

Both types of import will eventually succeed in a future session, given
sufficient time. In the remote case, the resumable download logic should
result in a short request followed by import of the previously downloaded
payload.
Rod Hynes 5 سال پیش
والد
کامیت
d9cb0fc459
5فایلهای تغییر یافته به همراه49 افزوده شده و 25 حذف شده
  1. 12 8
      ClientLibrary/clientlib/clientlib.go
  2. 10 9
      ConsoleClient/main.go
  3. 6 4
      MobileLibrary/psi/psi.go
  4. 19 4
      psiphon/dataStore.go
  5. 2 0
      psiphon/remoteServerList.go

+ 12 - 8
ClientLibrary/clientlib/clientlib.go

@@ -236,6 +236,10 @@ func StartTunnel(ctx context.Context,
 		}
 		}
 	}()
 	}()
 
 
+	// Create a cancelable context that will be used for stopping the tunnel
+	var controllerCtx context.Context
+	controllerCtx, tunnel.stopController = context.WithCancel(ctx)
+
 	// If specified, the embedded server list is loaded and stored. When there
 	// If specified, the embedded server list is loaded and stored. When there
 	// are no server candidates at all, we wait for this import to complete
 	// are no server candidates at all, we wait for this import to complete
 	// before starting the Psiphon controller. Otherwise, we import while
 	// before starting the Psiphon controller. Otherwise, we import while
@@ -246,13 +250,14 @@ func StartTunnel(ctx context.Context,
 	// still started: either existing candidate servers may suffice, or the
 	// still started: either existing candidate servers may suffice, or the
 	// remote server list fetch may obtain candidate servers.
 	// remote server list fetch may obtain candidate servers.
 	//
 	//
-	// TODO: abort import if controller run ctx is cancelled. Currently, this
-	// import will block shutdown.
+	// The import will be interrupted if it's still running when the controller
+	// is stopped.
 	tunnel.embeddedServerListWaitGroup.Add(1)
 	tunnel.embeddedServerListWaitGroup.Add(1)
 	go func() {
 	go func() {
 		defer tunnel.embeddedServerListWaitGroup.Done()
 		defer tunnel.embeddedServerListWaitGroup.Done()
 
 
 		err := psiphon.ImportEmbeddedServerEntries(
 		err := psiphon.ImportEmbeddedServerEntries(
+			controllerCtx,
 			config,
 			config,
 			"",
 			"",
 			embeddedServerEntryList)
 			embeddedServerEntryList)
@@ -269,13 +274,11 @@ func StartTunnel(ctx context.Context,
 	// Create the Psiphon controller
 	// Create the Psiphon controller
 	controller, err := psiphon.NewController(config)
 	controller, err := psiphon.NewController(config)
 	if err != nil {
 	if err != nil {
+		tunnel.stopController()
+		tunnel.embeddedServerListWaitGroup.Wait()
 		return nil, errors.TraceMsg(err, "psiphon.NewController failed")
 		return nil, errors.TraceMsg(err, "psiphon.NewController failed")
 	}
 	}
 
 
-	// Create a cancelable context that will be used for stopping the tunnel
-	var controllerCtx context.Context
-	controllerCtx, tunnel.stopController = context.WithCancel(ctx)
-
 	// Begin tunnel connection
 	// Begin tunnel connection
 	tunnel.controllerWaitGroup.Add(1)
 	tunnel.controllerWaitGroup.Add(1)
 	go func() {
 	go func() {
@@ -306,9 +309,10 @@ func StartTunnel(ctx context.Context,
 // Stop stops/disconnects/shuts down the tunnel. It is safe to call when not connected.
 // Stop stops/disconnects/shuts down the tunnel. It is safe to call when not connected.
 // Not safe to call concurrently with Start.
 // Not safe to call concurrently with Start.
 func (tunnel *PsiphonTunnel) Stop() {
 func (tunnel *PsiphonTunnel) Stop() {
-	if tunnel.stopController != nil {
-		tunnel.stopController()
+	if tunnel.stopController == nil {
+		return
 	}
 	}
+	tunnel.stopController()
 	tunnel.controllerWaitGroup.Wait()
 	tunnel.controllerWaitGroup.Wait()
 	tunnel.embeddedServerListWaitGroup.Wait()
 	tunnel.embeddedServerListWaitGroup.Wait()
 	psiphon.CloseDataStore()
 	psiphon.CloseDataStore()

+ 10 - 9
ConsoleClient/main.go

@@ -248,15 +248,15 @@ func main() {
 		}
 		}
 	}
 	}
 
 
-	err = worker.Init(config)
+	workCtx, stopWork := context.WithCancel(context.Background())
+	defer stopWork()
+
+	err = worker.Init(workCtx, config)
 	if err != nil {
 	if err != nil {
 		psiphon.NoticeError("error in init: %s", err)
 		psiphon.NoticeError("error in init: %s", err)
 		os.Exit(1)
 		os.Exit(1)
 	}
 	}
 
 
-	workCtx, stopWork := context.WithCancel(context.Background())
-	defer stopWork()
-
 	workWaitGroup := new(sync.WaitGroup)
 	workWaitGroup := new(sync.WaitGroup)
 	workWaitGroup.Add(1)
 	workWaitGroup.Add(1)
 	go func() {
 	go func() {
@@ -352,7 +352,7 @@ func (p *tunProvider) GetSecondaryDnsServer() string {
 // compiled executable.
 // compiled executable.
 type Worker interface {
 type Worker interface {
 	// Init is called once for the worker to perform any initialization.
 	// Init is called once for the worker to perform any initialization.
-	Init(config *psiphon.Config) error
+	Init(ctx context.Context, config *psiphon.Config) error
 	// Run is called once, after Init(..), for the worker to perform its
 	// Run is called once, after Init(..), for the worker to perform its
 	// work. The provided context should control the lifetime of the work
 	// work. The provided context should control the lifetime of the work
 	// being performed.
 	// being performed.
@@ -367,7 +367,7 @@ type TunnelWorker struct {
 }
 }
 
 
 // Init implements the Worker interface.
 // Init implements the Worker interface.
-func (w *TunnelWorker) Init(config *psiphon.Config) error {
+func (w *TunnelWorker) Init(ctx context.Context, config *psiphon.Config) error {
 
 
 	// Initialize data store
 	// Initialize data store
 
 
@@ -387,8 +387,8 @@ func (w *TunnelWorker) Init(config *psiphon.Config) error {
 	// still started: either existing candidate servers may suffice, or the
 	// still started: either existing candidate servers may suffice, or the
 	// remote server list fetch may obtain candidate servers.
 	// remote server list fetch may obtain candidate servers.
 	//
 	//
-	// TODO: abort import if controller run ctx is cancelled. Currently, this
-	// import will block shutdown.
+	// The import will be interrupted if it's still running when the controller
+	// is stopped.
 	if w.embeddedServerEntryListFilename != "" {
 	if w.embeddedServerEntryListFilename != "" {
 		w.embeddedServerListWaitGroup = new(sync.WaitGroup)
 		w.embeddedServerListWaitGroup = new(sync.WaitGroup)
 		w.embeddedServerListWaitGroup.Add(1)
 		w.embeddedServerListWaitGroup.Add(1)
@@ -396,6 +396,7 @@ func (w *TunnelWorker) Init(config *psiphon.Config) error {
 			defer w.embeddedServerListWaitGroup.Done()
 			defer w.embeddedServerListWaitGroup.Done()
 
 
 			err = psiphon.ImportEmbeddedServerEntries(
 			err = psiphon.ImportEmbeddedServerEntries(
+				ctx,
 				config,
 				config,
 				w.embeddedServerEntryListFilename,
 				w.embeddedServerEntryListFilename,
 				"")
 				"")
@@ -441,7 +442,7 @@ type FeedbackWorker struct {
 }
 }
 
 
 // Init implements the Worker interface.
 // Init implements the Worker interface.
-func (f *FeedbackWorker) Init(config *psiphon.Config) error {
+func (f *FeedbackWorker) Init(ctx context.Context, config *psiphon.Config) error {
 
 
 	// The datastore is not opened here, with psiphon.OpenDatastore,
 	// The datastore is not opened here, with psiphon.OpenDatastore,
 	// because it is opened/closed transiently in the psiphon.SendFeedback
 	// because it is opened/closed transiently in the psiphon.SendFeedback

+ 6 - 4
MobileLibrary/psi/psi.go

@@ -184,6 +184,8 @@ func Start(
 		return fmt.Errorf("error initializing datastore: %s", err)
 		return fmt.Errorf("error initializing datastore: %s", err)
 	}
 	}
 
 
+	controllerCtx, stopController = context.WithCancel(context.Background())
+
 	// If specified, the embedded server list is loaded and stored. When there
 	// If specified, the embedded server list is loaded and stored. When there
 	// are no server candidates at all, we wait for this import to complete
 	// are no server candidates at all, we wait for this import to complete
 	// before starting the Psiphon controller. Otherwise, we import while
 	// before starting the Psiphon controller. Otherwise, we import while
@@ -194,14 +196,15 @@ func Start(
 	// still started: either existing candidate servers may suffice, or the
 	// still started: either existing candidate servers may suffice, or the
 	// remote server list fetch may obtain candidate servers.
 	// remote server list fetch may obtain candidate servers.
 	//
 	//
-	// TODO: abort import if controller run ctx is cancelled. Currently, this
-	// import will block shutdown.
+	// The import will be interrupted if it's still running when the controller
+	// is stopped.
 	embeddedServerListWaitGroup = new(sync.WaitGroup)
 	embeddedServerListWaitGroup = new(sync.WaitGroup)
 	embeddedServerListWaitGroup.Add(1)
 	embeddedServerListWaitGroup.Add(1)
 	go func() {
 	go func() {
 		defer embeddedServerListWaitGroup.Done()
 		defer embeddedServerListWaitGroup.Done()
 
 
 		err := psiphon.ImportEmbeddedServerEntries(
 		err := psiphon.ImportEmbeddedServerEntries(
+			controllerCtx,
 			config,
 			config,
 			embeddedServerEntryListFilename,
 			embeddedServerEntryListFilename,
 			embeddedServerEntryList)
 			embeddedServerEntryList)
@@ -217,13 +220,12 @@ func Start(
 
 
 	controller, err = psiphon.NewController(config)
 	controller, err = psiphon.NewController(config)
 	if err != nil {
 	if err != nil {
+		stopController()
 		embeddedServerListWaitGroup.Wait()
 		embeddedServerListWaitGroup.Wait()
 		psiphon.CloseDataStore()
 		psiphon.CloseDataStore()
 		return fmt.Errorf("error initializing controller: %s", err)
 		return fmt.Errorf("error initializing controller: %s", err)
 	}
 	}
 
 
-	controllerCtx, stopController = context.WithCancel(context.Background())
-
 	controllerWaitGroup = new(sync.WaitGroup)
 	controllerWaitGroup = new(sync.WaitGroup)
 	controllerWaitGroup.Add(1)
 	controllerWaitGroup.Add(1)
 	go func() {
 	go func() {

+ 19 - 4
psiphon/dataStore.go

@@ -21,6 +21,7 @@ package psiphon
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"context"
 	"encoding/json"
 	"encoding/json"
 	"math"
 	"math"
 	"os"
 	"os"
@@ -362,9 +363,12 @@ func StoreServerEntries(
 	return nil
 	return nil
 }
 }
 
 
-// StreamingStoreServerEntries stores a list of server entries.
-// There is an independent transaction for each entry insert/update.
+// StreamingStoreServerEntries stores a list of server entries. There is an
+// independent transaction for each entry insert/update.
+// StreamingStoreServerEntries stops early and returns an error if ctx becomes
+// done; any server entries stored up to that point are retained.
 func StreamingStoreServerEntries(
 func StreamingStoreServerEntries(
+	ctx context.Context,
 	config *Config,
 	config *Config,
 	serverEntries *protocol.StreamingServerEntryDecoder,
 	serverEntries *protocol.StreamingServerEntryDecoder,
 	replaceIfExists bool) error {
 	replaceIfExists bool) error {
@@ -376,6 +380,13 @@ func StreamingStoreServerEntries(
 
 
 	n := 0
 	n := 0
 	for {
 	for {
+
+		select {
+		case <-ctx.Done():
+			return errors.Trace(ctx.Err())
+		default:
+		}
+
 		serverEntry, err := serverEntries.Next()
 		serverEntry, err := serverEntries.Next()
 		if err != nil {
 		if err != nil {
 			return errors.Trace(err)
 			return errors.Trace(err)
@@ -383,7 +394,7 @@ func StreamingStoreServerEntries(
 
 
 		if serverEntry == nil {
 		if serverEntry == nil {
 			// No more server entries
 			// No more server entries
-			break
+			return nil
 		}
 		}
 
 
 		err = StoreServerEntry(serverEntry, replaceIfExists)
 		err = StoreServerEntry(serverEntry, replaceIfExists)
@@ -404,8 +415,11 @@ func StreamingStoreServerEntries(
 // ImportEmbeddedServerEntries loads, decodes, and stores a list of server
 // ImportEmbeddedServerEntries loads, decodes, and stores a list of server
 // entries. If embeddedServerEntryListFilename is not empty,
 // entries. If embeddedServerEntryListFilename is not empty,
 // embeddedServerEntryList will be ignored and the encoded server entry list
 // embeddedServerEntryList will be ignored and the encoded server entry list
-// will be loaded from the specified file.
+// will be loaded from the specified file. The import process stops early if
+// ctx becomes done; any server entries imported up to that point are
+// retained.
 func ImportEmbeddedServerEntries(
 func ImportEmbeddedServerEntries(
+	ctx context.Context,
 	config *Config,
 	config *Config,
 	embeddedServerEntryListFilename string,
 	embeddedServerEntryListFilename string,
 	embeddedServerEntryList string) error {
 	embeddedServerEntryList string) error {
@@ -419,6 +433,7 @@ func ImportEmbeddedServerEntries(
 		defer file.Close()
 		defer file.Close()
 
 
 		err = StreamingStoreServerEntries(
 		err = StreamingStoreServerEntries(
+			ctx,
 			config,
 			config,
 			protocol.NewStreamingServerEntryDecoder(
 			protocol.NewStreamingServerEntryDecoder(
 				file,
 				file,

+ 2 - 0
psiphon/remoteServerList.go

@@ -104,6 +104,7 @@ func FetchCommonRemoteServerList(
 	authenticatedDownload = true
 	authenticatedDownload = true
 
 
 	err = StreamingStoreServerEntries(
 	err = StreamingStoreServerEntries(
+		ctx,
 		config,
 		config,
 		protocol.NewStreamingServerEntryDecoder(
 		protocol.NewStreamingServerEntryDecoder(
 			serverListPayloadReader,
 			serverListPayloadReader,
@@ -383,6 +384,7 @@ func downloadOSLFileSpec(
 	authenticatedDownload = true
 	authenticatedDownload = true
 
 
 	err = StreamingStoreServerEntries(
 	err = StreamingStoreServerEntries(
+		ctx,
 		config,
 		config,
 		protocol.NewStreamingServerEntryDecoder(
 		protocol.NewStreamingServerEntryDecoder(
 			serverListPayloadReader,
 			serverListPayloadReader,