session_test.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. /*
  2. * Copyright (c) 2023, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package inproxy
  20. import (
  21. "bytes"
  22. "context"
  23. "fmt"
  24. "math"
  25. "strings"
  26. "testing"
  27. "time"
  28. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  31. "github.com/flynn/noise"
  32. "golang.zx2c4.com/wireguard/replay"
  33. )
  34. func TestSessions(t *testing.T) {
  35. err := runTestSessions()
  36. if err != nil {
  37. t.Error(errors.Trace(err).Error())
  38. }
  39. }
  40. func runTestSessions() error {
  41. // Test: basic round trip succeeds
  42. responderPrivateKey, err := GenerateSessionPrivateKey()
  43. if err != nil {
  44. return errors.Trace(err)
  45. }
  46. responderPublicKey, err := responderPrivateKey.GetPublicKey()
  47. if err != nil {
  48. return errors.Trace(err)
  49. }
  50. responderRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  51. if err != nil {
  52. return errors.Trace(err)
  53. }
  54. responderSessions, err := NewResponderSessions(
  55. responderPrivateKey, responderRootObfuscationSecret)
  56. if err != nil {
  57. return errors.Trace(err)
  58. }
  59. initiatorPrivateKey, err := GenerateSessionPrivateKey()
  60. if err != nil {
  61. return errors.Trace(err)
  62. }
  63. initiatorPublicKey, err := initiatorPrivateKey.GetPublicKey()
  64. if err != nil {
  65. return errors.Trace(err)
  66. }
  67. initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
  68. waitToShareSession := true
  69. sessionHandshakeTimeout := 100 * time.Millisecond
  70. requestDelay := 1 * time.Microsecond
  71. requestTimeout := 200 * time.Millisecond
  72. roundTripper := newTestSessionRoundTripper(
  73. responderSessions,
  74. &initiatorPublicKey,
  75. sessionHandshakeTimeout,
  76. requestDelay,
  77. requestTimeout)
  78. request := roundTripper.MakeRequest()
  79. response, err := initiatorSessions.RoundTrip(
  80. context.Background(),
  81. roundTripper,
  82. responderPublicKey,
  83. responderRootObfuscationSecret,
  84. waitToShareSession,
  85. sessionHandshakeTimeout,
  86. requestDelay,
  87. requestTimeout,
  88. request)
  89. if err != nil {
  90. return errors.Trace(err)
  91. }
  92. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  93. return errors.TraceNew("unexpected response")
  94. }
  95. // Test: session expires; new one negotiated
  96. //
  97. // sessionStateResponder_XK_recv_e_es_send_e_ee case, when Nonce = 0
  98. responderSessions.sessions.Flush()
  99. request = roundTripper.MakeRequest()
  100. response, err = initiatorSessions.RoundTrip(
  101. context.Background(),
  102. roundTripper,
  103. responderPublicKey,
  104. responderRootObfuscationSecret,
  105. waitToShareSession,
  106. sessionHandshakeTimeout,
  107. requestDelay,
  108. requestTimeout,
  109. request)
  110. if err != nil {
  111. return errors.Trace(err)
  112. }
  113. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  114. return errors.TraceNew("unexpected response")
  115. }
  116. // Test: session expires; new one negotiated
  117. //
  118. // "unexpected nonce" case, when Nonce > 0
  119. for i := 0; i < 10; i++ {
  120. _, err = initiatorSessions.RoundTrip(
  121. context.Background(),
  122. roundTripper,
  123. responderPublicKey,
  124. responderRootObfuscationSecret,
  125. waitToShareSession,
  126. sessionHandshakeTimeout,
  127. requestDelay,
  128. requestTimeout,
  129. roundTripper.MakeRequest())
  130. if err != nil {
  131. return errors.Trace(err)
  132. }
  133. }
  134. responderSessions.sessions.Flush()
  135. request = roundTripper.MakeRequest()
  136. response, err = initiatorSessions.RoundTrip(
  137. context.Background(),
  138. roundTripper,
  139. responderPublicKey,
  140. responderRootObfuscationSecret,
  141. waitToShareSession,
  142. sessionHandshakeTimeout,
  143. requestDelay,
  144. requestTimeout,
  145. request)
  146. if err != nil {
  147. return errors.Trace(err)
  148. }
  149. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  150. return errors.TraceNew("unexpected response")
  151. }
  152. // Test: RoundTrips with waitToShareSession are interrupted when session
  153. // fails
  154. responderSessions.sessions.Flush()
  155. initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
  156. failingRoundTripper := newTestSessionRoundTripper(
  157. nil,
  158. &initiatorPublicKey,
  159. sessionHandshakeTimeout,
  160. requestDelay,
  161. requestTimeout)
  162. roundTripCount := 100
  163. results := make(chan error, roundTripCount)
  164. for i := 0; i < roundTripCount; i++ {
  165. go func() {
  166. time.Sleep(prng.DefaultPRNG().Period(0, 10*time.Millisecond))
  167. waitToShareSession := true
  168. _, err := initiatorSessions.RoundTrip(
  169. context.Background(),
  170. failingRoundTripper,
  171. responderPublicKey,
  172. responderRootObfuscationSecret,
  173. waitToShareSession,
  174. sessionHandshakeTimeout,
  175. requestDelay,
  176. requestTimeout,
  177. roundTripper.MakeRequest())
  178. results <- err
  179. }()
  180. }
  181. waitToShareSessionFailed := false
  182. for i := 0; i < roundTripCount; i++ {
  183. err := <-results
  184. if err == nil {
  185. return errors.TraceNew("unexpected success")
  186. }
  187. if strings.HasSuffix(err.Error(), "waitToShareSession failed") {
  188. waitToShareSessionFailed = true
  189. }
  190. }
  191. if !waitToShareSessionFailed {
  192. return errors.TraceNew("missing waitToShareSession failed error")
  193. }
  194. // Test: expected known initiator public key
  195. initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
  196. responderSessions, err = NewResponderSessionsForKnownInitiators(
  197. responderPrivateKey,
  198. responderRootObfuscationSecret,
  199. []SessionPublicKey{initiatorPublicKey})
  200. if err != nil {
  201. return errors.Trace(err)
  202. }
  203. roundTripper = newTestSessionRoundTripper(
  204. responderSessions,
  205. &initiatorPublicKey,
  206. sessionHandshakeTimeout,
  207. requestDelay,
  208. requestTimeout)
  209. request = roundTripper.MakeRequest()
  210. response, err = initiatorSessions.RoundTrip(
  211. context.Background(),
  212. roundTripper,
  213. responderPublicKey,
  214. responderRootObfuscationSecret,
  215. waitToShareSession,
  216. sessionHandshakeTimeout,
  217. requestDelay,
  218. requestTimeout,
  219. request)
  220. if err != nil {
  221. return errors.Trace(err)
  222. }
  223. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  224. return errors.TraceNew("unexpected response")
  225. }
  226. // Test: expected known initiator public key using SetKnownInitiatorPublicKeys
  227. initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
  228. responderSessions, err = NewResponderSessionsForKnownInitiators(
  229. responderPrivateKey,
  230. responderRootObfuscationSecret,
  231. []SessionPublicKey{})
  232. if err != nil {
  233. return errors.Trace(err)
  234. }
  235. responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{initiatorPublicKey})
  236. roundTripper = newTestSessionRoundTripper(
  237. responderSessions,
  238. &initiatorPublicKey,
  239. sessionHandshakeTimeout,
  240. requestDelay,
  241. requestTimeout)
  242. request = roundTripper.MakeRequest()
  243. response, err = initiatorSessions.RoundTrip(
  244. context.Background(),
  245. roundTripper,
  246. responderPublicKey,
  247. responderRootObfuscationSecret,
  248. waitToShareSession,
  249. sessionHandshakeTimeout,
  250. requestDelay,
  251. requestTimeout,
  252. request)
  253. if err != nil {
  254. return errors.Trace(err)
  255. }
  256. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  257. return errors.TraceNew("unexpected response")
  258. }
  259. // The existing session should not be dropped as the original key remains valid.
  260. responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{initiatorPublicKey})
  261. if responderSessions.sessions.ItemCount() != 1 {
  262. return errors.TraceNew("unexpected session cache state")
  263. }
  264. otherKnownInitiatorPrivateKey, err := GenerateSessionPrivateKey()
  265. if err != nil {
  266. return errors.Trace(err)
  267. }
  268. otherKnownInitiatorPublicKey, err := otherKnownInitiatorPrivateKey.GetPublicKey()
  269. if err != nil {
  270. return errors.Trace(err)
  271. }
  272. // The existing session should be dropped as the original key is not longer valid.
  273. responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{otherKnownInitiatorPublicKey})
  274. if responderSessions.sessions.ItemCount() != 0 {
  275. return errors.TraceNew("unexpected session cache state")
  276. }
  277. // Test: wrong known initiator public key
  278. unknownInitiatorPrivateKey, err := GenerateSessionPrivateKey()
  279. if err != nil {
  280. return errors.Trace(err)
  281. }
  282. unknownInitiatorSessions := NewInitiatorSessions(unknownInitiatorPrivateKey)
  283. ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
  284. defer cancelFunc()
  285. request = roundTripper.MakeRequest()
  286. _, err = unknownInitiatorSessions.RoundTrip(
  287. ctx,
  288. roundTripper,
  289. responderPublicKey,
  290. responderRootObfuscationSecret,
  291. waitToShareSession,
  292. sessionHandshakeTimeout,
  293. requestDelay,
  294. requestTimeout,
  295. request)
  296. if err == nil || !strings.HasSuffix(err.Error(), "unexpected initiator public key") {
  297. return errors.Tracef("unexpected result: %v", err)
  298. }
  299. // Test: many concurrent sessions
  300. responderSessions, err = NewResponderSessions(
  301. responderPrivateKey, responderRootObfuscationSecret)
  302. if err != nil {
  303. return errors.Trace(err)
  304. }
  305. roundTripper = newTestSessionRoundTripper(
  306. responderSessions,
  307. nil,
  308. sessionHandshakeTimeout,
  309. requestDelay,
  310. requestTimeout)
  311. clientCount := 10000
  312. requestCount := 100
  313. concurrentRequestCount := 5
  314. if common.IsRaceDetectorEnabled {
  315. // Workaround for very high memory usage and OOM that occurs only with
  316. // the race detector enabled.
  317. clientCount = 100
  318. }
  319. resultChan := make(chan error, clientCount)
  320. for i := 0; i < clientCount; i++ {
  321. // Run clients concurrently
  322. go func() {
  323. initiatorPrivateKey, err := GenerateSessionPrivateKey()
  324. if err != nil {
  325. resultChan <- errors.Trace(err)
  326. return
  327. }
  328. initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
  329. for i := 0; i < requestCount; i += concurrentRequestCount {
  330. requestResultChan := make(chan error, concurrentRequestCount)
  331. for j := 0; j < concurrentRequestCount; j++ {
  332. // Run some of each client's requests concurrently, to
  333. // exercise waitToShareSession
  334. go func(waitToShareSession bool) {
  335. request := roundTripper.MakeRequest()
  336. response, err := initiatorSessions.RoundTrip(
  337. context.Background(),
  338. roundTripper,
  339. responderPublicKey,
  340. responderRootObfuscationSecret,
  341. waitToShareSession,
  342. sessionHandshakeTimeout,
  343. requestDelay,
  344. requestTimeout,
  345. request)
  346. if err != nil {
  347. requestResultChan <- errors.Trace(err)
  348. return
  349. }
  350. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  351. requestResultChan <- errors.TraceNew("unexpected response")
  352. return
  353. }
  354. requestResultChan <- nil
  355. }(i%2 == 0)
  356. }
  357. for i := 0; i < concurrentRequestCount; i++ {
  358. err = <-requestResultChan
  359. if err != nil {
  360. resultChan <- errors.Trace(err)
  361. return
  362. }
  363. }
  364. }
  365. resultChan <- nil
  366. }()
  367. }
  368. for i := 0; i < clientCount; i++ {
  369. err = <-resultChan
  370. if err != nil {
  371. return errors.Trace(err)
  372. }
  373. }
  374. return nil
  375. }
  376. type testSessionRoundTripper struct {
  377. sessions *ResponderSessions
  378. expectedPeerPublicKey *SessionPublicKey
  379. expectedSessionHandshakeTimeout time.Duration
  380. expectedRequestDelay time.Duration
  381. expectedRequestTimeout time.Duration
  382. }
  383. func newTestSessionRoundTripper(
  384. sessions *ResponderSessions,
  385. expectedPeerPublicKey *SessionPublicKey,
  386. expectedSessionHandshakeTimeout time.Duration,
  387. expectedRequestDelay time.Duration,
  388. expectedRequestTimeout time.Duration) *testSessionRoundTripper {
  389. return &testSessionRoundTripper{
  390. sessions: sessions,
  391. expectedPeerPublicKey: expectedPeerPublicKey,
  392. expectedSessionHandshakeTimeout: expectedSessionHandshakeTimeout,
  393. expectedRequestDelay: expectedRequestDelay,
  394. expectedRequestTimeout: expectedRequestTimeout,
  395. }
  396. }
  397. func (t *testSessionRoundTripper) MakeRequest() []byte {
  398. return prng.Bytes(prng.Range(100, 1000))
  399. }
  400. func (t *testSessionRoundTripper) ExpectedResponse(requestPayload []byte) []byte {
  401. l := len(requestPayload)
  402. responsePayload := make([]byte, l)
  403. for i, b := range requestPayload {
  404. responsePayload[l-i-1] = b
  405. }
  406. return responsePayload
  407. }
  408. func (t *testSessionRoundTripper) RoundTrip(
  409. ctx context.Context,
  410. roundTripDelay time.Duration,
  411. roundTripTimeout time.Duration,
  412. requestPayload []byte) ([]byte, error) {
  413. err := ctx.Err()
  414. if err != nil {
  415. return nil, errors.Trace(err)
  416. }
  417. if t.sessions == nil {
  418. return nil, errors.TraceNew("closed")
  419. }
  420. if roundTripDelay > 0 {
  421. common.SleepWithContext(ctx, roundTripDelay)
  422. }
  423. _, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
  424. defer requestCancelFunc()
  425. isRequestRoundTrip := false
  426. unwrappedRequestHandler := func(initiatorID ID, unwrappedRequest []byte) ([]byte, error) {
  427. if t.expectedPeerPublicKey != nil {
  428. curve25519, err := (*t.expectedPeerPublicKey).ToCurve25519()
  429. if err != nil {
  430. return nil, errors.Trace(err)
  431. }
  432. if !bytes.Equal(initiatorID[:], curve25519[:]) {
  433. return nil, errors.TraceNew("unexpected initiator ID")
  434. }
  435. }
  436. isRequestRoundTrip = true
  437. return t.ExpectedResponse(unwrappedRequest), nil
  438. }
  439. responsePayload, err := t.sessions.HandlePacket(requestPayload, unwrappedRequestHandler)
  440. if err != nil {
  441. if responsePayload == nil {
  442. return nil, errors.Trace(err)
  443. } else {
  444. fmt.Printf("HandlePacket returned packet and error: %v\n", err)
  445. // Continue to relay packets
  446. }
  447. } else {
  448. // Handshake round trips and request payload round trips should have the
  449. // appropriate delays/timeouts.
  450. if isRequestRoundTrip {
  451. if roundTripDelay != t.expectedRequestDelay {
  452. return nil, errors.TraceNew("unexpected round trip delay")
  453. }
  454. if roundTripTimeout != t.expectedRequestTimeout {
  455. return nil, errors.TraceNew("unexpected round trip timeout")
  456. }
  457. } else {
  458. if roundTripDelay != time.Duration(0) {
  459. return nil, errors.TraceNew("unexpected round trip delay")
  460. }
  461. if roundTripTimeout != t.expectedSessionHandshakeTimeout {
  462. return nil, errors.TraceNew("unexpected round trip timeout")
  463. }
  464. }
  465. }
  466. return responsePayload, nil
  467. }
  468. func (t *testSessionRoundTripper) Close() error {
  469. t.sessions = nil
  470. return nil
  471. }
  472. func TestNoise(t *testing.T) {
  473. err := runTestNoise()
  474. if err != nil {
  475. t.Error(errors.Trace(err).Error())
  476. }
  477. }
  478. func runTestNoise() error {
  479. prologue := []byte("psiphon-inproxy-session")
  480. initiatorPrivateKey, err := GenerateSessionPrivateKey()
  481. if err != nil {
  482. return errors.Trace(err)
  483. }
  484. initiatorPublicKey, err := initiatorPrivateKey.GetPublicKey()
  485. if err != nil {
  486. return errors.Trace(err)
  487. }
  488. curve25519InitiatorPublicKey, err := initiatorPublicKey.ToCurve25519()
  489. if err != nil {
  490. return errors.Trace(err)
  491. }
  492. initiatorKeys := noise.DHKey{
  493. Public: curve25519InitiatorPublicKey[:],
  494. Private: initiatorPrivateKey.ToCurve25519()[:],
  495. }
  496. responderPrivateKey, err := GenerateSessionPrivateKey()
  497. if err != nil {
  498. return errors.Trace(err)
  499. }
  500. responderPublicKey, err := responderPrivateKey.GetPublicKey()
  501. if err != nil {
  502. return errors.Trace(err)
  503. }
  504. curve25519ResponderPublicKey, err := responderPublicKey.ToCurve25519()
  505. if err != nil {
  506. return errors.Trace(err)
  507. }
  508. responderKeys := noise.DHKey{
  509. Public: curve25519ResponderPublicKey[:],
  510. Private: responderPrivateKey.ToCurve25519()[:],
  511. }
  512. initiatorHandshake, err := noise.NewHandshakeState(
  513. noise.Config{
  514. CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
  515. Pattern: noise.HandshakeXK,
  516. Initiator: true,
  517. Prologue: prologue,
  518. StaticKeypair: initiatorKeys,
  519. PeerStatic: responderKeys.Public,
  520. })
  521. if err != nil {
  522. return errors.Trace(err)
  523. }
  524. responderHandshake, err := noise.NewHandshakeState(
  525. noise.Config{
  526. CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
  527. Pattern: noise.HandshakeXK,
  528. Initiator: false,
  529. Prologue: prologue,
  530. StaticKeypair: responderKeys,
  531. })
  532. if err != nil {
  533. return errors.Trace(err)
  534. }
  535. // Noise XK: -> e, es
  536. var initiatorMsg []byte
  537. initiatorMsg, _, _, err = initiatorHandshake.WriteMessage(initiatorMsg, nil)
  538. if err != nil {
  539. return errors.Trace(err)
  540. }
  541. var receivedPayload []byte
  542. receivedPayload, _, _, err = responderHandshake.ReadMessage(nil, initiatorMsg)
  543. if err != nil {
  544. return errors.Trace(err)
  545. }
  546. if len(receivedPayload) > 0 {
  547. return errors.TraceNew("unexpected payload")
  548. }
  549. // Noise XK: <- e, ee
  550. var responderMsg []byte
  551. responderMsg, _, _, err = responderHandshake.WriteMessage(responderMsg, nil)
  552. if err != nil {
  553. return errors.Trace(err)
  554. }
  555. receivedPayload, _, _, err = initiatorHandshake.ReadMessage(nil, responderMsg)
  556. if err != nil {
  557. return errors.Trace(err)
  558. }
  559. if len(receivedPayload) > 0 {
  560. return errors.TraceNew("unexpected payload")
  561. }
  562. // Noise XK: -> s, se + payload
  563. sendPayload := prng.Bytes(1000)
  564. var initiatorSend, initiatorReceive *noise.CipherState
  565. var initiatorReplay replay.Filter
  566. initiatorMsg = nil
  567. initiatorMsg, initiatorSend, initiatorReceive, err = initiatorHandshake.WriteMessage(initiatorMsg, sendPayload)
  568. if err != nil {
  569. return errors.Trace(err)
  570. }
  571. if initiatorSend == nil || initiatorReceive == nil {
  572. return errors.Tracef("unexpected incomplete handshake")
  573. }
  574. var responderSend, responderReceive *noise.CipherState
  575. var responderReplay replay.Filter
  576. receivedPayload = nil
  577. receivedPayload, responderReceive, responderSend, err = responderHandshake.ReadMessage(receivedPayload, initiatorMsg)
  578. if err != nil {
  579. return errors.Trace(err)
  580. }
  581. if responderReceive == nil || responderSend == nil {
  582. return errors.TraceNew("unexpected incomplete handshake")
  583. }
  584. if receivedPayload == nil {
  585. return errors.TraceNew("missing payload")
  586. }
  587. if !bytes.Equal(sendPayload, receivedPayload) {
  588. return errors.TraceNew("incorrect payload")
  589. }
  590. if !bytes.Equal(responderHandshake.PeerStatic(), initiatorKeys.Public) {
  591. return errors.TraceNew("unexpected initiator static public key")
  592. }
  593. // post-handshake initiator <- responder
  594. nonce := responderSend.Nonce()
  595. responderMsg = nil
  596. responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
  597. if err != nil {
  598. return errors.Trace(err)
  599. }
  600. initiatorReceive.SetNonce(nonce)
  601. receivedPayload = nil
  602. receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
  603. if err != nil {
  604. return errors.Trace(err)
  605. }
  606. if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
  607. return errors.TraceNew("replay detected")
  608. }
  609. if !bytes.Equal(sendPayload, receivedPayload) {
  610. return errors.TraceNew("incorrect payload")
  611. }
  612. for i := 0; i < 100; i++ {
  613. // post-handshake initiator -> responder
  614. sendPayload = prng.Bytes(1000)
  615. nonce = initiatorSend.Nonce()
  616. initiatorMsg = nil
  617. initiatorMsg, err = initiatorSend.Encrypt(initiatorMsg, nil, sendPayload)
  618. if err != nil {
  619. return errors.Trace(err)
  620. }
  621. responderReceive.SetNonce(nonce)
  622. receivedPayload = nil
  623. receivedPayload, err = responderReceive.Decrypt(receivedPayload, nil, initiatorMsg)
  624. if err != nil {
  625. return errors.Trace(err)
  626. }
  627. if !responderReplay.ValidateCounter(nonce, math.MaxUint64) {
  628. return errors.TraceNew("replay detected")
  629. }
  630. if !bytes.Equal(sendPayload, receivedPayload) {
  631. return errors.TraceNew("incorrect payload")
  632. }
  633. // post-handshake initiator <- responder
  634. nonce = responderSend.Nonce()
  635. responderMsg = nil
  636. responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
  637. if err != nil {
  638. return errors.Trace(err)
  639. }
  640. responderReceive.SetNonce(nonce)
  641. receivedPayload = nil
  642. receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
  643. if err != nil {
  644. return errors.Trace(err)
  645. }
  646. if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
  647. return errors.TraceNew("replay detected")
  648. }
  649. if !bytes.Equal(sendPayload, receivedPayload) {
  650. return errors.TraceNew("incorrect payload")
  651. }
  652. }
  653. return nil
  654. }