Просмотр исходного кода

Add per-protocol/initial Liveness Test tactics parameters

Amir Khan 10 месяцев назад
Родитель
Сommit
2d5f73419f

+ 59 - 0
psiphon/common/parameters/livenessTest.go

@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2025, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package parameters
+
+import "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+
+const LIVENESS_ANY = ""
+
+// LivenessTestSpec specifies the parameters for a Liveness Test.
+type LivenessTestSpec struct {
+	MinUpstreamBytes   int
+	MaxUpstreamBytes   int
+	MinDownstreamBytes int
+	MaxDownstreamBytes int
+}
+
+// LivenessTestSpecs is a map of tunnel protocol patterns to Liveness Test spec.
+// Patterns may contain the '*' wildcard.
+type LivenessTestSpecs map[string]*LivenessTestSpec
+
+func (l LivenessTestSpecs) Validate() error {
+	// Check that there is a LIVENESS_ANY entry.
+	if _, ok := l[LIVENESS_ANY]; !ok {
+		return errors.TraceNew("missing LIVENESS_ANY entry")
+	}
+	// Check that all entries are well-formed.
+	for _, spec := range l {
+		if spec.MinUpstreamBytes < 0 {
+			return errors.TraceNew("invalid MinUpstreamBytes")
+		}
+		if spec.MaxUpstreamBytes < 0 {
+			return errors.TraceNew("invalid MaxUpstreamBytes")
+		}
+		if spec.MinDownstreamBytes < 0 {
+			return errors.TraceNew("invalid MinDownstreamBytes")
+		}
+		if spec.MaxDownstreamBytes < 0 {
+			return errors.TraceNew("invalid MaxDownstreamBytes")
+		}
+	}
+	return nil
+}

+ 19 - 0
psiphon/common/parameters/parameters.go

@@ -213,6 +213,8 @@ const (
 	MeekAlternateContentTypeProbability                = "MeekAlternateContentTypeProbability"
 	MeekAlternateContentTypeProbability                = "MeekAlternateContentTypeProbability"
 	TransformHostNameProbability                       = "TransformHostNameProbability"
 	TransformHostNameProbability                       = "TransformHostNameProbability"
 	PickUserAgentProbability                           = "PickUserAgentProbability"
 	PickUserAgentProbability                           = "PickUserAgentProbability"
+	InitialLivenessTest                                = "InitialLivenessTest"
+	LivenessTest                                       = "LivenessTest"
 	LivenessTestMinUpstreamBytes                       = "LivenessTestMinUpstreamBytes"
 	LivenessTestMinUpstreamBytes                       = "LivenessTestMinUpstreamBytes"
 	LivenessTestMaxUpstreamBytes                       = "LivenessTestMaxUpstreamBytes"
 	LivenessTestMaxUpstreamBytes                       = "LivenessTestMaxUpstreamBytes"
 	LivenessTestMinDownstreamBytes                     = "LivenessTestMinDownstreamBytes"
 	LivenessTestMinDownstreamBytes                     = "LivenessTestMinDownstreamBytes"
@@ -741,6 +743,8 @@ var defaultParameters = map[string]struct {
 	TransformHostNameProbability: {value: 0.5, minimum: 0.0},
 	TransformHostNameProbability: {value: 0.5, minimum: 0.0},
 	PickUserAgentProbability:     {value: 0.5, minimum: 0.0},
 	PickUserAgentProbability:     {value: 0.5, minimum: 0.0},
 
 
+	InitialLivenessTest:            {value: make(LivenessTestSpecs)},
+	LivenessTest:                   {value: make(LivenessTestSpecs)},
 	LivenessTestMinUpstreamBytes:   {value: 0, minimum: 0},
 	LivenessTestMinUpstreamBytes:   {value: 0, minimum: 0},
 	LivenessTestMaxUpstreamBytes:   {value: 0, minimum: 0},
 	LivenessTestMaxUpstreamBytes:   {value: 0, minimum: 0},
 	LivenessTestMinDownstreamBytes: {value: 0, minimum: 0},
 	LivenessTestMinDownstreamBytes: {value: 0, minimum: 0},
@@ -1665,6 +1669,15 @@ func (p *Parameters) Set(
 					}
 					}
 					return nil, errors.Trace(err)
 					return nil, errors.Trace(err)
 				}
 				}
