فهرست منبع

Merge branch 'master' into android-jni

Rod Hynes 11 سال پیش
والد
کامیت
c2a3eada36
9فایلهای تغییر یافته به همراه892 افزوده شده و 67 حذف شده
  1. 44 17
      psiphon/controller.go
  2. 43 22
      psiphon/dataStore.go
  3. 42 16
      psiphon/meekConn.go
  4. 38 11
      psiphon/serverApi.go
  5. 264 0
      psiphon/stats_collector.go
  6. 109 0
      psiphon/stats_conn.go
  7. 75 0
      psiphon/stats_regexp.go
  8. 258 0
      psiphon/stats_test.go
  9. 19 1
      psiphon/tunnel.go

+ 44 - 17
psiphon/controller.go

@@ -83,9 +83,11 @@ func NewController(config *Config) (controller *Controller) {
 // - a local SOCKS proxy that port forwards through the pool of tunnels
 // - a local HTTP proxy that port forwards through the pool of tunnels
 func (controller *Controller) Run(shutdownBroadcast <-chan struct{}) {
-
 	Notice(NOTICE_VERSION, VERSION)
 
+	Stats_Start()
+	defer Stats_Stop()
+
 	socksProxy, err := NewSocksProxy(controller.config, controller)
 	if err != nil {
 		Notice(NOTICE_ALERT, "error initializing local SOCKS proxy: %s", err)
@@ -378,7 +380,7 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 
 	Notice(NOTICE_INFO, "starting session for %s", tunnel.serverEntry.IpAddress)
 	// TODO: NewSession server API calls may block shutdown
-	_, err = NewSession(controller.config, tunnel)
+	session, err := NewSession(controller.config, tunnel)
 	if err != nil {
 		err = fmt.Errorf("error starting session for %s: %s", tunnel.serverEntry.IpAddress, err)
 	}
@@ -390,6 +392,8 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 	// of the first candidates next time establish runs.
 	PromoteServerEntry(tunnel.serverEntry.IpAddress)
 
+	statsTimer := time.NewTimer(NextSendPeriod())
+
 	for err == nil {
 		select {
 		case failures := <-tunnel.portForwardFailures:
@@ -407,8 +411,14 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 			err = errors.New("tunnel closed unexpectedly")
 
 		case <-controller.shutdownBroadcast:
+			// Send final stats
+			sendStats(tunnel, session, true)
 			Notice(NOTICE_INFO, "shutdown operate tunnel")
 			return
+
+		case <-statsTimer.C:
+			sendStats(tunnel, session, false)
+			statsTimer.Reset(NextSendPeriod())
 		}
 	}
 
@@ -425,6 +435,18 @@ func (controller *Controller) operateTunnel(tunnel *Tunnel) {
 	}
 }
 
+// sendStats is a helper for sending session stats to the server.
+func sendStats(tunnel *Tunnel, session *Session, final bool) {
+	payload := GetForServer(tunnel.serverEntry.IpAddress)
+	if payload != nil {
+		err := session.DoStatusRequest(payload, final)
+		if err != nil {
+			Notice(NOTICE_ALERT, "DoStatusRequest failed for %s: %s", tunnel.serverEntry.IpAddress, err)
+			PutBack(tunnel.serverEntry.IpAddress, payload)
+		}
+	}
+}
+
 // TunneledConn implements net.Conn and wraps a port foward connection.
 // It is used to hook into Read and Write to observe I/O errors and
 // report these errors back to the tunnel monitor as port forward failures.
@@ -476,10 +498,14 @@ func (controller *Controller) Dial(remoteAddr string) (conn net.Conn, err error)
 		}
 		return nil, ContextError(err)
 	}
-	return &TunneledConn{
-			Conn:   tunnelConn,
-			tunnel: tunnel},
-		nil
+
+	statsConn := NewStatsConn(tunnelConn, tunnel.ServerID(), tunnel.StatsRegexps())
+
+	conn = &TunneledConn{
+		Conn:   statsConn,
+		tunnel: tunnel}
+
+	return
 }
 
 // startEstablishing creates a pool of worker goroutines which will
