Procházet zdrojové kódy

Add ValidateAndGetProtobufBaseParams

Rod Hynes před 3 měsíci
rodič
revize
55612895f2

+ 89 - 50
psiphon/server/api.go

@@ -35,6 +35,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/inproxy"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
+	pb "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/pb/psiphond"
 	"github.com/fxamacker/cbor/v2"
 )
 
@@ -250,7 +251,7 @@ func handshakeAPIRequestHandler(
 
 	// Note: ignoring legacy "known_servers" params
 
-	err := validateRequestParams(support.Config, params, handshakeRequestParams)
+	err := validateRequestParams(params, handshakeRequestParams)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
@@ -617,7 +618,7 @@ func connectedAPIRequestHandler(
 	sshClient *sshClient,
 	params common.APIParameters) ([]byte, error) {
 
-	err := validateRequestParams(support.Config, params, connectedRequestParams)
+	err := validateRequestParams(params, connectedRequestParams)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
@@ -760,7 +761,7 @@ func statusAPIRequestHandler(
 	sshClient *sshClient,
 	params common.APIParameters) ([]byte, error) {
 
-	err := validateRequestParams(support.Config, params, statusRequestParams)
+	err := validateRequestParams(params, statusRequestParams)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
@@ -835,7 +836,7 @@ func statusAPIRequestHandler(
 				}
 			}
 
-			err := validateRequestParams(support.Config, remoteServerListStat, remoteServerListStatParams)
+			err := validateRequestParams(remoteServerListStat, remoteServerListStatParams)
 			if err != nil {
 				// Occasionally, clients may send corrupt persistent stat data. Do not
 				// fail the status request, as this will lead to endless retries.
@@ -873,7 +874,7 @@ func statusAPIRequestHandler(
 
 		for _, failedTunnelStat := range failedTunnelStats {
 
-			err := validateRequestParams(support.Config, failedTunnelStat, failedTunnelStatParams)
+			err := validateRequestParams(failedTunnelStat, failedTunnelStatParams)
 			if err != nil {
 				// Occasionally, clients may send corrupt persistent stat data. Do not
 				// fail the status request, as this will lead to endless retries.
@@ -1061,9 +1062,9 @@ var tacticsRequestParams = append(
 		tacticsParams...),
 	baseAndDialParams...)
 
-func getTacticsAPIParameterValidator(config *Config) common.APIParameterValidator {
+func getTacticsAPIParameterValidator() common.APIParameterValidator {
 	return func(params common.APIParameters) error {
-		return validateRequestParams(config, params, tacticsRequestParams)
+		return validateRequestParams(params, tacticsRequestParams)
 	}
 }
 
@@ -1092,9 +1093,9 @@ var inproxyBrokerRequestParams = append(
 		tacticsParams...),
 	baseParams...)
 
-func getInproxyBrokerAPIParameterValidator(config *Config) common.APIParameterValidator {
+func getInproxyBrokerAPIParameterValidator() common.APIParameterValidator {
 	return func(params common.APIParameters) error {
-		return validateRequestParams(config, params, inproxyBrokerRequestParams)
+		return validateRequestParams(params, inproxyBrokerRequestParams)
 	}
 }
 
@@ -1140,9 +1141,9 @@ var dslRequestParams = append(
 		tacticsParams...),
 	baseParams...)
 
-func getDSLAPIParameterValidator(config *Config) common.APIParameterValidator {
+func getDSLAPIParameterValidator() common.APIParameterValidator {
 	return func(params common.APIParameters) error {
-		return validateRequestParams(config, params, dslRequestParams)
+		return validateRequestParams(params, dslRequestParams)
 	}
 }
 
@@ -1163,12 +1164,53 @@ func getDSLAPIParameterLogFieldFormatter() common.APIParameterLogFieldFormatter
 	}
 }
 
+// TODO: add session_id to baseParams?
+var sessionBaseParams = append(
+	[]requestParamSpec{
+		{"session_id", isHexDigits, 0}},
+	baseParams...)
+
+// ValidateAndGetProtobufBaseParams unpacks and validates the input base API
+// parameters and returns the parameters in a pb/psiphond.BaseParams struct.
+//
+// Not all fields in pb/psiphond.BaseParams are populated; some fields, such
+// as GeoIP fields and authorized_access_types, are not accepted from clients
+// and are populated by the server; other fields, such as last_connected, are
+// only sent by the client for certain API requests.
+//
+// Note that the underlying protobuf converter code may intentionally panic in
+// unexpected cases such as a mismatch in log field and protobuf struct field
+// types.
+func ValidateAndGetProtobufBaseParams(
+	packedParams protocol.PackedAPIParameters) (*pb.BaseParams, error) {
+
+	params, err := protocol.DecodePackedAPIParameters(packedParams)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	expectedParams := sessionBaseParams
+
+	err = validateRequestParams(
+		params, expectedParams)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	logFields := getRequestLogFields(
+		"", "", "", GeoIPData{}, nil, params, expectedParams)
+
+	baseParams := protobufPopulateBaseParams(logFields)
+
+	return baseParams, nil
+}
+
 // requestParamSpec defines a request parameter. Each param is expected to be
 // a string, unless requestParamArray is specified, in which case an array of
 // strings is expected.
 type requestParamSpec struct {
 	name      string
-	validator func(*Config, string) bool
+	validator func(string) bool
 	flags     uint32
 }
 
@@ -1353,7 +1395,6 @@ var baseAndDialParams = append(
 	inproxyDialParams...)
 
 func validateRequestParams(
-	config *Config,
 	params common.APIParameters,
 	expectedParams []requestParamSpec) error {
 
@@ -1368,7 +1409,7 @@ func validateRequestParams(
 		var err error
 		switch {
 		case expectedParam.flags&requestParamArray != 0:
-			err = validateStringArrayRequestParam(config, expectedParam, value)
+			err = validateStringArrayRequestParam(expectedParam, value)
 		case expectedParam.flags&requestParamJSON != 0:
 			// No validation: the JSON already unmarshalled; the parameter
 			// user will validate that the JSON contains the expected
@@ -1379,7 +1420,7 @@ func validateRequestParams(
 			// and rejects the parameter.
 
 		default:
-			err = validateStringRequestParam(config, expectedParam, value)
+			err = validateStringRequestParam(expectedParam, value)
 		}
 		if err != nil {
 			return errors.Trace(err)
@@ -1421,7 +1462,6 @@ func copyUpdateOnConnectedParams(params common.APIParameters) common.APIParamete
 }
 
 func validateStringRequestParam(
-	config *Config,
 	expectedParam requestParamSpec,
 	value interface{}) error {
 
@@ -1429,14 +1469,13 @@ func validateStringRequestParam(
 	if !ok {
 		return errors.Tracef("unexpected string param type: %s", expectedParam.name)
 	}
-	if !expectedParam.validator(config, strValue) {
+	if !expectedParam.validator(strValue) {
 		return errors.Tracef("invalid param: %s: %s", expectedParam.name, strValue)
 	}
 	return nil
 }
 
 func validateStringArrayRequestParam(
-	config *Config,
 	expectedParam requestParamSpec,
 	value interface{}) error {
 
@@ -1445,7 +1484,7 @@ func validateStringArrayRequestParam(
 		return errors.Tracef("unexpected array param type: %s", expectedParam.name)
 	}
 	for _, value := range arrayValue {
-		err := validateStringRequestParam(config, expectedParam, value)
+		err := validateStringRequestParam(expectedParam, value)
 		if err != nil {
 			return errors.Trace(err)
 		}
@@ -1565,7 +1604,7 @@ func getRequestLogFields(
 
 			case "meek_dial_address":
 				host, _, _ := net.SplitHostPort(strValue)
-				if isIPAddress(nil, host) {
+				if isIPAddress(host) {
 					name = "meek_dial_ip_address"
 				} else {
 					name = "meek_dial_domain"
@@ -1832,76 +1871,76 @@ func normalizeClientPlatform(clientPlatform string) string {
 	return CLIENT_PLATFORM_WINDOWS
 }
 
-func isAnyString(config *Config, value string) bool {
-	return true
-}
-
 func isMobileClientPlatform(clientPlatform string) bool {
 	normalizedClientPlatform := normalizeClientPlatform(clientPlatform)
 	return normalizedClientPlatform == CLIENT_PLATFORM_ANDROID ||
 		normalizedClientPlatform == CLIENT_PLATFORM_IOS
 }
 
+func isAnyString(value string) bool {
+	return true
+}
+
 // Input validators follow the legacy validations rules in psi_web.
 
-func isSponsorID(config *Config, value string) bool {
-	return len(value) == SPONSOR_ID_LENGTH && isHexDigits(config, value)
+func isSponsorID(value string) bool {
+	return len(value) == SPONSOR_ID_LENGTH && isHexDigits(value)
 }
 
-func isHexDigits(_ *Config, value string) bool {
+func isHexDigits(value string) bool {
 	// Allows both uppercase in addition to lowercase, for legacy support.
 	return -1 == strings.IndexFunc(value, func(c rune) bool {
 		return !unicode.Is(unicode.ASCII_Hex_Digit, c)
 	})
 }
 
-func isBase64String(_ *Config, value string) bool {
+func isBase64String(value string) bool {
 	_, err := base64.StdEncoding.DecodeString(value)
 	return err == nil
 }
 
-func isUnpaddedBase64String(_ *Config, value string) bool {
+func isUnpaddedBase64String(value string) bool {
 	_, err := base64.RawStdEncoding.DecodeString(value)
 	return err == nil
 }
 
-func isDigits(_ *Config, value string) bool {
+func isDigits(value string) bool {
 	return -1 == strings.IndexFunc(value, func(c rune) bool {
 		return c < '0' || c > '9'
 	})
 }
 
-func isIntString(_ *Config, value string) bool {
+func isIntString(value string) bool {
 	_, err := strconv.Atoi(value)
 	return err == nil
 }
 
-func isFloatString(_ *Config, value string) bool {
+func isFloatString(value string) bool {
 	_, err := strconv.ParseFloat(value, 64)
 	return err == nil
 }
 
-func isClientPlatform(_ *Config, value string) bool {
+func isClientPlatform(value string) bool {
 	return -1 == strings.IndexFunc(value, func(c rune) bool {
 		// Note: stricter than psi_web's Python string.whitespace
 		return unicode.Is(unicode.White_Space, c)
 	})
 }
 
-func isRelayProtocol(_ *Config, value string) bool {
+func isRelayProtocol(value string) bool {
 	return common.Contains(protocol.SupportedTunnelProtocols, value)
 }
 
-func isBooleanFlag(_ *Config, value string) bool {
+func isBooleanFlag(value string) bool {
 	return value == "0" || value == "1"
 }
 
-func isUpstreamProxyType(_ *Config, value string) bool {
+func isUpstreamProxyType(value string) bool {
 	value = strings.ToLower(value)
 	return value == "http" || value == "socks5" || value == "socks4a"
 }
 
-func isRegionCode(_ *Config, value string) bool {
+func isRegionCode(value string) bool {
 	if len(value) != 2 {
 		return false
 	}
@@ -1910,16 +1949,16 @@ func isRegionCode(_ *Config, value string) bool {
 	})
 }
 
-func isDialAddress(_ *Config, value string) bool {
+func isDialAddress(value string) bool {
 	// "<host>:<port>", where <host> is a domain or IP address
 	parts := strings.Split(value, ":")
 	if len(parts) != 2 {
 		return false
 	}
-	if !isIPAddress(nil, parts[0]) && !isDomain(nil, parts[0]) {
+	if !isIPAddress(parts[0]) && !isDomain(parts[0]) {
 		return false
 	}
-	if !isDigits(nil, parts[1]) {
+	if !isDigits(parts[1]) {
 		return false
 	}
 	_, err := strconv.Atoi(parts[1])
@@ -1930,13 +1969,13 @@ func isDialAddress(_ *Config, value string) bool {
 	return true
 }
 
-func isIPAddress(_ *Config, value string) bool {
+func isIPAddress(value string) bool {
 	return net.ParseIP(value) != nil
 }
 
 var isDomainRegex = regexp.MustCompile(`[a-zA-Z\d-]{1,63}$`)
 
-func isDomain(_ *Config, value string) bool {
+func isDomain(value string) bool {
 
 	// From: http://stackoverflow.com/questions/2532053/validate-a-hostname-string
 	//
@@ -1963,32 +2002,32 @@ func isDomain(_ *Config, value string) bool {
 	return true
 }
 
-func isHostHeader(_ *Config, value string) bool {
+func isHostHeader(value string) bool {
 	// "<host>:<port>", where <host> is a domain or IP address and ":<port>" is optional
 	if strings.Contains(value, ":") {
-		return isDialAddress(nil, value)
+		return isDialAddress(value)
 	}
-	return isIPAddress(nil, value) || isDomain(nil, value)
+	return isIPAddress(value) || isDomain(value)
 }
 
-func isServerEntrySource(_ *Config, value string) bool {
+func isServerEntrySource(value string) bool {
 	return common.ContainsWildcard(protocol.SupportedServerEntrySources, value)
 }
 
 var isISO8601DateRegex = regexp.MustCompile(
 	`(?P<year>[0-9]{4})-(?P<month>[0-9]{1,2})-(?P<day>[0-9]{1,2})T(?P<hour>[0-9]{2}):(?P<minute>[0-9]{2}):(?P<second>[0-9]{2})(\.(?P<fraction>[0-9]+))?(?P<timezone>Z|(([-+])([0-9]{2}):([0-9]{2})))`)
 
-func isISO8601Date(_ *Config, value string) bool {
+func isISO8601Date(value string) bool {
 	return isISO8601DateRegex.Match([]byte(value))
 }
 
-func isLastConnected(_ *Config, value string) bool {
-	return value == "None" || isISO8601Date(nil, value)
+func isLastConnected(value string) bool {
+	return value == "None" || isISO8601Date(value)
 }
 
 const geohashAlphabet = "0123456789bcdefghjkmnpqrstuvwxyz"
 
-func isGeoHashString(_ *Config, value string) bool {
+func isGeoHashString(value string) bool {
 	// Verify that the string is between 1 and 12 characters long
 	// and contains only characters from the geohash alphabet.
 	if len(value) < 1 || len(value) > 12 {

+ 85 - 0
psiphon/server/api_test.go

@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2025, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package server
+
+import (
+	"fmt"
+	"strings"
+	"testing"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+)
+
+func TestValidateAndGetProtobufBaseParams(t *testing.T) {
+
+	params := make(common.APIParameters)
+
+	params["session_id"] = prng.HexString(8)
+	params["propagation_channel_id"] = strings.ToUpper(prng.HexString(8))
+	params["sponsor_id"] = strings.ToUpper(prng.HexString(8))
+	params["client_version"] = "1"
+	params["client_platform"] = prng.HexString(8)
+	params["client_features"] = []any{prng.HexString(8), prng.HexString(8)}
+	params["client_build_rev"] = prng.HexString(8)
+	params["device_region"] = "US"
+	params["device_location"] = "gzzzz"
+	params["egress_region"] = "US"
+	params["network_type"] = prng.HexString(8)
+	params["applied_tactics_tag"] = prng.HexString(8)
+
+	packedParams, err := protocol.EncodePackedAPIParameters(params)
+	if err != nil {
+		t.Fatalf("protocol.EncodePackedAPIParameters failed: %v", err)
+	}
+
+	protoBaseParams, err := ValidateAndGetProtobufBaseParams(packedParams)
+	if err != nil {
+		t.Fatalf("ValidateAndGetProtobufBaseParams failed: %v", err)
+	}
+
+	if protoBaseParams.ClientAsn != nil ||
+		protoBaseParams.ClientAso != nil ||
+		protoBaseParams.ClientCity != nil ||
+		protoBaseParams.ClientIsp != nil ||
+		protoBaseParams.ClientRegion != nil ||
+		protoBaseParams.LastConnected != nil ||
+		protoBaseParams.AuthorizedAccessTypes != nil {
+
+		t.Fatalf("unexpected non-nil field: %+v", protoBaseParams)
+	}
+
+	if *protoBaseParams.SessionId != params["session_id"].(string) ||
+		*protoBaseParams.PropagationChannelId != params["propagation_channel_id"].(string) ||
+		*protoBaseParams.SponsorId != params["sponsor_id"].(string) ||
+		fmt.Sprintf("%+v", *protoBaseParams.ClientVersion) != fmt.Sprintf("%+v", params["client_version"]) ||
+		*protoBaseParams.ClientPlatform != params["client_platform"].(string) ||
+		fmt.Sprintf("%+v", protoBaseParams.ClientFeatures) != fmt.Sprintf("%+v", params["client_features"]) ||
+		*protoBaseParams.ClientBuildRev != params["client_build_rev"].(string) ||
+		*protoBaseParams.DeviceRegion != params["device_region"].(string) ||
+		*protoBaseParams.DeviceLocation != params["device_location"].(string) ||
+		*protoBaseParams.EgressRegion != params["egress_region"].(string) ||
+		*protoBaseParams.NetworkType != params["network_type"].(string) ||
+		*protoBaseParams.AppliedTacticsTag != params["applied_tactics_tag"].(string) {
+
+		t.Fatalf("unexpected field: %+v", protoBaseParams)
+	}
+}

+ 1 - 1
psiphon/server/meek.go

@@ -357,7 +357,7 @@ func NewMeekServer(
 				AllowDomainFrontedDestinations: meekServer.inproxyBrokerAllowDomainFrontedDestinations,
 				AllowMatch:                     meekServer.inproxyBrokerAllowMatch,
 				LookupGeoIP:                    lookupGeoIPData,
-				APIParameterValidator:          getInproxyBrokerAPIParameterValidator(support.Config),
+				APIParameterValidator:          getInproxyBrokerAPIParameterValidator(),
 				APIParameterLogFieldFormatter:  getInproxyBrokerAPIParameterLogFieldFormatter(),
 				IsValidServerEntryTag:          support.PsinetDatabase.IsValidServerEntryTag,
 				GetTacticsPayload:              meekServer.inproxyBrokerGetTacticsPayload,

+ 2 - 2
psiphon/server/services.go

@@ -157,7 +157,7 @@ func RunServices(configJSON []byte) (retErr error) {
 			HostKeyFilename:               config.DSLRelayHostKeyFilename,
 			GetServiceAddress:             dslMakeGetServiceAddress(support),
 			HostID:                        config.HostID,
-			APIParameterValidator:         getDSLAPIParameterValidator(config),
+			APIParameterValidator:         getDSLAPIParameterValidator(),
 			APIParameterLogFieldFormatter: getDSLAPIParameterLogFieldFormatter(),
 		})
 		if err != nil {
@@ -663,7 +663,7 @@ func NewSupportServices(config *Config) (*SupportServices, error) {
 	tacticsServer, err := tactics.NewServer(
 		CommonLogger(log),
 		getTacticsAPIParameterLogFieldFormatter(),
-		getTacticsAPIParameterValidator(config),
+		getTacticsAPIParameterValidator(),
 		config.TacticsConfigFilename,
 		config.TacticsRequestPublicKey,
 		config.TacticsRequestPrivateKey,

+ 3 - 3
psiphon/server/tunnelServer.go

@@ -548,7 +548,7 @@ func newSSHServer(
 			ServerPrivateKey:            inproxyPrivateKey,
 			ServerRootObfuscationSecret: inproxyObfuscationSecret,
 			BrokerRoundTripperMaker:     makeRoundTripper,
-			ProxyMetricsValidator:       getInproxyBrokerAPIParameterValidator(support.Config),
+			ProxyMetricsValidator:       getInproxyBrokerAPIParameterValidator(),
 			ProxyMetricsFormatter:       getInproxyBrokerAPIParameterLogFieldFormatter(),
 
 			// Prefix for proxy metrics log fields in server_tunnel
@@ -2903,7 +2903,7 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 		}
 	}
 
-	if !isHexDigits(sshClient.sshServer.support.Config, sshPasswordPayload.SessionId) ||
+	if !isHexDigits(sshPasswordPayload.SessionId) ||
 		len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
 		return nil, errors.Tracef("invalid session ID for %q", conn.User())
 	}
@@ -2930,7 +2930,7 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 	// This optional, early sponsor ID will be logged with server_tunnel if
 	// the tunnel doesn't reach handshakeState.completed.
 	sponsorID := sshPasswordPayload.SponsorID
-	if sponsorID != "" && !isSponsorID(sshClient.sshServer.support.Config, sponsorID) {
+	if sponsorID != "" && !isSponsorID(sponsorID) {
 		return nil, errors.Tracef("invalid sponsor ID")
 	}