Просмотр исходного кода

Completion of untunneled upgrade download

* Automated tests for tunneled, untunneled
  and resumed upgrade download.
* Added "disruptor" proxy capability to
  controller test.
* Fix: available upgrade version from HEAD
  request was blank due to variable shadowing.
Rod Hynes 10 лет назад
Родитель
Сommit
4bc185afa4
5 измененных файлов с 274 добавлено и 55 удалено
  1. 1 1
      .travis.yml
  2. BIN
      psiphon/controller_test.config.enc
  3. 258 47
      psiphon/controller_test.go
  4. 5 0
      psiphon/notice.go
  5. 10 7
      psiphon/upgradeDownload.go

+ 1 - 1
.travis.yml

@@ -16,5 +16,5 @@ before_install:
 - go get github.com/axw/gocov/gocov
 - go get github.com/mattn/goveralls
 - if ! go get github.com/golang/tools/cmd/cover; then go get golang.org/x/tools/cmd/cover; fi
-- openssl aes-256-cbc -K $encrypted_9e40808ea1e2_key -iv $encrypted_9e40808ea1e2_iv
+- openssl aes-256-cbc -K $encrypted_ae0fe824cc69_key -iv $encrypted_ae0fe824cc69_iv
   -in psiphon/controller_test.config.enc -out psiphon/controller_test.config -d

BIN
psiphon/controller_test.config.enc


+ 258 - 47
psiphon/controller_test.go

@@ -20,16 +20,45 @@
 package psiphon
 
 import (
+	"flag"
 	"fmt"
+	"io"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"net/url"
+	"os"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
+
+	socks "github.com/Psiphon-Inc/goptlib"
 )
 
+func TestMain(m *testing.M) {
+	flag.Parse()
+	os.Remove(DATA_STORE_FILENAME)
+	initDisruptor()
+	setEmitDiagnosticNotices(true)
+	os.Exit(m.Run())
+}
+
+// Note: untunneled upgrade tests must execute before
+// the "Run" tests to ensure no tunnel is established.
+// We need a way to reset the datastore after it's been
+// initialized in order to to clear out its data entries
+// and be able to arbitrarily order the tests.
+
+func TestUntunneledUpgradeDownload(t *testing.T) {
+	doUntunnledUpgradeDownload(t, false)
+}
+
+func TestUntunneledResumableUpgradeDownload(t *testing.T) {
+	doUntunnledUpgradeDownload(t, true)
+}
+
 func TestControllerRunSSH(t *testing.T) {
 	controllerRun(t, TUNNEL_PROTOCOL_SSH, false)
 }
@@ -54,6 +83,120 @@ func TestControllerRunUnfrontedMeekHTTPS(t *testing.T) {
 	controllerRun(t, TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS, true)
 }
 
+func doUntunnledUpgradeDownload(t *testing.T, disrupt bool) {
+
+	configFileContents, err := ioutil.ReadFile("controller_test.config")
+	if err != nil {
+		// Skip, don't fail, if config file is not present
+		t.Skipf("error loading configuration file: %s", err)
+	}
+	config, err := LoadConfig(configFileContents)
+	if err != nil {
+		t.Fatalf("error processing configuration file: %s", err)
+	}
+
+	if disrupt {
+		config.UpstreamProxyUrl = disruptorProxyURL
+	}
+
+	// Clear remote server list so tunnel cannot be established and
+	// untunneled upgrade download case is tested.
+	config.RemoteServerListUrl = ""
+
+	os.Remove(config.UpgradeDownloadFilename)
+
+	err = InitDataStore(config)
+	if err != nil {
+		t.Fatalf("error initializing datastore: %s", err)
+	}
+
+	controller, err := NewController(config)
+	if err != nil {
+		t.Fatalf("error creating controller: %s", err)
+	}
+
+	upgradeDownloaded := make(chan struct{}, 1)
+
+	var clientUpgradeDownloadedBytesCount int32
+
+	SetNoticeOutput(NewNoticeReceiver(
+		func(notice []byte) {
+			// TODO: log notices without logging server IPs:
+			// fmt.Fprintf(os.Stderr, "%s\n", string(notice))
+			noticeType, payload, err := GetNotice(notice)
+			if err != nil {
+				return
+			}
+			switch noticeType {
+			case "Tunnels":
+				count := int(payload["count"].(float64))
+				if count > 0 {
+					// TODO: wrong goroutine for t.FatalNow()
+					t.Fatalf("tunnel established unexpectedly")
+				}
+			case "ClientUpgradeDownloadedBytes":
+				atomic.AddInt32(&clientUpgradeDownloadedBytesCount, 1)
+				t.Logf("ClientUpgradeDownloadedBytes: %d", int(payload["bytes"].(float64)))
+			case "ClientUpgradeDownloaded":
+				select {
+				case upgradeDownloaded <- *new(struct{}):
+				default:
+				}
+			}
+		}))
+
+	// Run controller
+
+	shutdownBroadcast := make(chan struct{})
+	controllerWaitGroup := new(sync.WaitGroup)
+	controllerWaitGroup.Add(1)
+	go func() {
+		defer controllerWaitGroup.Done()
+		controller.Run(shutdownBroadcast)
+	}()
+
+	defer func() {
+		// Test: shutdown must complete within 10 seconds
+
+		close(shutdownBroadcast)
+
+		shutdownTimeout := time.NewTimer(10 * time.Second)
+
+		shutdownOk := make(chan struct{}, 1)
+		go func() {
+			controllerWaitGroup.Wait()
+			shutdownOk <- *new(struct{})
+		}()
+
+		select {
+		case <-shutdownOk:
+		case <-shutdownTimeout.C:
+			t.Fatalf("controller shutdown timeout exceeded")
+		}
+	}()
+
+	// Test: upgrade must be downloaded within 120 seconds
+
+	downloadTimeout := time.NewTimer(120 * time.Second)
+
+	select {
+	case <-upgradeDownloaded:
+		// TODO: verify downloaded file
+
+	case <-downloadTimeout.C:
+		t.Fatalf("upgrade download timeout exceeded")
+	}
+
+	// Test: with disrupt, must be multiple download progress notices
+
+	if disrupt {
+		count := atomic.LoadInt32(&clientUpgradeDownloadedBytesCount)
+		if count <= 1 {
+			t.Fatalf("unexpected upgrade download progress: %d", count)
+		}
+	}
+}
+
 type TestHostNameTransformer struct {
 }
 