@@ -531,18 +557,18 @@ func (controller *Controller) stopEstablishing() {
 // servers with higher rank are priority candidates.
 func (controller *Controller) establishCandidateGenerator() {
 	defer controller.establishWaitGroup.Done()
+
+	iterator, err := NewServerEntryIterator(
+		controller.config.EgressRegion, controller.config.TunnelProtocol)
+	if err != nil {
+		Notice(NOTICE_ALERT, "failed to iterate over candidates: %s", err)
+		controller.SignalFailure()
+		return
+	}
+	defer iterator.Close()
+
 loop:
 	for {
-		// Note: it's possible that an active tunnel in excludeServerEntries will
-		// fail during this iteration of server entries and in that case the
-		// cooresponding server will not be retried (within the same iteration).
-		iterator, err := NewServerEntryIterator(
-			controller.config.EgressRegion, controller.config.TunnelProtocol)
-		if err != nil {
-			Notice(NOTICE_ALERT, "failed to iterate over candidates: %s", err)
-			controller.SignalFailure()
-			break loop
-		}
 		for {
 			serverEntry, err := iterator.Next()
 			if err != nil {
@@ -562,7 +588,8 @@ loop:
 				break loop
 			}
 		}
-		iterator.Close()
+		iterator.Reset()
+
 		// After a complete iteration of candidate servers, pause before iterating again.
 		// This helps avoid some busy wait loop conditions, and also allows some time for
 		// network conditions to change.

+ 43 - 22
psiphon/dataStore.go

@@ -53,13 +53,14 @@ func InitDataStore(filename string) (err error) {
              rank integer not null unique,
              region text not null,
              data blob not null);
-	    create table if not exists serverEntryProtocol
-	        (serverEntryId text not null,
-	         protocol text not null);
+        create table if not exists serverEntryProtocol
+            (serverEntryId text not null,
+             protocol text not null,
+             primary key (serverEntryId, protocol));
         create table if not exists keyValue
-            (key text not null,
+            (key text not null primary key,
              value text not null);
-		pragma journal_mode=WAL;
+        pragma journal_mode=WAL;
         `
 		var db *sql.DB
 		db, err = sql.Open(
@@ -130,27 +131,35 @@ func transactionWithRetry(updater func(*sql.Tx) error) error {
 
 // serverEntryExists returns true if a serverEntry with the
 // given ipAddress id already exists.
-func serverEntryExists(transaction *sql.Tx, ipAddress string) bool {
+func serverEntryExists(transaction *sql.Tx, ipAddress string) (bool, error) {
 	query := "select count(*) from serverEntry where id  = ?;"
 	var count int
 	err := singleton.db.QueryRow(query, ipAddress).Scan(&count)
-	return err == nil && count > 0
+	if err != nil {
+		return false, ContextError(err)
+	}
+	return count > 0, nil
 }
 
-// StoreServerEntry adds the server entry to the data store. A newly
-// stored (or re-stored) server entry is assigned the next-to-top rank
-// for cycle order (the previous top ranked entry is promoted). The
-// purpose of this is to keep the last selected server as the top
-// ranked server.
+// StoreServerEntry adds the server entry to the data store.
+// A newly stored (or re-stored) server entry is assigned the next-to-top
+// rank for iteration order (the previous top ranked entry is promoted). The
+// purpose of inserting at next-to-top is to keep the last selected server
+// as the top ranked server. Note, server candidates are iterated in decending
+// rank order, so the largest rank is top rank.
 // When replaceIfExists is true, an existing server entry record is
 // overwritten; otherwise, the existing record is unchanged.
 func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 	return transactionWithRetry(func(transaction *sql.Tx) error {
-		serverEntryExists := serverEntryExists(transaction, serverEntry.IpAddress)
+		serverEntryExists, err := serverEntryExists(transaction, serverEntry.IpAddress)
+		if err != nil {
+			return ContextError(err)
+		}
 		if serverEntryExists && !replaceIfExists {
+			// Nothing more to do
 			return nil
 		}
-		_, err := transaction.Exec(`
+		_, err = transaction.Exec(`
             update serverEntry set rank = rank + 1
                 where id = (select id from serverEntry order by rank desc limit 1);
             `)
@@ -166,6 +175,12 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
             insert or replace into serverEntry (id, rank, region, data)
             values (?, (select coalesce(max(rank)-1, 0) from serverEntry), ?, ?);
             `, serverEntry.IpAddress, serverEntry.Region, data)
+		if err != nil {
+			return err
+		}
+		_, err = transaction.Exec(`
+            delete from serverEntryProtocol where serverEntryId = ?;
+            `, serverEntry.IpAddress)
 		if err != nil {
 			return err
 		}
@@ -175,9 +190,9 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 			requiredCapability := strings.TrimSuffix(protocol, "-OSSH")
 			if Contains(serverEntry.Capabilities, requiredCapability) {
 				_, err = transaction.Exec(`
-		            insert or ignore into serverEntryProtocol (serverEntryId, protocol)
-		            values (?, ?);
-		            `, serverEntry.IpAddress, protocol)
+                    insert into serverEntryProtocol (serverEntryId, protocol)
+                    values (?, ?);
+                    `, serverEntry.IpAddress, protocol)
 				if err != nil {
 					return err
 				}
@@ -191,9 +206,10 @@ func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
 	})
 }
 
