|
|
@@ -21,67 +21,49 @@ package clientlib
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
- "flag"
|
|
|
- "fmt"
|
|
|
+ "encoding/json"
|
|
|
"io/ioutil"
|
|
|
"os"
|
|
|
"testing"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
-var testDataDirName string
|
|
|
-
|
|
|
-func TestMain(m *testing.M) {
|
|
|
- flag.Parse()
|
|
|
-
|
|
|
- var err error
|
|
|
- testDataDirName, err = ioutil.TempDir("", "psiphon-clientlib-test")
|
|
|
- if err != nil {
|
|
|
- fmt.Printf("TempDir failed: %s\n", err)
|
|
|
- os.Exit(1)
|
|
|
- }
|
|
|
- defer os.RemoveAll(testDataDirName)
|
|
|
+func TestStartTunnel(t *testing.T) {
|
|
|
+ // TODO: More comprehensive tests. This is only a smoke test.
|
|
|
|
|
|
- os.Exit(m.Run())
|
|
|
-}
|
|
|
+ clientPlatform := "clientlib_test.go"
|
|
|
+ networkID := "UNKNOWN"
|
|
|
+ timeout := 60
|
|
|
+ quickTimeout := 1
|
|
|
|
|
|
-func getConfigJSON(t *testing.T) []byte {
|
|
|
configJSON, err := ioutil.ReadFile("../../psiphon/controller_test.config")
|
|
|
if err != nil {
|
|
|
// Skip, don't fail, if config file is not present
|
|
|
t.Skipf("error loading configuration file: %s", err)
|
|
|
}
|
|
|
|
|
|
- return configJSON
|
|
|
-}
|
|
|
-
|
|
|
-func TestStartTunnel(t *testing.T) {
|
|
|
- // TODO: More comprehensive tests. This is only a smoke test.
|
|
|
+ // Initialize a fresh datastore and create a modified config which cannot
|
|
|
+ // connect without known servers, to be used in timeout cases.
|
|
|
|
|
|
- configJSON := getConfigJSON(t)
|
|
|
- clientPlatform := "clientlib_test.go"
|
|
|
- networkID := "UNKNOWN"
|
|
|
- timeout := 60
|
|
|
-
|
|
|
- // Cancels the context after a duration. Pass 0 for no cancel.
|
|
|
- // (Note that cancelling causes an error, not a timeout.)
|
|
|
- contextGetter := func(cancelAfter time.Duration) func() context.Context {
|
|
|
- return func() context.Context {
|
|
|
- if cancelAfter == 0 {
|
|
|
- return context.Background()
|
|
|
- }
|
|
|
+ testDataDirName, err := ioutil.TempDir("", "psiphon-clientlib-test")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("ioutil.TempDir failed: %v", err)
|
|
|
+ }
|
|
|
+ defer os.RemoveAll(testDataDirName)
|
|
|
|
|
|
- ctx, ctxCancel := context.WithCancel(context.Background())
|
|
|
- go func() {
|
|
|
- time.Sleep(cancelAfter)
|
|
|
- ctxCancel()
|
|
|
- }()
|
|
|
- return ctx
|
|
|
- }
|
|
|
+ var config map[string]interface{}
|
|
|
+ err = json.Unmarshal(configJSON, &config)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("json.Unmarshal failed: %v", err)
|
|
|
+ }
|
|
|
+ config["DisableRemoteServerListFetcher"] = true
|
|
|
+ configJSONNoFetcher, err := json.Marshal(config)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("json.Marshal failed: %v", err)
|
|
|
}
|
|
|
|
|
|
type args struct {
|
|
|
- ctxGetter func() context.Context
|
|
|
+ ctxTimeout time.Duration
|
|
|
configJSON []byte
|
|
|
embeddedServerEntryList string
|
|
|
params Parameters
|
|
|
@@ -89,16 +71,16 @@ func TestStartTunnel(t *testing.T) {
|
|
|
noticeReceiver func(NoticeEvent)
|
|
|
}
|
|
|
tests := []struct {
|
|
|
- name string
|
|
|
- args args
|
|
|
- wantTunnel bool
|
|
|
- wantErr bool
|
|
|
+ name string
|
|
|
+ args args
|
|
|
+ wantTunnel bool
|
|
|
+ expectedErr error
|
|
|
}{
|
|
|
{
|
|
|
- name: "Success: simple",
|
|
|
+ name: "Failure: context timeout",
|
|
|
args: args{
|
|
|
- ctxGetter: contextGetter(0),
|
|
|
- configJSON: configJSON,
|
|
|
+ ctxTimeout: 10 * time.Millisecond,
|
|
|
+ configJSON: configJSONNoFetcher,
|
|
|
embeddedServerEntryList: "",
|
|
|
params: Parameters{
|
|
|
DataRootDirectory: &testDataDirName,
|
|
|
@@ -109,13 +91,31 @@ func TestStartTunnel(t *testing.T) {
|
|
|
paramsDelta: nil,
|
|
|
noticeReceiver: nil,
|
|
|
},
|
|
|
- wantTunnel: true,
|
|
|
- wantErr: false,
|
|
|
+ wantTunnel: false,
|
|
|
+ expectedErr: ErrTimeout,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "Failure: config timeout",
|
|
|
+ args: args{
|
|
|
+ ctxTimeout: 0,
|
|
|
+ configJSON: configJSONNoFetcher,
|
|
|
+ embeddedServerEntryList: "",
|
|
|
+ params: Parameters{
|
|
|
+ DataRootDirectory: &testDataDirName,
|
|
|
+ ClientPlatform: &clientPlatform,
|
|
|
+ NetworkID: &networkID,
|
|
|
+ EstablishTunnelTimeoutSeconds: &quickTimeout,
|
|
|
+ },
|
|
|
+ paramsDelta: nil,
|
|
|
+ noticeReceiver: nil,
|
|
|
+ },
|
|
|
+ wantTunnel: false,
|
|
|
+ expectedErr: ErrTimeout,
|
|
|
},
|
|
|
{
|
|
|
- name: "Failure: timeout",
|
|
|
+ name: "Success: simple",
|
|
|
args: args{
|
|
|
- ctxGetter: contextGetter(10 * time.Millisecond),
|
|
|
+ ctxTimeout: 0,
|
|
|
configJSON: configJSON,
|
|
|
embeddedServerEntryList: "",
|
|
|
params: Parameters{
|
|
|
@@ -127,26 +127,46 @@ func TestStartTunnel(t *testing.T) {
|
|
|
paramsDelta: nil,
|
|
|
noticeReceiver: nil,
|
|
|
},
|
|
|
- wantTunnel: false,
|
|
|
- wantErr: true,
|
|
|
+ wantTunnel: true,
|
|
|
+ expectedErr: nil,
|
|
|
},
|
|
|
}
|
|
|
for _, tt := range tests {
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
- gotTunnel, err := StartTunnel(tt.args.ctxGetter(),
|
|
|
- tt.args.configJSON, tt.args.embeddedServerEntryList,
|
|
|
- tt.args.params, tt.args.paramsDelta, tt.args.noticeReceiver)
|
|
|
- if (err != nil) != tt.wantErr {
|
|
|
- t.Fatalf("StartTunnel() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
- return
|
|
|
+
|
|
|
+ ctx := context.Background()
|
|
|
+ var cancelFunc context.CancelFunc
|
|
|
+ if tt.args.ctxTimeout > 0 {
|
|
|
+ ctx, cancelFunc = context.WithTimeout(ctx, tt.args.ctxTimeout)
|
|
|
}
|
|
|
- if (gotTunnel != nil) != tt.wantTunnel {
|
|
|
+
|
|
|
+ tunnel, err := StartTunnel(
|
|
|
+ ctx,
|
|
|
+ tt.args.configJSON,
|
|
|
+ tt.args.embeddedServerEntryList,
|
|
|
+ tt.args.params,
|
|
|
+ tt.args.paramsDelta,
|
|
|
+ tt.args.noticeReceiver)
|
|
|
+
|
|
|
+ gotTunnel := (tunnel != nil)
|
|
|
+
|
|
|
+ if cancelFunc != nil {
|
|
|
+ cancelFunc()
|
|
|
+ }
|
|
|
+
|
|
|
+ if tunnel != nil {
|
|
|
+ tunnel.Stop()
|
|
|
+ }
|
|
|
+
|
|
|
+ if gotTunnel != tt.wantTunnel {
|
|
|
t.Errorf("StartTunnel() gotTunnel = %v, wantTunnel %v", err, tt.wantTunnel)
|
|
|
}
|
|
|
|
|
|
- if gotTunnel != nil {
|
|
|
- gotTunnel.Stop()
|
|
|
+ if err != tt.expectedErr {
|
|
|
+ t.Fatalf("StartTunnel() error = %v, expectedErr %v", err, tt.expectedErr)
|
|
|
+ return
|
|
|
}
|
|
|
+
|
|
|
})
|
|
|
}
|
|
|
}
|