+
+			case LivenessTestSpecs:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
 			}
 			}
 
 
 			// Enforce any minimums. Assumes defaultParameters[name]
 			// Enforce any minimums. Assumes defaultParameters[name]
@@ -2284,3 +2297,9 @@ func (p ParametersAccessor) InproxyTrafficShapingParameters(
 	p.snapshot.getValue(name, &value)
 	p.snapshot.getValue(name, &value)
 	return value
 	return value
 }
 }
+
+func (p ParametersAccessor) LivenessTest(name string) LivenessTestSpecs {
+	value := make(LivenessTestSpecs)
+	p.snapshot.getValue(name, &value)
+	return value
+}

+ 5 - 0
psiphon/common/parameters/parameters_test.go

@@ -225,6 +225,11 @@ func TestGetDefaultParameters(t *testing.T) {
 			if !reflect.DeepEqual(v, g) {
 			if !reflect.DeepEqual(v, g) {
 				t.Fatalf("ConjureTransports returned %+v expected %+v", g, v)
 				t.Fatalf("ConjureTransports returned %+v expected %+v", g, v)
 			}
 			}
+		case LivenessTestSpecs:
+			g := p.Get().LivenessTest(name)
+			if !reflect.DeepEqual(v, g) {
+				t.Fatalf("LivenessTestSpecs returned %+v expected %+v", g, v)
+			}
 		default:
 		default:
 			t.Fatalf("Unhandled default type: %s (%T)", name, defaults.value)
 			t.Fatalf("Unhandled default type: %s (%T)", name, defaults.value)
 		}
 		}

+ 10 - 0
psiphon/config.go

@@ -819,6 +819,8 @@ type Config struct {
 
 
 	// LivenessTestMinUpstreamBytes and other LivenessTest fields are for
 	// LivenessTestMinUpstreamBytes and other LivenessTest fields are for
 	// testing purposes.
 	// testing purposes.
+	InitialLivenessTest            parameters.LivenessTestSpecs
+	LivenessTest                   parameters.LivenessTestSpecs
 	LivenessTestMinUpstreamBytes   *int
 	LivenessTestMinUpstreamBytes   *int
 	LivenessTestMaxUpstreamBytes   *int
 	LivenessTestMaxUpstreamBytes   *int
 	LivenessTestMinDownstreamBytes *int
 	LivenessTestMinDownstreamBytes *int
@@ -2114,6 +2116,14 @@ func (config *Config) makeConfigParameters() map[string]interface{} {
 		applyParameters[parameters.ObfuscatedSSHMaxPadding] = *config.ObfuscatedSSHMaxPadding
 		applyParameters[parameters.ObfuscatedSSHMaxPadding] = *config.ObfuscatedSSHMaxPadding
 	}
 	}
 
 
+	if len(config.InitialLivenessTest) > 0 {
+		applyParameters[parameters.InitialLivenessTest] = config.InitialLivenessTest
+	}
+
+	if len(config.LivenessTest) > 0 {
+		applyParameters[parameters.LivenessTest] = config.LivenessTest
+	}
+
 	if config.LivenessTestMinUpstreamBytes != nil {
 	if config.LivenessTestMinUpstreamBytes != nil {
 		applyParameters[parameters.LivenessTestMinUpstreamBytes] = *config.LivenessTestMinUpstreamBytes
 		applyParameters[parameters.LivenessTestMinUpstreamBytes] = *config.LivenessTestMinUpstreamBytes
 	}
 	}

+ 16 - 4
psiphon/server/server_test.go

