Răsfoiți Sursa

Add destination bytes metrics

Rod Hynes 3 ani în urmă
părinte
comite
00acd35446

+ 9 - 9
psiphon/common/activity.go

@@ -57,7 +57,7 @@ type ActivityMonitoredConn struct {
 	net.Conn
 	inactivityTimeout time.Duration
 	activeOnWrite     bool
-	activityUpdater   ActivityUpdater
+	activityUpdaters  []ActivityUpdater
 	lruEntry          *LRUConnsEntry
 }
 
@@ -65,7 +65,7 @@ type ActivityMonitoredConn struct {
 // ActivityMonitoredConn activity. Values passed to UpdateProgress are bytes
 // transferred and conn duration since the previous UpdateProgress.
 type ActivityUpdater interface {
-	UpdateProgress(bytesRead, bytesWritten int64, durationNanoseconds int64)
+	UpdateProgress(bytesRead, bytesWritten, durationNanoseconds int64)
 }
 
 // NewActivityMonitoredConn creates a new ActivityMonitoredConn.
@@ -73,8 +73,8 @@ func NewActivityMonitoredConn(
 	conn net.Conn,
 	inactivityTimeout time.Duration,
 	activeOnWrite bool,
-	activityUpdater ActivityUpdater,
-	lruEntry *LRUConnsEntry) (*ActivityMonitoredConn, error) {
+	lruEntry *LRUConnsEntry,
+	activityUpdaters ...ActivityUpdater) (*ActivityMonitoredConn, error) {
 
 	if inactivityTimeout > 0 {
 		err := conn.SetDeadline(time.Now().Add(inactivityTimeout))
@@ -95,8 +95,8 @@ func NewActivityMonitoredConn(
 		realStartTime:        time.Now(),
 		monotonicStartTime:   now,
 		lastReadActivityTime: now,
-		activityUpdater:      activityUpdater,
 		lruEntry:             lruEntry,
+		activityUpdaters:     activityUpdaters,
 	}, nil
 }
 
@@ -129,8 +129,8 @@ func (conn *ActivityMonitoredConn) Read(buffer []byte) (int, error) {
 
 		atomic.StoreInt64(&conn.lastReadActivityTime, readActivityTime)
 
-		if conn.activityUpdater != nil {
-			conn.activityUpdater.UpdateProgress(
+		for _, activityUpdater := range conn.activityUpdaters {
+			activityUpdater.UpdateProgress(
 				int64(n), 0, readActivityTime-lastReadActivityTime)
 		}
 
@@ -153,8 +153,8 @@ func (conn *ActivityMonitoredConn) Write(buffer []byte) (int, error) {
 			}
 		}
 
-		if conn.activityUpdater != nil {
-			conn.activityUpdater.UpdateProgress(0, int64(n), 0)
+		for _, activityUpdater := range conn.activityUpdaters {
+			activityUpdater.UpdateProgress(0, int64(n), 0)
 		}
 
 		if conn.lruEntry != nil {

+ 3 - 4
psiphon/common/activity_test.go

@@ -34,7 +34,6 @@ func TestActivityMonitoredConn(t *testing.T) {
 		&dummyConn{},
 		200*time.Millisecond,
 		true,
-		nil,
 		nil)
 	if err != nil {
 		t.Fatalf("NewActivityMonitoredConn failed")
@@ -106,19 +105,19 @@ func TestActivityMonitoredLRUConns(t *testing.T) {
 	lruConns := NewLRUConns()
 
 	dummy1 := &dummyConn{}
-	conn1, err := NewActivityMonitoredConn(dummy1, 0, true, nil, lruConns.Add(dummy1))
+	conn1, err := NewActivityMonitoredConn(dummy1, 0, true, lruConns.Add(dummy1))
 	if err != nil {
 		t.Fatalf("NewActivityMonitoredConn failed")
 	}
 
 	dummy2 := &dummyConn{}
-	conn2, err := NewActivityMonitoredConn(dummy2, 0, true, nil, lruConns.Add(dummy2))
+	conn2, err := NewActivityMonitoredConn(dummy2, 0, true, lruConns.Add(dummy2))
 	if err != nil {
 		t.Fatalf("NewActivityMonitoredConn failed")
 	}
 
 	dummy3 := &dummyConn{}
-	conn3, err := NewActivityMonitoredConn(dummy3, 0, true, nil, lruConns.Add(dummy3))
+	conn3, err := NewActivityMonitoredConn(dummy3, 0, true, lruConns.Add(dummy3))
 	if err != nil {
 		t.Fatalf("NewActivityMonitoredConn failed")
 	}

+ 1 - 1
psiphon/common/osl/osl.go

@@ -521,7 +521,7 @@ func (state *ClientSeedState) sendIssueSLOKsSignal() {
 // design is that progress reported at the exact time of SLOK time period
 // rollover may be dropped.
 func (portForward *ClientSeedPortForward) UpdateProgress(
-	bytesRead, bytesWritten int64, durationNanoseconds int64) {
+	bytesRead, bytesWritten, durationNanoseconds int64) {
 
 	// Concurrency: non-blocking -- access to ClientSeedState is unsynchronized
 	// to read-only fields, atomic, or channels, except in the case of a time

+ 3 - 0
psiphon/common/parameters/parameters.go

@@ -300,6 +300,7 @@ const (
 	RestrictFrontingProviderIDsServerProbability     = "RestrictFrontingProviderIDsServerProbability"
 	RestrictFrontingProviderIDsClientProbability     = "RestrictFrontingProviderIDsClientProbability"
 	UpstreamProxyAllowAllServerEntrySources          = "UpstreamProxyAllowAllServerEntrySources"
+	DestinationBytesMetricsASN                       = "DestinationBytesMetricsASN"
 )
 
 const (
@@ -634,6 +635,8 @@ var defaultParameters = map[string]struct {
 	RestrictFrontingProviderIDsClientProbability: {value: 0.0, minimum: 0.0},
 
 	UpstreamProxyAllowAllServerEntrySources: {value: false},
+
+	DestinationBytesMetricsASN: {value: "", flags: serverSideOnly},
 }
 
 // IsServerSideOnly indicates if the parameter specified by name is used

+ 14 - 8
psiphon/common/tun/tun.go

@@ -348,14 +348,15 @@ type AllowedDomainChecker func(string) bool
 // flow activity. Values passed to UpdateProgress are bytes transferred
 // and flow duration since the previous UpdateProgress.
 type FlowActivityUpdater interface {
-	UpdateProgress(downstreamBytes, upstreamBytes int64, durationNanoseconds int64)
+	UpdateProgress(downstreamBytes, upstreamBytes, durationNanoseconds int64)
 }
 
 // FlowActivityUpdaterMaker is a function which returns a list of
 // appropriate updaters for a new flow to the specified upstream
 // hostname (if known -- may be ""), and IP address.
+// The flow is TCP when isTCP is true, and UDP otherwise.
 type FlowActivityUpdaterMaker func(
-	upstreamHostname string, upstreamIPAddress net.IP) []FlowActivityUpdater
+	isTCP bool, upstreamHostname string, upstreamIPAddress net.IP) []FlowActivityUpdater
 
 // MetricsUpdater is a function which receives a checkpoint summary
 // of application bytes transferred through a packet tunnel.
@@ -1389,20 +1390,25 @@ func (session *session) startTrackingFlow(
 		session.reapFlows()
 	}
 
+	var isTCP bool
 	var hostname string
-	//lint:ignore SA9003 intentionally empty branch
 	if ID.protocol == internetProtocolTCP {
 		// TODO: implement
 		// hostname = common.ExtractHostnameFromTCPFlow(applicationData)
+		isTCP = true
 	}
 
 	var activityUpdaters []FlowActivityUpdater
 
-	flowActivityUpdaterMaker := session.getFlowActivityUpdaterMaker()
-	if flowActivityUpdaterMaker != nil {
-		activityUpdaters = flowActivityUpdaterMaker(
-			hostname,
-			net.IP(ID.upstreamIPAddress[:]))
+	// Don't incur activity monitor overhead for DNS requests
+	if !isDNS {
+		flowActivityUpdaterMaker := session.getFlowActivityUpdaterMaker()
+		if flowActivityUpdaterMaker != nil {
+			activityUpdaters = flowActivityUpdaterMaker(
+				isTCP,
+				hostname,
+				net.IP(ID.upstreamIPAddress[:]))
+		}
 	}
 
 	flowState := &flowState{

+ 22 - 1
psiphon/server/geoip.go

@@ -38,6 +38,7 @@ import (
 const (
 	GEOIP_SESSION_CACHE_TTL = 60 * time.Minute
 	GEOIP_UNKNOWN_VALUE     = "None"
+	GEOIP_DATABASE_TYPE_ISP = "GeoIP2-ISP"
 )
 
 // GeoIPData is GeoIP data for a client session. Individual client
@@ -96,6 +97,7 @@ type geoIPDatabase struct {
 	filename       string
 	tempFilename   string
 	tempFileSuffix int64
+	isISPType      bool
 	maxMindReader  *maxminddb.Reader
 }
 
@@ -163,7 +165,10 @@ func NewGeoIPService(databaseFilenames []string) (*GeoIPService, error) {
 					_ = os.Remove(database.tempFilename)
 				}
 
+				isISPType := (maxMindReader.Metadata.DatabaseType == GEOIP_DATABASE_TYPE_ISP)
+
 				database.maxMindReader = maxMindReader
+				database.isISPType = isISPType
 				database.tempFilename = tempFilename
 				database.tempFileSuffix = tempFileSuffix
 
@@ -199,6 +204,17 @@ func (geoIP *GeoIPService) Lookup(strIP string) GeoIPData {
 
 // LookupIP determines a GeoIPData for a given client IP address.
 func (geoIP *GeoIPService) LookupIP(IP net.IP) GeoIPData {
+	return geoIP.lookupIP(IP, false)
+}
+
+// LookupISPForIP determines a GeoIPData for a given client IP address. Only
+// ISP, ASN, and ASO fields will be populated. This lookup is faster than a
+// full lookup.
+func (geoIP *GeoIPService) LookupISPForIP(IP net.IP) GeoIPData {
+	return geoIP.lookupIP(IP, true)
+}
+
+func (geoIP *GeoIPService) lookupIP(IP net.IP, ISPOnly bool) GeoIPData {
 
 	result := NewGeoIPData()
 
@@ -227,7 +243,12 @@ func (geoIP *GeoIPService) LookupIP(IP net.IP) GeoIPData {
 	// the separate ISP database populates ISP.
 	for _, database := range geoIP.databases {
 		database.ReloadableFile.RLock()
-		err := database.maxMindReader.Lookup(IP, &geoIPFields)
+		var err error
+		// Don't lookup the City database when only ISP fields are required;
+		// skipping the City lookup is 5-10x faster.
+		if !ISPOnly || database.isISPType {
+			err = database.maxMindReader.Lookup(IP, &geoIPFields)
+		}
 		database.ReloadableFile.RUnlock()
 		if err != nil {
 			log.WithTraceFields(LogFields{"error": err}).Warning("GeoIP lookup failed")

+ 142 - 45
psiphon/server/tunnelServer.go

@@ -1435,6 +1435,9 @@ type sshClient struct {
 	sendAlertRequests                    chan protocol.AlertRequest
 	sentAlertRequests                    map[string]bool
 	peakMetrics                          peakMetrics
+	destinationBytesMetricsASN           string
+	tcpDestinationBytesMetrics           destinationBytesMetrics
+	udpDestinationBytesMetrics           destinationBytesMetrics
 }
 
 type trafficState struct {
@@ -1558,6 +1561,18 @@ type handshakeState struct {
 	splitTunnelLookup       *splitTunnelLookup
 }
 
+type destinationBytesMetrics struct {
+	bytesUp   int64
+	bytesDown int64
+}
+
+func (d *destinationBytesMetrics) UpdateProgress(
+	downstreamBytes, upstreamBytes, _ int64) {
+
+	atomic.AddInt64(&d.bytesUp, upstreamBytes)
+	atomic.AddInt64(&d.bytesDown, downstreamBytes)
+}
+
 type splitTunnelLookup struct {
 	regions       []string
 	regionsLookup map[string]bool
@@ -1685,7 +1700,6 @@ func (sshClient *sshClient) run(
 		conn,
 		SSH_CONNECTION_READ_DEADLINE,
 		false,
-		nil,
 		nil)
 	if err != nil {
 		conn.Close()
@@ -2649,14 +2663,21 @@ func (sshClient *sshClient) handleNewPacketTunnelChannel(
 	}
 
 	flowActivityUpdaterMaker := func(
-		upstreamHostname string, upstreamIPAddress net.IP) []tun.FlowActivityUpdater {
+		isTCP bool, upstreamHostname string, upstreamIPAddress net.IP) []tun.FlowActivityUpdater {
+
+		trafficType := portForwardTypeTCP
+		if !isTCP {
+			trafficType = portForwardTypeUDP
+		}
 
-		var updaters []tun.FlowActivityUpdater
-		oslUpdater := sshClient.newClientSeedPortForward(upstreamIPAddress)
-		if oslUpdater != nil {
-			updaters = append(updaters, oslUpdater)
+		activityUpdaters := sshClient.getActivityUpdaters(trafficType, upstreamIPAddress)
+
+		flowUpdaters := make([]tun.FlowActivityUpdater, len(activityUpdaters))
+		for i, activityUpdater := range activityUpdaters {
+			flowUpdaters[i] = activityUpdater
 		}
-		return updaters
+
+		return flowUpdaters
 	}
 
 	metricUpdater := func(
@@ -2873,6 +2894,18 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 	logFields["random_stream_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.downstreamBytes
 	logFields["random_stream_sent_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.sentDownstreamBytes
 
+	if sshClient.destinationBytesMetricsASN != "" {
+		logFields["dest_bytes_asn"] = sshClient.destinationBytesMetricsASN
+		logFields["dest_bytes_up_tcp"] = sshClient.tcpDestinationBytesMetrics.bytesUp
+		logFields["dest_bytes_down_tcp"] = sshClient.tcpDestinationBytesMetrics.bytesDown
+		logFields["dest_bytes_up_udp"] = sshClient.udpDestinationBytesMetrics.bytesUp
+		logFields["dest_bytes_down_udp"] = sshClient.udpDestinationBytesMetrics.bytesDown
+		logFields["dest_bytes"] = sshClient.tcpDestinationBytesMetrics.bytesUp +
+			sshClient.tcpDestinationBytesMetrics.bytesDown +
+			sshClient.udpDestinationBytesMetrics.bytesUp +
+			sshClient.udpDestinationBytesMetrics.bytesDown
+	}
+
 	// Only log fields for peakMetrics when there is data recorded, otherwise
 	// omit the field.
 	if sshClient.peakMetrics.concurrentProximateAcceptedClients != nil {
@@ -3300,6 +3333,16 @@ func (sshClient *sshClient) setHandshakeState(
 
 	sshClient.setOSLConfig()
 
+	// Set destination bytes metrics.
+	//
+	// Limitation: this is a one-time operation and doesn't get reset when
+	// tactics are hot-reloaded. This allows us to simply retain any
+	// destination byte counts accumulated and eventually log in
+	// server_tunnel, without having to deal with a destination change
+	// mid-tunnel. As typical tunnels are short, and destination changes can
+	// be applied gradually, handling mid-tunnel changes is not a priority.
+	sshClient.setDestinationBytesMetrics()
+
 	return &handshakeStateInfo{
 		activeAuthorizationIDs:   authorizationIDs,
 		authorizedAccessTypes:    authorizedAccessTypes,
@@ -3372,33 +3415,6 @@ func (sshClient *sshClient) expectDomainBytes() bool {
 	return sshClient.handshakeState.expectDomainBytes
 }
 
-// setTrafficRules resets the client's traffic rules based on the latest server config
-// and client properties. As sshClient.trafficRules may be reset by a concurrent
-// goroutine, trafficRules must only be accessed within the sshClient mutex.
-func (sshClient *sshClient) setTrafficRules() (int64, int64) {
-	sshClient.Lock()
-	defer sshClient.Unlock()
-
-	isFirstTunnelInSession := sshClient.isFirstTunnelInSession &&
-		sshClient.handshakeState.establishedTunnelsCount == 0
-
-	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
-		isFirstTunnelInSession,
-		sshClient.tunnelProtocol,
-		sshClient.geoIPData,
-		sshClient.handshakeState)
-
-	if sshClient.throttledConn != nil {
-		// Any existing throttling state is reset.
-		sshClient.throttledConn.SetLimits(
-			sshClient.trafficRules.RateLimits.CommonRateLimits(
-				sshClient.handshakeState.completed))
-	}
-
-	return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
-		*sshClient.trafficRules.RateLimits.WriteBytesPerSecond
-}
-
 // setOSLConfig resets the client's OSL seed state based on the latest OSL config
 // As sshClient.oslClientSeedState may be reset by a concurrent goroutine,
 // oslClientSeedState must only be accessed within the sshClient mutex.
@@ -3449,7 +3465,7 @@ func (sshClient *sshClient) setOSLConfig() {
 
 // newClientSeedPortForward will return nil when no seeding is
 // associated with the specified ipAddress.
-func (sshClient *sshClient) newClientSeedPortForward(ipAddress net.IP) *osl.ClientSeedPortForward {
+func (sshClient *sshClient) newClientSeedPortForward(IPAddress net.IP) *osl.ClientSeedPortForward {
 	sshClient.Lock()
 	defer sshClient.Unlock()
 
@@ -3458,7 +3474,7 @@ func (sshClient *sshClient) newClientSeedPortForward(ipAddress net.IP) *osl.Clie
 		return nil
 	}
 
-	return sshClient.oslClientSeedState.NewClientSeedPortForward(ipAddress)
+	return sshClient.oslClientSeedState.NewClientSeedPortForward(IPAddress)
 }
 
 // getOSLSeedPayload returns a payload containing all seeded SLOKs for
@@ -3482,6 +3498,94 @@ func (sshClient *sshClient) clearOSLSeedPayload() {
 	sshClient.oslClientSeedState.ClearSeedPayload()
 }
 
+func (sshClient *sshClient) setDestinationBytesMetrics() {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	// Limitation: the server-side tactics cache is used to avoid the overhead
+	// of an additional tactics filtering per tunnel. As this cache is
+	// designed for GeoIP filtering only, handshake API parameters are not
+	// applied to tactics filtering in this case.
+
+	tacticsCache := sshClient.sshServer.support.ServerTacticsParametersCache
+	if tacticsCache == nil {
+		return
+	}
+
+	p, err := tacticsCache.Get(sshClient.geoIPData)
+	if err != nil {
+		log.WithTraceFields(LogFields{"error": err}).Warning("get tactics failed")
+		return
+	}
+	if p.IsNil() {
+		return
+	}
+
+	sshClient.destinationBytesMetricsASN = p.String(parameters.DestinationBytesMetricsASN)
+}
+
+func (sshClient *sshClient) newDestinationBytesMetricsUpdater(portForwardType int, IPAddress net.IP) *destinationBytesMetrics {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	if sshClient.destinationBytesMetricsASN == "" {
+		return nil
+	}
+
+	if sshClient.sshServer.support.GeoIPService.LookupISPForIP(IPAddress).ASN != sshClient.destinationBytesMetricsASN {
+		return nil
+	}
+
+	if portForwardType == portForwardTypeTCP {
+		return &sshClient.tcpDestinationBytesMetrics
+	}
+
+	return &sshClient.udpDestinationBytesMetrics
+}
+
+func (sshClient *sshClient) getActivityUpdaters(portForwardType int, IPAddress net.IP) []common.ActivityUpdater {
+	var updaters []common.ActivityUpdater
+
+	clientSeedPortForward := sshClient.newClientSeedPortForward(IPAddress)
+	if clientSeedPortForward != nil {
+		updaters = append(updaters, clientSeedPortForward)
+	}
+
+	destinationBytesMetrics := sshClient.newDestinationBytesMetricsUpdater(portForwardType, IPAddress)
+	if destinationBytesMetrics != nil {
+		updaters = append(updaters, destinationBytesMetrics)
+	}
+
+	return updaters
+}
+
+// setTrafficRules resets the client's traffic rules based on the latest server config
+// and client properties. As sshClient.trafficRules may be reset by a concurrent
+// goroutine, trafficRules must only be accessed within the sshClient mutex.
+func (sshClient *sshClient) setTrafficRules() (int64, int64) {
+	sshClient.Lock()
+	defer sshClient.Unlock()
+
+	isFirstTunnelInSession := sshClient.isFirstTunnelInSession &&
+		sshClient.handshakeState.establishedTunnelsCount == 0
+
+	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
+		isFirstTunnelInSession,
+		sshClient.tunnelProtocol,
+		sshClient.geoIPData,
+		sshClient.handshakeState)
+
+	if sshClient.throttledConn != nil {
+		// Any existing throttling state is reset.
+		sshClient.throttledConn.SetLimits(
+			sshClient.trafficRules.RateLimits.CommonRateLimits(
+				sshClient.handshakeState.completed))
+	}
+
+	return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
+		*sshClient.trafficRules.RateLimits.WriteBytesPerSecond
+}
+
 func (sshClient *sshClient) rateLimits() common.RateLimits {
 	sshClient.Lock()
 	defer sshClient.Unlock()
@@ -4165,19 +4269,12 @@ func (sshClient *sshClient) handleTCPChannel(
 	// forward if both reads and writes have been idle for the specified
 	// duration.
 
-	// Ensure nil interface if newClientSeedPortForward returns nil
-	var updater common.ActivityUpdater
-	seedUpdater := sshClient.newClientSeedPortForward(IP)
-	if seedUpdater != nil {
-		updater = seedUpdater
-	}
-
 	fwdConn, err = common.NewActivityMonitoredConn(
 		fwdConn,
 		sshClient.idleTCPPortForwardTimeout(),
 		true,
-		updater,
-		lruEntry)
+		lruEntry,
+		sshClient.getActivityUpdaters(portForwardTypeTCP, IP)...)
 	if err != nil {
 		log.WithTraceFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
 		return

+ 6 - 7
psiphon/server/udp.go

@@ -269,19 +269,18 @@ func (mux *udpgwPortForwardMultiplexer) run() {
 			// forward if both reads and writes have been idle for the specified
 			// duration.
 
-			// Ensure nil interface if newClientSeedPortForward returns nil
-			var updater common.ActivityUpdater
-			seedUpdater := mux.sshClient.newClientSeedPortForward(dialIP)
-			if seedUpdater != nil {
-				updater = seedUpdater
+			var activityUpdaters []common.ActivityUpdater
+			// Don't incur activity monitor overhead for DNS requests
+			if !message.forwardDNS {
+				activityUpdaters = mux.sshClient.getActivityUpdaters(portForwardTypeUDP, dialIP)
 			}
 
 			conn, err := common.NewActivityMonitoredConn(
 				udpConn,
 				mux.sshClient.idleUDPPortForwardTimeout(),
 				true,
-				updater,
-				lruEntry)
+				lruEntry,
+				activityUpdaters...)
 			if err != nil {
 				lruEntry.Remove()
 				mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0)