Эх сурвалжийг харах

inproxy MustUpgrade changes

- Fix precedence bug
- Add MustUpgrade response based on protocol version
- Add test case
Rod Hynes 1 жил өмнө
parent
commit
5e62b3a3a1

+ 15 - 4
psiphon/common/inproxy/api.go

@@ -30,10 +30,21 @@ import (
 )
 )
 
 
 const (
 const (
-	ProxyProtocolVersion1 = 1
-	MaxCompartmentIDs     = 10
+
+	// ProxyProtocolVersion1 represents protocol version 1.
+	ProxyProtocolVersion1 = int32(1)
+
+	// MinimumProxyProtocolVersion is the minimum supported version number.
+	MinimumProxyProtocolVersion = ProxyProtocolVersion1
+
+	MaxCompartmentIDs = 10
 )
 )
 
 
+// proxyProtocolVersion is the current protocol version number.
+// proxyProtocolVersion is variable, to enable overriding the value in tests.
+// This value should not be overridden outside of test cases.
+var proxyProtocolVersion = ProxyProtocolVersion1
+
 // ID is a unique identifier used to identify inproxy connections and actors.
 // ID is a unique identifier used to identify inproxy connections and actors.
 type ID [32]byte
 type ID [32]byte
 
 
@@ -468,7 +479,7 @@ func (metrics *ProxyMetrics) ValidateAndGetParametersAndLogFields(
 		return nil, nil, errors.Trace(err)
 		return nil, nil, errors.Trace(err)
 	}
 	}
 
 
-	if metrics.ProxyProtocolVersion != ProxyProtocolVersion1 {
+	if metrics.ProxyProtocolVersion < 0 || metrics.ProxyProtocolVersion > proxyProtocolVersion {
 		return nil, nil, errors.Tracef("invalid proxy protocol version: %v", metrics.ProxyProtocolVersion)
 		return nil, nil, errors.Tracef("invalid proxy protocol version: %v", metrics.ProxyProtocolVersion)
 	}
 	}
 
 
