ccm.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package ciphersuite
  4. import (
  5. "crypto/aes"
  6. "crypto/rand"
  7. "encoding/binary"
  8. "fmt"
  9. "github.com/pion/dtls/v2/pkg/crypto/ccm"
  10. "github.com/pion/dtls/v2/pkg/protocol"
  11. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  12. )
  13. // CCMTagLen is the length of Authentication Tag
  14. type CCMTagLen int
  15. // CCM Enums
  16. const (
  17. CCMTagLength8 CCMTagLen = 8
  18. CCMTagLength CCMTagLen = 16
  19. ccmNonceLength = 12
  20. )
  21. // CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
  22. type CCM struct {
  23. localCCM, remoteCCM ccm.CCM
  24. localWriteIV, remoteWriteIV []byte
  25. tagLen CCMTagLen
  26. }
  27. // NewCCM creates a DTLS GCM Cipher
  28. func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*CCM, error) {
  29. localBlock, err := aes.NewCipher(localKey)
  30. if err != nil {
  31. return nil, err
  32. }
  33. localCCM, err := ccm.NewCCM(localBlock, int(tagLen), ccmNonceLength)
  34. if err != nil {
  35. return nil, err
  36. }
  37. remoteBlock, err := aes.NewCipher(remoteKey)
  38. if err != nil {
  39. return nil, err
  40. }
  41. remoteCCM, err := ccm.NewCCM(remoteBlock, int(tagLen), ccmNonceLength)
  42. if err != nil {
  43. return nil, err
  44. }
  45. return &CCM{
  46. localCCM: localCCM,
  47. localWriteIV: localWriteIV,
  48. remoteCCM: remoteCCM,
  49. remoteWriteIV: remoteWriteIV,
  50. tagLen: tagLen,
  51. }, nil
  52. }
  53. // Encrypt encrypt a DTLS RecordLayer message
  54. func (c *CCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
  55. payload := raw[recordlayer.HeaderSize:]
  56. raw = raw[:recordlayer.HeaderSize]
  57. nonce := append(append([]byte{}, c.localWriteIV[:4]...), make([]byte, 8)...)
  58. if _, err := rand.Read(nonce[4:]); err != nil {
  59. return nil, err
  60. }
  61. additionalData := generateAEADAdditionalData(&pkt.Header, len(payload))
  62. encryptedPayload := c.localCCM.Seal(nil, nonce, payload, additionalData)
  63. encryptedPayload = append(nonce[4:], encryptedPayload...)
  64. raw = append(raw, encryptedPayload...)
  65. // Update recordLayer size to include explicit nonce
  66. binary.BigEndian.PutUint16(raw[recordlayer.HeaderSize-2:], uint16(len(raw)-recordlayer.HeaderSize))
  67. return raw, nil
  68. }
  69. // Decrypt decrypts a DTLS RecordLayer message
  70. func (c *CCM) Decrypt(in []byte) ([]byte, error) {
  71. var h recordlayer.Header
  72. err := h.Unmarshal(in)
  73. switch {
  74. case err != nil:
  75. return nil, err
  76. case h.ContentType == protocol.ContentTypeChangeCipherSpec:
  77. // Nothing to encrypt with ChangeCipherSpec
  78. return in, nil
  79. case len(in) <= (8 + recordlayer.HeaderSize):
  80. return nil, errNotEnoughRoomForNonce
  81. }
  82. nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[recordlayer.HeaderSize:recordlayer.HeaderSize+8]...)
  83. out := in[recordlayer.HeaderSize+8:]
  84. additionalData := generateAEADAdditionalData(&h, len(out)-int(c.tagLen))
  85. out, err = c.remoteCCM.Open(out[:0], nonce, out, additionalData)
  86. if err != nil {
  87. return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint
  88. }
  89. return append(in[:recordlayer.HeaderSize], out...), nil
  90. }