Browse Source

Report previously established tunnel count

- Indicates whether the tunnel is the first in a session
- Used in traffic rule selection
Rod Hynes 6 years ago
parent
commit
b6e5eed856

+ 7 - 1
psiphon/controller.go

@@ -30,6 +30,7 @@ import (
 	"math/rand"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
@@ -67,6 +68,7 @@ type Controller struct {
 	establishCtx                            context.Context
 	stopEstablish                           context.CancelFunc
 	establishWaitGroup                      *sync.WaitGroup
+	establishedTunnelCount                  int32
 	candidateServerEntries                  chan *candidateServerEntry
 	untunneledDialConfig                    *DialConfig
 	splitTunnelClassifier                   *SplitTunnelClassifier
@@ -731,6 +733,8 @@ loop:
 				break
 			}
 
+			atomic.AddInt32(&controller.establishedTunnelCount, 1)
+
 			NoticeActiveTunnel(
 				connectedTunnel.dialParams.ServerEntry.GetDiagnosticID(),
 				connectedTunnel.dialParams.TunnelProtocol,
@@ -1556,6 +1560,7 @@ func (controller *Controller) doFetchTactics(
 		selectProtocol,
 		serverEntry,
 		true,
+		0,
 		0)
 	if dialParams == nil {
 		// MakeDialParameters may return nil, nil when the server entry can't
@@ -1924,7 +1929,8 @@ loop:
 			selectProtocol,
 			candidateServerEntry.serverEntry,
 			false,
-			controller.establishConnectTunnelCount)
+			controller.establishConnectTunnelCount,
+			int(atomic.LoadInt32(&controller.establishedTunnelCount)))
 		if dialParams == nil || err != nil {
 
 			controller.concurrentEstablishTunnelsMutex.Unlock()

+ 8 - 5
psiphon/dialParameters.go

@@ -58,10 +58,11 @@ import (
 //
 // DialParameters is not safe for concurrent use.
 type DialParameters struct {
-	ServerEntry     *protocol.ServerEntry `json:"-"`
-	NetworkID       string                `json:"-"`
-	IsReplay        bool                  `json:"-"`
-	CandidateNumber int                   `json:"-"`
+	ServerEntry            *protocol.ServerEntry `json:"-"`
+	NetworkID              string                `json:"-"`
+	IsReplay               bool                  `json:"-"`
+	CandidateNumber        int                   `json:"-"`
+	EstablishedTunnelCount int                   `json:"-"`
 
 	IsExchanged bool
 
@@ -149,7 +150,8 @@ func MakeDialParameters(
 	selectProtocol func(serverEntry *protocol.ServerEntry) (string, bool),
 	serverEntry *protocol.ServerEntry,
 	isTactics bool,
-	candidateNumber int) (*DialParameters, error) {
+	candidateNumber int,
+	establishedTunnelCount int) (*DialParameters, error) {
 
 	networkID := config.GetNetworkID()
 
@@ -267,6 +269,7 @@ func MakeDialParameters(
 	dialParams.NetworkID = networkID
 	dialParams.IsReplay = isReplay
 	dialParams.CandidateNumber = candidateNumber
+	dialParams.EstablishedTunnelCount = establishedTunnelCount
 
 	// Even when replaying, LastUsedTimestamp is updated to extend the TTL of
 	// replayed dial parameters which will be updated in the datastore upon

+ 12 - 12
psiphon/dialParameters_test.go

@@ -108,7 +108,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	// Test: expected dial parameter fields set
 
-	dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -205,7 +205,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	dialParams.Failed(clientConfig)
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -220,7 +220,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	testNetworkID = prng.HexString(8)
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -237,7 +237,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	dialParams.Succeeded()
 
-	replayDialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	replayDialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -323,7 +323,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetClientParameters failed: %s", err)
 	}
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -338,7 +338,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	time.Sleep(1 * time.Second)
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -353,7 +353,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 	serverEntries[0].ConfigurationVersion += 1
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -377,14 +377,14 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 		t.Fatalf("SetClientParameters failed: %s", err)
 	}
 
-	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	dialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
 
 	dialParams.Succeeded()
 
-	replayDialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0)
+	replayDialParams, err = MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntries[0], false, 0, 0)
 	if err != nil {
 		t.Fatalf("MakeDialParameters failed: %s", err)
 	}
@@ -432,7 +432,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 
 		if i%10 == 0 {
 
-			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0)
+			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0, 0)
 			if err != nil {
 				t.Fatalf("MakeDialParameters failed: %s", err)
 			}
@@ -461,7 +461,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 				t.Fatalf("ServerEntryIterator.Next failed: %s", err)
 			}
 
