/* * Copyright (c) 2018, Psiphon Inc. * All rights reserved. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . * */ package tactics import ( "bytes" "context" "fmt" "io/ioutil" "net" "net/http" "os" "reflect" "strings" "testing" "time" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/stacktrace" ) func TestTactics(t *testing.T) { // Server tactics configuration // Long and short region lists test both map and slice lookups // Repeated median aggregation tests aggregation memoization tacticsConfigTemplate := ` { "RequestPublicKey" : "%s", "RequestPrivateKey" : "%s", "RequestObfuscatedKey" : "%s", "DefaultTactics" : { "TTL" : "1s", "Probability" : %0.1f, "Parameters" : { "NetworkLatencyMultiplier" : %0.1f } }, "FilteredTactics" : [ { "Filter" : { "Regions": ["R1", "R2", "R3", "R4", "R5", "R6"], "APIParameters" : {"client_platform" : ["P1"]}, "SpeedTestRTTMilliseconds" : { "Aggregation" : "Median", "AtLeast" : 1 } }, "Tactics" : { "Parameters" : { "ConnectionWorkerPoolSize" : %d } } }, { "Filter" : { "Regions": ["R1"], "APIParameters" : {"client_platform" : ["P1"], "client_version": ["V1"]}, "SpeedTestRTTMilliseconds" : { "Aggregation" : "Median", "AtLeast" : 1 } }, "Tactics" : { "Parameters" : { %s } } }, { "Filter" : { "Regions": ["R2"] }, "Tactics" : { "Parameters" : { "ConnectionWorkerPoolSize" : %d } } }, { "Filter" : { "Regions": ["R7"], "ISPs": ["I1"], "Cities": ["C1"] }, "Tactics" : { "Parameters" : { "FragmentorDownstreamProbability" : 1.0, "FragmentorDownstreamMinTotalBytes" : 1, "FragmentorDownstreamMaxTotalBytes" : 1, "FragmentorDownstreamMinWriteBytes" : 1, "FragmentorDownstreamMaxWriteBytes" : 1, "FragmentorDownstreamLimitProtocols" : ["OSSH"] } } } ] } ` if stringLookupThreshold != 5 { t.Fatalf("unexpected stringLookupThreshold") } encodedRequestPublicKey, encodedRequestPrivateKey, encodedObfuscatedKey, err := GenerateKeys() if err != nil { t.Fatalf("GenerateKeys failed: %s", err) } tacticsProbability := 0.5 tacticsNetworkLatencyMultiplier := 2.0 tacticsConnectionWorkerPoolSize := 5 tacticsLimitTunnelProtocols := protocol.TunnelProtocols{"OSSH", "SSH"} jsonTacticsLimitTunnelProtocols := `"LimitTunnelProtocols" : ["OSSH", "SSH"]` expectedApplyCount := 3 listenerProtocol := "OSSH" listenerFragmentedGeoIP := func(string) common.GeoIPData { return common.GeoIPData{Country: "R7", ISP: "I1", City: "C1"} } listenerUnfragmentedGeoIPWrongRegion := func(string) common.GeoIPData { return common.GeoIPData{Country: "R8", ISP: "I1", City: "C1"} } listenerUnfragmentedGeoIPWrongISP := func(string) common.GeoIPData { return common.GeoIPData{Country: "R7", ISP: "I2", City: "C1"} } listenerUnfragmentedGeoIPWrongCity := func(string) common.GeoIPData { return common.GeoIPData{Country: "R7", ISP: "I1", City: "C2"} } tacticsConfig := fmt.Sprintf( tacticsConfigTemplate, encodedRequestPublicKey, encodedRequestPrivateKey, encodedObfuscatedKey, tacticsProbability, tacticsNetworkLatencyMultiplier, tacticsConnectionWorkerPoolSize, jsonTacticsLimitTunnelProtocols, tacticsConnectionWorkerPoolSize+1) file, err := ioutil.TempFile("", "tactics.config") if err != nil { t.Fatalf("TempFile create failed: %s", err) } _, err = file.Write([]byte(tacticsConfig)) if err != nil { t.Fatalf("TempFile write failed: %s", err) } file.Close() configFileName := file.Name() defer os.Remove(configFileName) // Configure and run server // Mock server uses an insecure HTTP transport that exposes endpoint names clientGeoIPData := common.GeoIPData{Country: "R1"} logger := newTestLogger() validator := func( params common.APIParameters) error { expectedParams := []string{"client_platform", "client_version"} for _, name := range expectedParams { value, ok := params[name] if !ok { return fmt.Errorf("missing param: %s", name) } _, ok = value.(string) if !ok { return fmt.Errorf("invalid param type: %s", name) } } return nil } formatter := func( geoIPData common.GeoIPData, params common.APIParameters) common.LogFields { return common.LogFields(params) } server, err := NewServer( logger, formatter, validator, configFileName) if err != nil { t.Fatalf("NewServer failed: %s", err) } listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Listen failed: %s", err) } serverAddress := listener.Addr().String() go func() { serveMux := http.NewServeMux() serveMux.HandleFunc( "/", func(w http.ResponseWriter, r *http.Request) { // Ensure RTT takes at least 1 millisecond for speed test time.Sleep(1 * time.Millisecond) endPoint := strings.Trim(r.URL.Path, "/") if !server.HandleEndPoint(endPoint, clientGeoIPData, w, r) { http.NotFound(w, r) } }) httpServer := &http.Server{ Addr: serverAddress, Handler: serveMux, } httpServer.Serve(listener) }() // Configure client clientParams, err := parameters.NewClientParameters( func(err error) { t.Fatalf("ClientParameters getValue failed: %s", err) }) if err != nil { t.Fatalf("NewClientParameters failed: %s", err) } networkID := "NETWORK1" getNetworkID := func() string { return networkID } apiParams := common.APIParameters{ "client_platform": "P1", "client_version": "V1"} storer := newTestStorer() endPointRegion := "R0" endPointProtocol := "OSSH" differentEndPointProtocol := "SSH" roundTripper := func( ctx context.Context, endPoint string, requestBody []byte) ([]byte, error) { request, err := http.NewRequest( "POST", fmt.Sprintf("http://%s/%s", serverAddress, endPoint), bytes.NewReader(requestBody)) if err != nil { return nil, err } request = request.WithContext(ctx) response, err := http.DefaultClient.Do(request) if err != nil { return nil, err } defer response.Body.Close() if response.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP request failed: %d", response.StatusCode) } body, err := ioutil.ReadAll(response.Body) if err != nil { return nil, err } return body, nil } // There should be no local tactics tacticsRecord, err := UseStoredTactics(storer, networkID) if err != nil { t.Fatalf("UseStoredTactics failed: %s", err) } if tacticsRecord != nil { t.Fatalf("unexpected tactics record") } // Helper to check that expected tactics parameters are returned checkParameters := func(r *Record) { p, err := parameters.NewClientParameters(nil) if err != nil { t.Fatalf("NewClientParameters failed: %s", err) } if r.Tactics.Probability != tacticsProbability { t.Fatalf("Unexpected probability: %f", r.Tactics.Probability) } // skipOnError is true for Psiphon clients counts, err := p.Set(r.Tag, true, r.Tactics.Parameters) if err != nil { t.Fatalf("Apply failed: %s", err) } if counts[0] != expectedApplyCount { t.Fatalf("Unexpected apply count: %d", counts[0]) } multipler := p.Get().Float(parameters.NetworkLatencyMultiplier) if multipler != tacticsNetworkLatencyMultiplier { t.Fatalf("Unexpected NetworkLatencyMultiplier: %v", multipler) } connectionWorkerPoolSize := p.Get().Int(parameters.ConnectionWorkerPoolSize) if connectionWorkerPoolSize != tacticsConnectionWorkerPoolSize { t.Fatalf("Unexpected ConnectionWorkerPoolSize: %v", connectionWorkerPoolSize) } limitTunnelProtocols := p.Get().TunnelProtocols(parameters.LimitTunnelProtocols) if !reflect.DeepEqual(limitTunnelProtocols, tacticsLimitTunnelProtocols) { t.Fatalf("Unexpected LimitTunnelProtocols: %v", limitTunnelProtocols) } } // Initial tactics request; will also run a speed test // Request should complete in < 1 second ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second) initialFetchTacticsRecord, err := FetchTactics( ctx, clientParams, storer, getNetworkID, apiParams, endPointProtocol, endPointRegion, encodedRequestPublicKey, encodedObfuscatedKey, roundTripper) cancelFunc() if err != nil { t.Fatalf("FetchTactics failed: %s", err) } if initialFetchTacticsRecord == nil { t.Fatalf("expected tactics record") } checkParameters(initialFetchTacticsRecord) // There should now be cached local tactics storedTacticsRecord, err := UseStoredTactics(storer, networkID) if err != nil { t.Fatalf("UseStoredTactics failed: %s", err) } if storedTacticsRecord == nil { t.Fatalf("expected stored tactics record") } // Strip monotonic component so comparisons will work initialFetchTacticsRecord.Expiry = initialFetchTacticsRecord.Expiry.Round(0) if !reflect.DeepEqual(initialFetchTacticsRecord, storedTacticsRecord) { t.Fatalf("tactics records are not identical:\n\n%#v\n\n%#v\n\n", initialFetchTacticsRecord, storedTacticsRecord) } checkParameters(storedTacticsRecord) // There should now be a speed test sample speedTestSamples, err := getSpeedTestSamples(storer, networkID) if err != nil { t.Fatalf("getSpeedTestSamples failed: %s", err) } if len(speedTestSamples) != 1 { t.Fatalf("unexpected speed test samples count") } // Wait for tactics to expire time.Sleep(1 * time.Second) storedTacticsRecord, err = UseStoredTactics(storer, networkID) if err != nil { t.Fatalf("UseStoredTactics failed: %s", err) } if storedTacticsRecord != nil { t.Fatalf("unexpected stored tactics record") } // Next fetch should merge empty payload as tag matches // TODO: inspect tactics response payload fetchTacticsRecord, err := FetchTactics( context.Background(), clientParams, storer, getNetworkID, apiParams, endPointProtocol, endPointRegion, encodedRequestPublicKey, encodedObfuscatedKey, roundTripper) if err != nil { t.Fatalf("FetchTactics failed: %s", err) } if fetchTacticsRecord == nil { t.Fatalf("expected tactics record") } if initialFetchTacticsRecord.Tag != fetchTacticsRecord.Tag { t.Fatalf("tags are not identical") } if initialFetchTacticsRecord.Expiry.Equal(fetchTacticsRecord.Expiry) { t.Fatalf("expiries unexpectedly identical") } if !reflect.DeepEqual(initialFetchTacticsRecord.Tactics, fetchTacticsRecord.Tactics) { t.Fatalf("tactics are not identical:\n\n%#v\n\n%#v\n\n", initialFetchTacticsRecord.Tactics, fetchTacticsRecord.Tactics) } checkParameters(fetchTacticsRecord) // Modify tactics configuration to change payload tacticsConnectionWorkerPoolSize = 6 tacticsLimitTunnelProtocols = protocol.TunnelProtocols{} jsonTacticsLimitTunnelProtocols = `` expectedApplyCount = 2 // Omitting LimitTunnelProtocols entirely tests this bug fix: When a new // tactics payload is obtained, all previous parameters should be cleared. // // In the bug, any previous parameters not in the new tactics were // incorrectly retained. In this test case, LimitTunnelProtocols is // omitted in the new tactics; if FetchTactics fails to clear the old // LimitTunnelProtocols then the test will fail. tacticsConfig = fmt.Sprintf( tacticsConfigTemplate, encodedRequestPublicKey, encodedRequestPrivateKey, encodedObfuscatedKey, tacticsProbability, tacticsNetworkLatencyMultiplier, tacticsConnectionWorkerPoolSize, jsonTacticsLimitTunnelProtocols, tacticsConnectionWorkerPoolSize+1) err = ioutil.WriteFile(configFileName, []byte(tacticsConfig), 0600) if err != nil { t.Fatalf("WriteFile failed: %s", err) } reloaded, err := server.Reload() if err != nil { t.Fatalf("Reload failed: %s", err) } if !reloaded { t.Fatalf("Server config failed to reload") } // Next fetch should return a different payload fetchTacticsRecord, err = FetchTactics( context.Background(), clientParams, storer, getNetworkID, apiParams, endPointProtocol, endPointRegion, encodedRequestPublicKey, encodedObfuscatedKey, roundTripper) if err != nil { t.Fatalf("FetchTactics failed: %s", err) } if fetchTacticsRecord == nil { t.Fatalf("expected tactics record") } if initialFetchTacticsRecord.Tag == fetchTacticsRecord.Tag { t.Fatalf("tags unexpectedly identical") } if initialFetchTacticsRecord.Expiry.Equal(fetchTacticsRecord.Expiry) { t.Fatalf("expires unexpectedly identical") } if reflect.DeepEqual(initialFetchTacticsRecord.Tactics, fetchTacticsRecord.Tactics) { t.Fatalf("tactics unexpectedly identical") } checkParameters(fetchTacticsRecord) // Exercise handshake transport of tactics // Wait for tactics to expire; handshake should renew time.Sleep(1 * time.Second) handshakeParams := common.APIParameters{ "client_platform": "P1", "client_version": "V1"} err = SetTacticsAPIParameters(clientParams, storer, networkID, handshakeParams) if err != nil { t.Fatalf("SetTacticsAPIParameters failed: %s", err) } tacticsPayload, err := server.GetTacticsPayload(clientGeoIPData, handshakeParams) if err != nil { t.Fatalf("GetTacticsPayload failed: %s", err) } handshakeTacticsRecord, err := HandleTacticsPayload(storer, networkID, tacticsPayload) if err != nil { t.Fatalf("HandleTacticsPayload failed: %s", err) } if handshakeTacticsRecord == nil { t.Fatalf("expected tactics record") } if fetchTacticsRecord.Tag != handshakeTacticsRecord.Tag { t.Fatalf("tags are not identical") } if fetchTacticsRecord.Expiry.Equal(handshakeTacticsRecord.Expiry) { t.Fatalf("expiries unexpectedly identical") } if !reflect.DeepEqual(fetchTacticsRecord.Tactics, handshakeTacticsRecord.Tactics) { t.Fatalf("tactics are not identical:\n\n%#v\n\n%#v\n\n", fetchTacticsRecord.Tactics, handshakeTacticsRecord.Tactics) } checkParameters(handshakeTacticsRecord) // Now there should be stored tactics storedTacticsRecord, err = UseStoredTactics(storer, networkID) if err != nil { t.Fatalf("UseStoredTactics failed: %s", err) } if storedTacticsRecord == nil { t.Fatalf("expected stored tactics record") } handshakeTacticsRecord.Expiry = handshakeTacticsRecord.Expiry.Round(0) if !reflect.DeepEqual(handshakeTacticsRecord, storedTacticsRecord) { t.Fatalf("tactics records are not identical:\n\n%#v\n\n%#v\n\n", handshakeTacticsRecord, storedTacticsRecord) } checkParameters(storedTacticsRecord) // Change network ID, should be no stored tactics networkID = "NETWORK2" storedTacticsRecord, err = UseStoredTactics(storer, networkID) if err != nil { t.Fatalf("UseStoredTactics failed: %s", err) } if storedTacticsRecord != nil { t.Fatalf("unexpected stored tactics record") } // Exercise speed test sample truncation maxSamples := clientParams.Get().Int(parameters.SpeedTestMaxSampleCount) for i := 0; i < maxSamples*2; i++ { response, err := MakeSpeedTestResponse(0, 0) if err != nil { t.Fatalf("MakeSpeedTestResponse failed: %s", err) } err = AddSpeedTestSample( clientParams, storer, networkID, "", differentEndPointProtocol, 100*time.Millisecond, nil, response) if err != nil { t.Fatalf("AddSpeedTestSample failed: %s", err) } } speedTestSamples, err = getSpeedTestSamples(storer, networkID) if err != nil { t.Fatalf("getSpeedTestSamples failed: %s", err) } if len(speedTestSamples) != maxSamples { t.Fatalf("unexpected speed test samples count") } for _, sample := range speedTestSamples { if sample.EndPointProtocol == endPointProtocol { t.Fatalf("unexpected old speed test sample") } } // Fetch should fail when using incorrect keys encodedIncorrectRequestPublicKey, _, encodedIncorrectObfuscatedKey, err := GenerateKeys() if err != nil { t.Fatalf("GenerateKeys failed: %s", err) } _, err = FetchTactics( context.Background(), clientParams, storer, getNetworkID, apiParams, endPointProtocol, endPointRegion, encodedIncorrectRequestPublicKey, encodedObfuscatedKey, roundTripper) if err == nil { t.Fatalf("FetchTactics succeeded unexpectedly with incorrect request key") } _, err = FetchTactics( context.Background(), clientParams, storer, getNetworkID, apiParams, endPointProtocol, endPointRegion, encodedRequestPublicKey, encodedIncorrectObfuscatedKey, roundTripper) if err == nil { t.Fatalf("FetchTactics succeeded unexpectedly with incorrect obfuscated key") } // When no keys are supplied, untunneled tactics requests are not supported, but // handshake tactics (GetTacticsPayload) should still work. tacticsConfig = fmt.Sprintf( tacticsConfigTemplate, "", "", "", tacticsProbability, tacticsNetworkLatencyMultiplier, tacticsConnectionWorkerPoolSize, jsonTacticsLimitTunnelProtocols, tacticsConnectionWorkerPoolSize+1) err = ioutil.WriteFile(configFileName, []byte(tacticsConfig), 0600) if err != nil { t.Fatalf("WriteFile failed: %s", err) } reloaded, err = server.Reload() if err != nil { t.Fatalf("Reload failed: %s", err) } if !reloaded { t.Fatalf("Server config failed to reload") } _, err = server.GetTacticsPayload(clientGeoIPData, handshakeParams) if err != nil { t.Fatalf("GetTacticsPayload failed: %s", err) } handled := server.HandleEndPoint(TACTICS_END_POINT, clientGeoIPData, nil, nil) if handled { t.Fatalf("HandleEndPoint unexpectedly handled request") } handled = server.HandleEndPoint(SPEED_TEST_END_POINT, clientGeoIPData, nil, nil) if handled { t.Fatalf("HandleEndPoint unexpectedly handled request") } // Test Listener tacticsProbability = 1.0 tacticsConfig = fmt.Sprintf( tacticsConfigTemplate, "", "", "", tacticsProbability, tacticsNetworkLatencyMultiplier, tacticsConnectionWorkerPoolSize, jsonTacticsLimitTunnelProtocols, tacticsConnectionWorkerPoolSize+1) err = ioutil.WriteFile(configFileName, []byte(tacticsConfig), 0600) if err != nil { t.Fatalf("WriteFile failed: %s", err) } reloaded, err = server.Reload() if err != nil { t.Fatalf("Reload failed: %s", err) } if !reloaded { t.Fatalf("Server config failed to reload") } listenerTestCases := []struct { description string geoIPLookup func(string) common.GeoIPData expectFragmentor bool }{ { "fragmented", listenerFragmentedGeoIP, true, }, { "unfragmented-region", listenerUnfragmentedGeoIPWrongRegion, false, }, { "unfragmented-ISP", listenerUnfragmentedGeoIPWrongISP, false, }, { "unfragmented-city", listenerUnfragmentedGeoIPWrongCity, false, }, } for _, testCase := range listenerTestCases { t.Run(testCase.description, func(t *testing.T) { tcpListener, err := net.Listen("tcp", ":0") if err != nil { t.Fatalf(" net.Listen failed: %s", err) } tacticsListener := NewListener( tcpListener, server, listenerProtocol, testCase.geoIPLookup) clientConn, err := net.Dial("tcp", tacticsListener.Addr().String()) if err != nil { t.Fatalf(" net.Dial failed: %s", err) return } result := make(chan net.Conn, 1) go func() { serverConn, err := tacticsListener.Accept() if err == nil { result <- serverConn } }() timer := time.NewTimer(3 * time.Second) defer timer.Stop() select { case serverConn := <-result: _, isFragmentor := serverConn.(*fragmentor.Conn) if testCase.expectFragmentor && !isFragmentor { t.Fatalf("unexpected non-fragmentor: %T", serverConn) } else if !testCase.expectFragmentor && isFragmentor { t.Fatalf("unexpected fragmentor: %T", serverConn) } serverConn.Close() case <-timer.C: t.Fatalf("timeout before expected accepted connection") } clientConn.Close() tacticsListener.Close() }) } // TODO: test replay attack defence // TODO: test Server.Validate with invalid tactics configurations } type testStorer struct { tacticsRecords map[string][]byte speedTestSampleRecords map[string][]byte } func newTestStorer() *testStorer { return &testStorer{ tacticsRecords: make(map[string][]byte), speedTestSampleRecords: make(map[string][]byte), } } func (s *testStorer) SetTacticsRecord(networkID string, record []byte) error { s.tacticsRecords[networkID] = record return nil } func (s *testStorer) GetTacticsRecord(networkID string) ([]byte, error) { return s.tacticsRecords[networkID], nil } func (s *testStorer) SetSpeedTestSamplesRecord(networkID string, record []byte) error { s.speedTestSampleRecords[networkID] = record return nil } func (s *testStorer) GetSpeedTestSamplesRecord(networkID string) ([]byte, error) { return s.speedTestSampleRecords[networkID], nil } type testLogger struct { } func newTestLogger() *testLogger { return &testLogger{} } func (l *testLogger) WithTrace() common.LogTrace { return &testLoggerTrace{trace: stacktrace.GetParentFunctionName()} } func (l *testLogger) WithTraceFields(fields common.LogFields) common.LogTrace { return &testLoggerTrace{ trace: stacktrace.GetParentFunctionName(), fields: fields, } } func (l *testLogger) LogMetric(metric string, fields common.LogFields) { fmt.Printf("METRIC: %s: fields=%+v\n", metric, fields) } type testLoggerTrace struct { trace string fields common.LogFields } func (l *testLoggerTrace) log(priority, message string) { fmt.Printf("%s: %s: %s fields=%+v\n", priority, l.trace, message, l.fields) } func (l *testLoggerTrace) Debug(args ...interface{}) { l.log("DEBUG", fmt.Sprint(args...)) } func (l *testLoggerTrace) Info(args ...interface{}) { l.log("INFO", fmt.Sprint(args...)) } func (l *testLoggerTrace) Warning(args ...interface{}) { l.log("WARNING", fmt.Sprint(args...)) } func (l *testLoggerTrace) Error(args ...interface{}) { l.log("ERROR", fmt.Sprint(args...)) }