@@ -78,23 +221,26 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
 	}
 	config, err := LoadConfig(configFileContents)
 	if err != nil {
-		t.Errorf("error processing configuration file: %s", err)
-		t.FailNow()
+		t.Fatalf("error processing configuration file: %s", err)
 	}
+
+	// Disable untunneled upgrade downloader to ensure tunneled case is tested
+	config.UpgradeDownloadClientVersionHeader = ""
+
+	os.Remove(config.UpgradeDownloadFilename)
+
 	config.TunnelProtocol = protocol
 
 	config.HostNameTransformer = hostNameTransformer
 
 	err = InitDataStore(config)
 	if err != nil {
-		t.Errorf("error initializing datastore: %s", err)
-		t.FailNow()
+		t.Fatalf("error initializing datastore: %s", err)
 	}
 
 	controller, err := NewController(config)
 	if err != nil {
-		t.Errorf("error creating controller: %s", err)
-		t.FailNow()
+		t.Fatalf("error creating controller: %s", err)
 	}
 
 	// Monitor notices for "Tunnels" with count > 1, the
@@ -105,6 +251,8 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
 	httpProxyPort := 0
 
 	tunnelEstablished := make(chan struct{}, 1)
+	upgradeDownloaded := make(chan struct{}, 1)
+
 	SetNoticeOutput(NewNoticeReceiver(
 		func(notice []byte) {
 			// TODO: log notices without logging server IPs:
@@ -122,13 +270,20 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
 					default:
 					}
 				}
+			case "ClientUpgradeDownloadedBytes":
+				t.Logf("ClientUpgradeDownloadedBytes: %d", int(payload["bytes"].(float64)))
+			case "ClientUpgradeDownloaded":
+				select {
+				case upgradeDownloaded <- *new(struct{}):
+				default:
+				}
 			case "ListeningHttpProxyPort":
 				httpProxyPort = int(payload["port"].(float64))
 			case "ConnectingServer":
 				serverProtocol := payload["protocol"]
 				if serverProtocol != protocol {
-					t.Errorf("wrong protocol selected: %s", serverProtocol)
-					t.FailNow()
+					// TODO: wrong goroutine for t.FatalNow()
+					t.Fatalf("wrong protocol selected: %s", serverProtocol)
 				}
 			}
 		}))
@@ -143,6 +298,26 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
 		controller.Run(shutdownBroadcast)
 	}()
 
+	defer func() {
+		// Test: shutdown must complete within 10 seconds
+
+		close(shutdownBroadcast)
+
+		shutdownTimeout := time.NewTimer(10 * time.Second)
+
+		shutdownOk := make(chan struct{}, 1)
+		go func() {
+			controllerWaitGroup.Wait()
+			shutdownOk <- *new(struct{})
+		}()
+
+		select {
+		case <-shutdownOk:
+		case <-shutdownTimeout.C:
+			t.Fatalf("controller shutdown timeout exceeded")
+		}
+	}()
+
 	// Test: tunnel must be established within 60 seconds
 
 	establishTimeout := time.NewTimer(60 * time.Second)
