| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796 |
- /*
- * Copyright (c) 2023, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see <http://www.gnu.org/licenses/>.
- *
- */
- package inproxy
- import (
- "bytes"
- "context"
- "fmt"
- "math"
- "strings"
- "testing"
- "time"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
- "github.com/flynn/noise"
- "golang.zx2c4.com/wireguard/replay"
- )
- func TestSessions(t *testing.T) {
- err := runTestSessions()
- if err != nil {
- t.Errorf(errors.Trace(err).Error())
- }
- }
- func runTestSessions() error {
- // Test: basic round trip succeeds
- responderPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return errors.Trace(err)
- }
- responderPublicKey, err := responderPrivateKey.GetPublicKey()
- if err != nil {
- return errors.Trace(err)
- }
- responderRootObfuscationSecret, err := GenerateRootObfuscationSecret()
- if err != nil {
- return errors.Trace(err)
- }
- responderSessions, err := NewResponderSessions(
- responderPrivateKey, responderRootObfuscationSecret)
- if err != nil {
- return errors.Trace(err)
- }
- initiatorPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return errors.Trace(err)
- }
- initiatorPublicKey, err := initiatorPrivateKey.GetPublicKey()
- if err != nil {
- return errors.Trace(err)
- }
- initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
- waitToShareSession := true
- sessionHandshakeTimeout := 100 * time.Millisecond
- requestDelay := 1 * time.Microsecond
- requestTimeout := 200 * time.Millisecond
- roundTripper := newTestSessionRoundTripper(
- responderSessions,
- &initiatorPublicKey,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout)
- request := roundTripper.MakeRequest()
- response, err := initiatorSessions.RoundTrip(
- context.Background(),
- roundTripper,
- responderPublicKey,
- responderRootObfuscationSecret,
- waitToShareSession,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout,
- request)
- if err != nil {
- return errors.Trace(err)
- }
- if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
- return errors.TraceNew("unexpected response")
- }
- // Test: session expires; new one negotiated
- //
- // sessionStateResponder_XK_recv_e_es_send_e_ee case, when Nonce = 0
- responderSessions.sessions.Flush()
- request = roundTripper.MakeRequest()
- response, err = initiatorSessions.RoundTrip(
- context.Background(),
- roundTripper,
- responderPublicKey,
- responderRootObfuscationSecret,
- waitToShareSession,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout,
- request)
- if err != nil {
- return errors.Trace(err)
- }
- if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
- return errors.TraceNew("unexpected response")
- }
- // Test: session expires; new one negotiated
- //
- // "unexpected nonce" case, when Nonce > 0
- for i := 0; i < 10; i++ {
- _, err = initiatorSessions.RoundTrip(
- context.Background(),
- roundTripper,
- responderPublicKey,
- responderRootObfuscationSecret,
- waitToShareSession,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout,
- roundTripper.MakeRequest())
- if err != nil {
- return errors.Trace(err)
- }
- }
- responderSessions.sessions.Flush()
- request = roundTripper.MakeRequest()
- response, err = initiatorSessions.RoundTrip(
- context.Background(),
- roundTripper,
- responderPublicKey,
- responderRootObfuscationSecret,
- waitToShareSession,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout,
- request)
- if err != nil {
- return errors.Trace(err)
- }
- if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
- return errors.TraceNew("unexpected response")
- }
- // Test: RoundTrips with waitToShareSession are interrupted when session
- // fails
- responderSessions.sessions.Flush()
- initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
- failingRoundTripper := newTestSessionRoundTripper(
- nil,
- &initiatorPublicKey,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout)
- roundTripCount := 100
- results := make(chan error, roundTripCount)
- for i := 0; i < roundTripCount; i++ {
- go func() {
- time.Sleep(prng.DefaultPRNG().Period(0, 10*time.Millisecond))
- waitToShareSession := true
- _, err := initiatorSessions.RoundTrip(
- context.Background(),
- failingRoundTripper,
- responderPublicKey,
- responderRootObfuscationSecret,
- waitToShareSession,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout,
- roundTripper.MakeRequest())
- results <- err
- }()
- }
- waitToShareSessionFailed := false
- for i := 0; i < roundTripCount; i++ {
- err := <-results
- if err == nil {
- return errors.TraceNew("unexpected success")
- }
- if strings.HasSuffix(err.Error(), "waitToShareSession failed") {
- waitToShareSessionFailed = true
- }
- }
- if !waitToShareSessionFailed {
- return errors.TraceNew("missing waitToShareSession failed error")
- }
- // Test: expected known initiator public key
- initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
- responderSessions, err = NewResponderSessionsForKnownInitiators(
- responderPrivateKey,
- responderRootObfuscationSecret,
- []SessionPublicKey{initiatorPublicKey})
- if err != nil {
- return errors.Trace(err)
- }
- roundTripper = newTestSessionRoundTripper(
- responderSessions,
- &initiatorPublicKey,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout)
- request = roundTripper.MakeRequest()
- response, err = initiatorSessions.RoundTrip(
- context.Background(),
- roundTripper,
- responderPublicKey,
- responderRootObfuscationSecret,
- waitToShareSession,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout,
- request)
- if err != nil {
- return errors.Trace(err)
- }
- if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
- return errors.TraceNew("unexpected response")
- }
- // Test: expected known initiator public key using SetKnownInitiatorPublicKeys
- initiatorSessions = NewInitiatorSessions(initiatorPrivateKey)
- responderSessions, err = NewResponderSessionsForKnownInitiators(
- responderPrivateKey,
- responderRootObfuscationSecret,
- []SessionPublicKey{})
- if err != nil {
- return errors.Trace(err)
- }
- responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{initiatorPublicKey})
- roundTripper = newTestSessionRoundTripper(
- responderSessions,
- &initiatorPublicKey,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout)
- request = roundTripper.MakeRequest()
- response, err = initiatorSessions.RoundTrip(
- context.Background(),
- roundTripper,
- responderPublicKey,
- responderRootObfuscationSecret,
- waitToShareSession,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout,
- request)
- if err != nil {
- return errors.Trace(err)
- }
- if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
- return errors.TraceNew("unexpected response")
- }
- // The existing session should not be dropped as the original key remains valid.
- responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{initiatorPublicKey})
- if responderSessions.sessions.ItemCount() != 1 {
- return errors.TraceNew("unexpected session cache state")
- }
- otherKnownInitiatorPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return errors.Trace(err)
- }
- otherKnownInitiatorPublicKey, err := otherKnownInitiatorPrivateKey.GetPublicKey()
- if err != nil {
- return errors.Trace(err)
- }
- // The existing session should be dropped as the original key is not longer valid.
- responderSessions.SetKnownInitiatorPublicKeys([]SessionPublicKey{otherKnownInitiatorPublicKey})
- if responderSessions.sessions.ItemCount() != 0 {
- return errors.TraceNew("unexpected session cache state")
- }
- // Test: wrong known initiator public key
- unknownInitiatorPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return errors.Trace(err)
- }
- unknownInitiatorSessions := NewInitiatorSessions(unknownInitiatorPrivateKey)
- ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
- defer cancelFunc()
- request = roundTripper.MakeRequest()
- response, err = unknownInitiatorSessions.RoundTrip(
- ctx,
- roundTripper,
- responderPublicKey,
- responderRootObfuscationSecret,
- waitToShareSession,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout,
- request)
- if err == nil || !strings.HasSuffix(err.Error(), "unexpected initiator public key") {
- return errors.Tracef("unexpected result: %v", err)
- }
- // Test: many concurrent sessions
- responderSessions, err = NewResponderSessions(
- responderPrivateKey, responderRootObfuscationSecret)
- if err != nil {
- return errors.Trace(err)
- }
- roundTripper = newTestSessionRoundTripper(
- responderSessions,
- nil,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout)
- clientCount := 10000
- requestCount := 100
- concurrentRequestCount := 5
- if common.IsRaceDetectorEnabled {
- // Workaround for very high memory usage and OOM that occurs only with
- // the race detector enabled.
- clientCount = 100
- }
- resultChan := make(chan error, clientCount)
- for i := 0; i < clientCount; i++ {
- // Run clients concurrently
- go func() {
- initiatorPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- resultChan <- errors.Trace(err)
- return
- }
- initiatorSessions := NewInitiatorSessions(initiatorPrivateKey)
- for i := 0; i < requestCount; i += concurrentRequestCount {
- requestResultChan := make(chan error, concurrentRequestCount)
- for j := 0; j < concurrentRequestCount; j++ {
- // Run some of each client's requests concurrently, to
- // exercise waitToShareSession
- go func(waitToShareSession bool) {
- request := roundTripper.MakeRequest()
- response, err := initiatorSessions.RoundTrip(
- context.Background(),
- roundTripper,
- responderPublicKey,
- responderRootObfuscationSecret,
- waitToShareSession,
- sessionHandshakeTimeout,
- requestDelay,
- requestTimeout,
- request)
- if err != nil {
- requestResultChan <- errors.Trace(err)
- return
- }
- if !bytes.Equal(response, roundTripper.ExpectedResponse(request)) {
- requestResultChan <- errors.TraceNew("unexpected response")
- return
- }
- requestResultChan <- nil
- }(i%2 == 0)
- }
- for i := 0; i < concurrentRequestCount; i++ {
- err = <-requestResultChan
- if err != nil {
- resultChan <- errors.Trace(err)
- return
- }
- }
- }
- resultChan <- nil
- }()
- }
- for i := 0; i < clientCount; i++ {
- err = <-resultChan
- if err != nil {
- return errors.Trace(err)
- }
- }
- return nil
- }
- type testSessionRoundTripper struct {
- sessions *ResponderSessions
- expectedPeerPublicKey *SessionPublicKey
- expectedSessionHandshakeTimeout time.Duration
- expectedRequestDelay time.Duration
- expectedRequestTimeout time.Duration
- }
- func newTestSessionRoundTripper(
- sessions *ResponderSessions,
- expectedPeerPublicKey *SessionPublicKey,
- expectedSessionHandshakeTimeout time.Duration,
- expectedRequestDelay time.Duration,
- expectedRequestTimeout time.Duration) *testSessionRoundTripper {
- return &testSessionRoundTripper{
- sessions: sessions,
- expectedPeerPublicKey: expectedPeerPublicKey,
- expectedSessionHandshakeTimeout: expectedSessionHandshakeTimeout,
- expectedRequestDelay: expectedRequestDelay,
- expectedRequestTimeout: expectedRequestTimeout,
- }
- }
- func (t *testSessionRoundTripper) MakeRequest() []byte {
- return prng.Bytes(prng.Range(100, 1000))
- }
- func (t *testSessionRoundTripper) ExpectedResponse(requestPayload []byte) []byte {
- l := len(requestPayload)
- responsePayload := make([]byte, l)
- for i, b := range requestPayload {
- responsePayload[l-i-1] = b
- }
- return responsePayload
- }
- func (t *testSessionRoundTripper) RoundTrip(
- ctx context.Context,
- roundTripDelay time.Duration,
- roundTripTimeout time.Duration,
- requestPayload []byte) ([]byte, error) {
- err := ctx.Err()
- if err != nil {
- return nil, errors.Trace(err)
- }
- if t.sessions == nil {
- return nil, errors.TraceNew("closed")
- }
- if roundTripDelay > 0 {
- common.SleepWithContext(ctx, roundTripDelay)
- }
- _, requestCancelFunc := context.WithTimeout(ctx, roundTripTimeout)
- defer requestCancelFunc()
- isRequestRoundTrip := false
- unwrappedRequestHandler := func(initiatorID ID, unwrappedRequest []byte) ([]byte, error) {
- if t.expectedPeerPublicKey != nil {
- curve25519, err := (*t.expectedPeerPublicKey).ToCurve25519()
- if err != nil {
- return nil, errors.Trace(err)
- }
- if !bytes.Equal(initiatorID[:], curve25519[:]) {
- return nil, errors.TraceNew("unexpected initiator ID")
- }
- }
- isRequestRoundTrip = true
- return t.ExpectedResponse(unwrappedRequest), nil
- }
- responsePayload, err := t.sessions.HandlePacket(requestPayload, unwrappedRequestHandler)
- if err != nil {
- if responsePayload == nil {
- return nil, errors.Trace(err)
- } else {
- fmt.Printf("HandlePacket returned packet and error: %v\n", err)
- // Continue to relay packets
- }
- } else {
- // Handshake round trips and request payload round trips should have the
- // appropriate delays/timeouts.
- if isRequestRoundTrip {
- if roundTripDelay != t.expectedRequestDelay {
- return nil, errors.TraceNew("unexpected round trip delay")
- }
- if roundTripTimeout != t.expectedRequestTimeout {
- return nil, errors.TraceNew("unexpected round trip timeout")
- }
- } else {
- if roundTripDelay != time.Duration(0) {
- return nil, errors.TraceNew("unexpected round trip delay")
- }
- if roundTripTimeout != t.expectedSessionHandshakeTimeout {
- return nil, errors.TraceNew("unexpected round trip timeout")
- }
- }
- }
- return responsePayload, nil
- }
- func (t *testSessionRoundTripper) Close() error {
- t.sessions = nil
- return nil
- }
- func TestNoise(t *testing.T) {
- err := runTestNoise()
- if err != nil {
- t.Errorf(errors.Trace(err).Error())
- }
- }
- func runTestNoise() error {
- prologue := []byte("psiphon-inproxy-session")
- initiatorPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return errors.Trace(err)
- }
- initiatorPublicKey, err := initiatorPrivateKey.GetPublicKey()
- if err != nil {
- return errors.Trace(err)
- }
- curve25519InitiatorPublicKey, err := initiatorPublicKey.ToCurve25519()
- if err != nil {
- return errors.Trace(err)
- }
- initiatorKeys := noise.DHKey{
- Public: curve25519InitiatorPublicKey[:],
- Private: initiatorPrivateKey.ToCurve25519()[:],
- }
- responderPrivateKey, err := GenerateSessionPrivateKey()
- if err != nil {
- return errors.Trace(err)
- }
- responderPublicKey, err := responderPrivateKey.GetPublicKey()
- if err != nil {
- return errors.Trace(err)
- }
- curve25519ResponderPublicKey, err := responderPublicKey.ToCurve25519()
- if err != nil {
- return errors.Trace(err)
- }
- responderKeys := noise.DHKey{
- Public: curve25519ResponderPublicKey[:],
- Private: responderPrivateKey.ToCurve25519()[:],
- }
- initiatorHandshake, err := noise.NewHandshakeState(
- noise.Config{
- CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
- Pattern: noise.HandshakeXK,
- Initiator: true,
- Prologue: prologue,
- StaticKeypair: initiatorKeys,
- PeerStatic: responderKeys.Public,
- })
- if err != nil {
- return errors.Trace(err)
- }
- responderHandshake, err := noise.NewHandshakeState(
- noise.Config{
- CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2b),
- Pattern: noise.HandshakeXK,
- Initiator: false,
- Prologue: prologue,
- StaticKeypair: responderKeys,
- })
- if err != nil {
- return errors.Trace(err)
- }
- // Noise XK: -> e, es
- var initiatorMsg []byte
- initiatorMsg, _, _, err = initiatorHandshake.WriteMessage(initiatorMsg, nil)
- if err != nil {
- return errors.Trace(err)
- }
- var receivedPayload []byte
- receivedPayload, _, _, err = responderHandshake.ReadMessage(nil, initiatorMsg)
- if err != nil {
- return errors.Trace(err)
- }
- if len(receivedPayload) > 0 {
- return errors.TraceNew("unexpected payload")
- }
- // Noise XK: <- e, ee
- var responderMsg []byte
- responderMsg, _, _, err = responderHandshake.WriteMessage(responderMsg, nil)
- if err != nil {
- return errors.Trace(err)
- }
- receivedPayload = nil
- receivedPayload, _, _, err = initiatorHandshake.ReadMessage(nil, responderMsg)
- if err != nil {
- return errors.Trace(err)
- }
- if len(receivedPayload) > 0 {
- return errors.TraceNew("unexpected payload")
- }
- // Noise XK: -> s, se + payload
- sendPayload := prng.Bytes(1000)
- var initiatorSend, initiatorReceive *noise.CipherState
- var initiatorReplay replay.Filter
- initiatorMsg = nil
- initiatorMsg, initiatorSend, initiatorReceive, err = initiatorHandshake.WriteMessage(initiatorMsg, sendPayload)
- if err != nil {
- return errors.Trace(err)
- }
- if initiatorSend == nil || initiatorReceive == nil {
- return errors.Tracef("unexpected incomplete handshake")
- }
- var responderSend, responderReceive *noise.CipherState
- var responderReplay replay.Filter
- receivedPayload = nil
- receivedPayload, responderReceive, responderSend, err = responderHandshake.ReadMessage(receivedPayload, initiatorMsg)
- if err != nil {
- return errors.Trace(err)
- }
- if responderReceive == nil || responderSend == nil {
- return errors.TraceNew("unexpected incomplete handshake")
- }
- if receivedPayload == nil {
- return errors.TraceNew("missing payload")
- }
- if bytes.Compare(sendPayload, receivedPayload) != 0 {
- return errors.TraceNew("incorrect payload")
- }
- if bytes.Compare(responderHandshake.PeerStatic(), initiatorKeys.Public) != 0 {
- return errors.TraceNew("unexpected initiator static public key")
- }
- // post-handshake initiator <- responder
- nonce := responderSend.Nonce()
- responderMsg = nil
- responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
- if err != nil {
- return errors.Trace(err)
- }
- initiatorReceive.SetNonce(nonce)
- receivedPayload = nil
- receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
- if err != nil {
- return errors.Trace(err)
- }
- if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
- return errors.TraceNew("replay detected")
- }
- if bytes.Compare(sendPayload, receivedPayload) != 0 {
- return errors.TraceNew("incorrect payload")
- }
- for i := 0; i < 100; i++ {
- // post-handshake initiator -> responder
- sendPayload = prng.Bytes(1000)
- nonce = initiatorSend.Nonce()
- initiatorMsg = nil
- initiatorMsg, err = initiatorSend.Encrypt(initiatorMsg, nil, sendPayload)
- if err != nil {
- return errors.Trace(err)
- }
- responderReceive.SetNonce(nonce)
- receivedPayload = nil
- receivedPayload, err = responderReceive.Decrypt(receivedPayload, nil, initiatorMsg)
- if err != nil {
- return errors.Trace(err)
- }
- if !responderReplay.ValidateCounter(nonce, math.MaxUint64) {
- return errors.TraceNew("replay detected")
- }
- if bytes.Compare(sendPayload, receivedPayload) != 0 {
- return errors.TraceNew("incorrect payload")
- }
- // post-handshake initiator <- responder
- nonce = responderSend.Nonce()
- responderMsg = nil
- responderMsg, err = responderSend.Encrypt(responderMsg, nil, receivedPayload)
- if err != nil {
- return errors.Trace(err)
- }
- responderReceive.SetNonce(nonce)
- receivedPayload = nil
- receivedPayload, err = initiatorReceive.Decrypt(receivedPayload, nil, responderMsg)
- if err != nil {
- return errors.Trace(err)
- }
- if !initiatorReplay.ValidateCounter(nonce, math.MaxUint64) {
- return errors.TraceNew("replay detected")
- }
- if bytes.Compare(sendPayload, receivedPayload) != 0 {
- return errors.TraceNew("incorrect payload")
- }
- }
- return nil
- }
|