Преглед изворни кода

Fix: handle error in initial header write

mirokuratczyk пре 3 година
родитељ
комит
8758db0e9c

+ 80 - 43
psiphon/common/transforms/httpTransformer.go

@@ -122,19 +122,29 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 				}
 			}
 			if len(cl) == 0 {
-				// Either Content-Length header missing or Content-Length
-				// header value is empty, e.g. "Content-Length: ".
-				// b buffered in t.b
+				// Irrecoverable error because either Content-Length header
+				// missing, or Content-Length header value is empty, e.g.
+				// "Content-Length: ", and request body length cannot be
+				// determined.
+				//
+				// b buffered in t.b, return len(b) in an attempt to get
+				// through the current Write() sequence instead of getting
+				// stuck.
 				return len(b), errors.TraceNew("Content-Length missing")
 			}
 
-			n, err := strconv.ParseUint(string(cl), 10, 63)
+			contentLength, err := strconv.ParseUint(string(cl), 10, 63)
 			if err != nil {
-				// b buffered in t.b
+				// Irrecoverable error because Content-Length is malformed and
+				// request body length cannot be determined.
+				//
+				// b buffered in t.b, return len(b) in an attempt to get
+				// through the current Write() sequence instead of getting
+				// stuck.
 				return len(b), errors.Trace(err)
 			}
 
-			t.remain = n
+			t.remain = contentLength
 
 			// transform and write header
 
@@ -144,7 +154,13 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 			if t.transform != nil {
 				newHeaderS, err := t.transform.Apply(t.seed, string(header))
 				if err != nil {
-					// b buffered in t.b
+					// TODO: consider logging an error and skiping transform
+					// instead of returning an error, if the transform is broken
+					// then all subsequent applications may fail.
+					//
+					// b buffered in t.b, return len(b) in an attempt to get
+					// through the current Write() sequence instead of getting
+					// stuck.
 					return len(b), errors.Trace(err)
 				}
 
@@ -161,41 +177,45 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 			}
 
 			if math.MaxUint64-t.remain < uint64(len(header)) {
-				// b buffered in t.b
+				// Irrecoverable error because request is malformed:
+				// Content-Length + len(header) > math.MaxUint64.
+				//
+				// b buffered in t.b, return len(b) in an attempt to get
+				// through the current Write() sequence instead of getting
+				// stuck.
 				return len(b), errors.TraceNew("t.remain + uint64(len(header)) overflows")
 			}
 			t.remain += uint64(len(header))
 
-			err = t.writeBuffer()
+			n, err := t.writeBuffer()
+
+			written := len(b) // all bytes of b buffered in t.b
+
+			if n < len(header) ||
+				len(t.b) > 0 && t.remain == 0 {
+				// All bytes of b were not written, but all bytes of b have been
+				// buffered in t.b. Drop 1 byte of b from t.b to pretend 1 byte
+				// of b was not written to trigger another Write() call. This
+				// handles the scenario where all request bytes have been
+				// received but writing to the underlying net.Conn fails and
+				// another Write() call cannot be expected unless a value
+				// less than len(b) is returned. An alternative solution would
+				// be to retry writes, or spawn a goroutine which writes t.b,
+				// but we want to return the error to the caller immediately so
+				// it can act accordingly.
+				written = len(b) - 1
+				t.b = t.b[:len(t.b)-1]
+			}
 
 			if t.remain > 0 {
 				t.state = httpTransformerReadWriteBody
-			} else {
-				// Entire request, header and body, has been written. Return to
-				// waiting for next HTTP request header to arrive.
-				if len(t.b) > 0 {
-					// Return the number of bytes written to the underlying
-					// Conn and clear t.b instead of calling t.Write() with any
-					// remaining bytes of t.b. The caller must call Write()
-					// again with the unwritten, and unbuffered, bytes of b.
-					// Since t.remain = 0 it is guaranteed that
-					// len(b) - len(t.b) >= 0 because len(t.b) is the number of
-					// subsequent request bytes and len(b) is the number of
-					// trailing bytes of the current request plus the
-					// subsequent request bytes.
-					written := len(b) - len(t.b)
-					t.b = nil
-					return written, err
-				}
 			}
 
