Преглед изворни кода

Bug fixes and additional backwards compatibility allowances:
* "client_session_id" should be optional for backwards compatibility
* connected and status should use "session_id" for backwards compatibility
* allow "host_bytes" and "tunnel_stats" to be omitted
* fixed type assertions and handling in "host_bytes" and "tunnel_stats"
* allow "tunnel_whole_device" to be omitted and provide default for logging

Rod Hynes пре 9 година
родитељ
комит
8abcd1121c
1 измењених фајлова са 119 додато и 74 уклоњено
  1. 119 74
      psiphon/server/api.go

+ 119 - 74
psiphon/server/api.go

@@ -147,7 +147,9 @@ func handshakeAPIRequestHandler(
 }
 
 var connectedRequestParams = append(
-	[]requestParamSpec{requestParamSpec{"last_connected", isLastConnected, 0}},
+	[]requestParamSpec{
+		requestParamSpec{"session_id", isHexDigits, 0},
+		requestParamSpec{"last_connected", isLastConnected, 0}},
 	baseRequestParams...)
 
 // connectedAPIRequestHandler implements the "connected" API request.
@@ -188,7 +190,9 @@ func connectedAPIRequestHandler(
 }
 
 var statusRequestParams = append(
-	[]requestParamSpec{requestParamSpec{"connected", isBooleanFlag, 0}},
+	[]requestParamSpec{
+		requestParamSpec{"session_id", isHexDigits, 0},
+		requestParamSpec{"connected", isBooleanFlag, 0}},
 	baseRequestParams...)
 
 // statusAPIRequestHandler implements the "status" API request.
