cbc.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package ciphersuite
  4. import ( //nolint:gci
  5. "crypto/aes"
  6. "crypto/cipher"
  7. "crypto/hmac"
  8. "crypto/rand"
  9. "encoding/binary"
  10. "hash"
  11. "github.com/pion/dtls/v2/internal/util"
  12. "github.com/pion/dtls/v2/pkg/crypto/prf"
  13. "github.com/pion/dtls/v2/pkg/protocol"
  14. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  15. )
  16. // block ciphers using cipher block chaining.
  17. type cbcMode interface {
  18. cipher.BlockMode
  19. SetIV([]byte)
  20. }
  21. // CBC Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
  22. type CBC struct {
  23. writeCBC, readCBC cbcMode
  24. writeMac, readMac []byte
  25. h prf.HashFunc
  26. }
  27. // NewCBC creates a DTLS CBC Cipher
  28. func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte, h prf.HashFunc) (*CBC, error) {
  29. writeBlock, err := aes.NewCipher(localKey)
  30. if err != nil {
  31. return nil, err
  32. }
  33. readBlock, err := aes.NewCipher(remoteKey)
  34. if err != nil {
  35. return nil, err
  36. }
  37. writeCBC, ok := cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode)
  38. if !ok {
  39. return nil, errFailedToCast
  40. }
  41. readCBC, ok := cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode)
  42. if !ok {
  43. return nil, errFailedToCast
  44. }
  45. return &CBC{
  46. writeCBC: writeCBC,
  47. writeMac: localMac,
  48. readCBC: readCBC,
  49. readMac: remoteMac,
  50. h: h,
  51. }, nil
  52. }
  53. // Encrypt encrypt a DTLS RecordLayer message
  54. func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
  55. payload := raw[recordlayer.HeaderSize:]
  56. raw = raw[:recordlayer.HeaderSize]
  57. blockSize := c.writeCBC.BlockSize()
  58. // Generate + Append MAC
  59. h := pkt.Header
  60. MAC, err := c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, payload, c.writeMac, c.h)
  61. if err != nil {
  62. return nil, err
  63. }
  64. payload = append(payload, MAC...)
  65. // Generate + Append padding
  66. padding := make([]byte, blockSize-len(payload)%blockSize)
  67. paddingLen := len(padding)
  68. for i := 0; i < paddingLen; i++ {
  69. padding[i] = byte(paddingLen - 1)
  70. }
  71. payload = append(payload, padding...)
  72. // Generate IV
  73. iv := make([]byte, blockSize)
  74. if _, err := rand.Read(iv); err != nil {
  75. return nil, err
  76. }
  77. // Set IV + Encrypt + Prepend IV
  78. c.writeCBC.SetIV(iv)
  79. c.writeCBC.CryptBlocks(payload, payload)
  80. payload = append(iv, payload...)
  81. // Prepend unencrypte header with encrypted payload
  82. raw = append(raw, payload...)
  83. // Update recordLayer size to include IV+MAC+Padding
  84. binary.BigEndian.PutUint16(raw[recordlayer.HeaderSize-2:], uint16(len(raw)-recordlayer.HeaderSize))
  85. return raw, nil
  86. }
  87. // Decrypt decrypts a DTLS RecordLayer message
  88. func (c *CBC) Decrypt(in []byte) ([]byte, error) {
  89. body := in[recordlayer.HeaderSize:]
  90. blockSize := c.readCBC.BlockSize()
  91. mac := c.h()
  92. var h recordlayer.Header
  93. err := h.Unmarshal(in)
  94. switch {
  95. case err != nil:
  96. return nil, err
  97. case h.ContentType == protocol.ContentTypeChangeCipherSpec:
  98. // Nothing to encrypt with ChangeCipherSpec
  99. return in, nil
  100. case len(body)%blockSize != 0 || len(body) < blockSize+util.Max(mac.Size()+1, blockSize):
  101. return nil, errNotEnoughRoomForNonce
  102. }
  103. // Set + remove per record IV
  104. c.readCBC.SetIV(body[:blockSize])
  105. body = body[blockSize:]
  106. // Decrypt
  107. c.readCBC.CryptBlocks(body, body)
  108. // Padding+MAC needs to be checked in constant time
  109. // Otherwise we reveal information about the level of correctness
  110. paddingLen, paddingGood := examinePadding(body)
  111. if paddingGood != 255 {
  112. return nil, errInvalidMAC
  113. }
  114. macSize := mac.Size()
  115. if len(body) < macSize {
  116. return nil, errInvalidMAC
  117. }
  118. dataEnd := len(body) - macSize - paddingLen
  119. expectedMAC := body[dataEnd : dataEnd+macSize]
  120. actualMAC, err := c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, body[:dataEnd], c.readMac, c.h)
  121. // Compute Local MAC and compare
  122. if err != nil || !hmac.Equal(actualMAC, expectedMAC) {
  123. return nil, errInvalidMAC
  124. }
  125. return append(in[:recordlayer.HeaderSize], body[:dataEnd]...), nil
  126. }
  127. func (c *CBC) hmac(epoch uint16, sequenceNumber uint64, contentType protocol.ContentType, protocolVersion protocol.Version, payload []byte, key []byte, hf func() hash.Hash) ([]byte, error) {
  128. h := hmac.New(hf, key)
  129. msg := make([]byte, 13)
  130. binary.BigEndian.PutUint16(msg, epoch)
  131. util.PutBigEndianUint48(msg[2:], sequenceNumber)
  132. msg[8] = byte(contentType)
  133. msg[9] = protocolVersion.Major
  134. msg[10] = protocolVersion.Minor
  135. binary.BigEndian.PutUint16(msg[11:], uint16(len(payload)))
  136. if _, err := h.Write(msg); err != nil {
  137. return nil, err
  138. } else if _, err := h.Write(payload); err != nil {
  139. return nil, err
  140. }
  141. return h.Sum(nil), nil
  142. }