Selaa lähdekoodia

Enabled 0-RTT for QUIC connections

Amir Khan 2 vuotta sitten
vanhempi
sitoutus
74775c41fc

+ 2 - 0
psiphon/common/parameters/parameters.go

@@ -112,6 +112,7 @@ const (
 	LimitQUICVersionsProbability                     = "LimitQUICVersionsProbability"
 	LimitQUICVersions                                = "LimitQUICVersions"
 	DisableFrontingProviderQUICVersions              = "DisableFrontingProviderQUICVersions"
+	QUICDialEarlyProbability                         = "QUICDialEarlyProbability"
 	QUICDisableClientPathMTUDiscoveryProbability     = "QUICDisableClientPathMTUDiscoveryProbability"
 	FragmentorProbability                            = "FragmentorProbability"
 	FragmentorLimitProtocols                         = "FragmentorLimitProtocols"
@@ -451,6 +452,7 @@ var defaultParameters = map[string]struct {
 	LimitQUICVersionsProbability:                 {value: 1.0, minimum: 0.0},
 	LimitQUICVersions:                            {value: protocol.QUICVersions{}},
 	DisableFrontingProviderQUICVersions:          {value: protocol.LabeledQUICVersions{}},
+	QUICDialEarlyProbability:                     {value: 1.0, minimum: 0.0},
 	QUICDisableClientPathMTUDiscoveryProbability: {value: 0.0, minimum: 0.0},
 
 	FragmentorProbability:              {value: 0.5, minimum: 0.0},

+ 5 - 0
psiphon/common/quic/gquic.go

@@ -105,6 +105,11 @@ func (c *gQUICConnection) isErrorIndicatingClosed(err error) bool {
 	return false
 }
 
+func (c *gQUICConnection) isEarlyDataRejected(err error) bool {
+	// 0-RTT is not supported by gQUIC.
+	return false
+}
+
 func gQUICDialContext(
 	ctx context.Context,
 	packetConn net.PacketConn,

+ 2 - 0
psiphon/common/quic/obfuscator_test.go

@@ -132,6 +132,8 @@ func runNonceTransformer(t *testing.T, quicVersion string) {
 				TransformSpec: transforms.Spec{{"^.{24}", "ffff00000000000000000000"}},
 			},
 			false,
+			false,
+			nil,
 		)
 
 		return nil

+ 33 - 8
psiphon/common/quic/quic.go

@@ -243,7 +243,7 @@ func Listen(
 		// Skipping muxListener also avoids the additional overhead of
 		// pumping read packets though mux channels.
 
-		tlsConfig, ietfQUICConfig := makeIETFConfig(
+		tlsConfig, ietfQUICConfig := makeServerIETFConfig(
 			obfuscatedPacketConn, verifyClientHelloRandom, tlsCertificate)
 
 		tr := newIETFTransport(obfuscatedPacketConn)
@@ -281,7 +281,7 @@ func Listen(
 	}, nil
 }
 
-func makeIETFConfig(
+func makeServerIETFConfig(
 	conn *ObfuscatedPacketConn,
 	verifyClientHelloRandom func(net.Addr, []byte) bool,
 	tlsCertificate tls.Certificate) (*tls.Config, *ietf_quic.Config) {
@@ -292,6 +292,7 @@ func makeIETFConfig(
 	}
 
 	ietfQUICConfig := &ietf_quic.Config{
+		Allow0RTT:             true,
 		HandshakeIdleTimeout:  SERVER_HANDSHAKE_TIMEOUT,
 		MaxIdleTimeout:        serverIdleTimeout,
 		MaxIncomingStreams:    1,
@@ -369,7 +370,9 @@ func Dial(
 	obfuscationKey string,
 	obfuscationPaddingSeed *prng.Seed,
 	obfuscationNonceTransformerParameters *transforms.ObfuscatorSeedTransformerParameters,
-	disablePathMTUDiscovery bool) (net.Conn, error) {
+	disablePathMTUDiscovery bool,
+	dialEarly bool,
+	tlsClientSessionCache tls.ClientSessionCache) (net.Conn, error) {
 
 	if quicVersion == "" {
 		return nil, errors.TraceNew("missing version")
@@ -488,7 +491,9 @@ func Dial(
 		getClientHelloRandom,
 		maxPacketSizeAdjustment,
 		disablePathMTUDiscovery,
-		false)
+		dialEarly,
+		tlsClientSessionCache)
+
 	if err != nil {
 		packetConn.Close()
 		return nil, errors.Trace(err)
@@ -697,6 +702,8 @@ type QUICTransporter struct {
 	quicVersion             string
 	clientHelloSeed         *prng.Seed
 	disablePathMTUDiscovery bool
+	dialEarly               bool
+	tlsClientSessionCache   tls.ClientSessionCache
 	packetConn              atomic.Value
 
 	mutex sync.Mutex
@@ -711,7 +718,9 @@ func NewQUICTransporter(
 	quicSNIAddress string,
 	quicVersion string,
 	clientHelloSeed *prng.Seed,
-	disablePathMTUDiscovery bool) (*QUICTransporter, error) {
+	disablePathMTUDiscovery bool,
+	dialEarly bool,
+	tlsClientSessionCache tls.ClientSessionCache) (*QUICTransporter, error) {
 
 	if quicVersion == "" {
 		return nil, errors.TraceNew("missing version")
@@ -733,6 +742,8 @@ func NewQUICTransporter(
 		quicVersion:             quicVersion,
 		clientHelloSeed:         clientHelloSeed,
 		disablePathMTUDiscovery: disablePathMTUDiscovery,
+		dialEarly:               dialEarly,
+		tlsClientSessionCache:   tlsClientSessionCache,
 		ctx:                     ctx,
 	}
 
@@ -836,7 +847,9 @@ func (t *QUICTransporter) dialQUIC() (retConnection quicConnection, retErr error
 		nil,
 		0,
 		t.disablePathMTUDiscovery,
-		true)
+		t.dialEarly,
+		t.tlsClientSessionCache)
+
 	if err != nil {
 		packetConn.Close()
 		return nil, errors.Trace(err)
@@ -887,6 +900,7 @@ type quicConnection interface {
 	AcceptStream() (quicStream, error)
 	OpenStream() (quicStream, error)
 	isErrorIndicatingClosed(err error) bool
+	isEarlyDataRejected(err error) bool
 }
 
 type quicStream interface {
@@ -960,6 +974,13 @@ func (c *ietfQUICConnection) isErrorIndicatingClosed(err error) bool {
 		errStr == "timeout: no recent network activity"
 }
 
+func (c *ietfQUICConnection) isEarlyDataRejected(err error) bool {
+	if err == nil {
+		return false
+	}
+	return err == ietf_quic.Err0RTTRejected
+}
+
 func dialQUIC(
 	ctx context.Context,
 	packetConn net.PacketConn,
@@ -971,7 +992,8 @@ func dialQUIC(
 	getClientHelloRandom func() ([]byte, error),
 	clientMaxPacketSizeAdjustment int,
 	disablePathMTUDiscovery bool,
-	dialEarly bool) (quicConnection, error) {
+	dialEarly bool,
+	tlsClientSessionCache tls.ClientSessionCache) (quicConnection, error) {
 
 	if isIETFVersionNumber(versionNumber) {
 		quicConfig := &ietf_quic.Config{
@@ -1004,9 +1026,11 @@ func dialQUIC(
 			InsecureSkipVerify: true,
 			NextProtos:         []string{getALPN(versionNumber)},
 			ServerName:         sni,
+			ClientSessionCache: tlsClientSessionCache,
 		}
 
 		if dialEarly {
+			// Attempting 0-RTT if possible.
 			dialConnection, err = ietf_quic.DialEarly(
 				ctx,
 				packetConn,
@@ -1021,6 +1045,7 @@ func dialQUIC(
 				tlsConfig,
 				quicConfig)
 		}
+
 		if err != nil {
 			return nil, errors.Trace(err)
 		}
@@ -1195,7 +1220,7 @@ func newMuxListener(
 
 	listener.ietfQUICConn = newMuxPacketConn(conn.LocalAddr(), listener)
 
-	tlsConfig, ietfQUICConfig := makeIETFConfig(
+	tlsConfig, ietfQUICConfig := makeServerIETFConfig(
 		conn, verifyClientHelloRandom, tlsCertificate)
 
 	tr := newIETFTransport(listener.ietfQUICConn)

+ 3 - 1
psiphon/common/quic/quic_test.go

@@ -203,7 +203,9 @@ func runQUIC(
 				clientObfuscationKey,
 				obfuscationPaddingSeed,
 				nil,
-				disablePathMTUDiscovery)
+				disablePathMTUDiscovery,
+				true,
+				nil)
 
 			if invokeAntiProbing {
 

+ 12 - 0
psiphon/config.go

@@ -846,6 +846,9 @@ type Config struct {
 	// LimitTunnelDialPortNumbers is for testing purposes.
 	LimitTunnelDialPortNumbers parameters.TunnelProtocolPortLists
 
+	// QUICDialEarlyProbability is for testing purposes.
+	QUICDialEarlyProbability *float64
+
 	// QUICDisablePathMTUDiscoveryProbability is for testing purposes.
 	QUICDisablePathMTUDiscoveryProbability *float64
 
@@ -1989,6 +1992,10 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.LimitTunnelDialPortNumbers] = config.LimitTunnelDialPortNumbers
 	}
 
+	if config.QUICDialEarlyProbability != nil {
+		applyParameters[parameters.QUICDialEarlyProbability] = *config.QUICDialEarlyProbability
+	}
+
 	if config.QUICDisablePathMTUDiscoveryProbability != nil {
 		applyParameters[parameters.QUICDisableClientPathMTUDiscoveryProbability] = *config.QUICDisablePathMTUDiscoveryProbability
 	}
@@ -2541,6 +2548,11 @@ func (config *Config) setDialParametersHash() {
 		hash.Write(encodedLimitTunnelDialPortNumbers)
 	}
 
+	if config.QUICDialEarlyProbability != nil {
+		hash.Write([]byte("QUICDialEarlyProbability"))
+		binary.Write(hash, binary.LittleEndian, *config.QUICDialEarlyProbability)
+	}
+
 	if config.QUICDisablePathMTUDiscoveryProbability != nil {
 		hash.Write([]byte("QUICDisablePathMTUDiscoveryProbability"))
 		binary.Write(hash, binary.LittleEndian, *config.QUICDisablePathMTUDiscoveryProbability)

+ 9 - 0
psiphon/controller.go

@@ -33,6 +33,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	tls "github.com/Psiphon-Labs/psiphon-tls"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
@@ -41,6 +42,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/resolver"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
 	lrucache "github.com/cognusion/go-cache-lru"
+	utls "github.com/refraction-networking/utls"
 )
 
 // Controller is a tunnel lifecycle coordinator. It manages lists of servers to
@@ -88,6 +90,8 @@ type Controller struct {
 	staggerMutex                            sync.Mutex
 	resolver                                *resolver.Resolver
 	steeringIPCache                         *lrucache.Cache
+	tlsClientSessionCache                   tls.ClientSessionCache
+	utlsClientSessionCache                  utls.ClientSessionCache
 }
 
 // NewController initializes a new controller.
@@ -157,6 +161,9 @@ func NewController(config *Config) (controller *Controller, err error) {
 			steeringIPCacheTTL,
 			1*time.Minute,
 			steeringIPCacheMaxEntries),
+
+		tlsClientSessionCache:  tls.NewLRUClientSessionCache(0),
+		utlsClientSessionCache: utls.NewLRUClientSessionCache(0),
 	}
 
 	// Initialize untunneledDialConfig, used by untunneled dials including
@@ -2194,6 +2201,8 @@ loop:
 		dialParams, err := MakeDialParameters(
 			controller.config,
 			controller.steeringIPCache,
+			controller.tlsClientSessionCache,
+			controller.utlsClientSessionCache,
 			upstreamProxyErrorCallback,
 			canReplay,
 			selectProtocol,

+ 54 - 0
psiphon/dialParameters.go

@@ -32,6 +32,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	tls "github.com/Psiphon-Labs/psiphon-tls"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
@@ -44,6 +45,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values"
 	lrucache "github.com/cognusion/go-cache-lru"
+	utls "github.com/refraction-networking/utls"
 	"golang.org/x/net/bpf"
 )
 
@@ -131,6 +133,7 @@ type DialParameters struct {
 	QUICClientHelloSeed                      *prng.Seed
 	ObfuscatedQUICPaddingSeed                *prng.Seed
 	ObfuscatedQUICNonceTransformerParameters *transforms.ObfuscatorSeedTransformerParameters
+	QUICDialEarly                            bool
 	QUICDisablePathMTUDiscovery              bool
 
 	ConjureCachedRegistrationTTL        time.Duration
@@ -165,6 +168,10 @@ type DialParameters struct {
 	steeringIPCache    *lrucache.Cache `json:"-"`
 	steeringIPCacheKey string          `json:"-"`
 
+	quicTLSSessionCacheKey      string                  `json:"-"`
+	QUICTLSClientSessionCache   tls.ClientSessionCache  `json:"-"`
+	directTLSClientSessionCache utls.ClientSessionCache `json:"-"`
+
 	dialConfig *DialConfig `json:"-"`
 	meekConfig *MeekConfig `json:"-"`
 }
@@ -189,6 +196,8 @@ type DialParameters struct {
 func MakeDialParameters(
 	config *Config,
 	steeringIPCache *lrucache.Cache,
+	quicTLSClientSessionCache tls.ClientSessionCache,
+	directTLSClientSessionCache utls.ClientSessionCache,
 	upstreamProxyErrorCallback func(error),
 	canReplay func(serverEntry *protocol.ServerEntry, replayProtocol string) bool,
 	selectProtocol func(serverEntry *protocol.ServerEntry) (string, bool),
@@ -362,6 +371,8 @@ func MakeDialParameters(
 
 	dialParams.steeringIPCache = steeringIPCache
 
+	dialParams.directTLSClientSessionCache = directTLSClientSessionCache
+
 	dialParams.ServerEntry = serverEntry
 	dialParams.NetworkID = networkID
 	dialParams.IsReplay = isReplay
@@ -797,11 +808,23 @@ func MakeDialParameters(
 			}
 		}
 
+		dialParams.QUICDialEarly = p.WeightedCoinFlip(parameters.QUICDialEarlyProbability)
+
 		dialParams.QUICDisablePathMTUDiscovery =
 			protocol.QUICVersionUsesPathMTUDiscovery(dialParams.QUICVersion) &&
 				p.WeightedCoinFlip(parameters.QUICDisableClientPathMTUDiscoveryProbability)
 	}
 
+	// Sets up client session caching for QUIC with a TLS cache key unique to current endpoint.
+	if protocol.TunnelProtocolUsesQUIC(dialParams.TunnelProtocol) {
+		dialPortNumber, err := serverEntry.GetDialPortNumber(dialParams.TunnelProtocol)
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+		dialParams.quicTLSSessionCacheKey = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
+		dialParams.QUICTLSClientSessionCache = WrapClientSessionCache(quicTLSClientSessionCache, dialParams.quicTLSSessionCacheKey)
+	}
+
 	if (!isReplay || !replayObfuscatedQUIC) &&
 		protocol.QUICVersionIsObfuscated(dialParams.QUICVersion) {
 
@@ -1305,6 +1328,8 @@ func MakeDialParameters(
 			UseQUIC:                       protocol.TunnelProtocolUsesFrontedMeekQUIC(dialParams.TunnelProtocol),
 			QUICVersion:                   dialParams.QUICVersion,
 			QUICClientHelloSeed:           dialParams.QUICClientHelloSeed,
+			QUICDialEarly:                 dialParams.QUICDialEarly,
+			QuicTlsClientSessionCache:     dialParams.meekConfig.QuicTlsClientSessionCache,
 			QUICDisablePathMTUDiscovery:   dialParams.QUICDisablePathMTUDiscovery,
 			UseHTTPS:                      usingTLS,
 			TLSProfile:                    dialParams.TLSProfile,
@@ -1368,6 +1393,7 @@ func (dialParams *DialParameters) GetTLSOSSHConfig(config *Config) *TLSTunnelCon
 			NoDefaultTLSSessionID:    &dialParams.NoDefaultTLSSessionID,
 			RandomizedTLSProfileSeed: dialParams.RandomizedTLSProfileSeed,
 			FragmentClientHello:      dialParams.TLSFragmentClientHello,
+			ClientSessionCache:       dialParams.directTLSClientSessionCache,
 		},
 		// Obfuscated session tickets are not used because TLS-OSSH uses TLS 1.3.
 		UseObfuscatedSessionTickets: false,
@@ -1455,6 +1481,12 @@ func (dialParams *DialParameters) Failed(config *Config) {
 	if dialParams.steeringIPCacheKey != "" {
 		dialParams.steeringIPCache.Delete(dialParams.steeringIPCacheKey)
 	}
+
+	// Clear the TLS client session cache to avoid (potentially) reusing failed sessions.
+	if protocol.TunnelProtocolUsesQUIC(dialParams.TunnelProtocol) {
+		dialParams.QUICTLSClientSessionCache.Put(dialParams.quicTLSSessionCacheKey, nil)
+	}
+
 }
 
 func (dialParams *DialParameters) GetTLSVersionForMetrics() string {
@@ -1939,3 +1971,25 @@ func selectConjureTransport(
 
 	return transports[choice]
 }
+
+type tlsClientSessionCacheWrapper struct {
+	tls.ClientSessionCache
+
+	// sessinoKey specifies the value of the hard-coded TLS session cache key.
+	sessionKey string
+}
+
+func WrapClientSessionCache(cache tls.ClientSessionCache, sessionKey string) tls.ClientSessionCache {
+	return &tlsClientSessionCacheWrapper{
+		ClientSessionCache: cache,
+		sessionKey:         sessionKey,
+	}
+}
+
+func (c *tlsClientSessionCacheWrapper) Get(_ string) (session *tls.ClientSessionState, ok bool) {
+	return c.ClientSessionCache.Get(c.sessionKey)
+}
+
+func (c *tlsClientSessionCacheWrapper) Put(_ string, cs *tls.ClientSessionState) {
+	c.ClientSessionCache.Put(c.sessionKey, cs)
+}

+ 18 - 3
psiphon/meekConn.go

@@ -23,7 +23,7 @@ import (
 	"bytes"
 	"context"
 	"crypto/rand"
-	"crypto/tls"
+	std_tls "crypto/tls"
 	"encoding/base64"
 	"encoding/json"
 	"fmt"
@@ -37,6 +37,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	tls "github.com/Psiphon-Labs/psiphon-tls"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
@@ -120,6 +121,9 @@ type MeekConfig struct {
 	// QUICClientHelloSeed is used for randomized QUIC Client Hellos.
 	QUICClientHelloSeed *prng.Seed
 
+	// QUICDialEarly indicates whether the client should attempt 0-RTT.
+	QUICDialEarly bool
+
 	// QUICDisablePathMTUDiscovery indicates whether to disable path MTU
 	// discovery in the QUIC client.
 	QUICDisablePathMTUDiscovery bool
@@ -131,6 +135,10 @@ type MeekConfig struct {
 	// underlying TLS connections created by this meek connection.
 	TLSProfile string
 
+	// QuicTlsClientSessionCache specifies the TLS session cache to use
+	// for Meek connections that use HTTP/2 over QUIC.
+	QuicTlsClientSessionCache tls.ClientSessionCache
+
 	// TLSFragmentClientHello specifies whether to fragment the TLS Client Hello.
 	TLSFragmentClientHello bool
 
@@ -295,6 +303,11 @@ func DialMeek(
 			"invalid config: only one of UseQUIC or UseHTTPS may be set")
 	}
 
+	if meekConfig.UseQUIC && meekConfig.QuicTlsClientSessionCache == nil {
+		return nil, errors.TraceNew(
+			"invalid config: TLSClientSessionCache must be set when UseQUIC is set")
+	}
+
 	if meekConfig.UseQUIC &&
 		(meekConfig.VerifyServerName != "" || len(meekConfig.VerifyPins) > 0) {
 
@@ -406,7 +419,9 @@ func DialMeek(
 			meekConfig.SNIServerName,
 			meekConfig.QUICVersion,
 			meekConfig.QUICClientHelloSeed,
-			meekConfig.QUICDisablePathMTUDiscovery)
+			meekConfig.QUICDisablePathMTUDiscovery,
+			meekConfig.QUICDialEarly,
+			meekConfig.QuicTlsClientSessionCache)
 		if err != nil {
 			return nil, errors.Trace(err)
 		}
@@ -540,7 +555,7 @@ func DialMeek(
 		if IsTLSConnUsingHTTP2(preConn) {
 			NoticeInfo("negotiated HTTP/2 for %s", meekConfig.DiagnosticID)
 			transport = &http2.Transport{
-				DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
+				DialTLS: func(network, addr string, _ *std_tls.Config) (net.Conn, error) {
 					return cachedTLSDialer.dial(network, addr)
 				},
 			}

+ 2 - 0
psiphon/tactics.go

@@ -218,6 +218,8 @@ func fetchTactics(
 		config,
 		nil,
 		nil,
+		nil,
+		nil,
 		canReplay,
 		selectProtocol,
 		serverEntry,

+ 4 - 4
psiphon/tlsDialer.go

@@ -184,15 +184,15 @@ type CustomTLSConfig struct {
 	// FragmentClientHello specifies whether to fragment the ClientHello.
 	FragmentClientHello bool
 
-	clientSessionCache utls.ClientSessionCache
+	ClientSessionCache utls.ClientSessionCache
 }
 
 // EnableClientSessionCache initializes a cache to use to persist session
 // tickets, enabling TLS session resumability across multiple
 // CustomTLSDial calls or dialers using the same CustomTLSConfig.
 func (config *CustomTLSConfig) EnableClientSessionCache() {
-	if config.clientSessionCache == nil {
-		config.clientSessionCache = utls.NewLRUClientSessionCache(0)
+	if config.ClientSessionCache == nil {
+		config.ClientSessionCache = utls.NewLRUClientSessionCache(0)
 	}
 }
 
@@ -434,7 +434,7 @@ func CustomTLSDial(
 		}
 	}
 
-	clientSessionCache := config.clientSessionCache
+	clientSessionCache := config.ClientSessionCache
 	if clientSessionCache == nil {
 		clientSessionCache = utls.NewLRUClientSessionCache(0)
 	}

+ 3 - 1
psiphon/tunnel.go

@@ -793,7 +793,9 @@ func dialTunnel(
 			dialParams.ServerEntry.SshObfuscatedKey,
 			dialParams.ObfuscatedQUICPaddingSeed,
 			dialParams.ObfuscatedQUICNonceTransformerParameters,
-			dialParams.QUICDisablePathMTUDiscovery)
+			dialParams.QUICDisablePathMTUDiscovery,
+			dialParams.QUICDialEarly,
+			dialParams.QUICTLSClientSessionCache)
 		if err != nil {
 			return nil, errors.Trace(err)
 		}