|
|
@@ -20,16 +20,45 @@
|
|
|
package psiphon
|
|
|
|
|
|
import (
|
|
|
+ "flag"
|
|
|
"fmt"
|
|
|
+ "io"
|
|
|
"io/ioutil"
|
|
|
+ "net"
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
+ "os"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
+ "sync/atomic"
|
|
|
"testing"
|
|
|
"time"
|
|
|
+
|
|
|
+ socks "github.com/Psiphon-Inc/goptlib"
|
|
|
)
|
|
|
|
|
|
+func TestMain(m *testing.M) {
|
|
|
+ flag.Parse()
|
|
|
+ os.Remove(DATA_STORE_FILENAME)
|
|
|
+ initDisruptor()
|
|
|
+ setEmitDiagnosticNotices(true)
|
|
|
+ os.Exit(m.Run())
|
|
|
+}
|
|
|
+
|
|
|
+// Note: untunneled upgrade tests must execute before
|
|
|
+// the "Run" tests to ensure no tunnel is established.
|
|
|
+// We need a way to reset the datastore after it's been
|
|
|
+// initialized in order to to clear out its data entries
|
|
|
+// and be able to arbitrarily order the tests.
|
|
|
+
|
|
|
+func TestUntunneledUpgradeDownload(t *testing.T) {
|
|
|
+ doUntunnledUpgradeDownload(t, false)
|
|
|
+}
|
|
|
+
|
|
|
+func TestUntunneledResumableUpgradeDownload(t *testing.T) {
|
|
|
+ doUntunnledUpgradeDownload(t, true)
|
|
|
+}
|
|
|
+
|
|
|
func TestControllerRunSSH(t *testing.T) {
|
|
|
controllerRun(t, TUNNEL_PROTOCOL_SSH, false)
|
|
|
}
|
|
|
@@ -54,6 +83,120 @@ func TestControllerRunUnfrontedMeekHTTPS(t *testing.T) {
|
|
|
controllerRun(t, TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS, true)
|
|
|
}
|
|
|
|
|
|
+func doUntunnledUpgradeDownload(t *testing.T, disrupt bool) {
|
|
|
+
|
|
|
+ configFileContents, err := ioutil.ReadFile("controller_test.config")
|
|
|
+ if err != nil {
|
|
|
+ // Skip, don't fail, if config file is not present
|
|
|
+ t.Skipf("error loading configuration file: %s", err)
|
|
|
+ }
|
|
|
+ config, err := LoadConfig(configFileContents)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error processing configuration file: %s", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if disrupt {
|
|
|
+ config.UpstreamProxyUrl = disruptorProxyURL
|
|
|
+ }
|
|
|
+
|
|
|
+ // Clear remote server list so tunnel cannot be established and
|
|
|
+ // untunneled upgrade download case is tested.
|
|
|
+ config.RemoteServerListUrl = ""
|
|
|
+
|
|
|
+ os.Remove(config.UpgradeDownloadFilename)
|
|
|
+
|
|
|
+ err = InitDataStore(config)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error initializing datastore: %s", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ controller, err := NewController(config)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error creating controller: %s", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ upgradeDownloaded := make(chan struct{}, 1)
|
|
|
+
|
|
|
+ var clientUpgradeDownloadedBytesCount int32
|
|
|
+
|
|
|
+ SetNoticeOutput(NewNoticeReceiver(
|
|
|
+ func(notice []byte) {
|
|
|
+ // TODO: log notices without logging server IPs:
|
|
|
+ // fmt.Fprintf(os.Stderr, "%s\n", string(notice))
|
|
|
+ noticeType, payload, err := GetNotice(notice)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ switch noticeType {
|
|
|
+ case "Tunnels":
|
|
|
+ count := int(payload["count"].(float64))
|
|
|
+ if count > 0 {
|
|
|
+ // TODO: wrong goroutine for t.FatalNow()
|
|
|
+ t.Fatalf("tunnel established unexpectedly")
|
|
|
+ }
|
|
|
+ case "ClientUpgradeDownloadedBytes":
|
|
|
+ atomic.AddInt32(&clientUpgradeDownloadedBytesCount, 1)
|
|
|
+ t.Logf("ClientUpgradeDownloadedBytes: %d", int(payload["bytes"].(float64)))
|
|
|
+ case "ClientUpgradeDownloaded":
|
|
|
+ select {
|
|
|
+ case upgradeDownloaded <- *new(struct{}):
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }))
|
|
|
+
|
|
|
+ // Run controller
|
|
|
+
|
|
|
+ shutdownBroadcast := make(chan struct{})
|
|
|
+ controllerWaitGroup := new(sync.WaitGroup)
|
|
|
+ controllerWaitGroup.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer controllerWaitGroup.Done()
|
|
|
+ controller.Run(shutdownBroadcast)
|
|
|
+ }()
|
|
|
+
|
|
|
+ defer func() {
|
|
|
+ // Test: shutdown must complete within 10 seconds
|
|
|
+
|
|
|
+ close(shutdownBroadcast)
|
|
|
+
|
|
|
+ shutdownTimeout := time.NewTimer(10 * time.Second)
|
|
|
+
|
|
|
+ shutdownOk := make(chan struct{}, 1)
|
|
|
+ go func() {
|
|
|
+ controllerWaitGroup.Wait()
|
|
|
+ shutdownOk <- *new(struct{})
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-shutdownOk:
|
|
|
+ case <-shutdownTimeout.C:
|
|
|
+ t.Fatalf("controller shutdown timeout exceeded")
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Test: upgrade must be downloaded within 120 seconds
|
|
|
+
|
|
|
+ downloadTimeout := time.NewTimer(120 * time.Second)
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-upgradeDownloaded:
|
|
|
+ // TODO: verify downloaded file
|
|
|
+
|
|
|
+ case <-downloadTimeout.C:
|
|
|
+ t.Fatalf("upgrade download timeout exceeded")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test: with disrupt, must be multiple download progress notices
|
|
|
+
|
|
|
+ if disrupt {
|
|
|
+ count := atomic.LoadInt32(&clientUpgradeDownloadedBytesCount)
|
|
|
+ if count <= 1 {
|
|
|
+ t.Fatalf("unexpected upgrade download progress: %d", count)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
type TestHostNameTransformer struct {
|
|
|
}
|
|
|
|
|
|
@@ -78,23 +221,26 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
|
|
|
}
|
|
|
config, err := LoadConfig(configFileContents)
|
|
|
if err != nil {
|
|
|
- t.Errorf("error processing configuration file: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error processing configuration file: %s", err)
|
|
|
}
|
|
|
+
|
|
|
+ // Disable untunneled upgrade downloader to ensure tunneled case is tested
|
|
|
+ config.UpgradeDownloadClientVersionHeader = ""
|
|
|
+
|
|
|
+ os.Remove(config.UpgradeDownloadFilename)
|
|
|
+
|
|
|
config.TunnelProtocol = protocol
|
|
|
|
|
|
config.HostNameTransformer = hostNameTransformer
|
|
|
|
|
|
err = InitDataStore(config)
|
|
|
if err != nil {
|
|
|
- t.Errorf("error initializing datastore: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error initializing datastore: %s", err)
|
|
|
}
|
|
|
|
|
|
controller, err := NewController(config)
|
|
|
if err != nil {
|
|
|
- t.Errorf("error creating controller: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error creating controller: %s", err)
|
|
|
}
|
|
|
|
|
|
// Monitor notices for "Tunnels" with count > 1, the
|
|
|
@@ -105,6 +251,8 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
|
|
|
httpProxyPort := 0
|
|
|
|
|
|
tunnelEstablished := make(chan struct{}, 1)
|
|
|
+ upgradeDownloaded := make(chan struct{}, 1)
|
|
|
+
|
|
|
SetNoticeOutput(NewNoticeReceiver(
|
|
|
func(notice []byte) {
|
|
|
// TODO: log notices without logging server IPs:
|
|
|
@@ -122,13 +270,20 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
|
|
|
default:
|
|
|
}
|
|
|
}
|
|
|
+ case "ClientUpgradeDownloadedBytes":
|
|
|
+ t.Logf("ClientUpgradeDownloadedBytes: %d", int(payload["bytes"].(float64)))
|
|
|
+ case "ClientUpgradeDownloaded":
|
|
|
+ select {
|
|
|
+ case upgradeDownloaded <- *new(struct{}):
|
|
|
+ default:
|
|
|
+ }
|
|
|
case "ListeningHttpProxyPort":
|
|
|
httpProxyPort = int(payload["port"].(float64))
|
|
|
case "ConnectingServer":
|
|
|
serverProtocol := payload["protocol"]
|
|
|
if serverProtocol != protocol {
|
|
|
- t.Errorf("wrong protocol selected: %s", serverProtocol)
|
|
|
- t.FailNow()
|
|
|
+ // TODO: wrong goroutine for t.FatalNow()
|
|
|
+ t.Fatalf("wrong protocol selected: %s", serverProtocol)
|
|
|
}
|
|
|
}
|
|
|
}))
|
|
|
@@ -143,6 +298,26 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
|
|
|
controller.Run(shutdownBroadcast)
|
|
|
}()
|
|
|
|
|
|
+ defer func() {
|
|
|
+ // Test: shutdown must complete within 10 seconds
|
|
|
+
|
|
|
+ close(shutdownBroadcast)
|
|
|
+
|
|
|
+ shutdownTimeout := time.NewTimer(10 * time.Second)
|
|
|
+
|
|
|
+ shutdownOk := make(chan struct{}, 1)
|
|
|
+ go func() {
|
|
|
+ controllerWaitGroup.Wait()
|
|
|
+ shutdownOk <- *new(struct{})
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-shutdownOk:
|
|
|
+ case <-shutdownTimeout.C:
|
|
|
+ t.Fatalf("controller shutdown timeout exceeded")
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
// Test: tunnel must be established within 60 seconds
|
|
|
|
|
|
establishTimeout := time.NewTimer(60 * time.Second)
|
|
|
@@ -150,33 +325,26 @@ func doControllerRun(t *testing.T, protocol string, hostNameTransformer HostName
|
|
|
select {
|
|
|
case <-tunnelEstablished:
|
|
|
|
|
|
- // Allow for known race condition described in NewHttpProxy():
|
|
|
- time.Sleep(1 * time.Second)
|
|
|
-
|
|
|
- // Test: fetch website through tunnel
|
|
|
- fetchWebsite(t, httpProxyPort)
|
|
|
-
|
|
|
case <-establishTimeout.C:
|
|
|
- t.Errorf("tunnel establish timeout exceeded")
|
|
|
- // ...continue with cleanup
|
|
|
+ t.Fatalf("tunnel establish timeout exceeded")
|
|
|
}
|
|
|
|
|
|
- close(shutdownBroadcast)
|
|
|
+ // Allow for known race condition described in NewHttpProxy():
|
|
|
+ time.Sleep(1 * time.Second)
|
|
|
|
|
|
- // Test: shutdown must complete within 10 seconds
|
|
|
+ // Test: fetch website through tunnel
|
|
|
+ fetchWebsite(t, httpProxyPort)
|
|
|
|
|
|
- shutdownTimeout := time.NewTimer(10 * time.Second)
|
|
|
+ // Test: upgrade must be downloaded within 60 seconds
|
|
|
|
|
|
- shutdownOk := make(chan struct{}, 1)
|
|
|
- go func() {
|
|
|
- controllerWaitGroup.Wait()
|
|
|
- shutdownOk <- *new(struct{})
|
|
|
- }()
|
|
|
+ downloadTimeout := time.NewTimer(60 * time.Second)
|
|
|
|
|
|
select {
|
|
|
- case <-shutdownOk:
|
|
|
- case <-shutdownTimeout.C:
|
|
|
- t.Errorf("controller shutdown timeout exceeded")
|
|
|
+ case <-upgradeDownloaded:
|
|
|
+ // TODO: verify downloaded file
|
|
|
+
|
|
|
+ case <-downloadTimeout.C:
|
|
|
+ t.Fatalf("upgrade download timeout exceeded")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -194,8 +362,7 @@ func fetchWebsite(t *testing.T, httpProxyPort int) {
|
|
|
|
|
|
proxyUrl, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", httpProxyPort))
|
|
|
if err != nil {
|
|
|
- t.Errorf("error initializing proxied HTTP request: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error initializing proxied HTTP request: %s", err)
|
|
|
}
|
|
|
|
|
|
httpClient := &http.Client{
|
|
|
@@ -207,20 +374,17 @@ func fetchWebsite(t *testing.T, httpProxyPort int) {
|
|
|
|
|
|
response, err := httpClient.Get(testUrl)
|
|
|
if err != nil {
|
|
|
- t.Errorf("error sending proxied HTTP request: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error sending proxied HTTP request: %s", err)
|
|
|
}
|
|
|
|
|
|
body, err := ioutil.ReadAll(response.Body)
|
|
|
if err != nil {
|
|
|
- t.Errorf("error reading proxied HTTP response: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error reading proxied HTTP response: %s", err)
|
|
|
}
|
|
|
response.Body.Close()
|
|
|
|
|
|
if !checkResponse(string(body)) {
|
|
|
- t.Errorf("unexpected proxied HTTP response")
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("unexpected proxied HTTP response")
|
|
|
}
|
|
|
|
|
|
// Test: use direct URL proxy
|
|
|
@@ -234,20 +398,17 @@ func fetchWebsite(t *testing.T, httpProxyPort int) {
|
|
|
fmt.Sprintf("http://127.0.0.1:%d/direct/%s",
|
|
|
httpProxyPort, url.QueryEscape(testUrl)))
|
|
|
if err != nil {
|
|
|
- t.Errorf("error sending direct URL request: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error sending direct URL request: %s", err)
|
|
|
}
|
|
|
|
|
|
body, err = ioutil.ReadAll(response.Body)
|
|
|
if err != nil {
|
|
|
- t.Errorf("error reading direct URL response: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error reading direct URL response: %s", err)
|
|
|
}
|
|
|
response.Body.Close()
|
|
|
|
|
|
if !checkResponse(string(body)) {
|
|
|
- t.Errorf("unexpected direct URL response")
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("unexpected direct URL response")
|
|
|
}
|
|
|
|
|
|
// Test: use tunneled URL proxy
|
|
|
@@ -256,19 +417,69 @@ func fetchWebsite(t *testing.T, httpProxyPort int) {
|
|
|
fmt.Sprintf("http://127.0.0.1:%d/tunneled/%s",
|
|
|
httpProxyPort, url.QueryEscape(testUrl)))
|
|
|
if err != nil {
|
|
|
- t.Errorf("error sending tunneled URL request: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error sending tunneled URL request: %s", err)
|
|
|
}
|
|
|
|
|
|
body, err = ioutil.ReadAll(response.Body)
|
|
|
if err != nil {
|
|
|
- t.Errorf("error reading tunneled URL response: %s", err)
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("error reading tunneled URL response: %s", err)
|
|
|
}
|
|
|
response.Body.Close()
|
|
|
|
|
|
if !checkResponse(string(body)) {
|
|
|
- t.Errorf("unexpected tunneled URL response")
|
|
|
- t.FailNow()
|
|
|
+ t.Fatalf("unexpected tunneled URL response")
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+const disruptorProxyAddress = "127.0.0.1:2160"
|
|
|
+const disruptorProxyURL = "socks4a://" + disruptorProxyAddress
|
|
|
+const disruptorMaxConnectionBytes = 2000000
|
|
|
+const disruptorMaxConnectionTime = 15 * time.Second
|
|
|
+
|
|
|
+func initDisruptor() {
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ listener, err := socks.ListenSocks("tcp", disruptorProxyAddress)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Errorf("disruptor proxy listen error: %s", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ for {
|
|
|
+ localConn, err := listener.AcceptSocks()
|
|
|
+ if err != nil {
|
|
|
+ fmt.Errorf("disruptor proxy accept error: %s", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ go func() {
|
|
|
+ defer localConn.Close()
|
|
|
+ remoteConn, err := net.Dial("tcp", localConn.Req.Target)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Errorf("disruptor proxy dial error: %s", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer remoteConn.Close()
|
|
|
+ err = localConn.Grant(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 0})
|
|
|
+ if err != nil {
|
|
|
+ fmt.Errorf("disruptor proxy grant error: %s", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Cut connection after disruptorMaxConnectionTime
|
|
|
+ time.AfterFunc(disruptorMaxConnectionTime, func() {
|
|
|
+ localConn.Close()
|
|
|
+ remoteConn.Close()
|
|
|
+ })
|
|
|
+
|
|
|
+ // Relay connection, but only up to disruptorMaxConnectionBytes
|
|
|
+ waitGroup := new(sync.WaitGroup)
|
|
|
+ waitGroup.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer waitGroup.Done()
|
|
|
+ io.CopyN(localConn, remoteConn, disruptorMaxConnectionBytes)
|
|
|
+ }()
|
|
|
+ io.CopyN(remoteConn, localConn, disruptorMaxConnectionBytes)
|
|
|
+ waitGroup.Wait()
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+}
|