Просмотр исходного кода

Cap / rate limit relayed DSL requests

Rod Hynes 4 месяцев назад
Родитель
Сommit
0120178259

+ 108 - 24
psiphon/common/inproxy/broker.go

@@ -35,6 +35,7 @@ import (
 	"github.com/cespare/xxhash"
 	lrucache "github.com/cognusion/go-cache-lru"
 	"github.com/fxamacker/cbor/v2"
+	"golang.org/x/time/rate"
 )
 
 const (
@@ -52,6 +53,9 @@ const (
 	brokerPendingServerReportsTTL     = 60 * time.Second
 	brokerPendingServerReportsMaxSize = 100000
 	brokerMetricName                  = "inproxy_broker"
+
+	brokerRateLimiterReapHistoryFrequencySeconds = 300
+	brokerRateLimiterMaxCacheEntries             = 1000000
 )
 
 // LookupGeoIP is a callback for providing GeoIP lookup service.
@@ -100,15 +104,18 @@ type Broker struct {
 	commonCompartmentsMutex sync.Mutex
 	commonCompartments      *consistent.Consistent
 
-	proxyAnnounceTimeout       int64
-	clientOfferTimeout         int64
-	clientOfferPersonalTimeout int64
-	pendingServerReportsTTL    int64
+	proxyAnnounceTimeout       atomic.Int64
+	clientOfferTimeout         atomic.Int64
+	clientOfferPersonalTimeout atomic.Int64
+	pendingServerReportsTTL    atomic.Int64
 	maxRequestTimeouts         atomic.Value
-	maxCompartmentIDs          int64
+	maxCompartmentIDs          atomic.Int64
 
 	enableProxyQualityMutex sync.Mutex
 	enableProxyQuality      atomic.Bool
+
+	dslRequestRateLimiters    *lrucache.Cache
+	dslRequestRateLimitParams atomic.Value
 }
 
 // BrokerConfig specifies the configuration for a Broker.
@@ -226,6 +233,10 @@ type BrokerConfig struct {
 	MatcherOfferRateLimitQuantity int
 	MatcherOfferRateLimitInterval time.Duration
 
+	// DSL request relay rate limit configuration.
+	DSLRequestRateLimitQuantity int
+	DSLRequestRateLimitInterval time.Duration
+
 	// MaxCompartmentIDs specifies the maximum number of compartment IDs that
 	// can be included, per list, in one request. If 0, the value
 	// MaxCompartmentIDs is used.
@@ -306,12 +317,10 @@ func NewBroker(config *BrokerConfig) (*Broker, error) {
 
 		proxyQualityState: proxyQuality,
 
-		proxyAnnounceTimeout:       int64(config.ProxyAnnounceTimeout),
-		clientOfferTimeout:         int64(config.ClientOfferTimeout),
-		clientOfferPersonalTimeout: int64(config.ClientOfferPersonalTimeout),
-		pendingServerReportsTTL:    int64(config.PendingServerReportsTTL),
-
-		maxCompartmentIDs: int64(common.ValueOrDefault(config.MaxCompartmentIDs, MaxCompartmentIDs)),
+		dslRequestRateLimiters: lrucache.NewWithLRU(
+			0,
+			time.Duration(brokerRateLimiterReapHistoryFrequencySeconds)*time.Second,
+			brokerRateLimiterMaxCacheEntries),
 	}
 
 	b.pendingServerReports = lrucache.NewWithLRU(
@@ -319,6 +328,20 @@ func NewBroker(config *BrokerConfig) (*Broker, error) {
 		1*time.Minute,
 		brokerPendingServerReportsMaxSize)
 
+	b.proxyAnnounceTimeout.Store(int64(config.ProxyAnnounceTimeout))
+	b.clientOfferTimeout.Store(int64(config.ClientOfferTimeout))
+	b.clientOfferPersonalTimeout.Store(int64(config.ClientOfferPersonalTimeout))
+	b.pendingServerReportsTTL.Store(int64(config.PendingServerReportsTTL))
+
+	b.maxCompartmentIDs.Store(
+		int64(common.ValueOrDefault(config.MaxCompartmentIDs, MaxCompartmentIDs)))
+
+	b.dslRequestRateLimitParams.Store(
+		&brokerRateLimitParams{
+			quantity: config.DSLRequestRateLimitQuantity,
+			interval: config.DSLRequestRateLimitInterval,
+		})
+
 	if len(config.CommonCompartmentIDs) > 0 {
 		err = b.initializeCommonCompartmentIDHashing(config.CommonCompartmentIDs)
 		if err != nil {
@@ -365,10 +388,10 @@ func (b *Broker) SetTimeouts(
 	pendingServerReportsTTL time.Duration,
 	maxRequestTimeouts map[string]time.Duration) {
 
-	atomic.StoreInt64(&b.proxyAnnounceTimeout, int64(proxyAnnounceTimeout))
-	atomic.StoreInt64(&b.clientOfferTimeout, int64(clientOfferTimeout))
-	atomic.StoreInt64(&b.clientOfferPersonalTimeout, int64(clientOfferPersonalTimeout))
-	atomic.StoreInt64(&b.pendingServerReportsTTL, int64(pendingServerReportsTTL))
+	b.proxyAnnounceTimeout.Store(int64(proxyAnnounceTimeout))
+	b.clientOfferTimeout.Store(int64(clientOfferTimeout))
+	b.clientOfferPersonalTimeout.Store(int64(clientOfferPersonalTimeout))
+	b.pendingServerReportsTTL.Store(int64(pendingServerReportsTTL))
 	b.maxRequestTimeouts.Store(maxRequestTimeouts)
 }
 
@@ -383,7 +406,9 @@ func (b *Broker) SetLimits(
 	matcherOfferLimitEntryCount int,
 	matcherOfferRateLimitQuantity int,
 	matcherOfferRateLimitInterval time.Duration,
-	maxCompartmentIDs int) {
+	maxCompartmentIDs int,
+	dslRequestRateLimitQuantity int,
+	dslRequestRateLimitInterval time.Duration) {
 
 	b.matcher.SetLimits(
 		matcherAnnouncementLimitEntryCount,
@@ -394,9 +419,14 @@ func (b *Broker) SetLimits(
 		matcherOfferRateLimitQuantity,
 		matcherOfferRateLimitInterval)
 
-	atomic.StoreInt64(
-		&b.maxCompartmentIDs,
+	b.maxCompartmentIDs.Store(
 		int64(common.ValueOrDefault(maxCompartmentIDs, MaxCompartmentIDs)))
+
+	b.dslRequestRateLimitParams.Store(
+		&brokerRateLimitParams{
+			quantity: dslRequestRateLimitQuantity,
+			interval: dslRequestRateLimitInterval,
+		})
 }
 
 func (b *Broker) SetProxyQualityParameters(
@@ -666,7 +696,7 @@ func (b *Broker) handleProxyAnnounce(
 
 	var apiParams common.APIParameters
 	apiParams, logFields, err = announceRequest.ValidateAndGetParametersAndLogFields(
-		int(atomic.LoadInt64(&b.maxCompartmentIDs)),
+		int(b.maxCompartmentIDs.Load()),
 		b.config.APIParameterValidator,
 		b.config.APIParameterLogFieldFormatter,
 		geoIPData)
@@ -807,7 +837,7 @@ func (b *Broker) handleProxyAnnounce(
 	// Await client offer.
 
 	timeout := common.ValueOrDefault(
-		time.Duration(atomic.LoadInt64(&b.proxyAnnounceTimeout)),
+		time.Duration(b.proxyAnnounceTimeout.Load()),
 		brokerProxyAnnounceTimeout)
 
 	// Adjust the timeout to respect any shorter maximum request timeouts for
@@ -1038,7 +1068,7 @@ func (b *Broker) handleClientOffer(
 
 	var filteredSDP []byte
 	filteredSDP, logFields, err = offerRequest.ValidateAndGetLogFields(
-		int(atomic.LoadInt64(&b.maxCompartmentIDs)),
+		int(b.maxCompartmentIDs.Load()),
 		b.config.LookupGeoIP,
 		b.config.APIParameterValidator,
 		b.config.APIParameterLogFieldFormatter,
@@ -1111,9 +1141,9 @@ func (b *Broker) handleClientOffer(
 	// resulting broker rotation.
 	var timeout time.Duration
 	if hasPersonalCompartmentIDs {
-		timeout = time.Duration(atomic.LoadInt64(&b.clientOfferPersonalTimeout))
+		timeout = time.Duration(b.clientOfferPersonalTimeout.Load())
 	} else {
-		timeout = time.Duration(atomic.LoadInt64(&b.clientOfferTimeout))
+		timeout = time.Duration(b.clientOfferTimeout.Load())
 	}
 	timeout = common.ValueOrDefault(timeout, brokerClientOfferTimeout)
 
@@ -1680,6 +1710,26 @@ func (b *Broker) handleClientDSL(
 		}
 	}()
 
+	// Rate limit the number of relayed DSL requests. The DSL backend has its
+	// own rate limit enforcement, but avoiding excess requests here saves on
+	// resources consumed between the relay and backend.
+	//
+	// Unlike the announce/offer rate limit cases, there's no "limited" error
+	// flag returned to the client in this case, since this rate limiter is
+	// purely for abuse prevention and is expected to be configured with
+	// limits that won't be exceeded by legitimate clients.
+
+	rateLimitParams := b.dslRequestRateLimitParams.Load().(*brokerRateLimitParams)
+	err := brokerRateLimit(
+		b.dslRequestRateLimiters,
+		clientIP,
+		rateLimitParams.quantity,
+		rateLimitParams.interval)
+	if err != nil {
+		return nil, errors.Trace(err)
+
+	}
+
 	dslRequest, err := UnmarshalClientDSLRequest(requestPayload)
 	if err != nil {
 		return nil, errors.Trace(err)
@@ -1790,7 +1840,7 @@ func (b *Broker) initiateRelayedServerReport(
 			serverReport: serverReport,
 			roundTrip:    roundTrip,
 		},
-		time.Duration(atomic.LoadInt64(&b.pendingServerReportsTTL)))
+		time.Duration(b.pendingServerReportsTTL.Load()))
 
 	return relayPacket, nil
 }
@@ -2038,3 +2088,37 @@ func (b *Broker) selectCommonCompartmentID(proxyID ID) (ID, error) {
 
 	return compartmentID, nil
 }
+
+type brokerRateLimitParams struct {
+	quantity int
+	interval time.Duration
+}
+
+func brokerRateLimit(
+	rateLimiters *lrucache.Cache,
+	limitIP string,
+	quantity int,
+	interval time.Duration) error {
+
+	if quantity <= 0 || interval <= 0 {
+		return nil
+	}
+
+	var rateLimiter *rate.Limiter
+
+	entry, ok := rateLimiters.Get(limitIP)
+	if ok {
+		rateLimiter = entry.(*rate.Limiter)
+	} else {
+		limit := float64(quantity) / interval.Seconds()
+		rateLimiter = rate.NewLimiter(rate.Limit(limit), quantity)
+		rateLimiters.Set(
+			limitIP, rateLimiter, interval)
+	}
+
+	if !rateLimiter.Allow() {
+		return errors.TraceNew("rate exceeded for IP")
+	}
+
+	return nil
+}

+ 11 - 26
psiphon/common/inproxy/matcher.go

@@ -30,7 +30,6 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	lrucache "github.com/cognusion/go-cache-lru"
-	"golang.org/x/time/rate"
 )
 
 // TTLs should be aligned with STUN hole punch lifetimes.
@@ -42,9 +41,6 @@ const (
 	matcherPendingAnswersMaxSize    = 5000000
 	matcherMaxPreferredNATProbe     = 100
 	matcherMaxProbe                 = 1000
-
-	matcherRateLimiterReapHistoryFrequencySeconds = 300
-	matcherRateLimiterMaxCacheEntries             = 1000000
 )
 
 // Matcher matches proxy announcements with client offers. Matcher also
@@ -308,15 +304,15 @@ func NewMatcher(config *MatcherConfig) *Matcher {
 		announcementQueueEntryCountByIP: make(map[string]int),
 		announcementQueueRateLimiters: lrucache.NewWithLRU(
 			0,
-			time.Duration(matcherRateLimiterReapHistoryFrequencySeconds)*time.Second,
-			matcherRateLimiterMaxCacheEntries),
+			time.Duration(brokerRateLimiterReapHistoryFrequencySeconds)*time.Second,
+			brokerRateLimiterMaxCacheEntries),
 
 		offerQueue:               list.New(),
 		offerQueueEntryCountByIP: make(map[string]int),
 		offerQueueRateLimiters: lrucache.NewWithLRU(
 			0,
-			time.Duration(matcherRateLimiterReapHistoryFrequencySeconds)*time.Second,
-			matcherRateLimiterMaxCacheEntries),
+			time.Duration(brokerRateLimiterReapHistoryFrequencySeconds)*time.Second,
+			brokerRateLimiterMaxCacheEntries),
 
 		matchSignal: make(chan struct{}, 1),
 
@@ -1064,24 +1060,13 @@ func (m *Matcher) applyIPLimits(isAnnouncement bool, limitIP string, proxyID ID)
 	// that the rate limit state is updated regardless of the max count check
 	// outcome.
 
-	if quantity > 0 && interval > 0 {
-
-		var rateLimiter *rate.Limiter
-
-		entry, ok := queueRateLimiters.Get(limitIP)
-		if ok {
-			rateLimiter = entry.(*rate.Limiter)
-		} else {
-			limit := float64(quantity) / interval.Seconds()
-			rateLimiter = rate.NewLimiter(rate.Limit(limit), quantity)
-			queueRateLimiters.Set(
-				limitIP, rateLimiter, interval)
-		}
-
-		if !rateLimiter.Allow() {
-			return errors.Trace(
-				NewMatcherLimitError(std_errors.New("rate exceeded for IP")))
-		}
+	err := brokerRateLimit(
+		queueRateLimiters,
+		limitIP,
+		quantity,
+		interval)
+	if err != nil {
+		return errors.Trace(NewMatcherLimitError(err))
 	}
 
 	if limitEntryCount > 0 {

+ 4 - 0
psiphon/common/parameters/parameters.go

@@ -434,6 +434,8 @@ const (
 	InproxyBrokerClientOfferTimeout                    = "InproxyBrokerClientOfferTimeout"
 	InproxyBrokerClientOfferPersonalTimeout            = "InproxyBrokerClientOfferPersonalTimeout"
 	InproxyBrokerPendingServerRequestsTTL              = "InproxyBrokerPendingServerRequestsTTL"
+	InproxyBrokerDSLRequestRateLimitQuantity           = "InproxyBrokerDSLRequestRateLimitQuantity"
+	InproxyBrokerDSLRequestRateLimitInterval           = "InproxyBrokerDSLRequestRateLimitInterval"
 	InproxySessionHandshakeRoundTripTimeout            = "InproxySessionHandshakeRoundTripTimeout"
 	InproxyProxyAnnounceRequestTimeout                 = "InproxyProxyAnnounceRequestTimeout"
 	InproxyProxyAnnounceDelay                          = "InproxyProxyAnnounceDelay"
@@ -1054,6 +1056,8 @@ var defaultParameters = map[string]struct {
 	InproxyBrokerClientOfferTimeout:                    {value: 10 * time.Second, minimum: time.Duration(0), flags: serverSideOnly},
 	InproxyBrokerClientOfferPersonalTimeout:            {value: 5 * time.Second, minimum: time.Duration(0), flags: serverSideOnly},
 	InproxyBrokerPendingServerRequestsTTL:              {value: 60 * time.Second, minimum: time.Duration(0), flags: serverSideOnly},
+	InproxyBrokerDSLRequestRateLimitQuantity:           {value: 20, minimum: 0, flags: serverSideOnly},
+	InproxyBrokerDSLRequestRateLimitInterval:           {value: 1 * time.Second, minimum: time.Duration(0), flags: serverSideOnly},
 	InproxySessionHandshakeRoundTripTimeout:            {value: 10 * time.Second, minimum: time.Duration(0), flags: useNetworkLatencyMultiplier},
 	InproxyProxyAnnounceRequestTimeout:                 {value: 2*time.Minute + 10*time.Second, minimum: time.Duration(0)},
 	InproxyProxyAnnounceDelay:                          {value: 100 * time.Millisecond, minimum: time.Duration(0)},

+ 20 - 0
psiphon/server/api.go

@@ -1018,6 +1018,26 @@ func dslAPIRequestHandler(
 	sshClient *sshClient,
 	requestPayload []byte) ([]byte, error) {
 
+	// Sanity check: don't relay more than the modest number of DSL requests
+	// expected in the tunneled case. The DSL backend has its own rate limit
+	// enforcement, but avoiding excess requests here saves on resources
+	// consumed between the relay and backend.
+	//
+	// The equivalent pre-relay check in the in-proxy broker uses an explicit
+	// rate limiter; here a simpler hard limit per tunnel suffices due to the
+	// low limit size and the fact that tunnel dials are themselves rate
+	// limited.
+	ok := false
+	sshClient.Lock()
+	if sshClient.dslRequestCount < SSH_CLIENT_MAX_DSL_REQUEST_COUNT {
+		ok = true
+		sshClient.dslRequestCount += 1
+	}
+	sshClient.Unlock()
+	if !ok {
+		return nil, errors.TraceNew("too many DSL requests")
+	}
+
 	responsePayload, err := dslHandleRequest(
 		sshClient.runCtx,
 		support,

+ 3 - 1
psiphon/server/meek.go

@@ -1895,7 +1895,9 @@ func (server *MeekServer) inproxyReloadTactics() error {
 		p.Int(parameters.InproxyBrokerMatcherOfferLimitEntryCount),
 		p.Int(parameters.InproxyBrokerMatcherOfferRateLimitQuantity),
 		p.Duration(parameters.InproxyBrokerMatcherOfferRateLimitInterval),
-		p.Int(parameters.InproxyMaxCompartmentIDListLength))
+		p.Int(parameters.InproxyMaxCompartmentIDListLength),
+		p.Int(parameters.InproxyBrokerDSLRequestRateLimitQuantity),
+		p.Duration(parameters.InproxyBrokerDSLRequestRateLimitInterval))
 
 	server.inproxyBroker.SetProxyQualityParameters(
 		p.Bool(parameters.InproxyEnableProxyQuality),

+ 2 - 0
psiphon/server/tunnelServer.go

@@ -78,6 +78,7 @@ const (
 	RANDOM_STREAM_MAX_BYTES               = 10485760
 	ALERT_REQUEST_QUEUE_BUFFER_SIZE       = 16
 	SSH_MAX_CLIENT_COUNT                  = 100000
+	SSH_CLIENT_MAX_DSL_REQUEST_COUNT      = 32
 )
 
 // TunnelServer is the main server that accepts Psiphon client
@@ -1967,6 +1968,7 @@ type sshClient struct {
 	checkedServerEntryTags               int
 	invalidServerEntryTags               int
 	sshProtocolBytesTracker              *sshProtocolBytesTracker
+	dslRequestCount                      int
 }
 
 type trafficState struct {