Просмотр исходного кода

Fix: detect only dial goroutines launched within TestInterruptDials

Rod Hynes 11 месяцев назад
Родитель
Сommit
0478021bd2
1 измененных файлов с 21 добавлено и 10 удалено
  1. 21 10
      psiphon/interrupt_dials_test.go

+ 21 - 10
psiphon/interrupt_dials_test.go

@@ -43,11 +43,11 @@ func TestInterruptDials(t *testing.T) {
 	makeDialers := make(map[string]func(string) common.Dialer)
 
 	makeDialers["TCP"] = func(string) common.Dialer {
-		return NewTCPDialer(&DialConfig{ResolveIP: resolveIP})
+		return interruptDialsNewTCPDialer(&DialConfig{ResolveIP: resolveIP})
 	}
 
 	makeDialers["SOCKS4-Proxied"] = func(mockServerAddr string) common.Dialer {
-		return NewTCPDialer(
+		return interruptDialsNewTCPDialer(
 			&DialConfig{
 				ResolveIP:        resolveIP,
 				UpstreamProxyURL: "socks4a://" + mockServerAddr,
@@ -55,7 +55,7 @@ func TestInterruptDials(t *testing.T) {
 	}
 
 	makeDialers["SOCKS5-Proxied"] = func(mockServerAddr string) common.Dialer {
-		return NewTCPDialer(
+		return interruptDialsNewTCPDialer(
 			&DialConfig{
 				ResolveIP:        resolveIP,
 				UpstreamProxyURL: "socks5://" + mockServerAddr,
@@ -63,7 +63,7 @@ func TestInterruptDials(t *testing.T) {
 	}
 
 	makeDialers["HTTP-CONNECT-Proxied"] = func(mockServerAddr string) common.Dialer {
-		return NewTCPDialer(
+		return interruptDialsNewTCPDialer(
 			&DialConfig{
 				ResolveIP:        resolveIP,
 				UpstreamProxyURL: "http://" + mockServerAddr,
@@ -85,16 +85,18 @@ func TestInterruptDials(t *testing.T) {
 	makeDialers["TLS"] = func(string) common.Dialer {
 		// Cast CustomTLSDialer to common.Dialer.
 		return func(context context.Context, network, addr string) (net.Conn, error) {
-			return NewCustomTLSDialer(
+			return interruptDialsNewCustomTLSDialer(
 				&CustomTLSConfig{
-					Parameters:               params,
-					Dial:                     NewTCPDialer(&DialConfig{ResolveIP: resolveIP}),
+					Parameters: params,
+					Dial: interruptDialsNewTCPDialer(
+						&DialConfig{ResolveIP: resolveIP}),
 					RandomizedTLSProfileSeed: seed,
 				})(context, network, addr)
 		}
 	}
 
-	dialGoroutineFunctionNames := []string{"NewTCPDialer", "NewCustomTLSDialer"}
+	dialGoroutineFunctionNames := []string{
+		"interruptDialsNewTCPDialer", "interruptDialsNewCustomTLSDialer"}
 
 	for dialerName, makeDialer := range makeDialers {
 		for _, doTimeout := range []bool{true, false} {
@@ -112,6 +114,14 @@ func TestInterruptDials(t *testing.T) {
 
 }
 
+func interruptDialsNewTCPDialer(config *DialConfig) common.Dialer {
+	return NewTCPDialer(config)
+}
+
+func interruptDialsNewCustomTLSDialer(config *CustomTLSConfig) common.Dialer {
+	return NewCustomTLSDialer(config)
+}
+
 func runInterruptDials(
 	t *testing.T,
 	doTimeout bool,
@@ -160,7 +170,7 @@ func runInterruptDials(
 	var ctx context.Context
 	var cancelFunc context.CancelFunc
 
-	timeout := 100 * time.Millisecond
+	timeout := 1 * time.Second
 
 	if doTimeout {
 		ctx, cancelFunc = context.WithTimeout(context.Background(), timeout)
@@ -191,7 +201,7 @@ func runInterruptDials(
 	<-listenerAccepted
 
 	if doTimeout {
-		time.Sleep(timeout)
+		time.Sleep(timeout + 100*time.Millisecond)
 		defer cancelFunc()
 	} else {
 		// No timeout, so interrupt with cancel
@@ -224,6 +234,7 @@ func findGoroutines(t *testing.T, targets []string) bool {
 	r := make([]runtime.StackRecord, n)
 	runtime.GoroutineProfile(r)
 	found := false
+
 	for _, g := range r {
 		stack := g.Stack()
 		funcNames := make([]string, len(stack))