Просмотр исходного кода

Merge pull request #169 from rod-hynes/master

Resumable fetch remote server list
Rod Hynes 10 лет назад
Родитель
Сommit
dd05c45f19

+ 52 - 16
psiphon/config.go

@@ -47,12 +47,12 @@ const (
 	TUNNEL_SSH_KEEP_ALIVE_PROBE_INACTIVE_PERIOD    = 10 * time.Second
 	ESTABLISH_TUNNEL_TIMEOUT_SECONDS               = 300
 	ESTABLISH_TUNNEL_WORK_TIME                     = 60 * time.Second
-	ESTABLISH_TUNNEL_PAUSE_PERIOD                  = 5 * time.Second
+	ESTABLISH_TUNNEL_PAUSE_PERIOD_SECONDS          = 5
 	ESTABLISH_TUNNEL_SERVER_AFFINITY_GRACE_PERIOD  = 1 * time.Second
 	HTTP_PROXY_ORIGIN_SERVER_TIMEOUT_SECONDS       = 15
 	HTTP_PROXY_MAX_IDLE_CONNECTIONS_PER_HOST       = 50
 	FETCH_REMOTE_SERVER_LIST_TIMEOUT_SECONDS       = 30
-	FETCH_REMOTE_SERVER_LIST_RETRY_PERIOD          = 5 * time.Second
+	FETCH_REMOTE_SERVER_LIST_RETRY_PERIOD_SECONDS  = 30
 	FETCH_REMOTE_SERVER_LIST_STALE_PERIOD          = 6 * time.Hour
 	PSIPHON_API_CLIENT_SESSION_ID_LENGTH           = 16
 	PSIPHON_API_SERVER_TIMEOUT_SECONDS             = 20
@@ -67,7 +67,7 @@ const (
 	PSIPHON_API_TUNNEL_STATS_MAX_COUNT             = 1000
 	FETCH_ROUTES_TIMEOUT_SECONDS                   = 60
 	DOWNLOAD_UPGRADE_TIMEOUT                       = 15 * time.Minute
-	DOWNLOAD_UPGRADE_RETRY_PERIOD                  = 5 * time.Second
+	DOWNLOAD_UPGRADE_RETRY_PERIOD_SECONDS          = 30
 	DOWNLOAD_UPGRADE_STALE_PERIOD                  = 6 * time.Hour
 	IMPAIRED_PROTOCOL_CLASSIFICATION_DURATION      = 2 * time.Minute
 	IMPAIRED_PROTOCOL_CLASSIFICATION_THRESHOLD     = 3
@@ -95,11 +95,6 @@ type Config struct {
 	// continue running.
 	DataStoreDirectory string
 
-	// DataStoreTempDirectory is the directory in which to store temporary
-	// work files associated with the persistent database.
-	// This parameter is deprecated and may be removed.
-	DataStoreTempDirectory string
-
 	// PropagationChannelId is a string identifier which indicates how the
 	// Psiphon client was distributed. This parameter is required.
 	// This value is supplied by and depends on the Psiphon Network, and is
@@ -120,6 +115,14 @@ type Config struct {
 	// typically embedded in the client binary.
 	RemoteServerListUrl string
 
+	// RemoteServerListDownloadFilename specifies a target filename for
+	// storing the remote server list download. Data is stored in co-located
+	// files (RemoteServerListDownloadFilename.part*) to allow for resumable
+	// downloading. If not specified, the default is to use the
+	// remote object name as the filename, stored in the current working
+	// directory.
+	RemoteServerListDownloadFilename string
+
 	// RemoteServerListSignaturePublicKey specifies a public key that's
 	// used to authenticate the remote server list payload.
 	// This value is supplied by and depends on the Psiphon Network, and is
@@ -264,6 +267,8 @@ type Config struct {
 
 	// UpgradeDownloadFilename is the local target filename for an upgrade download.
 	// This parameter is required when UpgradeDownloadUrl is specified.
+	// Data is stored in co-located files (UpgradeDownloadFilename.part*) to allow
+	// for resumable downloading.
 	UpgradeDownloadFilename string
 
 	// EmitBytesTransferred indicates whether to emit periodic notices showing
@@ -312,49 +317,65 @@ type Config struct {
 
 	// TunnelConnectTimeoutSeconds specifies a single tunnel connection sequence timeout.
 	// Zero value means that connection process will not time out.
-	// If omitted default value is TUNNEL_CONNECT_TIMEOUT_SECONDS.
+	// If omitted, the default value is TUNNEL_CONNECT_TIMEOUT_SECONDS.
 	TunnelConnectTimeoutSeconds *int
 
 	// TunnelPortForwardTimeoutSeconds specifies a timeout per SSH port forward.
 	// Zero value means a port forward will not time out.
-	// If omitted default value is TUNNEL_PORT_FORWARD_DIAL_TIMEOUT_SECONDS.
+	// If omitted, the default value is TUNNEL_PORT_FORWARD_DIAL_TIMEOUT_SECONDS.
 	TunnelPortForwardTimeoutSeconds *int
 
 	// TunnelSshKeepAliveProbeTimeoutSeconds specifies a timeout value for "probe"
 	// SSH keep-alive that is sent upon port forward failure.
 	// Zero value means keep-alive request will not time out.
-	// If omitted default value is TUNNEL_SSH_KEEP_ALIVE_PROBE_TIMEOUT_SECONDS.
+	// If omitted, the default value is TUNNEL_SSH_KEEP_ALIVE_PROBE_TIMEOUT_SECONDS.
 	TunnelSshKeepAliveProbeTimeoutSeconds *int
 
 	// TunnelSshKeepAlivePeriodicTimeoutSeconds specifies a timeout value for regular
 	// SSH keep-alives that are sent periodically.
 	// Zero value means keep-alive request will not time out.
-	// If omitted default value is TUNNEL_SSH_KEEP_ALIVE_PERIODIC_TIMEOUT_SECONDS.
+	// If omitted, the default value is TUNNEL_SSH_KEEP_ALIVE_PERIODIC_TIMEOUT_SECONDS.
 	TunnelSshKeepAlivePeriodicTimeoutSeconds *int
 
 	// FetchRemoteServerListTimeoutSeconds specifies a timeout value for remote server list
 	// HTTP request. Zero value means that request will not time out.
-	// If omitted default value is FETCH_REMOTE_SERVER_LIST_TIMEOUT_SECONDS.
+	// If omitted, the default value is FETCH_REMOTE_SERVER_LIST_TIMEOUT_SECONDS.
 	FetchRemoteServerListTimeoutSeconds *int
 
 	// PsiphonApiServerTimeoutSeconds specifies a timeout for periodic API HTTP
 	// requests to Psiphon server such as stats, home pages, etc.
 	// Zero value means that request will not time out.
-	// If omitted default value is PSIPHON_API_SERVER_TIMEOUT_SECONDS.
+	// If omitted, the default value is PSIPHON_API_SERVER_TIMEOUT_SECONDS.
 	// Note that this value is overridden for final stats requests during shutdown
 	// process in order to prevent hangs.
 	PsiphonApiServerTimeoutSeconds *int
 
 	// FetchRoutesTimeoutSeconds specifies a timeout value for split tunnel routes
 	// HTTP request. Zero value means that request will not time out.
-	// If omitted default value is FETCH_ROUTES_TIMEOUT_SECONDS.
+	// If omitted, the default value is FETCH_ROUTES_TIMEOUT_SECONDS.
 	FetchRoutesTimeoutSeconds *int
 
 	// HttpProxyOriginServerTimeoutSeconds specifies an HTTP response header timeout
 	// value in various HTTP relays found in httpProxy.
 	// Zero value means that request will not time out.
-	// If omitted default value  HTTP_PROXY_ORIGIN_SERVER_TIMEOUT_SECONDS.
+	// If omitted, the default value is HTTP_PROXY_ORIGIN_SERVER_TIMEOUT_SECONDS.
 	HttpProxyOriginServerTimeoutSeconds *int
+
+	// FetchRemoteServerListRetryPeriodSeconds specifies the delay before
+	// resuming a remote server list download after a failure.
+	// If omitted, the default value FETCH_REMOTE_SERVER_LIST_RETRY_PERIOD_SECONDS.
+	FetchRemoteServerListRetryPeriodSeconds *int
+
+	// DownloadUpgradestRetryPeriodSeconds specifies the delay before
+	// resuming a client upgrade download after a failure.
+	// If omitted, the default value DOWNLOAD_UPGRADE_RETRY_PERIOD_SECONDS.
+	DownloadUpgradeRetryPeriodSeconds *int
+
+	// EstablishTunnelPausePeriodSeconds specifies the delay between attempts
+	// to establish tunnels. Briefly pausing allows for network conditions to improve
+	// and for asynchronous operations such as fetch remote server list to complete.
+	// If omitted, the default value is ESTABLISH_TUNNEL_PAUSE_PERIOD_SECONDS.
+	EstablishTunnelPausePeriodSeconds *int
 }
 
 // LoadConfig parses and validates a JSON format Psiphon config JSON
@@ -480,5 +501,20 @@ func LoadConfig(configJson []byte) (*Config, error) {
 		config.HttpProxyOriginServerTimeoutSeconds = &defaultHttpProxyOriginServerTimeoutSeconds
 	}
 
+	if config.FetchRemoteServerListRetryPeriodSeconds == nil {
+		defaultFetchRemoteServerListRetryPeriodSeconds := FETCH_REMOTE_SERVER_LIST_RETRY_PERIOD_SECONDS
+		config.FetchRemoteServerListRetryPeriodSeconds = &defaultFetchRemoteServerListRetryPeriodSeconds
+	}
+
+	if config.DownloadUpgradeRetryPeriodSeconds == nil {
+		defaultDownloadUpgradeRetryPeriodSeconds := DOWNLOAD_UPGRADE_RETRY_PERIOD_SECONDS
+		config.DownloadUpgradeRetryPeriodSeconds = &defaultDownloadUpgradeRetryPeriodSeconds
+	}
+
+	if config.EstablishTunnelPausePeriodSeconds == nil {
+		defaultEstablishTunnelPausePeriodSeconds := ESTABLISH_TUNNEL_PAUSE_PERIOD_SECONDS
+		config.EstablishTunnelPausePeriodSeconds = &defaultEstablishTunnelPausePeriodSeconds
+	}
+
 	return &config, nil
 }

+ 22 - 4
psiphon/controller.go

@@ -248,6 +248,15 @@ func (controller *Controller) SignalComponentFailure() {
 func (controller *Controller) remoteServerListFetcher() {
 	defer controller.runWaitGroup.Done()
 
+	if controller.config.RemoteServerListUrl == "" {
+		NoticeAlert("remote server list URL is blank")
+		return
+	}
+	if controller.config.RemoteServerListSignaturePublicKey == "" {
+		NoticeAlert("remote server list signature public key blank")
+		return
+	}
+
 	var lastFetchTime time.Time
 
 fetcherLoop:
@@ -275,8 +284,14 @@ fetcherLoop:
 				break fetcherLoop
 			}
 
+			// Pick any active tunnel and make the next fetch attempt. If there's
+			// no active tunnel, the untunneledDialConfig will be used.
+			tunnel := controller.getNextActiveTunnel()
+
 			err := FetchRemoteServerList(
-				controller.config, controller.untunneledDialConfig)
+				controller.config,
+				tunnel,
+				controller.untunneledDialConfig)
 
 			if err == nil {
 				lastFetchTime = time.Now()
@@ -285,7 +300,8 @@ fetcherLoop:
 
 			NoticeAlert("failed to fetch remote server list: %s", err)
 
-			timeout := time.After(FETCH_REMOTE_SERVER_LIST_RETRY_PERIOD)
+			timeout := time.After(
+				time.Duration(*controller.config.FetchRemoteServerListRetryPeriodSeconds) * time.Second)
 			select {
 			case <-timeout:
 			case <-controller.shutdownBroadcast:
@@ -452,7 +468,8 @@ downloadLoop:
 
 			NoticeAlert("failed to download upgrade: %s", err)
 
-			timeout := time.After(DOWNLOAD_UPGRADE_RETRY_PERIOD)
+			timeout := time.After(
+				time.Duration(*controller.config.DownloadUpgradeRetryPeriodSeconds) * time.Second)
 			select {
 			case <-timeout:
 			case <-controller.shutdownBroadcast:
@@ -1032,7 +1049,8 @@ loop:
 		// network conditions to change. Also allows for fetch remote to complete,
 		// in typical conditions (it isn't strictly necessary to wait for this, there will
 		// be more rounds if required).
-		timeout := time.After(ESTABLISH_TUNNEL_PAUSE_PERIOD)
+		timeout := time.After(
+			time.Duration(*controller.config.EstablishTunnelPausePeriodSeconds) * time.Second)
 		select {
 		case <-timeout:
 			// Retry iterating

BIN
psiphon/controller_test.config.enc


+ 99 - 10
psiphon/controller_test.go

@@ -45,15 +45,34 @@ func TestMain(m *testing.M) {
 	os.Exit(m.Run())
 }
 
-// Note: untunneled upgrade tests must execute before
-// the other tests to ensure no tunnel is established.
-// We need a way to reset the datastore after it's been
-// initialized in order to to clear out its data entries
-// and be able to arbitrarily order the tests.
+// Test case notes/limitations/dependencies:
+//
+// * Untunneled upgrade tests must execute before
+//   the other tests to ensure no tunnel is established.
+//   We need a way to reset the datastore after it's been
+//   initialized in order to to clear out its data entries
+//   and be able to arbitrarily order the tests.
+//
+// * The resumable download tests using disruptNetwork
+//   depend on the download object being larger than the
+//   disruptorMax limits so that the disruptor will actually
+//   interrupt the first download attempt. Specifically, the
+//   upgrade and remote server list at the URLs specified in
+//   controller_test.config.enc.
+//
+// * The protocol tests assume there is at least one server
+//   supporting each protocol in the server list at the URL
+//   specified in controller_test.config.enc, and that these
+//   servers are not overloaded.
+//
+// * fetchAndVerifyWebsite depends on the target URL being
+//   available and responding.
+//
 
 func TestUntunneledUpgradeDownload(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    true,
 			protocol:                 "",
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: false,
@@ -68,6 +87,7 @@ func TestUntunneledUpgradeDownload(t *testing.T) {
 func TestUntunneledResumableUpgradeDownload(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    true,
 			protocol:                 "",
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: false,
@@ -82,6 +102,7 @@ func TestUntunneledResumableUpgradeDownload(t *testing.T) {
 func TestUntunneledUpgradeClientIsLatestVersion(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    true,
 			protocol:                 "",
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: false,
@@ -93,9 +114,25 @@ func TestUntunneledUpgradeClientIsLatestVersion(t *testing.T) {
 		})
 }
 
+func TestUntunneledResumableFetchRemoveServerList(t *testing.T) {
+	controllerRun(t,
+		&controllerRunConfig{
+			expectNoServerEntries:    true,
+			protocol:                 "",
+			clientIsLatestVersion:    true,
+			disableUntunneledUpgrade: false,
+			disableEstablishing:      false,
+			tunnelPoolSize:           1,
+			disruptNetwork:           true,
+			useHostNameTransformer:   false,
+			runDuration:              0,
+		})
+}
+
 func TestTunneledUpgradeClientIsLatestVersion(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 "",
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
@@ -118,6 +155,7 @@ func TestImpairedProtocols(t *testing.T) {
 
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 "",
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
@@ -132,6 +170,7 @@ func TestImpairedProtocols(t *testing.T) {
 func TestSSH(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 TUNNEL_PROTOCOL_SSH,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
@@ -146,6 +185,7 @@ func TestSSH(t *testing.T) {
 func TestObfuscatedSSH(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 TUNNEL_PROTOCOL_OBFUSCATED_SSH,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
@@ -160,6 +200,7 @@ func TestObfuscatedSSH(t *testing.T) {
 func TestUnfrontedMeek(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
@@ -174,6 +215,7 @@ func TestUnfrontedMeek(t *testing.T) {
 func TestUnfrontedMeekWithTransformer(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
@@ -188,6 +230,7 @@ func TestUnfrontedMeekWithTransformer(t *testing.T) {
 func TestFrontedMeek(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 TUNNEL_PROTOCOL_FRONTED_MEEK,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
@@ -202,6 +245,7 @@ func TestFrontedMeek(t *testing.T) {
 func TestFrontedMeekWithTransformer(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 TUNNEL_PROTOCOL_FRONTED_MEEK,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
@@ -216,6 +260,7 @@ func TestFrontedMeekWithTransformer(t *testing.T) {
 func TestFrontedMeekHTTP(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
@@ -230,6 +275,7 @@ func TestFrontedMeekHTTP(t *testing.T) {
 func TestUnfrontedMeekHTTPS(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 			clientIsLatestVersion:    false,
 			disableUntunneledUpgrade: true,
@@ -244,6 +290,7 @@ func TestUnfrontedMeekHTTPS(t *testing.T) {
 func TestUnfrontedMeekHTTPSWithTransformer(t *testing.T) {
 	controllerRun(t,
 		&controllerRunConfig{
+			expectNoServerEntries:    false,
 			protocol:                 TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 			clientIsLatestVersion:    true,
 			disableUntunneledUpgrade: true,
@@ -256,6 +303,7 @@ func TestUnfrontedMeekHTTPSWithTransformer(t *testing.T) {
 }
 
 type controllerRunConfig struct {
+	expectNoServerEntries    bool
 	protocol                 string
 	clientIsLatestVersion    bool
 	disableUntunneledUpgrade bool
@@ -312,6 +360,14 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 		t.Fatalf("error initializing datastore: %s", err)
 	}
 
+	serverEntryCount := CountServerEntries("", "")
+
+	if runConfig.expectNoServerEntries && serverEntryCount > 0 {
+		// TODO: replace expectNoServerEntries with resetServerEntries
+		// so tests can run in arbitrary order
+		t.Fatalf("unexpected server entries")
+	}
+
 	controller, err := NewController(config)
 	if err != nil {
 		t.Fatalf("error creating controller: %s", err)
@@ -326,9 +382,11 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 
 	tunnelEstablished := make(chan struct{}, 1)
 	upgradeDownloaded := make(chan struct{}, 1)
+	remoteServerListDownloaded := make(chan struct{}, 1)
 	confirmedLatestVersion := make(chan struct{}, 1)
 
 	var clientUpgradeDownloadedBytesCount int32
+	var remoteServerListDownloadedBytesCount int32
 	var impairedProtocolCount int32
 	var impairedProtocolClassification = struct {
 		sync.RWMutex
@@ -391,6 +449,18 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 				default:
 				}
 
+			case "RemoteServerListDownloadedBytes":
+
+				atomic.AddInt32(&remoteServerListDownloadedBytesCount, 1)
+				t.Logf("RemoteServerListDownloadedBytes: %d", int(payload["bytes"].(float64)))
+
+			case "RemoteServerListDownloaded":
+
+				select {
+				case remoteServerListDownloaded <- *new(struct{}):
+				default:
+				}
+
 			case "ImpairedProtocolClassification":
 
 				classification := payload["classification"].(map[string]interface{})
@@ -459,9 +529,9 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 
 	if !runConfig.disableEstablishing {
 
-		// Test: tunnel must be established within 60 seconds
+		// Test: tunnel must be established within 120 seconds
 
-		establishTimeout := time.NewTimer(60 * time.Second)
+		establishTimeout := time.NewTimer(120 * time.Second)
 
 		select {
 		case <-tunnelEstablished:
@@ -470,6 +540,25 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 			t.Fatalf("tunnel establish timeout exceeded")
 		}
 
+		// Test: if starting with no server entries, a fetch remote
+		// server list must have succeeded. With disruptNetwork, the
+		// fetch must have been resumed at least once.
+
+		if serverEntryCount == 0 {
+			select {
+			case <-remoteServerListDownloaded:
+			default:
+				t.Fatalf("expected remote server list downloaded")
+			}
+
+			if runConfig.disruptNetwork {
+				count := atomic.LoadInt32(&remoteServerListDownloadedBytesCount)
+				if count <= 1 {
+					t.Fatalf("unexpected remote server list download progress: %d", count)
+				}
+			}
+		}
+
 		// Test: fetch website through tunnel
 
 		// Allow for known race condition described in NewHttpProxy():
@@ -508,9 +597,9 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 		}
 	}
 
-	// Test: upgrade check/download must be downloaded within 120 seconds
+	// Test: upgrade check/download must be downloaded within 180 seconds
 
-	upgradeTimeout := time.NewTimer(120 * time.Second)
+	upgradeTimeout := time.NewTimer(180 * time.Second)
 
 	select {
 	case <-upgradeDownloaded:
@@ -653,7 +742,7 @@ func useTunnel(t *testing.T, httpProxyPort int) {
 
 const disruptorProxyAddress = "127.0.0.1:2160"
 const disruptorProxyURL = "socks4a://" + disruptorProxyAddress
-const disruptorMaxConnectionBytes = 2000000
+const disruptorMaxConnectionBytes = 625000
 const disruptorMaxConnectionTime = 10 * time.Second
 
 func initDisruptor() {

+ 190 - 0
psiphon/net.go

@@ -29,6 +29,7 @@ import (
 	"net"
 	"net/http"
 	"net/url"
+	"os"
 	"reflect"
 	"sync"
 	"time"
@@ -378,6 +379,195 @@ func MakeTunneledHttpClient(
 	}, nil
 }
 
+// MakeDownloadHttpClient is a resusable helper that sets up a
+// http.Client for use either untunneled or through a tunnel.
+// See MakeUntunneledHttpsClient for a note about request URL
+// rewritting.
+func MakeDownloadHttpClient(
+	config *Config,
+	tunnel *Tunnel,
+	untunneledDialConfig *DialConfig,
+	requestUrl string,
+	requestTimeout time.Duration) (*http.Client, string, error) {
+
+	var httpClient *http.Client
+	var err error
+
+	if tunnel != nil {
+		httpClient, err = MakeTunneledHttpClient(config, tunnel, requestTimeout)
+		if err != nil {
+			return nil, "", ContextError(err)
+		}
+	} else {
+		httpClient, requestUrl, err = MakeUntunneledHttpsClient(
+			untunneledDialConfig, nil, requestUrl, requestTimeout)
+		if err != nil {
+			return nil, "", ContextError(err)
+		}
+	}
+
+	return httpClient, requestUrl, nil
+}
+
+// ResumeDownload is a resuable helper that downloads requestUrl via the
+// httpClient, storing the result in downloadFilename when the download is
+// complete. Intermediate, partial downloads state is stored in
+// downloadFilename.part and downloadFilename.part.etag.
+// Any existing downloadFilename file will be overwritten.
+//
+// In the case where the remote object has change while a partial download
+// is to be resumed, the partial state is reset and resumeDownload fails.
+// The caller must restart the download.
+//
+// When ifNoneMatchETag is specified, no download is made if the remote
+// object has the same ETag. ifNoneMatchETag has an effect only when no
+// partial download is in progress.
+//
+func ResumeDownload(
+	httpClient *http.Client,
+	requestUrl string,
+	downloadFilename string,
+	ifNoneMatchETag string) (int64, string, error) {
+
+	partialFilename := fmt.Sprintf("%s.part", downloadFilename)
+
+	partialETagFilename := fmt.Sprintf("%s.part.etag", downloadFilename)
+
+	file, err := os.OpenFile(partialFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
+	if err != nil {
+		return 0, "", ContextError(err)
+	}
+	defer file.Close()
+
+	fileInfo, err := file.Stat()
+	if err != nil {
+		return 0, "", ContextError(err)
+	}
+
+	// A partial download should have an ETag which is to be sent with the
+	// Range request to ensure that the source object is the same as the
+	// one that is partially downloaded.
+	var partialETag []byte
+	if fileInfo.Size() > 0 {
+
+		partialETag, err = ioutil.ReadFile(partialETagFilename)
+
+		// When the ETag can't be loaded, delete the partial download. To keep the
+		// code simple, there is no immediate, inline retry here, on the assumption
+		// that the controller's upgradeDownloader will shortly call DownloadUpgrade
+		// again.
+		if err != nil {
+			os.Remove(partialFilename)
+			os.Remove(partialETagFilename)
+			return 0, "", ContextError(
+				fmt.Errorf("failed to load partial download ETag: %s", err))
+		}
+	}
+
+	request, err := http.NewRequest("GET", requestUrl, nil)
+	if err != nil {
+		return 0, "", ContextError(err)
+	}
+
+	request.Header.Add("Range", fmt.Sprintf("bytes=%d-", fileInfo.Size()))
+
+	if partialETag != nil {
+
+		// Note: not using If-Range, since not all host servers support it.
+		// Using If-Match means we need to check for status code 412 and reset
+		// when the ETag has changed since the last partial download.
+		request.Header.Add("If-Match", string(partialETag))
+
+	} else if ifNoneMatchETag != "" {
+
+		// Can't specify both If-Match and If-None-Match. Behavior is undefined.
+		// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.26
+		// So for downloaders that store an ETag and wish to use that to prevent
+		// redundant downloads, that ETag is sent as If-None-Match in the case
+		// where a partial download is not in progress. When a partial download
+		// is in progress, the partial ETag is sent as If-Match: either that's
+		// a version that was never fully received, or it's no longer current in
+		// which case the response will be StatusPreconditionFailed, the partial
+		// download will be discarded, and then the next retry will use
+		// If-None-Match.
+
+		// Note: in this case, fileInfo.Size() == 0
+
+		request.Header.Add("If-None-Match", ifNoneMatchETag)
+	}
+
+	response, err := httpClient.Do(request)
+
+	// The resumeable download may ask for bytes past the resource range
+	// since it doesn't store the "completed download" state. In this case,
+	// the HTTP server returns 416. Otherwise, we expect 206. We may also
+	// receive 412 on ETag mismatch.
+	if err == nil &&
+		(response.StatusCode != http.StatusPartialContent &&
+			response.StatusCode != http.StatusRequestedRangeNotSatisfiable &&
+			response.StatusCode != http.StatusPreconditionFailed &&
+			response.StatusCode != http.StatusNotModified) {
+		response.Body.Close()
+		err = fmt.Errorf("unexpected response status code: %d", response.StatusCode)
+	}
+	if err != nil {
+		return 0, "", ContextError(err)
+	}
+	defer response.Body.Close()
+
+	responseETag := response.Header.Get("ETag")
+
+	if response.StatusCode == http.StatusPreconditionFailed {
+		// When the ETag no longer matches, delete the partial download. As above,
+		// simply failing and relying on the caller's retry schedule.
+		os.Remove(partialFilename)
+		os.Remove(partialETagFilename)
+		return 0, "", ContextError(errors.New("partial download ETag mismatch"))
+
+	} else if response.StatusCode == http.StatusNotModified {
+		// This status code is possible in the "If-None-Match" case. Don't leave
+		// any partial download in progress. Caller should check that responseETag
+		// matches ifNoneMatchETag.
+		os.Remove(partialFilename)
+		os.Remove(partialETagFilename)
+		return 0, responseETag, nil
+	}
+
+	// Not making failure to write ETag file fatal, in case the entire download
+	// succeeds in this one request.
+	ioutil.WriteFile(partialETagFilename, []byte(responseETag), 0600)
+
+	// A partial download occurs when this copy is interrupted. The io.Copy
+	// will fail, leaving a partial download in place (.part and .part.etag).
+	n, err := io.Copy(NewSyncFileWriter(file), response.Body)
+
+	// From this point, n bytes are indicated as downloaded, even if there is
+	// an error; the caller may use this to report partial download progress.
+
+	if err != nil {
+		return n, "", ContextError(err)
+	}
+
+	// Ensure the file is flushed to disk. The deferred close
+	// will be a noop when this succeeds.
+	err = file.Close()
+	if err != nil {
+		return n, "", ContextError(err)
+	}
+
+	// Remove if exists, to enable rename
+	os.Remove(downloadFilename)
+
+	err = os.Rename(partialFilename, downloadFilename)
+	if err != nil {
+		return n, "", ContextError(err)
+	}
+
+	os.Remove(partialETagFilename)
+
+	return n, responseETag, nil
+}
+
 // IPAddressFromAddr is a helper which extracts an IP address
 // from a net.Addr or returns "" if there is no IP address.
 func IPAddressFromAddr(addr net.Addr) string {

+ 12 - 1
psiphon/notice.go

@@ -323,7 +323,18 @@ func NoticeBuildInfo(buildDate, buildRepo, buildRev, goVersion, gomobileVersion
 
 // NoticeExiting indicates that tunnel-core is exiting imminently.
 func NoticeExiting() {
-	outputNotice("Exiting", false, true)
+	outputNotice("Exiting", false, false)
+}
+
+// NoticeRemoteServerListDownloadedBytes reports remote server list download progress.
+func NoticeRemoteServerListDownloadedBytes(bytes int64) {
+	outputNotice("RemoteServerListDownloadedBytes", true, false, "bytes", bytes)
+}
+
+// NoticeRemoteServerListDownloaded indicates that a remote server list download
+// completed successfully.
+func NoticeRemoteServerListDownloaded(filename string) {
+	outputNotice("RemoteServerListDownloaded", false, false, "filename", filename)
 }
 
 type repetitiveNoticeState struct {

+ 51 - 31
psiphon/remoteServerList.go

@@ -20,10 +20,10 @@
 package psiphon
 
 import (
-	"errors"
-	"fmt"
+	"compress/zlib"
 	"io/ioutil"
-	"net/http"
+	"os"
+	"strings"
 	"time"
 )
 
@@ -31,58 +31,76 @@ import (
 // config.RemoteServerListUrl; validates its digital signature using the
 // public key config.RemoteServerListSignaturePublicKey; and parses the
 // data field into ServerEntry records.
-func FetchRemoteServerList(config *Config, dialConfig *DialConfig) (err error) {
+func FetchRemoteServerList(
+	config *Config,
+	tunnel *Tunnel,
+	untunneledDialConfig *DialConfig) error {
+
 	NoticeInfo("fetching remote server list")
 
-	if config.RemoteServerListUrl == "" {
-		return ContextError(errors.New("remote server list URL is blank"))
-	}
-	if config.RemoteServerListSignaturePublicKey == "" {
-		return ContextError(errors.New("remote server list signature public key blank"))
-	}
+	// Select tunneled or untunneled configuration
 
-	httpClient, requestUrl, err := MakeUntunneledHttpsClient(
-		dialConfig, nil, config.RemoteServerListUrl, time.Duration(*config.FetchRemoteServerListTimeoutSeconds)*time.Second)
+	httpClient, requestUrl, err := MakeDownloadHttpClient(
+		config,
+		tunnel,
+		untunneledDialConfig,
+		config.RemoteServerListUrl,
+		time.Duration(*config.FetchRemoteServerListTimeoutSeconds)*time.Second)
 	if err != nil {
 		return ContextError(err)
 	}
 
-	request, err := http.NewRequest("GET", requestUrl, nil)
+	// Proceed with download
+
+	downloadFilename := config.RemoteServerListDownloadFilename
+	if downloadFilename == "" {
+		splitPath := strings.Split(config.RemoteServerListUrl, "/")
+		downloadFilename = splitPath[len(splitPath)-1]
+	}
+
+	lastETag, err := GetUrlETag(config.RemoteServerListUrl)
 	if err != nil {
 		return ContextError(err)
 	}
 
-	etag, err := GetUrlETag(config.RemoteServerListUrl)
+	n, responseETag, err := ResumeDownload(
+		httpClient, requestUrl, downloadFilename, lastETag)
+
+	NoticeRemoteServerListDownloadedBytes(n)
+
 	if err != nil {
 		return ContextError(err)
 	}
-	if etag != "" {
-		request.Header.Add("If-None-Match", etag)
+
+	if responseETag == lastETag {
+		// The remote server list is unchanged and no data was downloaded
+		return nil
 	}
 
-	response, err := httpClient.Do(request)
+	NoticeRemoteServerListDownloaded(downloadFilename)
 
-	if err == nil &&
-		(response.StatusCode != http.StatusOK && response.StatusCode != http.StatusNotModified) {
-		response.Body.Close()
-		err = fmt.Errorf("unexpected response status code: %d", response.StatusCode)
-	}
+	// The downloaded content is a zlib compressed authenticated
+	// data package containing a list of encoded server entries.
+
+	downloadContent, err := os.Open(downloadFilename)
 	if err != nil {
 		return ContextError(err)
 	}
-	defer response.Body.Close()
+	defer downloadContent.Close()
 
-	if response.StatusCode == http.StatusNotModified {
-		return nil
+	zlibReader, err := zlib.NewReader(downloadContent)
+	if err != nil {
+		return ContextError(err)
 	}
 
-	body, err := ioutil.ReadAll(response.Body)
+	dataPackage, err := ioutil.ReadAll(zlibReader)
+	zlibReader.Close()
 	if err != nil {
 		return ContextError(err)
 	}
 
 	remoteServerList, err := ReadAuthenticatedDataPackage(
-		body, config.RemoteServerListSignaturePublicKey)
+		dataPackage, config.RemoteServerListSignaturePublicKey)
 	if err != nil {
 		return ContextError(err)
 	}
@@ -100,11 +118,13 @@ func FetchRemoteServerList(config *Config, dialConfig *DialConfig) (err error) {
 		return ContextError(err)
 	}
 
-	etag = response.Header.Get("ETag")
-	if etag != "" {
-		err := SetUrlETag(config.RemoteServerListUrl, etag)
+	// Now that the server entries are successfully imported, store the response
+	// ETag so we won't re-download this same data again.
+
+	if responseETag != "" {
+		err := SetUrlETag(config.RemoteServerListUrl, responseETag)
 		if err != nil {
-			NoticeAlert("failed to set remote server list etag: %s", ContextError(err))
+			NoticeAlert("failed to set remote server list ETag: %s", ContextError(err))
 			// This fetch is still reported as a success, even if we can't store the etag
 		}
 	}

+ 29 - 8
psiphon/server/config.go

@@ -73,18 +73,29 @@ type Config struct {
 	// panic, fatal, error, warn, info, debug
 	LogLevel string
 
-	// SyslogAddress specifies the UDP address of a syslog
-	// service. When set, syslog is used for message logging.
-	SyslogAddress string
-
 	// SyslogFacility specifies the syslog facility to log to.
+	// When set, the local syslog service is used for message
+	// logging.
 	// Valid values include: "user", "local0", "local1", etc.
 	SyslogFacility string
 
 	// SyslogTag specifies an optional tag for syslog log
-	// messages. The default tag is "psiphon-server".
+	// messages. The default tag is "psiphon-server". The
+	// fail2ban logs, if enabled, also use this tag.
 	SyslogTag string
 
+	// Fail2BanFormat is a string format specifier for the
+	// log message format to use for fail2ban integration for
+	// blocking abusive clients by source IP address.
+	// When set, logs with this format are made to the AUTH
+	// facility with INFO severity in the local syslog server
+	// if clients fail to authenticate.
+	// The client's IP address is included with the log message.
+	// An example format specifier, which should be compatible
+	// with default SSH fail2ban configuration, is
+	// "Authentication failure for psiphon-client from %s".
+	Fail2BanFormat string
+
 	// DiscoveryValueHMACKey is the network-wide secret value
 	// used to determine a unique discovery strategy.
 	DiscoveryValueHMACKey string
@@ -200,12 +211,18 @@ func (config *Config) RunObfuscatedSSHServer() bool {
 	return config.ObfuscatedSSHServerPort > 0
 }
 
-// RunObfuscatedSSHServer indicates whether to store per-session GeoIP information in
+// UseRedis indicates whether to store per-session GeoIP information in
 // redis. This is for integration with the legacy psi_web component.
 func (config *Config) UseRedis() bool {
 	return config.RedisServerAddress != ""
 }
 
+// UseFail2Ban indicates whether to log client IP addresses, in authentication
+// failure cases, to the local syslog service AUTH facility for use by fail2ban.
+func (config *Config) UseFail2Ban() bool {
+	return config.Fail2BanFormat != ""
+}
+
 // GetTrafficRules looks up the traffic rules for the specified country. If there
 // are no RegionalTrafficRules for the country, DefaultTrafficRules are returned.
 func (config *Config) GetTrafficRules(targetCountryCode string) TrafficRules {
@@ -237,8 +254,12 @@ func LoadConfig(configJSONs [][]byte) (*Config, error) {
 		}
 	}
 
+	if config.Fail2BanFormat != "" && strings.Count(config.Fail2BanFormat, "%s") != 1 {
+		return nil, errors.New("Fail2BanFormat must have one '%%s' placeholder")
+	}
+
 	if config.ServerIPAddress == "" {
-		return nil, errors.New("server IP address is missing from config file")
+		return nil, errors.New("ServerIPAddress is missing from config file")
 	}
 
 	if config.WebServerPort > 0 && (config.WebServerSecret == "" || config.WebServerCertificate == "" ||
@@ -374,9 +395,9 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 
 	config := &Config{
 		LogLevel:                DEFAULT_LOG_LEVEL,
-		SyslogAddress:           "",
 		SyslogFacility:          "",
 		SyslogTag:               DEFAULT_SYSLOG_TAG,
+		Fail2BanFormat:          "",
 		DiscoveryValueHMACKey:   "",
 		GeoIPDatabaseFilename:   DEFAULT_GEO_IP_DATABASE_FILENAME,
 		ServerIPAddress:         serverIPaddress,

+ 27 - 6
psiphon/server/log.go

@@ -20,6 +20,7 @@
 package server
 
 import (
+	"fmt"
 	"io"
 	"log/syslog"
 	"os"
@@ -69,10 +70,14 @@ func NewLogWriter() *io.PipeWriter {
 }
 
 var log *ContextLogger
+var fail2BanFormat string
+var fail2BanWriter *syslog.Writer
 
 // InitLogging configures a logger according to the specified
 // config params. If not called, the default logger set by the
 // package init() is used.
+// When configured, InitLogging also establishes a local syslog
+// logger specifically for fail2ban integration.
 // Concurrenty note: should only be called from the main
 // goroutine.
 func InitLogging(config *Config) error {
@@ -86,14 +91,10 @@ func InitLogging(config *Config) error {
 
 	var syslogHook *logrus_syslog.SyslogHook
 
-	if config.SyslogAddress != "" {
+	if config.SyslogFacility != "" {
 
 		syslogHook, err = logrus_syslog.NewSyslogHook(
-			"udp",
-			config.SyslogAddress,
-			getSyslogPriority(config),
-			config.SyslogTag)
-
+			"", "", getSyslogPriority(config), config.SyslogTag)
 		if err != nil {
 			return psiphon.ContextError(err)
 		}
@@ -110,9 +111,29 @@ func InitLogging(config *Config) error {
 		},
 	}
 
+	if config.Fail2BanFormat != "" {
+		fail2BanFormat = config.Fail2BanFormat
+		fail2BanWriter, err = syslog.Dial(
+			"", "", syslog.LOG_AUTH|syslog.LOG_INFO, config.SyslogTag)
+		if err != nil {
+			return psiphon.ContextError(err)
+		}
+	}
+
 	return nil
 }
 
+// LogFail2Ban logs a message to the local syslog service AUTH
+// facility with INFO severity using the format specified by
+// config.Fail2BanFormat and the given client IP address. This
+// is for integration with fail2ban for blocking abusive
+// clients by source IP address. When set, the tag in
+// config.SyslogTag is used.
+func LogFail2Ban(clientIPAddress string) {
+	fail2BanWriter.Info(
+		fmt.Sprintf(fail2BanFormat, clientIPAddress))
+}
+
 // getSyslogPriority determines golang's syslog "priority" value
 // based on the provided config.
 func getSyslogPriority(config *Config) syslog.Priority {

+ 6 - 0
psiphon/server/sshService.go

@@ -544,6 +544,12 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 
 func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
 	if err != nil {
+		if sshClient.sshServer.config.UseFail2Ban() {
+			clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
+			if clientIPAddress != "" {
+				LogFail2Ban(clientIPAddress)
+			}
+		}
 		log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
 	} else {
 		log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")

+ 1 - 2
psiphon/splitTunnel.go

@@ -295,8 +295,7 @@ func (classifier *SplitTunnelClassifier) getRoutes(tunnel *Tunnel) (routesData [
 	}
 
 	if !useCachedRoutes {
-		bytesReader := bytes.NewReader(compressedRoutesData)
-		zlibReader, err := zlib.NewReader(bytesReader)
+		zlibReader, err := zlib.NewReader(bytes.NewReader(compressedRoutesData))
 		if err == nil {
 			routesData, err = ioutil.ReadAll(zlibReader)
 			zlibReader.Close()

+ 15 - 113
psiphon/upgradeDownload.go

@@ -20,10 +20,7 @@
 package psiphon
 
 import (
-	"errors"
 	"fmt"
-	"io"
-	"io/ioutil"
 	"net/http"
 	"os"
 	"strconv"
@@ -67,24 +64,14 @@ func DownloadUpgrade(
 		return nil
 	}
 
-	requestUrl := config.UpgradeDownloadUrl
-	var httpClient *http.Client
-	var err error
-
 	// Select tunneled or untunneled configuration
 
-	if tunnel != nil {
-		httpClient, err = MakeTunneledHttpClient(config, tunnel, DOWNLOAD_UPGRADE_TIMEOUT)
-		if err != nil {
-			return ContextError(err)
-		}
-	} else {
-		httpClient, requestUrl, err = MakeUntunneledHttpsClient(
-			untunneledDialConfig, nil, requestUrl, DOWNLOAD_UPGRADE_TIMEOUT)
-		if err != nil {
-			return ContextError(err)
-		}
-	}
+	httpClient, requestUrl, err := MakeDownloadHttpClient(
+		config,
+		tunnel,
+		untunneledDialConfig,
+		config.UpgradeDownloadUrl,
+		DOWNLOAD_UPGRADE_TIMEOUT)
 
 	// If no handshake version is supplied, make an initial HEAD request
 	// to get the current version from the version header.
@@ -112,7 +99,7 @@ func DownloadUpgrade(
 
 		// Note: if the header is missing, Header.Get returns "" and then
 		// strconv.Atoi returns a parse error.
-		availableClientVersion := response.Header.Get(config.UpgradeDownloadClientVersionHeader)
+		availableClientVersion = response.Header.Get(config.UpgradeDownloadClientVersionHeader)
 		checkAvailableClientVersion, err := strconv.Atoi(availableClientVersion)
 		if err != nil {
 			// If the header is missing or malformed, we can't determine the available
@@ -133,92 +120,16 @@ func DownloadUpgrade(
 		}
 	}
 
-	// Proceed with full download
+	// Proceed with download
 
-	partialFilename := fmt.Sprintf(
-		"%s.%s.part", config.UpgradeDownloadFilename, availableClientVersion)
+	// An intermediate filename is used since the presence of
+	// config.UpgradeDownloadFilename indicates a completed download.
 
-	partialETagFilename := fmt.Sprintf(
-		"%s.%s.part.etag", config.UpgradeDownloadFilename, availableClientVersion)
+	downloadFilename := fmt.Sprintf(
+		"%s.%s", config.UpgradeDownloadFilename, availableClientVersion)
 
-	file, err := os.OpenFile(partialFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
-	if err != nil {
-		return ContextError(err)
-	}
-	defer file.Close()
-
-	fileInfo, err := file.Stat()
-	if err != nil {
-		return ContextError(err)
-	}
-
-	// A partial download should have an ETag which is to be sent with the
-	// Range request to ensure that the source object is the same as the
-	// one that is partially downloaded.
-	var partialETag []byte
-	if fileInfo.Size() > 0 {
-
-		partialETag, err = ioutil.ReadFile(partialETagFilename)
-
-		// When the ETag can't be loaded, delete the partial download. To keep the
-		// code simple, there is no immediate, inline retry here, on the assumption
-		// that the controller's upgradeDownloader will shortly call DownloadUpgrade
-		// again.
-		if err != nil {
-			os.Remove(partialFilename)
-			os.Remove(partialETagFilename)
-			return ContextError(
-				fmt.Errorf("failed to load partial download ETag: %s", err))
-		}
-
-	}
-
-	request, err := http.NewRequest("GET", requestUrl, nil)
-	if err != nil {
-		return ContextError(err)
-	}
-	request.Header.Add("Range", fmt.Sprintf("bytes=%d-", fileInfo.Size()))
-
-	// Note: not using If-Range, since not all remote server list host servers
-	// support it. Using If-Match means we need to check for status code 412
-	// and reset when the ETag has changed since the last partial download.
-	if partialETag != nil {
-		request.Header.Add("If-Match", string(partialETag))
-	}
-
-	response, err := httpClient.Do(request)
-
-	// The resumeable download may ask for bytes past the resource range
-	// since it doesn't store the "completed download" state. In this case,
-	// the HTTP server returns 416. Otherwise, we expect 206. We may also
-	// receive 412 on ETag mismatch.
-	if err == nil &&
-		(response.StatusCode != http.StatusPartialContent &&
-			response.StatusCode != http.StatusRequestedRangeNotSatisfiable &&
-			response.StatusCode != http.StatusPreconditionFailed) {
-		response.Body.Close()
-		err = fmt.Errorf("unexpected response status code: %d", response.StatusCode)
-	}
-	if err != nil {
-		return ContextError(err)
-	}
-	defer response.Body.Close()
-
-	if response.StatusCode == http.StatusPreconditionFailed {
-		// When the ETag no longer matches, delete the partial download. As above,
-		// simply failing and relying on the controller's upgradeDownloader retry.
-		os.Remove(partialFilename)
-		os.Remove(partialETagFilename)
-		return ContextError(errors.New("partial download ETag mismatch"))
-	}
-
-	// Not making failure to write ETag file fatal, in case the entire download
-	// succeeds in this one request.
-	ioutil.WriteFile(partialETagFilename, []byte(response.Header.Get("ETag")), 0600)
-
-	// A partial download occurs when this copy is interrupted. The io.Copy
-	// will fail, leaving a partial download in place (.part and .part.etag).
-	n, err := io.Copy(NewSyncFileWriter(file), response.Body)
+	n, _, err := ResumeDownload(
+		httpClient, requestUrl, downloadFilename, "")
 
 	NoticeClientUpgradeDownloadedBytes(n)
 
@@ -226,20 +137,11 @@ func DownloadUpgrade(
 		return ContextError(err)
 	}
 
-	// Ensure the file is flushed to disk. The deferred close
-	// will be a noop when this succeeds.
-	err = file.Close()
+	err = os.Rename(downloadFilename, config.UpgradeDownloadFilename)
 	if err != nil {
 		return ContextError(err)
 	}
 
-	err = os.Rename(partialFilename, config.UpgradeDownloadFilename)
-	if err != nil {
-		return ContextError(err)
-	}
-
-	os.Remove(partialETagFilename)
-
 	NoticeClientUpgradeDownloaded(config.UpgradeDownloadFilename)
 
 	return nil