state.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. // Package noise implements the Noise Protocol Framework.
  2. //
  3. // Noise is a low-level framework for building crypto protocols. Noise protocols
  4. // support mutual and optional authentication, identity hiding, forward secrecy,
  5. // zero round-trip encryption, and other advanced features. For more details,
  6. // visit https://noiseprotocol.org.
  7. package noise
  8. import (
  9. "crypto/rand"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "math"
  14. )
  15. // A CipherState provides symmetric encryption and decryption after a successful
  16. // handshake.
  17. type CipherState struct {
  18. cs CipherSuite
  19. c Cipher
  20. k [32]byte
  21. n uint64
  22. invalid bool
  23. }
  24. // MaxNonce is the maximum value of n that is allowed. ErrMaxNonce is returned
  25. // by Encrypt and Decrypt after this has been reached. 2^64-1 is reserved for rekeys.
  26. const MaxNonce = uint64(math.MaxUint64) - 1
  27. var ErrMaxNonce = errors.New("noise: cipherstate has reached maximum n, a new handshake must be performed")
  28. var ErrCipherSuiteCopied = errors.New("noise: CipherSuite has been copied, state is invalid")
  29. // Encrypt encrypts the plaintext and then appends the ciphertext and an
  30. // authentication tag across the ciphertext and optional authenticated data to
  31. // out. This method automatically increments the nonce after every call, so
  32. // messages must be decrypted in the same order. ErrMaxNonce is returned after
  33. // the maximum nonce of 2^64-2 is reached.
  34. func (s *CipherState) Encrypt(out, ad, plaintext []byte) ([]byte, error) {
  35. if s.invalid {
  36. return nil, ErrCipherSuiteCopied
  37. }
  38. if s.n > MaxNonce {
  39. return nil, ErrMaxNonce
  40. }
  41. out = s.c.Encrypt(out, s.n, ad, plaintext)
  42. s.n++
  43. return out, nil
  44. }
  45. // Decrypt checks the authenticity of the ciphertext and authenticated data and
  46. // then decrypts and appends the plaintext to out. This method automatically
  47. // increments the nonce after every call, messages must be provided in the same
  48. // order that they were encrypted with no missing messages. ErrMaxNonce is
  49. // returned after the maximum nonce of 2^64-2 is reached.
  50. func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) {
  51. if s.invalid {
  52. return nil, ErrCipherSuiteCopied
  53. }
  54. if s.n > MaxNonce {
  55. return nil, ErrMaxNonce
  56. }
  57. out, err := s.c.Decrypt(out, s.n, ad, ciphertext)
  58. if err != nil {
  59. return nil, err
  60. }
  61. s.n++
  62. return out, nil
  63. }
  64. // Cipher returns the low-level symmetric encryption primitive. It should only
  65. // be used if nonces need to be managed manually, for example with a network
  66. // protocol that can deliver out-of-order messages. This is dangerous, users
  67. // must ensure that they are incrementing a nonce after every encrypt operation.
  68. // After calling this method, it is an error to call Encrypt/Decrypt on the
  69. // CipherState.
  70. func (s *CipherState) Cipher() Cipher {
  71. s.invalid = true
  72. return s.c
  73. }
  74. // Nonce returns the current value of n. This can be used to determine if a
  75. // new handshake should be performed due to approaching MaxNonce.
  76. func (s *CipherState) Nonce() uint64 {
  77. return s.n
  78. }
  79. // SetNonce sets the current value of n.
  80. func (s *CipherState) SetNonce(n uint64) {
  81. s.n = n
  82. }
  83. func (s *CipherState) Rekey() {
  84. var zeros [32]byte
  85. var out []byte
  86. out = s.c.Encrypt(out, math.MaxUint64, []byte{}, zeros[:])
  87. copy(s.k[:], out[:32])
  88. s.c = s.cs.Cipher(s.k)
  89. }
  90. type symmetricState struct {
  91. CipherState
  92. hasK bool
  93. ck []byte
  94. h []byte
  95. prevCK []byte
  96. prevH []byte
  97. }
  98. func (s *symmetricState) InitializeSymmetric(handshakeName []byte) {
  99. h := s.cs.Hash()
  100. if len(handshakeName) <= h.Size() {
  101. s.h = make([]byte, h.Size())
  102. copy(s.h, handshakeName)
  103. } else {
  104. h.Write(handshakeName)
  105. s.h = h.Sum(nil)
  106. }
  107. s.ck = make([]byte, len(s.h))
  108. copy(s.ck, s.h)
  109. }
  110. func (s *symmetricState) MixKey(dhOutput []byte) {
  111. s.n = 0
  112. s.hasK = true
  113. var hk []byte
  114. s.ck, hk, _ = hkdf(s.cs.Hash, 2, s.ck[:0], s.k[:0], nil, s.ck, dhOutput)
  115. copy(s.k[:], hk)
  116. s.c = s.cs.Cipher(s.k)
  117. }
  118. func (s *symmetricState) MixHash(data []byte) {
  119. h := s.cs.Hash()
  120. h.Write(s.h)
  121. h.Write(data)
  122. s.h = h.Sum(s.h[:0])
  123. }
  124. func (s *symmetricState) MixKeyAndHash(data []byte) {
  125. var hk []byte
  126. var temp []byte
  127. s.ck, temp, hk = hkdf(s.cs.Hash, 3, s.ck[:0], temp, s.k[:0], s.ck, data)
  128. s.MixHash(temp)
  129. copy(s.k[:], hk)
  130. s.c = s.cs.Cipher(s.k)
  131. s.n = 0
  132. s.hasK = true
  133. }
  134. func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) {
  135. if !s.hasK {
  136. s.MixHash(plaintext)
  137. return append(out, plaintext...), nil
  138. }
  139. ciphertext, err := s.Encrypt(out, s.h, plaintext)
  140. if err != nil {
  141. return nil, err
  142. }
  143. s.MixHash(ciphertext[len(out):])
  144. return ciphertext, nil
  145. }
  146. func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
  147. if !s.hasK {
  148. s.MixHash(data)
  149. return append(out, data...), nil
  150. }
  151. plaintext, err := s.Decrypt(out, s.h, data)
  152. if err != nil {
  153. return nil, err
  154. }
  155. s.MixHash(data)
  156. return plaintext, nil
  157. }
  158. func (s *symmetricState) Split() (*CipherState, *CipherState) {
  159. s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs}
  160. hk1, hk2, _ := hkdf(s.cs.Hash, 2, s1.k[:0], s2.k[:0], nil, s.ck, nil)
  161. copy(s1.k[:], hk1)
  162. copy(s2.k[:], hk2)
  163. s1.c = s.cs.Cipher(s1.k)
  164. s2.c = s.cs.Cipher(s2.k)
  165. return s1, s2
  166. }
  167. func (s *symmetricState) Checkpoint() {
  168. if len(s.ck) > cap(s.prevCK) {
  169. s.prevCK = make([]byte, len(s.ck))
  170. }
  171. s.prevCK = s.prevCK[:len(s.ck)]
  172. copy(s.prevCK, s.ck)
  173. if len(s.h) > cap(s.prevH) {
  174. s.prevH = make([]byte, len(s.h))
  175. }
  176. s.prevH = s.prevH[:len(s.h)]
  177. copy(s.prevH, s.h)
  178. }
  179. func (s *symmetricState) Rollback() {
  180. s.ck = s.ck[:len(s.prevCK)]
  181. copy(s.ck, s.prevCK)
  182. s.h = s.h[:len(s.prevH)]
  183. copy(s.h, s.prevH)
  184. }
  185. // A MessagePattern is a single message or operation used in a Noise handshake.
  186. type MessagePattern int
  187. // A HandshakePattern is a list of messages and operations that are used to
  188. // perform a specific Noise handshake.
  189. type HandshakePattern struct {
  190. Name string
  191. InitiatorPreMessages []MessagePattern
  192. ResponderPreMessages []MessagePattern
  193. Messages [][]MessagePattern
  194. }
  195. const (
  196. MessagePatternS MessagePattern = iota
  197. MessagePatternE
  198. MessagePatternDHEE
  199. MessagePatternDHES
  200. MessagePatternDHSE
  201. MessagePatternDHSS
  202. MessagePatternPSK
  203. )
  204. // MaxMsgLen is the maximum number of bytes that can be sent in a single Noise
  205. // message.
  206. const MaxMsgLen = 65535
  207. // A HandshakeState tracks the state of a Noise handshake. It may be discarded
  208. // after the handshake is complete.
  209. type HandshakeState struct {
  210. ss symmetricState
  211. s DHKey // local static keypair
  212. e DHKey // local ephemeral keypair
  213. rs []byte // remote party's static public key
  214. re []byte // remote party's ephemeral public key
  215. psk []byte // preshared key, maybe zero length
  216. messagePatterns [][]MessagePattern
  217. shouldWrite bool
  218. initiator bool
  219. msgIdx int
  220. rng io.Reader
  221. }
  222. // A Config provides the details necessary to process a Noise handshake. It is
  223. // never modified by this package, and can be reused.
  224. type Config struct {
  225. // CipherSuite is the set of cryptographic primitives that will be used.
  226. CipherSuite CipherSuite
  227. // Random is the source for cryptographically appropriate random bytes. If
  228. // zero, it is automatically configured.
  229. Random io.Reader
  230. // Pattern is the pattern for the handshake.
  231. Pattern HandshakePattern
  232. // Initiator must be true if the first message in the handshake will be sent
  233. // by this peer.
  234. Initiator bool
  235. // Prologue is an optional message that has already be communicated and must
  236. // be identical on both sides for the handshake to succeed.
  237. Prologue []byte
  238. // PresharedKey is the optional preshared key for the handshake.
  239. PresharedKey []byte
  240. // PresharedKeyPlacement specifies the placement position of the PSK token
  241. // when PresharedKey is specified
  242. PresharedKeyPlacement int
  243. // StaticKeypair is this peer's static keypair, required if part of the
  244. // handshake.
  245. StaticKeypair DHKey
  246. // EphemeralKeypair is this peer's ephemeral keypair that was provided as
  247. // a pre-message in the handshake.
  248. EphemeralKeypair DHKey
  249. // PeerStatic is the static public key of the remote peer that was provided
  250. // as a pre-message in the handshake.
  251. PeerStatic []byte
  252. // PeerEphemeral is the ephemeral public key of the remote peer that was
  253. // provided as a pre-message in the handshake.
  254. PeerEphemeral []byte
  255. }
  256. // NewHandshakeState starts a new handshake using the provided configuration.
  257. func NewHandshakeState(c Config) (*HandshakeState, error) {
  258. hs := &HandshakeState{
  259. s: c.StaticKeypair,
  260. e: c.EphemeralKeypair,
  261. rs: c.PeerStatic,
  262. psk: c.PresharedKey,
  263. messagePatterns: c.Pattern.Messages,
  264. shouldWrite: c.Initiator,
  265. initiator: c.Initiator,
  266. rng: c.Random,
  267. }
  268. if hs.rng == nil {
  269. hs.rng = rand.Reader
  270. }
  271. if len(c.PeerEphemeral) > 0 {
  272. hs.re = make([]byte, len(c.PeerEphemeral))
  273. copy(hs.re, c.PeerEphemeral)
  274. }
  275. hs.ss.cs = c.CipherSuite
  276. pskModifier := ""
  277. if len(hs.psk) > 0 {
  278. if len(hs.psk) != 32 {
  279. return nil, errors.New("noise: specification mandates 256-bit preshared keys")
  280. }
  281. pskModifier = fmt.Sprintf("psk%d", c.PresharedKeyPlacement)
  282. hs.messagePatterns = append([][]MessagePattern(nil), hs.messagePatterns...)
  283. if c.PresharedKeyPlacement == 0 {
  284. hs.messagePatterns[0] = append([]MessagePattern{MessagePatternPSK}, hs.messagePatterns[0]...)
  285. } else {
  286. hs.messagePatterns[c.PresharedKeyPlacement-1] = append(hs.messagePatterns[c.PresharedKeyPlacement-1], MessagePatternPSK)
  287. }
  288. }
  289. hs.ss.InitializeSymmetric([]byte("Noise_" + c.Pattern.Name + pskModifier + "_" + string(hs.ss.cs.Name())))
  290. hs.ss.MixHash(c.Prologue)
  291. for _, m := range c.Pattern.InitiatorPreMessages {
  292. switch {
  293. case c.Initiator && m == MessagePatternS:
  294. hs.ss.MixHash(hs.s.Public)
  295. case c.Initiator && m == MessagePatternE:
  296. hs.ss.MixHash(hs.e.Public)
  297. case !c.Initiator && m == MessagePatternS:
  298. hs.ss.MixHash(hs.rs)
  299. case !c.Initiator && m == MessagePatternE:
  300. hs.ss.MixHash(hs.re)
  301. }
  302. }
  303. for _, m := range c.Pattern.ResponderPreMessages {
  304. switch {
  305. case !c.Initiator && m == MessagePatternS:
  306. hs.ss.MixHash(hs.s.Public)
  307. case !c.Initiator && m == MessagePatternE:
  308. hs.ss.MixHash(hs.e.Public)
  309. case c.Initiator && m == MessagePatternS:
  310. hs.ss.MixHash(hs.rs)
  311. case c.Initiator && m == MessagePatternE:
  312. hs.ss.MixHash(hs.re)
  313. }
  314. }
  315. return hs, nil
  316. }
  317. // WriteMessage appends a handshake message to out. The message will include the
  318. // optional payload if provided. If the handshake is completed by the call, two
  319. // CipherStates will be returned, one is used for encryption of messages to the
  320. // remote peer, the other is used for decryption of messages from the remote
  321. // peer. It is an error to call this method out of sync with the handshake
  322. // pattern.
  323. func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
  324. if !s.shouldWrite {
  325. return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
  326. }
  327. if s.msgIdx > len(s.messagePatterns)-1 {
  328. return nil, nil, nil, errors.New("noise: no handshake messages left")
  329. }
  330. if len(payload) > MaxMsgLen {
  331. return nil, nil, nil, errors.New("noise: message is too long")
  332. }
  333. var err error
  334. for _, msg := range s.messagePatterns[s.msgIdx] {
  335. switch msg {
  336. case MessagePatternE:
  337. e, err := s.ss.cs.GenerateKeypair(s.rng)
  338. if err != nil {
  339. return nil, nil, nil, err
  340. }
  341. s.e = e
  342. out = append(out, s.e.Public...)
  343. s.ss.MixHash(s.e.Public)
  344. if len(s.psk) > 0 {
  345. s.ss.MixKey(s.e.Public)
  346. }
  347. case MessagePatternS:
  348. if len(s.s.Public) == 0 {
  349. return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
  350. }
  351. out, err = s.ss.EncryptAndHash(out, s.s.Public)
  352. if err != nil {
  353. return nil, nil, nil, err
  354. }
  355. case MessagePatternDHEE:
  356. dh, err := s.ss.cs.DH(s.e.Private, s.re)
  357. if err != nil {
  358. return nil, nil, nil, err
  359. }
  360. s.ss.MixKey(dh)
  361. case MessagePatternDHES:
  362. if s.initiator {
  363. dh, err := s.ss.cs.DH(s.e.Private, s.rs)
  364. if err != nil {
  365. return nil, nil, nil, err
  366. }
  367. s.ss.MixKey(dh)
  368. } else {
  369. dh, err := s.ss.cs.DH(s.s.Private, s.re)
  370. if err != nil {
  371. return nil, nil, nil, err
  372. }
  373. s.ss.MixKey(dh)
  374. }
  375. case MessagePatternDHSE:
  376. if s.initiator {
  377. dh, err := s.ss.cs.DH(s.s.Private, s.re)
  378. if err != nil {
  379. return nil, nil, nil, err
  380. }
  381. s.ss.MixKey(dh)
  382. } else {
  383. dh, err := s.ss.cs.DH(s.e.Private, s.rs)
  384. if err != nil {
  385. return nil, nil, nil, err
  386. }
  387. s.ss.MixKey(dh)
  388. }
  389. case MessagePatternDHSS:
  390. dh, err := s.ss.cs.DH(s.s.Private, s.rs)
  391. if err != nil {
  392. return nil, nil, nil, err
  393. }
  394. s.ss.MixKey(dh)
  395. case MessagePatternPSK:
  396. s.ss.MixKeyAndHash(s.psk)
  397. }
  398. }
  399. s.shouldWrite = false
  400. s.msgIdx++
  401. out, err = s.ss.EncryptAndHash(out, payload)
  402. if err != nil {
  403. return nil, nil, nil, err
  404. }
  405. if s.msgIdx >= len(s.messagePatterns) {
  406. cs1, cs2 := s.ss.Split()
  407. return out, cs1, cs2, nil
  408. }
  409. return out, nil, nil, nil
  410. }
  411. // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
  412. var ErrShortMessage = errors.New("noise: message is too short")
  413. // ReadMessage processes a received handshake message and appends the payload,
  414. // if any to out. If the handshake is completed by the call, two CipherStates
  415. // will be returned, one is used for encryption of messages to the remote peer,
  416. // the other is used for decryption of messages from the remote peer. It is an
  417. // error to call this method out of sync with the handshake pattern.
  418. func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) {
  419. if s.shouldWrite {
  420. return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
  421. }
  422. if s.msgIdx > len(s.messagePatterns)-1 {
  423. return nil, nil, nil, errors.New("noise: no handshake messages left")
  424. }
  425. rsSet := false
  426. s.ss.Checkpoint()
  427. var err error
  428. for _, msg := range s.messagePatterns[s.msgIdx] {
  429. switch msg {
  430. case MessagePatternE, MessagePatternS:
  431. expected := s.ss.cs.DHLen()
  432. if msg == MessagePatternS && s.ss.hasK {
  433. expected += 16
  434. }
  435. if len(message) < expected {
  436. return nil, nil, nil, ErrShortMessage
  437. }
  438. switch msg {
  439. case MessagePatternE:
  440. if cap(s.re) < s.ss.cs.DHLen() {
  441. s.re = make([]byte, s.ss.cs.DHLen())
  442. }
  443. s.re = s.re[:s.ss.cs.DHLen()]
  444. copy(s.re, message)
  445. s.ss.MixHash(s.re)
  446. if len(s.psk) > 0 {
  447. s.ss.MixKey(s.re)
  448. }
  449. case MessagePatternS:
  450. if len(s.rs) > 0 {
  451. return nil, nil, nil, errors.New("noise: invalid state, rs is not nil")
  452. }
  453. s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected])
  454. rsSet = true
  455. }
  456. if err != nil {
  457. s.ss.Rollback()
  458. if rsSet {
  459. s.rs = nil
  460. }
  461. return nil, nil, nil, err
  462. }
  463. message = message[expected:]
  464. case MessagePatternDHEE:
  465. dh, err := s.ss.cs.DH(s.e.Private, s.re)
  466. if err != nil {
  467. return nil, nil, nil, err
  468. }
  469. s.ss.MixKey(dh)
  470. case MessagePatternDHES:
  471. if s.initiator {
  472. dh, err := s.ss.cs.DH(s.e.Private, s.rs)
  473. if err != nil {
  474. return nil, nil, nil, err
  475. }
  476. s.ss.MixKey(dh)
  477. } else {
  478. dh, err := s.ss.cs.DH(s.s.Private, s.re)
  479. if err != nil {
  480. return nil, nil, nil, err
  481. }
  482. s.ss.MixKey(dh)
  483. }
  484. case MessagePatternDHSE:
  485. if s.initiator {
  486. dh, err := s.ss.cs.DH(s.s.Private, s.re)
  487. if err != nil {
  488. return nil, nil, nil, err
  489. }
  490. s.ss.MixKey(dh)
  491. } else {
  492. dh, err := s.ss.cs.DH(s.e.Private, s.rs)
  493. if err != nil {
  494. return nil, nil, nil, err
  495. }
  496. s.ss.MixKey(dh)
  497. }
  498. case MessagePatternDHSS:
  499. dh, err := s.ss.cs.DH(s.s.Private, s.rs)
  500. if err != nil {
  501. return nil, nil, nil, err
  502. }
  503. s.ss.MixKey(dh)
  504. case MessagePatternPSK:
  505. s.ss.MixKeyAndHash(s.psk)
  506. }
  507. }
  508. out, err = s.ss.DecryptAndHash(out, message)
  509. if err != nil {
  510. s.ss.Rollback()
  511. if rsSet {
  512. s.rs = nil
  513. }
  514. return nil, nil, nil, err
  515. }
  516. s.shouldWrite = true
  517. s.msgIdx++
  518. if s.msgIdx >= len(s.messagePatterns) {
  519. cs1, cs2 := s.ss.Split()
  520. return out, cs1, cs2, nil
  521. }
  522. return out, nil, nil, nil
  523. }
  524. // ChannelBinding provides a value that uniquely identifies the session and can
  525. // be used as a channel binding. It is an error to call this method before the
  526. // handshake is complete.
  527. func (s *HandshakeState) ChannelBinding() []byte {
  528. return s.ss.h
  529. }
  530. // PeerStatic returns the static key provided by the remote peer during
  531. // a handshake. It is an error to call this method if a handshake message
  532. // containing a static key has not been read.
  533. func (s *HandshakeState) PeerStatic() []byte {
  534. return s.rs
  535. }
  536. // MessageIndex returns the current handshake message id
  537. func (s *HandshakeState) MessageIndex() int {
  538. return s.msgIdx
  539. }
  540. // PeerEphemeral returns the ephemeral key provided by the remote peer during
  541. // a handshake. It is an error to call this method if a handshake message
  542. // containing a static key has not been read.
  543. func (s *HandshakeState) PeerEphemeral() []byte {
  544. return s.re
  545. }
  546. // LocalEphemeral returns the local ephemeral key pair generated during
  547. // a handshake.
  548. func (s *HandshakeState) LocalEphemeral() DHKey {
  549. return s.e
  550. }