Browse Source

Add packet flow tracking

- Implements OSL seeding for packet tunnel, with once per flow
  initialization

- Traffic rules checked only once per flow

- Also relaxed traffic rule checking in transparent DNS case

- Prerequisite for domain bytes transferred metrics, to follow
Rod Hynes 8 years ago
parent
commit
f007795e56
3 changed files with 432 additions and 103 deletions
  1. 365 97
      psiphon/common/tun/tun.go
  2. 53 5
      psiphon/common/tun/tun_test.go
  3. 14 1
      psiphon/server/tunnelServer.go

+ 365 - 97
psiphon/common/tun/tun.go

@@ -146,6 +146,7 @@ const (
 	DEFAULT_UPSTREAM_PACKET_QUEUE_SIZE   = 32768
 	DEFAULT_IDLE_SESSION_EXPIRY_SECONDS  = 300
 	ORPHAN_METRICS_CHECKPOINTER_PERIOD   = 30 * time.Minute
+	FLOW_IDLE_EXPIRY                     = 60 * time.Second
 )
 
 // ServerConfig specifies the configuration of a packet tunnel server.
@@ -331,8 +332,24 @@ func (server *Server) Stop() {
 	server.config.Logger.WithContext().Info("stopped")
 }
 
+// AllowedPortChecker is a function which returns true when it is
+// permitted to relay packets to the specified upstream IP address
+// and/or port.
 type AllowedPortChecker func(upstreamIPAddress net.IP, port int) bool
 
+// FlowActivityUpdater defines an interface for receiving updates for
+// flow activity. Values passed to UpdateProgress are bytes transferred
+// and flow duration since the previous UpdateProgress.
+type FlowActivityUpdater interface {
+	UpdateProgress(upstreamBytes, downstreamBytes int64, durationNanoseconds int64)
+}
+
+// FlowActivityUpdaterMaker is a function which returns a list of
+// appropriate updaters for a new flow to the specified upstream
+// hostname (if known -- may be ""), and IP address.
+type FlowActivityUpdaterMaker func(
+	upstreamHostname string, upstreamIPAddress net.IP) []FlowActivityUpdater
+
 // ClientConnected handles new client connections, creating or resuming
 // a session and returns with client packet handlers running.
 //
