Эх сурвалжийг харах

Fix: return ErrTimeout on timeouts

- StartTunnel was returning "exited unexpectedly" in timeout
  cases. Now returns ErrTimeout when EstablishTunnelTimeout is
  posted or the input Context times out.

- Cover timeout cases with tests that check the error return value.

- Use buffered channels to ensure delivery of event signals.
Rod Hynes 5 жил өмнө
parent
commit
96e2a838dc

+ 29 - 14
ClientLibrary/clientlib/clientlib.go

@@ -114,9 +114,12 @@ var ErrTimeout = std_errors.New("clientlib: tunnel establishment timeout")
 //
 //
 // noticeReceiver, if non-nil, will be called for each notice emitted by tunnel core.
 // noticeReceiver, if non-nil, will be called for each notice emitted by tunnel core.
 // NOTE: Ordinary users of this library should never need this and should pass nil.
 // NOTE: Ordinary users of this library should never need this and should pass nil.
-func StartTunnel(ctx context.Context,
-	configJSON []byte, embeddedServerEntryList string,
-	params Parameters, paramsDelta ParametersDelta,
+func StartTunnel(
+	ctx context.Context,
+	configJSON []byte,
+	embeddedServerEntryList string,
+	params Parameters,
+	paramsDelta ParametersDelta,
 	noticeReceiver func(NoticeEvent)) (retTunnel *PsiphonTunnel, retErr error) {
 	noticeReceiver func(NoticeEvent)) (retTunnel *PsiphonTunnel, retErr error) {
 
 
 	config, err := psiphon.LoadConfig(configJSON)
 	config, err := psiphon.LoadConfig(configJSON)
@@ -170,11 +173,9 @@ func StartTunnel(ctx context.Context,
 	}
 	}
 
 
 	// Will receive a value when the tunnel has successfully connected.
 	// Will receive a value when the tunnel has successfully connected.
-	connected := make(chan struct{})
-	// Will receive a value if the tunnel times out trying to connect.
-	timedOut := make(chan struct{})
+	connected := make(chan struct{}, 1)
 	// Will receive a value if an error occurs during the connection sequence.
 	// Will receive a value if an error occurs during the connection sequence.
-	errored := make(chan error)
+	errored := make(chan error, 1)
 
 
 	// Create the tunnel object
 	// Create the tunnel object
 	tunnel := new(PsiphonTunnel)
 	tunnel := new(PsiphonTunnel)
@@ -203,7 +204,7 @@ func StartTunnel(ctx context.Context,
 				tunnel.SOCKSProxyPort = int(port)
 				tunnel.SOCKSProxyPort = int(port)
 			} else if event.Type == "EstablishTunnelTimeout" {
 			} else if event.Type == "EstablishTunnelTimeout" {
 				select {
 				select {
-				case timedOut <- struct{}{}:
+				case errored <- ErrTimeout:
 				default:
 				default:
 				}
 				}
 			} else if event.Type == "Tunnels" {
 			} else if event.Type == "Tunnels" {
@@ -287,22 +288,36 @@ func StartTunnel(ctx context.Context,
 		// Start the tunnel. Only returns on error (or internal timeout).
 		// Start the tunnel. Only returns on error (or internal timeout).
 		controller.Run(controllerCtx)
 		controller.Run(controllerCtx)
 
 
+		// controller.Run does not exit until the goroutine that posts
+		// EstablishTunnelTimeout has terminated; so, if there was a
+		// EstablishTunnelTimeout event, ErrTimeout is guaranteed to be sent to
+		// errord before this next error and will be the StartTunnel return value.
+
+		var err error
+		switch ctx.Err() {
+		case context.DeadlineExceeded:
+			err = ErrTimeout
+		case context.Canceled:
+			err = errors.TraceNew("StartTunnel canceled")
+		default:
+			err = errors.TraceNew("controller.Run exited unexpectedly")
+		}
 		select {
 		select {
-		case errored <- errors.TraceNew("controller.Run exited unexpectedly"):
+		case errored <- err:
 		default:
 		default:
 		}
 		}
 	}()
 	}()
 
 
-	// Wait for an active tunnel, timeout, or error
+	// Wait for an active tunnel or error
 	select {
 	select {
 	case <-connected:
 	case <-connected:
 		return tunnel, nil
 		return tunnel, nil
-	case <-timedOut:
-		tunnel.Stop()
-		return nil, ErrTimeout
 	case err := <-errored:
 	case err := <-errored:
 		tunnel.Stop()
 		tunnel.Stop()
-		return nil, errors.TraceMsg(err, "tunnel start produced error")
+		if err != ErrTimeout {
+			err = errors.TraceMsg(err, "tunnel start produced error")
+		}
+		return nil, err
 	}
 	}
 }
 }
 
 