@@ -521,7 +532,7 @@ func (metrics *ClientMetrics) ValidateAndGetLogFields(
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
 
 
-	if metrics.ProxyProtocolVersion != ProxyProtocolVersion1 {
+	if metrics.ProxyProtocolVersion < 0 || metrics.ProxyProtocolVersion > proxyProtocolVersion {
 		return nil, errors.Tracef("invalid proxy protocol version: %v", metrics.ProxyProtocolVersion)
 		return nil, errors.Tracef("invalid proxy protocol version: %v", metrics.ProxyProtocolVersion)
 	}
 	}
 
 

+ 31 - 0
psiphon/common/inproxy/broker.go

@@ -526,6 +526,22 @@ func (b *Broker) handleProxyAnnounce(
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
 
 
+	// Return MustUpgrade when the proxy's protocol version is less than the
+	// minimum required.
+	if announceRequest.Metrics.ProxyProtocolVersion < MinimumProxyProtocolVersion {
+		responsePayload, err := MarshalProxyAnnounceResponse(
+			&ProxyAnnounceResponse{
+				NoMatch:     true,
+				MustUpgrade: true,
+			})
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		return responsePayload, nil
+
+	}
+
 	// Fetch new tactics for the proxy, if required, using the tactics tag
 	// Fetch new tactics for the proxy, if required, using the tactics tag
 	// that should be included with the API parameters. A tacticsPayload may
 	// that should be included with the API parameters. A tacticsPayload may
 	// be returned when there are no new tactics, and this is relayed back to
 	// be returned when there are no new tactics, and this is relayed back to
@@ -818,6 +834,21 @@ func (b *Broker) handleClientOffer(
 		return nil, errors.Trace(err)
 		return nil, errors.Trace(err)
 	}
 	}
 
 
+	// Return MustUpgrade when the client's protocol version is less than the
+	// minimum required.
+	if offerRequest.Metrics.ProxyProtocolVersion < MinimumProxyProtocolVersion {
+		responsePayload, err := MarshalClientOfferResponse(
+			&ClientOfferResponse{
+				NoMatch:     true,
+				MustUpgrade: true,
+			})
+		if err != nil {
+			return nil, errors.Trace(err)
+		}
+
+		return responsePayload, nil
+	}
+
 	// Enqueue the client offer and await a proxy matching and subsequent
 	// Enqueue the client offer and await a proxy matching and subsequent
 	// proxy answer.
 	// proxy answer.
 
 

+ 15 - 15
psiphon/common/inproxy/client.go

@@ -377,7 +377,7 @@ func dialClientWebRTCConn(
 		&ClientOfferRequest{
 		&ClientOfferRequest{
 			Metrics: &ClientMetrics{
 			Metrics: &ClientMetrics{
 				BaseAPIParameters:    packedBaseParams,
 				BaseAPIParameters:    packedBaseParams,
-				ProxyProtocolVersion: ProxyProtocolVersion1,
+				ProxyProtocolVersion: proxyProtocolVersion,
 				NATType:              config.WebRTCDialCoordinator.NATType(),
 				NATType:              config.WebRTCDialCoordinator.NATType(),
 				PortMappingTypes:     config.WebRTCDialCoordinator.PortMappingTypes(),
 				PortMappingTypes:     config.WebRTCDialCoordinator.PortMappingTypes(),
 			},
 			},
@@ -396,28 +396,28 @@ func dialClientWebRTCConn(
 		return nil, false, errors.Trace(err)
 		return nil, false, errors.Trace(err)
 	}
 	}
 
 
-	// No retry when rate/entry limited or must upgrade; do retry on no-match,
-	// as a match may soon appear.
+	// MustUpgrade has precedence over other cases to ensure the callback is
+	// invoked. No retry when rate/entry limited or must upgrade; do retry on
+	// no-match, as a match may soon appear.
 
 
-	if offerResponse.Limited {
-		return nil, false, errors.TraceNew("limited")
-
-	} else if offerResponse.NoMatch {
-
-		return nil, true, errors.TraceNew("no proxy match")
-
-	} else if offerResponse.MustUpgrade {
+	if offerResponse.MustUpgrade {
 
 
 		if config.MustUpgrade != nil {
 		if config.MustUpgrade != nil {
 			config.MustUpgrade()
 			config.MustUpgrade()
 		}
 		}
-
 		return nil, false, errors.TraceNew("must upgrade")
 		return nil, false, errors.TraceNew("must upgrade")
+
+	} else if offerResponse.Limited {
+
+		return nil, false, errors.TraceNew("limited")
+
+	} else if offerResponse.NoMatch {
+
+		return nil, true, errors.TraceNew("no match")
 	}
 	}
 
 
-	if offerResponse.SelectedProxyProtocolVersion != ProxyProtocolVersion1 {
-		// This case is currently unexpected, as all clients and proxies use
-		// ProxyProtocolVersion1.
+	if offerResponse.SelectedProxyProtocolVersion < MinimumProxyProtocolVersion ||
+		offerResponse.SelectedProxyProtocolVersion > proxyProtocolVersion {
 		return nil, false, errors.Tracef(
 		return nil, false, errors.Tracef(
 			"Unsupported proxy protocol version: %d",
 			"Unsupported proxy protocol version: %d",
 			offerResponse.SelectedProxyProtocolVersion)
 			offerResponse.SelectedProxyProtocolVersion)

+ 91 - 38
psiphon/common/inproxy/inproxy_test.go

@@ -48,13 +48,20 @@ import (
 )
 )
 
 
 func TestInproxy(t *testing.T) {
 func TestInproxy(t *testing.T) {
-	err := runTestInproxy()
+	err := runTestInproxy(false)
 	if err != nil {
 	if err != nil {
 		t.Errorf(errors.Trace(err).Error())
 		t.Errorf(errors.Trace(err).Error())
 	}
 	}
 }
 }
 
 
-func runTestInproxy() error {
+func TestInproxyMustUpgrade(t *testing.T) {
+	err := runTestInproxy(true)
+	if err != nil {
+		t.Errorf(errors.Trace(err).Error())
+	}
+}
+
+func runTestInproxy(doMustUpgrade bool) error {
 
 
 	// Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
 	// Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
 
 
@@ -95,6 +102,23 @@ func runTestInproxy() error {
 	roundTripperFailedCount := int32(0)
 	roundTripperFailedCount := int32(0)
 	roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
 	roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
 
 
+	var receivedProxyMustUpgrade chan struct{}
+	var receivedClientMustUpgrade chan struct{}
+	if doMustUpgrade {
+
+		receivedProxyMustUpgrade = make(chan struct{})
+		receivedClientMustUpgrade = make(chan struct{})
+
+		// trigger MustUpgrade
+		proxyProtocolVersion = 0
+
+		// Minimize test parameters for MustUpgrade case
+		numProxies = 1
+		proxyMaxClients = 1
+		numClients = 1
+		testDisableSTUN = true
+	}
+
 	testCtx, stopTest := context.WithCancel(context.Background())
 	testCtx, stopTest := context.WithCancel(context.Background())
 	defer stopTest()
 	defer stopTest()
 
 
@@ -394,6 +418,9 @@ func runTestInproxy() error {
 
 
 		tacticsNetworkID := prng.HexString(32)
 		tacticsNetworkID := prng.HexString(32)
 
 
+		runCtx, cancelRun := context.WithCancel(testCtx)
+		// No deferred cancelRun due to testGroup.Go below
+
 		proxy, err := NewProxy(&ProxyConfig{
 		proxy, err := NewProxy(&ProxyConfig{
 
 
 			Logger: logger,
 			Logger: logger,
@@ -427,6 +454,11 @@ func runTestInproxy() error {
 					time.Now().UTC().Format(time.RFC3339),
 					time.Now().UTC().Format(time.RFC3339),
 					connectingClients, connectedClients, bytesUp, bytesDown)
 					connectingClients, connectedClients, bytesUp, bytesDown)
 			},
 			},
+
+			MustUpgrade: func() {
+				close(receivedProxyMustUpgrade)
+				cancelRun()
+			},
 		})
 		})
 		if err != nil {
 		if err != nil {
 			return errors.Trace(err)
 			return errors.Trace(err)
@@ -435,7 +467,7 @@ func runTestInproxy() error {
 		addPendingProxyTacticsCallback(proxyPrivateKey)
 		addPendingProxyTacticsCallback(proxyPrivateKey)
 
 
 		testGroup.Go(func() error {
 		testGroup.Go(func() error {
-			proxy.Run(testCtx)
+			proxy.Run(runCtx)
 			return nil
 			return nil
 		})
 		})
 	}
 	}
@@ -448,13 +480,15 @@ func runTestInproxy() error {
 	// - Don't wait for > numProxies announcements due to
 	// - Don't wait for > numProxies announcements due to
 	//   InitiatorSessions.NewRoundTrip waitToShareSession limitation
 	//   InitiatorSessions.NewRoundTrip waitToShareSession limitation
 
 
-	for {
-		time.Sleep(100 * time.Millisecond)
-		broker.matcher.announcementQueueMutex.Lock()
-		n := broker.matcher.announcementQueue.getLen()
-		broker.matcher.announcementQueueMutex.Unlock()
-		if n >= numProxies {
-			break
+	if !doMustUpgrade {
+		for {
+			time.Sleep(100 * time.Millisecond)
+			broker.matcher.announcementQueueMutex.Lock()
+			n := broker.matcher.announcementQueue.getLen()
+			broker.matcher.announcementQueueMutex.Unlock()
+			if n >= numProxies {
+				break
+			}
 		}
 		}
 	}
 	}
 
 
@@ -498,6 +532,11 @@ func runTestInproxy() error {
 					DialNetworkProtocol:          networkProtocol,
 					DialNetworkProtocol:          networkProtocol,
 					DialAddress:                  addr,
 					DialAddress:                  addr,
 					PackedDestinationServerEntry: packedDestinationServerEntry,
 					PackedDestinationServerEntry: packedDestinationServerEntry,
+					MustUpgrade: func() {
+						fmt.Printf("HI!\n")
+						close(receivedClientMustUpgrade)
+						cancelDial()
+					},
 				})
 				})
 			if err != nil {
 			if err != nil {
 				return errors.Trace(err)
 				return errors.Trace(err)
@@ -718,43 +757,57 @@ func runTestInproxy() error {
 		clientsGroup.Go(makeClientFunc(isTCP, isMobile, brokerClient, webRTCCoordinator))
 		clientsGroup.Go(makeClientFunc(isTCP, isMobile, brokerClient, webRTCCoordinator))
 	}
 	}
 
 
-	// Await client transfers complete
+	if doMustUpgrade {
 
 
-	logger.WithTrace().Info("AWAIT DATA TRANSFER")
+		// Await MustUpgrade callbacks
 
 
-	err = clientsGroup.Wait()
-	if err != nil {
-		return errors.Trace(err)
-	}
+		logger.WithTrace().Info("AWAIT MUST UPGRADE")
 
 
-	logger.WithTrace().Info("DONE DATA TRANSFER")
+		<-receivedProxyMustUpgrade
+		<-receivedClientMustUpgrade
 
 
-	if hasPendingBrokerServerReports() {
-		return errors.TraceNew("unexpected pending broker server requests")
-	}
+		_ = clientsGroup.Wait()
 
 
-	if hasPendingProxyTacticsCallbacks() {
-		return errors.TraceNew("unexpected pending proxy tactics callback")
-	}
+	} else {
+
+		// Await client transfers complete
 
 
-	// TODO: check that elapsed time is consistent with rate limit (+/-)
+		logger.WithTrace().Info("AWAIT DATA TRANSFER")
 
 
-	// Check if STUN server replay callbacks were triggered
-	if !testDisableSTUN {
-		if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
-			return errors.TraceNew("unexpected STUN server succeeded count")
+		err = clientsGroup.Wait()
+		if err != nil {
+			return errors.Trace(err)
 		}
 		}
-	}
-	if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
-		return errors.TraceNew("unexpected STUN server failed count")
-	}
 
 
-	// Check if RoundTripper server replay callbacks were triggered
-	if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
-		return errors.TraceNew("unexpected round tripper succeeded count")
-	}
-	if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
-		return errors.TraceNew("unexpected round tripper failed count")
+		logger.WithTrace().Info("DONE DATA TRANSFER")
+
+		if hasPendingBrokerServerReports() {
+			return errors.TraceNew("unexpected pending broker server requests")
+		}
+
+		if hasPendingProxyTacticsCallbacks() {
+			return errors.TraceNew("unexpected pending proxy tactics callback")
+		}
+
+		// TODO: check that elapsed time is consistent with rate limit (+/-)
+
+		// Check if STUN server replay callbacks were triggered
+		if !testDisableSTUN {
+			if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
+				return errors.TraceNew("unexpected STUN server succeeded count")
+			}
+		}
+		if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
+			return errors.TraceNew("unexpected STUN server failed count")
+		}
+
+		// Check if RoundTripper server replay callbacks were triggered
+		if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
+			return errors.TraceNew("unexpected round tripper succeeded count")
+		}
+		if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
+			return errors.TraceNew("unexpected round tripper failed count")
+		}
 	}
 	}
 
 
 	// Await shutdowns
 	// Await shutdowns

