|
|
@@ -28,6 +28,7 @@ import (
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
+ "os"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
@@ -51,18 +52,19 @@ const (
|
|
|
|
|
|
// RelayConfig specifies the configuration for a Relay.
|
|
|
//
|
|
|
-// The CACertificates and HostCertificate parameters are used for mutually
|
|
|
+// The CACertificates and HostCertificate/Key parameters are used for mutually
|
|
|
// authenticated TLS between the Relay and the DSL backend. The HostID value
|
|
|
// is sent to the DSL backend for logging, and should be populated with the
|
|
|
// HostID in psiphond.config.
|
|
|
type RelayConfig struct {
|
|
|
Logger common.Logger
|
|
|
|
|
|
- CACertificates *x509.CertPool
|
|
|
+ CACertificatesFilename string
|
|
|
+ HostCertificateFilename string
|
|
|
+ HostKeyFilename string
|
|
|
|
|
|
- HostCertificate *tls.Certificate
|
|
|
-
|
|
|
- DynamicServerListServiceURL string
|
|
|
+ GetServiceAddress func(
|
|
|
+ clientGeoIPData common.GeoIPData) (string, error)
|
|
|
|
|
|
HostID string
|
|
|
|
|
|
@@ -85,10 +87,14 @@ type RelayConfig struct {
|
|
|
// GetServerEntriesRequest requests may be fully or partially served out of
|
|
|
// the local cache.
|
|
|
type Relay struct {
|
|
|
- config *RelayConfig
|
|
|
- tlsConfig *tls.Config
|
|
|
+ config *RelayConfig
|
|
|
+
|
|
|
+ caCertificatesFile common.ReloadableFile
|
|
|
+ hostCertificateFile common.ReloadableFile
|
|
|
+ hostKeyFile common.ReloadableFile
|
|
|
|
|
|
mutex sync.Mutex
|
|
|
+ tlsConfig *tls.Config
|
|
|
httpClient *http.Client
|
|
|
requestTimeout time.Duration
|
|
|
requestRetryCount int
|
|
|
@@ -98,16 +104,19 @@ type Relay struct {
|
|
|
}
|
|
|
|
|
|
// NewRelay creates a new Relay.
|
|
|
-func NewRelay(config *RelayConfig) *Relay {
|
|
|
+func NewRelay(config *RelayConfig) (*Relay, error) {
|
|
|
+
|
|
|
+ relay := &Relay{
|
|
|
+ config: config,
|
|
|
|
|
|
- tlsConfig := &tls.Config{
|
|
|
- RootCAs: config.CACertificates,
|
|
|
- Certificates: []tls.Certificate{*config.HostCertificate},
|
|
|
+ caCertificatesFile: common.NewReloadableFile(config.CACertificatesFilename, false, nil),
|
|
|
+ hostCertificateFile: common.NewReloadableFile(config.HostCertificateFilename, false, nil),
|
|
|
+ hostKeyFile: common.NewReloadableFile(config.HostKeyFilename, false, nil),
|
|
|
}
|
|
|
|
|
|
- relay := &Relay{
|
|
|
- config: config,
|
|
|
- tlsConfig: tlsConfig,
|
|
|
+ _, err := relay.Reload()
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Trace(err)
|
|
|
}
|
|
|
|
|
|
relay.SetRequestParameters(
|
|
|
@@ -121,7 +130,98 @@ func NewRelay(config *RelayConfig) *Relay {
|
|
|
defaultServerEntryCacheTTL,
|
|
|
defaultServerEntryCacheMaxSize)
|
|
|
|
|
|
- return relay
|
|
|
+ return relay, nil
|
|
|
+}
|
|
|
+
|
|
|
+// Reload reloads the TLS configuration when the file contents have changed.
|
|
|
+//
|
|
|
+// Reload implements the common.Reloader interface.
|
|
|
+func (r *Relay) Reload() (bool, error) {
|
|
|
+
|
|
|
+ // The common.ReloadableFile.reloadAction callback not used; instead,
|
|
|
+ // ReloadableFiles are used to check for changed file contents. When any
|
|
|
+ // file has changed, all TLS configuration files are reloaded and the TLS
|
|
|
+ // configuration is reinitialized.
|
|
|
+
|
|
|
+ reloadedAny := false
|
|
|
+
|
|
|
+ reloaded, err := r.caCertificatesFile.Reload()
|
|
|
+ if err != nil {
|
|
|
+ return false, errors.Trace(err)
|
|
|
+ }
|
|
|
+ reloadedAny = reloadedAny || reloaded
|
|
|
+
|
|
|
+ reloaded, err = r.hostCertificateFile.Reload()
|
|
|
+ if err != nil {
|
|
|
+ return false, errors.Trace(err)
|
|
|
+ }
|
|
|
+ reloadedAny = reloadedAny || reloaded
|
|
|
+
|
|
|
+ reloaded, err = r.hostKeyFile.Reload()
|
|
|
+ if err != nil {
|
|
|
+ return false, errors.Trace(err)
|
|
|
+ }
|
|
|
+ reloadedAny = reloadedAny || reloaded
|
|
|
+
|
|
|
+ if !reloadedAny {
|
|
|
+ return false, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ caCertsPEM, err := os.ReadFile(r.config.CACertificatesFilename)
|
|
|
+ if err != nil {
|
|
|
+ return false, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ caCertificates := x509.NewCertPool()
|
|
|
+ if !caCertificates.AppendCertsFromPEM(caCertsPEM) {
|
|
|
+ return false, errors.TraceNew("AppendCertsFromPEM failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ hostCertificate, err := tls.LoadX509KeyPair(
|
|
|
+ r.config.HostCertificateFilename,
|
|
|
+ r.config.HostKeyFilename)
|
|
|
+ if err != nil {
|
|
|
+ return false, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ r.mutex.Lock()
|
|
|
+ defer r.mutex.Unlock()
|
|
|
+
|
|
|
+ r.tlsConfig = &tls.Config{
|
|
|
+ RootCAs: caCertificates,
|
|
|
+ Certificates: []tls.Certificate{hostCertificate},
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.httpClient != nil {
|
|
|
+
|
|
|
+ // Replace the http.Client if it exists. See the comment in
|
|
|
+ // SetRequestParameters regarding in-flight requests and idle timeout
|
|
|
+ // limitations.
|
|
|
+
|
|
|
+ httpTransport := r.httpClient.Transport.(*http.Transport)
|
|
|
+
|
|
|
+ r.httpClient = &http.Client{
|
|
|
+ Transport: &http.Transport{
|
|
|
+ TLSClientConfig: r.tlsConfig,
|
|
|
+ MaxConnsPerHost: httpTransport.MaxConnsPerHost,
|
|
|
+ MaxIdleConns: httpTransport.MaxIdleConns,
|
|
|
+ MaxIdleConnsPerHost: httpTransport.MaxIdleConnsPerHost,
|
|
|
+ IdleConnTimeout: httpTransport.IdleConnTimeout,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return true, nil
|
|
|
+}
|
|
|
+
|
|
|
+// WillReload implements the common.Reloader interface.
|
|
|
+func (r *Relay) WillReload() bool {
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+// ReloadLogDescription implements the common.Reloader interface.
|
|
|
+func (r *Relay) ReloadLogDescription() string {
|
|
|
+ return "DSL Relay TLS configuration"
|
|
|
}
|
|
|
|
|
|
// SetRequestParameters updates the HTTP request parameters used for upstream
|
|
|
@@ -144,7 +244,7 @@ func (r *Relay) SetRequestParameters(
|
|
|
// continue until complete and eventually the previous http.Client will
|
|
|
// be garbage collected.
|
|
|
//
|
|
|
- // TODO: don't retain the previous http.Client as long as
|
|
|
+ // TODO: don't retain the previous http.Client for as long as
|
|
|
// http.Transport.IdleConnTimeout.
|
|
|
|
|
|
var httpTransport *http.Transport
|
|
|
@@ -166,7 +266,6 @@ func (r *Relay) SetRequestParameters(
|
|
|
IdleConnTimeout: httpIdleConnTimeout,
|
|
|
},
|
|
|
}
|
|
|
-
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -323,7 +422,12 @@ func (r *Relay) HandleRequest(
|
|
|
defer requestCancelFunc()
|
|
|
}
|
|
|
|
|
|
- url := fmt.Sprintf("https://%s%s", r.config.DynamicServerListServiceURL, path)
|
|
|
+ serviceAddress, err := r.config.GetServiceAddress(clientGeoIPData)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ url := fmt.Sprintf("https://%s%s", serviceAddress, path)
|
|
|
|
|
|
httpRequest, err := http.NewRequestWithContext(
|
|
|
requestCtx, "POST", url, bytes.NewBuffer(relayedRequest.Request))
|