|
|
@@ -34,7 +34,6 @@ import (
|
|
|
"net/url"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
- "sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
|
|
|
@@ -244,15 +243,14 @@ type MeekConn struct {
|
|
|
tlsPadding int
|
|
|
limitRequestPayloadLength int
|
|
|
redialTLSProbability float64
|
|
|
- underlyingDialer common.Dialer
|
|
|
- cachedTLSDialer *cachedTLSDialer
|
|
|
transport transporter
|
|
|
- mutex sync.Mutex
|
|
|
- isClosed bool
|
|
|
- runCtx context.Context
|
|
|
- stopRunning context.CancelFunc
|
|
|
- relayWaitGroup *sync.WaitGroup
|
|
|
- firstUnderlyingConn net.Conn
|
|
|
+ connManager *meekUnderlyingConnManager
|
|
|
+
|
|
|
+ mutex sync.Mutex
|
|
|
+ isClosed bool
|
|
|
+ runCtx context.Context
|
|
|
+ stopRunning context.CancelFunc
|
|
|
+ relayWaitGroup *sync.WaitGroup
|
|
|
|
|
|
// For MeekModeObfuscatedRoundTrip
|
|
|
meekCookieEncryptionPublicKey string
|
|
|
@@ -324,20 +322,6 @@ func DialMeek(
|
|
|
|
|
|
runCtx, stopRunning := context.WithCancel(context.Background())
|
|
|
|
|
|
- cleanupStopRunning := true
|
|
|
- cleanupCachedTLSDialer := true
|
|
|
- var cachedTLSDialer *cachedTLSDialer
|
|
|
-
|
|
|
- // Cleanup in error cases
|
|
|
- defer func() {
|
|
|
- if cleanupStopRunning {
|
|
|
- stopRunning()
|
|
|
- }
|
|
|
- if cleanupCachedTLSDialer && cachedTLSDialer != nil {
|
|
|
- cachedTLSDialer.close()
|
|
|
- }
|
|
|
- }()
|
|
|
-
|
|
|
meek := &MeekConn{
|
|
|
params: meekConfig.Parameters,
|
|
|
mode: meekConfig.Mode,
|
|
|
@@ -348,6 +332,19 @@ func DialMeek(
|
|
|
relayWaitGroup: new(sync.WaitGroup),
|
|
|
}
|
|
|
|
|
|
+ cleanupStopRunning := true
|
|
|
+ cleanupConns := true
|
|
|
+
|
|
|
+ // Cleanup in error cases
|
|
|
+ defer func() {
|
|
|
+ if cleanupStopRunning {
|
|
|
+ meek.stopRunning()
|
|
|
+ }
|
|
|
+ if cleanupConns && meek.connManager != nil {
|
|
|
+ meek.connManager.closeAll()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
if meek.mode == MeekModeRelay {
|
|
|
var err error
|
|
|
meek.cookie,
|
|
|
@@ -396,13 +393,15 @@ func DialMeek(
|
|
|
return packetConn, remoteAddr, nil
|
|
|
}
|
|
|
|
|
|
+ meek.connManager = newMeekUnderlyingConnManager(nil, nil, udpDialer)
|
|
|
+
|
|
|
var err error
|
|
|
transport, err = quic.NewQUICTransporter(
|
|
|
ctx,
|
|
|
func(message string) {
|
|
|
NoticeInfo(message)
|
|
|
},
|
|
|
- udpDialer,
|
|
|
+ meek.connManager.dialPacketConn,
|
|
|
meekConfig.SNIServerName,
|
|
|
meekConfig.QUICVersion,
|
|
|
meekConfig.QUICClientHelloSeed,
|
|
|
@@ -448,12 +447,10 @@ func DialMeek(
|
|
|
|
|
|
scheme = "https"
|
|
|
|
|
|
- meek.initUnderlyingDialer(dialConfig)
|
|
|
-
|
|
|
tlsConfig := &CustomTLSConfig{
|
|
|
Parameters: meekConfig.Parameters,
|
|
|
DialAddr: meekConfig.DialAddress,
|
|
|
- Dial: meek.underlyingDial,
|
|
|
+ Dial: NewTCPDialer(dialConfig),
|
|
|
SNIServerName: meekConfig.SNIServerName,
|
|
|
SkipVerify: skipVerify,
|
|
|
VerifyServerName: meekConfig.VerifyServerName,
|
|
|
@@ -531,22 +528,19 @@ func DialMeek(
|
|
|
return nil, errors.Trace(err)
|
|
|
}
|
|
|
|
|
|
- cachedTLSDialer = newCachedTLSDialer(preConn, tlsDialer)
|
|
|
+ meek.connManager = newMeekUnderlyingConnManager(preConn, tlsDialer, nil)
|
|
|
|
|
|
if IsTLSConnUsingHTTP2(preConn) {
|
|
|
NoticeInfo("negotiated HTTP/2 for %s", meekConfig.DiagnosticID)
|
|
|
transport = &http2.Transport{
|
|
|
DialTLSContext: func(
|
|
|
ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
|
|
- return cachedTLSDialer.dial(ctx, network, addr)
|
|
|
+ return meek.connManager.dial(ctx, network, addr)
|
|
|
},
|
|
|
}
|
|
|
} else {
|
|
|
transport = &http.Transport{
|
|
|
- DialTLSContext: func(
|
|
|
- ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
- return cachedTLSDialer.dial(ctx, network, addr)
|
|
|
- },
|
|
|
+ DialTLSContext: meek.connManager.dial,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -577,8 +571,7 @@ func DialMeek(
|
|
|
*copyDialConfig = *dialConfig
|
|
|
copyDialConfig.UpstreamProxyURL = ""
|
|
|
|
|
|
- meek.initUnderlyingDialer(copyDialConfig)
|
|
|
- dialer = meek.underlyingDial
|
|
|
+ dialer = NewTCPDialer(copyDialConfig)
|
|
|
|
|
|
// In this proxy case, the destination server address is in the
|
|
|
// request line URL. net/http will render the request line using
|
|
|
@@ -602,8 +595,7 @@ func DialMeek(
|
|
|
// If dialConfig.UpstreamProxyURL is set, HTTP proxying via
|
|
|
// CONNECT will be used by the dialer.
|
|
|
|
|
|
- meek.initUnderlyingDialer(dialConfig)
|
|
|
- baseDialer := meek.underlyingDial
|
|
|
+ baseDialer := NewTCPDialer(dialConfig)
|
|
|
|
|
|
// The dialer ignores any address that http.Transport will pass in
|
|
|
// (derived from the HTTP request URL) and always dials
|
|
|
@@ -617,14 +609,19 @@ func DialMeek(
|
|
|
// Only apply transformer if it will perform a transform; otherwise
|
|
|
// applying a no-op transform will incur an unnecessary performance
|
|
|
// cost.
|
|
|
- if meekConfig.HTTPTransformerParameters != nil && meekConfig.HTTPTransformerParameters.ProtocolTransformSpec != nil {
|
|
|
- dialer = transforms.WrapDialerWithHTTPTransformer(dialer, meekConfig.HTTPTransformerParameters)
|
|
|
+ if meekConfig.HTTPTransformerParameters != nil &&
|
|
|
+ meekConfig.HTTPTransformerParameters.ProtocolTransformSpec != nil {
|
|
|
+
|
|
|
+ dialer = transforms.WrapDialerWithHTTPTransformer(
|
|
|
+ dialer, meekConfig.HTTPTransformerParameters)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ meek.connManager = newMeekUnderlyingConnManager(nil, dialer, nil)
|
|
|
+
|
|
|
httpTransport := &http.Transport{
|
|
|
Proxy: proxyUrl,
|
|
|
- DialContext: dialer,
|
|
|
+ DialContext: meek.connManager.dial,
|
|
|
}
|
|
|
|
|
|
if proxyUrl != nil {
|
|
|
@@ -694,12 +691,11 @@ func DialMeek(
|
|
|
|
|
|
meek.url = url
|
|
|
meek.additionalHeaders = additionalHeaders
|
|
|
- meek.cachedTLSDialer = cachedTLSDialer
|
|
|
meek.transport = transport
|
|
|
|
|
|
// stopRunning and cachedTLSDialer will now be closed in meek.Close()
|
|
|
cleanupStopRunning = false
|
|
|
- cleanupCachedTLSDialer = false
|
|
|
+ cleanupConns = false
|
|
|
|
|
|
// Allocate relay resources, including buffers and running the relay
|
|
|
// go routine, only when running in relay mode.
|
|
|
@@ -763,56 +759,175 @@ func DialMeek(
|
|
|
return meek, nil
|
|
|
}
|
|
|
|
|
|
-func (meek *MeekConn) initUnderlyingDialer(dialConfig *DialConfig) {
|
|
|
+type meekPacketConnDialer func(ctx context.Context) (net.PacketConn, *net.UDPAddr, error)
|
|
|
|
|
|
- // Not safe for concurrent calls; should be called only from DialMeek.
|
|
|
- meek.underlyingDialer = NewTCPDialer(dialConfig)
|
|
|
+// meekUnderlyingConnManager tracks the TCP/TLS and UDP connections underlying
|
|
|
+// the meek HTTP/HTTPS/QUIC transports. This tracking is used to:
|
|
|
+//
|
|
|
+// - Use the cached predial TLS conn created in DialMeek.
|
|
|
+// - Gather metrics from mechanisms enabled in the underlying conns, such as
|
|
|
+// the fragmentor, or inproxy.
|
|
|
+// - Fully close all underlying connections with the MeekConn is closed.
|
|
|
+type meekUnderlyingConnManager struct {
|
|
|
+ mutex sync.Mutex
|
|
|
+ cachedConn net.Conn
|
|
|
+ firstConn net.Conn
|
|
|
+ firstPacketConn net.PacketConn
|
|
|
+
|
|
|
+ dialer common.Dialer
|
|
|
+ managedConns *common.Conns[net.Conn]
|
|
|
+
|
|
|
+ packetConnDialer meekPacketConnDialer
|
|
|
+ managedPacketConns *common.Conns[net.PacketConn]
|
|
|
}
|
|
|
|
|
|
-func (meek *MeekConn) underlyingDial(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
- conn, err := meek.underlyingDialer(ctx, network, addr)
|
|
|
- if err == nil {
|
|
|
- meek.mutex.Lock()
|
|
|
- if meek.firstUnderlyingConn == nil {
|
|
|
- // Keep a reference to the first underlying conn to be used as a
|
|
|
- // common.MetricsSource in GetMetrics. This enables capturing
|
|
|
- // metrics such as fragmentor configuration.
|
|
|
- meek.firstUnderlyingConn = conn
|
|
|
- }
|
|
|
- meek.mutex.Unlock()
|
|
|
- }
|
|
|
+type meekUnderlyingConn struct {
|
|
|
+ net.Conn
|
|
|
+ connManager *meekUnderlyingConnManager
|
|
|
+}
|
|
|
+
|
|
|
+func (conn *meekUnderlyingConn) Close() error {
|
|
|
+ conn.connManager.managedConns.Remove(conn)
|
|
|
+
|
|
|
// Note: no trace error to preserve error type
|
|
|
- return conn, err
|
|
|
+ return conn.Conn.Close()
|
|
|
+}
|
|
|
+
|
|
|
+type meekUnderlyingPacketConn struct {
|
|
|
+ net.PacketConn
|
|
|
+ connManager *meekUnderlyingConnManager
|
|
|
}
|
|
|
|
|
|
-type cachedTLSDialer struct {
|
|
|
- usedCachedConn int32
|
|
|
- cachedConn net.Conn
|
|
|
- dialer common.Dialer
|
|
|
+func (packetConn *meekUnderlyingPacketConn) Close() error {
|
|
|
+ packetConn.connManager.managedPacketConns.Remove(packetConn)
|
|
|
+ return packetConn.PacketConn.Close()
|
|
|
}
|
|
|
|
|
|
-func newCachedTLSDialer(cachedConn net.Conn, dialer common.Dialer) *cachedTLSDialer {
|
|
|
- return &cachedTLSDialer{
|
|
|
- cachedConn: cachedConn,
|
|
|
- dialer: dialer,
|
|
|
+func newMeekUnderlyingConnManager(
|
|
|
+ cachedConn net.Conn,
|
|
|
+ dialer common.Dialer,
|
|
|
+ packetConnDialer meekPacketConnDialer) *meekUnderlyingConnManager {
|
|
|
+
|
|
|
+ m := &meekUnderlyingConnManager{
|
|
|
+ dialer: dialer,
|
|
|
+ managedConns: common.NewConns[net.Conn](),
|
|
|
+
|
|
|
+ packetConnDialer: packetConnDialer,
|
|
|
+ managedPacketConns: common.NewConns[net.PacketConn](),
|
|
|
}
|
|
|
+
|
|
|
+ if cachedConn != nil {
|
|
|
+ m.cachedConn = &meekUnderlyingConn{Conn: cachedConn, connManager: m}
|
|
|
+ m.firstConn = cachedConn
|
|
|
+ }
|
|
|
+
|
|
|
+ return m
|
|
|
}
|
|
|
|
|
|
-func (c *cachedTLSDialer) dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
- if atomic.CompareAndSwapInt32(&c.usedCachedConn, 0, 1) {
|
|
|
- conn := c.cachedConn
|
|
|
- c.cachedConn = nil
|
|
|
+func (m *meekUnderlyingConnManager) GetMetrics() common.LogFields {
|
|
|
+
|
|
|
+ logFields := common.LogFields{}
|
|
|
+
|
|
|
+ m.mutex.Lock()
|
|
|
+ underlyingMetrics, ok := m.firstConn.(common.MetricsSource)
|
|
|
+ if ok {
|
|
|
+ logFields.Add(underlyingMetrics.GetMetrics())
|
|
|
+ }
|
|
|
+
|
|
|
+ underlyingMetrics, ok = m.firstPacketConn.(common.MetricsSource)
|
|
|
+ if ok {
|
|
|
+ logFields.Add(underlyingMetrics.GetMetrics())
|
|
|
+ }
|
|
|
+ m.mutex.Unlock()
|
|
|
+
|
|
|
+ return logFields
|
|
|
+}
|
|
|
+
|
|
|
+func (m *meekUnderlyingConnManager) dial(
|
|
|
+ ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
+
|
|
|
+ if m.managedConns.IsClosed() {
|
|
|
+ return nil, errors.TraceNew("closed")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Consume the cached conn when present.
|
|
|
+
|
|
|
+ m.mutex.Lock()
|
|
|
+ var conn net.Conn
|
|
|
+ if m.cachedConn != nil {
|
|
|
+ conn = m.cachedConn
|
|
|
+ m.cachedConn = nil
|
|
|
+ }
|
|
|
+ m.mutex.Unlock()
|
|
|
+
|
|
|
+ if conn != nil {
|
|
|
return conn, nil
|
|
|
}
|
|
|
|
|
|
- return c.dialer(ctx, network, addr)
|
|
|
+ // The mutex lock is not held for the duration of dial, allowing for
|
|
|
+ // concurrent dials.
|
|
|
+
|
|
|
+ conn, err := m.dialer(ctx, network, addr)
|
|
|
+ if err != nil {
|
|
|
+ // Note: no trace error to preserve error type
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Keep a reference to the first underlying conn to be used as a
|
|
|
+ // common.MetricsSource in GetMetrics. This enables capturing metrics
|
|
|
+ // such as fragmentor configuration.
|
|
|
+
|
|
|
+ m.mutex.Lock()
|
|
|
+ if m.firstConn == nil {
|
|
|
+ m.firstConn = conn
|
|
|
+ }
|
|
|
+ m.mutex.Unlock()
|
|
|
+
|
|
|
+ // Wrap the dialed conn with meekUnderlyingConn, which will remove the
|
|
|
+ // conn from the set of tracked conns when the conn is closed.
|
|
|
+
|
|
|
+ conn = &meekUnderlyingConn{Conn: conn, connManager: m}
|
|
|
+
|
|
|
+ if !m.managedConns.Add(conn) {
|
|
|
+ _ = conn.Close()
|
|
|
+ return nil, errors.TraceNew("closed")
|
|
|
+ }
|
|
|
+
|
|
|
+ return conn, nil
|
|
|
}
|
|
|
|
|
|
-func (c *cachedTLSDialer) close() {
|
|
|
- if atomic.CompareAndSwapInt32(&c.usedCachedConn, 0, 1) {
|
|
|
- c.cachedConn.Close()
|
|
|
- c.cachedConn = nil
|
|
|
+func (m *meekUnderlyingConnManager) dialPacketConn(
|
|
|
+ ctx context.Context) (net.PacketConn, *net.UDPAddr, error) {
|
|
|
+
|
|
|
+ if m.managedPacketConns.IsClosed() {
|
|
|
+ return nil, nil, errors.TraceNew("closed")
|
|
|
}
|
|
|
+
|
|
|
+ packetConn, addr, err := m.packetConnDialer(ctx)
|
|
|
+ if err != nil {
|
|
|
+ // Note: no trace error to preserve error type
|
|
|
+ return nil, nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ m.mutex.Lock()
|
|
|
+ if m.firstPacketConn != nil {
|
|
|
+ m.firstPacketConn = packetConn
|
|
|
+ }
|
|
|
+ m.mutex.Unlock()
|
|
|
+
|
|
|
+ packetConn = &meekUnderlyingPacketConn{PacketConn: packetConn, connManager: m}
|
|
|
+
|
|
|
+ if !m.managedPacketConns.Add(packetConn) {
|
|
|
+ _ = packetConn.Close()
|
|
|
+ return nil, nil, errors.TraceNew("closed")
|
|
|
+ }
|
|
|
+
|
|
|
+ return packetConn, addr, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (m *meekUnderlyingConnManager) closeAll() {
|
|
|
+ m.managedConns.CloseAll()
|
|
|
+ m.managedPacketConns.CloseAll()
|
|
|
}
|
|
|
|
|
|
// Close terminates the meek connection and releases its resources. In in
|
|
|
@@ -828,31 +943,12 @@ func (meek *MeekConn) Close() (err error) {
|
|
|
|
|
|
if !isClosed {
|
|
|
meek.stopRunning()
|
|
|
- if meek.cachedTLSDialer != nil {
|
|
|
- meek.cachedTLSDialer.close()
|
|
|
- }
|
|
|
-
|
|
|
- // stopRunning interrupts HTTP requests in progress by closing the context
|
|
|
- // associated with the request. In the case of h2quic.RoundTripper, testing
|
|
|
- // indicates that quic-go.receiveStream.readImpl is _not_ interrupted in
|
|
|
- // this case, and so an in-flight FRONTED-MEEK-QUIC round trip may hang shutdown
|
|
|
- // in relayRoundTrip->readPayload->...->quic-go.receiveStream.readImpl.
|
|
|
- // TODO: check if this is still the case in newer quic-go versions.
|
|
|
- //
|
|
|
- // To workaround this, we call CloseIdleConnections _before_ Wait, as, in
|
|
|
- // the case of QUICTransporter, this closes the underlying UDP sockets which
|
|
|
- // interrupts any blocking I/O calls.
|
|
|
- //
|
|
|
- // The standard CloseIdleConnections call _after_ wait is for the net/http
|
|
|
- // case: it only closes idle connections, so the call should be after wait.
|
|
|
- // This call is intended to clean up all network resources deterministically
|
|
|
- // before Close returns.
|
|
|
- if meek.isQUIC {
|
|
|
- meek.transport.CloseIdleConnections()
|
|
|
- }
|
|
|
-
|
|
|
+ meek.connManager.closeAll()
|
|
|
meek.relayWaitGroup.Wait()
|
|
|
- meek.transport.CloseIdleConnections()
|
|
|
+
|
|
|
+ // meek.transport.CloseIdleConnections is no longed called here since
|
|
|
+ // meekUnderlyingConnManager.closeAll will terminate all underlying
|
|
|
+ // connections and prevent opening any new connections.
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
@@ -877,16 +973,12 @@ func (meek *MeekConn) GetMetrics() common.LogFields {
|
|
|
logFields["meek_limit_request"] = meek.limitRequestPayloadLength
|
|
|
logFields["meek_redial_probability"] = meek.redialTLSProbability
|
|
|
}
|
|
|
+
|
|
|
// Include metrics, such as fragmentor metrics, from the _first_ underlying
|
|
|
// dial conn. Properties of subsequent underlying dial conns are not reflected
|
|
|
// in these metrics; we assume that the first dial conn, which most likely
|
|
|
// transits the various protocol handshakes, is most significant.
|
|
|
- meek.mutex.Lock()
|
|
|
- underlyingMetrics, ok := meek.firstUnderlyingConn.(common.MetricsSource)
|
|
|
- if ok {
|
|
|
- logFields.Add(underlyingMetrics.GetMetrics())
|
|
|
- }
|
|
|
- meek.mutex.Unlock()
|
|
|
+ logFields.Add(meek.connManager.GetMetrics())
|
|
|
return logFields
|
|
|
}
|
|
|
|