@@ -150,33 +325,26 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
 	select {
 	case <-tunnelEstablished:
 
-		// Allow for known race condition described in NewHttpProxy():
-		time.Sleep(1 * time.Second)
-
-		// Test: fetch website through tunnel
-		fetchWebsite(t, httpProxyPort)
-
 	case <-establishTimeout.C:
-		t.Errorf("tunnel establish timeout exceeded")
-		// ...continue with cleanup
+		t.Fatalf("tunnel establish timeout exceeded")
 	}
 
-	close(shutdownBroadcast)
+	// Allow for known race condition described in NewHttpProxy():
+	time.Sleep(1 * time.Second)
 
-	// Test: shutdown must complete within 10 seconds
+	// Test: fetch website through tunnel
+	fetchWebsite(t, httpProxyPort)
 
-	shutdownTimeout := time.NewTimer(10 * time.Second)
+	// Test: upgrade must be downloaded within 60 seconds
 
-	shutdownOk := make(chan struct{}, 1)
-	go func() {
-		controllerWaitGroup.Wait()
-		shutdownOk <- *new(struct{})
-	}()
+	downloadTimeout := time.NewTimer(60 * time.Second)
 
 	select {
-	case <-shutdownOk:
-	case <-shutdownTimeout.C:
-		t.Errorf("controller shutdown timeout exceeded")
+	case <-upgradeDownloaded:
+		// TODO: verify downloaded file
+
+	case <-downloadTimeout.C:
+		t.Fatalf("upgrade download timeout exceeded")
 	}
 }
 
@@ -194,8 +362,7 @@ func fetchWebsite(t *testing.T, httpProxyPort int) {
 
 	proxyUrl, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", httpProxyPort))
 	if err != nil {
-		t.Errorf("error initializing proxied HTTP request: %s", err)
-		t.FailNow()
+		t.Fatalf("error initializing proxied HTTP request: %s", err)
 	}
 
 	httpClient := &http.Client{
@@ -207,20 +374,17 @@ func fetchWebsite(t *testing.T, httpProxyPort int) {
 
 	response, err := httpClient.Get(testUrl)
 	if err != nil {
-		t.Errorf("error sending proxied HTTP request: %s", err)
-		t.FailNow()
+		t.Fatalf("error sending proxied HTTP request: %s", err)
 	}
 
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
-		t.Errorf("error reading proxied HTTP response: %s", err)
-		t.FailNow()
+		t.Fatalf("error reading proxied HTTP response: %s", err)
 	}
 	response.Body.Close()
 
 	if !checkResponse(string(body)) {
-		t.Errorf("unexpected proxied HTTP response")
-		t.FailNow()
+		t.Fatalf("unexpected proxied HTTP response")
 	}
 
 	// Test: use direct URL proxy
@@ -234,20 +398,17 @@ func fetchWebsite(t *testing.T, httpProxyPort int) {
 		fmt.Sprintf("http://127.0.0.1:%d/direct/%s",
 			httpProxyPort, url.QueryEscape(testUrl)))
 	if err != nil {
-		t.Errorf("error sending direct URL request: %s", err)
-		t.FailNow()
+		t.Fatalf("error sending direct URL request: %s", err)
 	}
 
 	body, err = ioutil.ReadAll(response.Body)
 	if err != nil {
-		t.Errorf("error reading direct URL response: %s", err)
-		t.FailNow()
+		t.Fatalf("error reading direct URL response: %s", err)
 	}
 	response.Body.Close()
 
 	if !checkResponse(string(body)) {
-		t.Errorf("unexpected direct URL response")
-		t.FailNow()
+		t.Fatalf("unexpected direct URL response")
 	}
 
 	// Test: use tunneled URL proxy
@@ -256,19 +417,69 @@ func fetchWebsite(t *testing.T, httpProxyPort int) {
 		fmt.Sprintf("http://127.0.0.1:%d/tunneled/%s",
 			httpProxyPort, url.QueryEscape(testUrl)))
 	if err != nil {
-		t.Errorf("error sending tunneled URL request: %s", err)
-		t.FailNow()
+		t.Fatalf("error sending tunneled URL request: %s", err)
 	}
 
 	body, err = ioutil.ReadAll(response.Body)
 	if err != nil {
-		t.Errorf("error reading tunneled URL response: %s", err)
-		t.FailNow()
+		t.Fatalf("error reading tunneled URL response: %s", err)
 	}
 	response.Body.Close()
 
 	if !checkResponse(string(body)) {
-		t.Errorf("unexpected tunneled URL response")
-		t.FailNow()
+		t.Fatalf("unexpected tunneled URL response")
 	}
 }
+
+const disruptorProxyAddress = "127.0.0.1:2160"
+const disruptorProxyURL = "socks4a://" + disruptorProxyAddress
+const disruptorMaxConnectionBytes = 2000000
+const disruptorMaxConnectionTime = 15 * time.Second
+
+func initDisruptor() {
+
+	go func() {
+		listener, err := socks.ListenSocks("tcp", disruptorProxyAddress)
+		if err != nil {
+			fmt.Errorf("disruptor proxy listen error: %s", err)
+			return
+		}
+		for {
+			localConn, err := listener.AcceptSocks()
+			if err != nil {
+				fmt.Errorf("disruptor proxy accept error: %s", err)
+				return
+			}
+			go func() {
+				defer localConn.Close()
+				remoteConn, err := net.Dial("tcp", localConn.Req.Target)
+				if err != nil {
+					fmt.Errorf("disruptor proxy dial error: %s", err)
+					return
+				}
+				defer remoteConn.Close()
+				err = localConn.Grant(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 0})
+				if err != nil {
+					fmt.Errorf("disruptor proxy grant error: %s", err)
+					return
+				}
+
+				// Cut connection after disruptorMaxConnectionTime
+				time.AfterFunc(disruptorMaxConnectionTime, func() {
+					localConn.Close()
+					remoteConn.Close()
+				})
+
+				// Relay connection, but only up to disruptorMaxConnectionBytes
+				waitGroup := new(sync.WaitGroup)
+				waitGroup.Add(1)
+				go func() {
+					defer waitGroup.Done()
+					io.CopyN(localConn, remoteConn, disruptorMaxConnectionBytes)
+				}()
+				io.CopyN(remoteConn, localConn, disruptorMaxConnectionBytes)
+				waitGroup.Wait()
+			}()
+		}
+	}()
+}

