Преглед на файлове

Use FrontingSpecs for tunneled downloads

mirokuratczyk преди 2 години
родител
ревизия
e39606e1d8
променени са 2 файла, в които са добавени 73 реда и са изтрити 29 реда
  1. 1 0
      psiphon/feedback.go
  2. 72 29
      psiphon/net.go

+ 1 - 0
psiphon/feedback.go

@@ -166,6 +166,7 @@ func SendFeedback(ctx context.Context, config *Config, diagnostics, uploadPath s
 			// redefines ResolveIP such that the corresponding fronting
 			// provider ID is passed into UntunneledResolveIP to enable the use
 			// of pre-resolved IPs.
+			// TODO: do not use pre-resolved IPs when tunneled.
 			IPs, err := UntunneledResolveIP(
 				ctx, config, resolver, hostname, "")
 			if err != nil {

+ 72 - 29
psiphon/net.go

@@ -391,7 +391,7 @@ func UntunneledResolveIP(
 	return IPs, nil
 }
 
-// makeUntunneledFrontedHTTPClient returns a net/http.Client which is
+// makeFrontedHTTPClient returns a net/http.Client which is
 // configured to use domain fronting and custom dialing features -- including
 // BindToDevice, etc. One or more fronting specs must be provided, i.e.
 // len(frontingSpecs) must be greater than 0. A function is returned which,
@@ -400,10 +400,11 @@ func UntunneledResolveIP(
 //
 // The context is applied to underlying TCP dials. The caller is responsible
 // for applying the context to requests made with the returned http.Client.
-func makeUntunneledFrontedHTTPClient(
+func makeFrontedHTTPClient(
 	ctx context.Context,
 	config *Config,
-	untunneledDialConfig *DialConfig,
+	tunneled bool,
+	dialConfig *DialConfig,
 	frontingSpecs parameters.FrontingSpecs,
 	selectedFrontingProviderID func(string),
 	skipVerify bool,
@@ -499,26 +500,31 @@ func makeUntunneledFrontedHTTPClient(
 	var resolvedIPAddress atomic.Value
 	resolvedIPAddress.Store("")
 
-	// The default untunneled dial config does not support pre-resolved IPs so
-	// redefine the dial config to override ResolveIP with an implementation
-	// that enables their use by passing the fronting provider ID into
-	// UntunneledResolveIP.
-	meekDialConfig := &DialConfig{
-		UpstreamProxyURL: untunneledDialConfig.UpstreamProxyURL,
-		CustomHeaders:    untunneledDialConfig.CustomHeaders,
-		DeviceBinder:     untunneledDialConfig.DeviceBinder,
-		IPv6Synthesizer:  untunneledDialConfig.IPv6Synthesizer,
-		ResolveIP: func(ctx context.Context, hostname string) ([]net.IP, error) {
-			IPs, err := UntunneledResolveIP(
-				ctx, config, config.GetResolver(), hostname, frontingProviderID)
-			if err != nil {
-				return nil, errors.Trace(err)
-			}
-			return IPs, nil
-		},
-		ResolvedIPCallback: func(IPAddress string) {
-			resolvedIPAddress.Store(IPAddress)
-		},
+	var meekDialConfig *DialConfig
+	if tunneled {
+		meekDialConfig = dialConfig
+	} else {
+		// The default untunneled dial config does not support pre-resolved IPs so
+		// redefine the dial config to override ResolveIP with an implementation
+		// that enables their use by passing the fronting provider ID into
+		// UntunneledResolveIP.
+		meekDialConfig = &DialConfig{
+			UpstreamProxyURL: dialConfig.UpstreamProxyURL,
+			CustomHeaders:    dialConfig.CustomHeaders,
+			DeviceBinder:     dialConfig.DeviceBinder,
+			IPv6Synthesizer:  dialConfig.IPv6Synthesizer,
+			ResolveIP: func(ctx context.Context, hostname string) ([]net.IP, error) {
+				IPs, err := UntunneledResolveIP(
+					ctx, config, config.GetResolver(), hostname, frontingProviderID)
+				if err != nil {
+					return nil, errors.Trace(err)
+				}
+				return IPs, nil
+			},
+			ResolvedIPCallback: func(IPAddress string) {
+				resolvedIPAddress.Store(IPAddress)
+			},
+		}
 	}
 
 	selectedUserAgent, userAgent := selectUserAgentIfUnset(p, meekDialConfig.CustomHeaders)
@@ -654,9 +660,10 @@ func MakeUntunneledHTTPClient(
 
 		// Ignore skipVerify because it only applies when there are no
 		// fronting specs.
-		httpClient, getParams, err := makeUntunneledFrontedHTTPClient(
+		httpClient, getParams, err := makeFrontedHTTPClient(
 			ctx,
 			config,
+			false,
 			untunneledDialConfig,
 			frontingSpecs,
 			selectedFrontingProviderID,
@@ -704,9 +711,13 @@ func MakeUntunneledHTTPClient(
 // dialing and, optionally, UseTrustedCACertificatesForStockTLS.
 // This http.Client uses stock TLS for HTTPS.
 func MakeTunneledHTTPClient(
+	ctx context.Context,
 	config *Config,
 	tunnel *Tunnel,
-	skipVerify bool) (*http.Client, error) {
+	skipVerify bool,
+	disableSystemRootCAs bool,
+	frontingSpecs parameters.FrontingSpecs,
+	selectedFrontingProviderID func(string)) (*http.Client, func() common.APIParameters, error) {
 
 	// Note: there is no dial context since SSH port forward dials cannot
 	// be interrupted directly. Closing the tunnel will interrupt the dials.
@@ -718,6 +729,32 @@ func MakeTunneledHTTPClient(
 		return conn, errors.Trace(err)
 	}
 
+	if len(frontingSpecs) > 0 {
+
+		dialConfig := &DialConfig{
+			TrustedCACertificatesFilename: config.TrustedCACertificatesFilename,
+			CustomDialer: func(_ context.Context, _, addr string) (net.Conn, error) {
+				return tunneledDialer("", addr)
+			},
+		}
+
+		// Ignore skipVerify because it only applies when there are no
+		// fronting specs.
+		httpClient, getParams, err := makeFrontedHTTPClient(
+			ctx,
+			config,
+			true,
+			dialConfig,
+			frontingSpecs,
+			selectedFrontingProviderID,
+			false,
+			disableSystemRootCAs)
+		if err != nil {
+			return nil, nil, errors.Trace(err)
+		}
+		return httpClient, getParams, nil
+	}
+
 	transport := &http.Transport{
 		Dial: tunneledDialer,
 	}
@@ -731,7 +768,7 @@ func MakeTunneledHTTPClient(
 		rootCAs := x509.NewCertPool()
 		certData, err := ioutil.ReadFile(config.TrustedCACertificatesFilename)
 		if err != nil {
-			return nil, errors.Trace(err)
+			return nil, nil, errors.Trace(err)
 		}
 		rootCAs.AppendCertsFromPEM(certData)
 		transport.TLSClientConfig = &tls.Config{RootCAs: rootCAs}
@@ -739,7 +776,7 @@ func MakeTunneledHTTPClient(
 
 	return &http.Client{
 		Transport: transport,
-	}, nil
+	}, nil, nil
 }
 
 // MakeDownloadHTTPClient is a helper that sets up a http.Client for use either
@@ -766,8 +803,14 @@ func MakeDownloadHTTPClient(
 
 	if tunneled {
 
-		httpClient, err = MakeTunneledHTTPClient(
-			config, tunnel, skipVerify || disableSystemRootCAs)
+		httpClient, getParams, err = MakeTunneledHTTPClient(
+			ctx,
+			config,
+			tunnel,
+			skipVerify || disableSystemRootCAs,
+			disableSystemRootCAs,
+			frontingSpecs,
+			selectedFrontingProviderID)
 		if err != nil {
 			return nil, false, nil, errors.Trace(err)
 		}