Просмотр исходного кода

TLS client: Simplify cert's verification code (#5656)

Fixes https://github.com/XTLS/Xray-core/issues/5655
风扇滑翔翼 4 месяцев назад
Родитель
Сommit
4632984b66

+ 78 - 0
main/commands/all/tls/hash.go

@@ -0,0 +1,78 @@
+package tls
+
+import (
+	"bytes"
+	"crypto/x509"
+	"encoding/pem"
+	"flag"
+	"fmt"
+	"os"
+	"text/tabwriter"
+
+	"github.com/xtls/xray-core/main/commands/base"
+	. "github.com/xtls/xray-core/transport/internet/tls"
+)
+
+var cmdHash = &base.Command{
+	UsageLine: "{{.Exec}} tls hash",
+	Short:     "Calculate TLS certificate hash.",
+	Long: `
+	xray tls hash --cert <cert.pem>
+	Calculate TLS certificate hash.
+	`,
+}
+
+func init() {
+	cmdHash.Run = executeHash // break init loop
+}
+
+var input = cmdHash.Flag.String("cert", "fullchain.pem", "The file path of the certificate")
+
+func executeHash(cmd *base.Command, args []string) {
+	fs := flag.NewFlagSet("hash", flag.ContinueOnError)
+	if err := fs.Parse(args); err != nil {
+		fmt.Println(err)
+		return
+	}
+	certContent, err := os.ReadFile(*input)
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+	var certs []*x509.Certificate
+	if bytes.Contains(certContent, []byte("BEGIN")) {
+		for {
+			block, remain := pem.Decode(certContent)
+			if block == nil {
+				break
+			}
+			cert, err := x509.ParseCertificate(block.Bytes)
+			if err != nil {
+				fmt.Println("Unable to decode certificate:", err)
+				return
+			}
+			certs = append(certs, cert)
+			certContent = remain
+		}
+	} else {
+		certs, err = x509.ParseCertificates(certContent)
+		if err != nil {
+			fmt.Println("Unable to parse certificates:", err)
+			return
+		}
+	}
+	if len(certs) == 0 {
+		fmt.Println("No certificates found")
+		return
+	}
+	tabWriter := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
+	for i, cert := range certs {
+		hash := GenerateCertHashHex(cert)
+		if i == 0 {
+			fmt.Fprintf(tabWriter, "Leaf SHA256:\t%s\n", hash)
+		} else {
+			fmt.Fprintf(tabWriter, "CA <%s> SHA256:\t%s\n", cert.Subject.CommonName, hash)
+		}
+	}
+	tabWriter.Flush()
+}

+ 0 - 44
main/commands/all/tls/leafcerthash.go

@@ -1,44 +0,0 @@
-package tls
-
-import (
-	"flag"
-	"fmt"
-	"os"
-
-	"github.com/xtls/xray-core/main/commands/base"
-	"github.com/xtls/xray-core/transport/internet/tls"
-)
-
-var cmdLeafCertHash = &base.Command{
-	UsageLine: "{{.Exec}} tls leafCertHash",
-	Short:     "Calculate TLS leaf certificate hash.",
-	Long: `
-	xray tls leafCertHash --cert <cert.pem>
-	Calculate TLS leaf certificate hash.
-	`,
-}
-
-func init() {
-	cmdLeafCertHash.Run = executeLeafCertHash // break init loop
-}
-
-var input = cmdLeafCertHash.Flag.String("cert", "fullchain.pem", "The file path of the leaf certificate")
-
-func executeLeafCertHash(cmd *base.Command, args []string) {
-	fs := flag.NewFlagSet("leafCertHash", flag.ContinueOnError)
-	if err := fs.Parse(args); err != nil {
-		fmt.Println(err)
-		return
-	}
-	certContent, err := os.ReadFile(*input)
-	if err != nil {
-		fmt.Println(err)
-		return
-	}
-	certChainHashB64, err := tls.CalculatePEMLeafCertSHA256Hash(certContent)
-	if err != nil {
-		fmt.Println("failed to decode cert", err)
-		return
-	}
-	fmt.Println(certChainHashB64)
-}

+ 8 - 8
main/commands/all/tls/ping.go

@@ -135,15 +135,15 @@ func printCertificates(tabWriter *tabwriter.Writer, certs []*x509.Certificate) {
 			CAs = append(CAs, cert)
 		}
 	}
