|
|
@@ -35,6 +35,7 @@ import (
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/inproxy"
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
|
|
|
+ pb "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/pb/psiphond"
|
|
|
"github.com/fxamacker/cbor/v2"
|
|
|
)
|
|
|
|
|
|
@@ -250,7 +251,7 @@ func handshakeAPIRequestHandler(
|
|
|
|
|
|
// Note: ignoring legacy "known_servers" params
|
|
|
|
|
|
- err := validateRequestParams(support.Config, params, handshakeRequestParams)
|
|
|
+ err := validateRequestParams(params, handshakeRequestParams)
|
|
|
if err != nil {
|
|
|
return nil, errors.Trace(err)
|
|
|
}
|
|
|
@@ -617,7 +618,7 @@ func connectedAPIRequestHandler(
|
|
|
sshClient *sshClient,
|
|
|
params common.APIParameters) ([]byte, error) {
|
|
|
|
|
|
- err := validateRequestParams(support.Config, params, connectedRequestParams)
|
|
|
+ err := validateRequestParams(params, connectedRequestParams)
|
|
|
if err != nil {
|
|
|
return nil, errors.Trace(err)
|
|
|
}
|
|
|
@@ -760,7 +761,7 @@ func statusAPIRequestHandler(
|
|
|
sshClient *sshClient,
|
|
|
params common.APIParameters) ([]byte, error) {
|
|
|
|
|
|
- err := validateRequestParams(support.Config, params, statusRequestParams)
|
|
|
+ err := validateRequestParams(params, statusRequestParams)
|
|
|
if err != nil {
|
|
|
return nil, errors.Trace(err)
|
|
|
}
|
|
|
@@ -835,7 +836,7 @@ func statusAPIRequestHandler(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- err := validateRequestParams(support.Config, remoteServerListStat, remoteServerListStatParams)
|
|
|
+ err := validateRequestParams(remoteServerListStat, remoteServerListStatParams)
|
|
|
if err != nil {
|
|
|
// Occasionally, clients may send corrupt persistent stat data. Do not
|
|
|
// fail the status request, as this will lead to endless retries.
|
|
|
@@ -873,7 +874,7 @@ func statusAPIRequestHandler(
|
|
|
|
|
|
for _, failedTunnelStat := range failedTunnelStats {
|
|
|
|
|
|
- err := validateRequestParams(support.Config, failedTunnelStat, failedTunnelStatParams)
|
|
|
+ err := validateRequestParams(failedTunnelStat, failedTunnelStatParams)
|
|
|
if err != nil {
|
|
|
// Occasionally, clients may send corrupt persistent stat data. Do not
|
|
|
// fail the status request, as this will lead to endless retries.
|
|
|
@@ -1061,9 +1062,9 @@ var tacticsRequestParams = append(
|
|
|
tacticsParams...),
|
|
|
baseAndDialParams...)
|
|
|
|
|
|
-func getTacticsAPIParameterValidator(config *Config) common.APIParameterValidator {
|
|
|
+func getTacticsAPIParameterValidator() common.APIParameterValidator {
|
|
|
return func(params common.APIParameters) error {
|
|
|
- return validateRequestParams(config, params, tacticsRequestParams)
|
|
|
+ return validateRequestParams(params, tacticsRequestParams)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1092,9 +1093,9 @@ var inproxyBrokerRequestParams = append(
|
|
|
tacticsParams...),
|
|
|
baseParams...)
|
|
|
|
|
|
-func getInproxyBrokerAPIParameterValidator(config *Config) common.APIParameterValidator {
|
|
|
+func getInproxyBrokerAPIParameterValidator() common.APIParameterValidator {
|
|
|
return func(params common.APIParameters) error {
|
|
|
- return validateRequestParams(config, params, inproxyBrokerRequestParams)
|
|
|
+ return validateRequestParams(params, inproxyBrokerRequestParams)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1140,9 +1141,9 @@ var dslRequestParams = append(
|
|
|
tacticsParams...),
|
|
|
baseParams...)
|
|
|
|
|
|
-func getDSLAPIParameterValidator(config *Config) common.APIParameterValidator {
|
|
|
+func getDSLAPIParameterValidator() common.APIParameterValidator {
|
|
|
return func(params common.APIParameters) error {
|
|
|
- return validateRequestParams(config, params, dslRequestParams)
|
|
|
+ return validateRequestParams(params, dslRequestParams)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1163,12 +1164,53 @@ func getDSLAPIParameterLogFieldFormatter() common.APIParameterLogFieldFormatter
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// TODO: add session_id to baseParams?
|
|
|
+var sessionBaseParams = append(
|
|
|
+ []requestParamSpec{
|
|
|
+ {"session_id", isHexDigits, 0}},
|
|
|
+ baseParams...)
|
|
|
+
|
|
|
+// ValidateAndGetProtobufBaseParams unpacks and validates the input base API
|
|
|
+// parameters and returns the parameters in a pb/psiphond.BaseParams struct.
|
|
|
+//
|
|
|
+// Not all fields in pb/psiphond.BaseParams are populated; some fields, such
|
|
|
+// as GeoIP fields and authorized_access_types, are not accepted from clients
|
|
|
+// and are populated by the server; other fields, such as last_connected, are
|
|
|
+// only sent by the client for certain API requests.
|
|
|
+//
|
|
|
+// Note that the underlying protobuf converter code may intentionally panic in
|
|
|
+// unexpected cases such as a mismatch in log field and protobuf struct field
|
|
|
+// types.
|
|
|
+func ValidateAndGetProtobufBaseParams(
|
|
|
+ packedParams protocol.PackedAPIParameters) (*pb.BaseParams, error) {
|
|
|
+
|
|
|
+ params, err := protocol.DecodePackedAPIParameters(packedParams)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedParams := sessionBaseParams
|
|
|
+
|
|
|
+ err = validateRequestParams(
|
|
|
+ params, expectedParams)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ logFields := getRequestLogFields(
|
|
|
+ "", "", "", GeoIPData{}, nil, params, expectedParams)
|
|
|
+
|
|
|
+ baseParams := protobufPopulateBaseParams(logFields)
|
|
|
+
|
|
|
+ return baseParams, nil
|
|
|
+}
|
|
|
+
|
|
|
// requestParamSpec defines a request parameter. Each param is expected to be
|
|
|
// a string, unless requestParamArray is specified, in which case an array of
|
|
|
// strings is expected.
|
|
|
type requestParamSpec struct {
|
|
|
name string
|
|
|
- validator func(*Config, string) bool
|
|
|
+ validator func(string) bool
|
|
|
flags uint32
|
|
|
}
|
|
|
|
|
|
@@ -1353,7 +1395,6 @@ var baseAndDialParams = append(
|
|
|
inproxyDialParams...)
|
|
|
|
|
|
func validateRequestParams(
|
|
|
- config *Config,
|
|
|
params common.APIParameters,
|
|
|
expectedParams []requestParamSpec) error {
|
|
|
|
|
|
@@ -1368,7 +1409,7 @@ func validateRequestParams(
|
|
|
var err error
|
|
|
switch {
|
|
|
case expectedParam.flags&requestParamArray != 0:
|
|
|
- err = validateStringArrayRequestParam(config, expectedParam, value)
|
|
|
+ err = validateStringArrayRequestParam(expectedParam, value)
|
|
|
case expectedParam.flags&requestParamJSON != 0:
|
|
|
// No validation: the JSON already unmarshalled; the parameter
|
|
|
// user will validate that the JSON contains the expected
|
|
|
@@ -1379,7 +1420,7 @@ func validateRequestParams(
|
|
|
// and rejects the parameter.
|
|
|
|
|
|
default:
|
|
|
- err = validateStringRequestParam(config, expectedParam, value)
|
|
|
+ err = validateStringRequestParam(expectedParam, value)
|
|
|
}
|
|
|
if err != nil {
|
|
|
return errors.Trace(err)
|
|
|
@@ -1421,7 +1462,6 @@ func copyUpdateOnConnectedParams(params common.APIParameters) common.APIParamete
|
|
|
}
|
|
|
|
|
|
func validateStringRequestParam(
|
|
|
- config *Config,
|
|
|
expectedParam requestParamSpec,
|
|
|
value interface{}) error {
|
|
|
|
|
|
@@ -1429,14 +1469,13 @@ func validateStringRequestParam(
|
|
|
if !ok {
|
|
|
return errors.Tracef("unexpected string param type: %s", expectedParam.name)
|
|
|
}
|
|
|
- if !expectedParam.validator(config, strValue) {
|
|
|
+ if !expectedParam.validator(strValue) {
|
|
|
return errors.Tracef("invalid param: %s: %s", expectedParam.name, strValue)
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
func validateStringArrayRequestParam(
|
|
|
- config *Config,
|
|
|
expectedParam requestParamSpec,
|
|
|
value interface{}) error {
|
|
|
|
|
|
@@ -1445,7 +1484,7 @@ func validateStringArrayRequestParam(
|
|
|
return errors.Tracef("unexpected array param type: %s", expectedParam.name)
|
|
|
}
|
|
|
for _, value := range arrayValue {
|
|
|
- err := validateStringRequestParam(config, expectedParam, value)
|
|
|
+ err := validateStringRequestParam(expectedParam, value)
|
|
|
if err != nil {
|
|
|
return errors.Trace(err)
|
|
|
}
|
|
|
@@ -1565,7 +1604,7 @@ func getRequestLogFields(
|
|
|
|
|
|
case "meek_dial_address":
|
|
|
host, _, _ := net.SplitHostPort(strValue)
|
|
|
- if isIPAddress(nil, host) {
|
|
|
+ if isIPAddress(host) {
|
|
|
name = "meek_dial_ip_address"
|
|
|
} else {
|
|
|
name = "meek_dial_domain"
|
|
|
@@ -1832,76 +1871,76 @@ func normalizeClientPlatform(clientPlatform string) string {
|
|
|
return CLIENT_PLATFORM_WINDOWS
|
|
|
}
|
|
|
|
|
|
-func isAnyString(config *Config, value string) bool {
|
|
|
- return true
|
|
|
-}
|
|
|
-
|
|
|
func isMobileClientPlatform(clientPlatform string) bool {
|
|
|
normalizedClientPlatform := normalizeClientPlatform(clientPlatform)
|
|
|
return normalizedClientPlatform == CLIENT_PLATFORM_ANDROID ||
|
|
|
normalizedClientPlatform == CLIENT_PLATFORM_IOS
|
|
|
}
|
|
|
|
|
|
+func isAnyString(value string) bool {
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
// Input validators follow the legacy validations rules in psi_web.
|
|
|
|
|
|
-func isSponsorID(config *Config, value string) bool {
|
|
|
- return len(value) == SPONSOR_ID_LENGTH && isHexDigits(config, value)
|
|
|
+func isSponsorID(value string) bool {
|
|
|
+ return len(value) == SPONSOR_ID_LENGTH && isHexDigits(value)
|
|
|
}
|
|
|
|
|
|
-func isHexDigits(_ *Config, value string) bool {
|
|
|
+func isHexDigits(value string) bool {
|
|
|
// Allows both uppercase in addition to lowercase, for legacy support.
|
|
|
return -1 == strings.IndexFunc(value, func(c rune) bool {
|
|
|
return !unicode.Is(unicode.ASCII_Hex_Digit, c)
|
|
|
})
|
|
|
}
|
|
|
|
|
|
-func isBase64String(_ *Config, value string) bool {
|
|
|
+func isBase64String(value string) bool {
|
|
|
_, err := base64.StdEncoding.DecodeString(value)
|
|
|
return err == nil
|
|
|
}
|
|
|
|
|
|
-func isUnpaddedBase64String(_ *Config, value string) bool {
|
|
|
+func isUnpaddedBase64String(value string) bool {
|
|
|
_, err := base64.RawStdEncoding.DecodeString(value)
|
|
|
return err == nil
|
|
|
}
|
|
|
|
|
|
-func isDigits(_ *Config, value string) bool {
|
|
|
+func isDigits(value string) bool {
|
|
|
return -1 == strings.IndexFunc(value, func(c rune) bool {
|
|
|
return c < '0' || c > '9'
|
|
|
})
|
|
|
}
|
|
|
|
|
|
-func isIntString(_ *Config, value string) bool {
|
|
|
+func isIntString(value string) bool {
|
|
|
_, err := strconv.Atoi(value)
|
|
|
return err == nil
|
|
|
}
|
|
|
|
|
|
-func isFloatString(_ *Config, value string) bool {
|
|
|
+func isFloatString(value string) bool {
|
|
|
_, err := strconv.ParseFloat(value, 64)
|
|
|
return err == nil
|
|
|
}
|
|
|
|
|
|
-func isClientPlatform(_ *Config, value string) bool {
|
|
|
+func isClientPlatform(value string) bool {
|
|
|
return -1 == strings.IndexFunc(value, func(c rune) bool {
|
|
|
// Note: stricter than psi_web's Python string.whitespace
|
|
|
return unicode.Is(unicode.White_Space, c)
|
|
|
})
|
|
|
}
|
|
|
|
|
|
-func isRelayProtocol(_ *Config, value string) bool {
|
|
|
+func isRelayProtocol(value string) bool {
|
|
|
return common.Contains(protocol.SupportedTunnelProtocols, value)
|
|
|
}
|
|
|
|
|
|
-func isBooleanFlag(_ *Config, value string) bool {
|
|
|
+func isBooleanFlag(value string) bool {
|
|
|
return value == "0" || value == "1"
|
|
|
}
|
|
|
|
|
|
-func isUpstreamProxyType(_ *Config, value string) bool {
|
|
|
+func isUpstreamProxyType(value string) bool {
|
|
|
value = strings.ToLower(value)
|
|
|
return value == "http" || value == "socks5" || value == "socks4a"
|
|
|
}
|
|
|
|
|
|
-func isRegionCode(_ *Config, value string) bool {
|
|
|
+func isRegionCode(value string) bool {
|
|
|
if len(value) != 2 {
|
|
|
return false
|
|
|
}
|
|
|
@@ -1910,16 +1949,16 @@ func isRegionCode(_ *Config, value string) bool {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
-func isDialAddress(_ *Config, value string) bool {
|
|
|
+func isDialAddress(value string) bool {
|
|
|
// "<host>:<port>", where <host> is a domain or IP address
|
|
|
parts := strings.Split(value, ":")
|
|
|
if len(parts) != 2 {
|
|
|
return false
|
|
|
}
|
|
|
- if !isIPAddress(nil, parts[0]) && !isDomain(nil, parts[0]) {
|
|
|
+ if !isIPAddress(parts[0]) && !isDomain(parts[0]) {
|
|
|
return false
|
|
|
}
|
|
|
- if !isDigits(nil, parts[1]) {
|
|
|
+ if !isDigits(parts[1]) {
|
|
|
return false
|
|
|
}
|
|
|
_, err := strconv.Atoi(parts[1])
|
|
|
@@ -1930,13 +1969,13 @@ func isDialAddress(_ *Config, value string) bool {
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
-func isIPAddress(_ *Config, value string) bool {
|
|
|
+func isIPAddress(value string) bool {
|
|
|
return net.ParseIP(value) != nil
|
|
|
}
|
|
|
|
|
|
var isDomainRegex = regexp.MustCompile(`[a-zA-Z\d-]{1,63}$`)
|
|
|
|
|
|
-func isDomain(_ *Config, value string) bool {
|
|
|
+func isDomain(value string) bool {
|
|
|
|
|
|
// From: http://stackoverflow.com/questions/2532053/validate-a-hostname-string
|
|
|
//
|
|
|
@@ -1963,32 +2002,32 @@ func isDomain(_ *Config, value string) bool {
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
-func isHostHeader(_ *Config, value string) bool {
|
|
|
+func isHostHeader(value string) bool {
|
|
|
// "<host>:<port>", where <host> is a domain or IP address and ":<port>" is optional
|
|
|
if strings.Contains(value, ":") {
|
|
|
- return isDialAddress(nil, value)
|
|
|
+ return isDialAddress(value)
|
|
|
}
|
|
|
- return isIPAddress(nil, value) || isDomain(nil, value)
|
|
|
+ return isIPAddress(value) || isDomain(value)
|
|
|
}
|
|
|
|
|
|
-func isServerEntrySource(_ *Config, value string) bool {
|
|
|
+func isServerEntrySource(value string) bool {
|
|
|
return common.ContainsWildcard(protocol.SupportedServerEntrySources, value)
|
|
|
}
|
|
|
|
|
|
var isISO8601DateRegex = regexp.MustCompile(
|
|
|
`(?P<year>[0-9]{4})-(?P<month>[0-9]{1,2})-(?P<day>[0-9]{1,2})T(?P<hour>[0-9]{2}):(?P<minute>[0-9]{2}):(?P<second>[0-9]{2})(\.(?P<fraction>[0-9]+))?(?P<timezone>Z|(([-+])([0-9]{2}):([0-9]{2})))`)
|
|
|
|
|
|
-func isISO8601Date(_ *Config, value string) bool {
|
|
|
+func isISO8601Date(value string) bool {
|
|
|
return isISO8601DateRegex.Match([]byte(value))
|
|
|
}
|
|
|
|
|
|
-func isLastConnected(_ *Config, value string) bool {
|
|
|
- return value == "None" || isISO8601Date(nil, value)
|
|
|
+func isLastConnected(value string) bool {
|
|
|
+ return value == "None" || isISO8601Date(value)
|
|
|
}
|
|
|
|
|
|
const geohashAlphabet = "0123456789bcdefghjkmnpqrstuvwxyz"
|
|
|
|
|
|
-func isGeoHashString(_ *Config, value string) bool {
|
|
|
+func isGeoHashString(value string) bool {
|
|
|
// Verify that the string is between 1 and 12 characters long
|
|
|
// and contains only characters from the geohash alphabet.
|
|
|
if len(value) < 1 || len(value) > 12 {
|