|
|
@@ -24,6 +24,8 @@ import (
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"runtime"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
@@ -68,12 +70,18 @@ func TestInterruptDials(t *testing.T) {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+ dialGoroutineFunctionNames := []string{"NewTCPDialer", "NewCustomTLSDialer"}
|
|
|
+
|
|
|
for dialerName, makeDialer := range makeDialers {
|
|
|
for _, doTimeout := range []bool{true, false} {
|
|
|
t.Run(
|
|
|
fmt.Sprintf("%s-timeout-%+v", dialerName, doTimeout),
|
|
|
func(t *testing.T) {
|
|
|
- runInterruptDials(t, doTimeout, makeDialer)
|
|
|
+ runInterruptDials(
|
|
|
+ t,
|
|
|
+ doTimeout,
|
|
|
+ makeDialer,
|
|
|
+ dialGoroutineFunctionNames)
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
@@ -83,7 +91,8 @@ func TestInterruptDials(t *testing.T) {
|
|
|
func runInterruptDials(
|
|
|
t *testing.T,
|
|
|
doTimeout bool,
|
|
|
- makeDialer func(string) Dialer) {
|
|
|
+ makeDialer func(string) Dialer,
|
|
|
+ dialGoroutineFunctionNames []string) {
|
|
|
|
|
|
t.Logf("Test timeout: %+v", doTimeout)
|
|
|
|
|
|
@@ -101,7 +110,11 @@ func runInterruptDials(
|
|
|
|
|
|
listenerAccepted := make(chan struct{}, 1)
|
|
|
|
|
|
+ noResponseListenerWaitGroup := new(sync.WaitGroup)
|
|
|
+ noResponseListenerWaitGroup.Add(1)
|
|
|
+ defer noResponseListenerWaitGroup.Wait()
|
|
|
go func() {
|
|
|
+ defer noResponseListenerWaitGroup.Done()
|
|
|
for {
|
|
|
conn, err := noResponseListener.Accept()
|
|
|
if err != nil {
|
|
|
@@ -137,8 +150,6 @@ func runInterruptDials(
|
|
|
|
|
|
dialTerminated := make(chan struct{}, len(addrs))
|
|
|
|
|
|
- startGoroutines := runtime.NumGoroutine()
|
|
|
-
|
|
|
for _, addr := range addrs {
|
|
|
go func(addr string) {
|
|
|
conn, err := makeDialer(addr)(ctx, "tcp", addr)
|
|
|
@@ -173,15 +184,44 @@ func runInterruptDials(
|
|
|
|
|
|
interruptDuration := monotime.Since(startWaiting)
|
|
|
|
|
|
- if interruptDuration > 10*time.Millisecond {
|
|
|
+ if interruptDuration > 100*time.Millisecond {
|
|
|
t.Fatalf("interrupt duration too long: %s", interruptDuration)
|
|
|
}
|
|
|
|
|
|
// Test: interrupted dialers must not leave goroutines running
|
|
|
|
|
|
- endGoroutines := runtime.NumGoroutine()
|
|
|
+ if findGoroutines(t, dialGoroutineFunctionNames) {
|
|
|
+ t.Fatalf("unexpected dial goroutines")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func findGoroutines(t *testing.T, targets []string) bool {
|
|
|
+ n, _ := runtime.GoroutineProfile(nil)
|
|
|
+ r := make([]runtime.StackRecord, n)
|
|
|
+ n, _ = runtime.GoroutineProfile(r)
|
|
|
+ found := false
|
|
|
+ for _, g := range r {
|
|
|
+ stack := g.Stack()
|
|
|
+ funcNames := make([]string, len(stack))
|
|
|
+ for i := 0; i < len(stack); i++ {
|
|
|
+ funcNames[i] = getFunctionName(stack[i])
|
|
|
+ }
|
|
|
+ s := strings.Join(funcNames, ", ")
|
|
|
+ for _, target := range targets {
|
|
|
+ if strings.Contains(s, target) {
|
|
|
+ t.Logf("found dial goroutine: %s", s)
|
|
|
+ found = true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return found
|
|
|
+}
|
|
|
|
|
|
- if endGoroutines > startGoroutines {
|
|
|
- t.Fatalf("unexpected goroutines: %d > %d", endGoroutines, startGoroutines)
|
|
|
+func getFunctionName(pc uintptr) string {
|
|
|
+ funcName := runtime.FuncForPC(pc).Name()
|
|
|
+ index := strings.LastIndex(funcName, "/")
|
|
|
+ if index != -1 {
|
|
|
+ funcName = funcName[index+1:]
|
|
|
}
|
|
|
+ return funcName
|
|
|
}
|