Просмотр исходного кода

Merge pull request #225 from rod-hynes/master

Additional test coverage
Rod Hynes 9 лет назад
Родитель
Сommit
fef122644b
4 измененных файлов с 148 добавлено и 25 удалено
  1. 1 3
      psiphon/common/throttled.go
  2. 39 17
      psiphon/common/throttled_test.go
  3. 101 0
      psiphon/server/server_test.go
  4. 7 5
      psiphon/server/udp.go

+ 1 - 3
psiphon/common/throttled.go

@@ -145,9 +145,7 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 
 
 	bytesWritten := 0
 	bytesWritten := 0
 
 
-	for i := 0; i < len(buffer); i += chunkSize {
-
-		start := i
+	for start := 0; start < len(buffer); start += chunkSize {
 		end := start + chunkSize
 		end := start + chunkSize
 		if end > len(buffer) {
 		if end > len(buffer) {
 			end = len(buffer)
 			end = len(buffer)

+ 39 - 17
psiphon/common/throttled_test.go

@@ -20,6 +20,7 @@
 package common
 package common
 
 
 import (
 import (
+	"bytes"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
 	"math"
 	"math"
@@ -30,7 +31,7 @@ import (
 )
 )
 
 
 const (
 const (
-	serverAddress = "127.0.0.1:8080"
+	serverAddress = "127.0.0.1:8081"
 	testDataSize  = 10 * 1024 * 1024 // 10 MB
 	testDataSize  = 10 * 1024 * 1024 // 10 MB
 )
 )
 
 
@@ -47,21 +48,28 @@ func TestThrottledConn(t *testing.T) {
 		DownstreamUnlimitedBytes: 0,
 		DownstreamUnlimitedBytes: 0,
 		DownstreamBytesPerSecond: 5 * 1024 * 1024,
 		DownstreamBytesPerSecond: 5 * 1024 * 1024,
 		UpstreamUnlimitedBytes:   0,
 		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   0,
+		UpstreamBytesPerSecond:   5 * 1024 * 1024,
 	})
 	})
 
 
 	run(t, RateLimits{
 	run(t, RateLimits{
 		DownstreamUnlimitedBytes: 0,
 		DownstreamUnlimitedBytes: 0,
 		DownstreamBytesPerSecond: 2 * 1024 * 1024,
 		DownstreamBytesPerSecond: 2 * 1024 * 1024,
 		UpstreamUnlimitedBytes:   0,
 		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   0,
+		UpstreamBytesPerSecond:   2 * 1024 * 1024,
 	})
 	})
 
 
 	run(t, RateLimits{
 	run(t, RateLimits{
 		DownstreamUnlimitedBytes: 0,
 		DownstreamUnlimitedBytes: 0,
 		DownstreamBytesPerSecond: 1024 * 1024,
 		DownstreamBytesPerSecond: 1024 * 1024,
 		UpstreamUnlimitedBytes:   0,
 		UpstreamUnlimitedBytes:   0,
-		UpstreamBytesPerSecond:   0,
+		UpstreamBytesPerSecond:   1024 * 1024,
+	})
+
+	run(t, RateLimits{
+		DownstreamUnlimitedBytes: 0,
+		DownstreamBytesPerSecond: 1024 * 1024 / 8,
+		UpstreamUnlimitedBytes:   0,
+		UpstreamBytesPerSecond:   1024 * 1024 / 8,
 	})
 	})
 }
 }
 
 
@@ -72,6 +80,7 @@ func run(t *testing.T, rateLimits RateLimits) {
 	go func() {
 	go func() {
 
 
 		handler := func(w http.ResponseWriter, r *http.Request) {
 		handler := func(w http.ResponseWriter, r *http.Request) {
+			_, _ = ioutil.ReadAll(r.Body)
 			testData, _ := MakeSecureRandomBytes(testDataSize)
 			testData, _ := MakeSecureRandomBytes(testDataSize)
 			w.Write(testData)
 			w.Write(testData)
 		}
 		}
@@ -103,11 +112,14 @@ func run(t *testing.T, rateLimits RateLimits) {
 		},
 		},
 	}
 	}
 
 
-	// Download a large chunk of data, and time it
+	// Upload and download a large chunk of data, and time it
+
+	testData, _ := MakeSecureRandomBytes(testDataSize)
+	requestBody := bytes.NewReader(testData)
 
 
 	startTime := time.Now()
 	startTime := time.Now()
 
 
