certificate.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "bytes"
  6. "crypto/tls"
  7. "crypto/x509"
  8. "fmt"
  9. "strings"
  10. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  11. )
  12. // ClientHelloInfo contains information from a ClientHello message in order to
  13. // guide application logic in the GetCertificate.
  14. type ClientHelloInfo struct {
  15. // ServerName indicates the name of the server requested by the client
  16. // in order to support virtual hosting. ServerName is only set if the
  17. // client is using SNI (see RFC 4366, Section 3.1).
  18. ServerName string
  19. // CipherSuites lists the CipherSuites supported by the client (e.g.
  20. // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).
  21. CipherSuites []CipherSuiteID
  22. // [Psiphon]
  23. // Conjure DTLS support, from: https://github.com/mingyech/dtls/commit/a56eccc1
  24. RandomBytes [handshake.RandomBytesLength]byte
  25. }
  26. // CertificateRequestInfo contains information from a server's
  27. // CertificateRequest message, which is used to demand a certificate and proof
  28. // of control from a client.
  29. type CertificateRequestInfo struct {
  30. // AcceptableCAs contains zero or more, DER-encoded, X.501
  31. // Distinguished Names. These are the names of root or intermediate CAs
  32. // that the server wishes the returned certificate to be signed by. An
  33. // empty slice indicates that the server has no preference.
  34. AcceptableCAs [][]byte
  35. }
  36. // SupportsCertificate returns nil if the provided certificate is supported by
  37. // the server that sent the CertificateRequest. Otherwise, it returns an error
  38. // describing the reason for the incompatibility.
  39. // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273
  40. func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error {
  41. if len(cri.AcceptableCAs) == 0 {
  42. return nil
  43. }
  44. for j, cert := range c.Certificate {
  45. x509Cert := c.Leaf
  46. // Parse the certificate if this isn't the leaf node, or if
  47. // chain.Leaf was nil.
  48. if j != 0 || x509Cert == nil {
  49. var err error
  50. if x509Cert, err = x509.ParseCertificate(cert); err != nil {
  51. return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err)
  52. }
  53. }
  54. for _, ca := range cri.AcceptableCAs {
  55. if bytes.Equal(x509Cert.RawIssuer, ca) {
  56. return nil
  57. }
  58. }
  59. }
  60. return errNotAcceptableCertificateChain
  61. }
  62. func (c *handshakeConfig) setNameToCertificateLocked() {
  63. nameToCertificate := make(map[string]*tls.Certificate)
  64. for i := range c.localCertificates {
  65. cert := &c.localCertificates[i]
  66. x509Cert := cert.Leaf
  67. if x509Cert == nil {
  68. var parseErr error
  69. x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0])
  70. if parseErr != nil {
  71. continue
  72. }
  73. }
  74. if len(x509Cert.Subject.CommonName) > 0 {
  75. nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert
  76. }
  77. for _, san := range x509Cert.DNSNames {
  78. nameToCertificate[strings.ToLower(san)] = cert
  79. }
  80. }
  81. c.nameToCertificate = nameToCertificate
  82. }
  83. func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) {
  84. c.mu.Lock()
  85. defer c.mu.Unlock()
  86. if c.localGetCertificate != nil &&
  87. (len(c.localCertificates) == 0 || len(clientHelloInfo.ServerName) > 0) {
  88. cert, err := c.localGetCertificate(clientHelloInfo)
  89. if cert != nil || err != nil {
  90. return cert, err
  91. }
  92. }
  93. if c.nameToCertificate == nil {
  94. c.setNameToCertificateLocked()
  95. }
  96. if len(c.localCertificates) == 0 {
  97. return nil, errNoCertificates
  98. }
  99. if len(c.localCertificates) == 1 {
  100. // There's only one choice, so no point doing any work.
  101. return &c.localCertificates[0], nil
  102. }
  103. if len(clientHelloInfo.ServerName) == 0 {
  104. return &c.localCertificates[0], nil
  105. }
  106. name := strings.TrimRight(strings.ToLower(clientHelloInfo.ServerName), ".")
  107. if cert, ok := c.nameToCertificate[name]; ok {
  108. return cert, nil
  109. }
  110. // try replacing labels in the name with wildcards until we get a
  111. // match.
  112. labels := strings.Split(name, ".")
  113. for i := range labels {
  114. labels[i] = "*"
  115. candidate := strings.Join(labels, ".")
  116. if cert, ok := c.nameToCertificate[candidate]; ok {
  117. return cert, nil
  118. }
  119. }
  120. // If nothing matches, return the first certificate.
  121. return &c.localCertificates[0], nil
  122. }
  123. // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974
  124. func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) {
  125. c.mu.Lock()
  126. defer c.mu.Unlock()
  127. if c.localGetClientCertificate != nil {
  128. return c.localGetClientCertificate(cri)
  129. }
  130. for i := range c.localCertificates {
  131. chain := c.localCertificates[i]
  132. if err := cri.SupportsCertificate(&chain); err != nil {
  133. continue
  134. }
  135. return &chain, nil
  136. }
  137. // No acceptable certificate found. Don't send a certificate.
  138. return new(tls.Certificate), nil
  139. }