Bläddra i källkod

Merge branch 'master' of https://github.com/Psiphon-Labs/psiphon-tunnel-core

Rod Hynes 11 år sedan
förälder
incheckning
45ffdde0ca
8 ändrade filer med 797 tillägg och 19 borttagningar
  1. 32 6
      psiphon/controller.go
  2. 2 1
      psiphon/meekConn.go
  3. 38 11
      psiphon/serverApi.go
  4. 264 0
      psiphon/stats_collector.go
  5. 109 0
      psiphon/stats_conn.go
  6. 75 0
      psiphon/stats_regexp.go
  7. 258 0
      psiphon/stats_test.go
  8. 19 1
      psiphon/tunnel.go

+ 32 - 6
psiphon/controller.go

@@ -83,9 +83,11 @@ func NewController(config *Config) (controller *Controller) {
 // - a local SOCKS proxy that port forwards through the pool of tunnels
 // - a local HTTP proxy that port forwards through the pool of tunnels
 func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
-
 	Notice(NOTICE_VERSION, VERSION)
 
+	Stats_Start()
+	defer Stats_Stop()
+
 	socksProxy, err := NewSocksProxy(controller.config, controller)
 	if err != nil {
 		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
@@ -378,7 +380,7 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 
 	Notice(NOTICE_INFO, "starting session for %s", tunnel.serverEntry.IpAddress)
 	// TODO: NewSession server API calls may block shutdown
-	_, err = NewSession(controller.config, tunnel)
+	session, err := NewSession(controller.config, tunnel)
 	if err != nil {
 		err = fmt.Errorf("error starting session for %s: %s", tunnel.serverEntry.IpAddress, err)
 	}
@@ -390,6 +392,8 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 	// of the first candidates next time establish runs.
 	PromoteServerEntry(tunnel.serverEntry.IpAddress)
 
+	statsTimer := time.NewTimer(NextSendPeriod())
+
 	for err == nil {
 		select {
 		case failures := <-tunnel.portForwardFailures:
@@ -407,8 +411,14 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 			err = errors.New("tunnel closed unexpectedly")
 
 		case <-controller.shutdownBroadcast:
+			// Send final stats
+			sendStats(tunnel, session, true)
 			Notice(NOTICE_INFO, "shutdown operate tunnel")
 			return
+
+		case <-statsTimer.C:
+			sendStats(tunnel, session, false)
+			statsTimer.Reset(NextSendPeriod())
 		}
 	}
 
@@ -425,6 +435,18 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 	}
 }
 
+// sendStats is a helper for sending session stats to the server.
+func sendStats(tunnel *Tunnel, session *Session, final bool) {
+	payload := GetForServer(tunnel.serverEntry.IpAddress)
+	if payload != nil {
+		err := session.DoStatusRequest(payload, final)
+		if err != nil {
+			Notice(NOTICE_ALERT, "DoStatusRequest failed for %s: %s", tunnel.serverEntry.IpAddress, err)
+			PutBack(tunnel.serverEntry.IpAddress, payload)
+		}
+	}
+}
+
 // TunneledConn implements net.Conn and wraps a port foward connection.
 // It is used to hook into Read and Write to observe I/O errors and
 // report these errors back to the tunnel monitor as port forward failures.
@@ -476,10 +498,14 @@ func (controller *Controller) Dial(remoteAddr string) (conn net.Conn, err error)
 		}
 		return nil, ContextError(err)
 	}
-	return &TunneledConn{
-			Conn:   tunnelConn,
-			tunnel: tunnel},
-		nil
+
+	statsConn := NewStatsConn(tunnelConn, tunnel.ServerID(), tunnel.StatsRegexps())
+
+	conn = &TunneledConn{
+		Conn:   statsConn,
+		tunnel: tunnel}
+
+	return
 }
 
 // startEstablishing creates a pool of worker goroutines which will

+ 2 - 1
psiphon/meekConn.go