-	response, err := client.Get("http://" + serverAddress)
+	response, err := client.Post("http://"+serverAddress, "application/octet-stream", requestBody)
 	if err == nil && response.StatusCode != http.StatusOK {
 	if err == nil && response.StatusCode != http.StatusOK {
 		response.Body.Close()
 		response.Body.Close()
 		err = fmt.Errorf("unexpected response code: %d", response.StatusCode)
 		err = fmt.Errorf("unexpected response code: %d", response.StatusCode)
@@ -116,6 +128,13 @@ func run(t *testing.T, rateLimits RateLimits) {
 		t.Fatalf("request failed: %s", err)
 		t.Fatalf("request failed: %s", err)
 	}
 	}
 	defer response.Body.Close()
 	defer response.Body.Close()
+
+	// Test: elapsed upload time must reflect rate limit
+
+	checkElapsedTime(t, testDataSize, rateLimits.UpstreamBytesPerSecond, time.Now().Sub(startTime))
+
+	startTime = time.Now()
+
 	body, err := ioutil.ReadAll(response.Body)
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("read response failed: %s", err)
 		t.Fatalf("read response failed: %s", err)
@@ -124,18 +143,21 @@ func run(t *testing.T, rateLimits RateLimits) {
 		t.Fatalf("unexpected response size: %d", len(body))
 		t.Fatalf("unexpected response size: %d", len(body))
 	}
 	}
 
 