+ 1 - 1
psiphon/common/inproxy/obfuscation_test.go

@@ -78,7 +78,7 @@ func FuzzSessionPacketDeobfuscation(f *testing.F) {
 		}
 		}
 
 
 		if (err == nil) != inOriginals {
 		if (err == nil) != inOriginals {
-			f.Errorf("unexpected deobfuscation result")
+			t.Errorf("unexpected deobfuscation result")
 		}
 		}
 	})
 	})
 }
 }

+ 13 - 12
psiphon/common/inproxy/proxy.go

@@ -654,10 +654,19 @@ func (p *Proxy) proxyOneClient(
 		signalAnnounceDone()
 		signalAnnounceDone()
 	}
 	}
 
 
-	// Trigger back-off back off when rate/entry limited or must upgrade; no
-	// back-off for no-match.
+	// MustUpgrade has precedence over other cases, to ensure the callback is
+	// invoked. Trigger back-off back off when rate/entry limited or must
+	// upgrade; no back-off for no-match.
 
 
-	if announceResponse.Limited {
+	if announceResponse.MustUpgrade {
+
+		if p.config.MustUpgrade != nil {
+			p.config.MustUpgrade()
+		}
+		backOff = true
+		return backOff, errors.TraceNew("must upgrade")
+
+	} else if announceResponse.Limited {
 
 
 		backOff = true
 		backOff = true
 		return backOff, errors.TraceNew("limited")
 		return backOff, errors.TraceNew("limited")
@@ -666,14 +675,6 @@ func (p *Proxy) proxyOneClient(
 
 
 		return backOff, errors.TraceNew("no match")
 		return backOff, errors.TraceNew("no match")
 
 
-	} else if announceResponse.MustUpgrade {
-
-		if p.config.MustUpgrade != nil {
-			p.config.MustUpgrade()
-		}
-
-		backOff = true
-		return backOff, errors.TraceNew("must upgrade")
 	}
 	}
 
 
 	if announceResponse.ClientProxyProtocolVersion != ProxyProtocolVersion1 {
 	if announceResponse.ClientProxyProtocolVersion != ProxyProtocolVersion1 {
@@ -965,7 +966,7 @@ func (p *Proxy) getMetrics(webRTCCoordinator WebRTCDialCoordinator) (*ProxyMetri
 
 
 	return &ProxyMetrics{
 	return &ProxyMetrics{
 		BaseAPIParameters:             packedBaseParams,
 		BaseAPIParameters:             packedBaseParams,
-		ProxyProtocolVersion:          ProxyProtocolVersion1,
+		ProxyProtocolVersion:          proxyProtocolVersion,
 		NATType:                       webRTCCoordinator.NATType(),
 		NATType:                       webRTCCoordinator.NATType(),
 		PortMappingTypes:              webRTCCoordinator.PortMappingTypes(),
 		PortMappingTypes:              webRTCCoordinator.PortMappingTypes(),
 		MaxClients:                    int32(p.config.MaxClients),
 		MaxClients:                    int32(p.config.MaxClients),