-	fmt.Fprintf(tabWriter, "Certificate chain's total length: \t %d (certs count: %s)\n", length, strconv.Itoa(len(certs)))
+	fmt.Fprintf(tabWriter, "Certificate chain's total length:\t%d (certs count: %s)\n", length, strconv.Itoa(len(certs)))
 	if leaf != nil {
-		fmt.Fprintf(tabWriter, "Cert's signature algorithm: \t %s\n", leaf.SignatureAlgorithm.String())
-		fmt.Fprintf(tabWriter, "Cert's publicKey algorithm: \t %s\n", leaf.PublicKeyAlgorithm.String())
-		fmt.Fprintf(tabWriter, "Cert's leaf SHA256: \t %s\n", hex.EncodeToString(GenerateCertHash(leaf)))
+		fmt.Fprintf(tabWriter, "Cert's signature algorithm:\t%s\n", leaf.SignatureAlgorithm.String())
+		fmt.Fprintf(tabWriter, "Cert's publicKey algorithm:\t%s\n", leaf.PublicKeyAlgorithm.String())
+		fmt.Fprintf(tabWriter, "Cert's leaf SHA256:\t%s\n", hex.EncodeToString(GenerateCertHash(leaf)))
 		for _, ca := range CAs {
-			fmt.Fprintf(tabWriter, "Cert's CA: %s SHA256: \t %s\n", ca.Subject.CommonName, hex.EncodeToString(GenerateCertHash(ca)))
+			fmt.Fprintf(tabWriter, "Cert's CA <%s> SHA256:\t%s\n", ca.Subject.CommonName, hex.EncodeToString(GenerateCertHash(ca)))
 		}
-		fmt.Fprintf(tabWriter, "Cert's allowed domains: \t %v\n", leaf.DNSNames)
+		fmt.Fprintf(tabWriter, "Cert's allowed domains:\t%v\n", leaf.DNSNames)
 	}
 }
 
@@ -156,11 +156,11 @@ func printTLSConnDetail(tabWriter *tabwriter.Writer, tlsConn *utls.UConn) {
 	case gotls.VersionTLS12:
 		tlsVersion = "TLS 1.2"
 	}
-	fmt.Fprintf(tabWriter, "TLS Version: \t %s\n", tlsVersion)
+	fmt.Fprintf(tabWriter, "TLS Version:\t%s\n", tlsVersion)
 	curveID := utils.AccessField[utls.CurveID](tlsConn.Conn, "curveID")
 	if curveID != nil {
 		PostQuantum := (*curveID == utls.X25519MLKEM768)
-		fmt.Fprintf(tabWriter, "TLS Post-Quantum key exchange: \t %t (%s)\n", PostQuantum, curveID.String())
+		fmt.Fprintf(tabWriter, "TLS Post-Quantum key exchange:\t%t (%s)\n", PostQuantum, curveID.String())
 	} else {
 		fmt.Fprintf(tabWriter, "TLS Post-Quantum key exchange:  false (RSA Exchange)\n")
 	}

+ 1 - 1
main/commands/all/tls/tls.go

@@ -13,7 +13,7 @@ var CmdTLS = &base.Command{
 	Commands: []*base.Command{
 		cmdCert,
 		cmdPing,
-		cmdLeafCertHash,
+		cmdHash,
 		cmdECH,
 	},
 }

+ 7 - 5
transport/internet/tls/config.go

@@ -289,9 +289,6 @@ func (r *RandCarrier) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509
 	if len(certs) == 0 {
 		return errors.New("unexpected certs")
 	}
-	if certs[0].IsCA {
-		slices.Reverse(certs)
-	}
 
 	// directly return success if pinned cert is leaf
 	// or replace RootCAs if pinned cert is CA (and can be used in VerifyPeerCertByName)
@@ -558,14 +555,19 @@ const (
 )
 
 func verifyChain(certs []*x509.Certificate, pinnedPeerCertSha256 [][]byte) (verifyResult, *x509.Certificate) {
+	leafHash := GenerateCertHash(certs[0])
+	for _, c := range pinnedPeerCertSha256 {
+		if hmac.Equal(leafHash, c) {
+			return foundLeaf, nil
+		}
+	}
+	certs = certs[1:] // skip leaf
 	for _, cert := range certs {
 		certHash := GenerateCertHash(cert)
 		for _, c := range pinnedPeerCertSha256 {
 			if hmac.Equal(certHash, c) {
 				if cert.IsCA {
 					return foundCA, cert
-				} else {
-					return foundLeaf, cert
 				}
 			}
 		}

+ 11 - 20
transport/internet/tls/pin.go

@@ -4,28 +4,8 @@ import (
 	"crypto/sha256"
 	"crypto/x509"
 	"encoding/hex"
-	"encoding/pem"
 )
 
-func CalculatePEMLeafCertSHA256Hash(certContent []byte) (string, error) {
-	var leafCert *x509.Certificate
-	for {
-		var err error
-		block, remain := pem.Decode(certContent)
-		if block == nil {
-			break
-		}
-		leafCert, err = x509.ParseCertificate(block.Bytes)
-		if err != nil {
-			return "", err
-		}
-		certContent = remain
-	}
-	certHash := GenerateCertHash(leafCert)
-	certHashHex := hex.EncodeToString(certHash)
-	return certHashHex, nil
-}
-
 // []byte must be ASN.1 DER content
 func GenerateCertHash[T *x509.Certificate | []byte](cert T) []byte {
 	var out [32]byte
@@ -37,3 +17,14 @@ func GenerateCertHash[T *x509.Certificate | []byte](cert T) []byte {
 	}
 	return out[:]
 }
+
+func GenerateCertHashHex[T *x509.Certificate | []byte](cert T) string {
+	var out [32]byte
+	switch v := any(cert).(type) {
+	case *x509.Certificate:
+		out = sha256.Sum256(v.Raw)
+	case []byte:
+		out = sha256.Sum256(v)
+	}
+	return hex.EncodeToString(out[:])
+}