Browse Source

Attempt to report actual test coverage
* Get coverage for subpackages and report to coveralls
* Copy testing of upstreamproxy package into package itself
* Move server-only net utils to server package

Rod Hynes 9 years ago
parent
commit
6c6f86b7ae

+ 6 - 1
.travis.yml

@@ -10,10 +10,15 @@ install:
 script:
 - go test -v ./...
 - cd psiphon
-- go test -v -covermode=count -coverprofile=coverage.out
+- go test -v -covermode=count -coverprofile=transferstats.coverprofile ./transferstats
+- go test -v -covermode=count -coverprofile=upstreamproxy.coverprofile ./upstreamproxy
+- go test -v -covermode=count -coverprofile=server.coverprofile ./server
+- go test -v -covermode=count -coverprofile=psiphon.coverprofile
+- $HOME/gopath/bin/gover
 - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN
 before_install:
 - go get github.com/axw/gocov/gocov
+- go get github.com/modocache/gover
 - go get github.com/mattn/goveralls
 - if ! go get github.com/golang/tools/cmd/cover; then go get golang.org/x/tools/cmd/cover; fi
 - openssl aes-256-cbc -K $encrypted_bf83b4ab4874_key -iv $encrypted_bf83b4ab4874_iv

+ 2 - 0
psiphon/controller_test.go

