mirokuratczyk 3 лет назад
Родитель
Сommit
441e8aed50

+ 1 - 1
psiphon/common/resolver/resolver.go

@@ -1643,7 +1643,7 @@ func (conn *transformDNSPacketConn) Write(b []byte) (int, error) {
 	// the network packet MTU.
 	// the network packet MTU.
 
 
 	input := hex.EncodeToString(b)
 	input := hex.EncodeToString(b)
-	output, err := conn.transform.Apply(conn.seed, input)
+	output, err := conn.transform.ApplyString(conn.seed, input)
 	if err != nil {
 	if err != nil {
 		return 0, errors.Trace(err)
 		return 0, errors.Trace(err)
 	}
 	}

+ 35 - 45
psiphon/common/transforms/httpTransformer.go

@@ -74,8 +74,9 @@ type HTTPTransformer struct {
 	// state is the HTTPTransformer state. Possible values are
 	// state is the HTTPTransformer state. Possible values are
 	// httpTransformerReadWriteHeader and httpTransformerReadWriteBody.
 	// httpTransformerReadWriteHeader and httpTransformerReadWriteBody.
 	state int64
 	state int64
-	// b is the accumulated bytes of the current HTTP request.
-	b []byte
+	// b is used to buffer the accumulated bytes of the current HTTP request
+	// header until the entire header is received and written.
+	b bytes.Buffer
 	// remain is the number of remaining HTTP request body bytes to read into b.
 	// remain is the number of remaining HTTP request body bytes to read into b.
 	remain uint64
 	remain uint64
 
 
@@ -100,7 +101,8 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 
 
 	if t.state == httpTransformerReadWriteHeader {
 	if t.state == httpTransformerReadWriteHeader {
 
 
-		t.b = append(t.b, b...)
+		// Do not need to check return value https://github.com/golang/go/blob/1e9ff255a130200fcc4ec5e911d28181fce947d5/src/bytes/buffer.go#L164
+		t.b.Write(b)
 
 
 		// Wait until the entire HTTP request header has been read. Must check
 		// Wait until the entire HTTP request header has been read. Must check
 		// all accumulated bytes incase the "\r\n\r\n" separator is written over
 		// all accumulated bytes incase the "\r\n\r\n" separator is written over
@@ -109,7 +111,7 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 
 
 		sep := []byte("\r\n\r\n")
 		sep := []byte("\r\n\r\n")
 
 
-		headerBodyLines := bytes.SplitN(t.b, sep, 2) // split header and body
+		headerBodyLines := bytes.SplitN(t.b.Bytes(), sep, 2) // split header and body
 
 
 		if len(headerBodyLines) <= 1 {
 		if len(headerBodyLines) <= 1 {
 			// b buffered in t.b and the entire HTTP request header has not been
 			// b buffered in t.b and the entire HTTP request header has not been
@@ -158,10 +160,10 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 		// transform and write header
 		// transform and write header
 
 
 		headerLen := len(headerBodyLines[0]) + len(sep)
 		headerLen := len(headerBodyLines[0]) + len(sep)
-		header := t.b[:headerLen]
+		header := t.b.Bytes()[:headerLen]
 
 
 		if t.transform != nil {
 		if t.transform != nil {
-			newHeaderS, err := t.transform.Apply(t.seed, string(header))
+			newHeader, err := t.transform.Apply(t.seed, header)
 			if err != nil {
 			if err != nil {
 				// TODO: consider logging an error and skiping transform
 				// TODO: consider logging an error and skiping transform
 				// instead of returning an error, if the transform is broken
 				// instead of returning an error, if the transform is broken
@@ -169,13 +171,18 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 				return len(b), errors.Trace(err)
 				return len(b), errors.Trace(err)
 			}
 			}
 
 
-			newHeader := []byte(newHeaderS)
-
 			// only allocate new slice if header length changed
 			// only allocate new slice if header length changed
 			if len(newHeader) == len(header) {
 			if len(newHeader) == len(header) {
-				copy(t.b[:len(header)], newHeader)
+				// Do not need to check return value. It is guaranteed that
+				// n == len(newHeader) because t.b.Len() >= n if the header
+				// size has not changed.
+				copy(t.b.Bytes()[:len(header)], newHeader)
 			} else {
 			} else {
-				t.b = append(newHeader, t.b[len(header):]...)
+				b := t.b.Bytes()
+				t.b.Reset()
+				// Do not need to check return value of bytes.Buffer.Write() https://github.com/golang/go/blob/1e9ff255a130200fcc4ec5e911d28181fce947d5/src/bytes/buffer.go#L164
+				t.b.Write(newHeader)
+				t.b.Write(b[len(header):])
 			}
 			}
 
 
 			header = newHeader
 			header = newHeader
@@ -188,12 +195,20 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 		}
 		}
 		t.remain += uint64(len(header))
 		t.remain += uint64(len(header))
 
 
-		err = t.writeBuffer()
+		if uint64(t.b.Len()) > t.remain {
+			// Should never happen, multiple requests written in a single
+			// Write() are not supported.
+			return len(b), errors.TraceNew("multiple HTTP requests in single Write() not supported")
+		}
+
+		n, err := t.b.WriteTo(t.Conn)
+		t.remain -= uint64(n)
 
 
 		if t.remain > 0 {
 		if t.remain > 0 {
 			t.state = httpTransformerReadWriteBody
 			t.state = httpTransformerReadWriteBody
 		}
 		}
 
 
+		// Do not wrap any I/O err returned by Conn
 		return len(b), err
 		return len(b), err
 	}
 	}
 
 
@@ -203,14 +218,20 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 	// Must write buffered bytes first, in-order, to write bytes to underlying
 	// Must write buffered bytes first, in-order, to write bytes to underlying
 	// net.Conn in the same order they were received in.
 	// net.Conn in the same order they were received in.
 	//
 	//
+	// Already checked that t.b does not contain bytes of a subsequent HTTP
+	// request when the header is parsed, i.e. at this point it is guaranteed
+	// that t.b.Len() <= t.remain.
+	//
 	// In practise the buffer will be empty by this point because its entire
 	// In practise the buffer will be empty by this point because its entire
-	// contents will have been written in the first call to t.writeBuffer()
+	// contents will have been written in the first call to t.b.WriteTo(t.Conn)
 	// when the header is received, parsed, and transformed; otherwise the
 	// when the header is received, parsed, and transformed; otherwise the
 	// underlying transport will have failed and the caller will not invoke
 	// underlying transport will have failed and the caller will not invoke
 	// Write() again on this instance. See HTTPTransformer.Write() comment.
 	// Write() again on this instance. See HTTPTransformer.Write() comment.
-	err := t.writeBuffer()
+	wrote, err := t.b.WriteTo(t.Conn)
+	t.remain -= uint64(wrote)
 	if err != nil {
 	if err != nil {
 		// b not written or buffered
 		// b not written or buffered
+		// Do not wrap any I/O err returned by Conn
 		return 0, err
 		return 0, err
 	}
 	}
 
 
@@ -229,41 +250,10 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 		t.remain = 0
 		t.remain = 0
 	}
 	}
 
 
+	// Do not wrap any I/O err returned by Conn
 	return n, err
 	return n, err
 }
 }
 
 
