|
|
@@ -75,10 +75,7 @@ import (
|
|
|
"crypto/tls"
|
|
|
"crypto/x509"
|
|
|
"errors"
|
|
|
- "fmt"
|
|
|
- "io"
|
|
|
"net"
|
|
|
- "strings"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
@@ -91,25 +88,28 @@ func (timeoutError) Temporary() bool { return true }
|
|
|
// CustomTLSConfig contains parameters to determine the behavior
|
|
|
// of CustomTLSDial.
|
|
|
type CustomTLSConfig struct {
|
|
|
+
|
|
|
// Dial is the network connection dialer. TLS is layered on
|
|
|
// top of a new network connection created with dialer.
|
|
|
Dial Dialer
|
|
|
+
|
|
|
// Timeout is and optional timeout for combined network
|
|
|
// connection dial and TLS handshake.
|
|
|
Timeout time.Duration
|
|
|
+
|
|
|
// FrontingAddr overrides the "addr" input to Dial when specified
|
|
|
FrontingAddr string
|
|
|
- // HttpProxyAddress specifies an HTTP proxy to be used
|
|
|
- // (with HTTP CONNECT).
|
|
|
- HttpProxyAddress string
|
|
|
+
|
|
|
// SendServerName specifies whether to use SNI
|
|
|
// (tlsdialer functionality)
|
|
|
SendServerName bool
|
|
|
+
|
|
|
// VerifyLegacyCertificate is a special case self-signed server
|
|
|
// certificate case. Ignores IP SANs and basic constraints. No
|
|
|
// certificate chain. Just checks that the server presented the
|
|
|
// specified certificate.
|
|
|
VerifyLegacyCertificate *x509.Certificate
|
|
|
+
|
|
|
// TlsConfig is a tls.Config to use in the
|
|
|
// non-verifyLegacyCertificate case.
|
|
|
TlsConfig *tls.Config
|
|
|
@@ -141,9 +141,7 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
|
|
|
}
|
|
|
|
|
|
dialAddr := addr
|
|
|
- if config.HttpProxyAddress != "" {
|
|
|
- dialAddr = config.HttpProxyAddress
|
|
|
- } else if config.FrontingAddr != "" {
|
|
|
+ if config.FrontingAddr != "" {
|
|
|
dialAddr = config.FrontingAddr
|
|
|
}
|
|
|
|
|
|
@@ -152,34 +150,27 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
|
|
|
return nil, ContextError(err)
|
|
|
}
|
|
|
|
|
|
- targetAddr := addr
|
|
|
- if config.FrontingAddr != "" {
|
|
|
- targetAddr = config.FrontingAddr
|
|
|
- }
|
|
|
-
|
|
|
- colonPos := strings.LastIndex(targetAddr, ":")
|
|
|
- if colonPos == -1 {
|
|
|
- colonPos = len(targetAddr)
|
|
|
+ hostname, _, err := net.SplitHostPort(dialAddr)
|
|
|
+ if err != nil {
|
|
|
+ return nil, ContextError(err)
|
|
|
}
|
|
|
- hostname := targetAddr[:colonPos]
|
|
|
|
|
|
tlsConfig := config.TlsConfig
|
|
|
if tlsConfig == nil {
|
|
|
tlsConfig = &tls.Config{}
|
|
|
}
|
|
|
|
|
|
- serverName := tlsConfig.ServerName
|
|
|
+ // Copy config so we can tweak it
|
|
|
+ tlsConfigCopy := new(tls.Config)
|
|
|
+ *tlsConfigCopy = *tlsConfig
|
|
|
|
|
|
+ serverName := tlsConfig.ServerName
|
|
|
// If no ServerName is set, infer the ServerName
|
|
|
// from the hostname we're connecting to.
|
|
|
if serverName == "" {
|
|
|
serverName = hostname
|
|
|
}
|
|
|
|
|
|
- // Copy config so we can tweak it
|
|
|
- tlsConfigCopy := new(tls.Config)
|
|
|
- *tlsConfigCopy = *tlsConfig
|
|
|
-
|
|
|
if config.SendServerName {
|
|
|
// Set the ServerName and rely on the usual logic in
|
|
|
// tls.Conn.Handshake() to do its verification
|
|
|
@@ -192,34 +183,11 @@ func CustomTLSDial(network, addr string, config *CustomTLSConfig) (*tls.Conn, er
|
|
|
|
|
|
conn := tls.Client(rawConn, tlsConfigCopy)
|
|
|
|
|
|
- establishConnection := func(rawConn net.Conn, conn *tls.Conn) error {
|
|
|
- // TODO: use the proxy request/response code from net/http/transport.go
|
|
|
- if config.HttpProxyAddress != "" {
|
|
|
- connectRequest := fmt.Sprintf(
|
|
|
- "CONNECT %s HTTP/1.1\r\nHost: %s\r\nConnection: Keep-Alive\r\n\r\n",
|
|
|
- targetAddr, hostname)
|
|
|
- _, err := rawConn.Write([]byte(connectRequest))
|
|
|
- if err != nil {
|
|
|
- return ContextError(err)
|
|
|
- }
|
|
|
- expectedResponse := []byte("HTTP/1.1 200 OK\r\n\r\n")
|
|
|
- readBuffer := make([]byte, len(expectedResponse))
|
|
|
- _, err = io.ReadFull(rawConn, readBuffer)
|
|
|
- if err != nil {
|
|
|
- return ContextError(err)
|
|
|
- }
|
|
|
- if !bytes.Equal(readBuffer, expectedResponse) {
|
|
|
- return fmt.Errorf("unexpected HTTP proxy response: %s", string(readBuffer))
|
|
|
- }
|
|
|
- }
|
|
|
- return conn.Handshake()
|
|
|
- }
|
|
|
-
|
|
|
if config.Timeout == 0 {
|
|
|
- err = establishConnection(rawConn, conn)
|
|
|
+ err = conn.Handshake()
|
|
|
} else {
|
|
|
go func() {
|
|
|
- errChannel <- establishConnection(rawConn, conn)
|
|
|
+ errChannel <- conn.Handshake()
|
|
|
}()
|
|
|
err = <-errChannel
|
|
|
}
|