Просмотр исходного кода

Add retries to unreliable test

Rod Hynes 1 месяц назад
Родитель
Сommit
1593dc604d
1 измененных файлов с 41 добавлено и 16 удалено
  1. 41 16
      psiphon/common/inproxy/discovery_test.go

+ 41 - 16
psiphon/common/inproxy/discovery_test.go

@@ -23,13 +23,27 @@ package inproxy
 
 import (
 	"context"
+	"fmt"
 	"sync/atomic"
 	"testing"
 
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/internal/testutils"
 )
 
 func TestNATDiscovery(t *testing.T) {
+	// Since this test can fail due to external network conditions, retry.
+	var err error
+	for try := 0; try < 2; try++ {
+		err = runTestNATDiscovery()
+		if err == nil {
+			return
+		}
+	}
+	t.Error(err.Error())
+}
+
+func runTestNATDiscovery() error {
 
 	// TODO: run local STUN and port mapping servers to test against, along
 	// with iptables rules to simulate NAT conditions
@@ -54,48 +68,48 @@ func TestNATDiscovery(t *testing.T) {
 		},
 
 		stunServerAddressSucceeded: func(RFC5780 bool, address string) {
-			atomic.AddInt32(&stunServerAddressSucceededCallCount, 1)
-			if address != stunServerAddress {
-				t.Errorf("unexpected STUN server address")
+			if address == stunServerAddress {
+				atomic.AddInt32(&stunServerAddressSucceededCallCount, 1)
 			}
 		},
 
 		stunServerAddressFailed: func(RFC5780 bool, address string) {
-			atomic.AddInt32(&stunServerAddressFailedCallCount, 1)
-			if address != stunServerAddress {
-				t.Errorf("unexpected STUN server address")
+			if address == stunServerAddress {
+				atomic.AddInt32(&stunServerAddressFailedCallCount, 1)
 			}
 		},
 	}
 
-	checkCallCounts := func(a, b, c, d int32) {
+	checkCallCounts := func(a, b, c, d int32) error {
 		callCount := atomic.LoadInt32(&setNATTypeCallCount)
 		if callCount != a {
-			t.Errorf(
+			return errors.Tracef(
 				"unexpected setNATType call count: %d",
 				callCount)
 		}
 
 		callCount = atomic.LoadInt32(&setPortMappingTypesCallCount)
 		if callCount != b {
-			t.Errorf(
+			return errors.Tracef(
 				"unexpected setPortMappingTypes call count: %d",
 				callCount)
 		}
 
 		callCount = atomic.LoadInt32(&stunServerAddressSucceededCallCount)
 		if callCount != c {
-			t.Errorf(
+			return errors.Tracef(
 				"unexpected stunServerAddressSucceeded call count: %d",
 				callCount)
 		}
 
 		callCount = atomic.LoadInt32(&stunServerAddressFailedCallCount)
 		if callCount != d {
-			t.Errorf(
+			return errors.Tracef(
 				"unexpected stunServerAddressFailedCallCount call count: %d",
 				callCount)
 		}
+
+		return nil
 	}
 
 	config := &NATDiscoverConfig{
@@ -109,7 +123,10 @@ func TestNATDiscovery(t *testing.T) {
 
 	NATDiscover(context.Background(), config)
 
-	checkCallCounts(1, 0, 1, 0)
+	err := checkCallCounts(1, 0, 1, 0)
+	if err != nil {
+		return errors.Trace(err)
+	}
 
 	// Should do port mapping only
 
@@ -118,7 +135,10 @@ func TestNATDiscovery(t *testing.T) {
 
 	NATDiscover(context.Background(), config)
 
-	checkCallCounts(1, 1, 1, 0)
+	err = checkCallCounts(1, 1, 1, 0)
+	if err != nil {
+		return errors.Trace(err)
+	}
 
 	// Should skip both and use values cached in WebRTCDialCoordinator
 
@@ -127,8 +147,13 @@ func TestNATDiscovery(t *testing.T) {
 
 	NATDiscover(context.Background(), config)
 
-	checkCallCounts(1, 1, 1, 0)
+	err = checkCallCounts(1, 1, 1, 0)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	fmt.Printf("NAT Type: %s\n", coordinator.NATType())
+	fmt.Printf("Port Mapping Types: %s\n", coordinator.PortMappingTypes())
 
-	t.Logf("NAT Type: %s", coordinator.NATType())
-	t.Logf("Port Mapping Types: %s", coordinator.PortMappingTypes())
+	return nil
 }