浏览代码

Implemented persistent server entry data store and integrated with fetch remote server list and run tunnel. Fetch remote server list is now performed asynchronously. Run tunnel now repeated cycles through the stored server entries.

Rod Hynes 11 年之前
父节点
当前提交
5a37f03b77
共有 7 个文件被更改,包括 352 次插入50 次删除
  1. 14 13
      README.md
  2. 0 1
      psiphon/conn.go
  3. 258 0
      psiphon/dataStore.go
  4. 8 5
      psiphon/defaults.go
  5. 15 11
      psiphon/remoteServerList.go
  6. 48 20
      psiphon/runTunnel.go
  7. 9 0
      psiphon/utils.go

+ 14 - 13
README.md

@@ -13,26 +13,28 @@ Status
 
 This project is currently at the proof-of-concept stage. Current production Psiphon client code is available at our [main repository](https://bitbucket.org/psiphon/psiphon-circumvention-system).
 
-### TODO
+### TODO (proof-of-concept)
+
+* StoreServerEntry must assign top rank - 1
+* use ContextError in more places
+* add Psiphon web requests: handshake/connected/etc.
 * psiphon.Conn for Windows
-* more test cases
+* build/test on Android and iOS
 * integrate meek-client
+* disconnect all local SOCKS clients when tunnel disconnected
+* log levels
+
+### TODO (future)
+
+* add a HTTP proxy (chain to SOCKS)
+* SSH keepalive (+ hook into disconnectedSignal)
+* SSH compression?
 * add config options
   * protocol preference; whether to try multiple protocols for each server
   * region preference
   * platform (for upgrade download)
-* SSH keepalive (+ hook into disconnectedSignal)
-* SSH compression?
-*  local SOCKS
-  * disconnect all local clients when tunnel disconnected
-  * use InterruptableConn?
-* run fetchRemoteServerList in parallel when already have entries
-* add a HTTP proxy (chain to SOCKS)
-* persist server entries
-* add Psiphon web requests: handshake/connected/etc.
 * implement page view stats
 * implement local traffic stats (e.g., to display bytes sent/received
-* build/test on Android and iOS
 * control interface (w/ event messages)?
 * VpnService compatibility
 * upstream proxy support
@@ -44,7 +46,6 @@ This project is currently at the proof-of-concept stage. Current production Psip
   * server can push preferred/optimized settings; client should use over defaults
   * e.g., etablish worker pool size; multiplex tunnel pool size
 
-
 Licensing
 --------------------------------------------------------------------------------
 

+ 0 - 1
psiphon/conn.go

@@ -105,7 +105,6 @@ func Dial(
 // if the connection is already closed (and would never send
 // the signal).
 func (conn *Conn) SetClosedSignal(closedSignal chan bool) (err error) {
-	// TEMP **** needs comments
 	conn.mutex.Lock()
 	defer conn.mutex.Unlock()
 	if conn.isClosed {

+ 258 - 0
psiphon/dataStore.go

@@ -0,0 +1,258 @@
+/*
+ * Copyright (c) 2014, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package psiphon
+
+import (
+	"database/sql"
+	"encoding/json"
+	"errors"
+	sqlite3 "github.com/mattn/go-sqlite3"
+	"log"
+	"sync"
+	"time"
+)
+
+type dataStore struct {
+	init sync.Once
+	db   *sql.DB
+}
+
+var singleton dataStore
+
+// initDataStore initializes the singleton instance of dataStore. This
+// function uses a sync.Once and is safe for use by concurrent goroutines.
+// The underlying sql.DB connection pool is also safe.
+func initDataStore() {
+	singleton.init.Do(func() {
+		const schema = `
+        create table if not exists serverEntry
+            (id text not null primary key,
+             data blob not null,
+             rank integer not null unique);
+        `
+		db, err := sql.Open("sqlite3", DATA_STORE_FILENAME)
+		if err != nil {
+			log.Fatal("initDataStore failed to open database: %s", err)
+		}
+		_, err = db.Exec(schema)
+		if err != nil {
+			log.Fatal("initDataStore failed to initialize schema: %s", err)
+		}
+		singleton.db = db
+	})
+}
+
+// transactionWithRetry will retry a write transaction if sqlite3
+// reports ErrBusy or ErrBusySnapshot -- i.e., if the XXXXX
+func transactionWithRetry(updater func(*sql.Tx) error) error {
+	initDataStore()
+	for i := 0; i < 10; i++ {
+		transaction, err := singleton.db.Begin()
+		if err != nil {
+			return ContextError(err)
+		}
+		err = updater(transaction)
+		if err != nil {
+			transaction.Rollback()
+			if sqlError, ok := err.(sqlite3.Error); ok &&
+				(sqlError.Code == sqlite3.ErrBusy ||
+					sqlError.ExtendedCode == sqlite3.ErrBusySnapshot) {
+				time.Sleep(100)
+				continue
+			}
+			return ContextError(err)
+		}
+		err = transaction.Commit()
+		if err != nil {
+			return ContextError(err)
+		}
+		return nil
+	}
+	return ContextError(errors.New("retries exhausted"))
+}
+
+// serverEntryExists returns true if a serverEntry with the
+// given ipAddress id already exists.
+func serverEntryExists(transaction *sql.Tx, ipAddress string) bool {
+	query := "select count(*) from serverEntry where id  = ?;"
+	var count int
+	err := singleton.db.QueryRow(query, ipAddress).Scan(&count)
+	return err == nil && count > 0
+}
+
+// StoreServerEntry adds the server entry to the data store. A newly
+// stored (or re-stored) server entry is assigned the top rank for
+// cycle order. When replaceIfExists is true, an existing server entry
+// record is overwritten; otherwise, the existing record is unchanged.
+// TODO: should be assigned top rank - 1!
+func StoreServerEntry(serverEntry *ServerEntry, replaceIfExists bool) error {
+	insert := "insert or ignore "
+	if replaceIfExists {
+		insert = "insert or replace "
+	}
+	insert += `
+    into serverEntry (id, data, rank)
+    values (?, ?, (select coalesce(max(rank), 0)+1 from serverEntry));
+    `
+	return transactionWithRetry(func(transaction *sql.Tx) error {
+		serverEntryExists := serverEntryExists(transaction, serverEntry.IpAddress)
+		statement, err := transaction.Prepare(insert)
+		if err != nil {
+			return ContextError(err)
+		}
+		defer statement.Close()
+		data, err := json.Marshal(serverEntry)
+		if err != nil {
+			return ContextError(err)
+		}
+		_, err = statement.Exec(serverEntry.IpAddress, data)
+		if err != nil {
+			return ContextError(err)
+		}
+		if !serverEntryExists {
+			// TODO: log after commit
+			log.Printf("stored server %s", serverEntry.IpAddress)
+		}
+		return nil
+	})
+}
+
+// PromoteServerEntry assigns the top cycle rank to the specified
+// server entry. This server entry will be the first candidate in
+// a subsequent tunnel establishment.
+func PromoteServerEntry(ipAddress string) error {
+	update := `
+    update serverEntry
+    set rank = (select MAX(rank)+1 from serverEntry)
+    where id = ?;
+    `
+	return transactionWithRetry(func(transaction *sql.Tx) error {
+		statement, err := transaction.Prepare(update)
+		if err != nil {
+			return ContextError(err)
+		}
+		defer statement.Close()
+		_, err = statement.Exec(ipAddress)
+		if err != nil {
+			return ContextError(err)
+		}
+		return nil
+	})
+}
+
+// ServerEntryCycler is used to continuously iterate over
+// stored server entries in rank order.
+type ServerEntryCycler struct {
+	transaction *sql.Tx
+	cursor      *sql.Rows
+	isReset     bool
+}
+
+// NewServerEntryCycler creates a new ServerEntryCycler
+func NewServerEntryCycler() (cycler *ServerEntryCycler, err error) {
+	initDataStore()
+	cycler = new(ServerEntryCycler)
+	err = cycler.Reset()
+	if err != nil {
+		return nil, err
+	}
+	return cycler, nil
+}
+
+// Reset a ServerEntryCycler to the start of its cycle. The next
+// call to Next will return the first server entry.
+func (cycler *ServerEntryCycler) Reset() error {
+	cycler.Close()
+	transaction, err := singleton.db.Begin()
+	if err != nil {
+		return ContextError(err)
+	}
+	cursor, err := transaction.Query("select * from serverEntry order by rank desc;")
+	if err != nil {
+		transaction.Rollback()
+		return ContextError(err)
+	}
+	cycler.isReset = true
+	cycler.transaction = transaction
+	cycler.cursor = cursor
+	return nil
+}
+
+// Close cleans up resources associated with a ServerEntryCycler.
+func (cycler *ServerEntryCycler) Close() {
+	if cycler.cursor != nil {
+		cycler.cursor.Close()
+	}
+	cycler.cursor = nil
+	if cycler.transaction != nil {
+		cycler.transaction.Rollback()
+	}
+	cycler.transaction = nil
+}
+
+// Next returns the next server entry, by rank, for a ServerEntryCycler. When
+// the ServerEntryCycler has worked through all known server entries, Next will
+// call Reset and start over and return the first server entry again.
+func (cycler *ServerEntryCycler) Next() (serverEntry *ServerEntry, err error) {
+	defer func() {
+		if err != nil {
+			cycler.Close()
+		}
+	}()
+	for !cycler.cursor.Next() {
+		err = cycler.cursor.Err()
+		if err != nil {
+			return nil, ContextError(err)
+		}
+		if cycler.isReset {
+			return nil, ContextError(errors.New("no server entries"))
+		}
+		err = cycler.Reset()
+		if err != nil {
+			return nil, ContextError(err)
+		}
+	}
+	cycler.isReset = false
+	var id string
+	var data []byte
+	var rank int64
+	err = cycler.cursor.Scan(&id, &data, &rank)
+	if err != nil {
+		return nil, ContextError(err)
+	}
+	serverEntry = new(ServerEntry)
+	err = json.Unmarshal(data, serverEntry)
+	if err != nil {
+		return nil, ContextError(err)
+	}
+	return serverEntry, nil
+}
+
+// HasServerEntries returns true if the data store contains at
+// least one server entry.
+func HasServerEntries() bool {
+	initDataStore()
+	var count int
+	err := singleton.db.QueryRow("select count(*) from serverEntry;").Scan(&count)
+	if err == nil {
+		log.Printf("stored servers: %d", count)
+	}
+	return err == nil && count > 0
+}

+ 8 - 5
psiphon/defaults.go

@@ -24,9 +24,12 @@ import (
 )
 
 const (
-	FETCH_REMOTE_SERVER_LIST_TIMEOUT = 5 * time.Second
-	CONNECTION_CANDIDATE_TIMEOUT     = 10 * time.Second
-	ESTABLISH_TUNNEL_TIMEOUT         = 60 * time.Second
-	CONNECTION_WORKER_POOL_SIZE      = 10
-	TCP_KEEP_ALIVE_PERIOD_SECONDS    = 60
+	DATA_STORE_FILENAME                    = "psiphon.db"
+	FETCH_REMOTE_SERVER_LIST_TIMEOUT       = 5 * time.Second
+	CONNECTION_CANDIDATE_TIMEOUT           = 10 * time.Second
+	ESTABLISH_TUNNEL_TIMEOUT               = 60 * time.Second
+	CONNECTION_WORKER_POOL_SIZE            = 10
+	TCP_KEEP_ALIVE_PERIOD_SECONDS          = 60
+	FETCH_REMOTE_SERVER_LIST_RETRY_TIMEOUT = 5 * time.Second
+	FETCH_REMOTE_SERVER_LIST_STALE_TIMEOUT = 6 * time.Hour
 )

+ 15 - 11
psiphon/remoteServerList.go

@@ -29,6 +29,7 @@ import (
 	"encoding/json"
 	"errors"
 	"io/ioutil"
+	"log"
 	"net/http"
 	"strings"
 )
@@ -46,47 +47,50 @@ type RemoteServerList struct {
 // config.RemoteServerListUrl; validates its digital signature using the
 // public key config.RemoteServerListSignaturePublicKey; and parses the
 // data field into ServerEntry records.
-func FetchRemoteServerList(config *Config) (serverList []*ServerEntry, err error) {
+func FetchRemoteServerList(config *Config) (err error) {
+	log.Printf("fetching remote server list")
 	httpClient := http.Client{
 		Timeout: FETCH_REMOTE_SERVER_LIST_TIMEOUT,
 	}
 	response, err := httpClient.Get(config.RemoteServerListUrl)
 	if err != nil {
-		return nil, err
+		return err
 	}
 	defer response.Body.Close()
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
-		return nil, err
+		return err
 	}
 	var remoteServerList *RemoteServerList
 	err = json.Unmarshal(body, &remoteServerList)
 	if err != nil {
-		return nil, err
+		return err
 	}
 	err = validateRemoteServerList(config, remoteServerList)
 	if err != nil {
-		return nil, err
+		return err
 	}
-	serverList = make([]*ServerEntry, 0)
 	for _, hexEncodedServerListItem := range strings.Split(remoteServerList.Data, "\n") {
 		decodedServerListItem, err := hex.DecodeString(hexEncodedServerListItem)
 		if err != nil {
-			return nil, err
+			return err
 		}
 		// Skip past legacy format (4 space delimited fields) and just parse the JSON config
 		fields := strings.SplitN(string(decodedServerListItem), " ", 5)
 		if len(fields) != 5 {
-			return nil, errors.New("invalid remote server list item")
+			return errors.New("invalid remote server list item")
 		}
 		var serverEntry ServerEntry
 		err = json.Unmarshal([]byte(fields[4]), &serverEntry)
 		if err != nil {
-			return nil, err
+			return err
+		}
+		err = StoreServerEntry(&serverEntry, true)
+		if err != nil {
+			return err
 		}
-		serverList = append(serverList, &serverEntry)
 	}
-	return serverList, nil
+	return nil
 }
 
 func validateRemoteServerList(config *Config, remoteServerList *RemoteServerList) (err error) {

+ 48 - 20
psiphon/runTunnel.go

@@ -57,11 +57,21 @@ func establishTunnelWorker(
 			log.Printf("failed to connect to %s: %s", serverEntry.IpAddress, err)
 		} else {
 			log.Printf("successfully connected to %s", serverEntry.IpAddress)
-			establishedTunnels <- tunnel
+			select {
+			case establishedTunnels <- tunnel:
+			default:
+				discardTunnel(tunnel)
+			}
 		}
 	}
 }
 
+func discardTunnel(tunnel *Tunnel) {
+	log.Printf("discard connection to %s", tunnel.serverEntry.IpAddress)
+	PromoteServerEntry(tunnel.serverEntry.IpAddress)
+	tunnel.Close()
+}
+
 // runTunnel establishes a tunnel session and runs local proxies that make use of
 // that tunnel. The tunnel connection is monitored and this function returns an
 // error when the tunnel unexpectedly disconnects.
@@ -70,17 +80,11 @@ func establishTunnelWorker(
 // connections in parallel, and this process is stopped once the first tunnel
 // is established.
 func runTunnel(config *Config) error {
-	log.Printf("fetching remote server list")
-	// TODO: fetch in parallel goroutine (if have local server entries)
-	serverList, err := FetchRemoteServerList(config)
-	if err != nil {
-		return fmt.Errorf("failed to fetch remote server list: %s", err)
-	}
 	log.Printf("establishing tunnel")
 	waitGroup := new(sync.WaitGroup)
 	candidateServerEntries := make(chan *ServerEntry)
 	pendingConns := new(PendingConns)
-	establishedTunnels := make(chan *Tunnel, len(serverList))
+	establishedTunnels := make(chan *Tunnel, 1)
 	timeout := time.After(ESTABLISH_TUNNEL_TIMEOUT)
 	broadcastStopWorkers := make(chan bool)
 	for i := 0; i < CONNECTION_WORKER_POOL_SIZE; i++ {
@@ -89,21 +93,25 @@ func runTunnel(config *Config) error {
 			waitGroup, candidateServerEntries, broadcastStopWorkers,
 			pendingConns, establishedTunnels)
 	}
+	// TODO: add a throttle after each full cycle?
+	// Note: errors fall through to ensure worker and channel cleanup
 	var selectedTunnel *Tunnel
-	for _, serverEntry := range serverList {
+	cycler, err := NewServerEntryCycler()
+	for selectedTunnel == nil && err == nil {
+		serverEntry, err := cycler.Next()
+		if err != nil {
+			break
+		}
 		select {
 		case candidateServerEntries <- serverEntry:
 		case selectedTunnel = <-establishedTunnels:
 			defer selectedTunnel.Close()
 			log.Printf("selected connection to %s", selectedTunnel.serverEntry.IpAddress)
 		case <-timeout:
-			return errors.New("timeout establishing tunnel")
-		}
-		if selectedTunnel != nil {
-			break
+			err = errors.New("timeout establishing tunnel")
 		}
 	}
-	log.Printf("tunnel established")
+	cycler.Close()
 	close(candidateServerEntries)
 	close(broadcastStopWorkers)
 	// Interrupt any partial connections in progress, so that
@@ -113,14 +121,19 @@ func runTunnel(config *Config) error {
 	// Drain any excess tunnels
 	close(establishedTunnels)
 	for tunnel := range establishedTunnels {
-		log.Printf("discard connection to %s", tunnel.serverEntry.IpAddress)
-		tunnel.Close()
+		discardTunnel(tunnel)
+	}
+	// Note: end of error fall through
+	if err != nil {
+		return fmt.Errorf("failed to establish tunnel: %s", err)
 	}
 	// Don't hold references to candidates while running tunnel
 	candidateServerEntries = nil
 	pendingConns = nil
 	// TODO: can start SOCKS before synchronizing work group
 	if selectedTunnel != nil {
+		log.Printf("tunnel established")
+		PromoteServerEntry(selectedTunnel.serverEntry.IpAddress)
 		stopTunnelSignal := make(chan bool)
 		err = selectedTunnel.conn.SetClosedSignal(stopTunnelSignal)
 		if err != nil {
@@ -139,7 +152,7 @@ func runTunnel(config *Config) error {
 		log.Printf("monitoring tunnel")
 		<-stopTunnelSignal
 	}
-	return nil
+	return err
 }
 
 // RunTunnelForever executes the main loop of the Psiphon client. It establishes
@@ -156,10 +169,25 @@ func RunTunnelForever(config *Config) {
 		// TODO
 		//log.SetOutput(ioutil.Discard)
 	}
+	// TODO: unlike existing Psiphon clients, this code
+	// always makes the fetch remote server list request
+	go func() {
+		for {
+			err := FetchRemoteServerList(config)
+			if err != nil {
+				log.Printf("failed to fetch remote server list: %s", err)
+				time.Sleep(FETCH_REMOTE_SERVER_LIST_RETRY_TIMEOUT)
+			} else {
+				time.Sleep(FETCH_REMOTE_SERVER_LIST_STALE_TIMEOUT)
+			}
+		}
+	}()
 	for {
-		err := runTunnel(config)
-		if err != nil {
-			log.Printf("error: %s", err)
+		if HasServerEntries() {
+			err := runTunnel(config)
+			if err != nil {
+				log.Printf("run tunnel error: %s", err)
+			}
 		}
 		time.Sleep(1 * time.Second)
 	}

+ 9 - 0
psiphon/utils.go

@@ -22,7 +22,9 @@ package psiphon
 import (
 	"crypto/rand"
 	"errors"
+	"fmt"
 	"math/big"
+	"runtime"
 )
 
 // IsSignalled returns true when the signal channel yields
@@ -71,3 +73,10 @@ func MakeSecureRandomBytes(length int) ([]byte, error) {
 	}
 	return randomBytes, nil
 }
+
+// ContextError prefixes an error message with the current function name
+func ContextError(err error) error {
+	pc, _, _, _ := runtime.Caller(1)
+	funcName := runtime.FuncForPC(pc).Name()
+	return fmt.Errorf("%s: %s", funcName, err)
+}