| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605 |
- // Package noise implements the Noise Protocol Framework.
- //
- // Noise is a low-level framework for building crypto protocols. Noise protocols
- // support mutual and optional authentication, identity hiding, forward secrecy,
- // zero round-trip encryption, and other advanced features. For more details,
- // visit https://noiseprotocol.org.
- package noise
- import (
- "crypto/rand"
- "errors"
- "fmt"
- "io"
- "math"
- )
- // A CipherState provides symmetric encryption and decryption after a successful
- // handshake.
- type CipherState struct {
- cs CipherSuite
- c Cipher
- k [32]byte
- n uint64
- invalid bool
- }
- // MaxNonce is the maximum value of n that is allowed. ErrMaxNonce is returned
- // by Encrypt and Decrypt after this has been reached. 2^64-1 is reserved for rekeys.
- const MaxNonce = uint64(math.MaxUint64) - 1
- var ErrMaxNonce = errors.New("noise: cipherstate has reached maximum n, a new handshake must be performed")
- var ErrCipherSuiteCopied = errors.New("noise: CipherSuite has been copied, state is invalid")
- // Encrypt encrypts the plaintext and then appends the ciphertext and an
- // authentication tag across the ciphertext and optional authenticated data to
- // out. This method automatically increments the nonce after every call, so
- // messages must be decrypted in the same order. ErrMaxNonce is returned after
- // the maximum nonce of 2^64-2 is reached.
- func (s *CipherState) Encrypt(out, ad, plaintext []byte) ([]byte, error) {
- if s.invalid {
- return nil, ErrCipherSuiteCopied
- }
- if s.n > MaxNonce {
- return nil, ErrMaxNonce
- }
- out = s.c.Encrypt(out, s.n, ad, plaintext)
- s.n++
- return out, nil
- }
- // Decrypt checks the authenticity of the ciphertext and authenticated data and
- // then decrypts and appends the plaintext to out. This method automatically
- // increments the nonce after every call, messages must be provided in the same
- // order that they were encrypted with no missing messages. ErrMaxNonce is
- // returned after the maximum nonce of 2^64-2 is reached.
- func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
- if s.invalid {
- return nil, ErrCipherSuiteCopied
- }
- if s.n > MaxNonce {
- return nil, ErrMaxNonce
- }
- out, err := s.c.Decrypt(out, s.n, ad, ciphertext)
- if err != nil {
- return nil, err
- }
- s.n++
- return out, nil
- }
- // Cipher returns the low-level symmetric encryption primitive. It should only
- // be used if nonces need to be managed manually, for example with a network
- // protocol that can deliver out-of-order messages. This is dangerous, users
- // must ensure that they are incrementing a nonce after every encrypt operation.
- // After calling this method, it is an error to call Encrypt/Decrypt on the
- // CipherState.
- func (s *CipherState) Cipher() Cipher {
- s.invalid = true
- return s.c
- }
- // Nonce returns the current value of n. This can be used to determine if a
- // new handshake should be performed due to approaching MaxNonce.
- func (s *CipherState) Nonce() uint64 {
- return s.n
- }
- // SetNonce sets the current value of n.
- func (s *CipherState) SetNonce(n uint64) {
- s.n = n
- }
- func (s *CipherState) Rekey() {
- var zeros [32]byte
- var out []byte
- out = s.c.Encrypt(out, math.MaxUint64, []byte{}, zeros[:])
- copy(s.k[:], out[:32])
- s.c = s.cs.Cipher(s.k)
- }
- type symmetricState struct {
- CipherState
- hasK bool
- ck []byte
- h []byte
- prevCK []byte
- prevH []byte
- }
- func (s *symmetricState) InitializeSymmetric(handshakeName []byte) {
- h := s.cs.Hash()
- if len(handshakeName) <= h.Size() {
- s.h = make([]byte, h.Size())
- copy(s.h, handshakeName)
- } else {
- h.Write(handshakeName)
- s.h = h.Sum(nil)
- }
- s.ck = make([]byte, len(s.h))
- copy(s.ck, s.h)
- }
- func (s *symmetricState) MixKey(dhOutput []byte) {
- s.n = 0
- s.hasK = true
- var hk []byte
- s.ck, hk, _ = hkdf(s.cs.Hash, 2, s.ck[:0], s.k[:0], nil, s.ck, dhOutput)
- copy(s.k[:], hk)
- s.c = s.cs.Cipher(s.k)
- }
- func (s *symmetricState) MixHash(data []byte) {
- h := s.cs.Hash()
- h.Write(s.h)
- h.Write(data)
- s.h = h.Sum(s.h[:0])
- }
- func (s *symmetricState) MixKeyAndHash(data []byte) {
- var hk []byte
- var temp []byte
- s.ck, temp, hk = hkdf(s.cs.Hash, 3, s.ck[:0], temp, s.k[:0], s.ck, data)
- s.MixHash(temp)
- copy(s.k[:], hk)
- s.c = s.cs.Cipher(s.k)
- s.n = 0
- s.hasK = true
- }
- func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) {
- if !s.hasK {
- s.MixHash(plaintext)
- return append(out, plaintext...), nil
- }
- ciphertext, err := s.Encrypt(out, s.h, plaintext)
- if err != nil {
- return nil, err
- }
- s.MixHash(ciphertext[len(out):])
- return ciphertext, nil
- }
- func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
- if !s.hasK {
- s.MixHash(data)
- return append(out, data...), nil
- }
- plaintext, err := s.Decrypt(out, s.h, data)
- if err != nil {
- return nil, err
- }
- s.MixHash(data)
- return plaintext, nil
- }
- func (s *symmetricState) Split() (*CipherState, *CipherState) {
- s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs}
- hk1, hk2, _ := hkdf(s.cs.Hash, 2, s1.k[:0], s2.k[:0], nil, s.ck, nil)
- copy(s1.k[:], hk1)
- copy(s2.k[:], hk2)
- s1.c = s.cs.Cipher(s1.k)
- s2.c = s.cs.Cipher(s2.k)
- return s1, s2
- }
- func (s *symmetricState) Checkpoint() {
- if len(s.ck) > cap(s.prevCK) {
- s.prevCK = make([]byte, len(s.ck))
- }
- s.prevCK = s.prevCK[:len(s.ck)]
- copy(s.prevCK, s.ck)
- if len(s.h) > cap(s.prevH) {
- s.prevH = make([]byte, len(s.h))
- }
- s.prevH = s.prevH[:len(s.h)]
- copy(s.prevH, s.h)
- }
- func (s *symmetricState) Rollback() {
- s.ck = s.ck[:len(s.prevCK)]
- copy(s.ck, s.prevCK)
- s.h = s.h[:len(s.prevH)]
- copy(s.h, s.prevH)
- }
- // A MessagePattern is a single message or operation used in a Noise handshake.
- type MessagePattern int
- // A HandshakePattern is a list of messages and operations that are used to
- // perform a specific Noise handshake.
- type HandshakePattern struct {
- Name string
- InitiatorPreMessages []MessagePattern
- ResponderPreMessages []MessagePattern
- Messages [][]MessagePattern
- }
- const (
- MessagePatternS MessagePattern = iota
- MessagePatternE
- MessagePatternDHEE
- MessagePatternDHES
- MessagePatternDHSE
- MessagePatternDHSS
- MessagePatternPSK
- )
- // MaxMsgLen is the maximum number of bytes that can be sent in a single Noise
- // message.
- const MaxMsgLen = 65535
- // A HandshakeState tracks the state of a Noise handshake. It may be discarded
- // after the handshake is complete.
- type HandshakeState struct {
- ss symmetricState
- s DHKey // local static keypair
- e DHKey // local ephemeral keypair
- rs []byte // remote party's static public key
- re []byte // remote party's ephemeral public key
- psk []byte // preshared key, maybe zero length
- messagePatterns [][]MessagePattern
- shouldWrite bool
- initiator bool
- msgIdx int
- rng io.Reader
- }
- // A Config provides the details necessary to process a Noise handshake. It is
- // never modified by this package, and can be reused.
- type Config struct {
- // CipherSuite is the set of cryptographic primitives that will be used.
- CipherSuite CipherSuite
- // Random is the source for cryptographically appropriate random bytes. If
- // zero, it is automatically configured.
- Random io.Reader
- // Pattern is the pattern for the handshake.
- Pattern HandshakePattern
- // Initiator must be true if the first message in the handshake will be sent
- // by this peer.
- Initiator bool
- // Prologue is an optional message that has already be communicated and must
- // be identical on both sides for the handshake to succeed.
- Prologue []byte
- // PresharedKey is the optional preshared key for the handshake.
- PresharedKey []byte
- // PresharedKeyPlacement specifies the placement position of the PSK token
- // when PresharedKey is specified
- PresharedKeyPlacement int
- // StaticKeypair is this peer's static keypair, required if part of the
- // handshake.
- StaticKeypair DHKey
- // EphemeralKeypair is this peer's ephemeral keypair that was provided as
- // a pre-message in the handshake.
- EphemeralKeypair DHKey
- // PeerStatic is the static public key of the remote peer that was provided
- // as a pre-message in the handshake.
- PeerStatic []byte
- // PeerEphemeral is the ephemeral public key of the remote peer that was
- // provided as a pre-message in the handshake.
- PeerEphemeral []byte
- }
- // NewHandshakeState starts a new handshake using the provided configuration.
- func NewHandshakeState(c Config) (*HandshakeState, error) {
- hs := &HandshakeState{
- s: c.StaticKeypair,
- e: c.EphemeralKeypair,
- rs: c.PeerStatic,
- psk: c.PresharedKey,
- messagePatterns: c.Pattern.Messages,
- shouldWrite: c.Initiator,
- initiator: c.Initiator,
- rng: c.Random,
- }
- if hs.rng == nil {
- hs.rng = rand.Reader
- }
- if len(c.PeerEphemeral) > 0 {
- hs.re = make([]byte, len(c.PeerEphemeral))
- copy(hs.re, c.PeerEphemeral)
- }
- hs.ss.cs = c.CipherSuite
- pskModifier := ""
- if len(hs.psk) > 0 {
- if len(hs.psk) != 32 {
- return nil, errors.New("noise: specification mandates 256-bit preshared keys")
- }
- pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
- hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
- if c.PresharedKeyPlacement == 0 {
- hs.messagePatterns[0] = append([]MessagePattern{MessagePatternPSK}, hs.messagePatterns[0]...)
- } else {
- hs.messagePatterns[c.PresharedKeyPlacement-1] = append(hs.messagePatterns[c.PresharedKeyPlacement-1], MessagePatternPSK)
- }
- }
- hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name())))
- hs.ss.MixHash(c.Prologue)
- for _, m := range c.Pattern.InitiatorPreMessages {
- switch {
- case c.Initiator && m == MessagePatternS:
- hs.ss.MixHash(hs.s.Public)
- case c.Initiator && m == MessagePatternE:
- hs.ss.MixHash(hs.e.Public)
- case !c.Initiator && m == MessagePatternS:
- hs.ss.MixHash(hs.rs)
- case !c.Initiator && m == MessagePatternE:
- hs.ss.MixHash(hs.re)
- }
- }
- for _, m := range c.Pattern.ResponderPreMessages {
- switch {
- case !c.Initiator && m == MessagePatternS:
- hs.ss.MixHash(hs.s.Public)
- case !c.Initiator && m == MessagePatternE:
- hs.ss.MixHash(hs.e.Public)
- case c.Initiator && m == MessagePatternS:
- hs.ss.MixHash(hs.rs)
- case c.Initiator && m == MessagePatternE:
- hs.ss.MixHash(hs.re)
- }
- }
- return hs, nil
- }
- // WriteMessage appends a handshake message to out. The message will include the
- // optional payload if provided. If the handshake is completed by the call, two
- // CipherStates will be returned, one is used for encryption of messages to the
- // remote peer, the other is used for decryption of messages from the remote
- // peer. It is an error to call this method out of sync with the handshake
- // pattern.
- func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
- if !s.shouldWrite {
- return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
- }
- if s.msgIdx > len(s.messagePatterns)-1 {
- return nil, nil, nil, errors.New("noise: no handshake messages left")
- }
- if len(payload) > MaxMsgLen {
- return nil, nil, nil, errors.New("noise: message is too long")
- }
- var err error
- for _, msg := range s.messagePatterns[s.msgIdx] {
- switch msg {
- case MessagePatternE:
- e, err := s.ss.cs.GenerateKeypair(s.rng)
- if err != nil {
- return nil, nil, nil, err
- }
- s.e = e
- out = append(out, s.e.Public...)
- s.ss.MixHash(s.e.Public)
- if len(s.psk) > 0 {
- s.ss.MixKey(s.e.Public)
- }
- case MessagePatternS:
- if len(s.s.Public) == 0 {
- return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
- }
- out, err = s.ss.EncryptAndHash(out, s.s.Public)
- if err != nil {
- return nil, nil, nil, err
- }
- case MessagePatternDHEE:
- dh, err := s.ss.cs.DH(s.e.Private, s.re)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- case MessagePatternDHES:
- if s.initiator {
- dh, err := s.ss.cs.DH(s.e.Private, s.rs)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- } else {
- dh, err := s.ss.cs.DH(s.s.Private, s.re)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- }
- case MessagePatternDHSE:
- if s.initiator {
- dh, err := s.ss.cs.DH(s.s.Private, s.re)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- } else {
- dh, err := s.ss.cs.DH(s.e.Private, s.rs)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- }
- case MessagePatternDHSS:
- dh, err := s.ss.cs.DH(s.s.Private, s.rs)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- case MessagePatternPSK:
- s.ss.MixKeyAndHash(s.psk)
- }
- }
- s.shouldWrite = false
- s.msgIdx++
- out, err = s.ss.EncryptAndHash(out, payload)
- if err != nil {
- return nil, nil, nil, err
- }
- if s.msgIdx >= len(s.messagePatterns) {
- cs1, cs2 := s.ss.Split()
- return out, cs1, cs2, nil
- }
- return out, nil, nil, nil
- }
- // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
- var ErrShortMessage = errors.New("noise: message is too short")
- // ReadMessage processes a received handshake message and appends the payload,
- // if any to out. If the handshake is completed by the call, two CipherStates
- // will be returned, one is used for encryption of messages to the remote peer,
- // the other is used for decryption of messages from the remote peer. It is an
- // error to call this method out of sync with the handshake pattern.
- func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) {
- if s.shouldWrite {
- return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
- }
- if s.msgIdx > len(s.messagePatterns)-1 {
- return nil, nil, nil, errors.New("noise: no handshake messages left")
- }
- rsSet := false
- s.ss.Checkpoint()
- var err error
- for _, msg := range s.messagePatterns[s.msgIdx] {
- switch msg {
- case MessagePatternE, MessagePatternS:
- expected := s.ss.cs.DHLen()
- if msg == MessagePatternS && s.ss.hasK {
- expected += 16
- }
- if len(message) < expected {
- return nil, nil, nil, ErrShortMessage
- }
- switch msg {
- case MessagePatternE:
- if cap(s.re) < s.ss.cs.DHLen() {
- s.re = make([]byte, s.ss.cs.DHLen())
- }
- s.re = s.re[:s.ss.cs.DHLen()]
- copy(s.re, message)
- s.ss.MixHash(s.re)
- if len(s.psk) > 0 {
- s.ss.MixKey(s.re)
- }
- case MessagePatternS:
- if len(s.rs) > 0 {
- return nil, nil, nil, errors.New("noise: invalid state, rs is not nil")
- }
- s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected])
- rsSet = true
- }
- if err != nil {
- s.ss.Rollback()
- if rsSet {
- s.rs = nil
- }
- return nil, nil, nil, err
- }
- message = message[expected:]
- case MessagePatternDHEE:
- dh, err := s.ss.cs.DH(s.e.Private, s.re)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- case MessagePatternDHES:
- if s.initiator {
- dh, err := s.ss.cs.DH(s.e.Private, s.rs)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- } else {
- dh, err := s.ss.cs.DH(s.s.Private, s.re)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- }
- case MessagePatternDHSE:
- if s.initiator {
- dh, err := s.ss.cs.DH(s.s.Private, s.re)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- } else {
- dh, err := s.ss.cs.DH(s.e.Private, s.rs)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- }
- case MessagePatternDHSS:
- dh, err := s.ss.cs.DH(s.s.Private, s.rs)
- if err != nil {
- return nil, nil, nil, err
- }
- s.ss.MixKey(dh)
- case MessagePatternPSK:
- s.ss.MixKeyAndHash(s.psk)
- }
- }
- out, err = s.ss.DecryptAndHash(out, message)
- if err != nil {
- s.ss.Rollback()
- if rsSet {
- s.rs = nil
- }
- return nil, nil, nil, err
- }
- s.shouldWrite = true
- s.msgIdx++
- if s.msgIdx >= len(s.messagePatterns) {
- cs1, cs2 := s.ss.Split()
- return out, cs1, cs2, nil
- }
- return out, nil, nil, nil
- }
- // ChannelBinding provides a value that uniquely identifies the session and can
- // be used as a channel binding. It is an error to call this method before the
- // handshake is complete.
- func (s *HandshakeState) ChannelBinding() []byte {
- return s.ss.h
- }
- // PeerStatic returns the static key provided by the remote peer during
- // a handshake. It is an error to call this method if a handshake message
- // containing a static key has not been read.
- func (s *HandshakeState) PeerStatic() []byte {
- return s.rs
- }
- // MessageIndex returns the current handshake message id
- func (s *HandshakeState) MessageIndex() int {
- return s.msgIdx
- }
- // PeerEphemeral returns the ephemeral key provided by the remote peer during
- // a handshake. It is an error to call this method if a handshake message
- // containing a static key has not been read.
- func (s *HandshakeState) PeerEphemeral() []byte {
- return s.re
- }
- // LocalEphemeral returns the local ephemeral key pair generated during
- // a handshake.
- func (s *HandshakeState) LocalEphemeral() DHKey {
- return s.e
- }
|