-			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0)
+			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0, 0)
 			if err != nil {
 				t.Fatalf("MakeDialParameters failed: %s", err)
 			}
@@ -483,7 +483,7 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 				t.Fatalf("ServerEntryIterator.Next failed: %s", err)
 			}
 
-			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0)
+			dialParams, err := MakeDialParameters(clientConfig, canReplay, selectProtocol, serverEntry, false, 0, 0)
 			if err != nil {
 				t.Fatalf("MakeDialParameters failed: %s", err)
 			}

+ 1 - 0
psiphon/exchange_test.go

@@ -184,6 +184,7 @@ func TestServerEntryExchange(t *testing.T) {
 				selectProtocol,
 				serverEntry,
 				false,
+				0,
 				0)
 			if err != nil {
 				t.Fatalf("MakeDialParameters failed: %s", err)

+ 1 - 0
psiphon/notice.go

@@ -433,6 +433,7 @@ func noticeWithDialParameters(noticeType string, dialParams *DialParameters) {
 		"protocol", dialParams.TunnelProtocol,
 		"isReplay", dialParams.IsReplay,
 		"candidateNumber", dialParams.CandidateNumber,
+		"establishedTunnelCount", dialParams.EstablishedTunnelCount,
 		"networkType", dialParams.GetNetworkType(),
 	}
 

+ 27 - 10
psiphon/server/api.go

@@ -205,6 +205,10 @@ func handshakeAPIRequestHandler(
 	isMobile := isMobileClientPlatform(clientPlatform)
 	normalizedPlatform := normalizeClientPlatform(clientPlatform)
 
+	// establishedTunnelCount is used in traffic rule selection. When omitted by
+	// the client, a value of 0 will be used.
+	establishedTunnelCount, _ := getOptionalIntRequestParam(params, "established_tunnel_count")
+
 	var authorizations []string
 	if params[protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS] != nil {
 		authorizations, err = getStringArrayRequestParam(params, protocol.PSIPHON_API_HANDSHAKE_AUTHORIZATIONS)
@@ -224,13 +228,14 @@ func handshakeAPIRequestHandler(
 	// TODO: in the case of SSH API requests, the actual sshClient could
 	// be passed in and used here. The session ID lookup is only strictly
 	// necessary to support web API requests.
-	clientHandshakeStateInfo, err := support.TunnelServer.SetClientHandshakeState(
+	handshakeStateInfo, err := support.TunnelServer.SetClientHandshakeState(
 		sessionID,
 		handshakeState{
-			completed:         true,
-			apiProtocol:       apiProtocol,
-			apiParams:         copyBaseRequestParams(params),
-			expectDomainBytes: len(httpsRequestRegexes) > 0,
+			completed:              true,
+			apiProtocol:            apiProtocol,
+			apiParams:              copyBaseRequestParams(params),
+			expectDomainBytes:      len(httpsRequestRegexes) > 0,
+			establishedTunnelCount: establishedTunnelCount,
 		},
 		authorizations)
 	if err != nil {
@@ -261,7 +266,7 @@ func handshakeAPIRequestHandler(
 			logFields := getRequestLogFields(
 				tactics.TACTICS_METRIC_EVENT_NAME,
 				geoIPData,
-				clientHandshakeStateInfo.AuthorizedAccessTypes,
+				handshakeStateInfo.authorizedAccessTypes,
 				params,
 				handshakeRequestParams)
 
@@ -284,7 +289,7 @@ func handshakeAPIRequestHandler(
 		getRequestLogFields(
 			"",
 			geoIPData,
-			clientHandshakeStateInfo.AuthorizedAccessTypes,
+			handshakeStateInfo.authorizedAccessTypes,
 			params,
 			baseRequestParams)).Debug("handshake")
 
@@ -319,10 +324,10 @@ func handshakeAPIRequestHandler(
 		EncodedServerList:        encodedServerList,
 		ClientRegion:             geoIPData.Country,
 		ServerTimestamp:          common.GetCurrentTimestamp(),
-		ActiveAuthorizationIDs:   clientHandshakeStateInfo.ActiveAuthorizationIDs,
+		ActiveAuthorizationIDs:   handshakeStateInfo.activeAuthorizationIDs,
 		TacticsPayload:           marshaledTacticsPayload,
-		UpstreamBytesPerSecond:   clientHandshakeStateInfo.UpstreamBytesPerSecond,
-		DownstreamBytesPerSecond: clientHandshakeStateInfo.DownstreamBytesPerSecond,
+		UpstreamBytesPerSecond:   handshakeStateInfo.upstreamBytesPerSecond,
+		DownstreamBytesPerSecond: handshakeStateInfo.downstreamBytesPerSecond,
 		Padding:                  strings.Repeat(" ", pad_response),
 	}
 
@@ -770,6 +775,7 @@ var baseRequestParams = []requestParamSpec{
 	{"egress_region", isRegionCode, requestParamOptional},
 	{"dial_duration", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"candidate_number", isIntString, requestParamOptional | requestParamLogStringAsInt},
+	{"established_tunnel_count", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"upstream_ossh_padding", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"meek_cookie_size", isIntString, requestParamOptional | requestParamLogStringAsInt},
 	{"meek_limit_request", isIntString, requestParamOptional | requestParamLogStringAsInt},
@@ -1101,6 +1107,17 @@ func getStringRequestParam(params common.APIParameters, name string) (string, er
 	return value, nil
 }
 
+func getOptionalIntRequestParam(params common.APIParameters, name string) (int, bool) {
+	if params[name] == nil {
+		return 0, false
+	}
+	value, ok := params[name].(float64)
+	if !ok {
+		return 0, false
+	}
+	return int(value), true
+}
+
 func getInt64RequestParam(params common.APIParameters, name string) (int64, error) {
 	if params[name] == nil {
 		return 0, errors.Tracef("missing param: %s", name)

+ 2 - 0
psiphon/server/server_test.go

@@ -1185,6 +1185,7 @@ func checkExpectedLogFields(
 		"is_replay",
 		"dial_duration",
 		"candidate_number",
+		"established_tunnel_count",
 		"network_latency_multiplier",
 		"network_type",
 		"client_app_id",
@@ -2121,6 +2122,7 @@ func storePruneServerEntriesTest(
 			},
 			serverEntry,
 			false,
+			0,
 			0)
 		if err != nil {
 			t.Fatalf("MakeDialParameters failed: %s", err)

+ 2 - 1
psiphon/server/sessionID_test.go

@@ -166,7 +166,8 @@ func TestDuplicateSessionID(t *testing.T) {
 			func(_ *protocol.ServerEntry) (string, bool) { return "OSSH", true },
 			serverEntry,
 			false,
-			1)
+			0,
+			0)
 		if err != nil {
 			t.Fatalf("MakeDialParameters failed: %s", err)
 		}

+ 28 - 24
psiphon/server/tunnelServer.go

@@ -284,13 +284,6 @@ func (server *TunnelServer) ResetAllClientOSLConfigs() {
 	server.sshServer.resetAllClientOSLConfigs()
 }
 
-type ClientHandshakeStateInfo struct {
-	ActiveAuthorizationIDs   []string
-	AuthorizedAccessTypes    []string
-	UpstreamBytesPerSecond   int64
-	DownstreamBytesPerSecond int64
-}
-
 // SetClientHandshakeState sets the handshake state -- that it completed and
 // what parameters were passed -- in sshClient. This state is used for allowing
 // port forwards and for future traffic rule selection. SetClientHandshakeState
@@ -307,7 +300,7 @@ type ClientHandshakeStateInfo struct {
 func (server *TunnelServer) SetClientHandshakeState(
 	sessionID string,
 	state handshakeState,
-	authorizations []string) (*ClientHandshakeStateInfo, error) {
+	authorizations []string) (*handshakeStateInfo, error) {
 
 	return server.sshServer.setClientHandshakeState(sessionID, state, authorizations)
 }
@@ -917,7 +910,7 @@ func (sshServer *sshServer) resetAllClientOSLConfigs() {
 func (sshServer *sshServer) setClientHandshakeState(
 	sessionID string,
 	state handshakeState,
-	authorizations []string) (*ClientHandshakeStateInfo, error) {
+	authorizations []string) (*handshakeStateInfo, error) {
 
 	sshServer.clientsMutex.Lock()
 	client := sshServer.clients[sessionID]
@@ -927,13 +920,13 @@ func (sshServer *sshServer) setClientHandshakeState(
 		return nil, errors.TraceNew("unknown session ID")
 	}
 
-	clientHandshakeStateInfo, err := client.setHandshakeState(
+	handshakeStateInfo, err := client.setHandshakeState(
 		state, authorizations)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
 
-	return clientHandshakeStateInfo, nil
+	return handshakeStateInfo, nil
 }
 
 func (sshServer *sshServer) getClientHandshaked(
@@ -1232,12 +1225,20 @@ type qualityMetrics struct {
 }
 
 type handshakeState struct {
-	completed             bool
-	apiProtocol           string
-	apiParams             common.APIParameters
-	authorizedAccessTypes []string
-	authorizationsRevoked bool
-	expectDomainBytes     bool
+	completed              bool
+	apiProtocol            string
+	apiParams              common.APIParameters
+	authorizedAccessTypes  []string
+	authorizationsRevoked  bool
+	expectDomainBytes      bool
+	establishedTunnelCount int
+}
+
+type handshakeStateInfo struct {
+	activeAuthorizationIDs   []string
+	authorizedAccessTypes    []string
+	upstreamBytesPerSecond   int64
+	downstreamBytesPerSecond int64
 }
 
 func newSshClient(
@@ -2517,7 +2518,7 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, logMessa
 // sshClient.stop().
 func (sshClient *sshClient) setHandshakeState(
 	state handshakeState,
-	authorizations []string) (*ClientHandshakeStateInfo, error) {
+	authorizations []string) (*handshakeStateInfo, error) {
 
 	sshClient.Lock()
 	completed := sshClient.handshakeState.completed
@@ -2666,11 +2667,11 @@ func (sshClient *sshClient) setHandshakeState(
 
 	sshClient.setOSLConfig()
 
-	return &ClientHandshakeStateInfo{
-		ActiveAuthorizationIDs:   authorizationIDs,
-		AuthorizedAccessTypes:    authorizedAccessTypes,
-		UpstreamBytesPerSecond:   upstreamBytesPerSecond,
-		DownstreamBytesPerSecond: downstreamBytesPerSecond,
+	return &handshakeStateInfo{
+		activeAuthorizationIDs:   authorizationIDs,
+		authorizedAccessTypes:    authorizedAccessTypes,
+		upstreamBytesPerSecond:   upstreamBytesPerSecond,
+		downstreamBytesPerSecond: downstreamBytesPerSecond,
 	}, nil
 }
 
@@ -2738,8 +2739,11 @@ func (sshClient *sshClient) setTrafficRules() (int64, int64) {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
+	isFirstTunnelInSession := sshClient.isFirstTunnelInSession &&
+		sshClient.handshakeState.establishedTunnelCount == 0
+
 	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
-		sshClient.isFirstTunnelInSession,
+		isFirstTunnelInSession,
 		sshClient.tunnelProtocol,
 		sshClient.geoIPData,
 		sshClient.handshakeState)

+ 2 - 0
psiphon/serverApi.go

@@ -942,6 +942,8 @@ func getBaseAPIParameters(
 
 	params["candidate_number"] = strconv.Itoa(dialParams.CandidateNumber)
 
+	params["established_tunnel_count"] = strconv.Itoa(dialParams.EstablishedTunnelCount)
+
 	if dialParams.NetworkLatencyMultiplier != 0.0 {
 		params["network_latency_multiplier"] =
 			fmt.Sprintf("%f", dialParams.NetworkLatencyMultiplier)