mirokuratczyk 2 лет назад
Родитель
Сommit
b2c25ec8c8

+ 9 - 7
psiphon/common/transforms/httpNormalizer_test.go

@@ -115,13 +115,15 @@ func runHTTPNormalizerTest(tt *httpNormalizerTest, useNormalizer bool) error {
 		}
 		}
 	}
 	}
 
 
+	// Calling Read on an instance of HTTPNormalizer will return io.EOF once a
+	// passthrough has been activated.
 	if tt.validateMeekCookie != nil && err == io.EOF {
 	if tt.validateMeekCookie != nil && err == io.EOF {
 
 
 		// wait for passthrough to complete
 		// wait for passthrough to complete
 
 
 		timeout := time.After(time.Second)
 		timeout := time.After(time.Second)
 
 
-		for len(passthroughConn.readBuffer) != 0 || len(conn.readBuffer) != 0 {
+		for len(passthroughConn.ReadBuffer()) != 0 || len(conn.ReadBuffer()) != 0 {
 
 
 			select {
 			select {
 			case <-timeout:
 			case <-timeout:
@@ -149,20 +151,20 @@ func runHTTPNormalizerTest(tt *httpNormalizerTest, useNormalizer bool) error {
 			return errors.TraceNew("expected to read no bytes")
 			return errors.TraceNew("expected to read no bytes")
 		}
 		}
 
 
-		if string(passthroughConn.readBuffer) != "" {
+		if string(passthroughConn.ReadBuffer()) != "" {
 			return errors.TraceNew("expected read buffer to be emptied")
 			return errors.TraceNew("expected read buffer to be emptied")
 		}
 		}
 
 
-		if string(passthroughConn.writeBuffer) != tt.wantOutput {
-			return errors.Tracef("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(passthroughConn.writeBuffer)), len(passthroughConn.writeBuffer))
+		if string(passthroughConn.WriteBuffer()) != tt.wantOutput {
+			return errors.Tracef("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(passthroughConn.WriteBuffer())), len(passthroughConn.WriteBuffer()))
 		}
 		}
 
 
-		if string(conn.readBuffer) != "" {
+		if string(conn.ReadBuffer()) != "" {
 			return errors.TraceNew("expected read buffer to be emptied")
 			return errors.TraceNew("expected read buffer to be emptied")
 		}
 		}
 
 
-		if string(conn.writeBuffer) != passthroughMessage {
-			return errors.Tracef("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(passthroughMessage), len(passthroughMessage), escapeNewlines(string(conn.writeBuffer)), len(conn.writeBuffer))
+		if string(conn.WriteBuffer()) != passthroughMessage {
+			return errors.Tracef("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(passthroughMessage), len(passthroughMessage), escapeNewlines(string(conn.WriteBuffer())), len(conn.WriteBuffer()))
 		}
 		}
 	}
 	}
 
 

+ 42 - 7
psiphon/common/transforms/httpTransformer_test.go

@@ -29,6 +29,7 @@ import (
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"strings"
 	"strings"
+	"sync"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -271,8 +272,8 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
 				if err != nil {
 				if err != nil {
 					t.Fatalf("unexpected error %v", err)
 					t.Fatalf("unexpected error %v", err)
 				}
 				}
-				if string(conn.writeBuffer) != tt.wantOutput {
-					t.Fatalf("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(conn.writeBuffer)), len(conn.writeBuffer))
+				if string(conn.WriteBuffer()) != tt.wantOutput {
+					t.Fatalf("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(conn.WriteBuffer())), len(conn.WriteBuffer()))
 				}
 				}
 			} else {
 			} else {
 				// tt.wantError != nil
 				// tt.wantError != nil
@@ -461,10 +462,16 @@ func escapeNewlines(s string) string {
 }
 }
 
 
 type testConn struct {
 type testConn struct {
-	// writeBuffer are the accumulated bytes from Write() calls.
-	writeBuffer []byte
+	readLock sync.Mutex
 	// readBuffer are the bytes to return from Read() calls.
 	// readBuffer are the bytes to return from Read() calls.
 	readBuffer []byte
 	readBuffer []byte
+	// readErrs are returned from Read() calls in order. If empty, then a nil
+	// error is returned.
+	readErrs []error
+
+	writeLock sync.Mutex
+	// writeBuffer are the accumulated bytes from Write() calls.
+	writeBuffer []byte
 	// writeLimit is the max number of bytes that will be written in a Write()
 	// writeLimit is the max number of bytes that will be written in a Write()
 	// call.
 	// call.
 	writeLimit int
 	writeLimit int
@@ -475,15 +482,28 @@ type testConn struct {
 	// writeErrs are returned from Write() calls in order. If empty, then a nil
 	// writeErrs are returned from Write() calls in order. If empty, then a nil
 	// error is returned.
 	// error is returned.
 	writeErrs []error
 	writeErrs []error
-	// readErrs are returned from Read() calls in order. If empty, then a nil
-	// error is returned.
-	readErrs []error
 
 
 	net.Conn
 	net.Conn
 }
 }
 
 
+// ReadBuffer returns a copy of the underlying readBuffer. The length of the
+// returned buffer is also the number of bytes remaining to be Read when Conn
+// is not set.
+func (c *testConn) ReadBuffer() []byte {
+	c.readLock.Lock()
+	defer c.readLock.Unlock()
+
+	readBufferCopy := make([]byte, len(c.readBuffer))
+	copy(readBufferCopy, c.readBuffer)
+
+	return readBufferCopy
+}
+
 func (c *testConn) Read(b []byte) (n int, err error) {
 func (c *testConn) Read(b []byte) (n int, err error) {
 
 
+	c.readLock.Lock()
+	defer c.readLock.Unlock()
+
 	if len(c.readErrs) > 0 {
 	if len(c.readErrs) > 0 {
 		err = c.readErrs[0]
 		err = c.readErrs[0]
 		c.readErrs = c.readErrs[1:]
 		c.readErrs = c.readErrs[1:]
@@ -509,8 +529,23 @@ func (c *testConn) Read(b []byte) (n int, err error) {
 	return
 	return
 }
 }
 
 
+// WriteBuffer returns a copy of the underlying writeBuffer, which is the
+// accumulation of all bytes written with Write.
+func (c *testConn) WriteBuffer() []byte {
+	c.readLock.Lock()
+	defer c.readLock.Unlock()
+
+	writeBufferCopy := make([]byte, len(c.writeBuffer))
+	copy(writeBufferCopy, c.writeBuffer)
+
+	return writeBufferCopy
+}
+
 func (c *testConn) Write(b []byte) (n int, err error) {
 func (c *testConn) Write(b []byte) (n int, err error) {
 
 
+	c.writeLock.Lock()
+	defer c.writeLock.Unlock()
+
 	if len(c.writeErrs) > 0 {
 	if len(c.writeErrs) > 0 {
 		err = c.writeErrs[0]
 		err = c.writeErrs[0]
 		c.writeErrs = c.writeErrs[1:]
 		c.writeErrs = c.writeErrs[1:]