Răsfoiți Sursa

Add test coverage for upstream rate limiting

Rod Hynes 9 ani în urmă
părinte
comite
778de2a8f7
1 a modificat fișierele cu 39 adăugiri și 17 ștergeri
  1. 39 17
      psiphon/common/throttled_test.go

+ 39 - 17
psiphon/common/throttled_test.go

@@ -20,6 +20,7 @@
 package common
 
 import (
+	"bytes"
 	"fmt"
 	"io/ioutil"
 	"math"
@@ -30,7 +31,7 @@ import (
 )
 
 const (
-	serverAddress = "127.0.0.1:8080"
+	serverAddress = "127.0.0.1:8081"
 	testDataSize  = 10 * 1024 * 1024 // 10 MB
 )
 
@@ -47,21 +48,28 @@ func TestThrottledConn(t *testing.T) {
 		DownstreamUnlimitedBytes: 0,
 		DownstreamBytesPerSecond: 5 * 1024 * 1024,
 		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   0,
+		UpstreamBytesPerSecond:   5 * 1024 * 1024,
 	})
 
 	run(t, RateLimits{
 		DownstreamUnlimitedBytes: 0,
 		DownstreamBytesPerSecond: 2 * 1024 * 1024,
 		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   0,
+		UpstreamBytesPerSecond:   2 * 1024 * 1024,
 	})
 
 	run(t, RateLimits{
 		DownstreamUnlimitedBytes: 0,
 		DownstreamBytesPerSecond: 1024 * 1024,
 		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   0,
+		UpstreamBytesPerSecond:   1024 * 1024,
+	})
+
+	run(t, RateLimits{
+		DownstreamUnlimitedBytes: 0,
+		DownstreamBytesPerSecond: 1024 * 1024 / 8,
+		UpstreamUnlimitedBytes:   0,
+		UpstreamBytesPerSecond:   1024 * 1024 / 8,
 	})
 }
 
@@ -72,6 +80,7 @@ func run(t *testing.T, rateLimits RateLimits) {
 	go func() {
 
 		handler := func(w http.ResponseWriter, r *http.Request) {
+			_, _ = ioutil.ReadAll(r.Body)
 			testData, _ := MakeSecureRandomBytes(testDataSize)
 			w.Write(testData)
 		}
@@ -103,11 +112,14 @@ func run(t *testing.T, rateLimits RateLimits) {
 		},
 	}
 
-	// Download a large chunk of data, and time it
+	// Upload and download a large chunk of data, and time it
+
+	testData, _ := MakeSecureRandomBytes(testDataSize)
+	requestBody := bytes.NewReader(testData)
 
 	startTime := time.Now()
 
-	response, err := client.Get("http://" + serverAddress)
+	response, err := client.Post("http://"+serverAddress, "application/octet-stream", requestBody)
 	if err == nil && response.StatusCode != http.StatusOK {
 		response.Body.Close()
 		err = fmt.Errorf("unexpected response code: %d", response.StatusCode)
@@ -116,6 +128,13 @@ func run(t *testing.T, rateLimits RateLimits) {
 		t.Fatalf("request failed: %s", err)
 	}
 	defer response.Body.Close()
+
+	// Test: elapsed upload time must reflect rate limit
+
+	checkElapsedTime(t, testDataSize, rateLimits.UpstreamBytesPerSecond, time.Now().Sub(startTime))
+
+	startTime = time.Now()
+
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
 		t.Fatalf("read response failed: %s", err)
@@ -124,18 +143,21 @@ func run(t *testing.T, rateLimits RateLimits) {
 		t.Fatalf("unexpected response size: %d", len(body))
 	}
 
-	duration := time.Now().Sub(startTime)
+	// Test: elapsed download time must reflect rate limit
+
+	checkElapsedTime(t, testDataSize, rateLimits.DownstreamBytesPerSecond, time.Now().Sub(startTime))
+}
 
-	// Test: elapsed time must reflect rate limit
+func checkElapsedTime(t *testing.T, dataSize int, rateLimit int64, duration time.Duration) {
 
-	// No rate limit should finish under a couple seconds
+	// With no rate limit, should finish under a couple seconds
 	floorElapsedTime := 0 * time.Second
 	ceilingElapsedTime := 2 * time.Second
 
-	if rateLimits.DownstreamBytesPerSecond != 0 {
+	if rateLimit != 0 {
 		// With rate limit, should finish within a couple seconds or so of data size / bytes-per-second;
-		// won't be eaxact due to request overhead and approximations in "ratelimit" package
-		expectedElapsedTime := float64(testDataSize) / float64(rateLimits.DownstreamBytesPerSecond)
+		// won't be exact due to request overhead and approximations in "ratelimit" package
+		expectedElapsedTime := float64(testDataSize) / float64(rateLimit)
 		floorElapsedTime = time.Duration(int64(math.Floor(expectedElapsedTime))) * time.Second
 		floorElapsedTime -= 1500 * time.Millisecond
 		if floorElapsedTime < 0 {
@@ -146,18 +168,18 @@ func run(t *testing.T, rateLimits RateLimits) {
 	}
 
 	t.Logf(
-		"data size: %d; downstream rate limit: %d; elapsed time: %s; expected time: [%s,%s]",
-		testDataSize,
-		rateLimits.DownstreamBytesPerSecond,
+		"\ndata size: %d\nrate limit: %d\nelapsed time: %s\nexpected time: [%s,%s]\n\n",
+		dataSize,
+		rateLimit,
 		duration,
 		floorElapsedTime,
 		ceilingElapsedTime)
 
 	if duration < floorElapsedTime {
-		t.Fatalf("unexpected duration: %s < %s", duration, floorElapsedTime)
+		t.Errorf("unexpected duration: %s < %s", duration, floorElapsedTime)
 	}
 
 	if duration > ceilingElapsedTime {
-		t.Fatalf("unexpected duration: %s > %s", duration, ceilingElapsedTime)
+		t.Errorf("unexpected duration: %s > %s", duration, ceilingElapsedTime)
 	}
 }