| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- package hpke
- import (
- "crypto/cipher"
- "fmt"
- )
- type encdecContext struct {
- // Serialized parameters
- suite Suite
- sharedSecret []byte
- secret []byte
- keyScheduleContext []byte
- exporterSecret []byte
- key []byte
- baseNonce []byte
- sequenceNumber []byte
- // Operational parameters
- cipher.AEAD
- nonce []byte
- }
- type (
- sealContext struct{ *encdecContext }
- openContext struct{ *encdecContext }
- )
- // Export takes a context string exporterContext and a desired length (in
- // bytes), and produces a secret derived from the internal exporter secret
- // using the corresponding KDF Expand function. It panics if length is
- // greater than 255*N bytes, where N is the size (in bytes) of the KDF's
- // output.
- func (c *encdecContext) Export(exporterContext []byte, length uint) []byte {
- maxLength := uint(255 * c.suite.kdfID.ExtractSize())
- if length > maxLength {
- panic(fmt.Errorf("output length must be lesser than %v bytes", maxLength))
- }
- return c.suite.labeledExpand(c.exporterSecret, []byte("sec"),
- exporterContext, uint16(length))
- }
- func (c *encdecContext) Suite() Suite {
- return c.suite
- }
- func (c *encdecContext) calcNonce() []byte {
- for i := range c.baseNonce {
- c.nonce[i] = c.baseNonce[i] ^ c.sequenceNumber[i]
- }
- return c.nonce
- }
- func (c *encdecContext) increment() error {
- // tests whether the sequence number is all-ones, which prevents an
- // overflow after the increment.
- allOnes := byte(0xFF)
- for i := range c.sequenceNumber {
- allOnes &= c.sequenceNumber[i]
- }
- if allOnes == byte(0xFF) {
- return ErrAEADSeqOverflows
- }
- // performs an increment by 1 and verifies whether the sequence overflows.
- carry := uint(1)
- for i := len(c.sequenceNumber) - 1; i >= 0; i-- {
- sum := uint(c.sequenceNumber[i]) + carry
- carry = sum >> 8
- c.sequenceNumber[i] = byte(sum & 0xFF)
- }
- if carry != 0 {
- return ErrAEADSeqOverflows
- }
- return nil
- }
- func (c *sealContext) Seal(pt, aad []byte) ([]byte, error) {
- ct := c.AEAD.Seal(nil, c.calcNonce(), pt, aad)
- err := c.increment()
- if err != nil {
- for i := range ct {
- ct[i] = 0
- }
- return nil, err
- }
- return ct, nil
- }
- func (c *openContext) Open(ct, aad []byte) ([]byte, error) {
- pt, err := c.AEAD.Open(nil, c.calcNonce(), ct, aad)
- if err != nil {
- return nil, err
- }
- err = c.increment()
- if err != nil {
- for i := range pt {
- pt[i] = 0
- }
- return nil, err
- }
- return pt, nil
- }
|