+ 5 - 0
psiphon/notice.go

@@ -228,6 +228,11 @@ func NoticeUpstreamProxyError(err error) {
 	outputNotice("UpstreamProxyError", false, true, "message", err.Error())
 }
 
+// NoticeClientUpgradeDownloadedBytes reports client upgrade download progress.
+func NoticeClientUpgradeDownloadedBytes(bytes int64) {
+	outputNotice("ClientUpgradeDownloadedBytes", false, false, "bytes", bytes)
+}
+
 // NoticeClientUpgradeDownloaded indicates that a client upgrade download
 // is complete and available at the destination specified.
 func NoticeClientUpgradeDownloaded(filename string) {

+ 10 - 7
psiphon/upgradeDownload.go

@@ -112,8 +112,8 @@ func DownloadUpgrade(
 
 		// Note: if the header is missing, Header.Get returns "" and then
 		// strconv.Atoi returns a parse error.
-		headerValue := response.Header.Get(config.UpgradeDownloadClientVersionHeader)
-		availableClientVersion, err := strconv.Atoi(headerValue)
+		availableClientVersion := response.Header.Get(config.UpgradeDownloadClientVersionHeader)
+		checkAvailableClientVersion, err := strconv.Atoi(availableClientVersion)
 		if err != nil {
 			// If the header is missing or malformed, we can't determine the available
 			// version number. This is unexpected; but if it happens, it's likely due
@@ -123,12 +123,14 @@ func DownloadUpgrade(
 			// download later in the session).
 			NoticeAlert(
 				"failed to download upgrade: invalid %s header value %s: %s",
-				config.UpgradeDownloadClientVersionHeader, headerValue, err)
+				config.UpgradeDownloadClientVersionHeader, availableClientVersion, err)
 			return nil
 		}
 
-		if currentClientVersion >= availableClientVersion {
-			NoticeInfo("skipping download of available client version %d", availableClientVersion)
+		if currentClientVersion >= checkAvailableClientVersion {
+			NoticeInfo(
+				"skipping download of available client version %d",
+				checkAvailableClientVersion)
 		}
 	}
 
@@ -218,12 +220,13 @@ func DownloadUpgrade(
 	// A partial download occurs when this copy is interrupted. The io.Copy
 	// will fail, leaving a partial download in place (.part and .part.etag).
 	n, err := io.Copy(NewSyncFileWriter(file), response.Body)
+
+	NoticeClientUpgradeDownloadedBytes(n)
+
 	if err != nil {
 		return ContextError(err)
 	}
 
-	NoticeInfo("client upgrade downloaded bytes: %d", n)
-
 	// Ensure the file is flushed to disk. The deferred close
 	// will be a noop when this succeeds.
 	err = file.Close()