@@ -1561,10 +1561,22 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		}
 		}
 
 
 		if runConfig.forceLivenessTest {
 		if runConfig.forceLivenessTest {
-			applyParameters[parameters.LivenessTestMinUpstreamBytes] = livenessTestSize
-			applyParameters[parameters.LivenessTestMaxUpstreamBytes] = livenessTestSize
-			applyParameters[parameters.LivenessTestMinDownstreamBytes] = livenessTestSize
-			applyParameters[parameters.LivenessTestMaxDownstreamBytes] = livenessTestSize
+			applyParameters[parameters.InitialLivenessTest] = parameters.LivenessTestSpecs{
+				"": &parameters.LivenessTestSpec{
+					MinUpstreamBytes:   livenessTestSize,
+					MaxUpstreamBytes:   livenessTestSize,
+					MinDownstreamBytes: livenessTestSize,
+					MaxDownstreamBytes: livenessTestSize,
+				},
+			}
+			applyParameters[parameters.LivenessTest] = parameters.LivenessTestSpecs{
+				"": &parameters.LivenessTestSpec{
+					MinUpstreamBytes:   livenessTestSize,
+					MaxUpstreamBytes:   livenessTestSize,
+					MinDownstreamBytes: livenessTestSize,
+					MaxDownstreamBytes: livenessTestSize,
+				},
+			}
 		}
 		}
 
 
 		if runConfig.doPruneServerEntries {
 		if runConfig.doPruneServerEntries {

+ 67 - 11
psiphon/tunnel.go

@@ -31,6 +31,7 @@ import (
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"slices"
 	"strconv"
 	"strconv"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
@@ -49,6 +50,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/refraction"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/refraction"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/wildcard"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/transferstats"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/transferstats"
 	"github.com/fxamacker/cbor/v2"
 	"github.com/fxamacker/cbor/v2"
 )
 )
@@ -801,10 +803,7 @@ func dialTunnel(
 	rateLimits := p.RateLimits(parameters.TunnelRateLimits)
 	rateLimits := p.RateLimits(parameters.TunnelRateLimits)
 	obfuscatedSSHMinPadding := p.Int(parameters.ObfuscatedSSHMinPadding)
 	obfuscatedSSHMinPadding := p.Int(parameters.ObfuscatedSSHMinPadding)
 	obfuscatedSSHMaxPadding := p.Int(parameters.ObfuscatedSSHMaxPadding)
 	obfuscatedSSHMaxPadding := p.Int(parameters.ObfuscatedSSHMaxPadding)
-	livenessTestMinUpstreamBytes := p.Int(parameters.LivenessTestMinUpstreamBytes)
-	livenessTestMaxUpstreamBytes := p.Int(parameters.LivenessTestMaxUpstreamBytes)
-	livenessTestMinDownstreamBytes := p.Int(parameters.LivenessTestMinDownstreamBytes)
-	livenessTestMaxDownstreamBytes := p.Int(parameters.LivenessTestMaxDownstreamBytes)
+	livenessTestSpec := getLivenessTestSpec(p, dialParams.TunnelProtocol, dialParams.EstablishedTunnelsCount)
 	burstUpstreamTargetBytes := int64(p.Int(parameters.ClientBurstUpstreamTargetBytes))
 	burstUpstreamTargetBytes := int64(p.Int(parameters.ClientBurstUpstreamTargetBytes))
 	burstUpstreamDeadline := p.Duration(parameters.ClientBurstUpstreamDeadline)
 	burstUpstreamDeadline := p.Duration(parameters.ClientBurstUpstreamDeadline)
 	burstDownstreamTargetBytes := int64(p.Int(parameters.ClientBurstDownstreamTargetBytes))
 	burstDownstreamTargetBytes := int64(p.Int(parameters.ClientBurstDownstreamTargetBytes))
@@ -1194,7 +1193,7 @@ func dialTunnel(
 
 
 			sshClient = ssh.NewClient(sshClientConn, sshChannels, noRequests)
 			sshClient = ssh.NewClient(sshClientConn, sshChannels, noRequests)
 
 
-			if livenessTestMaxUpstreamBytes > 0 || livenessTestMaxDownstreamBytes > 0 {
+			if livenessTestSpec.MaxUpstreamBytes > 0 || livenessTestSpec.MaxDownstreamBytes > 0 {
 
 
 				// When configured, perform a liveness test which sends and
 				// When configured, perform a liveness test which sends and
 				// receives bytes through the tunnel to ensure the tunnel had
 				// receives bytes through the tunnel to ensure the tunnel had
@@ -1208,8 +1207,7 @@ func dialTunnel(
 
 
 				metrics, err = performLivenessTest(
 				metrics, err = performLivenessTest(
 					sshClient,
 					sshClient,
-					livenessTestMinUpstreamBytes, livenessTestMaxUpstreamBytes,
-					livenessTestMinDownstreamBytes, livenessTestMaxDownstreamBytes,
+					livenessTestSpec,
 					dialParams.LivenessTestSeed)
 					dialParams.LivenessTestSeed)
 
 
 				// Skip notice when cancelling.
 				// Skip notice when cancelling.
@@ -1632,8 +1630,7 @@ type livenessTestMetrics struct {
 
 
 func performLivenessTest(
 func performLivenessTest(
 	sshClient *ssh.Client,
 	sshClient *ssh.Client,
-	minUpstreamBytes, maxUpstreamBytes int,
-	minDownstreamBytes, maxDownstreamBytes int,
+	spec *parameters.LivenessTestSpec,
 	livenessTestPRNGSeed *prng.Seed) (*livenessTestMetrics, error) {
 	livenessTestPRNGSeed *prng.Seed) (*livenessTestMetrics, error) {
 
 
 	metrics := new(livenessTestMetrics)
 	metrics := new(livenessTestMetrics)
@@ -1644,8 +1641,8 @@ func performLivenessTest(
 
 
 	PRNG := prng.NewPRNGWithSeed(livenessTestPRNGSeed)
 	PRNG := prng.NewPRNGWithSeed(livenessTestPRNGSeed)
 
 
-	metrics.UpstreamBytes = PRNG.Range(minUpstreamBytes, maxUpstreamBytes)
-	metrics.DownstreamBytes = PRNG.Range(minDownstreamBytes, maxDownstreamBytes)
+	metrics.UpstreamBytes = PRNG.Range(spec.MinUpstreamBytes, spec.MaxUpstreamBytes)
+	metrics.DownstreamBytes = PRNG.Range(spec.MinDownstreamBytes, spec.MaxDownstreamBytes)
 
 
 	request := &protocol.RandomStreamRequest{
 	request := &protocol.RandomStreamRequest{
 		UpstreamBytes:   metrics.UpstreamBytes,
 		UpstreamBytes:   metrics.UpstreamBytes,
@@ -2257,3 +2254,62 @@ func sendStats(tunnel *Tunnel) bool {
 
 
 	return err == nil
 	return err == nil
 }
 }
+
+// getLivenessTestSpec returns the LivenessTestSpec for the given tunnel protocol.
+func getLivenessTestSpec(
+	p parameters.ParametersAccessor,
+	tunnelProtocol string,
+	establishedTunnelsCount int) *parameters.LivenessTestSpec {
+
+	// matchingSpec returns the first matching LivenessTestSpec for the given
+	// tunnelProtocol.
+	matchingSpec := func(
+		spec parameters.LivenessTestSpecs,
+		tunnelProtocol string) *parameters.LivenessTestSpec {
+		if len(spec) != 0 {
+			// Sort the patterns by length, longest first, so that the most specific
+			// match is found first.
+			patterns := make([]string, 0, len(spec))
+			for p := range spec {
+				patterns = append(patterns, p)
+			}
+			slices.SortFunc(patterns, func(i, j string) int {
+				return len(j) - len(i)
+			})
+			// Find the first and longest pattern that matches the tunnel protocol.
+			for _, p := range patterns {
+				if wildcard.Match(p, tunnelProtocol) {
+					return spec[p]
+				}
+			}
+			// Default to LIVENESS_ANY if no pattern matches.
+			if v, ok := spec[parameters.LIVENESS_ANY]; ok {
+				return v
+			}
+		}
+		return nil
+	}
+
+	// If EstablishedTunnelsCount is 0, attempt the InitialLivenessTest specification.
+	// If no match is found, or if this is a subsequent connection, proceed to LivenessTest.
+
+	if establishedTunnelsCount == 0 {
+		spec := matchingSpec(p.LivenessTest(parameters.InitialLivenessTest), tunnelProtocol)
+		if spec != nil {
+			return spec
+		}
+	}
+
+	spec := matchingSpec(p.LivenessTest(parameters.LivenessTest), tunnelProtocol)
+	if spec != nil {
+		return spec
+	}
+
+	// Return legacy values as a last resort.
+	return &parameters.LivenessTestSpec{
+		MinUpstreamBytes:   p.Int(parameters.LivenessTestMinUpstreamBytes),
+		MaxUpstreamBytes:   p.Int(parameters.LivenessTestMaxUpstreamBytes),
+		MinDownstreamBytes: p.Int(parameters.LivenessTestMinDownstreamBytes),
+		MaxDownstreamBytes: p.Int(parameters.LivenessTestMaxDownstreamBytes),
+	}
+}