session_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. /*
  2. * Copyright (c) 2023, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package inproxy
  20. import (
  21. "bytes"
  22. "context"
  23. "crypto/rand"
  24. "fmt"
  25. "math"
  26. "strings"
  27. "testing"
  28. "time"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  31. "github.com/flynn/noise"
  32. "golang.zx2c4.com/wireguard/replay"
  33. )
  34. func TestSessions(t *testing.T) {
  35. err := runTestSessions()
  36. if err != nil {
  37. t.Errorf(errors.Trace(err).Error())
  38. }
  39. }
  40. func runTestSessions() error {
  41. // Test: basic round trip succeeds
  42. responderPrivateKey, err := GenerateSessionPrivateKey()
  43. if err != nil {
  44. return errors.Trace(err)
  45. }
  46. responderPublicKey, err := GetSessionPublicKey(responderPrivateKey)
  47. if err != nil {
  48. return errors.Trace(err)
  49. }
  50. responderRootObfuscationSecret, err := GenerateRootObfuscationSecret()
  51. if err != nil {
  52. return errors.Trace(err)
  53. }
  54. responderSessions, err := NewResponderSessions(
  55. responderPrivateKey, responderRootObfuscationSecret)
  56. if err != nil {
  57. return errors.Trace(err)
  58. }
  59. initiatorPrivateKey, err := GenerateSessionPrivateKey()
  60. if err != nil {
  61. return errors.Trace(err)
  62. }
  63. initiatorPublicKey, err := GetSessionPublicKey(initiatorPrivateKey)
  64. if err != nil {
  65. return errors.Trace(err)
  66. }
  67. initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
  68. roundTripper := newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
  69. waitToShareSession := true
  70. request := roundTripper.MakeRequest()
  71. response, err := initiatorSessions.RoundTrip(
  72. context.Background(),
  73. roundTripper,
  74. responderPublicKey,
  75. responderRootObfuscationSecret,
  76. waitToShareSession,
  77. request)
  78. if err != nil {
  79. return errors.Trace(err)
  80. }
  81. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  82. return errors.TraceNew("unexpected response")
  83. }
  84. // Test: session expires; new one negotiated
  85. responderSessions.sessions.Flush()
  86. request = roundTripper.MakeRequest()
  87. response, err = initiatorSessions.RoundTrip(
  88. context.Background(),
  89. roundTripper,
  90. responderPublicKey,
  91. responderRootObfuscationSecret,
  92. waitToShareSession,
  93. request)
  94. if err != nil {
  95. return errors.Trace(err)
  96. }
  97. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  98. return errors.TraceNew("unexpected response")
  99. }
  100. // Test: expected known initiator public key
  101. initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
  102. responderSessions, err = NewResponderSessionsForKnownInitiators(
  103. responderPrivateKey,
  104. responderRootObfuscationSecret,
  105. []SessionPublicKey{initiatorPublicKey})
  106. if err != nil {
  107. return errors.Trace(err)
  108. }
  109. roundTripper = newTestSessionRoundTripper(responderSessions, &initiatorPublicKey)
  110. request = roundTripper.MakeRequest()
  111. response, err = initiatorSessions.RoundTrip(
  112. context.Background(),
  113. roundTripper,
  114. responderPublicKey,
  115. responderRootObfuscationSecret,
  116. waitToShareSession,
  117. request)
  118. if err != nil {
  119. return errors.Trace(err)
  120. }
  121. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  122. return errors.TraceNew("unexpected response")
  123. }
  124. // Test: wrong known initiator public key
  125. unknownInitiatorPrivateKey, err := GenerateSessionPrivateKey()
  126. if err != nil {
  127. return errors.Trace(err)
  128. }
  129. unknownInitiatorSessions := NewInitiatorSessions(unknownInitiatorPrivateKey)
  130. ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Millisecond)
  131. defer cancelFunc()
  132. request = roundTripper.MakeRequest()
  133. response, err = unknownInitiatorSessions.RoundTrip(
  134. ctx,
  135. roundTripper,
  136. responderPublicKey,
  137. responderRootObfuscationSecret,
  138. waitToShareSession,
  139. request)
  140. if err == nil || !strings.HasSuffix(err.Error(), "unexpected initiator public key") {
  141. return errors.Tracef("unexpected result: %v", err)
  142. }
  143. // Test: many concurrent sessions
  144. responderSessions, err = NewResponderSessions(
  145. responderPrivateKey, responderRootObfuscationSecret)
  146. if err != nil {
  147. return errors.Trace(err)
  148. }
  149. roundTripper = newTestSessionRoundTripper(responderSessions, nil)
  150. clientCount := 10000
  151. requestCount := 100
  152. concurrentRequestCount := 5
  153. resultChan := make(chan error, clientCount)
  154. for i := 0; i < clientCount; i++ {
  155. // Run clients concurrently
  156. go func() {
  157. initiatorPrivateKey, err := GenerateSessionPrivateKey()
  158. if err != nil {
  159. resultChan <- errors.Trace(err)
  160. return
  161. }
  162. initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
  163. for i := 0; i < requestCount; i += concurrentRequestCount {
  164. requestResultChan := make(chan error, concurrentRequestCount)
  165. for j := 0; j < concurrentRequestCount; j++ {
  166. // Run some of each client's requests concurrently, to
  167. // exercise waitToShareSession
  168. go func(waitToShareSession bool) {
  169. request := roundTripper.MakeRequest()
  170. response, err := initiatorSessions.RoundTrip(
  171. context.Background(),
  172. roundTripper,
  173. responderPublicKey,
  174. responderRootObfuscationSecret,
  175. waitToShareSession,
  176. request)
  177. if err != nil {
  178. requestResultChan <- errors.Trace(err)
  179. return
  180. }
  181. if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
  182. requestResultChan <- errors.TraceNew("unexpected response")
  183. return
  184. }
  185. requestResultChan <- nil
  186. }(i%2 == 0)
  187. }
  188. for i := 0; i < concurrentRequestCount; i++ {
  189. err = <-requestResultChan
  190. if err != nil {
  191. resultChan <- errors.Trace(err)
  192. return
  193. }
  194. }
  195. }
  196. resultChan <- nil
  197. }()
  198. }
  199. for i := 0; i < clientCount; i++ {
  200. err = <-resultChan
  201. if err != nil {
  202. return errors.Trace(err)
  203. }
  204. }
  205. return nil
  206. }
  207. type testSessionRoundTripper struct {
  208. sessions *ResponderSessions
  209. expectedPeerPublicKey *SessionPublicKey
  210. }
  211. func newTestSessionRoundTripper(
  212. sessions *ResponderSessions,
  213. expectedPeerPublicKey *SessionPublicKey) *testSessionRoundTripper {
  214. return &testSessionRoundTripper{
  215. sessions: sessions,
  216. expectedPeerPublicKey: expectedPeerPublicKey,
  217. }
  218. }
  219. func (t *testSessionRoundTripper) MakeRequest() []byte {
  220. return prng.Bytes(prng.Range(100, 1000))
  221. }
  222. func (t *testSessionRoundTripper) ExpectedResponse(requestPayload []byte) []byte {
  223. l := len(requestPayload)
  224. responsePayload := make([]byte, l)
  225. for i, b := range requestPayload {
  226. responsePayload[l-i-1] = b
  227. }
  228. return responsePayload
  229. }
  230. func (t *testSessionRoundTripper) RoundTrip(ctx context.Context, requestPayload []byte) ([]byte, error) {
  231. err := ctx.Err()
  232. if err != nil {
  233. return nil, errors.Trace(err)
  234. }
  235. unwrappedRequestHandler := func(initiatorID ID, unwrappedRequest []byte) ([]byte, error) {
  236. if t.expectedPeerPublicKey != nil {
  237. if !bytes.Equal(initiatorID[:], (*t.expectedPeerPublicKey)[:]) {
  238. return nil, errors.TraceNew("unexpected initiator ID")
  239. }
  240. }
  241. return t.ExpectedResponse(unwrappedRequest), nil
  242. }
  243. responsePayload, err := t.sessions.HandlePacket(requestPayload, unwrappedRequestHandler)
  244. if err != nil {
  245. // Errors here are expected; e.g., in the session expired case.
  246. fmt.Printf("HandlePacket failed: %v\n", err)
  247. return nil, errors.Trace(err)
  248. }
  249. return responsePayload, nil
  250. }
  251. func (t *testSessionRoundTripper) Close() error {
  252. t.sessions = nil
  253. return nil
  254. }
  255. func TestNoise(t *testing.T) {
  256. err := runTestNoise()
  257. if err != nil {
  258. t.Errorf(errors.Trace(err).Error())
  259. }
  260. }
  261. func runTestNoise() error {
  262. prologue := []byte("psiphon-inproxy-session")
  263. initiatorKeys, err := noise.DH25519.GenerateKeypair(rand.Reader)
  264. if err != nil {
  265. return errors.Trace(err)
  266. }
  267. responderKeys, err := noise.DH25519.GenerateKeypair(rand.Reader)
  268. if err != nil {
  269. return errors.Trace(err)
  270. }
  271. initiatorHandshake, err := noise.NewHandshakeState(
  272. noise.Config{
  273. CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
  274. Pattern: noise.HandshakeXK,
  275. Initiator: true,
  276. Prologue: prologue,
  277. StaticKeypair: initiatorKeys,
  278. PeerStatic: responderKeys.Public,
  279. })
  280. if err != nil {
  281. return errors.Trace(err)
  282. }
  283. responderHandshake, err := noise.NewHandshakeState(
  284. noise.Config{
  285. CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
  286. Pattern: noise.HandshakeXK,
  287. Initiator: false,
  288. Prologue: prologue,
  289. StaticKeypair: responderKeys,
  290. })
  291. if err != nil {
  292. return errors.Trace(err)
  293. }
  294. // Noise XK: -> e, es
  295. var initiatorMsg []byte
  296. initiatorMsg, _, _, err = initiatorHandshake.WriteMessage(initiatorMsg, nil)
  297. if err != nil {
  298. return errors.Trace(err)
  299. }
  300. var receivedPayload []byte
  301. receivedPayload, _, _, err = responderHandshake.ReadMessage(nil, initiatorMsg)
  302. if err != nil {
  303. return errors.Trace(err)
  304. }
  305. if len(receivedPayload) > 0 {
  306. return errors.TraceNew("unexpected payload")
  307. }
  308. // Noise XK: <- e, ee
  309. var responderMsg []byte
  310. responderMsg, _, _, err = responderHandshake.WriteMessage(responderMsg, nil)
  311. if err != nil {
  312. return errors.Trace(err)
  313. }
  314. receivedPayload = nil
  315. receivedPayload, _, _, err = initiatorHandshake.ReadMessage(nil, responderMsg)
  316. if err != nil {
  317. return errors.Trace(err)
  318. }
  319. if len(receivedPayload) > 0 {
  320. return errors.TraceNew("unexpected payload")
  321. }
  322. // Noise XK: -> s, se + payload
  323. sendPayload := prng.Bytes(1000)
  324. var initiatorSend, initiatorReceive *noise.CipherState
  325. var initiatorReplay replay.Filter
  326. initiatorMsg = nil
  327. initiatorMsg, initiatorSend, initiatorReceive, err = initiatorHandshake.WriteMessage(initiatorMsg, sendPayload)
  328. if err != nil {
  329. return errors.Trace(err)
  330. }
  331. if initiatorSend == nil || initiatorReceive == nil {
  332. return errors.Tracef("unexpected incomplete handshake")
  333. }
  334. var responderSend, responderReceive *noise.CipherState
  335. var responderReplay replay.Filter
  336. receivedPayload = nil
  337. receivedPayload, responderReceive, responderSend, err = responderHandshake.ReadMessage(receivedPayload, initiatorMsg)
  338. if err != nil {
  339. return errors.Trace(err)
  340. }
  341. if responderReceive == nil || responderSend == nil {
  342. return errors.TraceNew("unexpected incomplete handshake")
  343. }
  344. if receivedPayload == nil {
  345. return errors.TraceNew("missing payload")
  346. }
  347. if bytes.Compare(sendPayload, receivedPayload) != 0 {
  348. return errors.TraceNew("incorrect payload")
  349. }
  350. if bytes.Compare(responderHandshake.PeerStatic(), initiatorKeys.Public) != 0 {
  351. return errors.TraceNew("unexpected initiator static public key")
  352. }
  353. // post-handshake initiator <- responder
  354. nonce := responderSend.Nonce()
  355. responderMsg = nil
  356. responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
  357. if err != nil {
  358. return errors.Trace(err)
  359. }
  360. initiatorReceive.SetNonce(nonce)
  361. receivedPayload = nil
  362. receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
  363. if err != nil {
  364. return errors.Trace(err)
  365. }
  366. if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
  367. return errors.TraceNew("replay detected")
  368. }
  369. if bytes.Compare(sendPayload, receivedPayload) != 0 {
  370. return errors.TraceNew("incorrect payload")
  371. }
  372. for i := 0; i < 100; i++ {
  373. // post-handshake initiator -> responder
  374. sendPayload = prng.Bytes(1000)
  375. nonce = initiatorSend.Nonce()
  376. initiatorMsg = nil
  377. initiatorMsg, err = initiatorSend.Encrypt(initiatorMsg, nil, sendPayload)
  378. if err != nil {
  379. return errors.Trace(err)
  380. }
  381. responderReceive.SetNonce(nonce)
  382. receivedPayload = nil
  383. receivedPayload, err = responderReceive.Decrypt(receivedPayload, nil, initiatorMsg)
  384. if err != nil {
  385. return errors.Trace(err)
  386. }
  387. if !responderReplay.ValidateCounter(nonce, math.MaxUint64) {
  388. return errors.TraceNew("replay detected")
  389. }
  390. if bytes.Compare(sendPayload, receivedPayload) != 0 {
  391. return errors.TraceNew("incorrect payload")
  392. }
  393. // post-handshake initiator <- responder
  394. nonce = responderSend.Nonce()
  395. responderMsg = nil
  396. responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
  397. if err != nil {
  398. return errors.Trace(err)
  399. }
  400. responderReceive.SetNonce(nonce)
  401. receivedPayload = nil
  402. receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
  403. if err != nil {
  404. return errors.Trace(err)
  405. }
  406. if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
  407. return errors.TraceNew("replay detected")
  408. }
  409. if bytes.Compare(sendPayload, receivedPayload) != 0 {
  410. return errors.TraceNew("incorrect payload")
  411. }
  412. }
  413. return nil
  414. }