certificate_test.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "crypto/tls"
  6. "reflect"
  7. "testing"
  8. "github.com/pion/dtls/v2/pkg/crypto/selfsign"
  9. )
  10. func TestGetCertificate(t *testing.T) {
  11. certificateWildcard, err := selfsign.GenerateSelfSignedWithDNS("*.test.test")
  12. if err != nil {
  13. t.Fatal(err)
  14. }
  15. certificateTest, err := selfsign.GenerateSelfSignedWithDNS("test.test", "www.test.test", "pop.test.test")
  16. if err != nil {
  17. t.Fatal(err)
  18. }
  19. certificateRandom, err := selfsign.GenerateSelfSigned()
  20. if err != nil {
  21. t.Fatal(err)
  22. }
  23. testCases := []struct {
  24. localCertificates []tls.Certificate
  25. desc string
  26. serverName string
  27. expectedCertificate tls.Certificate
  28. getCertificate func(info *ClientHelloInfo) (*tls.Certificate, error)
  29. }{
  30. {
  31. desc: "Simple match in CN",
  32. localCertificates: []tls.Certificate{
  33. certificateRandom,
  34. certificateTest,
  35. certificateWildcard,
  36. },
  37. serverName: "test.test",
  38. expectedCertificate: certificateTest,
  39. },
  40. {
  41. desc: "Simple match in SANs",
  42. localCertificates: []tls.Certificate{
  43. certificateRandom,
  44. certificateTest,
  45. certificateWildcard,
  46. },
  47. serverName: "www.test.test",
  48. expectedCertificate: certificateTest,
  49. },
  50. {
  51. desc: "Wildcard match",
  52. localCertificates: []tls.Certificate{
  53. certificateRandom,
  54. certificateTest,
  55. certificateWildcard,
  56. },
  57. serverName: "foo.test.test",
  58. expectedCertificate: certificateWildcard,
  59. },
  60. {
  61. desc: "No match return first",
  62. localCertificates: []tls.Certificate{
  63. certificateRandom,
  64. certificateTest,
  65. certificateWildcard,
  66. },
  67. serverName: "foo.bar",
  68. expectedCertificate: certificateRandom,
  69. },
  70. {
  71. desc: "Get certificate from callback",
  72. getCertificate: func(info *ClientHelloInfo) (*tls.Certificate, error) {
  73. return &certificateTest, nil
  74. },
  75. expectedCertificate: certificateTest,
  76. },
  77. }
  78. for _, test := range testCases {
  79. test := test
  80. t.Run(test.desc, func(t *testing.T) {
  81. cfg := &handshakeConfig{
  82. localCertificates: test.localCertificates,
  83. localGetCertificate: test.getCertificate,
  84. }
  85. cert, err := cfg.getCertificate(&ClientHelloInfo{ServerName: test.serverName})
  86. if err != nil {
  87. t.Fatal(err)
  88. }
  89. if !reflect.DeepEqual(cert.Leaf, test.expectedCertificate.Leaf) {
  90. t.Fatalf("Certificate does not match: expected(%v) actual(%v)", test.expectedCertificate.Leaf, cert.Leaf)
  91. }
  92. })
  93. }
  94. }