|
|
@@ -220,7 +220,7 @@ func (d *NetDialer) DialContext(ctx context.Context, network, address string) (n
|
|
|
// LocalProxyRelay must close localConn in order to interrupt blocking
|
|
|
// I/O calls when the upstream port forward is closed. remoteConn is
|
|
|
// also closed before returning.
|
|
|
-func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
|
|
|
+func LocalProxyRelay(config *Config, proxyType string, localConn, remoteConn net.Conn) {
|
|
|
|
|
|
closing := int32(0)
|
|
|
|
|
|
@@ -230,7 +230,7 @@ func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
|
|
|
go func() {
|
|
|
defer copyWaitGroup.Done()
|
|
|
|
|
|
- _, err := io.Copy(localConn, remoteConn)
|
|
|
+ _, err := RelayCopyBuffer(config, localConn, remoteConn)
|
|
|
if err != nil && atomic.LoadInt32(&closing) != 1 {
|
|
|
NoticeLocalProxyError(proxyType, errors.TraceMsg(err, "Relay failed"))
|
|
|
}
|
|
|
@@ -245,7 +245,7 @@ func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
|
|
|
localConn.Close()
|
|
|
}()
|
|
|
|
|
|
- _, err := io.Copy(remoteConn, localConn)
|
|
|
+ _, err := RelayCopyBuffer(config, remoteConn, localConn)
|
|
|
if err != nil && atomic.LoadInt32(&closing) != 1 {
|
|
|
NoticeLocalProxyError(proxyType, errors.TraceMsg(err, "Relay failed"))
|
|
|
}
|
|
|
@@ -260,6 +260,29 @@ func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
|
|
|
copyWaitGroup.Wait()
|
|
|
}
|
|
|
|
|
|
+// RelayCopyBuffer performs an io.Copy, optionally using a smaller buffer when
|
|
|
+// config.LimitRelayBufferSizes is set.
|
|
|
+func RelayCopyBuffer(config *Config, dst io.Writer, src io.Reader) (int64, error) {
|
|
|
+
|
|
|
+ // By default, io.CopyBuffer will allocate a 32K buffer when a nil buffer
|
|
|
+ // is passed in. When configured, make and specify a smaller buffer. But
|
|
|
+ // only if src doesn't implement WriterTo and dst doesn't implement
|
|
|
+ // ReaderFrom, as in those cases io.CopyBuffer entirely avoids a buffer
|
|
|
+ // allocation.
|
|
|
+
|
|
|
+ var buffer []byte
|
|
|
+ if config.LimitRelayBufferSizes {
|
|
|
+ _, isWT := src.(io.WriterTo)
|
|
|
+ _, isRF := dst.(io.ReaderFrom)
|
|
|
+ if !isWT && !isRF {
|
|
|
+ buffer = make([]byte, 4096)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Do not wrap any I/O errors
|
|
|
+ return io.CopyBuffer(dst, src, buffer)
|
|
|
+}
|
|
|
+
|
|
|
// WaitForNetworkConnectivity uses a NetworkConnectivityChecker to
|
|
|
// periodically check for network connectivity. It returns true if
|
|
|
// no NetworkConnectivityChecker is provided (waiting is disabled)
|