Jelajahi Sumber

Inproxy: broker matcher fixes

- TestMatcherMultiQueue exited early without reaching most test cases.

- Since it took a value receiver instead of a pointer receiver, the guard
  against multiple calls to announcementQueueReference.dequeue had no effect,
  leading to potential underflows in the tracked counts.

- Empty compartment queues were left in place by dequeue, leading to
  potentially unbounded memory consumption in the case of high compartment ID
  churn.
Rod Hynes 1 tahun lalu
induk
melakukan
e222c2156d
2 mengubah file dengan 84 tambahan dan 33 penghapusan
  1. 22 7
      psiphon/common/inproxy/matcher.go
  2. 62 26
      psiphon/common/inproxy/matcher_test.go

+ 22 - 7
psiphon/common/inproxy/matcher.go

@@ -977,7 +977,7 @@ func (m *Matcher) addAnnouncementEntry(announcementEntry *announcementEntry) err
 
 func (m *Matcher) removeAnnouncementEntry(aborting bool, announcementEntry *announcementEntry) {
 
-	// In the aborting case, the queue isn't already locked. Otherise, assume
+	// In the aborting case, the queue isn't already locked. Otherwise, assume
 	// it is locked.
 	if aborting {
 		m.announcementQueueMutex.Lock()
@@ -1119,6 +1119,8 @@ type announcementMultiQueue struct {
 // announcements, which are used, when matching, to determine when better NAT
 // matches may be possible.
 type announcementCompartmentQueue struct {
+	isCommonCompartment      bool
+	compartmentID            ID
 	entries                  *list.List
 	unlimitedNATCount        int
 	partiallyLimitedNATCount int
@@ -1177,12 +1179,14 @@ func (q *announcementMultiQueue) enqueue(announcementEntry *announcementEntry) e
 		return errors.TraceNew("announcement must specify exactly one compartment ID")
 	}
 
+	isCommonCompartment := true
 	var compartmentID ID
 	var compartmentQueues map[ID]*announcementCompartmentQueue
 	if len(commonCompartmentIDs) > 0 {
 		compartmentID = commonCompartmentIDs[0]
 		compartmentQueues = q.commonCompartmentQueues
 	} else {
+		isCommonCompartment = false
 		compartmentID = personalCompartmentIDs[0]
 		compartmentQueues = q.personalCompartmentQueues
 	}
@@ -1190,7 +1194,9 @@ func (q *announcementMultiQueue) enqueue(announcementEntry *announcementEntry) e
 	compartmentQueue, ok := compartmentQueues[compartmentID]
 	if !ok {
 		compartmentQueue = &announcementCompartmentQueue{
-			entries: list.New(),
+			isCommonCompartment: isCommonCompartment,
+			compartmentID:       compartmentID,
+			entries:             list.New(),
 		}
 		compartmentQueues[compartmentID] = compartmentQueue
 	}
@@ -1221,7 +1227,7 @@ func (q *announcementMultiQueue) enqueue(announcementEntry *announcementEntry) e
 }
 
 // announcementQueueReference returns false if the item is already dequeued.
-func (r announcementQueueReference) dequeue() bool {
+func (r *announcementQueueReference) dequeue() bool {
 
 	if r.entry == nil {
 		// Already dequeued.
@@ -1242,6 +1248,15 @@ func (r announcementQueueReference) dequeue() bool {
 
 	r.compartmentQueue.entries.Remove(r.entry)
 
+	if r.compartmentQueue.entries.Len() == 0 {
+		// Remove empty compartment queue.
+		queues := r.multiQueue.commonCompartmentQueues
+		if !r.compartmentQueue.isCommonCompartment {
+			queues = r.multiQueue.personalCompartmentQueues
+		}
+		delete(queues, r.compartmentQueue.compartmentID)
+	}
+
 	r.multiQueue.totalEntries -= 1
 
 	// Mark as dequeued.
@@ -1319,14 +1334,14 @@ func (iter *announcementMatchIterator) getNext() *announcementEntry {
 
 	// Select the oldest item, by deadline, from all the candidate queue head
 	// items. This operation is linear in the number of matching compartment
-	// ID queues, which is currently bounded by This is a linear time
-	// operation, bounded by the length of matching compartment IDs (no more
-	// than maxCompartmentIDs, as enforced in
+	// ID queues, which is currently bounded by the length of matching
+	// compartment IDs (no more than maxCompartmentIDs, as enforced in
 	// ClientOfferRequest.ValidateAndGetLogFields).
 	//
 	// A potential future enhancement is to add more iterator state to track
 	// which queue has the next oldest time to select on the following
-	// getNext call.
+	// getNext call. Another potential enhancement is to remove fully
+	// consumed queues from compartmentQueues/compartmentIDs/nextEntries.
 
 	var selectedCandidate *announcementEntry
 	selectedIndex := -1

+ 62 - 26
psiphon/common/inproxy/matcher_test.go

@@ -671,7 +671,6 @@ func TestMatcherMultiQueue(t *testing.T) {
 	if err != nil {
 		t.Errorf(errors.Trace(err).Error())
 	}
-
 }
 
 func runTestMatcherMultiQueue() error {
@@ -725,7 +724,7 @@ func runTestMatcherMultiQueue() error {
 						otherCommonCompartmentIDs[i%numOtherCompartmentIDs]},
 					NATType: NATTypeSymmetric,
 				}}})
-		if err == nil {
+		if err != nil {
 			return errors.Trace(err)
 		}
 		err = q.enqueue(&announcementEntry{
@@ -736,38 +735,41 @@ func runTestMatcherMultiQueue() error {
 						otherPersonalCompartmentIDs[i%numOtherCompartmentIDs]},
 					NATType: NATTypeSymmetric,
 				}}})
-		if err == nil {
+		if err != nil {
 			return errors.Trace(err)
 		}
 	}
 
 	var matchingCommonCompartmentIDs []ID
 	numMatchingCompartmentIDs := 2
+	numMatchingEntries := 2
 	var expectedMatches []*announcementEntry
 	for i := 0; i < numMatchingCompartmentIDs; i++ {
-		commonCompartmentID, _ := MakeID()
-		matchingCommonCompartmentIDs = append(
-			matchingCommonCompartmentIDs, commonCompartmentID)
-		ctx, cancel := context.WithDeadline(
-			context.Background(), time.Now().Add(time.Duration(i+1)*time.Minute))
-		defer cancel()
-		a := &announcementEntry{
-			ctx: ctx,
-			announcement: &MatchAnnouncement{
-				Properties: MatchProperties{
-					CommonCompartmentIDs: matchingCommonCompartmentIDs[i:i],
-					NATType:              NATTypeNone,
-				}}}
-		expectedMatches = append(expectedMatches, a)
-		err := q.enqueue(a)
-		if err == nil {
-			return errors.Trace(err)
+		for j := 0; j < numMatchingEntries; j++ {
+			commonCompartmentID, _ := MakeID()
+			matchingCommonCompartmentIDs = append(
+				matchingCommonCompartmentIDs, commonCompartmentID)
+			ctx, cancel := context.WithDeadline(
+				context.Background(), time.Now().Add(time.Duration(i+1)*time.Minute))
+			defer cancel()
+			a := &announcementEntry{
+				ctx: ctx,
+				announcement: &MatchAnnouncement{
+					Properties: MatchProperties{
+						CommonCompartmentIDs: matchingCommonCompartmentIDs[i : i+1],
+						NATType:              NATTypeNone,
+					}}}
+			expectedMatches = append(expectedMatches, a)
+			err := q.enqueue(a)
+			if err != nil {
+				return errors.Trace(err)
+			}
 		}
 	}
 
 	// Test: inspect queue state
 
-	if q.getLen() != numOtherEntries*2+numMatchingCompartmentIDs {
+	if q.getLen() != numOtherEntries*2+numMatchingCompartmentIDs*numMatchingEntries {
 		return errors.TraceNew("unexpected total entries count")
 	}
 
@@ -789,7 +791,9 @@ func runTestMatcherMultiQueue() error {
 	}
 
 	unlimited, partiallyLimited, strictlyLimited := iter.getNATCounts()
-	if unlimited != numMatchingCompartmentIDs || partiallyLimited != 0 || strictlyLimited != 0 {
+	if unlimited != numMatchingCompartmentIDs*numMatchingEntries ||
+		partiallyLimited != 0 ||
+		strictlyLimited != 0 {
 		return errors.TraceNew("unexpected NAT counts")
 	}
 
@@ -797,7 +801,7 @@ func runTestMatcherMultiQueue() error {
 	if match == nil {
 		return errors.TraceNew("unexpected missing match")
 	}
-	if match == expectedMatches[0] {
+	if match != expectedMatches[0] {
 		return errors.TraceNew("unexpected match")
 	}
 
@@ -811,12 +815,14 @@ func runTestMatcherMultiQueue() error {
 
 	iter = q.startMatching(true, matchingCommonCompartmentIDs)
 
-	if len(iter.compartmentQueues) != numMatchingCompartmentIDs-1 {
+	if len(iter.compartmentQueues) != numMatchingCompartmentIDs {
 		return errors.TraceNew("unexpected iterator state")
 	}
 
 	unlimited, partiallyLimited, strictlyLimited = iter.getNATCounts()
-	if unlimited != numMatchingCompartmentIDs-1 || partiallyLimited != 0 || strictlyLimited != 0 {
+	if unlimited != numMatchingEntries*numMatchingCompartmentIDs-1 ||
+		partiallyLimited != 0 ||
+		strictlyLimited != 0 {
 		return errors.TraceNew("unexpected NAT counts")
 	}
 
@@ -824,7 +830,37 @@ func runTestMatcherMultiQueue() error {
 	if match == nil {
 		return errors.TraceNew("unexpected missing match")
 	}
-	if match == expectedMatches[1] {
+	if match != expectedMatches[1] {
+		return errors.TraceNew("unexpected match")
+	}
+
+	if !match.queueReference.dequeue() {
+		return errors.TraceNew("unexpected already dequeued")
+	}
+
+	if len(iter.compartmentQueues) != numMatchingCompartmentIDs {
+		return errors.TraceNew("unexpected iterator state")
+	}
+
+	// Test: getNext after dequeue
+
+	match = iter.getNext()
+	if match == nil {
+		return errors.TraceNew("unexpected missing match")
+	}
+	if match != expectedMatches[2] {
+		return errors.TraceNew("unexpected match")
+	}
+
+	if !match.queueReference.dequeue() {
+		return errors.TraceNew("unexpected already dequeued")
+	}
+
+	match = iter.getNext()
+	if match == nil {
+		return errors.TraceNew("unexpected missing match")
+	}
+	if match != expectedMatches[3] {
 		return errors.TraceNew("unexpected match")
 	}