aead.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package hpke
  2. import (
  3. "crypto/cipher"
  4. "fmt"
  5. )
  6. type encdecContext struct {
  7. // Serialized parameters
  8. suite Suite
  9. sharedSecret []byte
  10. secret []byte
  11. keyScheduleContext []byte
  12. exporterSecret []byte
  13. key []byte
  14. baseNonce []byte
  15. sequenceNumber []byte
  16. // Operational parameters
  17. cipher.AEAD
  18. nonce []byte
  19. }
  20. type (
  21. sealContext struct{ *encdecContext }
  22. openContext struct{ *encdecContext }
  23. )
  24. // Export takes a context string exporterContext and a desired length (in
  25. // bytes), and produces a secret derived from the internal exporter secret
  26. // using the corresponding KDF Expand function. It panics if length is
  27. // greater than 255*N bytes, where N is the size (in bytes) of the KDF's
  28. // output.
  29. func (c *encdecContext) Export(exporterContext []byte, length uint) []byte {
  30. maxLength := uint(255 * c.suite.kdfID.ExtractSize())
  31. if length > maxLength {
  32. panic(fmt.Errorf("output length must be lesser than %v bytes", maxLength))
  33. }
  34. return c.suite.labeledExpand(c.exporterSecret, []byte("sec"),
  35. exporterContext, uint16(length))
  36. }
  37. func (c *encdecContext) Suite() Suite {
  38. return c.suite
  39. }
  40. func (c *encdecContext) calcNonce() []byte {
  41. for i := range c.baseNonce {
  42. c.nonce[i] = c.baseNonce[i] ^ c.sequenceNumber[i]
  43. }
  44. return c.nonce
  45. }
  46. func (c *encdecContext) increment() error {
  47. // tests whether the sequence number is all-ones, which prevents an
  48. // overflow after the increment.
  49. allOnes := byte(0xFF)
  50. for i := range c.sequenceNumber {
  51. allOnes &= c.sequenceNumber[i]
  52. }
  53. if allOnes == byte(0xFF) {
  54. return ErrAEADSeqOverflows
  55. }
  56. // performs an increment by 1 and verifies whether the sequence overflows.
  57. carry := uint(1)
  58. for i := len(c.sequenceNumber) - 1; i >= 0; i-- {
  59. sum := uint(c.sequenceNumber[i]) + carry
  60. carry = sum >> 8
  61. c.sequenceNumber[i] = byte(sum & 0xFF)
  62. }
  63. if carry != 0 {
  64. return ErrAEADSeqOverflows
  65. }
  66. return nil
  67. }
  68. func (c *sealContext) Seal(pt, aad []byte) ([]byte, error) {
  69. ct := c.AEAD.Seal(nil, c.calcNonce(), pt, aad)
  70. err := c.increment()
  71. if err != nil {
  72. for i := range ct {
  73. ct[i] = 0
  74. }
  75. return nil, err
  76. }
  77. return ct, nil
  78. }
  79. func (c *openContext) Open(ct, aad []byte) ([]byte, error) {
  80. pt, err := c.AEAD.Open(nil, c.calcNonce(), ct, aad)
  81. if err != nil {
  82. return nil, err
  83. }
  84. err = c.increment()
  85. if err != nil {
  86. for i := range pt {
  87. pt[i] = 0
  88. }
  89. return nil, err
  90. }
  91. return pt, nil
  92. }