Просмотр исходного кода

Add API protocol to traffic rules filters

Rod Hynes 9 лет назад
Родитель
Сommit
0e22795721

+ 3 - 0
psiphon/common/protocol.go

@@ -41,6 +41,9 @@ const (
 	PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME = "psiphon-client-verification"
 
 	PSIPHON_API_CLIENT_SESSION_ID_LENGTH = 16
+
+	PSIPHON_SSH_API_PROTOCOL = "ssh"
+	PSIPHON_WEB_API_PROTOCOL = "web"
 )
 
 var SupportedTunnelProtocols = []string{

+ 4 - 1
psiphon/config.go

@@ -476,7 +476,10 @@ func LoadConfig(configJson []byte) (*Config, error) {
 			errors.New("HostNameTransformer interface must be set at runtime"))
 	}
 
-	if !common.Contains([]string{"", "ssh", "web"}, config.TargetApiProtocol) {
+	if !common.Contains(
+		[]string{"", common.PSIPHON_SSH_API_PROTOCOL, common.PSIPHON_WEB_API_PROTOCOL},
+		config.TargetApiProtocol) {
+
 		return nil, common.ContextError(
 			errors.New("invalid TargetApiProtocol"))
 	}

+ 1 - 1
psiphon/controller.go

@@ -1045,7 +1045,7 @@ loop:
 				break
 			}
 
-			if controller.config.TargetApiProtocol == "ssh" &&
+			if controller.config.TargetApiProtocol == common.PSIPHON_SSH_API_PROTOCOL &&
 				!serverEntry.SupportsSSHAPIRequests() {
 				continue
 			}

+ 12 - 4
psiphon/server/api.go

@@ -72,13 +72,19 @@ func sshAPIRequestHandler(
 			fmt.Errorf("invalid payload for request name: %s: %s", name, err))
 	}
 
