Просмотр исходного кода

Add packet tunnel bytes transferred to server_tunnel

- Add application bytes transferred to packet_metrics
  log.

- Add packet tunnel application bytes transferred to
  server_tunnel log.

- Fix: access control and flow updater callbacks are
  reset when client reconnects.
Rod Hynes 8 лет назад
Родитель
Сommit
bb5b39a0c5
3 измененных файлов с 309 добавлено и 107 удалено
  1. 248 86
      psiphon/common/tun/tun.go
  2. 47 20
      psiphon/common/tun/tun_test.go
  3. 14 1
      psiphon/server/tunnelServer.go

+ 248 - 86
psiphon/common/tun/tun.go

@@ -342,6 +342,12 @@ type FlowActivityUpdater interface {
 type FlowActivityUpdaterMaker func(
 	upstreamHostname string, upstreamIPAddress net.IP) []FlowActivityUpdater
 
+// MetricsUpdater is a function which receives a checkpoint summary
+// of application bytes transferred through a packet tunnel.
+type MetricsUpdater func(
+	TCPApplicationBytesUp, TCPApplicationBytesDown,
+	UDPApplicationBytesUp, UDPApplicationBytesDown int64)
+
 // ClientConnected handles new client connections, creating or resuming
 // a session and returns with client packet handlers running.
 //
@@ -356,6 +362,13 @@ type FlowActivityUpdaterMaker func(
 // permitted. These callbacks must be efficient and safe for concurrent
 // calls.
 //
+// flowActivityUpdaterMaker is a callback invoked for each new packet
+// flow; it may create updaters to track flow activity.
+//
+// metricsUpdater is a callback invoked at metrics checkpoints (usually
+// when the client disconnects) with a summary of application bytes
+// transferred.
+//
 // It is safe to make concurrent calls to ClientConnected for distinct
 // session IDs. The caller is responsible for serializing calls with the
 // same session ID. Further, the caller must ensure, in the case of a client
@@ -368,7 +381,8 @@ func (server *Server) ClientConnected(
 	sessionID string,
 	transport io.ReadWriteCloser,
 	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker,
-	flowActivityUpdaterMaker FlowActivityUpdaterMaker) error {
+	flowActivityUpdaterMaker FlowActivityUpdaterMaker,
+	metricsUpdater MetricsUpdater) error {
 
 	// It's unusual to call both sync.WaitGroup.Add() _and_ Done() in the same
 	// goroutine. There's no other place to call Add() since ClientConnected is
@@ -427,9 +441,6 @@ func (server *Server) ClientConnected(
 			metrics:                  new(packetMetrics),
 			DNSResolverIPv4Addresses: append([]net.IP(nil), DNSResolverIPv4Addresses...),
 			DNSResolverIPv6Addresses: append([]net.IP(nil), server.config.GetDNSResolverIPv6Addresses()...),
-			checkAllowedTCPPortFunc:  checkAllowedTCPPortFunc,
-			checkAllowedUDPPortFunc:  checkAllowedUDPPortFunc,
-			flowActivityUpdaterMaker: flowActivityUpdaterMaker,
 			workers:                  new(sync.WaitGroup),
 		}
 
@@ -448,7 +459,13 @@ func (server *Server) ClientConnected(
 	// allocateIndex and resumeSession calls here, so interruptSession and
 	// related code must not assume resumeSession has been called.
 
-	server.resumeSession(clientSession, NewChannel(transport, MTU))
+	server.resumeSession(
+		clientSession,
+		NewChannel(transport, MTU),
+		checkAllowedTCPPortFunc,
+		checkAllowedUDPPortFunc,
+		flowActivityUpdaterMaker,
+		metricsUpdater)
 
 	return nil
 }
@@ -481,17 +498,50 @@ func (server *Server) getSession(sessionID string) *session {
 	return nil
 }
 
-func (server *Server) resumeSession(session *session, channel *Channel) {
+func (server *Server) resumeSession(
+	session *session,
+	channel *Channel,
+	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker,
+	flowActivityUpdaterMaker FlowActivityUpdaterMaker,
+	metricsUpdater MetricsUpdater) {
 
 	session.mutex.Lock()
 	defer session.mutex.Unlock()
 
+	// Performance/concurrency note: the downstream packet queue
+	// and various packet event callbacks may be accessed while
+	// the session is idle, via the runDeviceDownstream goroutine,
+	// which runs concurrent to resumeSession/interruptSession calls.
+	// Consequently, all accesses to these fields must be
+	// synchronized.
+	//
+	// Benchmarking indicates the atomic.LoadPointer mechanism
+	// outperforms a mutex; approx. 2 ns/op vs. 20 ns/op in the case
+	// of getCheckAllowedTCPPortFunc. Since these accesses occur
+	// multiple times per packet, atomic.LoadPointer is used and so
+	// each of these fields is an unsafe.Pointer in the session
+	// struct.
+
+	// Begin buffering downstream packets.
+
 	downstreamPacketQueueSize := DEFAULT_DOWNSTREAM_PACKET_QUEUE_SIZE
 	if server.config.DownstreamPacketQueueSize > 0 {
 		downstreamPacketQueueSize = server.config.DownstreamPacketQueueSize
 	}
 	downstreamPackets := NewPacketQueue(downstreamPacketQueueSize)
-	atomic.StorePointer(&session.downstreamPackets, unsafe.Pointer(downstreamPackets))
+
+	session.setDownstreamPackets(downstreamPackets)
+
+	// Set new access control, flow monitoring, and metrics
+	// callbacks; all associated with the new client connection.
+
+	session.setCheckAllowedTCPPortFunc(checkAllowedTCPPortFunc)
+
+	session.setCheckAllowedUDPPortFunc(checkAllowedUDPPortFunc)
+
+	session.setFlowActivityUpdaterMaker(flowActivityUpdaterMaker)
+
+	session.setMetricsUpdater(metricsUpdater)
 
 	session.channel = channel
 
@@ -540,22 +590,34 @@ func (server *Server) interruptSession(session *session) {
 		session.channel = nil
 	}
 
-	// Release the downstream packet buffer, so the associated
-	// memory is not consumed while no client is connected.
-	//
-	// Since runDeviceDownstream continues to run and will access
-	// session.downstreamPackets, an atomic pointer is used to
-	// synchronize access.
-	atomic.StorePointer(&session.downstreamPackets, unsafe.Pointer(nil))
+	metricsUpdater := session.getMetricsUpdater()
 
 	// interruptSession may be called for idle sessions, to ensure
 	// the session is in an expected state: in ClientConnected,
 	// and in server.Stop(); don't log in those cases.
 	if wasRunning {
 		session.metrics.checkpoint(
-			server.config.Logger, "packet_metrics", packetMetricsAll)
+			server.config.Logger,
+			metricsUpdater,
+			"packet_metrics",
+			packetMetricsAll)
 	}
 
+	// Release the downstream packet buffer, so the associated
+	// memory is not consumed while no client is connected.
+	//
+	// Since runDeviceDownstream continues to run and will access
+	// session.downstreamPackets, an atomic pointer is used to
+	// synchronize access.
+	session.setDownstreamPackets(nil)
+
+	session.setCheckAllowedTCPPortFunc(nil)
+
+	session.setCheckAllowedUDPPortFunc(nil)
+
+	session.setFlowActivityUpdaterMaker(nil)
+
+	session.setMetricsUpdater(nil)
 }
 
 func (server *Server) runSessionReaper() {
@@ -630,7 +692,7 @@ func (server *Server) runOrphanMetricsCheckpointer() {
 
 		// TODO: skip log if all zeros?
 		server.orphanMetrics.checkpoint(
-			server.config.Logger, "orphan_packet_metrics", packetMetricsRejected)
+			server.config.Logger, nil, "orphan_packet_metrics", packetMetricsRejected)
 		if done {
 			return
 		}
@@ -694,7 +756,7 @@ func (server *Server) runDeviceDownstream() {
 
 		session := s.(*session)
 
-		downstreamPackets := (*PacketQueue)(atomic.LoadPointer(&session.downstreamPackets))
+		downstreamPackets := session.getDownstreamPackets()
 
 		// No downstreamPackets buffer is maintained when no client is
 		// connected, so the packet is dropped.
@@ -811,8 +873,7 @@ func (server *Server) runClientDownstream(session *session) {
 
 	for {
 
-		downstreamPackets := (*PacketQueue)(atomic.LoadPointer(&session.downstreamPackets))
-
+		downstreamPackets := session.getDownstreamPackets()
 		// Note: downstreamPackets will not be nil, since this goroutine only
 		// runs while the session has a connected client.
 
@@ -1002,9 +1063,10 @@ type session struct {
 	assignedIPv6Address      net.IP
 	setOriginalIPv6Address   int32
 	originalIPv6Address      net.IP
-	checkAllowedTCPPortFunc  AllowedPortChecker
-	checkAllowedUDPPortFunc  AllowedPortChecker
-	flowActivityUpdaterMaker FlowActivityUpdaterMaker
+	checkAllowedTCPPortFunc  unsafe.Pointer
+	checkAllowedUDPPortFunc  unsafe.Pointer
+	flowActivityUpdaterMaker unsafe.Pointer
+	metricsUpdater           unsafe.Pointer
 	downstreamPackets        unsafe.Pointer
 	flows                    sync.Map
 	workers                  *sync.WaitGroup
@@ -1014,6 +1076,106 @@ type session struct {
 	stopRunning              context.CancelFunc
 }
 
+func (session *session) touch() {
+	atomic.StoreInt64(&session.lastActivity, int64(monotime.Now()))
+}
+
+func (session *session) expired(idleExpiry time.Duration) bool {
+	lastActivity := monotime.Time(atomic.LoadInt64(&session.lastActivity))
+	return monotime.Since(lastActivity) > idleExpiry
+}
+
+func (session *session) setOriginalIPv4AddressIfNotSet(IPAddress net.IP) {
+	if !atomic.CompareAndSwapInt32(&session.setOriginalIPv4Address, 0, 1) {
+		return
+	}
+	// Make a copy of IPAddress; don't reference a slice of a reusable
+	// packet buffer, which will be overwritten.
+	session.originalIPv4Address = net.IP(append([]byte(nil), []byte(IPAddress)...))
+}
+
+func (session *session) getOriginalIPv4Address() net.IP {
+	if atomic.LoadInt32(&session.setOriginalIPv4Address) == 0 {
+		return nil
+	}
+	return session.originalIPv4Address
+}
+
+func (session *session) setOriginalIPv6AddressIfNotSet(IPAddress net.IP) {
+	if !atomic.CompareAndSwapInt32(&session.setOriginalIPv6Address, 0, 1) {
+		return
+	}
+	// Make a copy of IPAddress.
+	session.originalIPv6Address = net.IP(append([]byte(nil), []byte(IPAddress)...))
+}
+
+func (session *session) getOriginalIPv6Address() net.IP {
+	if atomic.LoadInt32(&session.setOriginalIPv6Address) == 0 {
+		return nil
+	}
+	return session.originalIPv6Address
+}
+
+func (session *session) setCheckAllowedTCPPortFunc(f AllowedPortChecker) {
+	g := f
+	atomic.StorePointer(&session.checkAllowedTCPPortFunc, unsafe.Pointer(&g))
+}
+
+func (session *session) getCheckAllowedTCPPortFunc() AllowedPortChecker {
+	f := (*AllowedPortChecker)(atomic.LoadPointer(&session.checkAllowedTCPPortFunc))
+	if f == nil {
+		return nil
+	}
+	return *f
+}
+
+func (session *session) setCheckAllowedUDPPortFunc(f AllowedPortChecker) {
+	g := f
+	atomic.StorePointer(&session.checkAllowedUDPPortFunc, unsafe.Pointer(&g))
+}
+
+func (session *session) getCheckAllowedUDPPortFunc() AllowedPortChecker {
+	f := (*AllowedPortChecker)(atomic.LoadPointer(&session.checkAllowedUDPPortFunc))
+	if f == nil {
+		return nil
+	}
+	return *f
+}
+
+func (session *session) setFlowActivityUpdaterMaker(f FlowActivityUpdaterMaker) {
+	g := f
+	atomic.StorePointer(&session.flowActivityUpdaterMaker, unsafe.Pointer(&g))
+}
+
+func (session *session) getFlowActivityUpdaterMaker() FlowActivityUpdaterMaker {
+	f := (*FlowActivityUpdaterMaker)(atomic.LoadPointer(&session.flowActivityUpdaterMaker))
+	if f == nil {
+		return nil
+	}
+	return *f
+}
+
+func (session *session) setMetricsUpdater(f MetricsUpdater) {
+	g := f
+	atomic.StorePointer(&session.flowActivityUpdaterMaker, unsafe.Pointer(&g))
+}
+
+func (session *session) getMetricsUpdater() MetricsUpdater {
+	f := (*MetricsUpdater)(atomic.LoadPointer(&session.metricsUpdater))
+	if f == nil {
+		return nil
+	}
+	return *f
+}
+
+func (session *session) setDownstreamPackets(p *PacketQueue) {
+	atomic.StorePointer(&session.downstreamPackets, unsafe.Pointer(p))
+}
+
+func (session *session) getDownstreamPackets() *PacketQueue {
+	return (*PacketQueue)(atomic.LoadPointer(&session.downstreamPackets))
+}
+
 // flowID identifies an IP traffic flow using the conventional
 // network 5-tuple. flowIDs track bidirectional flows.
 type flowID struct {
@@ -1068,46 +1230,6 @@ func (flowState *flowState) expired(idleExpiry time.Duration) bool {
 		(now.Sub(monotime.Time(atomic.LoadInt64(&flowState.lastDownstreamPacketTime))) > idleExpiry)
 }
 
-func (session *session) touch() {
-	atomic.StoreInt64(&session.lastActivity, int64(monotime.Now()))
-}
-
-func (session *session) expired(idleExpiry time.Duration) bool {
-	lastActivity := monotime.Time(atomic.LoadInt64(&session.lastActivity))
-	return monotime.Since(lastActivity) > idleExpiry
-}
-
-func (session *session) setOriginalIPv4AddressIfNotSet(IPAddress net.IP) {
-	if !atomic.CompareAndSwapInt32(&session.setOriginalIPv4Address, 0, 1) {
-		return
-	}
-	// Make a copy of IPAddress; don't reference a slice of a reusable
-	// packet buffer, which will be overwritten.
-	session.originalIPv4Address = net.IP(append([]byte(nil), []byte(IPAddress)...))
-}
-
-func (session *session) getOriginalIPv4Address() net.IP {
-	if atomic.LoadInt32(&session.setOriginalIPv4Address) == 0 {
-		return nil
-	}
-	return session.originalIPv4Address
-}
-
-func (session *session) setOriginalIPv6AddressIfNotSet(IPAddress net.IP) {
-	if !atomic.CompareAndSwapInt32(&session.setOriginalIPv6Address, 0, 1) {
-		return
-	}
-	// Make a copy of IPAddress.
-	session.originalIPv6Address = net.IP(append([]byte(nil), []byte(IPAddress)...))
-}
-
-func (session *session) getOriginalIPv6Address() net.IP {
-	if atomic.LoadInt32(&session.setOriginalIPv6Address) == 0 {
-		return nil
-	}
-	return session.originalIPv6Address
-}
-
 // isTrackingFlow checks if a flow is being tracked.
 func (session *session) isTrackingFlow(ID flowID) bool {
 
@@ -1171,10 +1293,17 @@ func (session *session) startTrackingFlow(
 		// hostname = common.ExtractHostnameFromTCPFlow(applicationData)
 	}
 
-	flowState := &flowState{
-		activityUpdaters: session.flowActivityUpdaterMaker(
+	var activityUpdaters []FlowActivityUpdater
+
+	flowActivityUpdaterMaker := session.getFlowActivityUpdaterMaker()
+	if flowActivityUpdaterMaker != nil {
+		activityUpdaters = flowActivityUpdaterMaker(
 			hostname,
-			net.IP(ID.upstreamIPAddress[:])),
+			net.IP(ID.upstreamIPAddress[:]))
+	}
+
+	flowState := &flowState{
+		activityUpdaters: activityUpdaters,
 	}
 
 	if direction == packetDirectionServerUpstream {
@@ -1245,10 +1374,12 @@ type packetMetrics struct {
 }
 
 type relayedPacketMetrics struct {
-	packetsUp   int64
-	packetsDown int64
-	bytesUp     int64
-	bytesDown   int64
+	packetsUp            int64
+	packetsDown          int64
+	bytesUp              int64
+	bytesDown            int64
+	applicationBytesUp   int64
+	applicationBytesDown int64
 }
 
 func (metrics *packetMetrics) rejectedPacket(
@@ -1271,9 +1402,9 @@ func (metrics *packetMetrics) relayedPacket(
 	direction packetDirection,
 	version int,
 	protocol internetProtocol,
-	packetLength int) {
+	packetLength, applicationDataLength int) {
 
-	var packetsMetric, bytesMetric *int64
+	var packetsMetric, bytesMetric, applicationBytesMetric *int64
 
 	if direction == packetDirectionServerUpstream ||
 		direction == packetDirectionClientUpstream {
@@ -1283,9 +1414,11 @@ func (metrics *packetMetrics) relayedPacket(
 			if protocol == internetProtocolTCP {
 				packetsMetric = &metrics.TCPIPv4.packetsUp
 				bytesMetric = &metrics.TCPIPv4.bytesUp
+				applicationBytesMetric = &metrics.TCPIPv4.applicationBytesUp
 			} else { // UDP
 				packetsMetric = &metrics.UDPIPv4.packetsUp
 				bytesMetric = &metrics.UDPIPv4.bytesUp
+				applicationBytesMetric = &metrics.UDPIPv4.applicationBytesUp
 			}
 
 		} else { // IPv6
@@ -1293,9 +1426,11 @@ func (metrics *packetMetrics) relayedPacket(
 			if protocol == internetProtocolTCP {
 				packetsMetric = &metrics.TCPIPv6.packetsUp
 				bytesMetric = &metrics.TCPIPv6.bytesUp
+				applicationBytesMetric = &metrics.TCPIPv6.applicationBytesUp
 			} else { // UDP
 				packetsMetric = &metrics.UDPIPv6.packetsUp
 				bytesMetric = &metrics.UDPIPv6.bytesUp
+				applicationBytesMetric = &metrics.UDPIPv6.applicationBytesUp
 			}
 		}
 
@@ -1306,9 +1441,11 @@ func (metrics *packetMetrics) relayedPacket(
 			if protocol == internetProtocolTCP {
 				packetsMetric = &metrics.TCPIPv4.packetsDown
 				bytesMetric = &metrics.TCPIPv4.bytesDown
+				applicationBytesMetric = &metrics.TCPIPv4.applicationBytesDown
 			} else { // UDP
 				packetsMetric = &metrics.UDPIPv4.packetsDown
 				bytesMetric = &metrics.UDPIPv4.bytesDown
+				applicationBytesMetric = &metrics.UDPIPv4.applicationBytesDown
 			}
 
 		} else { // IPv6
@@ -1316,19 +1453,18 @@ func (metrics *packetMetrics) relayedPacket(
 			if protocol == internetProtocolTCP {
 				packetsMetric = &metrics.TCPIPv6.packetsDown
 				bytesMetric = &metrics.TCPIPv6.bytesDown
+				applicationBytesMetric = &metrics.TCPIPv6.applicationBytesDown
 			} else { // UDP
 				packetsMetric = &metrics.UDPIPv6.packetsDown
 				bytesMetric = &metrics.UDPIPv6.bytesDown
+				applicationBytesMetric = &metrics.UDPIPv6.applicationBytesDown
 			}
 		}
 	}
 
-	// Note: packet length, and so bytes transferred, includes IP and TCP/UDP
-	// headers, not just payload data, as is counted in port forwarding. It
-	// makes sense to include this packet overhead, since we have to tunnel it.
-
 	atomic.AddInt64(packetsMetric, 1)
 	atomic.AddInt64(bytesMetric, int64(packetLength))
+	atomic.AddInt64(applicationBytesMetric, int64(applicationDataLength))
 }
 
 const (
@@ -1338,7 +1474,7 @@ const (
 )
 
 func (metrics *packetMetrics) checkpoint(
-	logger common.Logger, logName string, whichMetrics int) {
+	logger common.Logger, updater MetricsUpdater, logName string, whichMetrics int) {
 
 	// Report all metric counters in a single log message. Each
 	// counter is reset to 0 when added to the log.
@@ -1357,21 +1493,41 @@ func (metrics *packetMetrics) checkpoint(
 
 	if whichMetrics&packetMetricsRelayed != 0 {
 
+		var TCPApplicationBytesUp, TCPApplicationBytesDown,
+			UDPApplicationBytesUp, UDPApplicationBytesDown int64
+
 		relayedMetrics := []struct {
-			prefix  string
-			metrics *relayedPacketMetrics
+			prefix           string
+			metrics          *relayedPacketMetrics
+			updaterBytesUp   *int64
+			updaterBytesDown *int64
 		}{
-			{"tcp_ipv4_", &metrics.TCPIPv4},
-			{"tcp_ipv6_", &metrics.TCPIPv6},
-			{"udp_ipv4_", &metrics.UDPIPv4},
-			{"udp_ipv6_", &metrics.UDPIPv6},
+			{"tcp_ipv4_", &metrics.TCPIPv4, &TCPApplicationBytesUp, &TCPApplicationBytesDown},
+			{"tcp_ipv6_", &metrics.TCPIPv6, &TCPApplicationBytesUp, &TCPApplicationBytesDown},
+			{"udp_ipv4_", &metrics.UDPIPv4, &UDPApplicationBytesUp, &UDPApplicationBytesDown},
+			{"udp_ipv6_", &metrics.UDPIPv6, &UDPApplicationBytesUp, &UDPApplicationBytesDown},
 		}
 
 		for _, r := range relayedMetrics {
+
+			applicationBytesUp := atomic.SwapInt64(&r.metrics.applicationBytesUp, 0)
+			applicationBytesDown := atomic.SwapInt64(&r.metrics.applicationBytesDown, 0)
+
+			*r.updaterBytesUp += applicationBytesUp
+			*r.updaterBytesDown += applicationBytesDown
+
 			logFields[r.prefix+"packets_up"] = atomic.SwapInt64(&r.metrics.packetsUp, 0)
 			logFields[r.prefix+"packets_down"] = atomic.SwapInt64(&r.metrics.packetsDown, 0)
 			logFields[r.prefix+"bytes_up"] = atomic.SwapInt64(&r.metrics.bytesUp, 0)
 			logFields[r.prefix+"bytes_down"] = atomic.SwapInt64(&r.metrics.bytesDown, 0)
+			logFields[r.prefix+"application_bytes_up"] = applicationBytesUp
+			logFields[r.prefix+"application_bytes_down"] = applicationBytesDown
+		}
+
+		if updater != nil {
+			updater(
+				TCPApplicationBytesUp, TCPApplicationBytesDown,
+				UDPApplicationBytesUp, UDPApplicationBytesDown)
 		}
 	}
 
@@ -1726,7 +1882,7 @@ func (client *Client) Stop() {
 	client.workers.Wait()
 
 	client.metrics.checkpoint(
-		client.config.Logger, "packet_metrics", packetMetricsAll)
+		client.config.Logger, nil, "packet_metrics", packetMetricsAll)
 
 	client.config.Logger.WithContext().Info("stopped")
 }
@@ -2205,9 +2361,12 @@ func processPacket(
 
 		if protocol == internetProtocolTCP {
 
+			checkAllowedTCPPortFunc := session.getCheckAllowedTCPPortFunc()
+
 			if checkPort == 0 ||
 				(isServer &&
-					!session.checkAllowedTCPPortFunc(net.IP(ID.upstreamIPAddress[:]), checkPort)) {
+					(checkAllowedTCPPortFunc == nil ||
+						!checkAllowedTCPPortFunc(net.IP(ID.upstreamIPAddress[:]), checkPort))) {
 
 				metrics.rejectedPacket(direction, packetRejectTCPPort)
 				return false
@@ -2215,9 +2374,12 @@ func processPacket(
 
 		} else if protocol == internetProtocolUDP {
 
+			checkAllowedUDPPortFunc := session.getCheckAllowedUDPPortFunc()
+
 			if checkPort == 0 ||
 				(isServer &&
-					!session.checkAllowedUDPPortFunc(net.IP(ID.upstreamIPAddress[:]), checkPort)) {
+					(checkAllowedUDPPortFunc == nil ||
+						!checkAllowedUDPPortFunc(net.IP(ID.upstreamIPAddress[:]), checkPort))) {
 
 				metrics.rejectedPacket(direction, packetRejectUDPPort)
 				return false
@@ -2341,7 +2503,7 @@ func processPacket(
 		}
 	}
 
-	metrics.relayedPacket(direction, int(version), protocol, len(packet))
+	metrics.relayedPacket(direction, int(version), protocol, len(packet), len(applicationData))
 
 	return true
 }

+ 47 - 20
psiphon/common/tun/tun_test.go

@@ -108,12 +108,20 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 		t.Fatalf("startTestTCPServer failed: %s", err)
 	}
 
-	var counter bytesTransferredCounter
+	var flowCounter bytesTransferredCounter
+
 	flowActivityUpdaterMaker := func(_ string, _ net.IP) []FlowActivityUpdater {
-		return []FlowActivityUpdater{&counter}
+		return []FlowActivityUpdater{&flowCounter}
+	}
+
+	var metricsCounter bytesTransferredCounter
+
+	metricsUpdater := func(TCPApplicationBytesUp, TCPApplicationBytesDown, _, _ int64) {
+		metricsCounter.UpdateProgress(
+			TCPApplicationBytesUp, TCPApplicationBytesDown, 0)
 	}
 
-	testServer, err := startTestServer(useIPv6, MTU, flowActivityUpdaterMaker)
+	testServer, err := startTestServer(useIPv6, MTU, flowActivityUpdaterMaker, metricsUpdater)
 	if err != nil {
 		t.Fatalf("startTestServer failed: %s", err)
 	}
@@ -201,6 +209,8 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 				{"packets_down", TCP_RELAY_TOTAL_SIZE / int64(MTU)},
 				{"bytes_up", TCP_RELAY_TOTAL_SIZE},
 				{"bytes_down", TCP_RELAY_TOTAL_SIZE},
+				{"application_bytes_up", TCP_RELAY_TOTAL_SIZE},
+				{"application_bytes_down", TCP_RELAY_TOTAL_SIZE},
 			}
 
 			for _, expectedField := range expectedFields {
@@ -240,14 +250,25 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 	// Note: reported bytes transferred can exceed expected bytes
 	// transferred due to retransmission of packets.
 
-	upstreamBytesTransferred, downstreamBytesTransferred, _ := counter.Get()
 	expectedBytesTransferred := CONCURRENT_CLIENT_COUNT * TCP_RELAY_TOTAL_SIZE
+
+	upstreamBytesTransferred, downstreamBytesTransferred, _ := flowCounter.Get()
+	if upstreamBytesTransferred < expectedBytesTransferred {
+		t.Fatalf("unexpected flow upstreamBytesTransferred: %d; expected at least %d",
+			upstreamBytesTransferred, expectedBytesTransferred)
+	}
+	if downstreamBytesTransferred < expectedBytesTransferred {
+		t.Fatalf("unexpected flow downstreamBytesTransferred: %d; expected at least %d",
+			downstreamBytesTransferred, expectedBytesTransferred)
+	}
+
+	upstreamBytesTransferred, downstreamBytesTransferred, _ = metricsCounter.Get()
 	if upstreamBytesTransferred < expectedBytesTransferred {
-		t.Fatalf("unexpected upstreamBytesTransferred: %d; expected at least %d",
+		t.Fatalf("unexpected metrics upstreamBytesTransferred: %d; expected at least %d",
 			upstreamBytesTransferred, expectedBytesTransferred)
 	}
 	if downstreamBytesTransferred < expectedBytesTransferred {
-		t.Fatalf("unexpected downstreamBytesTransferred: %d; expected at least %d",
+		t.Fatalf("unexpected metrics downstreamBytesTransferred: %d; expected at least %d",
 			downstreamBytesTransferred, expectedBytesTransferred)
 	}
 
@@ -280,16 +301,20 @@ func (counter *bytesTransferredCounter) Get() (int64, int64, int64) {
 }
 
 type testServer struct {
-	logger       *testLogger
-	updaterMaker FlowActivityUpdaterMaker
-	tunServer    *Server
-	unixListener net.Listener
-	clientConns  *common.Conns
-	workers      *sync.WaitGroup
+	logger         *testLogger
+	updaterMaker   FlowActivityUpdaterMaker
+	metricsUpdater MetricsUpdater
+	tunServer      *Server
+	unixListener   net.Listener
+	clientConns    *common.Conns
+	workers        *sync.WaitGroup
 }
 
 func startTestServer(
-	useIPv6 bool, MTU int, updaterMaker FlowActivityUpdaterMaker) (*testServer, error) {
+	useIPv6 bool,
+	MTU int,
+	updaterMaker FlowActivityUpdaterMaker,
+	metricsUpdater MetricsUpdater) (*testServer, error) {
 
 	logger := newTestLogger(true)
 
@@ -319,12 +344,13 @@ func startTestServer(
 	}
 
 	server := &testServer{
-		logger:       logger,
-		updaterMaker: updaterMaker,
-		tunServer:    tunServer,
-		unixListener: unixListener,
-		clientConns:  new(common.Conns),
-		workers:      new(sync.WaitGroup),
+		logger:         logger,
+		updaterMaker:   updaterMaker,
+		metricsUpdater: metricsUpdater,
+		tunServer:      tunServer,
+		unixListener:   unixListener,
+		clientConns:    new(common.Conns),
+		workers:        new(sync.WaitGroup),
 	}
 
 	server.workers.Add(1)
@@ -367,7 +393,8 @@ func (server *testServer) run() {
 				signalConn,
 				checkAllowedPortFunc,
 				checkAllowedPortFunc,
-				server.updaterMaker)
+				server.updaterMaker,
+				server.metricsUpdater)
 
 			signalConn.Wait()
 

+ 14 - 1
psiphon/server/tunnelServer.go

@@ -1476,12 +1476,25 @@ func (sshClient *sshClient) runTunnel(
 				return updaters
 			}
 
+			metricUpdater := func(
+				TCPApplicationBytesUp, TCPApplicationBytesDown,
+				UDPApplicationBytesUp, UDPApplicationBytesDown int64) {
+
+				sshClient.Lock()
+				sshClient.tcpTrafficState.bytesUp += TCPApplicationBytesUp
+				sshClient.tcpTrafficState.bytesDown += TCPApplicationBytesDown
+				sshClient.udpTrafficState.bytesUp += UDPApplicationBytesUp
+				sshClient.udpTrafficState.bytesDown += UDPApplicationBytesDown
+				sshClient.Unlock()
+			}
+
 			err = sshClient.sshServer.support.PacketTunnelServer.ClientConnected(
 				sshClient.sessionID,
 				packetTunnelChannel,
 				checkAllowedTCPPortFunc,
 				checkAllowedUDPPortFunc,
-				flowActivityUpdaterMaker)
+				flowActivityUpdaterMaker,
+				metricUpdater)
 			if err != nil {
 				log.WithContextFields(LogFields{"error": err}).Warning("start packet tunnel client failed")
 				sshClient.setPacketTunnelChannel(nil)