Просмотр исходного кода

Fix server entry tests

- Restore ValidateServerEntry, since "protocol" test cases
  exercise the IP address validation
- Add streaming test case
Rod Hynes 8 лет назад
Родитель
Сommit
ab17175189

+ 42 - 13
psiphon/common/protocol/serverEntry.go

@@ -27,6 +27,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
 	"strings"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
@@ -188,8 +189,22 @@ func DecodeServerEntry(
 	return serverEntry, nil
 }
 
+// ValidateServerEntry checks for malformed server entries.
+// Currently, it checks for a valid ipAddress. This is important since
+// the IP address is the key used to store/lookup the server entry.
+// TODO: validate more fields?
+func ValidateServerEntry(serverEntry *ServerEntry) error {
+	ipAddr := net.ParseIP(serverEntry.IpAddress)
+	if ipAddr == nil {
+		errMsg := fmt.Sprintf("server entry has invalid ipAddress: '%s'", serverEntry.IpAddress)
+		return common.ContextError(errors.New(errMsg))
+	}
+	return nil
+}
+
 // DecodeServerEntryList extracts server entries from the list encoding
 // used by remote server lists and Psiphon server handshake requests.
+// Each server entry is validated and invalid entries are skipped.
 // See DecodeServerEntry for note on serverEntrySource/timestamp.
 func DecodeServerEntryList(
 	encodedServerEntryList, timestamp,
@@ -207,6 +222,12 @@ func DecodeServerEntryList(
 			return nil, common.ContextError(err)
 		}
 
+		if ValidateServerEntry(serverEntry) != nil {
+			// Skip this entry and continue with the next one
+			// TODO: invoke a logging callback
+			continue
+		}
+
 		serverEntries = append(serverEntries, serverEntry)
 	}
 	return serverEntries, nil
@@ -232,8 +253,8 @@ func NewStreamingServerEntryDecoder(
 	}
 }
 
-// Next reads and decodes the next server entry from the input stream,
-// returning a nil server entry when the stream is complete.
+// Next reads and decodes, and validates the next server entry from the
+// input stream, returning a nil server entry when the stream is complete.
 //
 // Limitations:
 // - Each encoded server entry line cannot exceed bufio.MaxScanTokenSize,