@@ -358,7 +375,8 @@ type AllowedPortChecker func(upstreamIPAddress net.IP, port int) bool
 func (server *Server) ClientConnected(
 	sessionID string,
 	transport io.ReadWriteCloser,
-	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker) error {
+	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker,
+	flowActivityUpdaterMaker FlowActivityUpdaterMaker) error {
 
 	// It's unusual to call both sync.WaitGroup.Add() _and_ Done() in the same
 	// goroutine. There's no other place to call Add() since ClientConnected is
@@ -421,6 +439,7 @@ func (server *Server) ClientConnected(
 			DNSResolverIPv6Addresses: append([]net.IP(nil), server.config.GetDNSResolverIPv6Addresses()...),
 			checkAllowedTCPPortFunc:  checkAllowedTCPPortFunc,
 			checkAllowedUDPPortFunc:  checkAllowedUDPPortFunc,
+			flowActivityUpdaterMaker: flowActivityUpdaterMaker,
 			downstreamPackets:        NewPacketQueue(downstreamPacketQueueSize),
 			workers:                  new(sync.WaitGroup),
 		}
@@ -936,6 +955,7 @@ type session struct {
 	// at the start of struct to ensure 64-bit alignment.
 	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
 	lastActivity             int64
+	lastFlowReapIndex        int64
 	metrics                  *packetMetrics
 	sessionID                string
 	index                    int32
@@ -949,7 +969,9 @@ type session struct {
 	originalIPv6Address      net.IP
 	checkAllowedTCPPortFunc  AllowedPortChecker
 	checkAllowedUDPPortFunc  AllowedPortChecker
+	flowActivityUpdaterMaker FlowActivityUpdaterMaker
 	downstreamPackets        *PacketQueue
+	flows                    syncmap.Map
 	workers                  *sync.WaitGroup
 	mutex                    sync.Mutex
 	channel                  *Channel
@@ -957,6 +979,31 @@ type session struct {
 	stopRunning              context.CancelFunc
 }
 
+// flowID identifies an IP traffic flow using the conventional
+// network 5-tuple. flowIDs track bidirectional flows.
+type flowID struct {
+	downstreamIPAddress net.IP
+	downstreamPort      uint16
+	upstreamIPAddress   net.IP
+	upstreamPort        uint16
+	protocol            internetProtocol
+}
+
+type flowState struct {
+	// Note: 64-bit ints used with atomic operations are placed
+	// at the start of struct to ensure 64-bit alignment.
+	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
+	lastUpstreamPacketTime   int64
+	lastDownstreamPacketTime int64
+	activityUpdaters         []FlowActivityUpdater
+}
+
+func (flowState *flowState) expired(idleExpiry time.Duration) bool {
+	now := monotime.Now()
+	return (now.Sub(monotime.Time(atomic.LoadInt64(&flowState.lastUpstreamPacketTime))) > idleExpiry) ||
+		(now.Sub(monotime.Time(atomic.LoadInt64(&flowState.lastDownstreamPacketTime))) > idleExpiry)
+}
+
 func (session *session) touch() {
 	atomic.StoreInt64(&session.lastActivity, int64(monotime.Now()))
 }
@@ -997,6 +1044,124 @@ func (session *session) getOriginalIPv6Address() net.IP {
 	return session.originalIPv6Address
 }
 
+// isTrackingFlow checks if a flow is being tracked.
+func (session *session) isTrackingFlow(ID flowID) bool {
+
+	f, ok := session.flows.Load(ID)
+	if !ok {
+		return false
+	}
+	flowState := f.(*flowState)
+
+	// Check if flow is expired but not yet reaped.
+	if flowState.expired(FLOW_IDLE_EXPIRY) {
+		session.flows.Delete(ID)
+		return false
+	}
+
+	return true
+}
+
+// startTrackingFlow starts flow tracking for the flow identified
+// by ID.
+//
+// Flow tracking is used to implement:
+// - one-time permissions checks for a flow
+// - OSLs
+// - domain bytes transferred [TODO]
+//
+// The applicationData from the first packet in the flow is
+// inspected to determine any associated hostname, using HTTP or
+// TLS payload. The session's FlowActivityUpdaterMaker is invoked
+// to determine a list of updaters to track flow activity.
+//
+// Flows are untracked after an idle expiry period. Transport
+// protocol indicators of end of flow, such as FIN or RST for TCP,
+// which may or may not appear in a flow, are not currently used.
+//
+// startTrackingFlow may be called from concurrent goroutines; if
+// the flow is already tracked, it is simply updated.
+func (session *session) startTrackingFlow(
+	ID flowID, direction packetDirection, applicationData []byte) {
+
+	now := int64(monotime.Now())
+
+	// Once every period, iterate over flows and reap expired entries.
+	reapIndex := now / int64(monotime.Time(FLOW_IDLE_EXPIRY/2))
+	previousReapIndex := atomic.LoadInt64(&session.lastFlowReapIndex)
+	if reapIndex != previousReapIndex &&
+		atomic.CompareAndSwapInt64(&session.lastFlowReapIndex, previousReapIndex, reapIndex) {
+		session.reapFlows()
+	}
+
+	var hostname string
+	if ID.protocol == internetProtocolTCP {
+		// TODO: implement
+		// hostname = common.ExtractHostnameFromTCPFlow(applicationData)
+	}
+
+	flowState := &flowState{
+		activityUpdaters: session.flowActivityUpdaterMaker(hostname, ID.upstreamIPAddress),
+	}
+
+	if direction == packetDirectionServerUpstream {
+		flowState.lastUpstreamPacketTime = now
+	} else {
+		flowState.lastDownstreamPacketTime = now
+	}
+
+	// LoadOrStore will retain any existing entry
+	session.flows.LoadOrStore(ID, flowState)
+
+	session.updateFlow(ID, direction, applicationData)
+}
+
+func (session *session) updateFlow(
+	ID flowID, direction packetDirection, applicationData []byte) {
+
+	f, ok := session.flows.Load(ID)
+	if !ok {
+		return
+	}
+	flowState := f.(*flowState)
+
+	// Note: no expired check here, since caller is assumed to
+	// have just called isTrackingFlow.
+
+	now := int64(monotime.Now())
+	var upstreamBytes, downstreamBytes, durationNanoseconds int64
+
+	if direction == packetDirectionServerUpstream {
+		upstreamBytes = int64(len(applicationData))
+		atomic.StoreInt64(&flowState.lastUpstreamPacketTime, now)
+	} else {
+		downstreamBytes = int64(len(applicationData))
+
+		// Follows common.ActivityMonitoredConn semantics, where
+		// duration is updated only for downstream activity. This
+		// is intened to produce equivalent behaviour for port
+		// forward clients (tracked with ActivityUpdaters) and
+		// packet tunnel clients (tracked with FlowActivityUpdaters).
+
+		durationNanoseconds = now - atomic.SwapInt64(&flowState.lastDownstreamPacketTime, now)
+	}
+
+	for _, updater := range flowState.activityUpdaters {
+		updater.UpdateProgress(upstreamBytes, downstreamBytes, durationNanoseconds)
+	}
+}
+
+// reapFlows removes expired idle flows.
+func (session *session) reapFlows() {
+	session.flows.Range(func(key, value interface{}) bool {
+		flowState := value.(*flowState)
+		if flowState.expired(FLOW_IDLE_EXPIRY) {
+			session.flows.Delete(key)
+		}
+		return true
+	})
+}
+
 type packetMetrics struct {
 	upstreamRejectReasons   [packetRejectReasonCount]int64
 	downstreamRejectReasons [packetRejectReasonCount]int64
@@ -1033,18 +1198,8 @@ func (metrics *packetMetrics) relayedPacket(
 	direction packetDirection,
 	version int,
 	protocol internetProtocol,
-	upstreamIPAddress net.IP,
 	packetLength int) {
 
-	// TODO: OSL integration
-	// - Update OSL up/down progress for upstreamIPAddress.
-	// - For port forwards, OSL progress tracking involves one SeedSpecs subnets
-	//   lookup per port forward; this may be too much overhead per packet; OSL
-	//   progress tracking also uses port forward duration as an input.
-	// - Can we do simple flow tracking to achieve the same (a) lookup rate,
-	//   (b) duration measurement? E.g., track flow via 4-tuple of source/dest
-	//   IP/port?
-
 	var packetsMetric, bytesMetric *int64
 
 	if direction == packetDirectionServerUpstream ||
@@ -1692,18 +1847,24 @@ func getPacketDestinationIPAddress(
 	}
 }
 
+// processPacket parses IP packets, applies relaying rules,
+// and rewrites packet elements as required. processPacket
+// returns true if a packet parses correctly, is accepted
+// by the relay rules, and is successfully rewritten.
+//
+// When a packet is rejected, processPacket returns false
+// and updates a reason in the supplied metrics.
+//
+// Rejection may result in partially rewritten packets.
 func processPacket(
 	metrics *packetMetrics,
 	session *session,
 	direction packetDirection,
 	packet []byte) bool {
 
-	// Parse and validate packets and perform either upstream
-	// or downstream rewriting.
-	// Failures may result in partially rewritten packets.
+	// Parse and validate IP packet structure
 
 	// Must have an IP version field.
-
 	if len(packet) < 1 {
 		metrics.rejectedPacket(direction, packetRejectLength)
 		return false
@@ -1712,7 +1873,6 @@ func processPacket(
 	version := packet[0] >> 4
 
 	// Must be IPv4 or IPv6.
-
 	if version != 4 && version != 6 {
 		metrics.rejectedPacket(direction, packetRejectVersion)
 		return false
@@ -1722,6 +1882,7 @@ func processPacket(
 	var sourceIPAddress, destinationIPAddress net.IP
 	var sourcePort, destinationPort uint16
 	var IPChecksum, TCPChecksum, UDPChecksum []byte
+	var applicationData []byte
 
 	if version == 4 {
 
@@ -1743,14 +1904,21 @@ func processPacket(
 		// Protocol must be TCP or UDP.
 
 		protocol = internetProtocol(packet[9])
+		dataOffset := 0
 
 		if protocol == internetProtocolTCP {
-			if len(packet) < 40 {
+			if len(packet) < 32 {
+				metrics.rejectedPacket(direction, packetRejectTCPProtocolLength)
+				return false
+			}
+			dataOffset = 20 + 4*int(packet[32]>>4)
+			if len(packet) < dataOffset {
 				metrics.rejectedPacket(direction, packetRejectTCPProtocolLength)
 				return false
 			}
 		} else if protocol == internetProtocolUDP {
-			if len(packet) < 28 {
+			dataOffset := 28
+			if len(packet) < dataOffset {
 				metrics.rejectedPacket(direction, packetRejectUDPProtocolLength)
 				return false
 			}
@@ -1759,6 +1927,8 @@ func processPacket(
 			return false
 		}
 
+		applicationData = packet[dataOffset:]
+
 		// Slices reference packet bytes to be rewritten.
 
 		sourceIPAddress = packet[12:16]
@@ -1788,14 +1958,21 @@ func processPacket(
 		nextHeader := packet[6]
 
 		protocol = internetProtocol(nextHeader)
+		dataOffset := 0
 
 		if protocol == internetProtocolTCP {
-			if len(packet) < 60 {
+			if len(packet) < 52 {
+				metrics.rejectedPacket(direction, packetRejectTCPProtocolLength)
+				return false
+			}
+			dataOffset = 40 + 4*int(packet[52]>>4)
+			if len(packet) < dataOffset {
 				metrics.rejectedPacket(direction, packetRejectTCPProtocolLength)
 				return false
 			}
 		} else if protocol == internetProtocolUDP {
-			if len(packet) < 48 {
+			dataOffset := 48
+			if len(packet) < dataOffset {
 				metrics.rejectedPacket(direction, packetRejectUDPProtocolLength)
 				return false
 			}
@@ -1804,6 +1981,8 @@ func processPacket(
 			return false
 		}
 
+		applicationData = packet[dataOffset:]
+
 		// Slices reference packet bytes to be rewritten.
 
 		sourceIPAddress = packet[8:24]
@@ -1821,66 +2000,170 @@ func processPacket(
 		}
 	}
 
-	var upstreamIPAddress net.IP
-	if direction == packetDirectionServerUpstream {
+	// Apply rules
+	//
+	// Most of this logic is only applied on the server, as only
+	// the server knows the traffic rules configuration, and is
+	// tracking flows.
 
-		upstreamIPAddress = destinationIPAddress
+	isServer := (direction == packetDirectionServerUpstream ||
+		direction == packetDirectionServerDownstream)
 
-	} else if direction == packetDirectionServerDownstream {
+	// Check if the packet qualifies for transparent DNS rewriting
+	//
+	// - Both TCP and UDP DNS packets may qualify
+	// - Transparent DNS flows are not tracked, as most DNS
+	//   resolutions are very-short lived exchanges
+	// - The traffic rules checks are bypassed, since transparent
+	//   DNS is essential
+
+	doTransparentDNS := false
+
+	if isServer {
+		if direction == packetDirectionServerUpstream {
+
+			// DNS packets destinated for the transparent DNS target addresses
+			// will be rewritten to go to one of the server's resolvers.
+
+			if destinationPort == portNumberDNS {
+				if version == 4 && destinationIPAddress.Equal(transparentDNSResolverIPv4Address) {
+					numResolvers := len(session.DNSResolverIPv4Addresses)
+					if numResolvers > 0 {
+						doTransparentDNS = true
+					} else {
+						metrics.rejectedPacket(direction, packetRejectNoDNSResolvers)
+						return false
+					}
+
+				} else if version == 6 && destinationIPAddress.Equal(transparentDNSResolverIPv6Address) {
+					numResolvers := len(session.DNSResolverIPv6Addresses)
+					if numResolvers > 0 {
+						doTransparentDNS = true
+					} else {
+						metrics.rejectedPacket(direction, packetRejectNoDNSResolvers)
+						return false
+					}
+				}
+			}
 
-		upstreamIPAddress = sourceIPAddress
+		} else { // packetDirectionServerDownstream
+
+			// DNS packets with a source address of any of the server's
+			// resolvers will be rewritten back to the transparent DNS target
+			// address.
+
+			// Limitation: responses to client DNS packets _originally
+			// destined_ for a resolver in GetDNSResolverIPv4Addresses will
+			// be lost. This would happen if some process on the client
+			// ignores the system set DNS values; and forces use of the same
+			// resolvers as the server.
+
+			if sourcePort == portNumberDNS {
+				if version == 4 {
+					for _, IPAddress := range session.DNSResolverIPv4Addresses {
+						if sourceIPAddress.Equal(IPAddress) {
+							doTransparentDNS = true
+							break
+						}
+					}
+				} else if version == 6 {
+					for _, IPAddress := range session.DNSResolverIPv6Addresses {
+						if sourceIPAddress.Equal(IPAddress) {
+							doTransparentDNS = true
+							break
+						}
+					}
+				}
+			}
+		}
 	}
 
-	// Enforce traffic rules (allowed TCP/UDP ports).
+	// Check if flow is tracked before checking traffic permission
 
-	checkPort := 0
-	if direction == packetDirectionServerUpstream ||
-		direction == packetDirectionClientUpstream {
+	doFlowTracking := !doTransparentDNS && isServer
+
+	// TODO: verify this struct is stack allocated
+	var flowID flowID
+
+	isTrackingFlow := false
+
+	if doFlowTracking {
+
+		flowID.protocol = protocol
+
+		if direction == packetDirectionServerUpstream {
+
+			flowID.upstreamIPAddress = destinationIPAddress
+			flowID.upstreamPort = destinationPort
+			flowID.downstreamIPAddress = sourceIPAddress
+			flowID.downstreamPort = sourcePort
 
-		checkPort = int(destinationPort)
+		} else if direction == packetDirectionServerDownstream {
 
-	} else if direction == packetDirectionServerDownstream ||
-		direction == packetDirectionClientDownstream {
+			flowID.upstreamIPAddress = sourceIPAddress
+			flowID.upstreamPort = sourcePort
+			flowID.downstreamIPAddress = destinationIPAddress
+			flowID.downstreamPort = destinationPort
+		}
 
-		checkPort = int(sourcePort)
+		isTrackingFlow = session.isTrackingFlow(flowID)
 	}
 
-	if protocol == internetProtocolTCP {
+	// Check packet source/destination is permitted; except for:
+	// - existing flows, which have already been checked
+	// - transparent DNS, which is always allowed
 
-		if checkPort == 0 ||
-			(session != nil &&
-				!session.checkAllowedTCPPortFunc(upstreamIPAddress, checkPort)) {
+	if !doTransparentDNS && !isTrackingFlow {
 
-			metrics.rejectedPacket(direction, packetRejectTCPPort)
-			return false
+		// Enforce traffic rules (allowed TCP/UDP ports).
+
+		checkPort := 0
+		if direction == packetDirectionServerUpstream ||
+			direction == packetDirectionClientUpstream {
+
+			checkPort = int(destinationPort)
+
+		} else if direction == packetDirectionServerDownstream ||
+			direction == packetDirectionClientDownstream {
+
+			checkPort = int(sourcePort)
 		}
 
-	} else if protocol == internetProtocolUDP {
+		if protocol == internetProtocolTCP {
 
-		if checkPort == 0 ||
-			(session != nil &&
-				!session.checkAllowedUDPPortFunc(upstreamIPAddress, checkPort)) {
+			if checkPort == 0 ||
+				(isServer && !session.checkAllowedTCPPortFunc(flowID.upstreamIPAddress, checkPort)) {
 
-			metrics.rejectedPacket(direction, packetRejectUDPPort)
-			return false
+				metrics.rejectedPacket(direction, packetRejectTCPPort)
+				return false
+			}
+
+		} else if protocol == internetProtocolUDP {
+
+			if checkPort == 0 ||
+				(isServer && !session.checkAllowedUDPPortFunc(flowID.upstreamIPAddress, checkPort)) {
+
+				metrics.rejectedPacket(direction, packetRejectUDPPort)
+				return false
+			}
 		}
-	}
 
-	// Enforce no localhost, multicast or broadcast packets; and
-	// no client-to-client packets.
+		// Enforce no localhost, multicast or broadcast packets; and
+		// no client-to-client packets.
 
-	if !destinationIPAddress.IsGlobalUnicast() ||
+		if !destinationIPAddress.IsGlobalUnicast() ||
 
-		(direction == packetDirectionServerUpstream &&
-			((version == 4 &&
-				!destinationIPAddress.Equal(transparentDNSResolverIPv4Address) &&
-				privateSubnetIPv4.Contains(destinationIPAddress)) ||
-				(version == 6 &&
-					!destinationIPAddress.Equal(transparentDNSResolverIPv6Address) &&
-					privateSubnetIPv6.Contains(destinationIPAddress)))) {
+			(direction == packetDirectionServerUpstream &&
+				((version == 4 &&
+					!destinationIPAddress.Equal(transparentDNSResolverIPv4Address) &&
+					privateSubnetIPv4.Contains(destinationIPAddress)) ||
+					(version == 6 &&
+						!destinationIPAddress.Equal(transparentDNSResolverIPv6Address) &&
+						privateSubnetIPv6.Contains(destinationIPAddress)))) {
 
-		metrics.rejectedPacket(direction, packetRejectDestinationAddress)
-		return false
+			metrics.rejectedPacket(direction, packetRejectDestinationAddress)
+			return false
+		}
 	}
 
 	// Configure rewriting.
@@ -1904,24 +2187,14 @@ func processPacket(
 		// Rewrite DNS packets destinated for the transparent DNS target
 		// addresses to go to one of the server's resolvers.
 
-		if destinationPort == portNumberDNS {
-			if version == 4 && destinationIPAddress.Equal(transparentDNSResolverIPv4Address) {
-				numResolvers := len(session.DNSResolverIPv4Addresses)
-				if numResolvers > 0 {
-					rewriteDestinationIPAddress = session.DNSResolverIPv4Addresses[rand.Intn(numResolvers)]
-				} else {
-					metrics.rejectedPacket(direction, packetRejectNoDNSResolvers)
-					return false
-				}
+		if doTransparentDNS {
 
-			} else if version == 6 && destinationIPAddress.Equal(transparentDNSResolverIPv6Address) {
-				numResolvers := len(session.DNSResolverIPv6Addresses)
-				if numResolvers > 0 {
-					rewriteDestinationIPAddress = session.DNSResolverIPv6Addresses[rand.Intn(numResolvers)]
-				} else {
-					metrics.rejectedPacket(direction, packetRejectNoDNSResolvers)
-					return false
-				}
+			if version == 4 {
+				rewriteDestinationIPAddress = session.DNSResolverIPv4Addresses[rand.Intn(
+					len(session.DNSResolverIPv4Addresses))]
+			} else { // version == 6
+				rewriteDestinationIPAddress = session.DNSResolverIPv6Addresses[rand.Intn(
+					len(session.DNSResolverIPv6Addresses))]
 			}
 		}
 
@@ -1940,30 +2213,15 @@ func processPacket(
 			return false
 		}
 
-		// Source address for DNS packets from the server's resolvers
-		// will be changed to transparent DNS target address.
+		// Rewrite source address  of packets from servers' resolvers
+		// to transparent DNS target address.
 
-		// Limitation: responses to client DNS packets _originally
-		// destined_ for a resolver in GetDNSResolverIPv4Addresses will
-		// be lost. This would happen if some process on the client
-		// ignores the system set DNS values; and forces use of the same
-		// resolvers as the server.
+		if doTransparentDNS {
 
-		if sourcePort == portNumberDNS {
 			if version == 4 {
-				for _, IPAddress := range session.DNSResolverIPv4Addresses {
-					if sourceIPAddress.Equal(IPAddress) {
-						rewriteSourceIPAddress = transparentDNSResolverIPv4Address
-						break
-					}
-				}
-			} else if version == 6 {
-				for _, IPAddress := range session.DNSResolverIPv6Addresses {
-					if sourceIPAddress.Equal(IPAddress) {
-						rewriteSourceIPAddress = transparentDNSResolverIPv6Address
-						break
-					}
-				}
+				rewriteSourceIPAddress = transparentDNSResolverIPv4Address
+			} else { // version == 6
+				rewriteSourceIPAddress = transparentDNSResolverIPv6Address
 			}
 		}
 	}
@@ -1997,7 +2255,17 @@ func processPacket(
 		}
 	}
 
-	metrics.relayedPacket(direction, int(version), protocol, upstreamIPAddress, len(packet))
+	// Start/update flow tracking, only once past all possible packet rejects
+
+	if doFlowTracking {
+		if !isTrackingFlow {
+			session.startTrackingFlow(flowID, direction, applicationData)
+		} else {
+			session.updateFlow(flowID, direction, applicationData)
+		}
+	}
+
+	metrics.relayedPacket(direction, int(version), protocol, len(packet))
 
 	return true
 }

+ 53 - 5
psiphon/common/tun/tun_test.go

@@ -29,6 +29,7 @@ import (
 	"os"
 	"strconv"
 	"sync"
+	"sync/atomic"
 	"syscall"
 	"testing"
 	"time"
@@ -82,6 +83,8 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 	// - each TCP client transfers TCP_RELAY_TOTAL_SIZE bytes to the TCP server
 	// - the test checks that all data echoes back correctly and that the server packet
 	//   metrics reflects the expected amount of data transferred through the tunnel
+	// - the test also checks that the flow activity updater mechanism correctly reports
+	//   the total bytes transferred
 	// - this test runs in either IPv4 or IPv6 mode
 	// - the test host's public IP address is used as the TCP server IP address; it is
 	//   expected that the server tun device will NAT to the public interface; clients
@@ -91,6 +94,10 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 	// Note: this test can modify host network configuration; in addition to tun device
 	// and routing config, see the changes made in fixBindToDevice.
 
+	if TCP_RELAY_TOTAL_SIZE%TCP_RELAY_CHUNK_SIZE != 0 {
+		t.Fatalf("startTestTCPServer failed: invalid relay size")
+	}
+
 	MTU := DEFAULT_MTU
 
 	testTCPServer, err := startTestTCPServer(useIPv6)
@@ -101,7 +108,12 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 		t.Fatalf("startTestTCPServer failed: %s", err)
 	}
 
-	testServer, err := startTestServer(useIPv6, MTU)
+	var counter bytesTransferredCounter
+	flowActivityUpdaterMaker := func(_ string, _ net.IP) []FlowActivityUpdater {
+		return []FlowActivityUpdater{&counter}
+	}
+
+	testServer, err := startTestServer(useIPv6, MTU, flowActivityUpdaterMaker)
 	if err != nil {
 		t.Fatalf("startTestServer failed: %s", err)
 	}
@@ -139,8 +151,6 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 
 			sendChunk, receiveChunk := make([]byte, TCP_RELAY_CHUNK_SIZE), make([]byte, TCP_RELAY_CHUNK_SIZE)
 
-			// Note: data transfer doesn't have to be exactly TCP_RELAY_TOTAL_SIZE,
-			// so not handling TCP_RELAY_TOTAL_SIZE%TCP_RELAY_CHUNK_SIZE != 0.
 			for i := int64(0); i < TCP_RELAY_TOTAL_SIZE; i += TCP_RELAY_CHUNK_SIZE {
 
 				_, err := rand.Read(sendChunk)
@@ -227,20 +237,56 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 		}
 	}
 
+	upstreamBytesTransferred, downstreamBytesTransferred, _ := counter.Get()
+	expectedBytesTransferred := CONCURRENT_CLIENT_COUNT * TCP_RELAY_TOTAL_SIZE
+	if upstreamBytesTransferred != expectedBytesTransferred {
+		t.Fatalf("unexpected upstreamBytesTransferred: %d: %d",
+			upstreamBytesTransferred, expectedBytesTransferred)
+	}
+	if downstreamBytesTransferred != expectedBytesTransferred {
+		t.Fatalf("unexpected downstreamBytesTransferred: %d: %d",
+			downstreamBytesTransferred, expectedBytesTransferred)
+	}
+
 	testServer.stop()
 
 	testTCPServer.stop()
 }
 
+type bytesTransferredCounter struct {
+	// Note: 64-bit ints used with atomic operations are placed
+	// at the start of struct to ensure 64-bit alignment.
+	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
+	upstreamBytes       int64
+	downstreamBytes     int64
+	durationNanoseconds int64
+}
+
+func (counter *bytesTransferredCounter) UpdateProgress(
+	upstreamBytes, downstreamBytes int64, durationNanoseconds int64) {
+
+	atomic.AddInt64(&counter.upstreamBytes, upstreamBytes)
+	atomic.AddInt64(&counter.downstreamBytes, downstreamBytes)
+	atomic.AddInt64(&counter.durationNanoseconds, durationNanoseconds)
+}
+
+func (counter *bytesTransferredCounter) Get() (int64, int64, int64) {
+	return atomic.LoadInt64(&counter.upstreamBytes),
+		atomic.LoadInt64(&counter.downstreamBytes),
+		atomic.LoadInt64(&counter.durationNanoseconds)
+}
+
 type testServer struct {
 	logger       *testLogger
+	updaterMaker FlowActivityUpdaterMaker
 	tunServer    *Server
 	unixListener net.Listener
 	clientConns  *common.Conns
 	workers      *sync.WaitGroup
 }
 
-func startTestServer(useIPv6 bool, MTU int) (*testServer, error) {
+func startTestServer(
+	useIPv6 bool, MTU int, updaterMaker FlowActivityUpdaterMaker) (*testServer, error) {
 
 	logger := newTestLogger(true)
 
@@ -271,6 +317,7 @@ func startTestServer(useIPv6 bool, MTU int) (*testServer, error) {
 
 	server := &testServer{
 		logger:       logger,
+		updaterMaker: updaterMaker,
 		tunServer:    tunServer,
 		unixListener: unixListener,
 		clientConns:  new(common.Conns),
@@ -316,7 +363,8 @@ func (server *testServer) run() {
 				sessionID,
 				signalConn,
 				checkAllowedPortFunc,
-				checkAllowedPortFunc)
+				checkAllowedPortFunc,
+				server.updaterMaker)
 
 			signalConn.Wait()
 

+ 14 - 1
psiphon/server/tunnelServer.go

@@ -39,6 +39,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
 )
 
 const (
@@ -1322,11 +1323,23 @@ func (sshClient *sshClient) runTunnel(
 				return sshClient.isPortForwardPermitted(portForwardTypeUDP, false, upstreamIPAddress, port)
 			}
 
+			flowActivityUpdaterMaker := func(
+				upstreamHostname string, upstreamIPAddress net.IP) []tun.FlowActivityUpdater {
+
+				var updaters []tun.FlowActivityUpdater
+				oslUpdater := sshClient.newClientSeedPortForward(upstreamIPAddress)
+				if oslUpdater != nil {
+					updaters = append(updaters, oslUpdater)
+				}
+				return updaters
+			}
+
 			sshClient.sshServer.support.PacketTunnelServer.ClientConnected(
 				sshClient.sessionID,
 				packetTunnelChannel,
 				checkAllowedTCPPortFunc,
-				checkAllowedUDPPortFunc)
+				checkAllowedUDPPortFunc,
+				flowActivityUpdaterMaker)
 		}
 
 		if newChannel.ChannelType() != "direct-tcpip" {