|
|
@@ -74,8 +74,9 @@ type HTTPTransformer struct {
|
|
|
// state is the HTTPTransformer state. Possible values are
|
|
|
// httpTransformerReadWriteHeader and httpTransformerReadWriteBody.
|
|
|
state int64
|
|
|
- // b is the accumulated bytes of the current HTTP request.
|
|
|
- b []byte
|
|
|
+ // b is used to buffer the accumulated bytes of the current HTTP request
|
|
|
+ // header until the entire header is received and written.
|
|
|
+ b bytes.Buffer
|
|
|
// remain is the number of remaining HTTP request body bytes to read into b.
|
|
|
remain uint64
|
|
|
|
|
|
@@ -100,7 +101,8 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
|
|
|
|
|
|
if t.state == httpTransformerReadWriteHeader {
|
|
|
|
|
|
- t.b = append(t.b, b...)
|
|
|
+ // Do not need to check return value https://github.com/golang/go/blob/1e9ff255a130200fcc4ec5e911d28181fce947d5/src/bytes/buffer.go#L164
|
|
|
+ t.b.Write(b)
|
|
|
|
|
|
// Wait until the entire HTTP request header has been read. Must check
|
|
|
// all accumulated bytes incase the "\r\n\r\n" separator is written over
|
|
|
@@ -109,7 +111,7 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
|
|
|
|
|
|
sep := []byte("\r\n\r\n")
|
|
|
|
|
|
- headerBodyLines := bytes.SplitN(t.b, sep, 2) // split header and body
|
|
|
+ headerBodyLines := bytes.SplitN(t.b.Bytes(), sep, 2) // split header and body
|
|
|
|
|
|
if len(headerBodyLines) <= 1 {
|
|
|
// b buffered in t.b and the entire HTTP request header has not been
|
|
|
@@ -158,10 +160,10 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
|
|
|
// transform and write header
|
|
|
|
|
|
headerLen := len(headerBodyLines[0]) + len(sep)
|
|
|
- header := t.b[:headerLen]
|
|
|
+ header := t.b.Bytes()[:headerLen]
|
|
|
|
|
|
if t.transform != nil {
|
|
|
- newHeaderS, err := t.transform.Apply(t.seed, string(header))
|
|
|
+ newHeader, err := t.transform.Apply(t.seed, header)
|
|
|
if err != nil {
|
|
|
// TODO: consider logging an error and skiping transform
|
|
|
// instead of returning an error, if the transform is broken
|
|
|
@@ -169,13 +171,18 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
|
|
|
return len(b), errors.Trace(err)
|
|
|
}
|
|
|
|
|
|
- newHeader := []byte(newHeaderS)
|
|
|
-
|
|
|
// only allocate new slice if header length changed
|
|
|
if len(newHeader) == len(header) {
|
|
|
- copy(t.b[:len(header)], newHeader)
|
|
|
+ // Do not need to check return value. It is guaranteed that
|
|
|
+ // n == len(newHeader) because t.b.Len() >= n if the header
|
|
|
+ // size has not changed.
|
|
|
+ copy(t.b.Bytes()[:len(header)], newHeader)
|
|
|
} else {
|
|
|
- t.b = append(newHeader, t.b[len(header):]...)
|
|
|
+ b := t.b.Bytes()
|
|
|
+ t.b.Reset()
|
|
|
+ // Do not need to check return value of bytes.Buffer.Write() https://github.com/golang/go/blob/1e9ff255a130200fcc4ec5e911d28181fce947d5/src/bytes/buffer.go#L164
|
|
|
+ t.b.Write(newHeader)
|
|
|
+ t.b.Write(b[len(header):])
|
|
|
}
|
|
|
|
|
|
header = newHeader
|
|
|
@@ -188,12 +195,20 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
|
|
|
}
|
|
|
t.remain += uint64(len(header))
|
|
|
|
|
|
- err = t.writeBuffer()
|
|
|
+ if uint64(t.b.Len()) > t.remain {
|
|
|
+ // Should never happen, multiple requests written in a single
|
|
|
+ // Write() are not supported.
|
|
|
+ return len(b), errors.TraceNew("multiple HTTP requests in single Write() not supported")
|
|
|
+ }
|
|
|
+
|
|
|
+ n, err := t.b.WriteTo(t.Conn)
|
|
|
+ t.remain -= uint64(n)
|
|
|
|
|
|
if t.remain > 0 {
|
|
|
t.state = httpTransformerReadWriteBody
|
|
|
}
|
|
|
|
|
|
+ // Do not wrap any I/O err returned by Conn
|
|
|
return len(b), err
|
|
|
}
|
|
|
|
|
|
@@ -203,14 +218,20 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
|
|
|
// Must write buffered bytes first, in-order, to write bytes to underlying
|
|
|
// net.Conn in the same order they were received in.
|
|
|
//
|
|
|
+ // Already checked that t.b does not contain bytes of a subsequent HTTP
|
|
|
+ // request when the header is parsed, i.e. at this point it is guaranteed
|
|
|
+ // that t.b.Len() <= t.remain.
|
|
|
+ //
|
|
|
// In practise the buffer will be empty by this point because its entire
|
|
|
- // contents will have been written in the first call to t.writeBuffer()
|
|
|
+ // contents will have been written in the first call to t.b.WriteTo(t.Conn)
|
|
|
// when the header is received, parsed, and transformed; otherwise the
|
|
|
// underlying transport will have failed and the caller will not invoke
|
|
|
// Write() again on this instance. See HTTPTransformer.Write() comment.
|
|
|
- err := t.writeBuffer()
|
|
|
+ wrote, err := t.b.WriteTo(t.Conn)
|
|
|
+ t.remain -= uint64(wrote)
|
|
|
if err != nil {
|
|
|
// b not written or buffered
|
|
|
+ // Do not wrap any I/O err returned by Conn
|
|
|
return 0, err
|
|
|
}
|
|
|
|
|
|
@@ -229,41 +250,10 @@ func (t *HTTPTransformer) Write(b []byte) (int, error) {
|
|
|
t.remain = 0
|
|
|
}
|
|
|
|
|
|
+ // Do not wrap any I/O err returned by Conn
|
|
|
return n, err
|
|
|
}
|
|
|
|
|
|
-func (t *HTTPTransformer) writeBuffer() error {
|
|
|
-
|
|
|
- if uint64(len(t.b)) > t.remain {
|
|
|
- // Should never happen, multiple requests written in a single
|
|
|
- // Write() are not supported.
|
|
|
- return errors.TraceNew("multiple HTTP requests in single Write() not supported")
|
|
|
- }
|
|
|
-
|
|
|
- // Continue to Write() buffered bytes to underlying net.Conn until Write()
|
|
|
- // fails or all buffered bytes are written.
|
|
|
- for len(t.b) > 0 {
|
|
|
-
|
|
|
- var n int
|
|
|
- n, err := t.Conn.Write(t.b)
|
|
|
-
|
|
|
- t.remain -= uint64(n)
|
|
|
-
|
|
|
- if n == len(t.b) {
|
|
|
- t.b = nil
|
|
|
- } else {
|
|
|
- t.b = t.b[n:]
|
|
|
- }
|
|
|
-
|
|
|
- // Stop writing and return if there was an error
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
func WrapDialerWithHTTPTransformer(dialer common.Dialer, params *HTTPTransformerParameters) common.Dialer {
|
|
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
conn, err := dialer(ctx, network, addr)
|