-			if err != nil {
-				// b buffered in t.b
-				return len(b), err
-			}
+			return written, err
 		}
 
-		// b buffered in t.b
+		// b buffered in t.b and the entire HTTP request header has not been
+		// recieved so another Write() call is expected.
 		return len(b), nil
 	}
 
@@ -204,18 +224,19 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 
 	// Must write buffered bytes first, in-order, to write bytes to underlying
 	// Conn in the same order they were received in.
-	err := t.writeBuffer()
+	_, err := t.writeBuffer()
 	if err != nil {
 		// b not written or buffered
 		return 0, errors.Trace(err)
 	}
 
-	bytesToWrite := uint64(len(b))
-	if bytesToWrite > t.remain {
-		bytesToWrite = t.remain
+	// Only write bytes of current request
+	writeN := uint64(len(b))
+	if writeN > t.remain {
+		writeN = t.remain
 	}
 
-	n, err := t.Conn.Write(b[:bytesToWrite])
+	n, err := t.Conn.Write(b[:writeN])
 
 	// Do not need to check for underflow because n <= t.remain
 	t.remain -= uint64(n)
@@ -235,15 +256,29 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
 	return n, err
 }
 
-func (t *HTTPTransformer) writeBuffer() error {
+func (t *HTTPTransformer) writeBuffer() (written int, err error) {
+
+	// Continue writing buffered bytes until either all buffered bytes have
+	// been written or all remaining bytes of the current HTTP request have
+	// been written.
 	for len(t.b) > 0 && t.remain > 0 {
 
-		bytesToWrite := uint64(len(t.b))
-		if bytesToWrite > t.remain {
-			bytesToWrite = t.remain
+		// Write all buffered bytes of the current request
+		writeN := uint64(len(t.b))
+		if writeN > t.remain {
+			// t.b contains bytes of the next request(s), only write current
+			// request bytes.
+			writeN = t.remain
+		}
+
+		// Check for potential overflow before Write() call
+		if math.MaxInt-written < int(writeN) {
+			return written, errors.TraceNew("written + bytesToWrite overflows")
 		}
 
-		n, err := t.Conn.Write(t.b[:bytesToWrite])
+		var n int
+		n, err = t.Conn.Write(t.b[:writeN])
+		written += n
 
 		// Do not need to check for underflow because n <= t.remain
 		t.remain -= uint64(n)
@@ -254,11 +289,13 @@ func (t *HTTPTransformer) writeBuffer() error {
 			t.b = t.b[n:]
 		}
 
+		// Stop writing and return if there was an error
 		if err != nil {
-			return err
+			return
 		}
 	}
-	return nil
+
+	return
 }
 
 func WrapDialerWithHTTPTransformer(dialer common.Dialer, params *HTTPTransformerParameters) common.Dialer {

+ 70 - 36
psiphon/common/transforms/httpTransformer_test.go

@@ -45,60 +45,80 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
 		chunkSize      int
 		transform      Spec
 		connWriteLimit int
+		connWriteLens  []int
 		connWriteErrs  []error
 	}
 
 	tests := []test{
 		{
-			name:       "no transform",
-			input:      "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
-			wantOutput: "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			name:       "written in chunks",
+			input:      "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			wantOutput: "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
 			chunkSize:  1,
 		},
 		{
-			name:           "no transform with partial write and errors",
-			input:          "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
-			wantOutput:     "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			name:       "written in a single write",
+			input:      "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			wantOutput: "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			chunkSize:  999,
+		},
+		{
+			name:          "written in single write with error",
+			input:         "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			wantOutput:    "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			chunkSize:     999,
+			connWriteErrs: []error{errors.New("err1")},
+		},
+		{
+			name:           "written with partial write and errors",
+			input:          "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			wantOutput:     "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
 			chunkSize:      1,
 			connWriteLimit: 1,
 			connWriteErrs:  []error{errors.New("err1"), errors.New("err2")},
 		},
 		{
 			name:       "transform not applied to body",
-			input:      "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
-			wantOutput: "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			input:      "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			wantOutput: "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
 			chunkSize:  1,
 			transform:  Spec{[2]string{"abcd", "efgh"}},
 		},
 		{
 			name:      "Content-Length missing",
-			input:     "HTTP 1.1\r\n\r\nabcd",
+			input:     "POST / HTTP/1.1\r\n\r\nabcd",
 			wantError: errors.New("Content-Length missing"),
 			chunkSize: 1,
 		},
 		{
 			name:      "Content-Length overflow",
-			input:     fmt.Sprintf("HTTP 1.1\r\nContent-Length: %d\r\n\r\nabcd", uint64(math.MaxUint64)),
+			input:     fmt.Sprintf("POST / HTTP/1.1\r\nContent-Length: %d\r\n\r\nabcd", uint64(math.MaxUint64)),
 			wantError: errors.New("strconv.ParseUint: parsing \"18446744073709551615\": value out of range"),
 			chunkSize: 1,
 		},
 		{
-			name:       "no transform",
-			input:      "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
-			wantOutput: "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			name:       "incorrect Content-Length header value",
+			input:      "POST / HTTP/1.1\r\nContent-Length: 3\r\n\r\nabcd",
+			wantOutput: "POST / HTTP/1.1\r\nContent-Length: 3\r\n\r\nabc",
 			chunkSize:  1,
 		},
 		{
-			name:       "incorrect Content-Length header value",
-			input:      "HTTP 1.1\r\nContent-Length: 3\r\n\r\nabcd",
-			wantOutput: "HTTP 1.1\r\nContent-Length: 3\r\n\r\nabc",
-			chunkSize:  1,
+			name:          "written in a single write with errors and partial writes",
+			input:         "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\n\r\n",
+			wantOutput:    "POST / HTTP/1.1\r\nContent-Length: 0\r\n\r\n",
+			chunkSize:     999,
+			transform:     Spec{[2]string{"Host: example.com\r\n", ""}},
+			connWriteErrs: []error{errors.New("err1"), nil, errors.New("err2"), nil, nil, errors.New("err3")},
+			connWriteLens: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
 		},
 		{
-			name:       "single HTTP request written in a single write",
-			input:      "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
-			wantOutput: "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcd",
-			chunkSize:  999,
+			name:          "written in a single write with error and partial write",
+			input:         "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 4\r\n\r\nabcd",
+			wantOutput:    "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcd",
+			chunkSize:     999,
+			transform:     Spec{[2]string{"Host: example.com\r\n", ""}},
+			connWriteErrs: []error{errors.New("err1")},
+			connWriteLens: []int{28}, // write lands mid "\r\n\r\n"
 		},
 		{
 			name:       "transform",
@@ -128,8 +148,8 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
 		// Multiple HTTP requests written in a single write.
 		{
 			name:       "multiple HTTP requests written in a single write",
-			input:      "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcdHTTP 1.1\r\nContent-Length: 2\r\n\r\n12",
-			wantOutput: "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcdHTTP 1.1\r\nContent-Length: 2\r\n\r\n12",
+			input:      "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcdPOST / HTTP/1.1\r\nContent-Length: 2\r\n\r\n12",
+			wantOutput: "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcdPOST / HTTP/1.1\r\nContent-Length: 2\r\n\r\n12",
 			chunkSize:  999,
 		},
 		// Multiple HTTP requests written in a single write. A write will occur
@@ -137,15 +157,15 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
 		// start of a new one.
 		{
 			name:       "multiple HTTP requests written in chunks",
-			input:      "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcdHTTP 1.1\r\nContent-Length: 2\r\n\r\n12",
-			wantOutput: "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcdHTTP 1.1\r\nContent-Length: 2\r\n\r\n12",
+			input:      "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcdPOST / HTTP/1.1\r\nContent-Length: 2\r\n\r\n12",
+			wantOutput: "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcdPOST / HTTP/1.1\r\nContent-Length: 2\r\n\r\n12",
 			chunkSize:  3,
 		},
 		// Multiple HTTP requests written in a single write with transform.
 		{
-			name:       "multiple HTTP requests written in a single write",
-			input:      "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcdHTTP 1.1\r\nContent-Length: 4\r\n\r\n12HTTP 1.1\r\nContent-Length: 4\r\n\r\n34",
-			wantOutput: "HTTP 1.1\r\nContent-Length: 100\r\n\r\nabcdHTTP 1.1\r\nContent-Length: 100\r\n\r\n12HTTP 1.1\r\nContent-Length: 100\r\n\r\n34",
+			name:       "multiple HTTP requests written in a single write with transform",
+			input:      "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcdPOST / HTTP/1.1\r\nContent-Length: 4\r\n\r\n12POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\n34",
+			wantOutput: "POST / HTTP/1.1\r\nContent-Length: 100\r\n\r\nabcdPOST / HTTP/1.1\r\nContent-Length: 100\r\n\r\n12POST / HTTP/1.1\r\nContent-Length: 100\r\n\r\n34",
 			chunkSize:  999,
 			transform:  Spec{[2]string{"4", "100"}},
 		},
@@ -153,10 +173,10 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
 		// write will occur where it contains both the end of the previous HTTP
 		// request and the start of a new one.
 		{
-			name:       "multiple HTTP requests written in chunks",
-			input:      "HTTP 1.1\r\nContent-Length: 4\r\n\r\nabcdHTTP 1.1\r\nContent-Length: 4\r\n\r\n12",
-			wantOutput: "HTTP 1.1\r\nContent-Length: 100\r\n\r\nabcdHTTP 1.1\r\nContent-Length: 100\r\n\r\n12",
-			chunkSize:  3,
+			name:       "multiple HTTP requests written in chunks with transform",
+			input:      "POST / HTTP/1.1\r\nContent-Length: 4\r\n\r\nabcdPOST / HTTP/1.1\r\nContent-Length: 4\r\n\r\n12",
+			wantOutput: "POST / HTTP/1.1\r\nContent-Length: 100\r\n\r\nabcdPOST / HTTP/1.1\r\nContent-Length: 100\r\n\r\n12",
+			chunkSize:  4, // ensure one write contains bytes from both reqs
 			transform:  Spec{[2]string{"4", "100"}},
 		},
 	}
@@ -171,6 +191,7 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
 
 			conn := testConn{
 				writeLimit: tt.connWriteLimit,
+				writeLens:  tt.connWriteLens,
 				writeErrs:  tt.connWriteErrs,
 			}
 
@@ -354,6 +375,10 @@ type testConn struct {
 	// writeLimit is the max number of bytes that will be written in a Write()
 	// call.
 	writeLimit int
+	// writeLens are returned from Write() calls in order and determine the
+	// max number of bytes that will be written. Overrides writeLimit if
+	// non-empty. If empty, then the value of writeLimit is returned.
+	writeLens []int
 	// writeErrs are returned from Write() calls in order. If empty, then a nil
 	// error is returned.
 	writeErrs []error
@@ -370,14 +395,23 @@ func (c *testConn) Write(b []byte) (n int, err error) {
 		c.writeErrs = c.writeErrs[1:]
 	}
 
-	if c.writeLimit != 0 && c.writeLimit < len(b) {
+	if len(c.writeLens) > 0 {
+		n = c.writeLens[0]
+		c.writeLens = c.writeLens[1:]
+		if len(b) <= n {
+			c.b = append(c.b, b...)
+			n = len(b)
+		} else {
+			c.b = append(c.b, b[:n]...)
+		}
+	} else if c.writeLimit != 0 && c.writeLimit < len(b) {
 		c.b = append(c.b, b[:c.writeLimit]...)
 		n = c.writeLimit
-		return
+	} else {
+		c.b = append(c.b, b...)
+		n = len(b)
 	}
 
-	c.b = append(c.b, b...)
-	n = len(b)
 	return
 }