net.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910
  1. /*
  2. * Copyright (c) 2015, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. // for HTTPSServer.ServeTLS:
  20. /*
  21. Copyright (c) 2012 The Go Authors. All rights reserved.
  22. Redistribution and use in source and binary forms, with or without
  23. modification, are permitted provided that the following conditions are
  24. met:
  25. * Redistributions of source code must retain the above copyright
  26. notice, this list of conditions and the following disclaimer.
  27. * Redistributions in binary form must reproduce the above
  28. copyright notice, this list of conditions and the following disclaimer
  29. in the documentation and/or other materials provided with the
  30. distribution.
  31. * Neither the name of Google Inc. nor the names of its
  32. contributors may be used to endorse or promote products derived from
  33. this software without specific prior written permission.
  34. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  35. "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  36. LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  37. A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  38. OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  39. SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  40. LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  41. DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  42. THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  43. (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  44. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  45. */
  46. package psiphon
  47. import (
  48. "container/list"
  49. "crypto/tls"
  50. "crypto/x509"
  51. "errors"
  52. "fmt"
  53. "io"
  54. "io/ioutil"
  55. "net"
  56. "net/http"
  57. "net/url"
  58. "os"
  59. "reflect"
  60. "sync"
  61. "sync/atomic"
  62. "time"
  63. "github.com/Psiphon-Inc/dns"
  64. "github.com/Psiphon-Inc/ratelimit"
  65. )
  66. const DNS_PORT = 53
  67. // DialConfig contains parameters to determine the behavior
  68. // of a Psiphon dialer (TCPDial, MeekDial, etc.)
  69. type DialConfig struct {
  70. // UpstreamProxyUrl specifies a proxy to connect through.
  71. // E.g., "http://proxyhost:8080"
  72. // "socks5://user:password@proxyhost:1080"
  73. // "socks4a://proxyhost:1080"
  74. // "http://NTDOMAIN\NTUser:password@proxyhost:3375"
  75. //
  76. // Certain tunnel protocols require HTTP CONNECT support
  77. // when a HTTP proxy is specified. If CONNECT is not
  78. // supported, those protocols will not connect.
  79. UpstreamProxyUrl string
  80. ConnectTimeout time.Duration
  81. // PendingConns is used to track and interrupt dials in progress.
  82. // Dials may be interrupted using PendingConns.CloseAll(). Once instantiated,
  83. // a conn is added to pendingConns before the network connect begins and
  84. // removed from pendingConns once the connect succeeds or fails.
  85. // May be nil.
  86. PendingConns *Conns
  87. // BindToDevice parameters are used to exclude connections and
  88. // associated DNS requests from VPN routing.
  89. // When DeviceBinder is set, any underlying socket is
  90. // submitted to the device binding servicebefore connecting.
  91. // The service should bind the socket to a device so that it doesn't route
  92. // through a VPN interface. This service is also used to bind UDP sockets used
  93. // for DNS requests, in which case DnsServerGetter is used to get the
  94. // current active untunneled network DNS server.
  95. DeviceBinder DeviceBinder
  96. DnsServerGetter DnsServerGetter
  97. // UseIndistinguishableTLS specifies whether to try to use an
  98. // alternative stack for TLS. From a circumvention perspective,
  99. // Go's TLS has a distinct fingerprint that may be used for blocking.
  100. // Only applies to TLS connections.
  101. UseIndistinguishableTLS bool
  102. // TrustedCACertificatesFilename specifies a file containing trusted
  103. // CA certs. The file contents should be compatible with OpenSSL's
  104. // SSL_CTX_load_verify_locations.
  105. // Only applies to UseIndistinguishableTLS connections.
  106. TrustedCACertificatesFilename string
  107. // DeviceRegion is the reported region the host device is running in.
  108. // When set, this value may be used, pre-connection, to select performance
  109. // or circumvention optimization strategies for the given region.
  110. DeviceRegion string
  111. // ResolvedIPCallback, when set, is called with the IP address that was
  112. // dialed. This is either the specified IP address in the dial address,
  113. // or the resolved IP address in the case where the dial address is a
  114. // domain name.
  115. // The callback may be invoked by a concurrent goroutine.
  116. ResolvedIPCallback func(string)
  117. }
  118. // NetworkConnectivityChecker defines the interface to the external
  119. // HasNetworkConnectivity provider
  120. type NetworkConnectivityChecker interface {
  121. // TODO: change to bool return value once gobind supports that type
  122. HasNetworkConnectivity() int
  123. }
  124. // DeviceBinder defines the interface to the external BindToDevice provider
  125. type DeviceBinder interface {
  126. BindToDevice(fileDescriptor int) error
  127. }
  128. // DnsServerGetter defines the interface to the external GetDnsServer provider
  129. type DnsServerGetter interface {
  130. GetPrimaryDnsServer() string
  131. GetSecondaryDnsServer() string
  132. }
  133. // HostNameTransformer defines the interface for pluggable hostname
  134. // transformation circumvention strategies.
  135. type HostNameTransformer interface {
  136. TransformHostName(hostname string) (string, bool)
  137. }
  138. // IdentityHostNameTransformer is the default HostNameTransformer, which
  139. // returns the hostname unchanged.
  140. type IdentityHostNameTransformer struct{}
  141. func (IdentityHostNameTransformer) TransformHostName(hostname string) (string, bool) {
  142. return hostname, false
  143. }
  144. // TimeoutError implements the error interface
  145. type TimeoutError struct{}
  146. func (TimeoutError) Error() string { return "timed out" }
  147. func (TimeoutError) Timeout() bool { return true }
  148. func (TimeoutError) Temporary() bool { return true }
  149. // Dialer is a custom dialer compatible with http.Transport.Dial.
  150. type Dialer func(string, string) (net.Conn, error)
  151. // Conns is a synchronized list of Conns that is used to coordinate
  152. // interrupting a set of goroutines establishing connections, or
  153. // close a set of open connections, etc.
  154. // Once the list is closed, no more items may be added to the
  155. // list (unless it is reset).
  156. type Conns struct {
  157. mutex sync.Mutex
  158. isClosed bool
  159. conns map[net.Conn]bool
  160. }
  161. func (conns *Conns) Reset() {
  162. conns.mutex.Lock()
  163. defer conns.mutex.Unlock()
  164. conns.isClosed = false
  165. conns.conns = make(map[net.Conn]bool)
  166. }
  167. func (conns *Conns) Add(conn net.Conn) bool {
  168. conns.mutex.Lock()
  169. defer conns.mutex.Unlock()
  170. if conns.isClosed {
  171. return false
  172. }
  173. if conns.conns == nil {
  174. conns.conns = make(map[net.Conn]bool)
  175. }
  176. conns.conns[conn] = true
  177. return true
  178. }
  179. func (conns *Conns) Remove(conn net.Conn) {
  180. conns.mutex.Lock()
  181. defer conns.mutex.Unlock()
  182. delete(conns.conns, conn)
  183. }
  184. func (conns *Conns) CloseAll() {
  185. conns.mutex.Lock()
  186. defer conns.mutex.Unlock()
  187. conns.isClosed = true
  188. for conn, _ := range conns.conns {
  189. conn.Close()
  190. }
  191. conns.conns = make(map[net.Conn]bool)
  192. }
  193. // LRUConns is a concurrency-safe list of net.Conns ordered
  194. // by recent activity. Its purpose is to facilitate closing
  195. // the oldest connection in a set of connections.
  196. //
  197. // New connections added are referenced by a LRUConnsEntry,
  198. // which is used to Touch() active connections, which
  199. // promotes them to the front of the order and to Remove()
  200. // connections that are no longer LRU candidates.
  201. //
  202. // CloseOldest() will remove the oldest connection from the
  203. // list and call net.Conn.Close() on the connection.
  204. //
  205. // After an entry has been removed, LRUConnsEntry Touch()
  206. // and Remove() will have no effect.
  207. type LRUConns struct {
  208. mutex sync.Mutex
  209. list *list.List
  210. }
  211. // NewLRUConns initializes a new LRUConns.
  212. func NewLRUConns() *LRUConns {
  213. return &LRUConns{list: list.New()}
  214. }
  215. // Add inserts a net.Conn as the freshest connection
  216. // in a LRUConns and returns an LRUConnsEntry to be
  217. // used to freshen the connection or remove the connection
  218. // from the LRU list.
  219. func (conns *LRUConns) Add(conn net.Conn) *LRUConnsEntry {
  220. conns.mutex.Lock()
  221. defer conns.mutex.Unlock()
  222. return &LRUConnsEntry{
  223. lruConns: conns,
  224. element: conns.list.PushFront(conn),
  225. }
  226. }
  227. // CloseOldest closes the oldest connection in a
  228. // LRUConns. It calls net.Conn.Close() on the
  229. // connection.
  230. func (conns *LRUConns) CloseOldest() {
  231. conns.mutex.Lock()
  232. oldest := conns.list.Back()
  233. conn, ok := oldest.Value.(net.Conn)
  234. if oldest != nil {
  235. conns.list.Remove(oldest)
  236. }
  237. // Release mutex before closing conn
  238. conns.mutex.Unlock()
  239. if ok {
  240. conn.Close()
  241. }
  242. }
  243. // LRUConnsEntry is an entry in a LRUConns list.
  244. type LRUConnsEntry struct {
  245. lruConns *LRUConns
  246. element *list.Element
  247. }
  248. // Remove deletes the connection referenced by the
  249. // LRUConnsEntry from the associated LRUConns.
  250. // Has no effect if the entry was not initialized
  251. // or previously removed.
  252. func (entry *LRUConnsEntry) Remove() {
  253. if entry.lruConns == nil || entry.element == nil {
  254. return
  255. }
  256. entry.lruConns.mutex.Lock()
  257. defer entry.lruConns.mutex.Unlock()
  258. entry.lruConns.list.Remove(entry.element)
  259. }
  260. // Touch promotes the connection referenced by the
  261. // LRUConnsEntry to the front of the associated LRUConns.
  262. // Has no effect if the entry was not initialized
  263. // or previously removed.
  264. func (entry *LRUConnsEntry) Touch() {
  265. if entry.lruConns == nil || entry.element == nil {
  266. return
  267. }
  268. entry.lruConns.mutex.Lock()
  269. defer entry.lruConns.mutex.Unlock()
  270. entry.lruConns.list.MoveToFront(entry.element)
  271. }
  272. // LocalProxyRelay sends to remoteConn bytes received from localConn,
  273. // and sends to localConn bytes received from remoteConn.
  274. func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
  275. copyWaitGroup := new(sync.WaitGroup)
  276. copyWaitGroup.Add(1)
  277. go func() {
  278. defer copyWaitGroup.Done()
  279. _, err := io.Copy(localConn, remoteConn)
  280. if err != nil {
  281. err = fmt.Errorf("Relay failed: %s", ContextError(err))
  282. NoticeLocalProxyError(proxyType, err)
  283. }
  284. }()
  285. _, err := io.Copy(remoteConn, localConn)
  286. if err != nil {
  287. err = fmt.Errorf("Relay failed: %s", ContextError(err))
  288. NoticeLocalProxyError(proxyType, err)
  289. }
  290. copyWaitGroup.Wait()
  291. }
  292. // WaitForNetworkConnectivity uses a NetworkConnectivityChecker to
  293. // periodically check for network connectivity. It returns true if
  294. // no NetworkConnectivityChecker is provided (waiting is disabled)
  295. // or when NetworkConnectivityChecker.HasNetworkConnectivity()
  296. // indicates connectivity. It waits and polls the checker once a second.
  297. // If any stop is broadcast, false is returned immediately.
  298. func WaitForNetworkConnectivity(
  299. connectivityChecker NetworkConnectivityChecker, stopBroadcasts ...<-chan struct{}) bool {
  300. if connectivityChecker == nil || 1 == connectivityChecker.HasNetworkConnectivity() {
  301. return true
  302. }
  303. NoticeInfo("waiting for network connectivity")
  304. ticker := time.NewTicker(1 * time.Second)
  305. for {
  306. if 1 == connectivityChecker.HasNetworkConnectivity() {
  307. return true
  308. }
  309. selectCases := make([]reflect.SelectCase, 1+len(stopBroadcasts))
  310. selectCases[0] = reflect.SelectCase{
  311. Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ticker.C)}
  312. for i, stopBroadcast := range stopBroadcasts {
  313. selectCases[i+1] = reflect.SelectCase{
  314. Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stopBroadcast)}
  315. }
  316. chosen, _, ok := reflect.Select(selectCases)
  317. if chosen == 0 && ok {
  318. // Ticker case, so check again
  319. } else {
  320. // Stop case
  321. return false
  322. }
  323. }
  324. }
  325. // ResolveIP uses a custom dns stack to make a DNS query over the
  326. // given TCP or UDP conn. This is used, e.g., when we need to ensure
  327. // that a DNS connection bypasses a VPN interface (BindToDevice) or
  328. // when we need to ensure that a DNS connection is tunneled.
  329. // Caller must set timeouts or interruptibility as required for conn.
  330. func ResolveIP(host string, conn net.Conn) (addrs []net.IP, ttls []time.Duration, err error) {
  331. // Send the DNS query
  332. dnsConn := &dns.Conn{Conn: conn}
  333. defer dnsConn.Close()
  334. query := new(dns.Msg)
  335. query.SetQuestion(dns.Fqdn(host), dns.TypeA)
  336. query.RecursionDesired = true
  337. dnsConn.WriteMsg(query)
  338. // Process the response
  339. response, err := dnsConn.ReadMsg()
  340. if err != nil {
  341. return nil, nil, ContextError(err)
  342. }
  343. addrs = make([]net.IP, 0)
  344. ttls = make([]time.Duration, 0)
  345. for _, answer := range response.Answer {
  346. if a, ok := answer.(*dns.A); ok {
  347. addrs = append(addrs, a.A)
  348. ttl := time.Duration(a.Hdr.Ttl) * time.Second
  349. ttls = append(ttls, ttl)
  350. }
  351. }
  352. return addrs, ttls, nil
  353. }
  354. // MakeUntunneledHttpsClient returns a net/http.Client which is
  355. // configured to use custom dialing features -- including BindToDevice,
  356. // UseIndistinguishableTLS, etc. -- for a specific HTTPS request URL.
  357. // If verifyLegacyCertificate is not nil, it's used for certificate
  358. // verification.
  359. // Because UseIndistinguishableTLS requires a hack to work with
  360. // net/http, MakeUntunneledHttpClient may return a modified request URL
  361. // to be used. Callers should always use this return value to make
  362. // requests, not the input value.
  363. func MakeUntunneledHttpsClient(
  364. dialConfig *DialConfig,
  365. verifyLegacyCertificate *x509.Certificate,
  366. requestUrl string,
  367. requestTimeout time.Duration) (*http.Client, string, error) {
  368. // Change the scheme to "http"; otherwise http.Transport will try to do
  369. // another TLS handshake inside the explicit TLS session. Also need to
  370. // force an explicit port, as the default for "http", 80, won't talk TLS.
  371. urlComponents, err := url.Parse(requestUrl)
  372. if err != nil {
  373. return nil, "", ContextError(err)
  374. }
  375. urlComponents.Scheme = "http"
  376. host, port, err := net.SplitHostPort(urlComponents.Host)
  377. if err != nil {
  378. // Assume there's no port
  379. host = urlComponents.Host
  380. port = ""
  381. }
  382. if port == "" {
  383. port = "443"
  384. }
  385. urlComponents.Host = net.JoinHostPort(host, port)
  386. // Note: IndistinguishableTLS mode doesn't support VerifyLegacyCertificate
  387. useIndistinguishableTLS := dialConfig.UseIndistinguishableTLS && verifyLegacyCertificate == nil
  388. dialer := NewCustomTLSDialer(
  389. // Note: when verifyLegacyCertificate is not nil, some
  390. // of the other CustomTLSConfig is overridden.
  391. &CustomTLSConfig{
  392. Dial: NewTCPDialer(dialConfig),
  393. VerifyLegacyCertificate: verifyLegacyCertificate,
  394. SNIServerName: host,
  395. SkipVerify: false,
  396. UseIndistinguishableTLS: useIndistinguishableTLS,
  397. TrustedCACertificatesFilename: dialConfig.TrustedCACertificatesFilename,
  398. })
  399. transport := &http.Transport{
  400. Dial: dialer,
  401. }
  402. httpClient := &http.Client{
  403. Timeout: requestTimeout,
  404. Transport: transport,
  405. }
  406. return httpClient, urlComponents.String(), nil
  407. }
  408. // MakeTunneledHttpClient returns a net/http.Client which is
  409. // configured to use custom dialing features including tunneled
  410. // dialing and, optionally, UseTrustedCACertificatesForStockTLS.
  411. // Unlike MakeUntunneledHttpsClient and makePsiphonHttpsClient,
  412. // This http.Client uses stock TLS and no scheme transformation
  413. // hack is required.
  414. func MakeTunneledHttpClient(
  415. config *Config,
  416. tunnel *Tunnel,
  417. requestTimeout time.Duration) (*http.Client, error) {
  418. tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
  419. return tunnel.sshClient.Dial("tcp", addr)
  420. }
  421. transport := &http.Transport{
  422. Dial: tunneledDialer,
  423. ResponseHeaderTimeout: requestTimeout,
  424. }
  425. if config.UseTrustedCACertificatesForStockTLS {
  426. if config.TrustedCACertificatesFilename == "" {
  427. return nil, ContextError(errors.New(
  428. "UseTrustedCACertificatesForStockTLS requires TrustedCACertificatesFilename"))
  429. }
  430. rootCAs := x509.NewCertPool()
  431. certData, err := ioutil.ReadFile(config.TrustedCACertificatesFilename)
  432. if err != nil {
  433. return nil, ContextError(err)
  434. }
  435. rootCAs.AppendCertsFromPEM(certData)
  436. transport.TLSClientConfig = &tls.Config{RootCAs: rootCAs}
  437. }
  438. return &http.Client{
  439. Transport: transport,
  440. Timeout: requestTimeout,
  441. }, nil
  442. }
  443. // MakeDownloadHttpClient is a resusable helper that sets up a
  444. // http.Client for use either untunneled or through a tunnel.
  445. // See MakeUntunneledHttpsClient for a note about request URL
  446. // rewritting.
  447. func MakeDownloadHttpClient(
  448. config *Config,
  449. tunnel *Tunnel,
  450. untunneledDialConfig *DialConfig,
  451. requestUrl string,
  452. requestTimeout time.Duration) (*http.Client, string, error) {
  453. var httpClient *http.Client
  454. var err error
  455. if tunnel != nil {
  456. httpClient, err = MakeTunneledHttpClient(config, tunnel, requestTimeout)
  457. if err != nil {
  458. return nil, "", ContextError(err)
  459. }
  460. } else {
  461. httpClient, requestUrl, err = MakeUntunneledHttpsClient(
  462. untunneledDialConfig, nil, requestUrl, requestTimeout)
  463. if err != nil {
  464. return nil, "", ContextError(err)
  465. }
  466. }
  467. return httpClient, requestUrl, nil
  468. }
  469. // ResumeDownload is a resuable helper that downloads requestUrl via the
  470. // httpClient, storing the result in downloadFilename when the download is
  471. // complete. Intermediate, partial downloads state is stored in
  472. // downloadFilename.part and downloadFilename.part.etag.
  473. // Any existing downloadFilename file will be overwritten.
  474. //
  475. // In the case where the remote object has change while a partial download
  476. // is to be resumed, the partial state is reset and resumeDownload fails.
  477. // The caller must restart the download.
  478. //
  479. // When ifNoneMatchETag is specified, no download is made if the remote
  480. // object has the same ETag. ifNoneMatchETag has an effect only when no
  481. // partial download is in progress.
  482. //
  483. func ResumeDownload(
  484. httpClient *http.Client,
  485. requestUrl string,
  486. downloadFilename string,
  487. ifNoneMatchETag string) (int64, string, error) {
  488. partialFilename := fmt.Sprintf("%s.part", downloadFilename)
  489. partialETagFilename := fmt.Sprintf("%s.part.etag", downloadFilename)
  490. file, err := os.OpenFile(partialFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
  491. if err != nil {
  492. return 0, "", ContextError(err)
  493. }
  494. defer file.Close()
  495. fileInfo, err := file.Stat()
  496. if err != nil {
  497. return 0, "", ContextError(err)
  498. }
  499. // A partial download should have an ETag which is to be sent with the
  500. // Range request to ensure that the source object is the same as the
  501. // one that is partially downloaded.
  502. var partialETag []byte
  503. if fileInfo.Size() > 0 {
  504. partialETag, err = ioutil.ReadFile(partialETagFilename)
  505. // When the ETag can't be loaded, delete the partial download. To keep the
  506. // code simple, there is no immediate, inline retry here, on the assumption
  507. // that the controller's upgradeDownloader will shortly call DownloadUpgrade
  508. // again.
  509. if err != nil {
  510. os.Remove(partialFilename)
  511. os.Remove(partialETagFilename)
  512. return 0, "", ContextError(
  513. fmt.Errorf("failed to load partial download ETag: %s", err))
  514. }
  515. }
  516. request, err := http.NewRequest("GET", requestUrl, nil)
  517. if err != nil {
  518. return 0, "", ContextError(err)
  519. }
  520. request.Header.Add("Range", fmt.Sprintf("bytes=%d-", fileInfo.Size()))
  521. if partialETag != nil {
  522. // Note: not using If-Range, since not all host servers support it.
  523. // Using If-Match means we need to check for status code 412 and reset
  524. // when the ETag has changed since the last partial download.
  525. request.Header.Add("If-Match", string(partialETag))
  526. } else if ifNoneMatchETag != "" {
  527. // Can't specify both If-Match and If-None-Match. Behavior is undefined.
  528. // https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.26
  529. // So for downloaders that store an ETag and wish to use that to prevent
  530. // redundant downloads, that ETag is sent as If-None-Match in the case
  531. // where a partial download is not in progress. When a partial download
  532. // is in progress, the partial ETag is sent as If-Match: either that's
  533. // a version that was never fully received, or it's no longer current in
  534. // which case the response will be StatusPreconditionFailed, the partial
  535. // download will be discarded, and then the next retry will use
  536. // If-None-Match.
  537. // Note: in this case, fileInfo.Size() == 0
  538. request.Header.Add("If-None-Match", ifNoneMatchETag)
  539. }
  540. response, err := httpClient.Do(request)
  541. // The resumeable download may ask for bytes past the resource range
  542. // since it doesn't store the "completed download" state. In this case,
  543. // the HTTP server returns 416. Otherwise, we expect 206. We may also
  544. // receive 412 on ETag mismatch.
  545. if err == nil &&
  546. (response.StatusCode != http.StatusPartialContent &&
  547. response.StatusCode != http.StatusRequestedRangeNotSatisfiable &&
  548. response.StatusCode != http.StatusPreconditionFailed &&
  549. response.StatusCode != http.StatusNotModified) {
  550. response.Body.Close()
  551. err = fmt.Errorf("unexpected response status code: %d", response.StatusCode)
  552. }
  553. if err != nil {
  554. return 0, "", ContextError(err)
  555. }
  556. defer response.Body.Close()
  557. responseETag := response.Header.Get("ETag")
  558. if response.StatusCode == http.StatusPreconditionFailed {
  559. // When the ETag no longer matches, delete the partial download. As above,
  560. // simply failing and relying on the caller's retry schedule.
  561. os.Remove(partialFilename)
  562. os.Remove(partialETagFilename)
  563. return 0, "", ContextError(errors.New("partial download ETag mismatch"))
  564. } else if response.StatusCode == http.StatusNotModified {
  565. // This status code is possible in the "If-None-Match" case. Don't leave
  566. // any partial download in progress. Caller should check that responseETag
  567. // matches ifNoneMatchETag.
  568. os.Remove(partialFilename)
  569. os.Remove(partialETagFilename)
  570. return 0, responseETag, nil
  571. }
  572. // Not making failure to write ETag file fatal, in case the entire download
  573. // succeeds in this one request.
  574. ioutil.WriteFile(partialETagFilename, []byte(responseETag), 0600)
  575. // A partial download occurs when this copy is interrupted. The io.Copy
  576. // will fail, leaving a partial download in place (.part and .part.etag).
  577. n, err := io.Copy(NewSyncFileWriter(file), response.Body)
  578. // From this point, n bytes are indicated as downloaded, even if there is
  579. // an error; the caller may use this to report partial download progress.
  580. if err != nil {
  581. return n, "", ContextError(err)
  582. }
  583. // Ensure the file is flushed to disk. The deferred close
  584. // will be a noop when this succeeds.
  585. err = file.Close()
  586. if err != nil {
  587. return n, "", ContextError(err)
  588. }
  589. // Remove if exists, to enable rename
  590. os.Remove(downloadFilename)
  591. err = os.Rename(partialFilename, downloadFilename)
  592. if err != nil {
  593. return n, "", ContextError(err)
  594. }
  595. os.Remove(partialETagFilename)
  596. return n, responseETag, nil
  597. }
  598. // IPAddressFromAddr is a helper which extracts an IP address
  599. // from a net.Addr or returns "" if there is no IP address.
  600. func IPAddressFromAddr(addr net.Addr) string {
  601. ipAddress := ""
  602. if addr != nil {
  603. host, _, err := net.SplitHostPort(addr.String())
  604. if err == nil {
  605. ipAddress = host
  606. }
  607. }
  608. return ipAddress
  609. }
  610. // HTTPSServer is a wrapper around http.Server which adds the
  611. // ServeTLS function.
  612. type HTTPSServer struct {
  613. http.Server
  614. }
  615. // ServeTLS is a offers the equivalent interface as http.Serve.
  616. // The http package has both ListenAndServe and ListenAndServeTLS higher-
  617. // level interfaces, but only Serve (not TLS) offers a lower-level interface that
  618. // allows the caller to keep a refererence to the Listener, allowing for external
  619. // shutdown. ListenAndServeTLS also requires the TLS cert and key to be in files
  620. // and we avoid that here.
  621. // tcpKeepAliveListener is used in http.ListenAndServeTLS but not exported,
  622. // so we use a copy from https://golang.org/src/net/http/server.go.
  623. func (server *HTTPSServer) ServeTLS(listener net.Listener) error {
  624. tlsListener := tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, server.TLSConfig)
  625. return server.Serve(tlsListener)
  626. }
  627. type tcpKeepAliveListener struct {
  628. *net.TCPListener
  629. }
  630. func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
  631. tc, err := ln.AcceptTCP()
  632. if err != nil {
  633. return
  634. }
  635. tc.SetKeepAlive(true)
  636. tc.SetKeepAlivePeriod(3 * time.Minute)
  637. return tc, nil
  638. }
  639. // ActivityMonitoredConn wraps a net.Conn, adding logic to deal with
  640. // events triggered by I/O activity.
  641. //
  642. // When an inactivity timeout is specified, the net.Conn Read() will
  643. // timeout after the specified period of read inactivity. Optionally,
  644. // ActivityMonitoredConn will also consider the connection active when
  645. // data is written to it.
  646. //
  647. // When a LRUConnsEntry is specified, then the LRU entry is promoted on
  648. // either a successful read or write.
  649. //
  650. type ActivityMonitoredConn struct {
  651. net.Conn
  652. inactivityTimeout time.Duration
  653. activeOnWrite bool
  654. startTime int64
  655. lastActivityTime int64
  656. lruEntry *LRUConnsEntry
  657. }
  658. func NewActivityMonitoredConn(
  659. conn net.Conn,
  660. inactivityTimeout time.Duration,
  661. activeOnWrite bool,
  662. lruEntry *LRUConnsEntry) *ActivityMonitoredConn {
  663. if inactivityTimeout > 0 {
  664. conn.SetReadDeadline(time.Now().Add(inactivityTimeout))
  665. }
  666. now := time.Now().UnixNano()
  667. return &ActivityMonitoredConn{
  668. Conn: conn,
  669. inactivityTimeout: inactivityTimeout,
  670. activeOnWrite: activeOnWrite,
  671. startTime: now,
  672. lastActivityTime: now,
  673. lruEntry: lruEntry,
  674. }
  675. }
  676. // GetStartTime gets the time when the ActivityMonitoredConn was
  677. // initialized.
  678. func (conn *ActivityMonitoredConn) GetStartTime() time.Time {
  679. return time.Unix(0, conn.startTime)
  680. }
  681. // GetActiveDuration returns the time elapsed between the initialization
  682. // of the ActivityMonitoredConn and the last Read (or Write when
  683. // activeOnWrite is specified).
  684. func (conn *ActivityMonitoredConn) GetActiveDuration() time.Duration {
  685. return time.Duration(atomic.LoadInt64(&conn.lastActivityTime) - conn.startTime)
  686. }
  687. func (conn *ActivityMonitoredConn) Read(buffer []byte) (int, error) {
  688. n, err := conn.Conn.Read(buffer)
  689. if err == nil {
  690. atomic.StoreInt64(&conn.lastActivityTime, time.Now().UnixNano())
  691. if conn.inactivityTimeout > 0 {
  692. conn.Conn.SetReadDeadline(time.Now().Add(conn.inactivityTimeout))
  693. }
  694. if conn.lruEntry != nil {
  695. conn.lruEntry.Touch()
  696. }
  697. }
  698. return n, err
  699. }
  700. func (conn *ActivityMonitoredConn) Write(buffer []byte) (int, error) {
  701. n, err := conn.Conn.Write(buffer)
  702. if err == nil {
  703. if conn.activeOnWrite {
  704. atomic.StoreInt64(&conn.lastActivityTime, time.Now().UnixNano())
  705. if conn.inactivityTimeout > 0 {
  706. conn.Conn.SetReadDeadline(time.Now().Add(conn.inactivityTimeout))
  707. }
  708. }
  709. if conn.lruEntry != nil {
  710. conn.lruEntry.Touch()
  711. }
  712. }
  713. return n, err
  714. }
  715. // ThrottledConn wraps a net.Conn with read and write rate limiters.
  716. // Rates are specified as bytes per second. Optional unlimited byte
  717. // counts allow for a number of bytes to read or write before
  718. // applying rate limiting. Specify limit values of 0 to set no rate
  719. // limit (unlimited counts are ignored in this case).
  720. // The underlying rate limiter uses the token bucket algorithm to
  721. // calculate delay times for read and write operations.
  722. type ThrottledConn struct {
  723. net.Conn
  724. unlimitedReadBytes int64
  725. limitingReads int32
  726. limitedReader io.Reader
  727. unlimitedWriteBytes int64
  728. limitingWrites int32
  729. limitedWriter io.Writer
  730. }
  731. // NewThrottledConn initializes a new ThrottledConn.
  732. func NewThrottledConn(
  733. conn net.Conn,
  734. unlimitedReadBytes, limitReadBytesPerSecond,
  735. unlimitedWriteBytes, limitWriteBytesPerSecond int64) *ThrottledConn {
  736. // When no limit is specified, the rate limited reader/writer
  737. // is simply the base reader/writer.
  738. var reader io.Reader
  739. if limitReadBytesPerSecond == 0 {
  740. reader = conn
  741. } else {
  742. reader = ratelimit.Reader(conn,
  743. ratelimit.NewBucketWithRate(
  744. float64(limitReadBytesPerSecond), limitReadBytesPerSecond))
  745. }
  746. var writer io.Writer
  747. if limitWriteBytesPerSecond == 0 {
  748. writer = conn
  749. } else {
  750. writer = ratelimit.Writer(conn,
  751. ratelimit.NewBucketWithRate(
  752. float64(limitWriteBytesPerSecond), limitWriteBytesPerSecond))
  753. }
  754. return &ThrottledConn{
  755. Conn: conn,
  756. unlimitedReadBytes: unlimitedReadBytes,
  757. limitingReads: 0,
  758. limitedReader: reader,
  759. unlimitedWriteBytes: unlimitedWriteBytes,
  760. limitingWrites: 0,
  761. limitedWriter: writer,
  762. }
  763. }
  764. func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
  765. // Use the base reader until the unlimited count is exhausted.
  766. if atomic.LoadInt32(&conn.limitingReads) == 0 {
  767. if atomic.AddInt64(&conn.unlimitedReadBytes, -int64(len(buffer))) <= 0 {
  768. atomic.StoreInt32(&conn.limitingReads, 1)
  769. } else {
  770. return conn.Read(buffer)
  771. }
  772. }
  773. return conn.limitedReader.Read(buffer)
  774. }
  775. func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
  776. // Use the base writer until the unlimited count is exhausted.
  777. if atomic.LoadInt32(&conn.limitingWrites) == 0 {
  778. if atomic.AddInt64(&conn.unlimitedWriteBytes, -int64(len(buffer))) <= 0 {
  779. atomic.StoreInt32(&conn.limitingWrites, 1)
  780. } else {
  781. return conn.Write(buffer)
  782. }
  783. }
  784. return conn.limitedWriter.Write(buffer)
  785. }