فهرست منبع

Use a mock web server in server_test

Rod Hynes 9 سال پیش
والد
کامیت
2154c6030f
1فایلهای تغییر یافته به همراه58 افزوده شده و 20 حذف شده
  1. 58 20
      psiphon/server/server_test.go

+ 58 - 20
psiphon/server/server_test.go

@@ -41,12 +41,25 @@ import (
 	"golang.org/x/net/proxy"
 )
 
-var testDataDirName string
+var serverIPAddress, testDataDirName string
+var mockWebServerURL, mockWebServerExpectedResponse string
+var mockWebServerPort = 8080
 
 func TestMain(m *testing.M) {
 	flag.Parse()
 
 	var err error
+	for _, interfaceName := range []string{"eth0", "en0"} {
+		serverIPAddress, err = common.GetInterfaceIPAddress(interfaceName)
+		if err == nil {
+			break
+		}
+	}
+	if err != nil {
+		fmt.Printf("error getting server IP address: %s", err)
+		os.Exit(1)
+	}
+
 	testDataDirName, err = ioutil.TempDir("", "psiphon-server-test")
 	if err != nil {
 		fmt.Printf("TempDir failed: %s\n", err)
@@ -60,9 +73,39 @@ func TestMain(m *testing.M) {
 
 	CLIENT_VERIFICATION_REQUIRED = true
 
+	mockWebServerURL, mockWebServerExpectedResponse = runMockWebServer()
+
 	os.Exit(m.Run())
 }
 
+func runMockWebServer() (string, string) {
+
+	responseBody, _ := common.MakeRandomStringHex(100000)
+
+	serveMux := http.NewServeMux()
+	serveMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte(responseBody))
+	})
+	webServerAddress := fmt.Sprintf("%s:%d", serverIPAddress, mockWebServerPort)
+	server := &http.Server{
+		Addr:    webServerAddress,
+		Handler: serveMux,
+	}
+
+	go func() {
+		err := server.ListenAndServe()
+		if err != nil {
+			fmt.Printf("error running mock web server: %s\n", err)
+			os.Exit(1)
+		}
+	}()
+
+	// TODO: properly synchronize with web server readiness
+	time.Sleep(1 * time.Second)
+
+	return fmt.Sprintf("http://%s/", webServerAddress), responseBody
+}
+
 // Note: not testing fronting meek protocols, which client is
 // hard-wired to except running on privileged ports 80 and 443.
 
@@ -256,21 +299,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 	// create a server
 
-	var err error
-	serverIPaddress := ""
-	for _, interfaceName := range []string{"eth0", "en0"} {
-		serverIPaddress, err = common.GetInterfaceIPAddress(interfaceName)
-		if err == nil {
-			break
-		}
-	}
-	if err != nil {
-		t.Fatalf("error getting server IP address: %s", err)
-	}
-
 	serverConfigJSON, _, encodedServerEntry, err := GenerateConfig(
 		&GenerateConfigParams{
-			ServerIPAddress:      serverIPaddress,
+			ServerIPAddress:      serverIPAddress,
 			EnableSSHAPIRequests: runConfig.enableSSHAPIRequests,
 			WebServerPort:        8000,
 			TunnelProtocolPorts:  map[string]int{runConfig.tunnelProtocol: 4000},
@@ -510,7 +541,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 		// Test: tunneled web site fetch
 
-		err = makeTunneledWebRequest(t, localHTTPProxyPort)
+		err = makeTunneledWebRequest(
+			t, localHTTPProxyPort, mockWebServerURL, mockWebServerExpectedResponse)
 
 		if err == nil {
 			if runConfig.denyTrafficRules {
@@ -556,9 +588,11 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	}
 }
 
-func makeTunneledWebRequest(t *testing.T, localHTTPProxyPort int) error {
+func makeTunneledWebRequest(
+	t *testing.T,
+	localHTTPProxyPort int,
+	requestURL, expectedResponseBody string) error {
 
-	testUrl := "https://psiphon.ca"
 	roundTripTimeout := 30 * time.Second
 
 	proxyUrl, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", localHTTPProxyPort))
@@ -573,17 +607,21 @@ func makeTunneledWebRequest(t *testing.T, localHTTPProxyPort int) error {
 		Timeout: roundTripTimeout,
 	}
 
-	response, err := httpClient.Get(testUrl)
+	response, err := httpClient.Get(requestURL)
 	if err != nil {
 		return fmt.Errorf("error sending proxied HTTP request: %s", err)
 	}
 
-	_, err = ioutil.ReadAll(response.Body)
+	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
 		return fmt.Errorf("error reading proxied HTTP response: %s", err)
 	}
 	response.Body.Close()
 
+	if string(body) != expectedResponseBody {
+		return fmt.Errorf("unexpected proxied HTTP response")
+	}
+
 	return nil
 }
 
@@ -834,7 +872,7 @@ func pavePsinetDatabaseFile(
 func paveTrafficRulesFile(
 	t *testing.T, trafficRulesFilename, propagationChannelID string, deny bool) {
 
-	allowTCPPorts := "443"
+	allowTCPPorts := fmt.Sprintf("%d", mockWebServerPort)
 	allowUDPPorts := "53, 123"
 
 	if deny {