+ 84 - 64
ClientLibrary/clientlib/clientlib_test.go

@@ -21,67 +21,49 @@ package clientlib
 
 
 import (
 import (
 	"context"
 	"context"
-	"flag"
-	"fmt"
+	"encoding/json"
 	"io/ioutil"
 	"io/ioutil"
 	"os"
 	"os"
 	"testing"
 	"testing"
 	"time"
 	"time"
 )
 )
 
 
-var testDataDirName string
-
-func TestMain(m *testing.M) {
-	flag.Parse()
-
-	var err error
-	testDataDirName, err = ioutil.TempDir("", "psiphon-clientlib-test")
-	if err != nil {
-		fmt.Printf("TempDir failed: %s\n", err)
-		os.Exit(1)
-	}
-	defer os.RemoveAll(testDataDirName)
+func TestStartTunnel(t *testing.T) {
+	// TODO: More comprehensive tests. This is only a smoke test.
 
 
-	os.Exit(m.Run())
-}
+	clientPlatform := "clientlib_test.go"
+	networkID := "UNKNOWN"
+	timeout := 60
+	quickTimeout := 1
 
 
-func getConfigJSON(t *testing.T) []byte {
 	configJSON, err := ioutil.ReadFile("../../psiphon/controller_test.config")
 	configJSON, err := ioutil.ReadFile("../../psiphon/controller_test.config")
 	if err != nil {
 	if err != nil {
 		// Skip, don't fail, if config file is not present
 		// Skip, don't fail, if config file is not present
 		t.Skipf("error loading configuration file: %s", err)
 		t.Skipf("error loading configuration file: %s", err)
 	}
 	}
 
 
-	return configJSON
-}
-
-func TestStartTunnel(t *testing.T) {
-	// TODO: More comprehensive tests. This is only a smoke test.
+	// Initialize a fresh datastore and create a modified config which cannot
+	// connect without known servers, to be used in timeout cases.
 
 
-	configJSON := getConfigJSON(t)
-	clientPlatform := "clientlib_test.go"
-	networkID := "UNKNOWN"
-	timeout := 60
-
-	// Cancels the context after a duration. Pass 0 for no cancel.
-	// (Note that cancelling causes an error, not a timeout.)
-	contextGetter := func(cancelAfter time.Duration) func() context.Context {
-		return func() context.Context {
-			if cancelAfter == 0 {
-				return context.Background()
-			}
+	testDataDirName, err := ioutil.TempDir("", "psiphon-clientlib-test")
+	if err != nil {
+		t.Fatalf("ioutil.TempDir failed: %v", err)
+	}
+	defer os.RemoveAll(testDataDirName)
 
 
-			ctx, ctxCancel := context.WithCancel(context.Background())
-			go func() {
-				time.Sleep(cancelAfter)
-				ctxCancel()
-			}()
-			return ctx
-		}
+	var config map[string]interface{}
+	err = json.Unmarshal(configJSON, &config)
+	if err != nil {
+		t.Fatalf("json.Unmarshal failed: %v", err)
+	}
+	config["DisableRemoteServerListFetcher"] = true
+	configJSONNoFetcher, err := json.Marshal(config)
+	if err != nil {
+		t.Fatalf("json.Marshal failed: %v", err)
 	}
 	}
 
 
 	type args struct {
 	type args struct {
-		ctxGetter               func() context.Context
+		ctxTimeout              time.Duration
 		configJSON              []byte
 		configJSON              []byte
 		embeddedServerEntryList string
 		embeddedServerEntryList string
 		params                  Parameters
 		params                  Parameters
@@ -89,16 +71,16 @@ func TestStartTunnel(t *testing.T) {
 		noticeReceiver          func(NoticeEvent)
 		noticeReceiver          func(NoticeEvent)
 	}
 	}
 	tests := []struct {
 	tests := []struct {
-		name       string
-		args       args
-		wantTunnel bool
-		wantErr    bool
+		name        string
+		args        args
+		wantTunnel  bool
+		expectedErr error
 	}{
 	}{
 		{
 		{
-			name: "Success: simple",
+			name: "Failure: context timeout",
 			args: args{
 			args: args{
-				ctxGetter:               contextGetter(0),
-				configJSON:              configJSON,
+				ctxTimeout:              10 * time.Millisecond,
+				configJSON:              configJSONNoFetcher,
 				embeddedServerEntryList: "",
 				embeddedServerEntryList: "",
 				params: Parameters{
 				params: Parameters{
 					DataRootDirectory:             &testDataDirName,
 					DataRootDirectory:             &testDataDirName,
@@ -109,13 +91,31 @@ func TestStartTunnel(t *testing.T) {
 				paramsDelta:    nil,
 				paramsDelta:    nil,
 				noticeReceiver: nil,
 				noticeReceiver: nil,
 			},
 			},
-			wantTunnel: true,
-			wantErr:    false,
+			wantTunnel:  false,
+			expectedErr: ErrTimeout,
+		},
+		{
+			name: "Failure: config timeout",
+			args: args{
+				ctxTimeout:              0,
+				configJSON:              configJSONNoFetcher,
+				embeddedServerEntryList: "",
+				params: Parameters{
+					DataRootDirectory:             &testDataDirName,
+					ClientPlatform:                &clientPlatform,
+					NetworkID:                     &networkID,
+					EstablishTunnelTimeoutSeconds: &quickTimeout,
+				},
+				paramsDelta:    nil,
+				noticeReceiver: nil,
+			},
+			wantTunnel:  false,
+			expectedErr: ErrTimeout,
 		},
 		},
 		{
 		{
-			name: "Failure: timeout",
+			name: "Success: simple",
 			args: args{
 			args: args{
-				ctxGetter:               contextGetter(10 * time.Millisecond),
+				ctxTimeout:              0,
 				configJSON:              configJSON,
 				configJSON:              configJSON,
 				embeddedServerEntryList: "",
 				embeddedServerEntryList: "",
 				params: Parameters{
 				params: Parameters{
@@ -127,26 +127,46 @@ func TestStartTunnel(t *testing.T) {
 				paramsDelta:    nil,
 				paramsDelta:    nil,
 				noticeReceiver: nil,
 				noticeReceiver: nil,
 			},
 			},
-			wantTunnel: false,
-			wantErr:    true,
+			wantTunnel:  true,
+			expectedErr: nil,
 		},
 		},
 	}
 	}
 	for _, tt := range tests {
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
-			gotTunnel, err := StartTunnel(tt.args.ctxGetter(),
-				tt.args.configJSON, tt.args.embeddedServerEntryList,
-				tt.args.params, tt.args.paramsDelta, tt.args.noticeReceiver)
-			if (err != nil) != tt.wantErr {
-				t.Fatalf("StartTunnel() error = %v, wantErr %v", err, tt.wantErr)
-				return
+
+			ctx := context.Background()
+			var cancelFunc context.CancelFunc
+			if tt.args.ctxTimeout > 0 {
+				ctx, cancelFunc = context.WithTimeout(ctx, tt.args.ctxTimeout)
 			}
 			}
-			if (gotTunnel != nil) != tt.wantTunnel {
+
+			tunnel, err := StartTunnel(
+				ctx,
+				tt.args.configJSON,
+				tt.args.embeddedServerEntryList,
+				tt.args.params,
+				tt.args.paramsDelta,
+				tt.args.noticeReceiver)
+
+			gotTunnel := (tunnel != nil)
+
+			if cancelFunc != nil {
+				cancelFunc()
+			}
+
+			if tunnel != nil {
+				tunnel.Stop()
+			}
+
+			if gotTunnel != tt.wantTunnel {
 				t.Errorf("StartTunnel() gotTunnel = %v, wantTunnel %v", err, tt.wantTunnel)
 				t.Errorf("StartTunnel() gotTunnel = %v, wantTunnel %v", err, tt.wantTunnel)
 			}
 			}
 
 
-			if gotTunnel != nil {
-				gotTunnel.Stop()
+			if err != tt.expectedErr {
+				t.Fatalf("StartTunnel() error = %v, expectedErr %v", err, tt.expectedErr)
+				return
 			}
 			}
+
 		})
 		})
 	}
 	}
 }
 }