Просмотр исходного кода

Refactor: simplified tls_resumed_session metric gathering

Amir Khan 1 год назад
Родитель
Сommit
a524fda831
4 измененных файлов с 38 добавлено и 33 удалено
  1. 17 15
      psiphon/meekConn.go
  2. 17 11
      psiphon/tlsDialer.go
  3. 1 1
      psiphon/tlsDialer_test.go
  4. 3 6
      psiphon/tlsTunnelConn.go

+ 17 - 15
psiphon/meekConn.go

@@ -265,10 +265,6 @@ type MeekConn struct {
 	relayWaitGroup            *sync.WaitGroup
 	firstUnderlyingConn       net.Conn
 
-	// resumedTLSSession represents whether the first underlying TLS connection
-	// was resumed for metrics purposes.
-	resumedTLSSession bool
-
 	// For MeekModeObfuscatedRoundTrip
 	meekCookieEncryptionPublicKey string
 	meekObfuscatedKey             string
@@ -562,8 +558,6 @@ func DialMeek(
 			return nil, errors.Trace(err)
 		}
 
-		meek.resumedTLSSession = preConn.resumedSession
-
 		cachedTLSDialer = newCachedTLSDialer(preConn, tlsDialer)
 
 		if IsTLSConnUsingHTTP2(preConn) {
@@ -819,16 +813,26 @@ func (meek *MeekConn) underlyingDial(ctx context.Context, network, addr string)
 type cachedTLSDialer struct {
 	usedCachedConn int32
 	cachedConn     net.Conn
-	dialer         CustomTLSDialer
+	dialer         common.Dialer
+
+	// cachedConnMetrics records cachedConn metrics.
+	// These metrics do not change after the first dial.
+	cachedConnMetrics common.LogFields
 
 	mutex      sync.Mutex
 	requestCtx context.Context
 }
 
-func newCachedTLSDialer(cachedConn net.Conn, dialer CustomTLSDialer) *cachedTLSDialer {
+func newCachedTLSDialer(cachedConn net.Conn, dialer common.Dialer) *cachedTLSDialer {
+	cachedConnMetrics, ok := cachedConn.(common.MetricsSource)
+	metrics := make(common.LogFields)
+	if ok {
+		metrics = cachedConnMetrics.GetMetrics()
+	}
 	return &cachedTLSDialer{
-		cachedConn: cachedConn,
-		dialer:     dialer,
+		cachedConn:        cachedConn,
+		dialer:            dialer,
+		cachedConnMetrics: metrics,
 	}
 }
 
@@ -933,14 +937,12 @@ func (meek *MeekConn) GetMetrics() common.LogFields {
 	if ok {
 		logFields.Add(underlyingMetrics.GetMetrics())
 	}
-	if meek.cachedTLSDialer != nil {
-		logFields["tls_resumed_session"] = meek.resumedTLSSession
-	}
-
 	if quicTransport, ok := meek.transport.(*quic.QUICTransporter); ok {
 		logFields.Add(quicTransport.GetMetrics())
+	} else {
+		// For non-QUIC transports, include the TLS session resumption status.
+		logFields.Add(meek.cachedTLSDialer.cachedConnMetrics)
 	}
-
 	meek.mutex.Unlock()
 	return logFields
 }

+ 17 - 11
psiphon/tlsDialer.go

@@ -197,16 +197,9 @@ func (config *CustomTLSConfig) EnableClientSessionCache() {
 	}
 }
 
-type CustomTLSDialer = func(ctx context.Context, network, addr string) (*CustomTLSConn, error)
-
-type CustomTLSConn struct {
-	net.Conn
-	resumedSession bool
-}
-
 // NewCustomTLSDialer creates a new dialer based on CustomTLSDial.
-func NewCustomTLSDialer(config *CustomTLSConfig) CustomTLSDialer {
-	return func(ctx context.Context, network, addr string) (*CustomTLSConn, error) {
+func NewCustomTLSDialer(config *CustomTLSConfig) common.Dialer {
+	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 		return CustomTLSDial(ctx, network, addr, config)
 	}
 }
@@ -218,7 +211,7 @@ func NewCustomTLSDialer(config *CustomTLSConfig) CustomTLSDialer {
 func CustomTLSDial(
 	ctx context.Context,
 	network, addr string,
-	config *CustomTLSConfig) (*CustomTLSConn, error) {
+	config *CustomTLSConfig) (net.Conn, error) {
 
 	// Note that servers may return a chain which excludes the root CA
 	// cert https://datatracker.ietf.org/doc/html/rfc8446#section-4.4.2.
@@ -675,12 +668,25 @@ func CustomTLSDial(
 		return nil, errors.Trace(err)
 	}
 
-	return &CustomTLSConn{
+	return &tlsConn{
 		Conn:           conn,
+		underlyingConn: rawConn,
 		resumedSession: usedSessionTicket,
 	}, nil
 }
 
+type tlsConn struct {
+	net.Conn
+	underlyingConn net.Conn
+	resumedSession bool
+}
+
+func (conn *tlsConn) GetMetrics() common.LogFields {
+	logFields := make(common.LogFields)
+	logFields["tls_resumed_session"] = conn.resumedSession
+	return logFields
+}
+
 func verifyLegacyCertificate(rawCerts [][]byte, expectedCertificate *x509.Certificate) error {
 	if len(rawCerts) < 1 {
 		return errors.TraceNew("missing certificate")

+ 1 - 1
psiphon/tlsDialer_test.go

@@ -581,7 +581,7 @@ func testTLSDialerCompatibility(t *testing.T, address string, fragmentClientHell
 			} else {
 
 				tlsVersion := ""
-				version := conn.Conn.(*utls.UConn).ConnectionState().Version
+				version := conn.(*tlsConn).Conn.(*utls.UConn).ConnectionState().Version
 				if version == utls.VersionTLS12 {
 					tlsVersion = "TLS 1.2"
 				} else if version == utls.VersionTLS13 {

+ 3 - 6
psiphon/tlsTunnelConn.go

@@ -49,8 +49,7 @@ type TLSTunnelConfig struct {
 // TLSTunnelConn is a network connection that tunnels net.Conn flows over TLS.
 type TLSTunnelConn struct {
 	net.Conn
-	tlsPadding        int
-	resumedTLSSession bool
+	tlsPadding int
 }
 
 // DialTLSTunnel returns an initialized tls-tunnel connection.
@@ -117,9 +116,8 @@ func DialTLSTunnel(
 	}
 
 	return &TLSTunnelConn{
-		Conn:              conn,
-		tlsPadding:        tlsPadding,
-		resumedTLSSession: conn.resumedSession,
+		Conn:       conn,
+		tlsPadding: tlsPadding,
 	}, nil
 }
 
@@ -163,7 +161,6 @@ func (conn *TLSTunnelConn) GetMetrics() common.LogFields {
 	logFields := make(common.LogFields)
 
 	logFields["tls_padding"] = conn.tlsPadding
-	logFields["tls_resumed_session"] = conn.resumedTLSSession
 
 	// Include metrics, such as fragmentor metrics, from the underlying dial
 	// conn. Properties of subsequent underlying dial conns are not reflected