gcm.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package ciphersuite
  4. import (
  5. "crypto/aes"
  6. "crypto/cipher"
  7. "crypto/rand"
  8. "encoding/binary"
  9. "fmt"
  10. "github.com/pion/dtls/v2/pkg/protocol"
  11. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  12. )
  13. const (
  14. gcmTagLength = 16
  15. gcmNonceLength = 12
  16. )
  17. // GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
  18. type GCM struct {
  19. localGCM, remoteGCM cipher.AEAD
  20. localWriteIV, remoteWriteIV []byte
  21. }
  22. // NewGCM creates a DTLS GCM Cipher
  23. func NewGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*GCM, error) {
  24. localBlock, err := aes.NewCipher(localKey)
  25. if err != nil {
  26. return nil, err
  27. }
  28. localGCM, err := cipher.NewGCM(localBlock)
  29. if err != nil {
  30. return nil, err
  31. }
  32. remoteBlock, err := aes.NewCipher(remoteKey)
  33. if err != nil {
  34. return nil, err
  35. }
  36. remoteGCM, err := cipher.NewGCM(remoteBlock)
  37. if err != nil {
  38. return nil, err
  39. }
  40. return &GCM{
  41. localGCM: localGCM,
  42. localWriteIV: localWriteIV,
  43. remoteGCM: remoteGCM,
  44. remoteWriteIV: remoteWriteIV,
  45. }, nil
  46. }
  47. // Encrypt encrypt a DTLS RecordLayer message
  48. func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
  49. payload := raw[recordlayer.HeaderSize:]
  50. raw = raw[:recordlayer.HeaderSize]
  51. nonce := make([]byte, gcmNonceLength)
  52. copy(nonce, g.localWriteIV[:4])
  53. if _, err := rand.Read(nonce[4:]); err != nil {
  54. return nil, err
  55. }
  56. additionalData := generateAEADAdditionalData(&pkt.Header, len(payload))
  57. encryptedPayload := g.localGCM.Seal(nil, nonce, payload, additionalData)
  58. r := make([]byte, len(raw)+len(nonce[4:])+len(encryptedPayload))
  59. copy(r, raw)
  60. copy(r[len(raw):], nonce[4:])
  61. copy(r[len(raw)+len(nonce[4:]):], encryptedPayload)
  62. // Update recordLayer size to include explicit nonce
  63. binary.BigEndian.PutUint16(r[recordlayer.HeaderSize-2:], uint16(len(r)-recordlayer.HeaderSize))
  64. return r, nil
  65. }
  66. // Decrypt decrypts a DTLS RecordLayer message
  67. func (g *GCM) Decrypt(in []byte) ([]byte, error) {
  68. var h recordlayer.Header
  69. err := h.Unmarshal(in)
  70. switch {
  71. case err != nil:
  72. return nil, err
  73. case h.ContentType == protocol.ContentTypeChangeCipherSpec:
  74. // Nothing to encrypt with ChangeCipherSpec
  75. return in, nil
  76. case len(in) <= (8 + recordlayer.HeaderSize):
  77. return nil, errNotEnoughRoomForNonce
  78. }
  79. nonce := make([]byte, 0, gcmNonceLength)
  80. nonce = append(append(nonce, g.remoteWriteIV[:4]...), in[recordlayer.HeaderSize:recordlayer.HeaderSize+8]...)
  81. out := in[recordlayer.HeaderSize+8:]
  82. additionalData := generateAEADAdditionalData(&h, len(out)-gcmTagLength)
  83. out, err = g.remoteGCM.Open(out[:0], nonce, out, additionalData)
  84. if err != nil {
  85. return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint
  86. }
  87. return append(in[:recordlayer.HeaderSize], out...), nil
  88. }