|
@@ -8,10 +8,13 @@ import (
|
|
|
"log"
|
|
"log"
|
|
|
"net"
|
|
"net"
|
|
|
"sync"
|
|
"sync"
|
|
|
|
|
+ "time"
|
|
|
|
|
|
|
|
"www.bamsoftware.com/git/dnstt.git/turbotunnel"
|
|
"www.bamsoftware.com/git/dnstt.git/turbotunnel"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+const dialTimeout = 30 * time.Second
|
|
|
|
|
+
|
|
|
// TLSPacketConn is a TLS- and TCP-based transport for DNS messages, used for
|
|
// TLSPacketConn is a TLS- and TCP-based transport for DNS messages, used for
|
|
|
// DNS over TLS (DoT). Its WriteTo and ReadFrom methods exchange DNS messages
|
|
// DNS over TLS (DoT). Its WriteTo and ReadFrom methods exchange DNS messages
|
|
|
// over a TLS channel, prefixing each message with a two-octet length field as
|
|
// over a TLS channel, prefixing each message with a two-octet length field as
|
|
@@ -41,8 +44,11 @@ func NewTLSPacketConn(addr string) (*TLSPacketConn, error) {
|
|
|
// becomes disconnected. We do the first dial here, outside the
|
|
// becomes disconnected. We do the first dial here, outside the
|
|
|
// goroutine, so that any immediate and permanent connection errors are
|
|
// goroutine, so that any immediate and permanent connection errors are
|
|
|
// reported directly to the caller of NewTLSPacketConn.
|
|
// reported directly to the caller of NewTLSPacketConn.
|
|
|
|
|
+ dialer := &net.Dialer{
|
|
|
|
|
+ Timeout: dialTimeout,
|
|
|
|
|
+ }
|
|
|
tlsConfig := &tls.Config{}
|
|
tlsConfig := &tls.Config{}
|
|
|
- conn, err := tls.Dial("tcp", addr, tlsConfig)
|
|
|
|
|
|
|
+ conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
@@ -69,7 +75,7 @@ func NewTLSPacketConn(addr string) (*TLSPacketConn, error) {
|
|
|
conn.Close()
|
|
conn.Close()
|
|
|
|
|
|
|
|
// Whenever the TLS connection dies, redial a new one.
|
|
// Whenever the TLS connection dies, redial a new one.
|
|
|
- conn, err = tls.Dial("tcp", addr, tlsConfig)
|
|
|
|
|
|
|
+ conn, err = tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
log.Printf("tls.Dial: %v", err)
|
|
log.Printf("tls.Dial: %v", err)
|
|
|
break
|
|
break
|