|
|
@@ -31,6 +31,7 @@ import (
|
|
|
"net/http"
|
|
|
"os"
|
|
|
"sync"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
"github.com/Psiphon-Labs/dns"
|
|
|
@@ -175,22 +176,49 @@ func (d *NetDialer) DialContext(ctx context.Context, network, address string) (n
|
|
|
|
|
|
// LocalProxyRelay sends to remoteConn bytes received from localConn,
|
|
|
// and sends to localConn bytes received from remoteConn.
|
|
|
+//
|
|
|
+// LocalProxyRelay must close localConn in order to interrupt blocking
|
|
|
+// I/O calls when the upstream port forward is closed. remoteConn is
|
|
|
+// also closed before returning.
|
|
|
func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
|
|
|
+
|
|
|
+ closing := int32(0)
|
|
|
+
|
|
|
copyWaitGroup := new(sync.WaitGroup)
|
|
|
copyWaitGroup.Add(1)
|
|
|
+
|
|
|
go func() {
|
|
|
defer copyWaitGroup.Done()
|
|
|
+
|
|
|
_, err := io.Copy(localConn, remoteConn)
|
|
|
- if err != nil {
|
|
|
+ if err != nil && atomic.LoadInt32(&closing) != 1 {
|
|
|
err = fmt.Errorf("Relay failed: %s", common.ContextError(err))
|
|
|
NoticeLocalProxyError(proxyType, err)
|
|
|
}
|
|
|
+
|
|
|
+ // When the server closes a port forward, ex. due to idle timeout,
|
|
|
+ // remoteConn.Read will return EOF, which causes the downstream io.Copy to
|
|
|
+ // return (with a nil error). To ensure the downstream local proxy
|
|
|
+ // connection also closes at this point, we interrupt the blocking upstream
|
|
|
+ // io.Copy by closing localConn.
|
|
|
+
|
|
|
+ atomic.StoreInt32(&closing, 1)
|
|
|
+ localConn.Close()
|
|
|
}()
|
|
|
+
|
|
|
_, err := io.Copy(remoteConn, localConn)
|
|
|
- if err != nil {
|
|
|
+ if err != nil && atomic.LoadInt32(&closing) != 1 {
|
|
|
err = fmt.Errorf("Relay failed: %s", common.ContextError(err))
|
|
|
NoticeLocalProxyError(proxyType, err)
|
|
|
}
|
|
|
+
|
|
|
+ // When a local proxy peer connection closes, localConn.Read will return EOF.
|
|
|
+ // As above, close the other end of the relay to ensure immediate shutdown,
|
|
|
+ // as no more data can be relayed.
|
|
|
+
|
|
|
+ atomic.StoreInt32(&closing, 1)
|
|
|
+ remoteConn.Close()
|
|
|
+
|
|
|
copyWaitGroup.Wait()
|
|
|
}
|
|
|
|
|
|
@@ -275,8 +303,8 @@ func MakeUntunneledHTTPClient(
|
|
|
// Note: when verifyLegacyCertificate is not nil, some
|
|
|
// of the other CustomTLSConfig is overridden.
|
|
|
tlsConfig := &CustomTLSConfig{
|
|
|
- ClientParameters: config.clientParameters,
|
|
|
- Dial: dialer,
|
|
|
+ ClientParameters: config.clientParameters,
|
|
|
+ Dial: dialer,
|
|
|
VerifyLegacyCertificate: verifyLegacyCertificate,
|
|
|
UseDialAddrSNI: true,
|
|
|
SNIServerName: "",
|