-func (t *HTTPTransformer) writeBuffer() error {
-
-	if uint64(len(t.b)) > t.remain {
-		// Should never happen, multiple requests written in a single
-		// Write() are not supported.
-		return errors.TraceNew("multiple HTTP requests in single Write() not supported")
-	}
-
-	// Continue to Write() buffered bytes to underlying net.Conn until Write()
-	// fails or all buffered bytes are written.
-	for len(t.b) > 0 {
-
-		var n int
-		n, err := t.Conn.Write(t.b)
-
-		t.remain -= uint64(n)
-
-		if n == len(t.b) {
-			t.b = nil
-		} else {
-			t.b = t.b[n:]
-		}
-
-		// Stop writing and return if there was an error
-		if err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
 func WrapDialerWithHTTPTransformer(dialer common.Dialer, params *HTTPTransformerParameters) common.Dialer {
 func WrapDialerWithHTTPTransformer(dialer common.Dialer, params *HTTPTransformerParameters) common.Dialer {
 	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 		conn, err := dialer(ctx, network, addr)
 		conn, err := dialer(ctx, network, addr)

+ 4 - 4
psiphon/common/transforms/httpTransformer_test.go

@@ -90,7 +90,7 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
 			wantOutput:     "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
 			wantOutput:     "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
 			chunkSize:      1,
 			chunkSize:      1,
 			connWriteLimit: 1,
 			connWriteLimit: 1,
-			connWriteErrs:  []error{nil, errors.New("err1"), errors.New("err2")},
+			connWriteErrs:  []error{errors.New("err1"), errors.New("err2")},
 			wantError:      errors.New("err1"),
 			wantError:      errors.New("err1"),
 		},
 		},
 		{
 		{
@@ -257,6 +257,9 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
 				if err != nil {
 				if err != nil {
 					t.Fatalf("unexpected error %v", err)
 					t.Fatalf("unexpected error %v", err)
 				}
 				}
+				if string(conn.b) != tt.wantOutput {
+					t.Fatalf("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(conn.b)), len(conn.b))
+				}
 			} else {
 			} else {
 				// tt.wantError != nil
 				// tt.wantError != nil
 				if err == nil {
 				if err == nil {
@@ -265,9 +268,6 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
 					t.Fatalf("expected error %v got %v", tt.wantError, err)
 					t.Fatalf("expected error %v got %v", tt.wantError, err)
 				}
 				}
 			}
 			}
-			if tt.wantError == nil && string(conn.b) != tt.wantOutput {
-				t.Fatalf("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(conn.b)), len(conn.b))
-			}
 		})
 		})
 	}
 	}
 }
 }