@@ -246,18 +267,26 @@ func NewStreamingServerEntryDecoder(
 //
 func (decoder *StreamingServerEntryDecoder) Next() (*ServerEntry, error) {
 
-	if !decoder.scanner.Scan() {
-		return nil, common.ContextError(decoder.scanner.Err())
-	}
+	for {
+		if !decoder.scanner.Scan() {
+			return nil, common.ContextError(decoder.scanner.Err())
+		}
 
-	// TODO: use scanner.Bytes which doesn't allocate, instead of scanner.Text
+		// TODO: use scanner.Bytes which doesn't allocate, instead of scanner.Text
 
-	// TODO: skip this entry and continue if can't decode?
-	serverEntry, err := DecodeServerEntry(
-		decoder.scanner.Text(), decoder.timestamp, decoder.serverEntrySource)
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
+		// TODO: skip this entry and continue if can't decode?
+		serverEntry, err := DecodeServerEntry(
+			decoder.scanner.Text(), decoder.timestamp, decoder.serverEntrySource)
+		if err != nil {
+			return nil, common.ContextError(err)
+		}
 
-	return serverEntry, nil
+		if ValidateServerEntry(serverEntry) != nil {
+			// Skip this entry and continue with the next one
+			// TODO: invoke a logging callback
+			continue
+		}
+
+		return serverEntry, nil
+	}
 }

+ 40 - 7
psiphon/common/protocol/serverEntry_test.go

@@ -20,6 +20,7 @@
 package protocol
 
 import (
+	"bytes"
 	"encoding/hex"
 	"testing"
 
@@ -34,15 +35,15 @@ const (
 	_EXPECTED_IP_ADDRESS                          = `192.168.0.1`
 )
 
-// DecodeAndValidateServerEntryList should return 2 valid decoded entries from the input list of 4
-func TestDecodeAndValidateServerEntryList(t *testing.T) {
+var testEncodedServerEntryList = hex.EncodeToString([]byte(_VALID_NORMAL_SERVER_ENTRY)) + "\n" +
+	hex.EncodeToString([]byte(_VALID_BLANK_LEGACY_SERVER_ENTRY)) + "\n" +
+	hex.EncodeToString([]byte(_INVALID_WINDOWS_REGISTRY_LEGACY_SERVER_ENTRY)) + "\n" +
+	hex.EncodeToString([]byte(_INVALID_MALFORMED_IP_ADDRESS_SERVER_ENTRY))
 
-	testEncodedServerEntryList := hex.EncodeToString([]byte(_VALID_NORMAL_SERVER_ENTRY)) + "\n" +
-		hex.EncodeToString([]byte(_VALID_BLANK_LEGACY_SERVER_ENTRY)) + "\n" +
-		hex.EncodeToString([]byte(_INVALID_WINDOWS_REGISTRY_LEGACY_SERVER_ENTRY)) + "\n" +
-		hex.EncodeToString([]byte(_INVALID_MALFORMED_IP_ADDRESS_SERVER_ENTRY))
+// DecodeServerEntryList should return 2 valid decoded entries from the input list of 4
+func TestDecodeServerEntryList(t *testing.T) {
 
-	serverEntries, err := DecodeAndValidateServerEntryList(
+	serverEntries, err := DecodeServerEntryList(
 		testEncodedServerEntryList, common.GetCurrentTimestamp(), SERVER_ENTRY_SOURCE_EMBEDDED)
 	if err != nil {
 		t.Error(err.Error())
@@ -58,6 +59,38 @@ func TestDecodeAndValidateServerEntryList(t *testing.T) {
 	}
 }
 
+func TestStreamingServerEntryDecoder(t *testing.T) {
+
+	decoder := NewStreamingServerEntryDecoder(
+		bytes.NewReader([]byte(testEncodedServerEntryList)),
+		common.GetCurrentTimestamp(), SERVER_ENTRY_SOURCE_EMBEDDED)
+
+	serverEntries := make([]*ServerEntry, 0)
+
+	for {
+		serverEntry, err := decoder.Next()
+		if err != nil {
+			t.Error(err.Error())
+			t.FailNow()
+		}
+
+		if serverEntry == nil {
+			break
+		}
+
+		serverEntries = append(serverEntries, serverEntry)
+	}
+
+	if len(serverEntries) != 2 {
+		t.Error("unexpected number of valid server entries")
+	}
+	for _, serverEntry := range serverEntries {
+		if serverEntry.IpAddress != _EXPECTED_IP_ADDRESS {
+			t.Error("unexpected IP address in decoded server entry: %s", serverEntry.IpAddress)
+		}
+	}
+}
+
 // Directly call DecodeServerEntry and ValidateServerEntry with invalid inputs
 func TestInvalidServerEntries(t *testing.T) {
 

+ 7 - 7
psiphon/dataStore.go

@@ -25,7 +25,6 @@ import (
 	"errors"
 	"fmt"
 	"math/rand"
-	"net"
 	"os"
 	"path/filepath"
 	"strings"
@@ -187,11 +186,11 @@ func checkInitDataStore() {
 func StoreServerEntry(serverEntry *protocol.ServerEntry, replaceIfExists bool) error {
 	checkInitDataStore()
 
-	ipAddr := net.ParseIP(serverEntry.IpAddress)
-	if ipAddr == nil {
-		NoticeAlert("skip storing server with invalid IP address: %s", serverEntry.IpAddress)
-		// Returns no error so callers such as StoreServerEntries won't abort
-		return nil
+	// Server entries should already be validated before this point,
+	// so instead of skipping we fail with an error.
+	err := protocol.ValidateServerEntry(serverEntry)
+	if err != nil {
+		return common.ContextError(errors.New("invalid server entry"))
 	}
 
 	// BoltDB implementation note:
@@ -202,7 +201,7 @@ func StoreServerEntry(serverEntry *protocol.ServerEntry, replaceIfExists bool) e
 	// values (e.g., many servers support all protocols), performance
 	// is expected to be acceptable.
 
-	err := singleton.db.Update(func(tx *bolt.Tx) error {
+	err = singleton.db.Update(func(tx *bolt.Tx) error {
 
 		serverEntries := tx.Bucket([]byte(serverEntriesBucket))
 
@@ -280,6 +279,7 @@ func StreamingStoreServerEntries(
 	// allocate temporary memory buffers for hex/JSON decoding/encoding,
 	// so this isn't true constant-memory streaming (it depends on garbage
 	// collection).
+
 	for {
 		serverEntry, err := serverEntries.Next()
 		if err != nil {

+ 7 - 0
psiphon/serverApi.go

@@ -198,6 +198,13 @@ func (serverContext *ServerContext) doHandshakeRequest() error {
 			return common.ContextError(err)
 		}
 
+		err = protocol.ValidateServerEntry(serverEntry)
+		if err != nil {
+			// Skip this entry and continue with the next one
+			NoticeAlert("invalid server entry: %s", err)
+			continue
+		}
+
 		decodedServerEntries = append(decodedServerEntries, serverEntry)
 	}