util.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package hpke
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. )
  7. func (st state) keySchedule(ss, info, psk, pskID []byte) (*encdecContext, error) {
  8. if err := st.verifyPSKInputs(psk, pskID); err != nil {
  9. return nil, err
  10. }
  11. pskIDHash := st.labeledExtract(nil, []byte("psk_id_hash"), pskID)
  12. infoHash := st.labeledExtract(nil, []byte("info_hash"), info)
  13. keySchCtx := append(append(
  14. []byte{st.modeID},
  15. pskIDHash...),
  16. infoHash...)
  17. secret := st.labeledExtract(ss, []byte("secret"), psk)
  18. Nk := uint16(st.aeadID.KeySize())
  19. key := st.labeledExpand(secret, []byte("key"), keySchCtx, Nk)
  20. aead, err := st.aeadID.New(key)
  21. if err != nil {
  22. return nil, err
  23. }
  24. Nn := uint16(aead.NonceSize())
  25. baseNonce := st.labeledExpand(secret, []byte("base_nonce"), keySchCtx, Nn)
  26. exporterSecret := st.labeledExpand(
  27. secret,
  28. []byte("exp"),
  29. keySchCtx,
  30. uint16(st.kdfID.ExtractSize()),
  31. )
  32. return &encdecContext{
  33. st.Suite,
  34. ss,
  35. secret,
  36. keySchCtx,
  37. exporterSecret,
  38. key,
  39. baseNonce,
  40. make([]byte, Nn),
  41. aead,
  42. make([]byte, Nn),
  43. }, nil
  44. }
  45. func (st state) verifyPSKInputs(psk, pskID []byte) error {
  46. gotPSK := psk != nil
  47. gotPSKID := pskID != nil
  48. if gotPSK != gotPSKID {
  49. return errors.New("inconsistent PSK inputs")
  50. }
  51. switch st.modeID {
  52. case modeBase | modeAuth:
  53. if gotPSK {
  54. return errors.New("PSK input provided when not needed")
  55. }
  56. case modePSK | modeAuthPSK:
  57. if !gotPSK {
  58. return errors.New("missing required PSK input")
  59. }
  60. }
  61. return nil
  62. }
  63. // Params returns the codepoints for the algorithms comprising the suite.
  64. func (suite Suite) Params() (KEM, KDF, AEAD) {
  65. return suite.kemID, suite.kdfID, suite.aeadID
  66. }
  67. func (suite Suite) String() string {
  68. return fmt.Sprintf(
  69. "kem_id: %v kdf_id: %v aead_id: %v",
  70. suite.kemID, suite.kdfID, suite.aeadID,
  71. )
  72. }
  73. func (suite Suite) getSuiteID() (id [10]byte) {
  74. id[0], id[1], id[2], id[3] = 'H', 'P', 'K', 'E'
  75. binary.BigEndian.PutUint16(id[4:6], uint16(suite.kemID))
  76. binary.BigEndian.PutUint16(id[6:8], uint16(suite.kdfID))
  77. binary.BigEndian.PutUint16(id[8:10], uint16(suite.aeadID))
  78. return
  79. }
  80. func (suite Suite) isValid() bool {
  81. return suite.kemID.IsValid() &&
  82. suite.kdfID.IsValid() &&
  83. suite.aeadID.IsValid()
  84. }
  85. func (suite Suite) labeledExtract(salt, label, ikm []byte) []byte {
  86. suiteID := suite.getSuiteID()
  87. labeledIKM := append(append(append(append(
  88. make([]byte, 0, len(versionLabel)+len(suiteID)+len(label)+len(ikm)),
  89. versionLabel...),
  90. suiteID[:]...),
  91. label...),
  92. ikm...)
  93. return suite.kdfID.Extract(labeledIKM, salt)
  94. }
  95. func (suite Suite) labeledExpand(prk, label, info []byte, l uint16) []byte {
  96. suiteID := suite.getSuiteID()
  97. labeledInfo := make([]byte,
  98. 2, 2+len(versionLabel)+len(suiteID)+len(label)+len(info))
  99. binary.BigEndian.PutUint16(labeledInfo[0:2], l)
  100. labeledInfo = append(append(append(append(labeledInfo,
  101. versionLabel...),
  102. suiteID[:]...),
  103. label...),
  104. info...)
  105. return suite.kdfID.Expand(prk, labeledInfo, uint(l))
  106. }