session_test.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. /*
  2. * Copyright (c) 2023, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package inproxy
  20. import (
  21. "bytes"
  22. "context"
  23. "fmt"
  24. "math"
  25. "strings"
  26. "testing"
  27. "time"
  28. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  31. "github.com/flynn/noise"
  32. "golang.zx2c4.com/wireguard/replay"
  33. )
  34. func TestSessions(t *testing.T) {
  35. err := runTestSessions()
  36. if err != nil {
  37. t.Errorf(errors.Trace(err).Error())
  38. }
  39. }
  40. func runTestSessions() error {
  41. // Test: basic round trip succeeds
  42. responderPrivateKey, err := GenerateSessionPrivateKey()
  43. if err != nil {
  44. return errors.Trace(err)
  45. }
  46. responderPublicKey, err := responderPrivateKey.GetPublicKey()
  47. if err != nil {
  48. return errors.Trace(err)
  49. }
  50. responderRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  51. if err != nil {
  52. return errors.Trace(err)
  53. }
  54. responderSessions, err := NewResponderSessions(
  55. responderPrivateKey, responderRootObfuscationSecret)
  56. if err != nil {
  57. return errors.Trace(err)
  58. }
  59. initiatorPrivateKey, err := GenerateSessionPrivateKey()
  60. if err != nil {
  61. return errors.Trace(err)
  62. }
  63. initiatorPublicKey, err := initiatorPrivateKey.GetPublicKey()
  64. if err != nil {
  65. return errors.Trace(err)
  66. }
  67. initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
  68. roundTripper := newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
  69. waitToShareSession := true
  70. request := roundTripper.MakeRequest()
  71. response, err := initiatorSessions.RoundTrip(
  72. context.Background(),
  73. roundTripper,
  74. responderPublicKey,
  75. responderRootObfuscationSecret,
  76. waitToShareSession,
  77. request)
  78. if err != nil {
  79. return errors.Trace(err)
  80. }
  81. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  82. return errors.TraceNew("unexpected response")
  83. }
  84. // Test: session expires; new one negotiated
  85. //
  86. // sessionStateResponder_XK_recv_e_es_send_e_ee case, when Nonce = 0
  87. responderSessions.sessions.Flush()
  88. request = roundTripper.MakeRequest()
  89. response, err = initiatorSessions.RoundTrip(
  90. context.Background(),
  91. roundTripper,
  92. responderPublicKey,
  93. responderRootObfuscationSecret,
  94. waitToShareSession,
  95. request)
  96. if err != nil {
  97. return errors.Trace(err)
  98. }
  99. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  100. return errors.TraceNew("unexpected response")
  101. }
  102. // Test: session expires; new one negotiated
  103. //
  104. // "unexpected nonce" case, when Nonce > 0
  105. for i := 0; i < 10; i++ {
  106. _, err = initiatorSessions.RoundTrip(
  107. context.Background(),
  108. roundTripper,
  109. responderPublicKey,
  110. responderRootObfuscationSecret,
  111. waitToShareSession,
  112. roundTripper.MakeRequest())
  113. if err != nil {
  114. return errors.Trace(err)
  115. }
  116. }
  117. responderSessions.sessions.Flush()
  118. request = roundTripper.MakeRequest()
  119. response, err = initiatorSessions.RoundTrip(
  120. context.Background(),
  121. roundTripper,
  122. responderPublicKey,
  123. responderRootObfuscationSecret,
  124. waitToShareSession,
  125. request)
  126. if err != nil {
  127. return errors.Trace(err)
  128. }
  129. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  130. return errors.TraceNew("unexpected response")
  131. }
  132. // Test: RoundTrips with waitToShareSession are interrupted when session
  133. // fails
  134. responderSessions.sessions.Flush()
  135. initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
  136. failingRoundTripper := newTestSessionRoundTripper(nil, &initiatorPublicKey)
  137. roundTripCount := 100
  138. results := make(chan error, roundTripCount)
  139. for i := 0; i < roundTripCount; i++ {
  140. go func() {
  141. time.Sleep(prng.DefaultPRNG().Period(0, 10*time.Millisecond))
  142. waitToShareSession := true
  143. _, err := initiatorSessions.RoundTrip(
  144. context.Background(),
  145. failingRoundTripper,
  146. responderPublicKey,
  147. responderRootObfuscationSecret,
  148. waitToShareSession,
  149. roundTripper.MakeRequest())
  150. results <- err
  151. }()
  152. }
  153. waitToShareSessionFailed := false
  154. for i := 0; i < roundTripCount; i++ {
  155. err := <-results
  156. if err == nil {
  157. return errors.TraceNew("unexpected success")
  158. }
  159. if strings.HasSuffix(err.Error(), "waitToShareSession failed") {
  160. waitToShareSessionFailed = true
  161. }
  162. }
  163. if !waitToShareSessionFailed {
  164. return errors.TraceNew("missing waitToShareSession failed error")
  165. }
  166. // Test: expected known initiator public key
  167. initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
  168. responderSessions, err = NewResponderSessionsForKnownInitiators(
  169. responderPrivateKey,
  170. responderRootObfuscationSecret,
  171. []SessionPublicKey{initiatorPublicKey})
  172. if err != nil {
  173. return errors.Trace(err)
  174. }
  175. roundTripper = newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
  176. request = roundTripper.MakeRequest()
  177. response, err = initiatorSessions.RoundTrip(
  178. context.Background(),
  179. roundTripper,
  180. responderPublicKey,
  181. responderRootObfuscationSecret,
  182. waitToShareSession,
  183. request)
  184. if err != nil {
  185. return errors.Trace(err)
  186. }
  187. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  188. return errors.TraceNew("unexpected response")
  189. }
  190. // Test: expected known initiator public key using SetKnownInitiatorPublicKeys
  191. initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
  192. responderSessions, err = NewResponderSessionsForKnownInitiators(
  193. responderPrivateKey,
  194. responderRootObfuscationSecret,
  195. []SessionPublicKey{})
  196. if err != nil {
  197. return errors.Trace(err)
  198. }
  199. responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{initiatorPublicKey})
  200. roundTripper = newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
  201. request = roundTripper.MakeRequest()
  202. response, err = initiatorSessions.RoundTrip(
  203. context.Background(),
  204. roundTripper,
  205. responderPublicKey,
  206. responderRootObfuscationSecret,
  207. waitToShareSession,
  208. request)
  209. if err != nil {
  210. return errors.Trace(err)
  211. }
  212. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  213. return errors.TraceNew("unexpected response")
  214. }
  215. // The existing session should not be dropped as the original key remains valid.
  216. responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{initiatorPublicKey})
  217. if responderSessions.sessions.ItemCount() != 1 {
  218. return errors.TraceNew("unexpected session cache state")
  219. }
  220. otherKnownInitiatorPrivateKey, err := GenerateSessionPrivateKey()
  221. if err != nil {
  222. return errors.Trace(err)
  223. }
  224. otherKnownInitiatorPublicKey, err := otherKnownInitiatorPrivateKey.GetPublicKey()
  225. if err != nil {
  226. return errors.Trace(err)
  227. }
  228. // The existing session should be dropped as the original key is not longer valid.
  229. responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{otherKnownInitiatorPublicKey})
  230. if responderSessions.sessions.ItemCount() != 0 {
  231. return errors.TraceNew("unexpected session cache state")
  232. }
  233. // Test: wrong known initiator public key
  234. unknownInitiatorPrivateKey, err := GenerateSessionPrivateKey()
  235. if err != nil {
  236. return errors.Trace(err)
  237. }
  238. unknownInitiatorSessions := NewInitiatorSessions(unknownInitiatorPrivateKey)
  239. ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
  240. defer cancelFunc()
  241. request = roundTripper.MakeRequest()
  242. response, err = unknownInitiatorSessions.RoundTrip(
  243. ctx,
  244. roundTripper,
  245. responderPublicKey,
  246. responderRootObfuscationSecret,
  247. waitToShareSession,
  248. request)
  249. if err == nil || !strings.HasSuffix(err.Error(), "unexpected initiator public key") {
  250. return errors.Tracef("unexpected result: %v", err)
  251. }
  252. // Test: many concurrent sessions
  253. responderSessions, err = NewResponderSessions(
  254. responderPrivateKey, responderRootObfuscationSecret)
  255. if err != nil {
  256. return errors.Trace(err)
  257. }
  258. roundTripper = newTestSessionRoundTripper(responderSessions, nil)
  259. clientCount := 10000
  260. requestCount := 100
  261. concurrentRequestCount := 5
  262. if common.IsRaceDetectorEnabled {
  263. // Workaround for very high memory usage and OOM that occurs only with
  264. // the race detector enabled.
  265. clientCount = 100
  266. }
  267. resultChan := make(chan error, clientCount)
  268. for i := 0; i < clientCount; i++ {
  269. // Run clients concurrently
  270. go func() {
  271. initiatorPrivateKey, err := GenerateSessionPrivateKey()
  272. if err != nil {
  273. resultChan <- errors.Trace(err)
  274. return
  275. }
  276. initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
  277. for i := 0; i < requestCount; i += concurrentRequestCount {
  278. requestResultChan := make(chan error, concurrentRequestCount)
  279. for j := 0; j < concurrentRequestCount; j++ {
  280. // Run some of each client's requests concurrently, to
  281. // exercise waitToShareSession
  282. go func(waitToShareSession bool) {
  283. request := roundTripper.MakeRequest()
  284. response, err := initiatorSessions.RoundTrip(
  285. context.Background(),
  286. roundTripper,
  287. responderPublicKey,
  288. responderRootObfuscationSecret,
  289. waitToShareSession,
  290. request)
  291. if err != nil {
  292. requestResultChan <- errors.Trace(err)
  293. return
  294. }
  295. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  296. requestResultChan <- errors.TraceNew("unexpected response")
  297. return
  298. }
  299. requestResultChan <- nil
  300. }(i%2 == 0)
  301. }
  302. for i := 0; i < concurrentRequestCount; i++ {
  303. err = <-requestResultChan
  304. if err != nil {
  305. resultChan <- errors.Trace(err)
  306. return
  307. }
  308. }
  309. }
  310. resultChan <- nil
  311. }()
  312. }
  313. for i := 0; i < clientCount; i++ {
  314. err = <-resultChan
  315. if err != nil {
  316. return errors.Trace(err)
  317. }
  318. }
  319. return nil
  320. }
  321. type testSessionRoundTripper struct {
  322. sessions *ResponderSessions
  323. expectedPeerPublicKey *SessionPublicKey
  324. }
  325. func newTestSessionRoundTripper(
  326. sessions *ResponderSessions,
  327. expectedPeerPublicKey *SessionPublicKey) *testSessionRoundTripper {
  328. return &testSessionRoundTripper{
  329. sessions: sessions,
  330. expectedPeerPublicKey: expectedPeerPublicKey,
  331. }
  332. }
  333. func (t *testSessionRoundTripper) MakeRequest() []byte {
  334. return prng.Bytes(prng.Range(100, 1000))
  335. }
  336. func (t *testSessionRoundTripper) ExpectedResponse(requestPayload []byte) []byte {
  337. l := len(requestPayload)
  338. responsePayload := make([]byte, l)
  339. for i, b := range requestPayload {
  340. responsePayload[l-i-1] = b
  341. }
  342. return responsePayload
  343. }
  344. func (t *testSessionRoundTripper) RoundTrip(ctx context.Context, requestPayload []byte) ([]byte, error) {
  345. err := ctx.Err()
  346. if err != nil {
  347. return nil, errors.Trace(err)
  348. }
  349. if t.sessions == nil {
  350. return nil, errors.TraceNew("closed")
  351. }
  352. unwrappedRequestHandler := func(initiatorID ID, unwrappedRequest []byte) ([]byte, error) {
  353. if t.expectedPeerPublicKey != nil {
  354. curve25519, err := (*t.expectedPeerPublicKey).ToCurve25519()
  355. if err != nil {
  356. return nil, errors.Trace(err)
  357. }
  358. if !bytes.Equal(initiatorID[:], curve25519[:]) {
  359. return nil, errors.TraceNew("unexpected initiator ID")
  360. }
  361. }
  362. return t.ExpectedResponse(unwrappedRequest), nil
  363. }
  364. responsePayload, err := t.sessions.HandlePacket(requestPayload, unwrappedRequestHandler)
  365. if err != nil {
  366. if responsePayload == nil {
  367. return nil, errors.Trace(err)
  368. } else {
  369. fmt.Printf("HandlePacket returned packet and error: %v\n", err)
  370. // Continue to relay packets
  371. }
  372. }
  373. return responsePayload, nil
  374. }
  375. func (t *testSessionRoundTripper) Close() error {
  376. t.sessions = nil
  377. return nil
  378. }
  379. func TestNoise(t *testing.T) {
  380. err := runTestNoise()
  381. if err != nil {
  382. t.Errorf(errors.Trace(err).Error())
  383. }
  384. }
  385. func runTestNoise() error {
  386. prologue := []byte("psiphon-inproxy-session")
  387. initiatorPrivateKey, err := GenerateSessionPrivateKey()
  388. if err != nil {
  389. return errors.Trace(err)
  390. }
  391. initiatorPublicKey, err := initiatorPrivateKey.GetPublicKey()
  392. if err != nil {
  393. return errors.Trace(err)
  394. }
  395. curve25519InitiatorPublicKey, err := initiatorPublicKey.ToCurve25519()
  396. if err != nil {
  397. return errors.Trace(err)
  398. }
  399. initiatorKeys := noise.DHKey{
  400. Public: curve25519InitiatorPublicKey[:],
  401. Private: initiatorPrivateKey.ToCurve25519()[:],
  402. }
  403. responderPrivateKey, err := GenerateSessionPrivateKey()
  404. if err != nil {
  405. return errors.Trace(err)
  406. }
  407. responderPublicKey, err := responderPrivateKey.GetPublicKey()
  408. if err != nil {
  409. return errors.Trace(err)
  410. }
  411. curve25519ResponderPublicKey, err := responderPublicKey.ToCurve25519()
  412. if err != nil {
  413. return errors.Trace(err)
  414. }
  415. responderKeys := noise.DHKey{
  416. Public: curve25519ResponderPublicKey[:],
  417. Private: responderPrivateKey.ToCurve25519()[:],
  418. }
  419. initiatorHandshake, err := noise.NewHandshakeState(
  420. noise.Config{
  421. CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
  422. Pattern: noise.HandshakeXK,
  423. Initiator: true,
  424. Prologue: prologue,
  425. StaticKeypair: initiatorKeys,
  426. PeerStatic: responderKeys.Public,
  427. })
  428. if err != nil {
  429. return errors.Trace(err)
  430. }
  431. responderHandshake, err := noise.NewHandshakeState(
  432. noise.Config{
  433. CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
  434. Pattern: noise.HandshakeXK,
  435. Initiator: false,
  436. Prologue: prologue,
  437. StaticKeypair: responderKeys,
  438. })
  439. if err != nil {
  440. return errors.Trace(err)
  441. }
  442. // Noise XK: -> e, es
  443. var initiatorMsg []byte
  444. initiatorMsg, _, _, err = initiatorHandshake.WriteMessage(initiatorMsg, nil)
  445. if err != nil {
  446. return errors.Trace(err)
  447. }
  448. var receivedPayload []byte
  449. receivedPayload, _, _, err = responderHandshake.ReadMessage(nil, initiatorMsg)
  450. if err != nil {
  451. return errors.Trace(err)
  452. }
  453. if len(receivedPayload) > 0 {
  454. return errors.TraceNew("unexpected payload")
  455. }
  456. // Noise XK: <- e, ee
  457. var responderMsg []byte
  458. responderMsg, _, _, err = responderHandshake.WriteMessage(responderMsg, nil)
  459. if err != nil {
  460. return errors.Trace(err)
  461. }
  462. receivedPayload = nil
  463. receivedPayload, _, _, err = initiatorHandshake.ReadMessage(nil, responderMsg)
  464. if err != nil {
  465. return errors.Trace(err)
  466. }
  467. if len(receivedPayload) > 0 {
  468. return errors.TraceNew("unexpected payload")
  469. }
  470. // Noise XK: -> s, se + payload
  471. sendPayload := prng.Bytes(1000)
  472. var initiatorSend, initiatorReceive *noise.CipherState
  473. var initiatorReplay replay.Filter
  474. initiatorMsg = nil
  475. initiatorMsg, initiatorSend, initiatorReceive, err = initiatorHandshake.WriteMessage(initiatorMsg, sendPayload)
  476. if err != nil {
  477. return errors.Trace(err)
  478. }
  479. if initiatorSend == nil || initiatorReceive == nil {
  480. return errors.Tracef("unexpected incomplete handshake")
  481. }
  482. var responderSend, responderReceive *noise.CipherState
  483. var responderReplay replay.Filter
  484. receivedPayload = nil
  485. receivedPayload, responderReceive, responderSend, err = responderHandshake.ReadMessage(receivedPayload, initiatorMsg)
  486. if err != nil {
  487. return errors.Trace(err)
  488. }
  489. if responderReceive == nil || responderSend == nil {
  490. return errors.TraceNew("unexpected incomplete handshake")
  491. }
  492. if receivedPayload == nil {
  493. return errors.TraceNew("missing payload")
  494. }
  495. if bytes.Compare(sendPayload, receivedPayload) != 0 {
  496. return errors.TraceNew("incorrect payload")
  497. }
  498. if bytes.Compare(responderHandshake.PeerStatic(), initiatorKeys.Public) != 0 {
  499. return errors.TraceNew("unexpected initiator static public key")
  500. }
  501. // post-handshake initiator <- responder
  502. nonce := responderSend.Nonce()
  503. responderMsg = nil
  504. responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
  505. if err != nil {
  506. return errors.Trace(err)
  507. }
  508. initiatorReceive.SetNonce(nonce)
  509. receivedPayload = nil
  510. receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
  511. if err != nil {
  512. return errors.Trace(err)
  513. }
  514. if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
  515. return errors.TraceNew("replay detected")
  516. }
  517. if bytes.Compare(sendPayload, receivedPayload) != 0 {
  518. return errors.TraceNew("incorrect payload")
  519. }
  520. for i := 0; i < 100; i++ {
  521. // post-handshake initiator -> responder
  522. sendPayload = prng.Bytes(1000)
  523. nonce = initiatorSend.Nonce()
  524. initiatorMsg = nil
  525. initiatorMsg, err = initiatorSend.Encrypt(initiatorMsg, nil, sendPayload)
  526. if err != nil {
  527. return errors.Trace(err)
  528. }
  529. responderReceive.SetNonce(nonce)
  530. receivedPayload = nil
  531. receivedPayload, err = responderReceive.Decrypt(receivedPayload, nil, initiatorMsg)
  532. if err != nil {
  533. return errors.Trace(err)
  534. }
  535. if !responderReplay.ValidateCounter(nonce, math.MaxUint64) {
  536. return errors.TraceNew("replay detected")
  537. }
  538. if bytes.Compare(sendPayload, receivedPayload) != 0 {
  539. return errors.TraceNew("incorrect payload")
  540. }
  541. // post-handshake initiator <- responder
  542. nonce = responderSend.Nonce()
  543. responderMsg = nil
  544. responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
  545. if err != nil {
  546. return errors.Trace(err)
  547. }
  548. responderReceive.SetNonce(nonce)
  549. receivedPayload = nil
  550. receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
  551. if err != nil {
  552. return errors.Trace(err)
  553. }
  554. if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
  555. return errors.TraceNew("replay detected")
  556. }
  557. if bytes.Compare(sendPayload, receivedPayload) != 0 {
  558. return errors.TraceNew("incorrect payload")
  559. }
  560. }
  561. return nil
  562. }