-	duration := time.Now().Sub(startTime)
+	// Test: elapsed download time must reflect rate limit
+
+	checkElapsedTime(t, testDataSize, rateLimits.DownstreamBytesPerSecond, time.Now().Sub(startTime))
+}
 
 
-	// Test: elapsed time must reflect rate limit
+func checkElapsedTime(t *testing.T, dataSize int, rateLimit int64, duration time.Duration) {
 
 
-	// No rate limit should finish under a couple seconds
+	// With no rate limit, should finish under a couple seconds
 	floorElapsedTime := 0 * time.Second
 	floorElapsedTime := 0 * time.Second
 	ceilingElapsedTime := 2 * time.Second
 	ceilingElapsedTime := 2 * time.Second
 
 
-	if rateLimits.DownstreamBytesPerSecond != 0 {
+	if rateLimit != 0 {
 		// With rate limit, should finish within a couple seconds or so of data size / bytes-per-second;
 		// With rate limit, should finish within a couple seconds or so of data size / bytes-per-second;
-		// won't be eaxact due to request overhead and approximations in "ratelimit" package
-		expectedElapsedTime := float64(testDataSize) / float64(rateLimits.DownstreamBytesPerSecond)
+		// won't be exact due to request overhead and approximations in "ratelimit" package
+		expectedElapsedTime := float64(testDataSize) / float64(rateLimit)
 		floorElapsedTime = time.Duration(int64(math.Floor(expectedElapsedTime))) * time.Second
 		floorElapsedTime = time.Duration(int64(math.Floor(expectedElapsedTime))) * time.Second
 		floorElapsedTime -= 1500 * time.Millisecond
 		floorElapsedTime -= 1500 * time.Millisecond
 		if floorElapsedTime < 0 {
 		if floorElapsedTime < 0 {
@@ -146,18 +168,18 @@ func run(t *testing.T, rateLimits RateLimits) {
 	}
 	}
 
 
 	t.Logf(
 	t.Logf(
-		"data size: %d; downstream rate limit: %d; elapsed time: %s; expected time: [%s,%s]",
-		testDataSize,
-		rateLimits.DownstreamBytesPerSecond,
+		"\ndata size: %d\nrate limit: %d\nelapsed time: %s\nexpected time: [%s,%s]\n\n",
+		dataSize,
+		rateLimit,
 		duration,
 		duration,
 		floorElapsedTime,
 		floorElapsedTime,
 		ceilingElapsedTime)
 		ceilingElapsedTime)
 
 
 	if duration < floorElapsedTime {
 	if duration < floorElapsedTime {
-		t.Fatalf("unexpected duration: %s < %s", duration, floorElapsedTime)
+		t.Errorf("unexpected duration: %s < %s", duration, floorElapsedTime)
 	}
 	}
 
 
 	if duration > ceilingElapsedTime {
 	if duration > ceilingElapsedTime {
-		t.Fatalf("unexpected duration: %s > %s", duration, ceilingElapsedTime)
+		t.Errorf("unexpected duration: %s > %s", duration, ceilingElapsedTime)
 	}
 	}
 }
 }

+ 101 - 0
psiphon/server/server_test.go

@@ -24,6 +24,7 @@ import (
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"os"
 	"os"
@@ -34,6 +35,7 @@ import (
 
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"golang.org/x/net/proxy"
 )
 )
 
 
 func TestMain(m *testing.M) {
 func TestMain(m *testing.M) {
@@ -165,6 +167,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	serverConfig.(map[string]interface{})["GeoIPDatabaseFilename"] = ""
 	serverConfig.(map[string]interface{})["GeoIPDatabaseFilename"] = ""
 	serverConfig.(map[string]interface{})["PsinetDatabaseFilename"] = psinetFilename
 	serverConfig.(map[string]interface{})["PsinetDatabaseFilename"] = psinetFilename
 	serverConfig.(map[string]interface{})["TrafficRulesFilename"] = ""
 	serverConfig.(map[string]interface{})["TrafficRulesFilename"] = ""
+	serverConfig.(map[string]interface{})["LogLevel"] = "debug"
+
 	serverConfigJSON, _ = json.Marshal(serverConfig)
 	serverConfigJSON, _ = json.Marshal(serverConfig)
 
 
 	// run server
 	// run server
@@ -225,6 +229,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 
 	// TODO: currently, TargetServerEntry only works with one tunnel
 	// TODO: currently, TargetServerEntry only works with one tunnel
 	numTunnels := 1
 	numTunnels := 1
+	localSOCKSProxyPort := 1081
 	localHTTPProxyPort := 8081
 	localHTTPProxyPort := 8081
 	establishTunnelPausePeriodSeconds := 1
 	establishTunnelPausePeriodSeconds := 1
 
 
@@ -245,6 +250,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	clientConfig.EstablishTunnelPausePeriodSeconds = &establishTunnelPausePeriodSeconds
 	clientConfig.EstablishTunnelPausePeriodSeconds = &establishTunnelPausePeriodSeconds
 	clientConfig.TargetServerEntry = string(encodedServerEntry)
 	clientConfig.TargetServerEntry = string(encodedServerEntry)
 	clientConfig.TunnelProtocol = runConfig.tunnelProtocol
 	clientConfig.TunnelProtocol = runConfig.tunnelProtocol
+	clientConfig.LocalSocksProxyPort = localSOCKSProxyPort
 	clientConfig.LocalHttpProxyPort = localHTTPProxyPort
 	clientConfig.LocalHttpProxyPort = localHTTPProxyPort
 
 
 	err = psiphon.InitDataStore(clientConfig)
 	err = psiphon.InitDataStore(clientConfig)
@@ -338,6 +344,16 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 
 	// Test: tunneled web site fetch
 	// Test: tunneled web site fetch
 
 
+	makeTunneledWebRequest(t, localHTTPProxyPort)
+
+	// Test: tunneled UDP packet
+
+	udpgwServerAddress := serverConfig.(map[string]interface{})["UDPInterceptUdpgwServerAddress"].(string)
+	makeTunneledDNSRequest(t, localSOCKSProxyPort, udpgwServerAddress)
+}
+
+func makeTunneledWebRequest(t *testing.T, localHTTPProxyPort int) {
+
 	testUrl := "https://psiphon.ca"
 	testUrl := "https://psiphon.ca"
 	roundTripTimeout := 30 * time.Second
 	roundTripTimeout := 30 * time.Second
 
 
@@ -365,6 +381,91 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	response.Body.Close()
 	response.Body.Close()
 }
 }
 
 
+func makeTunneledDNSRequest(t *testing.T, localSOCKSProxyPort int, udpgwServerAddress string) {
+
+	testHostname := "psiphon.ca"
+	timeout := 10 * time.Second
+
+	localUDPProxyAddress, err := net.ResolveUDPAddr("udp", "127.0.0.1:7301")
+	if err != nil {
+		t.Fatalf("ResolveUDPAddr failed: %s", err)
+	}
+
+	go func() {
+
+		serverUDPConn, err := net.ListenUDP("udp", localUDPProxyAddress)
+		if err != nil {
+			t.Fatalf("ListenUDP failed: %s", err)
+		}
+		defer serverUDPConn.Close()
+
+		udpgwPreambleSize := 11 // see writeUdpgwPreamble
+		buffer := make([]byte, udpgwProtocolMaxMessageSize)
+		packetSize, clientAddr, err := serverUDPConn.ReadFromUDP(
+			buffer[udpgwPreambleSize:len(buffer)])
+		if err != nil {
+			t.Fatalf("serverUDPConn.Read failed: %s", err)
+		}
+
+		socksProxyAddress := fmt.Sprintf("127.0.0.1:%d", localSOCKSProxyPort)
+
+		dialer, err := proxy.SOCKS5("tcp", socksProxyAddress, nil, proxy.Direct)
+		if err != nil {
+			t.Fatalf("proxy.SOCKS5 failed: %s", err)
+		}
+
+		socksTCPConn, err := dialer.Dial("tcp", udpgwServerAddress)
+		if err != nil {
+			t.Fatalf("dialer.Dial failed: %s", err)
+		}
+		defer socksTCPConn.Close()
+
+		err = writeUdpgwPreamble(
+			udpgwPreambleSize,
+			udpgwProtocolFlagDNS,
+			0,
+			make([]byte, 4), // ignored due to transparent DNS forwarding
+			53,
+			uint16(packetSize),
+			buffer)
+		if err != nil {
+			t.Fatalf("writeUdpgwPreamble failed: %s", err)
+		}
+
+		_, err = socksTCPConn.Write(buffer[0 : udpgwPreambleSize+packetSize])
+		if err != nil {
+			t.Fatalf("socksTCPConn.Write failed: %s", err)
+		}
+
+		updgwProtocolMessage, err := readUdpgwMessage(socksTCPConn, buffer)
+		if err != nil {
+			t.Fatalf("readUdpgwMessage failed: %s", err)
+		}
+
+		_, err = serverUDPConn.WriteToUDP(updgwProtocolMessage.packet, clientAddr)
+		if err != nil {
+			t.Fatalf("serverUDPConn.Write failed: %s", err)
+		}
+	}()
+
+	// TODO: properly synchronize with server startup
+	time.Sleep(1 * time.Second)
+
+	clientUDPConn, err := net.DialUDP("udp", nil, localUDPProxyAddress)
+	if err != nil {
+		t.Fatalf("DialUDP failed: %s", err)
+	}
+	defer clientUDPConn.Close()
+
+	clientUDPConn.SetReadDeadline(time.Now().Add(timeout))
+	clientUDPConn.SetWriteDeadline(time.Now().Add(timeout))
+
+	_, _, err = psiphon.ResolveIP(testHostname, clientUDPConn)
+	if err != nil {
+		t.Fatalf("ResolveIP failed: %s", err)
+	}
+}
+
 func pavePsinetDatabaseFile(t *testing.T, psinetFilename string) (string, string) {
 func pavePsinetDatabaseFile(t *testing.T, psinetFilename string) (string, string) {
 
 
 	sponsorID, _ := common.MakeRandomStringHex(8)
 	sponsorID, _ := common.MakeRandomStringHex(8)

+ 7 - 5
psiphon/server/udp.go

@@ -324,6 +324,7 @@ func (portForward *udpPortForward) relayDownstream() {
 
 
 		err = writeUdpgwPreamble(
 		err = writeUdpgwPreamble(
 			portForward.preambleSize,
 			portForward.preambleSize,
+			0,
 			portForward.connID,
 			portForward.connID,
 			portForward.remoteIP,
 			portForward.remoteIP,
 			portForward.remotePort,
 			portForward.remotePort,
@@ -377,7 +378,7 @@ const (
 	udpgwProtocolMaxMessageSize  = udpgwProtocolMaxPreambleSize + udpgwProtocolMaxPayloadSize
 	udpgwProtocolMaxMessageSize  = udpgwProtocolMaxPreambleSize + udpgwProtocolMaxPayloadSize
 )
 )
 
 
-type udpProtocolMessage struct {
+type udpgwProtocolMessage struct {
 	connID              uint16
 	connID              uint16
 	preambleSize        int
 	preambleSize        int
 	remoteIP            []byte
 	remoteIP            []byte
@@ -388,7 +389,7 @@ type udpProtocolMessage struct {
 }
 }
 
 
 func readUdpgwMessage(
 func readUdpgwMessage(
-	reader io.Reader, buffer []byte) (*udpProtocolMessage, error) {
+	reader io.Reader, buffer []byte) (*udpgwProtocolMessage, error) {
 
 
 	// udpgw message layout:
 	// udpgw message layout:
 	//
 	//
@@ -455,9 +456,9 @@ func readUdpgwMessage(
 		}
 		}
 
 
 		// Assemble message
 		// Assemble message
-		// Note: udpProtocolMessage.packet references memory in the input buffer
+		// Note: udpgwProtocolMessage.packet references memory in the input buffer
 
 
-		message := &udpProtocolMessage{
+		message := &udpgwProtocolMessage{
 			connID:              connID,
 			connID:              connID,
 			preambleSize:        packetStart,
 			preambleSize:        packetStart,
 			remoteIP:            remoteIP,
 			remoteIP:            remoteIP,
@@ -473,6 +474,7 @@ func readUdpgwMessage(
 
 
 func writeUdpgwPreamble(
 func writeUdpgwPreamble(
 	preambleSize int,
 	preambleSize int,
+	flags uint8,
 	connID uint16,
 	connID uint16,
 	remoteIP []byte,
 	remoteIP []byte,
 	remotePort uint16,
 	remotePort uint16,
@@ -490,7 +492,7 @@ func writeUdpgwPreamble(
 	buffer[1] = byte(size >> 8)
 	buffer[1] = byte(size >> 8)
 
 
 	// flags
 	// flags
-	buffer[2] = 0
+	buffer[2] = flags
 
 
 	// connID
 	// connID
 	buffer[3] = byte(connID & 0xFF)
 	buffer[3] = byte(connID & 0xFF)