+ 43 - 17
psiphon/common/transforms/transforms.go

@@ -61,7 +61,7 @@ func (specs Specs) Validate() error {
 	}
 	}
 	for _, spec := range specs {
 	for _, spec := range specs {
 		// Call Apply to compile/validate the regular expressions and generators.
 		// Call Apply to compile/validate the regular expressions and generators.
-		_, err := spec.Apply(seed, "")
+		_, err := spec.ApplyString(seed, "")
 		if err != nil {
 		if err != nil {
 			return errors.Trace(err)
 			return errors.Trace(err)
 		}
 		}
@@ -142,30 +142,56 @@ func (specs Specs) Select(scope string, scopedSpecs ScopedSpecNames) (string, Sp
 	return specName, spec
 	return specName, spec
 }
 }
 
 
-// Apply applies the Spec to the input string, producing the output string.
+// ApplyString applies the Spec to the input string, producing the output string.
 //
 //
 // The input seed is used for all random generation. The same seed can be
 // The input seed is used for all random generation. The same seed can be
 // supplied to produce the same output, for replay.
 // supplied to produce the same output, for replay.
-func (spec Spec) Apply(seed *prng.Seed, input string) (string, error) {
-
-	// TODO: the compiled regexp and regen could be cached, but the seed is an
-	// issue with caching the regen.
+func (spec Spec) ApplyString(seed *prng.Seed, input string) (string, error) {
 
 
 	value := input
 	value := input
 	for _, transform := range spec {
 	for _, transform := range spec {
 
 
-		args := &regen.GeneratorArgs{
-			RngSource: prng.NewPRNGWithSeed(seed),
-			Flags:     syntax.OneLine | syntax.NonGreedy,
-		}
-		rg, err := regen.NewGenerator(transform[1], args)
-		if err != nil {
-			panic(err.Error())
-		}
-		replacement := rg.Generate()
-
-		re := regexp.MustCompile(transform[0])
+		re, replacement := makeRegexAndRepl(seed, transform)
 		value = re.ReplaceAllString(value, replacement)
 		value = re.ReplaceAllString(value, replacement)
 	}
 	}
 	return value, nil
 	return value, nil
 }
 }
+
+// Apply applies the Spec to the input bytes, producing the output bytes.
+//
+// The input seed is used for all random generation. The same seed can be
+// supplied to produce the same output, for replay.
+func (spec Spec) Apply(seed *prng.Seed, input []byte) ([]byte, error) {
+
+	value := input
+	for _, transform := range spec {
+
+		re, replacement := makeRegexAndRepl(seed, transform)
+		value = re.ReplaceAll(value, []byte(replacement))
+	}
+	return value, nil
+}
+
+// makeRegexAndRepl generates the regex and replacement for a given seed and
+// transform. The same seed can be supplied to produce the same output, for
+// replay.
+func makeRegexAndRepl(seed *prng.Seed, transform [2]string) (re *regexp.Regexp, replacement string) {
+
+	// TODO: the compiled regexp and regen could be cached, but the seed is an
+	// issue with caching the regen.
+
+	args := &regen.GeneratorArgs{
+		RngSource: prng.NewPRNGWithSeed(seed),
+		Flags:     syntax.OneLine | syntax.NonGreedy,
+	}
+	rg, err := regen.NewGenerator(transform[1], args)
+	if err != nil {
+		panic(err.Error())
+	}
+
+	replacement = rg.Generate()
+
+	re = regexp.MustCompile(transform[0])
+
+	return
+}

+ 3 - 3
psiphon/common/transforms/transforms_test.go

@@ -87,7 +87,7 @@ func runTestTransforms() error {
 	}
 	}
 
 
 	input := "aa0aa0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa"
 	input := "aa0aa0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa0bb0aa"
-	output, err := spec.Apply(seed, input)
+	output, err := spec.ApplyString(seed, input)
 	if err != nil {
 	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
@@ -102,7 +102,7 @@ func runTestTransforms() error {
 
 
 	previousOutput := output
 	previousOutput := output
 
 
-	output, err = spec.Apply(seed, input)
+	output, err = spec.ApplyString(seed, input)
 	if err != nil {
 	if err != nil {
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
@@ -121,7 +121,7 @@ func runTestTransforms() error {
 			return errors.Trace(err)
 			return errors.Trace(err)
 		}
 		}
 
 
-		output, err = spec.Apply(seed, input)
+		output, err = spec.ApplyString(seed, input)
 		if err != nil {
 		if err != nil {
 			return errors.Trace(err)
 			return errors.Trace(err)
 		}
 		}