Просмотр исходного кода

Bug fixes and enhancements

- Add InproxyBrokerAllowBogonWebRTCConnections broker debug config param
- Add CPU percent to server_load
- Record and log broker match metrics
- Log proxy announcement delay and elapsed time
- Exclude the broker endpoint from meek rate limiting
- Fix resolver race conditions which led to dropped A or AAAA results
- Replace juju/ratelimit with golang.org/x/time/rate
- Document personal pairing mode limitations
Rod Hynes 1 год назад
Родитель
Сommit
dde8c33dc9

+ 12 - 0
psiphon/common/inproxy/api.go

@@ -162,6 +162,18 @@ func (p NetworkProtocol) String() string {
 	return ""
 }
 
+// IsStream indicates if the NetworkProtocol is stream-oriented (e.g., TCP)
+// and not packet-oriented (e.g., UDP).
+func (p NetworkProtocol) IsStream() bool {
+	switch p {
+	case NetworkProtocolTCP:
+		return true
+	case NetworkProtocolUDP:
+		return false
+	}
+	return false
+}
+
 // ProxyMetrics are network topolology and resource metrics provided by a
 // proxy to a broker. The broker uses this information when matching proxies
 // and clients.

+ 6 - 2
psiphon/common/inproxy/broker.go

@@ -442,6 +442,7 @@ func (b *Broker) handleProxyAnnounce(
 	var logFields common.LogFields
 	var newTacticsTag string
 	var clientOffer *MatchOffer
+	var matchMetrics *MatchMetrics
 	var timedOut bool
 	var limitedErr error
 
@@ -503,6 +504,7 @@ func (b *Broker) handleProxyAnnounce(
 			logFields["error"] = limitedErr.Error()
 		}
 		logFields.Add(transportLogFields)
+		logFields.Add(matchMetrics.GetMetrics())
 		b.config.Logger.LogMetric(brokerMetricName, logFields)
 	}()
 