@@ -21,7 +21,6 @@ package psiphon
 
 import (
 	"bytes"
-	"code.google.com/p/go.crypto/nacl/box"
 	"crypto/rand"
 	"encoding/base64"
 	"encoding/json"
@@ -33,6 +32,8 @@ import (
 	"net/url"
 	"sync"
 	"time"
+
+	"code.google.com/p/go.crypto/nacl/box"
 )
 
 // MeekConn is based on meek-client.go from Tor and Psiphon:

+ 38 - 11
psiphon/serverApi.go

@@ -24,6 +24,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"net"
 	"net/http"
@@ -71,8 +72,26 @@ func NewSession(config *Config, tunnel *Tunnel) (session *Session, err error) {
 	return session, nil
 }
 
-func (session *Session) DoStatusRequest() {
-	// TODO: implement (required for page view stats)
+// DoStatusRequest makes a /status request to the server, sending session stats.
+// final should be true if this is the last such request before disconnecting.
+func (session *Session) DoStatusRequest(statsPayload json.Marshaler, final bool) error {
+	statsPayloadJSON, err := json.Marshal(statsPayload)
+	if err != nil {
+		return ContextError(err)
+	}
+
+	connected := "1"
+	if final {
+		connected = "0"
+	}
+
+	url := session.buildRequestUrl(
+		"status",
+		&ExtraParam{"session_id", session.tunnel.sessionId},
+		&ExtraParam{"connected", connected})
+
+	err = session.doPostRequest(url, "application/json", bytes.NewReader(statsPayloadJSON))
+	return ContextError(err)
 }
 
 // doHandshakeRequest performs the handshake API request. The handshake
@@ -142,15 +161,9 @@ func (session *Session) doHandshakeRequest() error {
 	if upgradeClientVersion > session.config.ClientVersion {
 		Notice(NOTICE_UPGRADE, "%d", upgradeClientVersion)
 	}
-	// TODO: remove regex notices -- regexes will be used internally
-	/*
-		for _, pageViewRegex := range handshakeConfig.PageViewRegexes {
-			Notice(NOTICE_PAGE_VIEW_REGEX, "%s %s", pageViewRegex["regex"], pageViewRegex["replace"])
-		}
-		for _, httpsRequestRegex := range handshakeConfig.HttpsRequestRegexes {
-			Notice(NOTICE_HTTPS_REGEX, "%s %s", httpsRequestRegex["regex"], httpsRequestRegex["replace"])
-		}
-	*/
+	session.tunnel.SetStatsRegexps(MakeRegexps(
+		handshakeConfig.PageViewRegexes,
+		handshakeConfig.HttpsRequestRegexes))
 	return nil
 }
 
@@ -250,6 +263,20 @@ func (session *Session) doGetRequest(requestUrl string) (responseBody []byte, er
 	return body, nil
 }
 
+// doPostRequest makes a tunneled HTTPS POST request.
+func (session *Session) doPostRequest(requestUrl string, bodyType string, body io.Reader) (err error) {
+	response, err := session.psiphonHttpsClient.Post(requestUrl, bodyType, body)
+	if err != nil {
+		// Trim this error since it may include long URLs
+		return ContextError(TrimError(err))
+	}
+	response.Body.Close()
+	if response.StatusCode != http.StatusOK {
+		return ContextError(fmt.Errorf("HTTP POST request failed with response code: %d", response.StatusCode))
+	}
+	return
+}
+
 // makeHttpsClient creates a Psiphon HTTPS client that tunnels requests and which validates
 // the web server using the Psiphon server entry web server certificate.
 // This is not a general purpose HTTPS client.

+ 264 - 0
psiphon/stats_collector.go

@@ -0,0 +1,264 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import (
+	"encoding/base64"
+	"encoding/json"
+	"sync"
+	"time"
+)
+
+// TODO: Stats for a server are only removed when they are sent in a status
+// update to that server. So if there's an unexpected disconnect from serverA
+// and then a reconnect to serverB, the stats for serverA will never get sent
+// (unless there's later a reconnect to serverA). That means the stats for
+// serverA will never get deleted and the memory won't get freed. This is only
+// a small amount of memory (< 1KB, probably), but we should still probably add
+// some kind of stale-stats cleanup.
+
+// _CHANNEL_CAPACITY is the size of the channel that connections use to send stats
+// bundles to the collector/processor.
+var _CHANNEL_CAPACITY = 1000
+
+// Per-host/domain stats.
+// Note that the bytes we're counting are the ones going into the tunnel, so do
+// not include transport overhead.
+type hostStats struct {
+	numBytesSent     int64
+	numBytesReceived int64
+}
+
+func newHostStats() *hostStats {
+	return &hostStats{}
+}
+
+// serverStats holds per-server stats.
+type serverStats struct {
+	hostnameToStats map[string]*hostStats
+}
+
+func newServerStats() *serverStats {
+	return &serverStats{
+		hostnameToStats: make(map[string]*hostStats),
+	}
+}
+
+// allStats is the root object that holds stats for all servers and all hosts,
+// as well as the mutex to access them, the channel to update them, etc.
+var allStats struct {
+	serverIDtoStats    map[string]*serverStats
+	statsMutex         sync.RWMutex
+	stopSignal         chan struct{}
+	statsChan          chan []*statsUpdate
+	processorWaitGroup sync.WaitGroup
+}
+
+// Start initializes and begins stats collection. Must be called once, when the
+// application starts.
+func Stats_Start() {
+	if allStats.stopSignal != nil {
+		return
+	}
+
+	allStats.serverIDtoStats = make(map[string]*serverStats)
+	allStats.stopSignal = make(chan struct{})
+	allStats.statsChan = make(chan []*statsUpdate, _CHANNEL_CAPACITY)
+
+	allStats.processorWaitGroup.Add(1)
+	go processStats()
+}
+
+// Stop ends stats collection. Must be called once, before the application
+// terminates.
+func Stats_Stop() {
+	if allStats.stopSignal != nil {
+		close(allStats.stopSignal)
+		allStats.processorWaitGroup.Wait()
+		allStats.stopSignal = nil
+	}
+}
+
+// Instances of statsUpdate will be sent through the connection-to-collector
+// channel.
+type statsUpdate struct {
+	serverID         string
+	hostname         string
+	numBytesSent     int64
+	numBytesReceived int64
+}
+
+// recordStats makes sure the given stats update is added to the global
+// collection. Guaranteed to not block.
+// Callers of this function should assume that it "takes control" of the
+// statsUpdate object.
+func recordStat(newStat *statsUpdate) {
+	statSlice := []*statsUpdate{newStat}
+	// Priority: Don't block connections when updating stats. We can't just
+	// write to the statsChan, since that will block if it's full. We could
+	// launch a goroutine for each update, but that seems like  unnecessary
+	// overhead. So we'll try to write to the channel, and launch a goro if it
+	// fails.
+	select {
+	case allStats.statsChan <- statSlice:
+	default:
+		go func() {
+			allStats.statsChan <- statSlice
+		}()
+	}
+}
+
+// processStats is a goro started by Start() and runs until Stop(). It collects
+// stats provided by StatsConn.
+func processStats() {
+	defer allStats.processorWaitGroup.Done()
+
+	for {
+		select {
+		case statSlice := <-allStats.statsChan:
+			allStats.statsMutex.Lock()
+
+			for _, stat := range statSlice {
+				if stat.hostname == "" {
+					stat.hostname = "(OTHER)"
+				}
+
+				storedServerStats := allStats.serverIDtoStats[stat.serverID]
+				if storedServerStats == nil {
+					storedServerStats = newServerStats()
+					allStats.serverIDtoStats[stat.serverID] = storedServerStats
+				}
+
+				storedHostStats := storedServerStats.hostnameToStats[stat.hostname]
+				if storedHostStats == nil {
+					storedHostStats = newHostStats()
+					storedServerStats.hostnameToStats[stat.hostname] = storedHostStats
+				}
+
+				storedHostStats.numBytesSent += stat.numBytesSent
+				storedHostStats.numBytesReceived += stat.numBytesReceived
+
+				//fmt.Println("server:", stat.serverID, "host:", stat.hostname, "sent:", storedHostStats.numBytesSent, "received:", storedHostStats.numBytesReceived)
+			}
+
+			allStats.statsMutex.Unlock()
+
+		default:
+			// Note that we only checking the stopSignal in the default case. This is
+			// because we don't want the statsChan to fill and block the connections
+			// sending to it. The connections have their own signals, so they will
+			// stop themselves, we will drain the channel, and then we will stop.
+			select {
+			case <-allStats.stopSignal:
+				return
+			default:
+			}
+		}
+	}
+}
+
+// NextSendPeriod returns the amount of time that should be waited before the
+// next time stats are sent.
+func NextSendPeriod() (duration time.Duration) {
+	defaultStatsSendDuration := 5 * 60 * 1000 // 5 minutes in millis
+
+	// We include a random component to make the stats send less fingerprintable.
+	jitter, err := MakeSecureRandomInt(defaultStatsSendDuration)
+
+	// In case of error we're just going to use zero jitter.
+	if err != nil {
+		Notice(NOTICE_ALERT, "stats.NextSendPeriod: MakeSecureRandomInt failed")
+	}
+
+	duration = time.Duration(defaultStatsSendDuration+jitter) * time.Millisecond
+	return
+}
+
+// Implement the json.Marshaler interface
+func (ss serverStats) MarshalJSON() ([]byte, error) {
+	out := make(map[string]interface{})
+
+	// Add a random amount of padding to help prevent stats updates from being
+	// a predictable size (which often happens when the connection is quiet).
+	var padding []byte
+	paddingSize, err := MakeSecureRandomInt(256)
+	// In case of randomness fail, we're going to proceed with zero padding.
+	// TODO: Is this okay?
+	if err != nil {
+		Notice(NOTICE_ALERT, "stats.serverStats.MarshalJSON: MakeSecureRandomInt failed")
+		padding = make([]byte, 0)
+	} else {
+		padding, err = MakeSecureRandomBytes(paddingSize)
+		if err != nil {
+			Notice(NOTICE_ALERT, "stats.serverStats.MarshalJSON: MakeSecureRandomBytes failed")
+			padding = make([]byte, 0)
+		}
+	}
+
+	hostBytes := make(map[string]int64)
+	bytesTransferred := int64(0)
+
+	for hostname, hostStats := range ss.hostnameToStats {
+		totalBytes := hostStats.numBytesReceived + hostStats.numBytesSent
+		bytesTransferred += totalBytes
+		hostBytes[hostname] = totalBytes
+	}
+
+	out["bytes_transferred"] = bytesTransferred
+	out["host_bytes"] = hostBytes
+
+	// Print the notice before adding the padding, since it's not interesting
+	noticeJSON, _ := json.Marshal(out)
+	Notice(NOTICE_INFO, "sending stats: %s", noticeJSON)
+
+	out["padding"] = base64.StdEncoding.EncodeToString(padding)
+
+	// We're not using these fields, but the server requires them
+	out["page_views"] = make([]string, 0)
+	out["https_requests"] = make([]string, 0)
+
+	return json.Marshal(out)
+}
+
+// GetForServer returns the json-able stats package for the given server.
+// If there are no stats, nil will be returned.
+func GetForServer(serverID string) (payload *serverStats) {
+	allStats.statsMutex.Lock()
+	defer allStats.statsMutex.Unlock()
+
+	payload = allStats.serverIDtoStats[serverID]
+	delete(allStats.serverIDtoStats, serverID)
+	return
+}
+
+// PutBack re-adds a set of server stats to the collection.
+func PutBack(serverID string, ss *serverStats) {
+	statSlice := make([]*statsUpdate, 0, len(ss.hostnameToStats))
+	for hostname, hoststats := range ss.hostnameToStats {
+		statSlice = append(statSlice, &statsUpdate{
+			serverID:         serverID,
+			hostname:         hostname,
+			numBytesSent:     hoststats.numBytesSent,
+			numBytesReceived: hoststats.numBytesReceived,
+		})
+	}
+
+	allStats.statsChan <- statSlice
+}

+ 109 - 0
psiphon/stats_conn.go

@@ -0,0 +1,109 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+// Package stats counts and keeps track of session stats. These are per-domain
+// bytes transferred and total bytes transferred.
+package psiphon
+
+/*
+Assumption: The same connection will not be used to access different hostnames
+	(even if, say, those hostnames map to the same server). If this does occur, we
+	will mis-attribute some bytes.
+Assumption: Enough of the first HTTP will be present in the first Write() call
+	for us to a) recognize that it is HTTP, and b) parse the hostname.
+		- If this turns out to not be generally true we will need to add buffering.
+*/
+
+import (
+	"bufio"
+	"bytes"
+	"net"
+	"net/http"
+)
+
+// StatsConn is to be used as an intermediate link in a chain of net.Conn objects.
+// It inspects requests and responses and derives stats from them.
+type StatsConn struct {
+	net.Conn
+	serverID   string
+	hostname   string
+	firstWrite bool
+	regexps    *Regexps
+}
+
+// NewStatsConn creates a StatsConn. serverID can be anything that uniquely
+// identifies the server; it will be passed to GetForServer() when retrieving
+// the accumulated stats.
+func NewStatsConn(nextConn net.Conn, serverID string, regexps *Regexps) *StatsConn {
+	return &StatsConn{
+		Conn:       nextConn,
+		serverID:   serverID,
+		firstWrite: true,
+		regexps:    regexps,
+	}
+}
+
+// Write is called when requests are being written out through the tunnel to
+// the remote server.
+func (conn *StatsConn) Write(buffer []byte) (n int, err error) {
+	// First pass the data down the chain.
+	n, err = conn.Conn.Write(buffer)
+
+	// Count stats before we check the error condition. It could happen that the
+	// buffer was partially written and then an error occurred.
+	if n > 0 {
+		// If this is the first request, try to determine the hostname to associate
+		// with this connection.
+		if conn.firstWrite {
+			conn.firstWrite = false
+
+			// Check if this is a HTTP request
+			bufferReader := bufio.NewReader(bytes.NewReader(buffer))
+			httpReq, httpErr := http.ReadRequest(bufferReader)
+			if httpErr == nil {
+				// Get the hostname value that will be stored in stats by
+				// regexing the real hostname.
+				conn.hostname = regexHostname(httpReq.Host, conn.regexps)
+			}
+		}
+
+		recordStat(&statsUpdate{
+			conn.serverID,
+			conn.hostname,
+			int64(n),
+			0})
+	}
+
+	return
+}
+
+// Read is called when responses to requests are being read from the remote server.
+func (conn *StatsConn) Read(buffer []byte) (n int, err error) {
+	n, err = conn.Conn.Read(buffer)
+
+	// Count bytes without checking the error condition. It could happen that the
+	// buffer was partially read and then an error occurred.
+	recordStat(&statsUpdate{
+		conn.serverID,
+		conn.hostname,
+		0,
+		int64(n)})
+
+	return
+}

+ 75 - 0
psiphon/stats_regexp.go

@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import "regexp"
+
+type regexpReplace struct {
+	regexp  *regexp.Regexp
+	replace string
+}
+
+// Regexps holds the regular expressions and replacement strings used for
+// transforming URLs and hostnames into a stats-appropriate forms.
+type Regexps []regexpReplace
+
+// MakeRegexps takes the raw string-map form of the regex-replace pairs
+// returned by the server handshake and turns them into a usable object.
+func MakeRegexps(pageViewRegexes, httpsRequestRegexes []map[string]string) *Regexps {
+	regexps := make(Regexps, 0)
+
+	// We aren't doing page view stats anymore, so we won't process those regexps.
+	for _, rr := range httpsRequestRegexes {
+		regexString := rr["regex"]
+		if regexString == "" {
+			Notice(NOTICE_ALERT, "MakeRegexps: empty regex")
+			continue
+		}
+
+		replace := rr["replace"]
+		if replace == "" {
+			Notice(NOTICE_ALERT, "MakeRegexps: empty replace")
+			continue
+		}
+
+		regex, err := regexp.Compile(regexString)
+		if err != nil {
+			Notice(NOTICE_ALERT, "MakeRegexps: failed to compile regex: %s: %s", regexString, err)
+			continue
+		}
+
+		regexps = append(regexps, regexpReplace{regex, replace})
+	}
+
+	return &regexps
+}
+
+// regexHostname processes hostname through the given regexps and returns the
+// string that should be used for stats.
+func regexHostname(hostname string, regexps *Regexps) (statsHostname string) {
+	statsHostname = "(OTHER)"
+	for _, rr := range *regexps {
+		if rr.regexp.MatchString(hostname) {
+			statsHostname = rr.regexp.ReplaceAllString(hostname, rr.replace)
+			break
+		}
+	}
+	return
+}

+ 258 - 0
psiphon/stats_test.go

@@ -0,0 +1,258 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import (
+	"encoding/json"
+	"errors"
+	"net"
+	"net/http"
+	"testing"
+	"time"
+
+	mapset "github.com/deckarep/golang-set"
+	"github.com/stretchr/testify/suite"
+)
+
+var _SERVER_ID = "myserverid"
+
+type StatsTestSuite struct {
+	suite.Suite
+	httpClient *http.Client
+}
+
+func TestStatsTestSuite(t *testing.T) {
+	suite.Run(t, new(StatsTestSuite))
+}
+
+func (suite *StatsTestSuite) SetupTest() {
+	Stats_Start()
+
+	re := make(Regexps, 0)
+	suite.httpClient = &http.Client{
+		Transport: &http.Transport{
+			Dial: makeStatsDialer(_SERVER_ID, &re),
+		},
+	}
+}
+
+func (suite *StatsTestSuite) TearDownTest() {
+	suite.httpClient = nil
+	Stats_Stop()
+}
+
+func makeStatsDialer(serverID string, regexps *Regexps) func(network, addr string) (conn net.Conn, err error) {
+	return func(network, addr string) (conn net.Conn, err error) {
+		var subConn net.Conn
+
+		switch network {
+		case "tcp", "tcp4", "tcp6":
+			tcpAddr, err := net.ResolveTCPAddr(network, addr)
+			if err != nil {
+				return nil, err
+			}
+			subConn, err = net.DialTCP(network, nil, tcpAddr)
+			if err != nil {
+				return nil, err
+			}
+		default:
+			err = errors.New("using an unsupported testing network type")
+			return
+		}
+
+		conn = NewStatsConn(subConn, serverID, regexps)
+		err = nil
+		return
+	}
+}
+
+func (suite *StatsTestSuite) Test_StartStop() {
+	// Make sure Start and Stop calls don't crash
+	Stats_Start()
+	Stats_Start()
+	Stats_Stop()
+	Stats_Stop()
+	Stats_Start()
+	Stats_Stop()
+}
+
+func (suite *StatsTestSuite) Test_NextSendPeriod() {
+	res1 := NextSendPeriod()
+	suite.True(res1 > time.Duration(0), "duration should not be zero")
+
+	res2 := NextSendPeriod()
+	suite.NotEqual(res1, res2, "duration should have randomness difference between calls")
+}
+
+func (suite *StatsTestSuite) Test_StatsConn() {
+	resp, err := suite.httpClient.Get("http://example.com/index.html")
+	suite.Nil(err, "basic HTTP requests should succeed (1)")
+	resp.Body.Close()
+
+	resp, err = suite.httpClient.Get("http://example.org/index.html")
+	suite.Nil(err, "basic HTTP requests should succeed (1)")
+	resp.Body.Close()
+}
+
+func (suite *StatsTestSuite) Test_GetForServer() {
+	payload := GetForServer(_SERVER_ID)
+	suite.Nil(payload, "should get nil stats before any traffic (but not crash)")
+
+	resp, err := suite.httpClient.Get("http://example.com/index.html")
+	suite.Nil(err, "need successful http to proceed with tests")
+	resp.Body.Close()
+
+	// Make sure there aren't stats returned for a bad server ID
+	payload = GetForServer("INVALID")
+	suite.Nil(payload, "should get nil stats for invalid server ID")
+
+	payload = GetForServer(_SERVER_ID)
+	suite.NotNil(payload, "should receive valid payload for valid server ID")
+
+	payloadJSON, err := json.Marshal(payload)
+	var parsedJSON interface{}
+	err = json.Unmarshal(payloadJSON, &parsedJSON)
+	suite.Nil(err, "payload JSON should parse successfully")
+
+	// After we retrieve the stats for a server, they should be cleared out of the tracked stats
+	payload = GetForServer(_SERVER_ID)
+	suite.Nil(payload, "after retrieving stats for a server, there should be no more stats (until more data goes through)")
+}
+
+func (suite *StatsTestSuite) Test_PutBack() {
+	resp, err := suite.httpClient.Get("http://example.com/index.html")
+	suite.Nil(err, "need successful http to proceed with tests")
+	resp.Body.Close()
+
+	payloadToPutBack := GetForServer(_SERVER_ID)
+	suite.NotNil(payloadToPutBack, "should receive valid payload for valid server ID")
+
+	payload := GetForServer(_SERVER_ID)
+	suite.Nil(payload, "should not be any remaining stats after getting them")
+
+	PutBack(_SERVER_ID, payloadToPutBack)
+	// PutBack is asynchronous, so we'll need to wait a moment for it to do its thing
+	<-time.After(100 * time.Millisecond)
+
+	payload = GetForServer(_SERVER_ID)
+	suite.NotNil(payload, "stats should be re-added after putting back")
+	suite.Equal(payload, payloadToPutBack, "stats should be the same as after the first retrieval")
+}
+
+func (suite *StatsTestSuite) Test_MakeRegexps() {
+	pageViewRegexes := []map[string]string{make(map[string]string)}
+	pageViewRegexes[0]["regex"] = `(^http://[a-z0-9\.]*\.example\.[a-z\.]*)/.*`
+	pageViewRegexes[0]["replace"] = "$1"
+
+	httpsRequestRegexes := []map[string]string{make(map[string]string), make(map[string]string)}
+	httpsRequestRegexes[0]["regex"] = `^[a-z0-9\.]*\.(example\.com)$`
+	httpsRequestRegexes[0]["replace"] = "$1"
+	httpsRequestRegexes[1]["regex"] = `^.*example\.org$`
+	httpsRequestRegexes[1]["replace"] = "replacement"
+
+	regexps := MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+	suite.NotNil(regexps, "should return a valid object")
+	suite.Len(*regexps, 2, "should only have processed httpsRequestRegexes")
+
+	//
+	// Test some bad regexps
+	//
+
+	httpsRequestRegexes[0]["regex"] = ""
+	httpsRequestRegexes[0]["replace"] = "$1"
+	regexps = MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+	suite.NotNil(regexps, "should return a valid object")
+	suite.Len(*regexps, 1, "should have discarded one regexp")
+
+	httpsRequestRegexes[0]["regex"] = `^[a-z0-9\.]*\.(example\.com)$`
+	httpsRequestRegexes[0]["replace"] = ""
+	regexps = MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+	suite.NotNil(regexps, "should return a valid object")
+	suite.Len(*regexps, 1, "should have discarded one regexp")
+
+	httpsRequestRegexes[0]["regex"] = `^[a-z0-9\.]*\.(example\.com$` // missing closing paren
+	httpsRequestRegexes[0]["replace"] = "$1"
+	regexps = MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+	suite.NotNil(regexps, "should return a valid object")
+	suite.Len(*regexps, 1, "should have discarded one regexp")
+}
+
+func (suite *StatsTestSuite) Test_Regex() {
+	// We'll make a new client with actual regexps.
+	pageViewRegexes := make([]map[string]string, 0)
+	httpsRequestRegexes := []map[string]string{make(map[string]string), make(map[string]string)}
+	httpsRequestRegexes[0]["regex"] = `^[a-z0-9\.]*\.(example\.com)$`
+	httpsRequestRegexes[0]["replace"] = "$1"
+	httpsRequestRegexes[1]["regex"] = `^.*example\.org$`
+	httpsRequestRegexes[1]["replace"] = "replacement"
+	regexps := MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+
+	suite.httpClient = &http.Client{
+		Transport: &http.Transport{
+			Dial: makeStatsDialer(_SERVER_ID, regexps),
+		},
+	}
+
+	// No subdomain, so won't match regex
+	resp, err := suite.httpClient.Get("http://example.com/index.html")
+	suite.Nil(err)
+	resp.Body.Close()
+
+	// Will match the first regex
+	resp, err = suite.httpClient.Get("http://www.example.com/index.html")
+	suite.Nil(err)
+	resp.Body.Close()
+
+	// Will match the second regex
+	resp, err = suite.httpClient.Get("http://example.org/index.html")
+	suite.Nil(err)
+	resp.Body.Close()
+
+	payload := GetForServer(_SERVER_ID)
+	suite.NotNil(payload, "should get stats because we made HTTP reqs")
+
+	expectedHostnames := mapset.NewSet()
+	expectedHostnames.Add("(OTHER)")
+	expectedHostnames.Add("example.com")
+	expectedHostnames.Add("replacement")
+
+	hostnames := make([]interface{}, 0)
+	for hostname := range payload.hostnameToStats {
+		hostnames = append(hostnames, hostname)
+	}
+
+	actualHostnames := mapset.NewSetFromSlice(hostnames)
+
+	suite.Equal(expectedHostnames, actualHostnames, "post-regex hostnames should be processed as expecteds")
+}
+
+func (suite *StatsTestSuite) Test_recordStat() {
+	// The normal operation of this function will get exercised during the
+	// other tests, but there is a code branch that only gets hit when the
+	// allStats.statsChan is filled. To make sure we fill the channel, we will
+	// lock the stats access mutex, try to record a bunch of stats, and then
+	// release it.
+	allStats.statsMutex.Lock()
+	stat := statsUpdate{"test", "test", 1, 1}
+	for i := 0; i < _CHANNEL_CAPACITY*2; i++ {
+		recordStat(&stat)
+	}
+	allStats.statsMutex.Unlock()
+}

+ 19 - 1
psiphon/tunnel.go

@@ -21,7 +21,6 @@ package psiphon
 
 import (
 	"bytes"
-	"code.google.com/p/go.crypto/ssh"
 	"encoding/base64"
 	"encoding/json"
 	"errors"
@@ -30,6 +29,8 @@ import (
 	"strings"
 	"sync/atomic"
 	"time"
+
+	"code.google.com/p/go.crypto/ssh"
 )
 
 // Tunneler specifies the interface required by components that use a tunnel.
@@ -69,6 +70,7 @@ type Tunnel struct {
 	sshKeepAliveQuit        chan struct{}
 	portForwardFailures     chan int
 	portForwardFailureTotal int
+	regexps                 *Regexps
 }
 
 // EstablishTunnel first makes a network transport connection to the
@@ -282,3 +284,19 @@ func (tunnel *Tunnel) SignalFailure() {
 	Notice(NOTICE_ALERT, "tunnel received failure signal")
 	tunnel.Close()
 }
+
+// ServerID provides a unique identifier for the server the tunnel connects to.
+// This ID is consistent between multiple tunnels connected to that server.
+func (tunnel *Tunnel) ServerID() string {
+	return tunnel.serverEntry.IpAddress
+}
+
+// StatsRegexps gets the Regexps used for the statistics for this tunnel.
+func (tunnel *Tunnel) StatsRegexps() *Regexps {
+	return tunnel.regexps
+}
+
+// SetStatsRegexps sets the Regexps used for the statistics for this tunnel.
+func (tunnel *Tunnel) SetStatsRegexps(regexps *Regexps) {
+	tunnel.regexps = regexps
+}