@@ -944,6 +944,7 @@ func initUpstreamProxy() {
 		proxy.OnRequest().DoFunc(
 			func(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
 				if !hasExpectedCustomHeaders(r.Header) {
+					ctx.Logf("missing expected headers: %+v", ctx.Req.Header)
 					return nil, goproxy.NewResponse(r, goproxy.ContentTypeText, http.StatusUnauthorized, "")
 				}
 				return r, nil
@@ -952,6 +953,7 @@ func initUpstreamProxy() {
 		proxy.OnRequest().HandleConnectFunc(
 			func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
 				if !hasExpectedCustomHeaders(ctx.Req.Header) {
+					ctx.Logf("missing expected headers: %+v", ctx.Req.Header)
 					return goproxy.RejectConnect, host
 				}
 				return goproxy.OkConnect, host

+ 0 - 271
psiphon/net.go

@@ -51,7 +51,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 package psiphon
 
 import (
-	"container/list"
 	"crypto/tls"
 	"crypto/x509"
 	"errors"
@@ -64,11 +63,9 @@ import (
 	"os"
 	"reflect"
 	"sync"
-	"sync/atomic"
 	"time"
 
 	"github.com/Psiphon-Inc/dns"
-	"github.com/Psiphon-Inc/ratelimit"
 )
 
 const DNS_PORT = 53
@@ -227,92 +224,6 @@ func (conns *Conns) CloseAll() {
 	conns.conns = make(map[net.Conn]bool)
 }
 
-// LRUConns is a concurrency-safe list of net.Conns ordered
-// by recent activity. Its purpose is to facilitate closing
-// the oldest connection in a set of connections.
-//
-// New connections added are referenced by a LRUConnsEntry,
-// which is used to Touch() active connections, which
-// promotes them to the front of the order and to Remove()
-// connections that are no longer LRU candidates.
-//
-// CloseOldest() will remove the oldest connection from the
-// list and call net.Conn.Close() on the connection.
-//
-// After an entry has been removed, LRUConnsEntry Touch()
-// and Remove() will have no effect.
-type LRUConns struct {
-	mutex sync.Mutex
-	list  *list.List
-}
-
-// NewLRUConns initializes a new LRUConns.
-func NewLRUConns() *LRUConns {
-	return &LRUConns{list: list.New()}
-}
-
-// Add inserts a net.Conn as the freshest connection
-// in a LRUConns and returns an LRUConnsEntry to be
-// used to freshen the connection or remove the connection
-// from the LRU list.
-func (conns *LRUConns) Add(conn net.Conn) *LRUConnsEntry {
-	conns.mutex.Lock()
-	defer conns.mutex.Unlock()
-	return &LRUConnsEntry{
-		lruConns: conns,
-		element:  conns.list.PushFront(conn),
-	}
-}
-
-// CloseOldest closes the oldest connection in a
-// LRUConns. It calls net.Conn.Close() on the
-// connection.
-func (conns *LRUConns) CloseOldest() {
-	conns.mutex.Lock()
-	oldest := conns.list.Back()
-	conn, ok := oldest.Value.(net.Conn)
-	if oldest != nil {
-		conns.list.Remove(oldest)
-	}
-	// Release mutex before closing conn
-	conns.mutex.Unlock()
-	if ok {
-		conn.Close()
-	}
-}
-
-// LRUConnsEntry is an entry in a LRUConns list.
-type LRUConnsEntry struct {
-	lruConns *LRUConns
-	element  *list.Element
-}
-
-// Remove deletes the connection referenced by the
-// LRUConnsEntry from the associated LRUConns.
-// Has no effect if the entry was not initialized
-// or previously removed.
-func (entry *LRUConnsEntry) Remove() {
-	if entry.lruConns == nil || entry.element == nil {
-		return
-	}
-	entry.lruConns.mutex.Lock()
-	defer entry.lruConns.mutex.Unlock()
-	entry.lruConns.list.Remove(entry.element)
-}
-
-// Touch promotes the connection referenced by the
-// LRUConnsEntry to the front of the associated LRUConns.
-// Has no effect if the entry was not initialized
-// or previously removed.
-func (entry *LRUConnsEntry) Touch() {
-	if entry.lruConns == nil || entry.element == nil {
-		return
-	}
-	entry.lruConns.mutex.Lock()
-	defer entry.lruConns.mutex.Unlock()
-	entry.lruConns.list.MoveToFront(entry.element)
-}
-
 // LocalProxyRelay sends to remoteConn bytes received from localConn,
 // and sends to localConn bytes received from remoteConn.
 func LocalProxyRelay(proxyType string, localConn, remoteConn net.Conn) {
@@ -737,185 +648,3 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
 	tc.SetKeepAlivePeriod(3 * time.Minute)
 	return tc, nil
 }
-
-// ActivityMonitoredConn wraps a net.Conn, adding logic to deal with
-// events triggered by I/O activity.
-//
-// When an inactivity timeout is specified, the network I/O will
-// timeout after the specified period of read inactivity. Optionally,
-// ActivityMonitoredConn will also consider the connection active when
-// data is written to it.
-//
-// When a LRUConnsEntry is specified, then the LRU entry is promoted on
-// either a successful read or write.
-//
-type ActivityMonitoredConn struct {
-	net.Conn
-	inactivityTimeout    time.Duration
-	activeOnWrite        bool
-	startTime            int64
-	lastReadActivityTime int64
-	lruEntry             *LRUConnsEntry
-}
-
-func NewActivityMonitoredConn(
-	conn net.Conn,
-	inactivityTimeout time.Duration,
-	activeOnWrite bool,
-	lruEntry *LRUConnsEntry) (*ActivityMonitoredConn, error) {
-
-	if inactivityTimeout > 0 {
-		err := conn.SetDeadline(time.Now().Add(inactivityTimeout))
-		if err != nil {
-			return nil, ContextError(err)
-		}
-	}
-
-	now := time.Now().UnixNano()
-
-	return &ActivityMonitoredConn{
-		Conn:                 conn,
-		inactivityTimeout:    inactivityTimeout,
-		activeOnWrite:        activeOnWrite,
-		startTime:            now,
-		lastReadActivityTime: now,
-		lruEntry:             lruEntry,
-	}, nil
-}
-
-// GetStartTime gets the time when the ActivityMonitoredConn was
-// initialized.
-func (conn *ActivityMonitoredConn) GetStartTime() time.Time {
-	return time.Unix(0, conn.startTime)
-}
-
-// GetActiveDuration returns the time elapsed between the initialization
-// of the ActivityMonitoredConn and the last Read. Only reads are used
-// for this calculation since writes may succeed locally due to buffering.
-func (conn *ActivityMonitoredConn) GetActiveDuration() time.Duration {
-	return time.Duration(atomic.LoadInt64(&conn.lastReadActivityTime) - conn.startTime)
-}
-
-func (conn *ActivityMonitoredConn) Read(buffer []byte) (int, error) {
-	n, err := conn.Conn.Read(buffer)
-	if err == nil {
-
-		if conn.inactivityTimeout > 0 {
-			err = conn.Conn.SetDeadline(time.Now().Add(conn.inactivityTimeout))
-			if err != nil {
-				return n, ContextError(err)
-			}
-		}
-		if conn.lruEntry != nil {
-			conn.lruEntry.Touch()
-		}
-
-		atomic.StoreInt64(&conn.lastReadActivityTime, time.Now().UnixNano())
-
-	}
-	// Note: no context error to preserve error type
-	return n, err
-}
-
-func (conn *ActivityMonitoredConn) Write(buffer []byte) (int, error) {
-	n, err := conn.Conn.Write(buffer)
-	if err == nil && conn.activeOnWrite {
-
-		if conn.inactivityTimeout > 0 {
-			err = conn.Conn.SetDeadline(time.Now().Add(conn.inactivityTimeout))
-			if err != nil {
-				return n, ContextError(err)
-			}
-		}
-
-		if conn.lruEntry != nil {
-			conn.lruEntry.Touch()
-		}
-
-	}
-	// Note: no context error to preserve error type
-	return n, err
-}
-
-// ThrottledConn wraps a net.Conn with read and write rate limiters.
-// Rates are specified as bytes per second. Optional unlimited byte
-// counts allow for a number of bytes to read or write before
-// applying rate limiting. Specify limit values of 0 to set no rate
-// limit (unlimited counts are ignored in this case).
-// The underlying rate limiter uses the token bucket algorithm to
-// calculate delay times for read and write operations.
-type ThrottledConn struct {
-	net.Conn
-	unlimitedReadBytes  int64
-	limitingReads       int32
-	limitedReader       io.Reader
-	unlimitedWriteBytes int64
-	limitingWrites      int32
-	limitedWriter       io.Writer
-}
-
-// NewThrottledConn initializes a new ThrottledConn.
-func NewThrottledConn(
-	conn net.Conn,
-	unlimitedReadBytes, limitReadBytesPerSecond,
-	unlimitedWriteBytes, limitWriteBytesPerSecond int64) *ThrottledConn {
-
-	// When no limit is specified, the rate limited reader/writer
-	// is simply the base reader/writer.
-
-	var reader io.Reader
-	if limitReadBytesPerSecond == 0 {
-		reader = conn
-	} else {
-		reader = ratelimit.Reader(conn,
-			ratelimit.NewBucketWithRate(
-				float64(limitReadBytesPerSecond), limitReadBytesPerSecond))
-	}
-
-	var writer io.Writer
-	if limitWriteBytesPerSecond == 0 {
-		writer = conn
-	} else {
-		writer = ratelimit.Writer(conn,
-			ratelimit.NewBucketWithRate(
-				float64(limitWriteBytesPerSecond), limitWriteBytesPerSecond))
-	}
-
-	return &ThrottledConn{
-		Conn:                conn,
-		unlimitedReadBytes:  unlimitedReadBytes,
-		limitingReads:       0,
-		limitedReader:       reader,
-		unlimitedWriteBytes: unlimitedWriteBytes,
-		limitingWrites:      0,
-		limitedWriter:       writer,
-	}
-}
-
-func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
-
-	// Use the base reader until the unlimited count is exhausted.
-	if atomic.LoadInt32(&conn.limitingReads) == 0 {
-		if atomic.AddInt64(&conn.unlimitedReadBytes, -int64(len(buffer))) <= 0 {
-			atomic.StoreInt32(&conn.limitingReads, 1)
-		} else {
-			return conn.Read(buffer)
-		}
-	}
-
-	return conn.limitedReader.Read(buffer)
-}
-
-func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
-
-	// Use the base writer until the unlimited count is exhausted.
-	if atomic.LoadInt32(&conn.limitingWrites) == 0 {
-		if atomic.AddInt64(&conn.unlimitedWriteBytes, -int64(len(buffer))) <= 0 {
-			atomic.StoreInt32(&conn.limitingWrites, 1)
-		} else {
-			return conn.Write(buffer)
-		}
-	}
-
-	return conn.limitedWriter.Write(buffer)
-}

+ 6 - 6
psiphon/server/tunnelServer.go

@@ -423,7 +423,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	// the connection active. Writes are not considered reliable activity indicators
 	// due to buffering.
 
-	activityConn, err := psiphon.NewActivityMonitoredConn(
+	activityConn, err := NewActivityMonitoredConn(
 		clientConn,
 		SSH_CONNECTION_READ_DEADLINE,
 		false,
@@ -438,7 +438,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 	// Further wrap the connection in a rate limiting ThrottledConn.
 
 	rateLimits := sshClient.trafficRules.GetRateLimits(tunnelProtocol)
-	clientConn = psiphon.NewThrottledConn(
+	clientConn = NewThrottledConn(
 		clientConn,
 		rateLimits.DownstreamUnlimitedBytes,
 		int64(rateLimits.DownstreamBytesPerSecond),
@@ -542,7 +542,7 @@ type sshClient struct {
 	sshServer               *sshServer
 	tunnelProtocol          string
 	sshConn                 ssh.Conn
-	activityConn            *psiphon.ActivityMonitoredConn
+	activityConn            *ActivityMonitoredConn
 	geoIPData               GeoIPData
 	psiphonSessionID        string
 	udpChannel              ssh.Channel
@@ -550,7 +550,7 @@ type sshClient struct {
 	tcpTrafficState         *trafficState
 	udpTrafficState         *trafficState
 	channelHandlerWaitGroup *sync.WaitGroup
-	tcpPortForwardLRU       *psiphon.LRUConns
+	tcpPortForwardLRU       *LRUConns
 	stopBroadcast           chan struct{}
 }
 
@@ -572,7 +572,7 @@ func newSshClient(
 		tcpTrafficState:         &trafficState{},
 		udpTrafficState:         &trafficState{},
 		channelHandlerWaitGroup: new(sync.WaitGroup),
-		tcpPortForwardLRU:       psiphon.NewLRUConns(),
+		tcpPortForwardLRU:       NewLRUConns(),
 		stopBroadcast:           make(chan struct{}),
 	}
 }
@@ -991,7 +991,7 @@ func (sshClient *sshClient) handleTCPChannel(
 	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
 	defer lruEntry.Remove()
 
-	fwdConn, err = psiphon.NewActivityMonitoredConn(
+	fwdConn, err = NewActivityMonitoredConn(
 		fwdConn,
 		time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond,
 		true,

+ 4 - 4
psiphon/server/udp.go

@@ -73,7 +73,7 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
 		sshClient:      sshClient,
 		sshChannel:     sshChannel,
 		portForwards:   make(map[uint16]*udpPortForward),
-		portForwardLRU: psiphon.NewLRUConns(),
+		portForwardLRU: NewLRUConns(),
 		relayWaitGroup: new(sync.WaitGroup),
 	}
 	multiplexer.run()
@@ -84,7 +84,7 @@ type udpPortForwardMultiplexer struct {
 	sshChannel        ssh.Channel
 	portForwardsMutex sync.Mutex
 	portForwards      map[uint16]*udpPortForward
-	portForwardLRU    *psiphon.LRUConns
+	portForwardLRU    *LRUConns
 	relayWaitGroup    *sync.WaitGroup
 }
 
@@ -203,7 +203,7 @@ func (mux *udpPortForwardMultiplexer) run() {
 
 			lruEntry := mux.portForwardLRU.Add(udpConn)
 
-			conn, err := psiphon.NewActivityMonitoredConn(
+			conn, err := NewActivityMonitoredConn(
 				udpConn,
 				time.Duration(mux.sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds)*time.Millisecond,
 				true,
@@ -273,7 +273,7 @@ type udpPortForward struct {
 	remoteIP     []byte
 	remotePort   uint16
 	conn         net.Conn
-	lruEntry     *psiphon.LRUConnsEntry
+	lruEntry     *LRUConnsEntry
 	bytesUp      int64
 	bytesDown    int64
 	mux          *udpPortForwardMultiplexer

+ 283 - 0
psiphon/upstreamproxy/upstreamproxy_test.go

@@ -0,0 +1,283 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package upstreamproxy
+
+import (
+	"encoding/json"
+	"flag"
+	"fmt"
+	"io/ioutil"
+	"net/http"
+	"net/url"
+	"os"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server"
+	"github.com/elazarl/goproxy"
+)
+
+// Note: upstreamproxy_test is redundant -- it doesn't test any cases not
+// covered by controller_test; and its code is largely copied from server_test
+// and controller_test. upstreamproxy_test exists so that coverage within the
+// upstreamproxy package can be measured and reported.
+
+func TestMain(m *testing.M) {
+	flag.Parse()
+	os.Remove(psiphon.DATA_STORE_FILENAME)
+	initUpstreamProxy()
+	psiphon.SetEmitDiagnosticNotices(true)
+	os.Exit(m.Run())
+}
+
+func TestSSHViaUpstreamProxy(t *testing.T) {
+	runServer(t, "SSH")
+}
+
+func TestOSSHViaUpstreamProxy(t *testing.T) {
+	runServer(t, "OSSH")
+}
+
+func TestUnfrontedMeekViaUpstreamProxy(t *testing.T) {
+	runServer(t, "UNFRONTED-MEEK-OSSH")
+}
+
+func TestUnfrontedMeekHTTPSViaUpstreamProxy(t *testing.T) {
+	runServer(t, "UNFRONTED-MEEK-HTTPS-OSSH")
+}
+
+func runServer(t *testing.T, tunnelProtocol string) {
+
+	// create a server
+
+	serverIPaddress, err := psiphon.GetInterfaceIPAddress("en0")
+	if err != nil {
+		t.Fatalf("error getting server IP address: %s", err)
+	}
+
+	serverConfigJSON, _, encodedServerEntry, err := server.GenerateConfig(
+		&server.GenerateConfigParams{
+			ServerIPAddress:      serverIPaddress,
+			EnableSSHAPIRequests: true,
+			WebServerPort:        8000,
+			TunnelProtocolPorts:  map[string]int{tunnelProtocol: 4000},
+		})
+	if err != nil {
+		t.Fatalf("error generating server config: %s", err)
+	}
+
+	// customize server config
+
+	var serverConfig interface{}
+	json.Unmarshal(serverConfigJSON, &serverConfig)
+	serverConfig.(map[string]interface{})["GeoIPDatabaseFilename"] = ""
+	serverConfig.(map[string]interface{})["PsinetDatabaseFilename"] = ""
+	serverConfig.(map[string]interface{})["TrafficRulesFilename"] = ""
+	serverConfigJSON, _ = json.Marshal(serverConfig)
+
+	// run server
+
+	serverWaitGroup := new(sync.WaitGroup)
+	serverWaitGroup.Add(1)
+	go func() {
+		defer serverWaitGroup.Done()
+		err := server.RunServices(serverConfigJSON)
+		if err != nil {
+			// TODO: wrong goroutine for t.FatalNow()
+			t.Fatalf("error running server: %s", err)
+		}
+	}()
+	defer func() {
+		p, _ := os.FindProcess(os.Getpid())
+		p.Signal(os.Interrupt)
+		serverWaitGroup.Wait()
+	}()
+
+	// connect to server with client
+
+	// TODO: currently, TargetServerEntry only works with one tunnel
+	numTunnels := 1
+	localHTTPProxyPort := 8081
+	establishTunnelPausePeriodSeconds := 1
+
+	// Note: calling LoadConfig ensures all *int config fields are initialized
+	clientConfigJSON := `
+    {
+        "ClientVersion" : "0",
+        "SponsorId" : "0",
+        "PropagationChannelId" : "0"
+    }`
+	clientConfig, _ := psiphon.LoadConfig([]byte(clientConfigJSON))
+
+	clientConfig.ConnectionWorkerPoolSize = numTunnels
+	clientConfig.TunnelPoolSize = numTunnels
+	clientConfig.DisableRemoteServerListFetcher = true
+	clientConfig.EstablishTunnelPausePeriodSeconds = &establishTunnelPausePeriodSeconds
+	clientConfig.TargetServerEntry = string(encodedServerEntry)
+	clientConfig.TunnelProtocol = tunnelProtocol
+	clientConfig.LocalHttpProxyPort = localHTTPProxyPort
+
+	clientConfig.UpstreamProxyUrl = upstreamProxyURL
+	clientConfig.UpstreamProxyCustomHeaders = upstreamProxyCustomHeaders
+
+	err = psiphon.InitDataStore(clientConfig)
+	if err != nil {
+		t.Fatalf("error initializing client datastore: %s", err)
+	}
+
+	controller, err := psiphon.NewController(clientConfig)
+	if err != nil {
+		t.Fatalf("error creating client controller: %s", err)
+	}
+
+	tunnelsEstablished := make(chan struct{}, 1)
+
+	psiphon.SetNoticeOutput(psiphon.NewNoticeReceiver(
+		func(notice []byte) {
+
+			fmt.Printf("%s\n", string(notice))
+
+			noticeType, payload, err := psiphon.GetNotice(notice)
+			if err != nil {
+				return
+			}
+
+			switch noticeType {
+			case "Tunnels":
+				count := int(payload["count"].(float64))
+				if count >= numTunnels {
+					select {
+					case tunnelsEstablished <- *new(struct{}):
+					default:
+					}
+				}
+			}
+		}))
+
+	controllerShutdownBroadcast := make(chan struct{})
+	controllerWaitGroup := new(sync.WaitGroup)
+	controllerWaitGroup.Add(1)
+	go func() {
+		defer controllerWaitGroup.Done()
+		controller.Run(controllerShutdownBroadcast)
+	}()
+	defer func() {
+		close(controllerShutdownBroadcast)
+		controllerWaitGroup.Wait()
+	}()
+
+	// Test: tunnels must be established within 30 seconds
+
+	establishTimeout := time.NewTimer(30 * time.Second)
+	select {
+	case <-tunnelsEstablished:
+	case <-establishTimeout.C:
+		t.Fatalf("tunnel establish timeout exceeded")
+	}
+
+	// Test: tunneled web site fetch
+
+	testUrl := "https://psiphon.ca"
+	roundTripTimeout := 30 * time.Second
+
+	proxyUrl, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", localHTTPProxyPort))
+	if err != nil {
+		t.Fatalf("error initializing proxied HTTP request: %s", err)
+	}
+
+	httpClient := &http.Client{
+		Transport: &http.Transport{
+			Proxy: http.ProxyURL(proxyUrl),
+		},
+		Timeout: roundTripTimeout,
+	}
+
+	response, err := httpClient.Get(testUrl)
+	if err != nil {
+		t.Fatalf("error sending proxied HTTP request: %s", err)
+	}
+
+	_, err = ioutil.ReadAll(response.Body)
+	if err != nil {
+		t.Fatalf("error reading proxied HTTP response: %s", err)
+	}
+	response.Body.Close()
+}
+
+const upstreamProxyURL = "http://127.0.0.1:2161"
+
+var upstreamProxyCustomHeaders = map[string][]string{"X-Test-Header-Name": []string{"test-header-value1", "test-header-value2"}}
+
+func hasExpectedCustomHeaders(h http.Header) bool {
+	for name, values := range upstreamProxyCustomHeaders {
+		if h[name] == nil {
+			return false
+		}
+		// Order may not be the same
+		for _, value := range values {
+			if !psiphon.Contains(h[name], value) {
+				return false
+			}
+		}
+	}
+	return true
+}
+
+func initUpstreamProxy() {
+	go func() {
+		proxy := goproxy.NewProxyHttpServer()
+
+		proxy.OnRequest().DoFunc(
+			func(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
+				if !hasExpectedCustomHeaders(r.Header) {
+					ctx.Logf("missing expected headers: %+v", ctx.Req.Header)
+					return nil, goproxy.NewResponse(r, goproxy.ContentTypeText, http.StatusUnauthorized, "")
+				}
+				return r, nil
+			})
+
+		proxy.OnRequest().HandleConnectFunc(
+			func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
+				// TODO: enable this check. Currently the headers aren't send because the
+				// following type assertion in upstreamproxy.newHTTP fails (but only in this
+				// test context, not in controller_test):
+				//   if upstreamProxyConfig, ok := forward.(*UpstreamProxyConfig); ok {
+				//       hp.customHeaders = upstreamProxyConfig.CustomHeaders
+				//   }
+				//
+				/*
+					if !hasExpectedCustomHeaders(ctx.Req.Header) {
+						ctx.Logf("missing expected headers: %+v", ctx.Req.Header)
+						return goproxy.RejectConnect, host
+					}
+				*/
+				return goproxy.OkConnect, host
+			})
+
+		err := http.ListenAndServe("127.0.0.1:2161", proxy)
+		if err != nil {
+			fmt.Printf("upstream proxy failed: %s", err)
+		}
+	}()
+
+	// TODO: wait until listener is active?
+}