@@ -582,7 +584,7 @@ func (b *Broker) handleProxyAnnounce(
 	defer cancelFunc()
 	extendTransportTimeout(timeout)
 
-	clientOffer, err = b.matcher.Announce(
+	clientOffer, matchMetrics, err = b.matcher.Announce(
 		announceCtx,
 		proxyIP,
 		&MatchAnnouncement{
@@ -708,6 +710,7 @@ func (b *Broker) handleClientOffer(
 	var clientMatchOffer *MatchOffer
 	var proxyMatchAnnouncement *MatchAnnouncement
 	var proxyAnswer *MatchAnswer
+	var matchMetrics *MatchMetrics
 	var timedOut bool
 	var limitedErr error
 
@@ -747,6 +750,7 @@ func (b *Broker) handleClientOffer(
 			logFields["error"] = limitedErr.Error()
 		}
 		logFields.Add(transportLogFields)
+		logFields.Add(matchMetrics.GetMetrics())
 		b.config.Logger.LogMetric(brokerMetricName, logFields)
 	}()
 
@@ -823,7 +827,7 @@ func (b *Broker) handleClientOffer(
 		DestinationServerID:         serverParams.serverID,
 	}
 
-	proxyAnswer, proxyMatchAnnouncement, err = b.matcher.Offer(
+	proxyAnswer, proxyMatchAnnouncement, matchMetrics, err = b.matcher.Offer(
 		offerCtx,
 		clientIP,
 		clientMatchOffer)

+ 96 - 49
psiphon/common/inproxy/matcher.go

@@ -23,13 +23,14 @@ import (
 	std_errors "errors"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	lrucache "github.com/cognusion/go-cache-lru"
 	"github.com/gammazero/deque"
-	"github.com/juju/ratelimit"
+	"golang.org/x/time/rate"
 )
 
 // TTLs should be aligned with STUN hole punch lifetimes.
@@ -48,7 +49,7 @@ const (
 // coordinates pending proxy answers and routes answers to the awaiting
 // client offer handler.
 //
-// Matching prioritizes selecting the oldest announcments and client offers,
+// Matching prioritizes selecting the oldest announcements and client offers,
 // as they are closest to timing out.
 //
 // The client and proxy must supply matching personal or common compartment
@@ -65,7 +66,7 @@ const (
 // Candidates with unknown NAT types and mobile network types are assumed to
 // have the most limited NAT traversal capability.
 //
-// Preferred matchings take priority over announcment age.
+// Preferred matchings take priority over announcement age.
 //
 // The client and proxy will not match if they are in the same country and
 // ASN, as it's assumed that doesn't provide any blocking circumvention
@@ -202,6 +203,28 @@ type MatchAnswer struct {
 	ProxyAnswerSDP               WebRTCSessionDescription
 }
 
+// MatchMetrics records statistics about the match queue state at the time a
+// match is made.
+type MatchMetrics struct {
+	OfferMatchIndex        int
+	OfferQueueSize         int
+	AnnouncementMatchIndex int
+	AnnouncementQueueSize  int
+}
+
+// GetMetrics converts MatchMetrics to loggable fields.
+func (metrics *MatchMetrics) GetMetrics() common.LogFields {
+	if metrics == nil {
+		return nil
+	}
+	return common.LogFields{
+		"offer_match_index":        metrics.OfferMatchIndex,
+		"offer_queue_size":         metrics.OfferQueueSize,
+		"announcement_match_index": metrics.AnnouncementMatchIndex,
+		"announcement_queue_size":  metrics.AnnouncementQueueSize,
+	}
+}
+
 // announcementEntry is an announcement queue entry, an announcement with its
 // associated lifetime context and signaling channel.
 type announcementEntry struct {
@@ -209,15 +232,27 @@ type announcementEntry struct {
 	limitIP      string
 	announcement *MatchAnnouncement
 	offerChan    chan *MatchOffer
+	matchMetrics atomic.Value
+}
+
+func (announcementEntry *announcementEntry) getMatchMetrics() *MatchMetrics {
+	matchMetrics, _ := announcementEntry.matchMetrics.Load().(*MatchMetrics)
+	return matchMetrics
 }
 
 // offerEntry is an offer queue entry, an offer with its associated lifetime
 // context and signaling channel.
 type offerEntry struct {
-	ctx        context.Context
-	limitIP    string
-	offer      *MatchOffer
-	answerChan chan *answerInfo
+	ctx          context.Context
+	limitIP      string
+	offer        *MatchOffer
+	answerChan   chan *answerInfo
+	matchMetrics atomic.Value
+}
+
+func (offerEntry *offerEntry) getMatchMetrics() *MatchMetrics {
+	matchMetrics, _ := offerEntry.matchMetrics.Load().(*MatchMetrics)
+	return matchMetrics
 }
 
 // answerInfo is an answer and its associated announcement.
@@ -373,10 +408,13 @@ func (m *Matcher) Stop() {
 //
 // The offer is sent to the proxy by the broker, and then the proxy sends its
 // answer back to the broker, which calls Answer with that value.
+//
+// The returned MatchMetrics is nil unless a match is made; and non-nil if a
+// match is made, even if there is a later error.
 func (m *Matcher) Announce(
 	ctx context.Context,
 	proxyIP string,
-	proxyAnnouncement *MatchAnnouncement) (*MatchOffer, error) {
+	proxyAnnouncement *MatchAnnouncement) (*MatchOffer, *MatchMetrics, error) {
 
 	announcementEntry := &announcementEntry{
 		ctx:          ctx,
@@ -387,7 +425,7 @@ func (m *Matcher) Announce(
 
 	err := m.addAnnouncementEntry(announcementEntry)
 	if err != nil {
-		return nil, errors.Trace(err)
+		return nil, nil, errors.Trace(err)
 	}
 
 	// Await client offer.
@@ -397,12 +435,12 @@ func (m *Matcher) Announce(
 	select {
 	case <-ctx.Done():
 		m.removeAnnouncementEntry(announcementEntry)
-		return nil, errors.Trace(ctx.Err())
+		return nil, announcementEntry.getMatchMetrics(), errors.Trace(ctx.Err())
 
 	case clientOffer = <-announcementEntry.offerChan:
 	}
 
-	return clientOffer, nil
+	return clientOffer, announcementEntry.getMatchMetrics(), nil
 }
 
 // Offer enqueues the client offer and blocks until it is matched with a
@@ -412,10 +450,13 @@ func (m *Matcher) Announce(
 // The answer is returned to the client by the broker, and the WebRTC
 // connection is dialed. The original announcement is also returned, so its
 // match properties can be logged.
+//
+// The returned MatchMetrics is nil unless a match is made; and non-nil if a
+// match is made, even if there is a later error.
 func (m *Matcher) Offer(
 	ctx context.Context,
 	clientIP string,
-	clientOffer *MatchOffer) (*MatchAnswer, *MatchAnnouncement, error) {
+	clientOffer *MatchOffer) (*MatchAnswer, *MatchAnnouncement, *MatchMetrics, error) {
 
 	offerEntry := &offerEntry{
 		ctx:        ctx,
@@ -426,7 +467,7 @@ func (m *Matcher) Offer(
 
 	err := m.addOfferEntry(offerEntry)
 	if err != nil {
-		return nil, nil, errors.Trace(err)
+		return nil, nil, nil, errors.Trace(err)
 	}
 
 	// Await proxy answer.
@@ -442,7 +483,8 @@ func (m *Matcher) Offer(
 		// get removed. But a client may abort its request earlier than the
 		// timeout.
 
-		return nil, nil, errors.Trace(ctx.Err())
+		return nil, nil,
+			offerEntry.getMatchMetrics(), errors.Trace(ctx.Err())
 
 	case proxyAnswerInfo = <-offerEntry.answerChan:
 	}
@@ -450,18 +492,23 @@ func (m *Matcher) Offer(
 	if proxyAnswerInfo == nil {
 
 		// nil will be delivered to the channel when either the proxy
-		// announcment request concurrently timed out, or the answer
+		// announcement request concurrently timed out, or the answer
 		// indicated a proxy error, or the answer did not arrive in time.
-		return nil, nil, errors.TraceNew("no answer")
+		return nil, nil,
+			offerEntry.getMatchMetrics(), errors.TraceNew("no answer")
 	}
 
 	// This is a sanity check and not expected to fail.
 	if !proxyAnswerInfo.answer.ConnectionID.Equal(
 		proxyAnswerInfo.announcement.ConnectionID) {
-		return nil, nil, errors.TraceNew("unexpected connection ID")
+		return nil, nil,
+			offerEntry.getMatchMetrics(), errors.TraceNew("unexpected connection ID")
 	}
 
-	return proxyAnswerInfo.answer, proxyAnswerInfo.announcement, nil
+	return proxyAnswerInfo.answer,
+		proxyAnswerInfo.announcement,
+		offerEntry.getMatchMetrics(),
+		nil
 }
 
 // Answer delivers an answer from the proxy for a previously matched offer.
@@ -569,6 +616,22 @@ func (m *Matcher) matchAllOffers() {
 			continue
 		}
 
+		// Get the matched announcement entry.
+
+		announcementEntry := m.announcementQueue.At(j)
+
+		// Record match metrics.
+
+		matchMetrics := &MatchMetrics{
+			OfferMatchIndex:        i,
+			OfferQueueSize:         m.offerQueue.Len(),
+			AnnouncementMatchIndex: j,
+			AnnouncementQueueSize:  m.announcementQueue.Len(),
+		}
+
+		offerEntry.matchMetrics.Store(matchMetrics)
+		announcementEntry.matchMetrics.Store(matchMetrics)
+
 		// Remove the matched announcement from the queue. Send the offer to
 		// the announcement entry's offerChan, which will deliver it to the
 		// blocked Announce call. Add a pending answers entry to await the
@@ -576,30 +639,6 @@ func (m *Matcher) matchAllOffers() {
 		// entry is set to the matched Offer call's ctx, as the answer is
 		// only useful as long as the client is still waiting.
 
-		announcementEntry := m.announcementQueue.At(j)
-
-		if m.config.Logger.IsLogLevelDebug() {
-
-			announcementProxyID :=
-				announcementEntry.announcement.ProxyID
-			announcementConnectionID :=
-				announcementEntry.announcement.ConnectionID
-			announcementCommonCompartmentIDs :=
-				announcementEntry.announcement.Properties.CommonCompartmentIDs
-			offerCommonCompartmentIDs :=
-				offerEntry.offer.Properties.CommonCompartmentIDs
-
-			m.config.Logger.WithTraceFields(common.LogFields{
-				"announcement_proxy_id":               announcementProxyID,
-				"announcement_connection_id":          announcementConnectionID,
-				"announcement_common_compartment_ids": announcementCommonCompartmentIDs,
-				"offer_common_compartment_ids":        offerCommonCompartmentIDs,
-				"match_index":                         j,
-				"announcement_queue_size":             m.announcementQueue.Len(),
-				"offer_queue_size":                    m.offerQueue.Len(),
-			}).Debug("match metrics")
-		}
-
 		expiry := lrucache.DefaultExpiration
 		deadline, ok := offerEntry.ctx.Deadline()
 		if ok {
@@ -636,8 +675,8 @@ func (m *Matcher) matchOffer(offerEntry *offerEntry) (int, bool) {
 	// Assumes the caller has the queue mutexed locked.
 
 	// Check each announcement in turn, and select a match. There is an
-	// implicit preference for older proxy announcments, sooner to timeout, at the
-	// front of the queue.
+	// implicit preference for older proxy announcements, sooner to timeout,
+	// at the front of the queue.
 	//
 	// Limitation: since this logic matches each enqueued client in turn, it will
 	// only make the optimal NAT match for the oldest enqueued client vs. all
@@ -668,6 +707,14 @@ func (m *Matcher) matchOffer(offerEntry *offerEntry) (int, bool) {
 
 	end := m.announcementQueue.Len()
 
+	// TODO: add queue indexing to facilitate skipping ahead to a matching
+	// personal compartment ID, if any, when personal-only matching is
+	// required. Personal matching may often require near-full queue scans
+	// when looking for a match. Common compartment matching may also benefit
+	// from indexing, although with a handful of common compartment IDs more
+	// or less uniformly distributed, frequent long scans are not expected in
+	// practise.
+
 	for i := 0; i < end; i++ {
 
 		announcementEntry := m.announcementQueue.At(i)
@@ -815,19 +862,19 @@ func (m *Matcher) applyLimits(isAnnouncement bool, limitIP string, proxyID ID) e
 
 	if quantity > 0 && interval > 0 {
 
-		var rateLimiter *ratelimit.Bucket
+		var rateLimiter *rate.Limiter
 
 		entry, ok := queueRateLimiters.Get(limitIP)
 		if ok {
-			rateLimiter = entry.(*ratelimit.Bucket)
+			rateLimiter = entry.(*rate.Limiter)
 		} else {
-			rateLimiter = ratelimit.NewBucketWithQuantum(
-				interval, int64(quantity), int64(quantity))
+			limit := float64(quantity) / interval.Seconds()
+			rateLimiter = rate.NewLimiter(rate.Limit(limit), quantity)
 			queueRateLimiters.Set(
 				limitIP, rateLimiter, interval)
 		}
 
-		if rateLimiter.TakeAvailable(1) < 1 {
+		if !rateLimiter.Allow() {
 			return errors.Trace(
 				NewMatcherLimitError(std_errors.New("rate exceeded for IP")))
 		}

+ 40 - 7
psiphon/common/inproxy/matcher_test.go

@@ -23,6 +23,7 @@ import (
 	"context"
 	"fmt"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 
@@ -43,7 +44,7 @@ func runTestMatcher() error {
 
 	limitEntryCount := 50
 	rateLimitQuantity := 100
-	rateLimitInterval := 500 * time.Millisecond
+	rateLimitInterval := 1000 * time.Millisecond
 
 	logger := newTestLogger()
 
@@ -88,6 +89,13 @@ func runTestMatcher() error {
 		}
 	}
 
+	checkMatchMetrics := func(metrics *MatchMetrics) error {
+		if metrics.OfferQueueSize < 1 || metrics.AnnouncementQueueSize < 1 {
+			return errors.TraceNew("unexpected match metrics")
+		}
+		return nil
+	}
+
 	proxyIP := randomIPAddress()
 
 	proxyFunc := func(
@@ -102,10 +110,16 @@ func runTestMatcher() error {
 		defer cancelFunc()
 
 		announcement := makeAnnouncement(matchProperties)
-		offer, err := m.Announce(ctx, proxyIP, announcement)
+		offer, matchMetrics, err := m.Announce(ctx, proxyIP, announcement)
 		if err != nil {
 			resultChan <- errors.Trace(err)
 			return
+		} else {
+			err := checkMatchMetrics(matchMetrics)
+			if err != nil {
+				resultChan <- errors.Trace(err)
+				return
+			}
 		}
 
 		if waitBeforeAnswer != nil {
@@ -137,7 +151,7 @@ func runTestMatcher() error {
 		defer cancelFunc()
 
 		offer := makeOffer(matchProperties)
-		answer, _, err := m.Offer(ctx, clientIP, offer)
+		answer, _, matchMetrics, err := m.Offer(ctx, clientIP, offer)
 		if err != nil {
 			resultChan <- errors.Trace(err)
 			return
@@ -145,6 +159,12 @@ func runTestMatcher() error {
 		if answer.SelectedProxyProtocolVersion != offer.ClientProxyProtocolVersion {
 			resultChan <- errors.TraceNew("unexpected selected proxy protocol version")
 			return
+		} else {
+			err := checkMatchMetrics(matchMetrics)
+			if err != nil {
+				resultChan <- errors.Trace(err)
+				return
+			}
 		}
 		resultChan <- nil
 	}
@@ -293,11 +313,19 @@ func runTestMatcher() error {
 	maxEntries = rateLimitQuantity
 	maxEntriesProxyResultChan = make(chan error, maxEntries)
 
+	waitGroup := new(sync.WaitGroup)
 	for i := 0; i < maxEntries; i++ {
-		go proxyFunc(maxEntriesProxyResultChan, proxyIP, &MatchProperties{}, 1*time.Microsecond, nil, true)
+		waitGroup.Add(1)
+		go func() {
+			defer waitGroup.Done()
+			proxyFunc(maxEntriesProxyResultChan, proxyIP, &MatchProperties{}, 1*time.Microsecond, nil, true)
+		}()
 	}
 
-	time.Sleep(rateLimitInterval / 2)
+	// Use a wait group to ensure all maxEntries have hit the rate limiter
+	// without sleeping before the next attempt, as any sleep can increase
+	// the rate limiter token count.
+	waitGroup.Wait()
 
 	// the next enqueue should fail with "rate exceeded"
 	go proxyFunc(proxyResultChan, proxyIP, &MatchProperties{}, 10*time.Millisecond, nil, true)
@@ -311,11 +339,16 @@ func runTestMatcher() error {
 	maxEntries = rateLimitQuantity
 	maxEntriesClientResultChan = make(chan error, maxEntries)
 
+	waitGroup = new(sync.WaitGroup)
 	for i := 0; i < rateLimitQuantity; i++ {
-		go clientFunc(maxEntriesClientResultChan, clientIP, &MatchProperties{}, 1*time.Microsecond)
+		waitGroup.Add(1)
+		go func() {
+			defer waitGroup.Done()
+			clientFunc(maxEntriesClientResultChan, clientIP, &MatchProperties{}, 1*time.Microsecond)
+		}()
 	}
 
-	time.Sleep(rateLimitInterval / 2)
+	waitGroup.Wait()
 
 	// enqueue should fail with "rate exceeded"
 	go clientFunc(clientResultChan, clientIP, &MatchProperties{}, 10*time.Millisecond)

+ 8 - 0
psiphon/common/inproxy/proxy.go

@@ -571,6 +571,7 @@ func (p *Proxy) proxyOneClient(
 	//
 	// ProxyAnnounce applies an additional request timeout to facilitate
 	// long-polling.
+	announceStartTime := time.Now()
 	announceResponse, err := brokerClient.ProxyAnnounce(
 		ctx,
 		requestDelay,
@@ -578,6 +579,12 @@ func (p *Proxy) proxyOneClient(
 			PersonalCompartmentIDs: brokerCoordinator.PersonalCompartmentIDs(),
 			Metrics:                metrics,
 		})
+
+	p.config.Logger.WithTraceFields(common.LogFields{
+		"delay":       requestDelay,
+		"elapsedTime": time.Since(announceStartTime),
+	}).Info("announcement request")
+
 	if err != nil {
 		return backOff, errors.Trace(err)
 	}
@@ -787,6 +794,7 @@ func (p *Proxy) proxyOneClient(
 
 	destinationConn = common.NewThrottledConn(
 		destinationConn,
+		announceResponse.NetworkProtocol.IsStream(),
 		common.RateLimits{
 			ReadBytesPerSecond:  int64(p.config.LimitUpstreamBytesPerSecond),
 			WriteBytesPerSecond: int64(p.config.LimitDownstreamBytesPerSecond),

+ 69 - 46
psiphon/common/resolver/resolver.go

@@ -272,7 +272,7 @@ type resolverMetrics struct {
 	responsesIPv6           int
 	defaultResolves         int
 	defaultSuccesses        int
-	peakInFlight            int64
+	peakInFlight            int
 	minRTT                  time.Duration
 	maxRTT                  time.Duration
 }
@@ -673,9 +673,10 @@ func (r *Resolver) ResolveIP(
 	waitGroup := new(sync.WaitGroup)
 	conns := common.NewConns[net.Conn]()
 	type answer struct {
-		attempt int
-		IPs     []net.IP
-		TTLs    []time.Duration
+		attempt      int
+		questionType resolverQuestionType
+		IPs          []net.IP
+		TTLs         []time.Duration
 	}
 	var maxAttempts int
 	if params.PreferAlternateDNSServer {
@@ -685,15 +686,32 @@ func (r *Resolver) ResolveIP(
 		maxAttempts = len(servers) * params.AttemptsPerServer
 	}
 	answerChan := make(chan *answer, maxAttempts*2)
-	inFlight := int64(0)
-	awaitA := int32(1)
-	awaitAAAA := int32(1)
-	if !hasIPv6Route {
-		awaitAAAA = 0
-	}
+	inFlight := 0
+	awaitA := true
+	awaitAAAA := hasIPv6Route
 	var result *answer
 	var lastErr atomic.Value
 
+	trackResult := func(a *answer) {
+
+		// A result is sent from every attempt goroutine that is launched,
+		// even in the case of an error, in which case the result is nil.
+		// Update the number of in-flight attempts as results are received.
+		// Mark no longer awaiting A or AAAA as long as there is a valid
+		// response, even if there are no IPs in the IPv6 case.
+		if inFlight > 0 {
+			inFlight -= 1
+		}
+		if a != nil {
+			switch a.questionType {
+			case resolverQuestionTypeA:
+				awaitA = false
+			case resolverQuestionTypeAAAA:
+				awaitAAAA = false
+			}
+		}
+	}
+
 	stop := false
 	for i := 0; !stop && i < maxAttempts; i++ {
 
@@ -731,32 +749,29 @@ func (r *Resolver) ResolveIP(
 			// correct, we must increment inFlight in this outer goroutine to
 			// ensure the await logic sees either inFlight > 0 or an answer
 			// in the channel.
-			r.updateMetricPeakInFlight(atomic.AddInt64(&inFlight, 1))
+			inFlight += 1
+			r.updateMetricPeakInFlight(inFlight)
 
 			go func(attempt int, questionType resolverQuestionType, useProtocolTransform bool) {
 				defer waitGroup.Done()
 
 				// Always send a result back to the main loop, even if this
 				// attempt fails, so the main loop proceeds to the next
-				// iteration immediately. Nil is sent in failure cases.
+				// iteration immediately. Nil is sent in failure cases. When
+				// the answer is not nil, it's already been sent.
 				var a *answer
 				defer func() {
-					select {
-					case answerChan <- a:
-					default:
+					if a == nil {
+						// The channel should have sufficient buffering for
+						// the send to never block; the default case is used
+						// to avoid a hang in the case of a bug.
+						select {
+						case answerChan <- a:
+						default:
+						}
 					}
 				}()
 
-				// We must decrement inFlight only after sending an answer and
-				// setting awaitA or awaitAAAA to ensure that the await logic
-				// in the outer goroutine will see inFlight 0 only once those
-				// operations are complete.
-				//
-				// We cannot wait and decrement inFlight when the outer
-				// goroutine receives answers, as no answer is sent in some
-				// cases, such as when the resolve fails due to NXDOMAIN.
-				defer atomic.AddInt64(&inFlight, -1)
-
 				// The request count metric counts the _intention_ to send
 				// requests, as there's a possibility that newResolverConn or
 				// performDNSQuery fail locally before sending a request packet.
@@ -840,21 +855,29 @@ func (r *Resolver) ResolveIP(
 					return
 				}
 
-				// Mark no longer awaiting A or AAAA as long as there is a
-				// valid response, even if there are no IPs in the IPv6 case.
+				// Update response stats.
 				switch questionType {
 				case resolverQuestionTypeA:
 					r.updateMetricResponsesIPv4()
-					atomic.StoreInt32(&awaitA, 0)
 				case resolverQuestionTypeAAAA:
 					r.updateMetricResponsesIPv6()
-					atomic.StoreInt32(&awaitAAAA, 0)
-				default:
 				}
 
 				// Send the answer back to the main loop.
-				if len(IPs) > 0 {
-					a = &answer{attempt: attempt, IPs: IPs, TTLs: TTLs}
+				if len(IPs) > 0 || questionType == resolverQuestionTypeAAAA {
+					a = &answer{
+						attempt:      attempt,
+						questionType: questionType,
+						IPs:          IPs,
+						TTLs:         TTLs}
+
+					// The channel should have sufficient buffering for
+					// the send to never block; the default case is used
+					// to avoid a hang in the case of a bug.
+					select {
+					case answerChan <- a:
+					default:
+					}
 				}
 
 			}(i+1, questionType, useProtocolTransform)
@@ -864,15 +887,14 @@ func (r *Resolver) ResolveIP(
 
 		select {
 		case result = <-answerChan:
-			if result == nil {
-				// The attempt failed with an error.
-				break
+			trackResult(result)
+			if result != nil {
+				// When the first answer, a response with valid IPs, arrives, exit
+				// the attempts loop. The following await branch may collect
+				// additional answers.
+				params.setFirstAttemptWithAnswer(result.attempt)
+				stop = true
 			}
-			// When the first answer, a response with valid IPs, arrives, exit
-			// the attempts loop. The following await branch may collect
-			// additional answers.
-			params.setFirstAttemptWithAnswer(result.attempt)
-			stop = true
 		case <-timer.C:
 			// When requestTimeout arrives, loop around and launch the next
 			// attempt; leave the existing requests running in case they
@@ -902,6 +924,7 @@ func (r *Resolver) ResolveIP(
 		for loop := true; loop; {
 			select {
 			case nextAnswer := <-answerChan:
+				trackResult(nextAnswer)
 				if nextAnswer != nil {
 					result.IPs = append(result.IPs, nextAnswer.IPs...)
 					result.TTLs = append(result.TTLs, nextAnswer.TTLs...)
@@ -921,8 +944,8 @@ func (r *Resolver) ResolveIP(
 	// have an answer.
 	if result != nil &&
 		resolveCtx.Err() == nil &&
-		atomic.LoadInt64(&inFlight) > 0 &&
-		(atomic.LoadInt32(&awaitA) != 0 || atomic.LoadInt32(&awaitAAAA) != 0) &&
+		inFlight > 0 &&
+		(awaitA || awaitAAAA) &&
 		params.AwaitTimeout > 0 {
 
 		resetTimer(params.AwaitTimeout)
@@ -932,6 +955,7 @@ func (r *Resolver) ResolveIP(
 			stop := false
 			select {
 			case nextAnswer := <-answerChan:
+				trackResult(nextAnswer)
 				if nextAnswer != nil {
 					result.IPs = append(result.IPs, nextAnswer.IPs...)
 					result.TTLs = append(result.TTLs, nextAnswer.TTLs...)
@@ -943,9 +967,8 @@ func (r *Resolver) ResolveIP(
 				stop = true
 			}
 
-			if stop ||
-				atomic.LoadInt64(&inFlight) == 0 ||
-				(atomic.LoadInt32(&awaitA) == 0 && atomic.LoadInt32(&awaitAAAA) == 0) {
+			if stop || inFlight == 0 || (!awaitA && !awaitAAAA) {
+
 				break
 			}
 		}
@@ -1352,7 +1375,7 @@ func (r *Resolver) updateMetricDefaultResolver(success bool) {
 	}
 }
 
-func (r *Resolver) updateMetricPeakInFlight(inFlight int64) {
+func (r *Resolver) updateMetricPeakInFlight(inFlight int) {
 	r.mutex.Lock()
 	defer r.mutex.Unlock()
 

+ 115 - 32
psiphon/common/throttled.go

@@ -26,7 +26,7 @@ import (
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
-	"github.com/juju/ratelimit"
+	"golang.org/x/time/rate"
 )
 
 // RateLimits specify the rate limits for a ThrottledConn.
@@ -72,20 +72,28 @@ type ThrottledConn struct {
 	writeBytesPerSecond   int64
 	closeAfterExhausted   int32
 	readLock              sync.Mutex
-	readRateLimiter       *ratelimit.Bucket
+	readRateLimiter       *rate.Limiter
 	readDelayTimer        *time.Timer
 	writeLock             sync.Mutex
-	writeRateLimiter      *ratelimit.Bucket
+	writeRateLimiter      *rate.Limiter
 	writeDelayTimer       *time.Timer
 	isClosed              int32
 	stopBroadcast         chan struct{}
+	isStream              bool
 	net.Conn
 }
 
 // NewThrottledConn initializes a new ThrottledConn.
-func NewThrottledConn(conn net.Conn, limits RateLimits) *ThrottledConn {
+//
+// Set isStreamConn to true when conn is stream-oriented, such as TCP, and
+// false when the conn is packet-oriented, such as UDP. When conn is a
+// stream, reads and writes may be split to accomodate rate limits.
+func NewThrottledConn(
+	conn net.Conn, isStream bool, limits RateLimits) *ThrottledConn {
+
 	throttledConn := &ThrottledConn{
 		Conn:          conn,
+		isStream:      isStream,
 		stopBroadcast: make(chan struct{}),
 	}
 	throttledConn.SetLimits(limits)
@@ -137,10 +145,8 @@ func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
 	conn.readLock.Lock()
 	defer conn.readLock.Unlock()
 
-	select {
-	case <-conn.stopBroadcast:
+	if atomic.LoadInt32(&conn.isClosed) == 1 {
 		return 0, errors.TraceNew("throttled conn closed")
-	default:
 	}
 
 	// Use the base conn until the unthrottled count is
@@ -158,34 +164,68 @@ func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
 		return 0, errors.TraceNew("throttled conn exhausted")
 	}
 
-	rate := atomic.SwapInt64(&conn.readBytesPerSecond, -1)
+	readRate := atomic.SwapInt64(&conn.readBytesPerSecond, -1)
 
-	if rate != -1 {
+	if readRate != -1 {
 		// SetLimits has been called and a new rate limiter
 		// must be initialized. When no limit is specified,
 		// the reader/writer is simply the base conn.
 		// No state is retained from the previous rate limiter,
 		// so a pending I/O throttle sleep may be skipped when
 		// the old and new rate are similar.
-		if rate == 0 {
+		if readRate == 0 {
 			conn.readRateLimiter = nil
 		} else {
 			conn.readRateLimiter =
-				ratelimit.NewBucketWithRate(float64(rate), rate)
+				rate.NewLimiter(rate.Limit(readRate), int(readRate))
+		}
+	}
+
+	// The number of bytes read cannot exceed the rate limiter burst size,
+	// which is enforced by rate.Limiter.ReserveN. Reduce any read buffer
+	// size to be at most the burst size.
+	//
+	// Read should still return as soon as read bytes are available; and the
+	// number of bytes that will be received is unknown; so there is no loop
+	// here to read more bytes. Reducing the read buffer size minimizes
+	// latency for the up-to-burst-size bytes read, whereas allowing a full
+	// read followed by multiple ReserveN calls and sleeps would increase
+	// latency.
+	//
+	// In practise, with Psiphon tunnels, throttling is not applied until
+	// after the Psiphon API handshake, so read buffer reductions won't
+	// impact early obfuscation traffic shaping; and reads are on the order
+	// of one SSH "packet", up to 32K, unlikely to be split for all but the
+	// most restrictive of rate limits.
+
+	if conn.readRateLimiter != nil {
+		burst := conn.readRateLimiter.Burst()
+		if len(buffer) > burst {
+			if !conn.isStream {
+				return 0, errors.TraceNew("non-stream read buffer exceeds burst")
+			}
+			buffer = buffer[:burst]
 		}
 	}
 
 	n, err := conn.Conn.Read(buffer)
 
-	// Sleep to enforce the rate limit. This is the same logic as implemented in
-	// ratelimit.Reader, but using a timer and a close signal instead of an
-	// uninterruptible time.Sleep.
-	//
-	// The readDelayTimer is always expired/stopped and drained after this code
-	// block and is ready to be Reset on the next call.
+	if n > 0 && conn.readRateLimiter != nil {
+
+		// While rate.Limiter.WaitN would be simpler to use, internally Wait
+		// creates a new timer for every call which must sleep, which is
+		// expected to be most calls. Instead, call ReserveN to get the sleep
+		// time and reuse one timer without allocation.
+		//
+		// TODO: avoid allocation: ReserveN allocates a *Reservation; while
+		// the internal reserveN returns a struct, not a pointer.
 
-	if n >= 0 && conn.readRateLimiter != nil {
-		sleepDuration := conn.readRateLimiter.Take(int64(n))
+		reservation := conn.readRateLimiter.ReserveN(time.Now(), n)
+		if !reservation.OK() {
+			// This error is not expected, given the buffer size adjustment.
+			return 0, errors.TraceNew("burst size exceeded")
+		}
+		sleepDuration := reservation.Delay()
 		if sleepDuration > 0 {
 			if conn.readDelayTimer == nil {
 				conn.readDelayTimer = time.NewTimer(sleepDuration)
@@ -202,7 +242,8 @@ func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
 		}
 	}
 
-	return n, errors.Trace(err)
+	// Don't wrap I/O errors
+	return n, err
 }
 
 func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
@@ -212,10 +253,8 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 	conn.writeLock.Lock()
 	defer conn.writeLock.Unlock()
 
-	select {
-	case <-conn.stopBroadcast:
+	if atomic.LoadInt32(&conn.isClosed) == 1 {
 		return 0, errors.TraceNew("throttled conn closed")
-	default:
 	}
 
 	if atomic.LoadInt64(&conn.writeUnthrottledBytes) > 0 {
@@ -229,19 +268,58 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 		return 0, errors.TraceNew("throttled conn exhausted")
 	}
 
-	rate := atomic.SwapInt64(&conn.writeBytesPerSecond, -1)
+	writeRate := atomic.SwapInt64(&conn.writeBytesPerSecond, -1)
 
-	if rate != -1 {
-		if rate == 0 {
+	if writeRate != -1 {
+		if writeRate == 0 {
 			conn.writeRateLimiter = nil
 		} else {
 			conn.writeRateLimiter =
-				ratelimit.NewBucketWithRate(float64(rate), rate)
+				rate.NewLimiter(rate.Limit(writeRate), int(writeRate))
 		}
 	}
 
-	if len(buffer) >= 0 && conn.writeRateLimiter != nil {
-		sleepDuration := conn.writeRateLimiter.Take(int64(len(buffer)))
+	if conn.writeRateLimiter == nil {
+		n, err := conn.Conn.Write(buffer)
+		// Don't wrap I/O errors
+		return n, err
+	}
+
+	// The number of bytes written cannot exceed the rate limiter burst size,
+	// which is enforced by rate.Limiter.ReserveN. Split writes to be at most
+	// the burst size.
+	//
+	// Splitting writes may have some effect on the shape of TCP packets sent
+	// on the network.
+	//
+	// In practise, with Psiphon tunnels, throttling is not applied until
+	// after the Psiphon API handshake, so write splits won't impact early
+	// obfuscation traffic shaping; and writes are on the order of one
+	// SSH "packet", up to 32K, unlikely to be split for all but the most
+	// restrictive of rate limits.
+
+	burst := conn.writeRateLimiter.Burst()
+	if !conn.isStream && len(buffer) > burst {
+		return 0, errors.TraceNew("non-stream write exceeds burst")
+	}
+	totalWritten := 0
+	for i := 0; i < len(buffer); i += burst {
+
+		j := i + burst
+		if j > len(buffer) {
+			j = len(buffer)
+		}
+		b := buffer[i:j]
+
+		// See comment in Read regarding rate.Limiter.ReserveN vs.
+		// rate.Limiter.WaitN.
+
+		reservation := conn.writeRateLimiter.ReserveN(time.Now(), len(b))
+		if !reservation.OK() {
+			// This error is not expected, given the write split adjustments.
+			return 0, errors.TraceNew("burst size exceeded")
+		}
+		sleepDuration := reservation.Delay()
 		if sleepDuration > 0 {
 			if conn.writeDelayTimer == nil {
 				conn.writeDelayTimer = time.NewTimer(sleepDuration)
@@ -256,11 +334,16 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
 				}
 			}
 		}
-	}
 
-	n, err := conn.Conn.Write(buffer)
+		n, err := conn.Conn.Write(b)
+		totalWritten += n
+		if err != nil {
+			// Don't wrap I/O errors
+			return totalWritten, err
+		}
+	}
 
-	return n, errors.Trace(err)
+	return totalWritten, nil
 }
 
 func (conn *ThrottledConn) Close() error {

+ 28 - 10
psiphon/common/throttled_test.go

@@ -22,6 +22,7 @@ package common
 import (
 	"bytes"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"math"
 	"net"
@@ -113,7 +114,7 @@ func runRateLimitsTest(t *testing.T, rateLimits RateLimits) {
 		if err != nil {
 			return conn, err
 		}
-		return NewThrottledConn(conn, rateLimits), nil
+		return NewThrottledConn(conn, true, rateLimits), nil
 	}
 
 	client := &http.Client{
@@ -204,27 +205,27 @@ func TestThrottledConnClose(t *testing.T) {
 	n := 4
 	b := make([]byte, n+1)
 
-	throttledConn := NewThrottledConn(&testConn{}, rateLimits)
+	throttledConn := NewThrottledConn(&testConn{}, true, rateLimits)
 
 	now := time.Now()
-	_, err := throttledConn.Read(b)
+	_, err := io.ReadFull(throttledConn, b)
 	elapsed := time.Since(now)
 	if err != nil || elapsed < time.Duration(n)*time.Second {
 		t.Errorf("unexpected interrupted read: %s, %v", elapsed, err)
 	}
 
 	now = time.Now()
-	go func() {
+	go func(conn net.Conn) {
 		time.Sleep(500 * time.Millisecond)
-		throttledConn.Close()
-	}()
+		conn.Close()
+	}(throttledConn)
 	_, err = throttledConn.Read(b)
 	elapsed = time.Since(now)
 	if elapsed > 1*time.Second {
 		t.Errorf("unexpected uninterrupted read: %s, %v", elapsed, err)
 	}
 
-	throttledConn = NewThrottledConn(&testConn{}, rateLimits)
+	throttledConn = NewThrottledConn(&testConn{}, true, rateLimits)
 
 	now = time.Now()
 	_, err = throttledConn.Write(b)
@@ -234,10 +235,10 @@ func TestThrottledConnClose(t *testing.T) {
 	}
 
 	now = time.Now()
-	go func() {
+	go func(conn net.Conn) {
 		time.Sleep(500 * time.Millisecond)
-		throttledConn.Close()
-	}()
+		conn.Close()
+	}(throttledConn)
 	_, err = throttledConn.Write(b)
 	elapsed = time.Since(now)
 	if elapsed > 1*time.Second {
@@ -245,6 +246,23 @@ func TestThrottledConnClose(t *testing.T) {
 	}
 }
 
+func TestNonStreamThrottledConn(t *testing.T) {
+
+	MTU := int64(1500)
+
+	rateLimits := RateLimits{
+		ReadBytesPerSecond:  MTU - 1,
+		WriteBytesPerSecond: MTU - 1,
+	}
+
+	throttledConn := NewThrottledConn(&testConn{}, false, rateLimits)
+
+	_, err := throttledConn.Write(make([]byte, MTU))
+	if err == nil {
+		t.Errorf("unexpected split write")
+	}
+}
+
 type testConn struct {
 }
 

+ 42 - 1
psiphon/config.go

@@ -647,16 +647,57 @@ type Config struct {
 	// IDs used by an in-proxy proxy. Personal compartment IDs are
 	// distributed from proxy operators to client users out-of-band and
 	// provide a mechanism to allow only certain clients to use a proxy.
+	//
+	// See InproxyClientPersonalCompartmentIDs comment for limitations.
 	InproxyProxyPersonalCompartmentIDs []string
 
 	// InproxyClientPersonalCompartmentIDs specifies the personal compartment
 	// IDs used by an in-proxy client. Personal compartment IDs are
 	// distributed from proxy operators to client users out-of-band and
-	// provide a mechanism to ensure a client only uses a certain proxy.
+	// provide a mechanism to ensure a client uses only a certain proxy for
+	// all tunnels connections.
 	//
 	// When InproxyClientPersonalCompartmentIDs is set, the client will use
 	// only in-proxy protocols, ensuring that all connections go through the
 	// proxy or proxies with the same personal compartment IDs.
+	//
+	// Limitations:
+	//
+	// - While fully functional, the personal pairing mode has a number of
+	//   limitations that make the current implementation less suitable for
+	//   large scale deployment.
+	//
+	// - Since the mode requires an in-proxy connection to a proxy, announcing
+	//   with the corresponding personal compartment ID, not only must that
+	//   proxy be available, but also a broker, and both the client and proxy
+	//   must rendezvous at the same broker.
+	//
+	// - Currently, the client tunnel establishment algorithm does not launch
+	//   an untunneled tactics request as long as there is a cached tactics
+	//   with a valid TTL. The assumption, in regular mode, is that the
+	//   cached tactics will suffice, and any new tactics will be obtained
+	//   from any Psiphon server connection. Since broker specs are obtained
+	//   solely from tactics, if brokers are removed, reconfigured, or even
+	//   if the order is changed, personal mode may fail to connect until
+	//   cached tactics expire.
+	//
+	// - In personal mode, clients and proxies use a simplistic approach to
+	//   rendezvous: always select the first broker spec. This works, but is
+	//   not robust in terms of load balancing, and fails if the first broker
+	//   is unreachable or overloaded. Non-personal in-proxy dials can simply
+	//   use any available broker.
+	//
+	// - The broker matching queues lack compartment ID indexing. For a
+	//   handful of common compartment IDs, this is not expected to be an
+	//   issue. For personal compartment IDs, this may lead to frequency
+	//   near-full scans of the queues when looking for a match.
+	//
+	// - In personal mode, all establishment candidates must be in-proxy
+	//   dials, all using the same broker. Many concurrent, fronted broker
+	//   requests may result in CDN rate limiting, requiring some mechanism
+	//   to delay or spread the requests, as is currently done only for
+	//   batches of proxy announcements.
+	//
 	InproxyClientPersonalCompartmentIDs []string
 
 	// EmitInproxyProxyActivity indicates whether to emit frequent notices

+ 5 - 0
psiphon/server/config.go

@@ -471,6 +471,11 @@ type Config struct {
 	// and proxies from the same ASN. This parameter is for testing only.
 	InproxyBrokerAllowCommonASNMatching bool
 
+	// InproxyBrokerAllowBogonWebRTCConnections overrides the default broker
+	// SDP validation behavior, which doesn't allow private network WebRTC
+	// candidates. This parameter is for testing only.
+	InproxyBrokerAllowBogonWebRTCConnections bool
+
 	// InproxyServerSessionPrivateKey specifies the server's in-proxy session
 	// private key and derived public key used by brokers. This value is
 	// required when running in-proxy tunnel protocols.

+ 24 - 10
psiphon/server/meek.go

@@ -54,8 +54,8 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values"
 	lrucache "github.com/cognusion/go-cache-lru"
-	"github.com/juju/ratelimit"
 	"golang.org/x/crypto/nacl/box"
+	"golang.org/x/time/rate"
 )
 
 // MeekServer is based on meek-server.go from Tor and Psiphon:
@@ -313,6 +313,10 @@ func NewMeekServer(
 			inproxy.SetAllowCommonASNMatching(true)
 		}
 
+		if support.Config.InproxyBrokerAllowBogonWebRTCConnections {
+			inproxy.SetAllowBogonWebRTCConnections(true)
+		}
+
 		sessionPrivateKey, err := inproxy.SessionPrivateKeyFromString(
 			support.Config.InproxyBrokerSessionPrivateKey)
 		if err != nil {
@@ -1048,8 +1052,15 @@ func (server *MeekServer) getSessionOrEndpoint(
 	// based on response time combined with the rate limit configuration. The
 	// rate limit is primarily intended to limit memory resource consumption and
 	// not the overhead incurred by cookie validation.
-
-	if server.rateLimit(clientIP, geoIPData, server.listenerTunnelProtocol) {
+	//
+	// The meek rate limit is applied to new meek tunnel sessions and tactics
+	// requests, both of which may reasonably be limited to as low as 1 event
+	// per time period. The in-proxy broker is excluded from meek rate
+	// limiting since it has its own rate limiter and in-proxy requests are
+	// allowed to be more frequent.
+
+	if clientSessionData.EndPoint != inproxy.BrokerEndPointName &&
+		server.rateLimit(clientIP, geoIPData, server.listenerTunnelProtocol) {
 		return "", nil, nil, "", "", nil, errors.TraceNew("rate limit exceeded")
 	}
 
@@ -1273,22 +1284,25 @@ func (server *MeekServer) rateLimit(
 	// (as well as synchronizing access to rateLimitCount).
 	server.rateLimitLock.Lock()
 
-	var rateLimiter *ratelimit.Bucket
+	var rateLimiter *rate.Limiter
 	entry, ok := server.rateLimitHistory.Get(rateLimitIP)
 	if ok {
-		rateLimiter = entry.(*ratelimit.Bucket)
+		rateLimiter = entry.(*rate.Limiter)
 	} else {
-		rateLimiter = ratelimit.NewBucketWithQuantum(
-			time.Duration(thresholdSeconds)*time.Second,
-			int64(historySize),
-			int64(historySize))
+
+		// Set bursts to 1, which is appropriate for new meek tunnels and
+		// tactics requests.
+
+		limit := float64(historySize) / float64(thresholdSeconds)
+		bursts := 1
+		rateLimiter = rate.NewLimiter(rate.Limit(limit), bursts)
 		server.rateLimitHistory.Set(
 			rateLimitIP,
 			rateLimiter,
 			time.Duration(thresholdSeconds)*time.Second)
 	}
 
-	limit := rateLimiter.TakeAvailable(1) < 1
+	limit := !rateLimiter.Allow()
 
 	triggerGC := false
 	if limit {

+ 3 - 3
psiphon/server/server_test.go

@@ -964,6 +964,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			inproxyTestConfig.brokerServerEntrySignaturePublicKey
 
 		serverConfig["InproxyBrokerAllowCommonASNMatching"] = true
+		serverConfig["InproxyBrokerAllowBogonWebRTCConnections"] = true
 	}
 
 	// Uncomment to enable SIGUSR2 profile dumps
@@ -3388,7 +3389,8 @@ func generateInproxyTestConfig(
 	// To minimize external dependencies, STUN testing is disabled here; it is
 	// exercised in the common/inproxy package tests.
 	//
-	// InproxyBrokerAllowCommonASNMatching must be set to true in the
+	// InproxyBrokerAllowCommonASNMatching and
+	// InproxyBrokerAllowBogonWebRTCConnections must be set to true in the
 	// server/broker config, to allow matches with the same local network
 	// address. InproxyDisableIPv6ICECandidates is turned on, in tactics,
 	// since the test GeoIP database is IPv4-only (see paveGeoIPDatabaseFiles).
@@ -3553,8 +3555,6 @@ func generateInproxyTestConfig(
 		proxySessionPrivateKey:              proxySessionPrivateKeyStr,
 	}
 
-	inproxy.SetAllowBogonWebRTCConnections(true)
-
 	return config, nil
 }
 

+ 59 - 7
psiphon/server/services.go

@@ -42,6 +42,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/psinet"
+	"github.com/shirou/gopsutil/v4/cpu"
 )
 
 // RunServices initializes support functions including logging and GeoIP services;
@@ -182,10 +183,12 @@ func RunServices(configJSON []byte) (retErr error) {
 			defer ticker.Stop()
 
 			logNetworkBytes := true
+			logCPU := true
 
 			previousNetworkBytesReceived, previousNetworkBytesSent, err := getNetworkBytesTransferred()
 			if err != nil {
-				log.WithTraceFields(LogFields{"error": errors.Trace(err)}).Error("failed to get initial network bytes transferred")
+				log.WithTraceFields(LogFields{"error": errors.Trace(err)}).Error(
+					"failed to get initial network bytes transferred")
 
 				// If getNetworkBytesTransferred fails, stop logging network
 				// bytes for the lifetime of this process, in case there's a
@@ -194,24 +197,49 @@ func RunServices(configJSON []byte) (retErr error) {
 				logNetworkBytes = false
 			}
 
+			// Establish initial previous CPU stats. The previous CPU stats
+			// are stored internally by gopsutil/cpu.
+			_, err = getCPUPercent()
+			if err != nil {
+				log.WithTraceFields(LogFields{"error": errors.Trace(err)}).Error(
+					"failed to get initial CPU percent")
+
+				logCPU = false
+			}
+
 			for {
 				select {
 				case <-shutdownBroadcast:
 					return
 				case <-ticker.C:
-					var networkBytesReceived, networkBytesSent int64
 
+					var networkBytesReceived, networkBytesSent int64
 					if logNetworkBytes {
 						currentNetworkBytesReceived, currentNetworkBytesSent, err := getNetworkBytesTransferred()
 						if err != nil {
-							log.WithTraceFields(LogFields{"error": errors.Trace(err)}).Error("failed to get current network bytes transferred")
+							log.WithTraceFields(LogFields{"error": errors.Trace(err)}).Error(
+								"failed to get current network bytes transferred")
 							logNetworkBytes = false
 
 						} else {
 							networkBytesReceived = currentNetworkBytesReceived - previousNetworkBytesReceived
 							networkBytesSent = currentNetworkBytesSent - previousNetworkBytesSent
 
-							previousNetworkBytesReceived, previousNetworkBytesSent = currentNetworkBytesReceived, currentNetworkBytesSent
+							previousNetworkBytesReceived, previousNetworkBytesSent =
+								currentNetworkBytesReceived, currentNetworkBytesSent
+						}
+					}
+
+					var CPUPercent float64
+					if logCPU {
+						recentCPUPercent, err := getCPUPercent()
+						if err != nil {
+							log.WithTraceFields(LogFields{"error": errors.Trace(err)}).Error(
+								"failed to get recent CPU percent")
+							logCPU = false
+
+						} else {
+							CPUPercent = recentCPUPercent
 						}
 					}
 
@@ -220,7 +248,8 @@ func RunServices(configJSON []byte) (retErr error) {
 					// networkBytesSent may be < 0. logServerLoad will not
 					// log these negative values.
 
-					logServerLoad(support, logNetworkBytes, networkBytesReceived, networkBytesSent)
+					logServerLoad(
+						support, logNetworkBytes, networkBytesReceived, networkBytesSent, logCPU, CPUPercent)
 				}
 			}
 		}()
@@ -324,7 +353,7 @@ loop:
 			case signalProcessProfiles <- struct{}{}:
 			default:
 			}
-			logServerLoad(support, false, 0, 0)
+			logServerLoad(support, false, 0, 0, false, 0)
 
 		case <-systemStopSignal:
 			log.WithTrace().Info("shutdown by system")
@@ -402,7 +431,26 @@ func outputProcessProfiles(config *Config, filenameSuffix string) {
 	}
 }
 
-func logServerLoad(support *SupportServices, logNetworkBytes bool, networkBytesReceived int64, networkBytesSent int64) {
+// getCPUPercent returns the overall system CPU percent (not the percent used
+// by this process), across all cores.
+func getCPUPercent() (float64, error) {
+	values, err := cpu.Percent(0, false)
+	if err != nil {
+		return 0, errors.Trace(err)
+	}
+	if len(values) != 1 {
+		return 0, errors.TraceNew("unexpected cpu.Percent return value")
+	}
+	return values[0], nil
+}
+
+func logServerLoad(
+	support *SupportServices,
+	logNetworkBytes bool,
+	networkBytesReceived int64,
+	networkBytesSent int64,
+	logCPU bool,
+	CPUPercent float64) {
 
 	serverLoad := getRuntimeMetrics()
 
@@ -421,6 +469,10 @@ func logServerLoad(support *SupportServices, logNetworkBytes bool, networkBytesR
 		}
 	}
 
+	if logCPU {
+		serverLoad["cpu_percent"] = CPUPercent
+	}
+
 	establishTunnels, establishLimitedCount :=
 		support.TunnelServer.GetEstablishTunnelsMetrics()
 	serverLoad["establish_tunnels"] = establishTunnels

+ 4 - 2
psiphon/server/tunnelServer.go

@@ -2089,9 +2089,11 @@ func (sshClient *sshClient) run(
 	// Allow garbage collection.
 	p.Close()
 
-	// Further wrap the connection in a rate limiting ThrottledConn.
+	// Further wrap the connection in a rate limiting ThrottledConn. The
+	// underlying dialConn is always a stream, even when the network conn
+	// uses UDP.
 
-	throttledConn := common.NewThrottledConn(conn, sshClient.rateLimits())
+	throttledConn := common.NewThrottledConn(conn, true, sshClient.rateLimits())
 	conn = throttledConn
 
 	// Replay of server-side parameters is set or extended after a new tunnel

+ 3 - 1
psiphon/tunnel.go

@@ -1017,9 +1017,11 @@ func dialTunnel(
 		burstUpstreamTargetBytes, burstUpstreamDeadline,
 		burstDownstreamTargetBytes, burstDownstreamDeadline)
 
-	// Apply throttling (if configured)
+	// Apply throttling (if configured). The underlying dialConn is always a
+	// stream, even when the network conn uses UDP.
 	throttledConn := common.NewThrottledConn(
 		monitoredConn,
+		true,
 		rateLimits)
 
 	// Add obfuscated SSH layer