-	return dispatchAPIRequestHandler(support, geoIPData, name, params)
+	return dispatchAPIRequestHandler(
+		support,
+		common.PSIPHON_SSH_API_PROTOCOL,
+		geoIPData,
+		name,
+		params)
 }
 
 // dispatchAPIRequestHandler is the common dispatch point for both
 // web and SSH API requests.
 func dispatchAPIRequestHandler(
 	support *SupportServices,
+	apiProtocol string,
 	geoIPData GeoIPData,
 	name string,
 	params requestJSONObject) (response []byte, reterr error) {
@@ -97,7 +103,7 @@ func dispatchAPIRequestHandler(
 
 	switch name {
 	case common.PSIPHON_API_HANDSHAKE_REQUEST_NAME:
-		return handshakeAPIRequestHandler(support, geoIPData, params)
+		return handshakeAPIRequestHandler(support, apiProtocol, geoIPData, params)
 	case common.PSIPHON_API_CONNECTED_REQUEST_NAME:
 		return connectedAPIRequestHandler(support, geoIPData, params)
 	case common.PSIPHON_API_STATUS_REQUEST_NAME:
@@ -115,6 +121,7 @@ func dispatchAPIRequestHandler(
 // stats to record, etc.
 func handshakeAPIRequestHandler(
 	support *SupportServices,
+	apiProtocol string,
 	geoIPData GeoIPData,
 	params requestJSONObject) ([]byte, error) {
 
@@ -151,8 +158,9 @@ func handshakeAPIRequestHandler(
 	err = support.TunnelServer.SetClientHandshakeState(
 		sessionID,
 		handshakeState{
-			completed: true,
-			apiParams: copyBaseRequestParams(params),
+			completed:   true,
+			apiProtocol: apiProtocol,
+			apiParams:   copyBaseRequestParams(params),
 		})
 	if err != nil {
 		return nil, common.ContextError(err)

+ 17 - 0
psiphon/server/trafficRules.go

@@ -64,6 +64,11 @@ type TrafficRulesFilter struct {
 	// region matches.
 	Regions []string
 
+	// APIProtocol specifies whether the client must use the SSH
+	// API protocol (when "ssh") or the web API protocol (when "web").
+	// When omitted or blank, any API protocol matches.
+	APIProtocol string
+
 	// SponsorIDs is a list of client handshake sponsor IDs that must be
 	// specified to match this filter. When omitted or empty, any client
 	// sponsor ID matches.
@@ -284,6 +289,15 @@ func (set *TrafficRulesSet) GetTrafficRules(
 			}
 		}
 
+		if filteredRules.Filter.APIProtocol != "" {
+			if !state.completed {
+				continue
+			}
+			if state.apiProtocol != filteredRules.Filter.APIProtocol {
+				continue
+			}
+		}
+
 		// Note: ignoring param format errors as params have been validated
 
 		if len(filteredRules.Filter.SponsorIDs) > 0 {
@@ -307,6 +321,9 @@ func (set *TrafficRulesSet) GetTrafficRules(
 		}
 
 		if filteredRules.Filter.MinClientVersion != nil || filteredRules.Filter.MaxClientVersion != nil {
+			if !state.completed {
+				continue
+			}
 			clientVersionStr, _ := getStringRequestParam(state.apiParams, "client_version")
 			clientVersion, _ := strconv.Atoi(clientVersionStr)
 			if filteredRules.Filter.MinClientVersion != nil && clientVersion < *filteredRules.Filter.MinClientVersion {

+ 3 - 2
psiphon/server/tunnelServer.go

@@ -632,8 +632,9 @@ type trafficState struct {
 }
 
 type handshakeState struct {
-	completed bool
-	apiParams requestJSONObject
+	completed   bool
+	apiProtocol string
+	apiParams   requestJSONObject
 }
 
 func newSshClient(

+ 4 - 0
psiphon/server/webServer.go

@@ -234,6 +234,7 @@ func (webServer *webServer) handshakeHandler(w http.ResponseWriter, r *http.Requ
 	if err == nil {
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
+			common.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
 			common.PSIPHON_API_HANDSHAKE_REQUEST_NAME,
 			params)
@@ -263,6 +264,7 @@ func (webServer *webServer) connectedHandler(w http.ResponseWriter, r *http.Requ
 	if err == nil {
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
+			common.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
 			common.PSIPHON_API_CONNECTED_REQUEST_NAME,
 			params)
@@ -285,6 +287,7 @@ func (webServer *webServer) statusHandler(w http.ResponseWriter, r *http.Request
 	if err == nil {
 		_, err = dispatchAPIRequestHandler(
 			webServer.support,
+			common.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
 			common.PSIPHON_API_STATUS_REQUEST_NAME,
 			params)
@@ -307,6 +310,7 @@ func (webServer *webServer) clientVerificationHandler(w http.ResponseWriter, r *
 	if err == nil {
 		responsePayload, err = dispatchAPIRequestHandler(
 			webServer.support,
+			common.PSIPHON_WEB_API_PROTOCOL,
 			webServer.lookupGeoIPData(params),
 			common.PSIPHON_API_CLIENT_VERIFICATION_REQUEST_NAME,
 			params)

+ 3 - 1
psiphon/serverApi.go

@@ -87,7 +87,9 @@ func NewServerContext(tunnel *Tunnel, sessionId string) (*ServerContext, error)
 	// For legacy servers, set up psiphonHttpsClient for
 	// accessing the Psiphon API via the web service.
 	var psiphonHttpsClient *http.Client
-	if !tunnel.serverEntry.SupportsSSHAPIRequests() || tunnel.config.TargetApiProtocol == "web" {
+	if !tunnel.serverEntry.SupportsSSHAPIRequests() ||
+		tunnel.config.TargetApiProtocol == common.PSIPHON_WEB_API_PROTOCOL {
+
 		var err error
 		psiphonHttpsClient, err = makePsiphonHttpsClient(tunnel)
 		if err != nil {