certificate.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "bytes"
  6. "crypto/tls"
  7. "crypto/x509"
  8. "fmt"
  9. "strings"
  10. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  11. )
  12. // ClientHelloInfo contains information from a ClientHello message in order to
  13. // guide application logic in the GetCertificate.
  14. type ClientHelloInfo struct {
  15. // ServerName indicates the name of the server requested by the client
  16. // in order to support virtual hosting. ServerName is only set if the
  17. // client is using SNI (see RFC 4366, Section 3.1).
  18. ServerName string
  19. // CipherSuites lists the CipherSuites supported by the client (e.g.
  20. // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).
  21. CipherSuites []CipherSuiteID
  22. RandomBytes [handshake.RandomBytesLength]byte
  23. }
  24. // CertificateRequestInfo contains information from a server's
  25. // CertificateRequest message, which is used to demand a certificate and proof
  26. // of control from a client.
  27. type CertificateRequestInfo struct {
  28. // AcceptableCAs contains zero or more, DER-encoded, X.501
  29. // Distinguished Names. These are the names of root or intermediate CAs
  30. // that the server wishes the returned certificate to be signed by. An
  31. // empty slice indicates that the server has no preference.
  32. AcceptableCAs [][]byte
  33. }
  34. // SupportsCertificate returns nil if the provided certificate is supported by
  35. // the server that sent the CertificateRequest. Otherwise, it returns an error
  36. // describing the reason for the incompatibility.
  37. // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273
  38. func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error {
  39. if len(cri.AcceptableCAs) == 0 {
  40. return nil
  41. }
  42. for j, cert := range c.Certificate {
  43. x509Cert := c.Leaf
  44. // Parse the certificate if this isn't the leaf node, or if
  45. // chain.Leaf was nil.
  46. if j != 0 || x509Cert == nil {
  47. var err error
  48. if x509Cert, err = x509.ParseCertificate(cert); err != nil {
  49. return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err)
  50. }
  51. }
  52. for _, ca := range cri.AcceptableCAs {
  53. if bytes.Equal(x509Cert.RawIssuer, ca) {
  54. return nil
  55. }
  56. }
  57. }
  58. return errNotAcceptableCertificateChain
  59. }
  60. func (c *handshakeConfig) setNameToCertificateLocked() {
  61. nameToCertificate := make(map[string]*tls.Certificate)
  62. for i := range c.localCertificates {
  63. cert := &c.localCertificates[i]
  64. x509Cert := cert.Leaf
  65. if x509Cert == nil {
  66. var parseErr error
  67. x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0])
  68. if parseErr != nil {
  69. continue
  70. }
  71. }
  72. if len(x509Cert.Subject.CommonName) > 0 {
  73. nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert
  74. }
  75. for _, san := range x509Cert.DNSNames {
  76. nameToCertificate[strings.ToLower(san)] = cert
  77. }
  78. }
  79. c.nameToCertificate = nameToCertificate
  80. }
  81. func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) {
  82. c.mu.Lock()
  83. defer c.mu.Unlock()
  84. if c.localGetCertificate != nil &&
  85. (len(c.localCertificates) == 0 || len(clientHelloInfo.ServerName) > 0) {
  86. cert, err := c.localGetCertificate(clientHelloInfo)
  87. if cert != nil || err != nil {
  88. return cert, err
  89. }
  90. }
  91. if c.nameToCertificate == nil {
  92. c.setNameToCertificateLocked()
  93. }
  94. if len(c.localCertificates) == 0 {
  95. return nil, errNoCertificates
  96. }
  97. if len(c.localCertificates) == 1 {
  98. // There's only one choice, so no point doing any work.
  99. return &c.localCertificates[0], nil
  100. }
  101. if len(clientHelloInfo.ServerName) == 0 {
  102. return &c.localCertificates[0], nil
  103. }
  104. name := strings.TrimRight(strings.ToLower(clientHelloInfo.ServerName), ".")
  105. if cert, ok := c.nameToCertificate[name]; ok {
  106. return cert, nil
  107. }
  108. // try replacing labels in the name with wildcards until we get a
  109. // match.
  110. labels := strings.Split(name, ".")
  111. for i := range labels {
  112. labels[i] = "*"
  113. candidate := strings.Join(labels, ".")
  114. if cert, ok := c.nameToCertificate[candidate]; ok {
  115. return cert, nil
  116. }
  117. }
  118. // If nothing matches, return the first certificate.
  119. return &c.localCertificates[0], nil
  120. }
  121. // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974
  122. func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) {
  123. c.mu.Lock()
  124. defer c.mu.Unlock()
  125. if c.localGetClientCertificate != nil {
  126. return c.localGetClientCertificate(cri)
  127. }
  128. for i := range c.localCertificates {
  129. chain := c.localCertificates[i]
  130. if err := cri.SupportsCertificate(&chain); err != nil {
  131. continue
  132. }
  133. return &chain, nil
  134. }
  135. // No acceptable certificate found. Don't send a certificate.
  136. return new(tls.Certificate), nil
  137. }