|
|
@@ -48,13 +48,20 @@ import (
|
|
|
)
|
|
|
|
|
|
func TestInproxy(t *testing.T) {
|
|
|
- err := runTestInproxy()
|
|
|
+ err := runTestInproxy(false)
|
|
|
if err != nil {
|
|
|
t.Errorf(errors.Trace(err).Error())
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func runTestInproxy() error {
|
|
|
+func TestInproxyMustUpgrade(t *testing.T) {
|
|
|
+ err := runTestInproxy(true)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf(errors.Trace(err).Error())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func runTestInproxy(doMustUpgrade bool) error {
|
|
|
|
|
|
// Note: use the environment variable PION_LOG_TRACE=all to emit WebRTC logging.
|
|
|
|
|
|
@@ -95,6 +102,23 @@ func runTestInproxy() error {
|
|
|
roundTripperFailedCount := int32(0)
|
|
|
roundTripperFailed := func(RoundTripper) { atomic.AddInt32(&roundTripperFailedCount, 1) }
|
|
|
|
|
|
+ var receivedProxyMustUpgrade chan struct{}
|
|
|
+ var receivedClientMustUpgrade chan struct{}
|
|
|
+ if doMustUpgrade {
|
|
|
+
|
|
|
+ receivedProxyMustUpgrade = make(chan struct{})
|
|
|
+ receivedClientMustUpgrade = make(chan struct{})
|
|
|
+
|
|
|
+ // trigger MustUpgrade
|
|
|
+ proxyProtocolVersion = 0
|
|
|
+
|
|
|
+ // Minimize test parameters for MustUpgrade case
|
|
|
+ numProxies = 1
|
|
|
+ proxyMaxClients = 1
|
|
|
+ numClients = 1
|
|
|
+ testDisableSTUN = true
|
|
|
+ }
|
|
|
+
|
|
|
testCtx, stopTest := context.WithCancel(context.Background())
|
|
|
defer stopTest()
|
|
|
|
|
|
@@ -394,6 +418,9 @@ func runTestInproxy() error {
|
|
|
|
|
|
tacticsNetworkID := prng.HexString(32)
|
|
|
|
|
|
+ runCtx, cancelRun := context.WithCancel(testCtx)
|
|
|
+ // No deferred cancelRun due to testGroup.Go below
|
|
|
+
|
|
|
proxy, err := NewProxy(&ProxyConfig{
|
|
|
|
|
|
Logger: logger,
|
|
|
@@ -427,6 +454,11 @@ func runTestInproxy() error {
|
|
|
time.Now().UTC().Format(time.RFC3339),
|
|
|
connectingClients, connectedClients, bytesUp, bytesDown)
|
|
|
},
|
|
|
+
|
|
|
+ MustUpgrade: func() {
|
|
|
+ close(receivedProxyMustUpgrade)
|
|
|
+ cancelRun()
|
|
|
+ },
|
|
|
})
|
|
|
if err != nil {
|
|
|
return errors.Trace(err)
|
|
|
@@ -435,7 +467,7 @@ func runTestInproxy() error {
|
|
|
addPendingProxyTacticsCallback(proxyPrivateKey)
|
|
|
|
|
|
testGroup.Go(func() error {
|
|
|
- proxy.Run(testCtx)
|
|
|
+ proxy.Run(runCtx)
|
|
|
return nil
|
|
|
})
|
|
|
}
|
|
|
@@ -448,13 +480,15 @@ func runTestInproxy() error {
|
|
|
// - Don't wait for > numProxies announcements due to
|
|
|
// InitiatorSessions.NewRoundTrip waitToShareSession limitation
|
|
|
|
|
|
- for {
|
|
|
- time.Sleep(100 * time.Millisecond)
|
|
|
- broker.matcher.announcementQueueMutex.Lock()
|
|
|
- n := broker.matcher.announcementQueue.getLen()
|
|
|
- broker.matcher.announcementQueueMutex.Unlock()
|
|
|
- if n >= numProxies {
|
|
|
- break
|
|
|
+ if !doMustUpgrade {
|
|
|
+ for {
|
|
|
+ time.Sleep(100 * time.Millisecond)
|
|
|
+ broker.matcher.announcementQueueMutex.Lock()
|
|
|
+ n := broker.matcher.announcementQueue.getLen()
|
|
|
+ broker.matcher.announcementQueueMutex.Unlock()
|
|
|
+ if n >= numProxies {
|
|
|
+ break
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -498,6 +532,11 @@ func runTestInproxy() error {
|
|
|
DialNetworkProtocol: networkProtocol,
|
|
|
DialAddress: addr,
|
|
|
PackedDestinationServerEntry: packedDestinationServerEntry,
|
|
|
+ MustUpgrade: func() {
|
|
|
+ fmt.Printf("HI!\n")
|
|
|
+ close(receivedClientMustUpgrade)
|
|
|
+ cancelDial()
|
|
|
+ },
|
|
|
})
|
|
|
if err != nil {
|
|
|
return errors.Trace(err)
|
|
|
@@ -718,43 +757,57 @@ func runTestInproxy() error {
|
|
|
clientsGroup.Go(makeClientFunc(isTCP, isMobile, brokerClient, webRTCCoordinator))
|
|
|
}
|
|
|
|
|
|
- // Await client transfers complete
|
|
|
+ if doMustUpgrade {
|
|
|
|
|
|
- logger.WithTrace().Info("AWAIT DATA TRANSFER")
|
|
|
+ // Await MustUpgrade callbacks
|
|
|
|
|
|
- err = clientsGroup.Wait()
|
|
|
- if err != nil {
|
|
|
- return errors.Trace(err)
|
|
|
- }
|
|
|
+ logger.WithTrace().Info("AWAIT MUST UPGRADE")
|
|
|
|
|
|
- logger.WithTrace().Info("DONE DATA TRANSFER")
|
|
|
+ <-receivedProxyMustUpgrade
|
|
|
+ <-receivedClientMustUpgrade
|
|
|
|
|
|
- if hasPendingBrokerServerReports() {
|
|
|
- return errors.TraceNew("unexpected pending broker server requests")
|
|
|
- }
|
|
|
+ _ = clientsGroup.Wait()
|
|
|
|
|
|
- if hasPendingProxyTacticsCallbacks() {
|
|
|
- return errors.TraceNew("unexpected pending proxy tactics callback")
|
|
|
- }
|
|
|
+ } else {
|
|
|
+
|
|
|
+ // Await client transfers complete
|
|
|
|
|
|
- // TODO: check that elapsed time is consistent with rate limit (+/-)
|
|
|
+ logger.WithTrace().Info("AWAIT DATA TRANSFER")
|
|
|
|
|
|
- // Check if STUN server replay callbacks were triggered
|
|
|
- if !testDisableSTUN {
|
|
|
- if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
|
|
|
- return errors.TraceNew("unexpected STUN server succeeded count")
|
|
|
+ err = clientsGroup.Wait()
|
|
|
+ if err != nil {
|
|
|
+ return errors.Trace(err)
|
|
|
}
|
|
|
- }
|
|
|
- if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
|
|
|
- return errors.TraceNew("unexpected STUN server failed count")
|
|
|
- }
|
|
|
|
|
|
- // Check if RoundTripper server replay callbacks were triggered
|
|
|
- if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
|
|
|
- return errors.TraceNew("unexpected round tripper succeeded count")
|
|
|
- }
|
|
|
- if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
|
|
|
- return errors.TraceNew("unexpected round tripper failed count")
|
|
|
+ logger.WithTrace().Info("DONE DATA TRANSFER")
|
|
|
+
|
|
|
+ if hasPendingBrokerServerReports() {
|
|
|
+ return errors.TraceNew("unexpected pending broker server requests")
|
|
|
+ }
|
|
|
+
|
|
|
+ if hasPendingProxyTacticsCallbacks() {
|
|
|
+ return errors.TraceNew("unexpected pending proxy tactics callback")
|
|
|
+ }
|
|
|
+
|
|
|
+ // TODO: check that elapsed time is consistent with rate limit (+/-)
|
|
|
+
|
|
|
+ // Check if STUN server replay callbacks were triggered
|
|
|
+ if !testDisableSTUN {
|
|
|
+ if atomic.LoadInt32(&stunServerAddressSucceededCount) < 1 {
|
|
|
+ return errors.TraceNew("unexpected STUN server succeeded count")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if atomic.LoadInt32(&stunServerAddressFailedCount) > 0 {
|
|
|
+ return errors.TraceNew("unexpected STUN server failed count")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check if RoundTripper server replay callbacks were triggered
|
|
|
+ if atomic.LoadInt32(&roundTripperSucceededCount) < 1 {
|
|
|
+ return errors.TraceNew("unexpected round tripper succeeded count")
|
|
|
+ }
|
|
|
+ if atomic.LoadInt32(&roundTripperFailedCount) > 0 {
|
|
|
+ return errors.TraceNew("unexpected round tripper failed count")
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// Await shutdowns
|