Просмотр исходного кода

Ensure io.CopyBuffer uses a buffer where intended

- In these existing cases, the memory impact of buffer
  size is more important than potential zero-copy
  optimizations.

- Both existing cases happen to use the specified
  buffer due to the dst and src types. This change
  should simply make the code more robust in case
  of future changes.
Rod Hynes 6 лет назад
Родитель
Сommit
b67dc7d1bb
2 измененных файлов с 12 добавлено и 6 удалено
  1. 7 1
      psiphon/common/utils.go
  2. 5 5
      psiphon/server/tunnelServer.go

+ 7 - 1
psiphon/common/utils.go

@@ -167,10 +167,16 @@ func FormatByteCount(bytes uint64) string {
 		"%.1f%c", float64(bytes)/math.Pow(float64(base), float64(exp)), "KMGTPEZ"[exp-1])
 		"%.1f%c", float64(bytes)/math.Pow(float64(base), float64(exp)), "KMGTPEZ"[exp-1])
 }
 }
 
 
+// CopyBuffer calls io.CopyBuffer, masking out any src.WriteTo or dst.ReadFrom
+// to force use of the specified buf.
+func CopyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) {
+	return io.CopyBuffer(struct{ io.Writer }{dst}, struct{ io.Reader }{src}, buf)
+}
+
 func CopyNBuffer(dst io.Writer, src io.Reader, n int64, buf []byte) (written int64, err error) {
 func CopyNBuffer(dst io.Writer, src io.Reader, n int64, buf []byte) (written int64, err error) {
 	// Based on io.CopyN:
 	// Based on io.CopyN:
 	// https://github.com/golang/go/blob/release-branch.go1.11/src/io/io.go#L339
 	// https://github.com/golang/go/blob/release-branch.go1.11/src/io/io.go#L339
-	written, err = io.CopyBuffer(dst, io.LimitReader(src, n), buf)
+	written, err = CopyBuffer(dst, io.LimitReader(src, n), buf)
 	if written == n {
 	if written == n {
 		return n, nil
 		return n, nil
 	}
 	}

+ 5 - 5
psiphon/server/tunnelServer.go

@@ -3369,10 +3369,10 @@ func (sshClient *sshClient) handleTCPChannel(
 	relayWaitGroup.Add(1)
 	relayWaitGroup.Add(1)
 	go func() {
 	go func() {
 		defer relayWaitGroup.Done()
 		defer relayWaitGroup.Done()
-		// io.Copy allocates a 32K temporary buffer, and each port forward relay uses
-		// two of these buffers; using io.CopyBuffer with a smaller buffer reduces the
-		// overall memory footprint.
-		bytes, err := io.CopyBuffer(
+		// io.Copy allocates a 32K temporary buffer, and each port forward relay
+		// uses two of these buffers; using common.CopyBuffer with a smaller buffer
+		// reduces the overall memory footprint.
+		bytes, err := common.CopyBuffer(
 			fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
 			fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
 		atomic.AddInt64(&bytesDown, bytes)
 		atomic.AddInt64(&bytesDown, bytes)
 		if err != nil && err != io.EOF {
 		if err != nil && err != io.EOF {
@@ -3385,7 +3385,7 @@ func (sshClient *sshClient) handleTCPChannel(
 		// be flowing?
 		// be flowing?
 		fwdChannel.Close()
 		fwdChannel.Close()
 	}()
 	}()
-	bytes, err := io.CopyBuffer(
+	bytes, err := common.CopyBuffer(
 		fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
 		fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
 	atomic.AddInt64(&bytesUp, bytes)
 	atomic.AddInt64(&bytesUp, bytes)
 	if err != nil && err != io.EOF {
 	if err != nil && err != io.EOF {