@@ -223,73 +227,85 @@ func statusAPIRequestHandler(
 	log.WithContextFields(bytesTransferredFields).Info("API event")
 
 	// Domain bytes transferred stats
+	// Older clients may not submit this data
 
-	hostBytes, err := getMapStringInt64RequestParam(statusData, "host_bytes")
-	if err != nil {
-		return nil, psiphon.ContextError(err)
-	}
-	domainBytesFields := getRequestLogFields(
-		config, "domain_bytes", geoIPData, params, statusRequestParams)
-	for domain, bytes := range hostBytes {
-		domainBytesFields["domain"] = domain
-		domainBytesFields["bytes"] = bytes
-		log.WithContextFields(domainBytesFields).Info("API event")
+	if statusData["host_bytes"] != nil {
+
+		hostBytes, err := getMapStringInt64RequestParam(statusData, "host_bytes")
+		if err != nil {
+			return nil, psiphon.ContextError(err)
+		}
+		domainBytesFields := getRequestLogFields(
+			config, "domain_bytes", geoIPData, params, statusRequestParams)
+		for domain, bytes := range hostBytes {
+			domainBytesFields["domain"] = domain
+			domainBytesFields["bytes"] = bytes
+			log.WithContextFields(domainBytesFields).Info("API event")
+		}
 	}
 
 	// Tunnel duration and bytes transferred stats
+	// Older clients may not submit this data
 
-	tunnelStats, err := getJSONObjectArrayRequestParam(statusData, "tunnel_stats")
-	if err != nil {
-		return nil, psiphon.ContextError(err)
-	}
-	sessionFields := getRequestLogFields(
-		config, "session", geoIPData, params, statusRequestParams)
-	for _, tunnelStat := range tunnelStats {
+	if statusData["tunnel_stats"] != nil {
 
-		sessionID, err := getStringRequestParam(tunnelStat, "session_id")
+		tunnelStats, err := getJSONObjectArrayRequestParam(statusData, "tunnel_stats")
 		if err != nil {
 			return nil, psiphon.ContextError(err)
 		}
-		sessionFields["session_id"] = sessionID
+		sessionFields := getRequestLogFields(
+			config, "session", geoIPData, params, statusRequestParams)
+		for _, tunnelStat := range tunnelStats {
 
-		tunnelNumber, err := getInt64RequestParam(tunnelStat, "tunnel_number")
-		if err != nil {
-			return nil, psiphon.ContextError(err)
-		}
-		sessionFields["tunnel_number"] = tunnelNumber
+			sessionID, err := getStringRequestParam(tunnelStat, "session_id")
+			if err != nil {
+				return nil, psiphon.ContextError(err)
+			}
+			sessionFields["session_id"] = sessionID
 
-		tunnelServerIPAddress, err := getStringRequestParam(tunnelStat, "tunnel_server_ip_address")
-		if err != nil {
-			return nil, psiphon.ContextError(err)
-		}
-		sessionFields["tunnel_server_ip_address"] = tunnelServerIPAddress
+			tunnelNumber, err := getInt64RequestParam(tunnelStat, "tunnel_number")
+			if err != nil {
+				return nil, psiphon.ContextError(err)
+			}
+			sessionFields["tunnel_number"] = tunnelNumber
 
-		serverHandshakeTimestamp, err := getStringRequestParam(tunnelStat, "server_handshake_timestamp")
-		if err != nil {
-			return nil, psiphon.ContextError(err)
-		}
-		sessionFields["server_handshake_timestamp"] = serverHandshakeTimestamp
+			tunnelServerIPAddress, err := getStringRequestParam(tunnelStat, "tunnel_server_ip_address")
+			if err != nil {
+				return nil, psiphon.ContextError(err)
+			}
+			sessionFields["tunnel_server_ip_address"] = tunnelServerIPAddress
 
-		duration, err := getInt64RequestParam(tunnelStat, "duration")
-		if err != nil {
-			return nil, psiphon.ContextError(err)
-		}
-		// Client reports durations in nanoseconds; divide to get to milliseconds
-		sessionFields["duration"] = duration / 1000000
+			serverHandshakeTimestamp, err := getStringRequestParam(tunnelStat, "server_handshake_timestamp")
+			if err != nil {
+				return nil, psiphon.ContextError(err)
+			}
+			sessionFields["server_handshake_timestamp"] = serverHandshakeTimestamp
 
-		totalBytesSent, err := getInt64RequestParam(tunnelStat, "total_bytes_sent")
-		if err != nil {
-			return nil, psiphon.ContextError(err)
-		}
-		sessionFields["total_bytes_sent"] = totalBytesSent
+			strDuration, err := getStringRequestParam(tunnelStat, "duration")
+			if err != nil {
+				return nil, psiphon.ContextError(err)
+			}
+			duration, err := strconv.ParseInt(strDuration, 10, 64)
+			if err != nil {
+				return nil, psiphon.ContextError(err)
+			}
+			// Client reports durations in nanoseconds; divide to get to milliseconds
+			sessionFields["duration"] = duration / 1000000
 
-		totalBytesReceived, err := getInt64RequestParam(tunnelStat, "total_bytes_received")
-		if err != nil {
-			return nil, psiphon.ContextError(err)
-		}
-		sessionFields["total_bytes_received"] = totalBytesReceived
+			totalBytesSent, err := getInt64RequestParam(tunnelStat, "total_bytes_sent")
+			if err != nil {
+				return nil, psiphon.ContextError(err)
+			}
+			sessionFields["total_bytes_sent"] = totalBytesSent
 
-		log.WithContextFields(sessionFields).Info("API event")
+			totalBytesReceived, err := getInt64RequestParam(tunnelStat, "total_bytes_received")
+			if err != nil {
+				return nil, psiphon.ContextError(err)
+			}
+			sessionFields["total_bytes_received"] = totalBytesReceived
+
+			log.WithContextFields(sessionFields).Info("API event")
+		}
 	}
 
 	return make([]byte, 0), nil
@@ -329,13 +345,13 @@ const (
 // OPTIONAL_COMMON_INPUTS in psi_web.
 var baseRequestParams = []requestParamSpec{
 	requestParamSpec{"server_secret", isServerSecret, requestParamNotLogged},
-	requestParamSpec{"client_session_id", isHexDigits, 0},
+	requestParamSpec{"client_session_id", isHexDigits, requestParamOptional},
 	requestParamSpec{"propagation_channel_id", isHexDigits, 0},
 	requestParamSpec{"sponsor_id", isHexDigits, 0},
 	requestParamSpec{"client_version", isDigits, 0},
 	requestParamSpec{"client_platform", isClientPlatform, 0},
 	requestParamSpec{"relay_protocol", isRelayProtocol, 0},
-	requestParamSpec{"tunnel_whole_device", isBooleanFlag, 0},
+	requestParamSpec{"tunnel_whole_device", isBooleanFlag, requestParamOptional},
 	requestParamSpec{"device_region", isRegionCode, requestParamOptional},
 	requestParamSpec{"meek_dial_address", isDialAddress, requestParamOptional},
 	requestParamSpec{"meek_resolved_ip_address", isIPAddress, requestParamOptional},
@@ -359,7 +375,7 @@ func validateRequestParams(
 				continue
 			}
 			return psiphon.ContextError(
-				fmt.Errorf("missing required param: %s", expectedParam.name))
+				fmt.Errorf("missing param: %s", expectedParam.name))
 		}
 		strValue, ok := value.(string)
 		if !ok {
@@ -404,8 +420,15 @@ func getRequestLogFields(
 
 		value := params[expectedParam.name]
 		if value == nil {
-			// Skip optional params
-			continue
+
+			// Special case: older clients don't send this value,
+			// so log a default.
+			if expectedParam.name == "tunnel_whole_device" {
+				value = "0"
+			} else {
+				// Skip omitted, optional params
+				continue
+			}
 		}
 		strValue, ok := value.(string)
 		if !ok {
@@ -443,57 +466,79 @@ func getRequestLogFields(
 
 func getStringRequestParam(params requestJSONObject, name string) (string, error) {
 	if params[name] == nil {
-		return "", psiphon.ContextError(errors.New("missing param"))
+		return "", psiphon.ContextError(fmt.Errorf("missing param: %s", name))
 	}
 	value, ok := params[name].(string)
 	if !ok {
-		return "", psiphon.ContextError(errors.New("invalid param"))
+		return "", psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
 	return value, nil
 }
 
 func getInt64RequestParam(params requestJSONObject, name string) (int64, error) {
 	if params[name] == nil {
-		return 0, psiphon.ContextError(errors.New("missing param"))
+		return 0, psiphon.ContextError(fmt.Errorf("missing param: %s", name))
 	}
-	value, ok := params[name].(int64)
+	value, ok := params[name].(float64)
 	if !ok {
-		return 0, psiphon.ContextError(errors.New("invalid param"))
+		return 0, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
-	return value, nil
+	return int64(value), nil
 }
 
 func getJSONObjectRequestParam(params requestJSONObject, name string) (requestJSONObject, error) {
 	if params[name] == nil {
-		return nil, psiphon.ContextError(errors.New("missing param"))
+		return nil, psiphon.ContextError(fmt.Errorf("missing param: %s", name))
 	}
 	value, ok := params[name].(requestJSONObject)
 	if !ok {
-		return nil, psiphon.ContextError(errors.New("invalid param"))
+		return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
 	return value, nil
 }
 
 func getJSONObjectArrayRequestParam(params requestJSONObject, name string) ([]requestJSONObject, error) {
 	if params[name] == nil {
-		return nil, psiphon.ContextError(errors.New("missing param"))
+		return nil, psiphon.ContextError(fmt.Errorf("missing param: %s", name))
 	}
-	value, ok := params[name].([]requestJSONObject)
+	value, ok := params[name].([]interface{})
 	if !ok {
-		return nil, psiphon.ContextError(errors.New("invalid param"))
+		return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
-	return value, nil
+
+	result := make([]requestJSONObject, len(value))
+	for i, item := range value {
+		// TODO: can't use requestJSONObject type?
+		resultItem, ok := item.(map[string]interface{})
+		if !ok {
+			return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
+		}
+		result[i] = requestJSONObject(resultItem)
+	}
+
+	return result, nil
 }
 
 func getMapStringInt64RequestParam(params requestJSONObject, name string) (map[string]int64, error) {
 	if params[name] == nil {
-		return nil, psiphon.ContextError(errors.New("missing param"))
+		return nil, psiphon.ContextError(fmt.Errorf("missing param: %s", name))
 	}
-	value, ok := params[name].(map[string]int64)
+	// TODO: can't use requestJSONObject type?
+	value, ok := params[name].(map[string]interface{})
 	if !ok {
-		return nil, psiphon.ContextError(errors.New("invalid param"))
+		return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
 	}
-	return value, nil
+
+	result := make(map[string]int64)
+	for k, v := range value {
+		numValue, ok := v.(float64)
+		if !ok {
+			return nil, psiphon.ContextError(fmt.Errorf("invalid param: %s", name))
+		}
+		result[k] = int64(numValue)
+	}
+
+	return result, nil
 }
 
 // Input validators follow the legacy validations rules in psi_web.