-// PromoteServerEntry assigns the top cycle rank to the specified
-// server entry. This server entry will be the first candidate in
-// a subsequent tunnel establishment.
+// PromoteServerEntry assigns the top rank (one more than current
+// max rank) to the specified server entry. Server candidates are
+// iterated in decending rank order, so this server entry will be
+// the first candidate in a subsequent tunnel establishment.
 func PromoteServerEntry(ipAddress string) error {
 	return transactionWithRetry(func(transaction *sql.Tx) error {
 		_, err := transaction.Exec(`
@@ -344,6 +360,11 @@ func HasServerEntries(region, protocol string) bool {
 	query := "select count(*) from serverEntry" + whereClause
 	err := singleton.db.QueryRow(query, whereParams...).Scan(&count)
 
+	if err != nil {
+		Notice(NOTICE_ALERT, "HasServerEntries failed: %s", err)
+		return false
+	}
+
 	if region == "" {
 		region = "(any)"
 	}
@@ -353,7 +374,7 @@ func HasServerEntries(region, protocol string) bool {
 	Notice(NOTICE_INFO, "servers for region %s and protocol %s: %d",
 		region, protocol, count)
 
-	return err == nil && count > 0
+	return count > 0
 }
 
 // GetServerEntryIpAddresses returns an array containing
@@ -395,7 +416,7 @@ func SetKeyValue(key, value string) error {
 	})
 }
 
-// GetLastConnected retrieves a key/value pair. If not found,
+// GetKeyValue retrieves the value for a given key. If not found,
 // it returns an empty string value.
 func GetKeyValue(key string) (value string, err error) {
 	checkInitDataStore()

+ 42 - 16
psiphon/meekConn.go

@@ -21,7 +21,6 @@ package psiphon
 
 import (
 	"bytes"
-	"code.google.com/p/go.crypto/nacl/box"
 	"crypto/rand"
 	"encoding/base64"
 	"encoding/json"
@@ -33,6 +32,8 @@ import (
 	"net/url"
 	"sync"
 	"time"
+
+	"code.google.com/p/go.crypto/nacl/box"
 )
 
 // MeekConn is based on meek-client.go from Tor and Psiphon:
@@ -193,7 +194,6 @@ func (meek *MeekConn) SetClosedSignal(closedSignal chan struct{}) (err error) {
 // Close terminates the meek connection. Close waits for the relay processing goroutine
 // to stop and releases HTTP transport resources.
 // A mutex is required to support psiphon.Conn.SetClosedSignal concurrency semantics.
-// NOTE: currently doesn't interrupt any HTTP request in flight.
 func (meek *MeekConn) Close() (err error) {
 	meek.mutex.Lock()
 	defer meek.mutex.Unlock()
@@ -201,9 +201,6 @@ func (meek *MeekConn) Close() (err error) {
 		close(meek.broadcastClosed)
 		meek.pendingConns.CloseAll()
 		meek.relayWaitGroup.Wait()
-		// TODO: meek.transport.CancelRequest() for current in-flight request?
-		// (currently pendingConns will abort establishing connections, but not
-		// established persistent connections)
 		meek.transport.CloseIdleConnections()
 		meek.isClosed = true
 		select {
@@ -337,6 +334,7 @@ func (meek *MeekConn) relay() {
 		case <-timeout.C:
 			// In the polling case, send an empty payload
 		case <-meek.broadcastClosed:
+			// TODO: timeout case may be selected when broadcastClosed is set?
 			return
 		}
 		sendPayloadSize := 0
@@ -356,6 +354,10 @@ func (meek *MeekConn) relay() {
 			go meek.Close()
 			return
 		}
+		if receivedPayload == nil {
+			// In this case, meek.roundTrip encountered broadcastClosed. Exit without error.
+			return
+		}
 		receivedPayloadSize, err := meek.readPayload(receivedPayload)
 		if err != nil {
 			Notice(NOTICE_ALERT, "%s", ContextError(err))
@@ -417,13 +419,37 @@ func (meek *MeekConn) roundTrip(sendPayload []byte) (receivedPayload io.ReadClos
 	request.Header.Set("User-Agent", "")
 	request.Header.Set("Content-Type", "application/octet-stream")
 	request.AddCookie(meek.cookie)
-	// This retry mitigates intermittent failures between the client and front/server.
+
+	// The retry mitigates intermittent failures between the client and front/server.
 	// Note: Retry will only be effective if entire request failed (underlying transport protocol
 	// such as SSH will fail if extra bytes are replayed in either direction due to partial relay
 	// success followed by retry).
 	var response *http.Response
-	for i := 0; i <= 1; i++ {
-		response, err = meek.transport.RoundTrip(request)
+	for retry := 0; retry <= 1; retry++ {
+
+		// The http.Transport.RoundTrip is run in a goroutine to enable cancelling a request in-flight.
+		type roundTripResponse struct {
+			response *http.Response
+			err      error
+		}
+		roundTripResponseChannel := make(chan *roundTripResponse, 1)
+		roundTripWaitGroup := new(sync.WaitGroup)
+		roundTripWaitGroup.Add(1)
+		go func() {
+			defer roundTripWaitGroup.Done()
+			r, err := meek.transport.RoundTrip(request)
+			roundTripResponseChannel <- &roundTripResponse{r, err}
+		}()
+		select {
+		case roundTripResponse := <-roundTripResponseChannel:
+			response = roundTripResponse.response
+			err = roundTripResponse.err
+		case <-meek.broadcastClosed:
+			meek.transport.CancelRequest(request)
+			return nil, nil
+		}
+		roundTripWaitGroup.Wait()
+
 		if err == nil {
 			break
 		}
@@ -434,14 +460,14 @@ func (meek *MeekConn) roundTrip(sendPayload []byte) (receivedPayload io.ReadClos
 	if response.StatusCode != http.StatusOK {
 		return nil, ContextError(fmt.Errorf("http request failed %d", response.StatusCode))
 	}
-        // observe response cookies for meek session key token.
-        // Once found it must be used for all consecutive requests made to the server
-        for _, c := range response.Cookies() {
-            if meek.cookie.Name == c.Name {
-                meek.cookie.Value = c.Value
-                break
-            }
-        }
+	// observe response cookies for meek session key token.
+	// Once found it must be used for all consecutive requests made to the server
+	for _, c := range response.Cookies() {
+		if meek.cookie.Name == c.Name {
+			meek.cookie.Value = c.Value
+			break
+		}
+	}
 	return response.Body, nil
 }
 

+ 38 - 11
psiphon/serverApi.go

@@ -24,6 +24,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"net"
 	"net/http"
@@ -71,8 +72,26 @@ func NewSession(config *Config, tunnel *Tunnel) (session *Session, err error) {
 	return session, nil
 }
 
-func (session *Session) DoStatusRequest() {
-	// TODO: implement (required for page view stats)
+// DoStatusRequest makes a /status request to the server, sending session stats.
+// final should be true if this is the last such request before disconnecting.
+func (session *Session) DoStatusRequest(statsPayload json.Marshaler, final bool) error {
+	statsPayloadJSON, err := json.Marshal(statsPayload)
+	if err != nil {
+		return ContextError(err)
+	}
+
+	connected := "1"
+	if final {
+		connected = "0"
+	}
+
+	url := session.buildRequestUrl(
+		"status",
+		&ExtraParam{"session_id", session.tunnel.sessionId},
+		&ExtraParam{"connected", connected})
+
+	err = session.doPostRequest(url, "application/json", bytes.NewReader(statsPayloadJSON))
+	return ContextError(err)
 }
 
 // doHandshakeRequest performs the handshake API request. The handshake
@@ -142,15 +161,9 @@ func (session *Session) doHandshakeRequest() error {
 	if upgradeClientVersion > session.config.ClientVersion {
 		Notice(NOTICE_UPGRADE, "%d", upgradeClientVersion)
 	}
-	// TODO: remove regex notices -- regexes will be used internally
-	/*
-		for _, pageViewRegex := range handshakeConfig.PageViewRegexes {
-			Notice(NOTICE_PAGE_VIEW_REGEX, "%s %s", pageViewRegex["regex"], pageViewRegex["replace"])
-		}
-		for _, httpsRequestRegex := range handshakeConfig.HttpsRequestRegexes {
-			Notice(NOTICE_HTTPS_REGEX, "%s %s", httpsRequestRegex["regex"], httpsRequestRegex["replace"])
-		}
-	*/
+	session.tunnel.SetStatsRegexps(MakeRegexps(
+		handshakeConfig.PageViewRegexes,
+		handshakeConfig.HttpsRequestRegexes))
 	return nil
 }
 
@@ -250,6 +263,20 @@ func (session *Session) doGetRequest(requestUrl string) (responseBody []byte, er
 	return body, nil
 }
 
+// doPostRequest makes a tunneled HTTPS POST request.
+func (session *Session) doPostRequest(requestUrl string, bodyType string, body io.Reader) (err error) {
+	response, err := session.psiphonHttpsClient.Post(requestUrl, bodyType, body)
+	if err != nil {
+		// Trim this error since it may include long URLs
+		return ContextError(TrimError(err))
+	}
+	response.Body.Close()
+	if response.StatusCode != http.StatusOK {
+		return ContextError(fmt.Errorf("HTTP POST request failed with response code: %d", response.StatusCode))
+	}
+	return
+}
+
 // makeHttpsClient creates a Psiphon HTTPS client that tunnels requests and which validates
 // the web server using the Psiphon server entry web server certificate.
 // This is not a general purpose HTTPS client.

+ 264 - 0
psiphon/stats_collector.go

@@ -0,0 +1,264 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import (
+	"encoding/base64"
+	"encoding/json"
+	"sync"
+	"time"
+)
+
+// TODO: Stats for a server are only removed when they are sent in a status
+// update to that server. So if there's an unexpected disconnect from serverA
+// and then a reconnect to serverB, the stats for serverA will never get sent
+// (unless there's later a reconnect to serverA). That means the stats for
+// serverA will never get deleted and the memory won't get freed. This is only
+// a small amount of memory (< 1KB, probably), but we should still probably add
+// some kind of stale-stats cleanup.
+
+// _CHANNEL_CAPACITY is the size of the channel that connections use to send stats
+// bundles to the collector/processor.
+var _CHANNEL_CAPACITY = 1000
+
+// Per-host/domain stats.
+// Note that the bytes we're counting are the ones going into the tunnel, so do
+// not include transport overhead.
+type hostStats struct {
+	numBytesSent     int64
+	numBytesReceived int64
+}
+
+func newHostStats() *hostStats {
+	return &hostStats{}
+}
+
+// serverStats holds per-server stats.
+type serverStats struct {
+	hostnameToStats map[string]*hostStats
+}
+
+func newServerStats() *serverStats {
+	return &serverStats{
+		hostnameToStats: make(map[string]*hostStats),
+	}
+}
+
+// allStats is the root object that holds stats for all servers and all hosts,
+// as well as the mutex to access them, the channel to update them, etc.
+var allStats struct {
+	serverIDtoStats    map[string]*serverStats
+	statsMutex         sync.RWMutex
+	stopSignal         chan struct{}
+	statsChan          chan []*statsUpdate
+	processorWaitGroup sync.WaitGroup
+}
+
+// Start initializes and begins stats collection. Must be called once, when the
+// application starts.
+func Stats_Start() {
+	if allStats.stopSignal != nil {
+		return
+	}
+
+	allStats.serverIDtoStats = make(map[string]*serverStats)
+	allStats.stopSignal = make(chan struct{})
+	allStats.statsChan = make(chan []*statsUpdate, _CHANNEL_CAPACITY)
+
+	allStats.processorWaitGroup.Add(1)
+	go processStats()
+}
+
+// Stop ends stats collection. Must be called once, before the application
+// terminates.
+func Stats_Stop() {
+	if allStats.stopSignal != nil {
+		close(allStats.stopSignal)
+		allStats.processorWaitGroup.Wait()
+		allStats.stopSignal = nil
+	}
+}
+
+// Instances of statsUpdate will be sent through the connection-to-collector
+// channel.
+type statsUpdate struct {
+	serverID         string
+	hostname         string
+	numBytesSent     int64
+	numBytesReceived int64
+}
+
+// recordStats makes sure the given stats update is added to the global
+// collection. Guaranteed to not block.
+// Callers of this function should assume that it "takes control" of the
+// statsUpdate object.
+func recordStat(newStat *statsUpdate) {
+	statSlice := []*statsUpdate{newStat}
+	// Priority: Don't block connections when updating stats. We can't just
+	// write to the statsChan, since that will block if it's full. We could
+	// launch a goroutine for each update, but that seems like  unnecessary
+	// overhead. So we'll try to write to the channel, and launch a goro if it
+	// fails.
+	select {
+	case allStats.statsChan <- statSlice:
+	default:
+		go func() {
+			allStats.statsChan <- statSlice
+		}()
+	}
+}
+
+// processStats is a goro started by Start() and runs until Stop(). It collects
+// stats provided by StatsConn.
+func processStats() {
+	defer allStats.processorWaitGroup.Done()
+
+	for {
+		select {
+		case statSlice := <-allStats.statsChan:
+			allStats.statsMutex.Lock()
+
+			for _, stat := range statSlice {
+				if stat.hostname == "" {
+					stat.hostname = "(OTHER)"
+				}
+
+				storedServerStats := allStats.serverIDtoStats[stat.serverID]
+				if storedServerStats == nil {
+					storedServerStats = newServerStats()
+					allStats.serverIDtoStats[stat.serverID] = storedServerStats
+				}
+
+				storedHostStats := storedServerStats.hostnameToStats[stat.hostname]
+				if storedHostStats == nil {
+					storedHostStats = newHostStats()
+					storedServerStats.hostnameToStats[stat.hostname] = storedHostStats
+				}
+
+				storedHostStats.numBytesSent += stat.numBytesSent
+				storedHostStats.numBytesReceived += stat.numBytesReceived
+
+				//fmt.Println("server:", stat.serverID, "host:", stat.hostname, "sent:", storedHostStats.numBytesSent, "received:", storedHostStats.numBytesReceived)
+			}
+
+			allStats.statsMutex.Unlock()
+
+		default:
+			// Note that we only checking the stopSignal in the default case. This is
+			// because we don't want the statsChan to fill and block the connections
+			// sending to it. The connections have their own signals, so they will
+			// stop themselves, we will drain the channel, and then we will stop.
+			select {
+			case <-allStats.stopSignal:
+				return
+			default:
+			}
+		}
+	}
+}
+
+// NextSendPeriod returns the amount of time that should be waited before the
+// next time stats are sent.
+func NextSendPeriod() (duration time.Duration) {
+	defaultStatsSendDuration := 5 * 60 * 1000 // 5 minutes in millis
+
+	// We include a random component to make the stats send less fingerprintable.
+	jitter, err := MakeSecureRandomInt(defaultStatsSendDuration)
+
+	// In case of error we're just going to use zero jitter.
+	if err != nil {
+		Notice(NOTICE_ALERT, "stats.NextSendPeriod: MakeSecureRandomInt failed")
+	}
+
+	duration = time.Duration(defaultStatsSendDuration+jitter) * time.Millisecond
+	return
+}
+
+// Implement the json.Marshaler interface
+func (ss serverStats) MarshalJSON() ([]byte, error) {
+	out := make(map[string]interface{})
+
+	// Add a random amount of padding to help prevent stats updates from being
+	// a predictable size (which often happens when the connection is quiet).
+	var padding []byte
+	paddingSize, err := MakeSecureRandomInt(256)
+	// In case of randomness fail, we're going to proceed with zero padding.
+	// TODO: Is this okay?
+	if err != nil {
+		Notice(NOTICE_ALERT, "stats.serverStats.MarshalJSON: MakeSecureRandomInt failed")
+		padding = make([]byte, 0)
+	} else {
+		padding, err = MakeSecureRandomBytes(paddingSize)
+		if err != nil {
+			Notice(NOTICE_ALERT, "stats.serverStats.MarshalJSON: MakeSecureRandomBytes failed")
+			padding = make([]byte, 0)
+		}
+	}
+
+	hostBytes := make(map[string]int64)
+	bytesTransferred := int64(0)
+
+	for hostname, hostStats := range ss.hostnameToStats {
+		totalBytes := hostStats.numBytesReceived + hostStats.numBytesSent
+		bytesTransferred += totalBytes
+		hostBytes[hostname] = totalBytes
+	}
+
+	out["bytes_transferred"] = bytesTransferred
+	out["host_bytes"] = hostBytes
+
+	// Print the notice before adding the padding, since it's not interesting
+	noticeJSON, _ := json.Marshal(out)
+	Notice(NOTICE_INFO, "sending stats: %s", noticeJSON)
+
+	out["padding"] = base64.StdEncoding.EncodeToString(padding)
+
+	// We're not using these fields, but the server requires them
+	out["page_views"] = make([]string, 0)
+	out["https_requests"] = make([]string, 0)
+
+	return json.Marshal(out)
+}
+
+// GetForServer returns the json-able stats package for the given server.
+// If there are no stats, nil will be returned.
+func GetForServer(serverID string) (payload *serverStats) {
+	allStats.statsMutex.Lock()
+	defer allStats.statsMutex.Unlock()
+
+	payload = allStats.serverIDtoStats[serverID]
+	delete(allStats.serverIDtoStats, serverID)
+	return
+}
+
+// PutBack re-adds a set of server stats to the collection.
+func PutBack(serverID string, ss *serverStats) {
+	statSlice := make([]*statsUpdate, 0, len(ss.hostnameToStats))
+	for hostname, hoststats := range ss.hostnameToStats {
+		statSlice = append(statSlice, &statsUpdate{
+			serverID:         serverID,
+			hostname:         hostname,
+			numBytesSent:     hoststats.numBytesSent,
+			numBytesReceived: hoststats.numBytesReceived,
+		})
+	}
+
+	allStats.statsChan <- statSlice
+}

+ 109 - 0
psiphon/stats_conn.go

@@ -0,0 +1,109 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+// Package stats counts and keeps track of session stats. These are per-domain
+// bytes transferred and total bytes transferred.
+package psiphon
+
+/*
+Assumption: The same connection will not be used to access different hostnames
+	(even if, say, those hostnames map to the same server). If this does occur, we
+	will mis-attribute some bytes.
+Assumption: Enough of the first HTTP will be present in the first Write() call
+	for us to a) recognize that it is HTTP, and b) parse the hostname.
+		- If this turns out to not be generally true we will need to add buffering.
+*/
+
+import (
+	"bufio"
+	"bytes"
+	"net"
+	"net/http"
+)
+
+// StatsConn is to be used as an intermediate link in a chain of net.Conn objects.
+// It inspects requests and responses and derives stats from them.
+type StatsConn struct {
+	net.Conn
+	serverID   string
+	hostname   string
+	firstWrite bool
+	regexps    *Regexps
+}
+
+// NewStatsConn creates a StatsConn. serverID can be anything that uniquely
+// identifies the server; it will be passed to GetForServer() when retrieving
+// the accumulated stats.
+func NewStatsConn(nextConn net.Conn, serverID string, regexps *Regexps) *StatsConn {
+	return &StatsConn{
+		Conn:       nextConn,
+		serverID:   serverID,
+		firstWrite: true,
+		regexps:    regexps,
+	}
+}
+
+// Write is called when requests are being written out through the tunnel to
+// the remote server.
+func (conn *StatsConn) Write(buffer []byte) (n int, err error) {
+	// First pass the data down the chain.
+	n, err = conn.Conn.Write(buffer)
+
+	// Count stats before we check the error condition. It could happen that the
+	// buffer was partially written and then an error occurred.
+	if n > 0 {
+		// If this is the first request, try to determine the hostname to associate
+		// with this connection.
+		if conn.firstWrite {
+			conn.firstWrite = false
+
+			// Check if this is a HTTP request
+			bufferReader := bufio.NewReader(bytes.NewReader(buffer))
+			httpReq, httpErr := http.ReadRequest(bufferReader)
+			if httpErr == nil {
+				// Get the hostname value that will be stored in stats by
+				// regexing the real hostname.
+				conn.hostname = regexHostname(httpReq.Host, conn.regexps)
+			}
+		}
+
+		recordStat(&statsUpdate{
+			conn.serverID,
+			conn.hostname,
+			int64(n),
+			0})
+	}
+
+	return
+}
+
+// Read is called when responses to requests are being read from the remote server.
+func (conn *StatsConn) Read(buffer []byte) (n int, err error) {
+	n, err = conn.Conn.Read(buffer)
+
+	// Count bytes without checking the error condition. It could happen that the
+	// buffer was partially read and then an error occurred.
+	recordStat(&statsUpdate{
+		conn.serverID,
+		conn.hostname,
+		0,
+		int64(n)})
+
+	return
+}

+ 75 - 0
psiphon/stats_regexp.go

@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import "regexp"
+
+type regexpReplace struct {
+	regexp  *regexp.Regexp
+	replace string
+}
+
+// Regexps holds the regular expressions and replacement strings used for
+// transforming URLs and hostnames into a stats-appropriate forms.
+type Regexps []regexpReplace
+
+// MakeRegexps takes the raw string-map form of the regex-replace pairs
+// returned by the server handshake and turns them into a usable object.
+func MakeRegexps(pageViewRegexes, httpsRequestRegexes []map[string]string) *Regexps {
+	regexps := make(Regexps, 0)
+
+	// We aren't doing page view stats anymore, so we won't process those regexps.
+	for _, rr := range httpsRequestRegexes {
+		regexString := rr["regex"]
+		if regexString == "" {
+			Notice(NOTICE_ALERT, "MakeRegexps: empty regex")
+			continue
+		}
+
+		replace := rr["replace"]
+		if replace == "" {
+			Notice(NOTICE_ALERT, "MakeRegexps: empty replace")
+			continue
+		}
+
+		regex, err := regexp.Compile(regexString)
+		if err != nil {
+			Notice(NOTICE_ALERT, "MakeRegexps: failed to compile regex: %s: %s", regexString, err)
+			continue
+		}
+
+		regexps = append(regexps, regexpReplace{regex, replace})
+	}
+
+	return &regexps
+}
+
+// regexHostname processes hostname through the given regexps and returns the
+// string that should be used for stats.
+func regexHostname(hostname string, regexps *Regexps) (statsHostname string) {
+	statsHostname = "(OTHER)"
+	for _, rr := range *regexps {
+		if rr.regexp.MatchString(hostname) {
+			statsHostname = rr.regexp.ReplaceAllString(hostname, rr.replace)
+			break
+		}
+	}
+	return
+}

+ 258 - 0
psiphon/stats_test.go

@@ -0,0 +1,258 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import (
+	"encoding/json"
+	"errors"
+	"net"
+	"net/http"
+	"testing"
+	"time"
+
+	mapset "github.com/deckarep/golang-set"
+	"github.com/stretchr/testify/suite"
+)
+
+var _SERVER_ID = "myserverid"
+
+type StatsTestSuite struct {
+	suite.Suite
+	httpClient *http.Client
+}
+
+func TestStatsTestSuite(t *testing.T) {
+	suite.Run(t, new(StatsTestSuite))
+}
+
+func (suite *StatsTestSuite) SetupTest() {
+	Stats_Start()
+
+	re := make(Regexps, 0)
+	suite.httpClient = &http.Client{
+		Transport: &http.Transport{
+			Dial: makeStatsDialer(_SERVER_ID, &re),
+		},
+	}
+}
+
+func (suite *StatsTestSuite) TearDownTest() {
+	suite.httpClient = nil
+	Stats_Stop()
+}
+
+func makeStatsDialer(serverID string, regexps *Regexps) func(network, addr string) (conn net.Conn, err error) {
+	return func(network, addr string) (conn net.Conn, err error) {
+		var subConn net.Conn
+
+		switch network {
+		case "tcp", "tcp4", "tcp6":
+			tcpAddr, err := net.ResolveTCPAddr(network, addr)
+			if err != nil {
+				return nil, err
+			}
+			subConn, err = net.DialTCP(network, nil, tcpAddr)
+			if err != nil {
+				return nil, err
+			}
+		default:
+			err = errors.New("using an unsupported testing network type")
+			return
+		}
+
+		conn = NewStatsConn(subConn, serverID, regexps)
+		err = nil
+		return
+	}
+}
+
+func (suite *StatsTestSuite) Test_StartStop() {
+	// Make sure Start and Stop calls don't crash
+	Stats_Start()
+	Stats_Start()
+	Stats_Stop()
+	Stats_Stop()
+	Stats_Start()
+	Stats_Stop()
+}
+
+func (suite *StatsTestSuite) Test_NextSendPeriod() {
+	res1 := NextSendPeriod()
+	suite.True(res1 > time.Duration(0), "duration should not be zero")
+
+	res2 := NextSendPeriod()
+	suite.NotEqual(res1, res2, "duration should have randomness difference between calls")
+}
+
+func (suite *StatsTestSuite) Test_StatsConn() {
+	resp, err := suite.httpClient.Get("http://example.com/index.html")
+	suite.Nil(err, "basic HTTP requests should succeed (1)")
+	resp.Body.Close()
+
+	resp, err = suite.httpClient.Get("http://example.org/index.html")
+	suite.Nil(err, "basic HTTP requests should succeed (1)")
+	resp.Body.Close()
+}
+
+func (suite *StatsTestSuite) Test_GetForServer() {
+	payload := GetForServer(_SERVER_ID)
+	suite.Nil(payload, "should get nil stats before any traffic (but not crash)")
+
+	resp, err := suite.httpClient.Get("http://example.com/index.html")
+	suite.Nil(err, "need successful http to proceed with tests")
+	resp.Body.Close()
+
+	// Make sure there aren't stats returned for a bad server ID
+	payload = GetForServer("INVALID")
+	suite.Nil(payload, "should get nil stats for invalid server ID")
+
+	payload = GetForServer(_SERVER_ID)
+	suite.NotNil(payload, "should receive valid payload for valid server ID")
+
+	payloadJSON, err := json.Marshal(payload)
+	var parsedJSON interface{}
+	err = json.Unmarshal(payloadJSON, &parsedJSON)
+	suite.Nil(err, "payload JSON should parse successfully")
+
+	// After we retrieve the stats for a server, they should be cleared out of the tracked stats
+	payload = GetForServer(_SERVER_ID)
+	suite.Nil(payload, "after retrieving stats for a server, there should be no more stats (until more data goes through)")
+}
+
+func (suite *StatsTestSuite) Test_PutBack() {
+	resp, err := suite.httpClient.Get("http://example.com/index.html")
+	suite.Nil(err, "need successful http to proceed with tests")
+	resp.Body.Close()
+
+	payloadToPutBack := GetForServer(_SERVER_ID)
+	suite.NotNil(payloadToPutBack, "should receive valid payload for valid server ID")
+
+	payload := GetForServer(_SERVER_ID)
+	suite.Nil(payload, "should not be any remaining stats after getting them")
+
+	PutBack(_SERVER_ID, payloadToPutBack)
+	// PutBack is asynchronous, so we'll need to wait a moment for it to do its thing
+	<-time.After(100 * time.Millisecond)
+
+	payload = GetForServer(_SERVER_ID)
+	suite.NotNil(payload, "stats should be re-added after putting back")
+	suite.Equal(payload, payloadToPutBack, "stats should be the same as after the first retrieval")
+}
+
+func (suite *StatsTestSuite) Test_MakeRegexps() {
+	pageViewRegexes := []map[string]string{make(map[string]string)}
+	pageViewRegexes[0]["regex"] = `(^http://[a-z0-9\.]*\.example\.[a-z\.]*)/.*`
+	pageViewRegexes[0]["replace"] = "$1"
+
+	httpsRequestRegexes := []map[string]string{make(map[string]string), make(map[string]string)}
+	httpsRequestRegexes[0]["regex"] = `^[a-z0-9\.]*\.(example\.com)$`
+	httpsRequestRegexes[0]["replace"] = "$1"
+	httpsRequestRegexes[1]["regex"] = `^.*example\.org$`
+	httpsRequestRegexes[1]["replace"] = "replacement"
+
+	regexps := MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+	suite.NotNil(regexps, "should return a valid object")
+	suite.Len(*regexps, 2, "should only have processed httpsRequestRegexes")
+
+	//
+	// Test some bad regexps
+	//
+
+	httpsRequestRegexes[0]["regex"] = ""
+	httpsRequestRegexes[0]["replace"] = "$1"
+	regexps = MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+	suite.NotNil(regexps, "should return a valid object")
+	suite.Len(*regexps, 1, "should have discarded one regexp")
+
+	httpsRequestRegexes[0]["regex"] = `^[a-z0-9\.]*\.(example\.com)$`
+	httpsRequestRegexes[0]["replace"] = ""
+	regexps = MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+	suite.NotNil(regexps, "should return a valid object")
+	suite.Len(*regexps, 1, "should have discarded one regexp")
+
+	httpsRequestRegexes[0]["regex"] = `^[a-z0-9\.]*\.(example\.com$` // missing closing paren
+	httpsRequestRegexes[0]["replace"] = "$1"
+	regexps = MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+	suite.NotNil(regexps, "should return a valid object")
+	suite.Len(*regexps, 1, "should have discarded one regexp")
+}
+
+func (suite *StatsTestSuite) Test_Regex() {
+	// We'll make a new client with actual regexps.
+	pageViewRegexes := make([]map[string]string, 0)
+	httpsRequestRegexes := []map[string]string{make(map[string]string), make(map[string]string)}
+	httpsRequestRegexes[0]["regex"] = `^[a-z0-9\.]*\.(example\.com)$`
+	httpsRequestRegexes[0]["replace"] = "$1"
+	httpsRequestRegexes[1]["regex"] = `^.*example\.org$`
+	httpsRequestRegexes[1]["replace"] = "replacement"
+	regexps := MakeRegexps(pageViewRegexes, httpsRequestRegexes)
+
+	suite.httpClient = &http.Client{
+		Transport: &http.Transport{
+			Dial: makeStatsDialer(_SERVER_ID, regexps),
+		},
+	}
+
+	// No subdomain, so won't match regex
+	resp, err := suite.httpClient.Get("http://example.com/index.html")
+	suite.Nil(err)
+	resp.Body.Close()
+
+	// Will match the first regex
+	resp, err = suite.httpClient.Get("http://www.example.com/index.html")
+	suite.Nil(err)
+	resp.Body.Close()
+
+	// Will match the second regex
+	resp, err = suite.httpClient.Get("http://example.org/index.html")
+	suite.Nil(err)
+	resp.Body.Close()
+
+	payload := GetForServer(_SERVER_ID)
+	suite.NotNil(payload, "should get stats because we made HTTP reqs")
+
+	expectedHostnames := mapset.NewSet()
+	expectedHostnames.Add("(OTHER)")
+	expectedHostnames.Add("example.com")
+	expectedHostnames.Add("replacement")
+
+	hostnames := make([]interface{}, 0)
+	for hostname := range payload.hostnameToStats {
+		hostnames = append(hostnames, hostname)
+	}
+
+	actualHostnames := mapset.NewSetFromSlice(hostnames)
+
+	suite.Equal(expectedHostnames, actualHostnames, "post-regex hostnames should be processed as expecteds")
+}
+
+func (suite *StatsTestSuite) Test_recordStat() {
+	// The normal operation of this function will get exercised during the
+	// other tests, but there is a code branch that only gets hit when the
+	// allStats.statsChan is filled. To make sure we fill the channel, we will
+	// lock the stats access mutex, try to record a bunch of stats, and then
+	// release it.
+	allStats.statsMutex.Lock()
+	stat := statsUpdate{"test", "test", 1, 1}
+	for i := 0; i < _CHANNEL_CAPACITY*2; i++ {
+		recordStat(&stat)
+	}
+	allStats.statsMutex.Unlock()
+}

+ 19 - 1
psiphon/tunnel.go

@@ -21,7 +21,6 @@ package psiphon
 
 import (
 	"bytes"
-	"code.google.com/p/go.crypto/ssh"
 	"encoding/base64"
 	"encoding/json"
 	"errors"
@@ -30,6 +29,8 @@ import (
 	"strings"
 	"sync/atomic"
 	"time"
+
+	"code.google.com/p/go.crypto/ssh"
 )
 
 // Tunneler specifies the interface required by components that use a tunnel.
@@ -69,6 +70,7 @@ type Tunnel struct {
 	sshKeepAliveQuit        chan struct{}
 	portForwardFailures     chan int
 	portForwardFailureTotal int
+	regexps                 *Regexps
 }
 
 // EstablishTunnel first makes a network transport connection to the
@@ -282,3 +284,19 @@ func (tunnel *Tunnel) SignalFailure() {
 	Notice(NOTICE_ALERT, "tunnel received failure signal")
 	tunnel.Close()
 }
+
+// ServerID provides a unique identifier for the server the tunnel connects to.
+// This ID is consistent between multiple tunnels connected to that server.
+func (tunnel *Tunnel) ServerID() string {
+	return tunnel.serverEntry.IpAddress
+}
+
+// StatsRegexps gets the Regexps used for the statistics for this tunnel.
+func (tunnel *Tunnel) StatsRegexps() *Regexps {
+	return tunnel.regexps
+}
+
+// SetStatsRegexps sets the Regexps used for the statistics for this tunnel.
+func (tunnel *Tunnel) SetStatsRegexps(regexps *Regexps) {
+	tunnel.regexps = regexps
+}