state.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. // Package noise implements the Noise Protocol Framework.
  2. //
  3. // Noise is a low-level framework for building crypto protocols. Noise protocols
  4. // support mutual and optional authentication, identity hiding, forward secrecy,
  5. // zero round-trip encryption, and other advanced features. For more details,
  6. // visit https://noiseprotocol.org.
  7. package noise
  8. import (
  9. "crypto/rand"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "math"
  14. )
  15. // A CipherState provides symmetric encryption and decryption after a successful
  16. // handshake.
  17. type CipherState struct {
  18. cs CipherSuite
  19. c Cipher
  20. k [32]byte
  21. n uint64
  22. invalid bool
  23. }
  24. // MaxNonce is the maximum value of n that is allowed. ErrMaxNonce is returned
  25. // by Encrypt and Decrypt after this has been reached. 2^64-1 is reserved for rekeys.
  26. const MaxNonce = uint64(math.MaxUint64) - 1
  27. var ErrMaxNonce = errors.New("noise: cipherstate has reached maximum n, a new handshake must be performed")
  28. var ErrCipherSuiteCopied = errors.New("noise: CipherSuite has been copied, state is invalid")
  29. // Encrypt encrypts the plaintext and then appends the ciphertext and an
  30. // authentication tag across the ciphertext and optional authenticated data to
  31. // out. This method automatically increments the nonce after every call, so
  32. // messages must be decrypted in the same order. ErrMaxNonce is returned after
  33. // the maximum nonce of 2^64-2 is reached.
  34. func (s *CipherState) Encrypt(out, ad, plaintext []byte) ([]byte, error) {
  35. if s.invalid {
  36. return nil, ErrCipherSuiteCopied
  37. }
  38. if s.n > MaxNonce {
  39. return nil, ErrMaxNonce
  40. }
  41. out = s.c.Encrypt(out, s.n, ad, plaintext)
  42. s.n++
  43. return out, nil
  44. }
  45. // Decrypt checks the authenticity of the ciphertext and authenticated data and
  46. // then decrypts and appends the plaintext to out. This method automatically
  47. // increments the nonce after every call, messages must be provided in the same
  48. // order that they were encrypted with no missing messages. ErrMaxNonce is
  49. // returned after the maximum nonce of 2^64-2 is reached.
  50. func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
  51. if s.invalid {
  52. return nil, ErrCipherSuiteCopied
  53. }
  54. if s.n > MaxNonce {
  55. return nil, ErrMaxNonce
  56. }
  57. out, err := s.c.Decrypt(out, s.n, ad, ciphertext)
  58. if err != nil {
  59. return nil, err
  60. }
  61. s.n++
  62. return out, nil
  63. }
  64. // Cipher returns the low-level symmetric encryption primitive. It should only
  65. // be used if nonces need to be managed manually, for example with a network
  66. // protocol that can deliver out-of-order messages. This is dangerous, users
  67. // must ensure that they are incrementing a nonce after every encrypt operation.
  68. // After calling this method, it is an error to call Encrypt/Decrypt on the
  69. // CipherState.
  70. func (s *CipherState) Cipher() Cipher {
  71. s.invalid = true
  72. return s.c
  73. }
  74. // Nonce returns the current value of n. This can be used to determine if a
  75. // new handshake should be performed due to approaching MaxNonce.
  76. func (s *CipherState) Nonce() uint64 {
  77. return s.n
  78. }
  79. func (s *CipherState) Rekey() {
  80. var zeros [32]byte
  81. var out []byte
  82. out = s.c.Encrypt(out, math.MaxUint64, []byte{}, zeros[:])
  83. copy(s.k[:], out[:32])
  84. s.c = s.cs.Cipher(s.k)
  85. }
  86. type symmetricState struct {
  87. CipherState
  88. hasK bool
  89. ck []byte
  90. h []byte
  91. prevCK []byte
  92. prevH []byte
  93. }
  94. func (s *symmetricState) InitializeSymmetric(handshakeName []byte) {
  95. h := s.cs.Hash()
  96. if len(handshakeName) <= h.Size() {
  97. s.h = make([]byte, h.Size())
  98. copy(s.h, handshakeName)
  99. } else {
  100. h.Write(handshakeName)
  101. s.h = h.Sum(nil)
  102. }
  103. s.ck = make([]byte, len(s.h))
  104. copy(s.ck, s.h)
  105. }
  106. func (s *symmetricState) MixKey(dhOutput []byte) {
  107. s.n = 0
  108. s.hasK = true
  109. var hk []byte
  110. s.ck, hk, _ = hkdf(s.cs.Hash, 2, s.ck[:0], s.k[:0], nil, s.ck, dhOutput)
  111. copy(s.k[:], hk)
  112. s.c = s.cs.Cipher(s.k)
  113. }
  114. func (s *symmetricState) MixHash(data []byte) {
  115. h := s.cs.Hash()
  116. h.Write(s.h)
  117. h.Write(data)
  118. s.h = h.Sum(s.h[:0])
  119. }
  120. func (s *symmetricState) MixKeyAndHash(data []byte) {
  121. var hk []byte
  122. var temp []byte
  123. s.ck, temp, hk = hkdf(s.cs.Hash, 3, s.ck[:0], temp, s.k[:0], s.ck, data)
  124. s.MixHash(temp)
  125. copy(s.k[:], hk)
  126. s.c = s.cs.Cipher(s.k)
  127. s.n = 0
  128. s.hasK = true
  129. }
  130. func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) {
  131. if !s.hasK {
  132. s.MixHash(plaintext)
  133. return append(out, plaintext...), nil
  134. }
  135. ciphertext, err := s.Encrypt(out, s.h, plaintext)
  136. if err != nil {
  137. return nil, err
  138. }
  139. s.MixHash(ciphertext[len(out):])
  140. return ciphertext, nil
  141. }
  142. func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
  143. if !s.hasK {
  144. s.MixHash(data)
  145. return append(out, data...), nil
  146. }
  147. plaintext, err := s.Decrypt(out, s.h, data)
  148. if err != nil {
  149. return nil, err
  150. }
  151. s.MixHash(data)
  152. return plaintext, nil
  153. }
  154. func (s *symmetricState) Split() (*CipherState, *CipherState) {
  155. s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs}
  156. hk1, hk2, _ := hkdf(s.cs.Hash, 2, s1.k[:0], s2.k[:0], nil, s.ck, nil)
  157. copy(s1.k[:], hk1)
  158. copy(s2.k[:], hk2)
  159. s1.c = s.cs.Cipher(s1.k)
  160. s2.c = s.cs.Cipher(s2.k)
  161. return s1, s2
  162. }
  163. func (s *symmetricState) Checkpoint() {
  164. if len(s.ck) > cap(s.prevCK) {
  165. s.prevCK = make([]byte, len(s.ck))
  166. }
  167. s.prevCK = s.prevCK[:len(s.ck)]
  168. copy(s.prevCK, s.ck)
  169. if len(s.h) > cap(s.prevH) {
  170. s.prevH = make([]byte, len(s.h))
  171. }
  172. s.prevH = s.prevH[:len(s.h)]
  173. copy(s.prevH, s.h)
  174. }
  175. func (s *symmetricState) Rollback() {
  176. s.ck = s.ck[:len(s.prevCK)]
  177. copy(s.ck, s.prevCK)
  178. s.h = s.h[:len(s.prevH)]
  179. copy(s.h, s.prevH)
  180. }
  181. // A MessagePattern is a single message or operation used in a Noise handshake.
  182. type MessagePattern int
  183. // A HandshakePattern is a list of messages and operations that are used to
  184. // perform a specific Noise handshake.
  185. type HandshakePattern struct {
  186. Name string
  187. InitiatorPreMessages []MessagePattern
  188. ResponderPreMessages []MessagePattern
  189. Messages [][]MessagePattern
  190. }
  191. const (
  192. MessagePatternS MessagePattern = iota
  193. MessagePatternE
  194. MessagePatternDHEE
  195. MessagePatternDHES
  196. MessagePatternDHSE
  197. MessagePatternDHSS
  198. MessagePatternPSK
  199. )
  200. // MaxMsgLen is the maximum number of bytes that can be sent in a single Noise
  201. // message.
  202. const MaxMsgLen = 65535
  203. // A HandshakeState tracks the state of a Noise handshake. It may be discarded
  204. // after the handshake is complete.
  205. type HandshakeState struct {
  206. ss symmetricState
  207. s DHKey // local static keypair
  208. e DHKey // local ephemeral keypair
  209. rs []byte // remote party's static public key
  210. re []byte // remote party's ephemeral public key
  211. psk []byte // preshared key, maybe zero length
  212. messagePatterns [][]MessagePattern
  213. shouldWrite bool
  214. initiator bool
  215. msgIdx int
  216. rng io.Reader
  217. }
  218. // A Config provides the details necessary to process a Noise handshake. It is
  219. // never modified by this package, and can be reused.
  220. type Config struct {
  221. // CipherSuite is the set of cryptographic primitives that will be used.
  222. CipherSuite CipherSuite
  223. // Random is the source for cryptographically appropriate random bytes. If
  224. // zero, it is automatically configured.
  225. Random io.Reader
  226. // Pattern is the pattern for the handshake.
  227. Pattern HandshakePattern
  228. // Initiator must be true if the first message in the handshake will be sent
  229. // by this peer.
  230. Initiator bool
  231. // Prologue is an optional message that has already be communicated and must
  232. // be identical on both sides for the handshake to succeed.
  233. Prologue []byte
  234. // PresharedKey is the optional preshared key for the handshake.
  235. PresharedKey []byte
  236. // PresharedKeyPlacement specifies the placement position of the PSK token
  237. // when PresharedKey is specified
  238. PresharedKeyPlacement int
  239. // StaticKeypair is this peer's static keypair, required if part of the
  240. // handshake.
  241. StaticKeypair DHKey
  242. // EphemeralKeypair is this peer's ephemeral keypair that was provided as
  243. // a pre-message in the handshake.
  244. EphemeralKeypair DHKey
  245. // PeerStatic is the static public key of the remote peer that was provided
  246. // as a pre-message in the handshake.
  247. PeerStatic []byte
  248. // PeerEphemeral is the ephemeral public key of the remote peer that was
  249. // provided as a pre-message in the handshake.
  250. PeerEphemeral []byte
  251. }
  252. // NewHandshakeState starts a new handshake using the provided configuration.
  253. func NewHandshakeState(c Config) (*HandshakeState, error) {
  254. hs := &HandshakeState{
  255. s: c.StaticKeypair,
  256. e: c.EphemeralKeypair,
  257. rs: c.PeerStatic,
  258. psk: c.PresharedKey,
  259. messagePatterns: c.Pattern.Messages,
  260. shouldWrite: c.Initiator,
  261. initiator: c.Initiator,
  262. rng: c.Random,
  263. }
  264. if hs.rng == nil {
  265. hs.rng = rand.Reader
  266. }
  267. if len(c.PeerEphemeral) > 0 {
  268. hs.re = make([]byte, len(c.PeerEphemeral))
  269. copy(hs.re, c.PeerEphemeral)
  270. }
  271. hs.ss.cs = c.CipherSuite
  272. pskModifier := ""
  273. if len(hs.psk) > 0 {
  274. if len(hs.psk) != 32 {
  275. return nil, errors.New("noise: specification mandates 256-bit preshared keys")
  276. }
  277. pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
  278. hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
  279. if c.PresharedKeyPlacement == 0 {
  280. hs.messagePatterns[0] = append([]MessagePattern{MessagePatternPSK}, hs.messagePatterns[0]...)
  281. } else {
  282. hs.messagePatterns[c.PresharedKeyPlacement-1] = append(hs.messagePatterns[c.PresharedKeyPlacement-1], MessagePatternPSK)
  283. }
  284. }
  285. hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name())))
  286. hs.ss.MixHash(c.Prologue)
  287. for _, m := range c.Pattern.InitiatorPreMessages {
  288. switch {
  289. case c.Initiator && m == MessagePatternS:
  290. hs.ss.MixHash(hs.s.Public)
  291. case c.Initiator && m == MessagePatternE:
  292. hs.ss.MixHash(hs.e.Public)
  293. case !c.Initiator && m == MessagePatternS:
  294. hs.ss.MixHash(hs.rs)
  295. case !c.Initiator && m == MessagePatternE:
  296. hs.ss.MixHash(hs.re)
  297. }
  298. }
  299. for _, m := range c.Pattern.ResponderPreMessages {
  300. switch {
  301. case !c.Initiator && m == MessagePatternS:
  302. hs.ss.MixHash(hs.s.Public)
  303. case !c.Initiator && m == MessagePatternE:
  304. hs.ss.MixHash(hs.e.Public)
  305. case c.Initiator && m == MessagePatternS:
  306. hs.ss.MixHash(hs.rs)
  307. case c.Initiator && m == MessagePatternE:
  308. hs.ss.MixHash(hs.re)
  309. }
  310. }
  311. return hs, nil
  312. }
  313. // WriteMessage appends a handshake message to out. The message will include the
  314. // optional payload if provided. If the handshake is completed by the call, two
  315. // CipherStates will be returned, one is used for encryption of messages to the
  316. // remote peer, the other is used for decryption of messages from the remote
  317. // peer. It is an error to call this method out of sync with the handshake
  318. // pattern.
  319. func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
  320. if !s.shouldWrite {
  321. return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
  322. }
  323. if s.msgIdx > len(s.messagePatterns)-1 {
  324. return nil, nil, nil, errors.New("noise: no handshake messages left")
  325. }
  326. if len(payload) > MaxMsgLen {
  327. return nil, nil, nil, errors.New("noise: message is too long")
  328. }
  329. var err error
  330. for _, msg := range s.messagePatterns[s.msgIdx] {
  331. switch msg {
  332. case MessagePatternE:
  333. e, err := s.ss.cs.GenerateKeypair(s.rng)
  334. if err != nil {
  335. return nil, nil, nil, err
  336. }
  337. s.e = e
  338. out = append(out, s.e.Public...)
  339. s.ss.MixHash(s.e.Public)
  340. if len(s.psk) > 0 {
  341. s.ss.MixKey(s.e.Public)
  342. }
  343. case MessagePatternS:
  344. if len(s.s.Public) == 0 {
  345. return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
  346. }
  347. out, err = s.ss.EncryptAndHash(out, s.s.Public)
  348. if err != nil {
  349. return nil, nil, nil, err
  350. }
  351. case MessagePatternDHEE:
  352. dh, err := s.ss.cs.DH(s.e.Private, s.re)
  353. if err != nil {
  354. return nil, nil, nil, err
  355. }
  356. s.ss.MixKey(dh)
  357. case MessagePatternDHES:
  358. if s.initiator {
  359. dh, err := s.ss.cs.DH(s.e.Private, s.rs)
  360. if err != nil {
  361. return nil, nil, nil, err
  362. }
  363. s.ss.MixKey(dh)
  364. } else {
  365. dh, err := s.ss.cs.DH(s.s.Private, s.re)
  366. if err != nil {
  367. return nil, nil, nil, err
  368. }
  369. s.ss.MixKey(dh)
  370. }
  371. case MessagePatternDHSE:
  372. if s.initiator {
  373. dh, err := s.ss.cs.DH(s.s.Private, s.re)
  374. if err != nil {
  375. return nil, nil, nil, err
  376. }
  377. s.ss.MixKey(dh)
  378. } else {
  379. dh, err := s.ss.cs.DH(s.e.Private, s.rs)
  380. if err != nil {
  381. return nil, nil, nil, err
  382. }
  383. s.ss.MixKey(dh)
  384. }
  385. case MessagePatternDHSS:
  386. dh, err := s.ss.cs.DH(s.s.Private, s.rs)
  387. if err != nil {
  388. return nil, nil, nil, err
  389. }
  390. s.ss.MixKey(dh)
  391. case MessagePatternPSK:
  392. s.ss.MixKeyAndHash(s.psk)
  393. }
  394. }
  395. s.shouldWrite = false
  396. s.msgIdx++
  397. out, err = s.ss.EncryptAndHash(out, payload)
  398. if err != nil {
  399. return nil, nil, nil, err
  400. }
  401. if s.msgIdx >= len(s.messagePatterns) {
  402. cs1, cs2 := s.ss.Split()
  403. return out, cs1, cs2, nil
  404. }
  405. return out, nil, nil, nil
  406. }
  407. // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
  408. var ErrShortMessage = errors.New("noise: message is too short")
  409. // ReadMessage processes a received handshake message and appends the payload,
  410. // if any to out. If the handshake is completed by the call, two CipherStates
  411. // will be returned, one is used for encryption of messages to the remote peer,
  412. // the other is used for decryption of messages from the remote peer. It is an
  413. // error to call this method out of sync with the handshake pattern.
  414. func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) {
  415. if s.shouldWrite {
  416. return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
  417. }
  418. if s.msgIdx > len(s.messagePatterns)-1 {
  419. return nil, nil, nil, errors.New("noise: no handshake messages left")
  420. }
  421. rsSet := false
  422. s.ss.Checkpoint()
  423. var err error
  424. for _, msg := range s.messagePatterns[s.msgIdx] {
  425. switch msg {
  426. case MessagePatternE, MessagePatternS:
  427. expected := s.ss.cs.DHLen()
  428. if msg == MessagePatternS && s.ss.hasK {
  429. expected += 16
  430. }
  431. if len(message) < expected {
  432. return nil, nil, nil, ErrShortMessage
  433. }
  434. switch msg {
  435. case MessagePatternE:
  436. if cap(s.re) < s.ss.cs.DHLen() {
  437. s.re = make([]byte, s.ss.cs.DHLen())
  438. }
  439. s.re = s.re[:s.ss.cs.DHLen()]
  440. copy(s.re, message)
  441. s.ss.MixHash(s.re)
  442. if len(s.psk) > 0 {
  443. s.ss.MixKey(s.re)
  444. }
  445. case MessagePatternS:
  446. if len(s.rs) > 0 {
  447. return nil, nil, nil, errors.New("noise: invalid state, rs is not nil")
  448. }
  449. s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected])
  450. rsSet = true
  451. }
  452. if err != nil {
  453. s.ss.Rollback()
  454. if rsSet {
  455. s.rs = nil
  456. }
  457. return nil, nil, nil, err
  458. }
  459. message = message[expected:]
  460. case MessagePatternDHEE:
  461. dh, err := s.ss.cs.DH(s.e.Private, s.re)
  462. if err != nil {
  463. return nil, nil, nil, err
  464. }
  465. s.ss.MixKey(dh)
  466. case MessagePatternDHES:
  467. if s.initiator {
  468. dh, err := s.ss.cs.DH(s.e.Private, s.rs)
  469. if err != nil {
  470. return nil, nil, nil, err
  471. }
  472. s.ss.MixKey(dh)
  473. } else {
  474. dh, err := s.ss.cs.DH(s.s.Private, s.re)
  475. if err != nil {
  476. return nil, nil, nil, err
  477. }
  478. s.ss.MixKey(dh)
  479. }
  480. case MessagePatternDHSE:
  481. if s.initiator {
  482. dh, err := s.ss.cs.DH(s.s.Private, s.re)
  483. if err != nil {
  484. return nil, nil, nil, err
  485. }
  486. s.ss.MixKey(dh)
  487. } else {
  488. dh, err := s.ss.cs.DH(s.e.Private, s.rs)
  489. if err != nil {
  490. return nil, nil, nil, err
  491. }
  492. s.ss.MixKey(dh)
  493. }
  494. case MessagePatternDHSS:
  495. dh, err := s.ss.cs.DH(s.s.Private, s.rs)
  496. if err != nil {
  497. return nil, nil, nil, err
  498. }
  499. s.ss.MixKey(dh)
  500. case MessagePatternPSK:
  501. s.ss.MixKeyAndHash(s.psk)
  502. }
  503. }
  504. out, err = s.ss.DecryptAndHash(out, message)
  505. if err != nil {
  506. s.ss.Rollback()
  507. if rsSet {
  508. s.rs = nil
  509. }
  510. return nil, nil, nil, err
  511. }
  512. s.shouldWrite = true
  513. s.msgIdx++
  514. if s.msgIdx >= len(s.messagePatterns) {
  515. cs1, cs2 := s.ss.Split()
  516. return out, cs1, cs2, nil
  517. }
  518. return out, nil, nil, nil
  519. }
  520. // ChannelBinding provides a value that uniquely identifies the session and can
  521. // be used as a channel binding. It is an error to call this method before the
  522. // handshake is complete.
  523. func (s *HandshakeState) ChannelBinding() []byte {
  524. return s.ss.h
  525. }
  526. // PeerStatic returns the static key provided by the remote peer during
  527. // a handshake. It is an error to call this method if a handshake message
  528. // containing a static key has not been read.
  529. func (s *HandshakeState) PeerStatic() []byte {
  530. return s.rs
  531. }
  532. // MessageIndex returns the current handshake message id
  533. func (s *HandshakeState) MessageIndex() int {
  534. return s.msgIdx
  535. }
  536. // PeerEphemeral returns the ephemeral key provided by the remote peer during
  537. // a handshake. It is an error to call this method if a handshake message
  538. // containing a static key has not been read.
  539. func (s *HandshakeState) PeerEphemeral() []byte {
  540. return s.re
  541. }
  542. // LocalEphemeral returns the local ephemeral key pair generated during
  543. // a handshake.
  544. func (s *HandshakeState) LocalEphemeral() DHKey {
  545. return s.e
  546. }