Просмотр исходного кода

Merge branch 'server-side-tactics'

Rod Hynes 7 лет назад
Родитель
Сommit
85bfa38c45

+ 1 - 1
psiphon/common/parameters/clientParameters.go

@@ -418,7 +418,7 @@ func (p *ClientParameters) Set(
 
 			// A JSON remarshal resolves cases where applyParameters is a
 			// result of unmarshal-into-interface, in which case non-scalar
-			// values will not have the expecte types; see:
+			// values will not have the expected types; see:
 			// https://golang.org/pkg/encoding/json/#Unmarshal. This remarshal
 			// also results in a deep copy.
 

+ 148 - 39
psiphon/common/tactics/tactics.go

@@ -162,6 +162,7 @@ import (
 	"errors"
 	"fmt"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"sort"
 	"time"
@@ -230,6 +231,10 @@ type Server struct {
 	// RequestObfuscatedKey is the tactics request obfuscation key.
 	RequestObfuscatedKey []byte
 
+	// EnforceServerSide enables server-side enforcement of certain tactics
+	// parameters via Listeners.
+	EnforceServerSide bool
+
 	// DefaultTactics is the baseline tactics for all clients. It must include a
 	// TTL and Probability.
 	DefaultTactics Tactics
@@ -448,6 +453,7 @@ func NewServer(
 			server.RequestPublicKey = newServer.RequestPublicKey
 			server.RequestPrivateKey = newServer.RequestPrivateKey
 			server.RequestObfuscatedKey = newServer.RequestObfuscatedKey
+			server.EnforceServerSide = newServer.EnforceServerSide
 			server.DefaultTactics = newServer.DefaultTactics
 			server.FilteredTactics = newServer.FilteredTactics
 
@@ -608,6 +614,60 @@ func (server *Server) GetTacticsPayload(
 	geoIPData common.GeoIPData,
 	apiParams common.APIParameters) (*Payload, error) {
 
+	tactics, err := server.getTactics(geoIPData, apiParams)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	if tactics == nil {
+		return nil, nil
+	}
+
+	marshaledTactics, err := json.Marshal(tactics)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	// MD5 hash is used solely as a data checksum and not for any security purpose.
+	digest := md5.Sum(marshaledTactics)
+	tag := hex.EncodeToString(digest[:])
+
+	payload := &Payload{
+		Tag: tag,
+	}
+
+	// New clients should always send STORED_TACTICS_TAG_PARAMETER_NAME. When they have no
+	// stored tactics, the stored tag will be "" and not match payload.Tag and payload.Tactics
+	// will be sent.
+	//
+	// When new clients send a stored tag that matches payload.Tag, the client already has
+	// the correct data and payload.Tactics is not sent.
+	//
+	// Old clients will not send STORED_TACTICS_TAG_PARAMETER_NAME. In this case, do not
+	// send payload.Tactics as the client will not use it, will not store it, will not send
+	// back the new tag and so the handshake response will always contain wasteful tactics
+	// data.
+
+	sendPayloadTactics := true
+
+	clientStoredTag, err := getStringRequestParam(apiParams, STORED_TACTICS_TAG_PARAMETER_NAME)
+
+	// Old client or new client with same tag.
+	if err != nil || payload.Tag == clientStoredTag {
+		sendPayloadTactics = false
+	}
+
+	if sendPayloadTactics {
+		payload.Tactics = marshaledTactics
+	}
+
+	return payload, nil
+}
+
+func (server *Server) getTactics(
+	geoIPData common.GeoIPData,
+	apiParams common.APIParameters) (*Tactics, error) {
+
 	server.ReloadableFile.RLock()
 	defer server.ReloadableFile.RUnlock()
 
@@ -703,45 +763,7 @@ func (server *Server) GetTacticsPayload(
 		// Continue to apply more matches. Last matching tactics has priority for any field.
 	}
 
-	marshaledTactics, err := json.Marshal(tactics)
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
-
-	// MD5 hash is used solely as a data checksum and not for any security purpose.
-	digest := md5.Sum(marshaledTactics)
-	tag := hex.EncodeToString(digest[:])
-
-	payload := &Payload{
-		Tag: tag,
-	}
-
-	// New clients should always send STORED_TACTICS_TAG_PARAMETER_NAME. When they have no
-	// stored tactics, the stored tag will be "" and not match payload.Tag and payload.Tactics
-	// will be sent.
-	//
-	// When new clients send a stored tag that matches payload.Tag, the client already has
-	// the correct data and payload.Tactics is not sent.
-	//
-	// Old clients will not send STORED_TACTICS_TAG_PARAMETER_NAME. In this case, do not
-	// send payload.Tactics as the client will not use it, will not store it, will not send
-	// back the new tag and so the handshake response will always contain wasteful tactics
-	// data.
-
-	sendPayloadTactics := true
-
-	clientStoredTag, err := getStringRequestParam(apiParams, STORED_TACTICS_TAG_PARAMETER_NAME)
-
-	// Old client or new client with same tag.
-	if err != nil || payload.Tag == clientStoredTag {
-		sendPayloadTactics = false
-	}
-
-	if sendPayloadTactics {
-		payload.Tactics = marshaledTactics
-	}
-
-	return payload, nil
+	return tactics, nil
 }
 
 // TODO: refactor this copy of psiphon/server.getStringRequestParam into common?
@@ -1036,6 +1058,93 @@ func (server *Server) handleTacticsRequest(
 	server.logger.LogMetric(TACTICS_METRIC_EVENT_NAME, logFields)
 }
 
+// Listener wraps a net.Listener and applies server-side enforcement of
+// certain tactics parameters to accepted connections. Tactics filtering is
+// limited to GeoIP attributes as the client has not yet sent API paramaters.
+type Listener struct {
+	net.Listener
+	server         *Server
+	tunnelProtocol string
+	geoIPLookup    func(IPaddress string) common.GeoIPData
+}
+
+// NewListener creates a new Listener.
+func NewListener(
+	listener net.Listener,
+	server *Server,
+	tunnelProtocol string,
+	geoIPLookup func(IPaddress string) common.GeoIPData) *Listener {
+
+	return &Listener{
+		Listener:       listener,
+		server:         server,
+		tunnelProtocol: tunnelProtocol,
+		geoIPLookup:    geoIPLookup,
+	}
+}
+
+// Close calls the underlying listener's Accept, and then
+// checks if tactics for the connection set LimitTunnelProtocols.
+// If LimitTunnelProtocols is set and does not include the
+// tunnel protocol the listener is running, the accepted
+// connection is immediately closed and the underlying
+// Accept is called again.
+func (listener *Listener) Accept() (net.Conn, error) {
+	for {
+
+		conn, err := listener.Listener.Accept()
+		if err != nil {
+			// Don't modify error from net.Listener
+			return nil, err
+		}
+
+		if !listener.server.EnforceServerSide {
+			return conn, nil
+		}
+
+		geoIPData := listener.geoIPLookup(common.IPAddressFromAddr(conn.RemoteAddr()))
+
+		tactics, err := listener.server.getTactics(geoIPData, make(common.APIParameters))
+		if err != nil {
+			listener.server.logger.WithContextFields(
+				common.LogFields{"error": err}).Warning("failed to get tactics for connection")
+			// If tactics is somehow misconfigured, keep handling connections.
+			// Other error cases that follow below take the same approach.
+			return conn, nil
+		}
+
+		if tactics == nil {
+			// This server isn't configured with tactics.
+			return conn, nil
+		}
+
+		limitTunnelProtocolsParameter, ok := tactics.Parameters[parameters.LimitTunnelProtocols]
+		if !ok {
+			// The tactics for the connection don't set LimitTunnelProtocols.
+			return conn, nil
+		}
+
+		if !common.FlipWeightedCoin(tactics.Probability) {
+			// Skip tactics with the configured probability.
+			return conn, nil
+		}
+
+		limitTunnelProtocols, ok := common.GetStringSlice(limitTunnelProtocolsParameter)
+		if !ok ||
+			len(limitTunnelProtocols) == 0 ||
+			common.Contains(limitTunnelProtocols, listener.tunnelProtocol) {
+
+			// The parameter is invalid; or no limit is set; or the
+			// listener protocol is not prohibited.
+			return conn, nil
+		}
+
+		// Don't accept this connection as its tactics prohibits the
+		// listener's tunnel protocol.
+		conn.Close()
+	}
+}
+
 // RoundTripper performs a round trip to the specified endpoint, sending the
 // request body and returning the response body. The context may be used to
 // set a timeout or cancel the rount trip.

+ 106 - 1
psiphon/common/tactics/tactics_test.go

@@ -50,6 +50,7 @@ func TestTactics(t *testing.T) {
       "RequestPublicKey" : "%s",
       "RequestPrivateKey" : "%s",
       "RequestObfuscatedKey" : "%s",
+      "EnforceServerSide" : true,
       "DefaultTactics" : {
         "TTL" : "1s",
         "Probability" : %0.1f,
@@ -97,6 +98,16 @@ func TestTactics(t *testing.T) {
               "ConnectionWorkerPoolSize" : %d
             }
           }
+        },
+        {
+          "Filter" : {
+            "Regions": ["R7"]
+          },
+          "Tactics" : {
+            "Parameters" : {
+              "LimitTunnelProtocols" : ["SSH"]
+            }
+          }
         }
       ]
     }
@@ -116,6 +127,10 @@ func TestTactics(t *testing.T) {
 	tacticsLimitTunnelProtocols := protocol.TunnelProtocols{"OSSH", "SSH"}
 	jsonTacticsLimitTunnelProtocols, _ := json.Marshal(tacticsLimitTunnelProtocols)
 
+	listenerProtocol := "OSSH"
+	listenerProhibitedGeoIP := func(string) common.GeoIPData { return common.GeoIPData{Country: "R7"} }
+	listenerAllowedGeoIP := func(string) common.GeoIPData { return common.GeoIPData{Country: "R8"} }
+
 	tacticsConfig := fmt.Sprintf(
 		tacticsConfigTemplate,
 		encodedRequestPublicKey,
@@ -680,8 +695,98 @@ func TestTactics(t *testing.T) {
 		t.Fatalf("HandleEndPoint unexpectedly handled request")
 	}
 
-	// TODO: test replay attack defence
+	// Test Listener
 
+	tacticsProbability = 1.0
+
+	tacticsConfig = fmt.Sprintf(
+		tacticsConfigTemplate,
+		"",
+		"",
+		"",
+		tacticsProbability,
+		tacticsNetworkLatencyMultiplier,
+		tacticsConnectionWorkerPoolSize,
+		jsonTacticsLimitTunnelProtocols,
+		tacticsConnectionWorkerPoolSize+1)
+
+	err = ioutil.WriteFile(configFileName, []byte(tacticsConfig), 0600)
+	if err != nil {
+		t.Fatalf("WriteFile failed: %s", err)
+	}
+
+	reloaded, err = server.Reload()
+	if err != nil {
+		t.Fatalf("Reload failed: %s", err)
+	}
+
+	listenerTestCases := []struct {
+		description      string
+		geoIPLookup      func(string) common.GeoIPData
+		expectConnection bool
+	}{
+		{
+			"connection prohibited",
+			listenerProhibitedGeoIP,
+			false,
+		},
+		{
+			"connection allowed",
+			listenerAllowedGeoIP,
+			true,
+		},
+	}
+
+	for _, testCase := range listenerTestCases {
+		t.Run(testCase.description, func(t *testing.T) {
+
+			tcpListener, err := net.Listen("tcp", ":0")
+			if err != nil {
+				t.Fatalf(" net.Listen failed: %s", err)
+			}
+
+			tacticsListener := NewListener(
+				tcpListener,
+				server,
+				listenerProtocol,
+				testCase.geoIPLookup)
+
+			clientConn, err := net.Dial("tcp", tacticsListener.Addr().String())
+			if err != nil {
+				t.Fatalf(" net.Dial failed: %s", err)
+				return
+			}
+
+			result := make(chan struct{}, 1)
+
+			go func() {
+				serverConn, err := tacticsListener.Accept()
+				if err == nil {
+					result <- *new(struct{})
+					serverConn.Close()
+				}
+			}()
+
+			timer := time.NewTimer(3 * time.Second)
+			defer timer.Stop()
+
+			select {
+			case <-result:
+				if !testCase.expectConnection {
+					t.Fatalf("unexpected accepted connection")
+				}
+			case <-timer.C:
+				if testCase.expectConnection {
+					t.Fatalf("timeout before expected accepted connection")
+				}
+			}
+
+			clientConn.Close()
+			tacticsListener.Close()
+		})
+	}
+
+	// TODO: test replay attack defence
 	// TODO: test Server.Validate with invalid tactics configurations
 }
 

+ 19 - 0
psiphon/common/utils.go

@@ -70,6 +70,25 @@ func ContainsInt(list []int, target int) bool {
 	return false
 }
 
+// GetStringSlice converts an interface{} which is
+// of type []interace{}, and with the type of each
+// element a string, to []string.
+func GetStringSlice(value interface{}) ([]string, bool) {
+	slice, ok := value.([]interface{})
+	if !ok {
+		return nil, false
+	}
+	strSlice := make([]string, len(slice))
+	for index, element := range slice {
+		str, ok := element.(string)
+		if !ok {
+			return nil, false
+		}
+		strSlice[index] = str
+	}
+	return strSlice, true
+}
+
 // FlipCoin is a helper function that randomly
 // returns true or false.
 //

+ 28 - 0
psiphon/common/utils_test.go

@@ -21,12 +21,40 @@ package common
 
 import (
 	"bytes"
+	"encoding/json"
 	"fmt"
 	"math"
+	"reflect"
 	"testing"
 	"time"
 )
 
+func TestGetStringSlice(t *testing.T) {
+
+	originalSlice := []string{"a", "b", "c"}
+
+	j, err := json.Marshal(originalSlice)
+	if err != nil {
+		t.Errorf("json.Marshal failed: %s", err)
+	}
+
+	var value interface{}
+
+	err = json.Unmarshal(&value)
+	if err != nil {
+		t.Errorf("json.Unmarshal failed: %s", err)
+	}
+
+	newSlice, ok := GetStringSlice(value)
+	if !ok {
+		t.Errorf("GetStringSlice failed")
+	}
+
+	if !reflect.DeepEqual(originalSlice, newSlice) {
+		t.Errorf("unexpected GetStringSlice output")
+	}
+}
+
 func TestMakeRandomPeriod(t *testing.T) {
 	min := 1 * time.Nanosecond
 	max := 10000 * time.Nanosecond

+ 1 - 19
psiphon/server/net.go

@@ -53,7 +53,6 @@ package server
 import (
 	"net"
 	"net/http"
-	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tls"
 )
@@ -74,24 +73,7 @@ type HTTPSServer struct {
 //
 // Note that the http.Server.TLSConfig field is ignored and the
 // psiphon/common/tls.Config parameter is used intead.
-//
-// tcpKeepAliveListener is used in http.ListenAndServeTLS but not exported,
-// so we use a copy from https://golang.org/src/net/http/server.go.
 func (server *HTTPSServer) ServeTLS(listener net.Listener, config *tls.Config) error {
-	tlsListener := tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
+	tlsListener := tls.NewListener(listener, config)
 	return server.Serve(tlsListener)
 }
-
-type tcpKeepAliveListener struct {
-	*net.TCPListener
-}
-
-func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
-	tc, err := ln.AcceptTCP()
-	if err != nil {
-		return
-	}
-	tc.SetKeepAlive(true)
-	tc.SetKeepAlivePeriod(3 * time.Minute)
-	return tc, nil
-}

+ 22 - 5
psiphon/server/server_test.go

@@ -444,6 +444,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		paveTacticsConfigFile(
 			t, tacticsConfigFilename,
 			tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
+			runConfig.tunnelProtocol,
 			propagationChannelID)
 	}
 
@@ -453,7 +454,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	serverConfig["PsinetDatabaseFilename"] = psinetFilename
 	serverConfig["TrafficRulesFilename"] = trafficRulesFilename
 	serverConfig["OSLConfigFilename"] = oslConfigFilename
-	serverConfig["TacticsConfigFilename"] = tacticsConfigFilename
+	if doTactics {
+		serverConfig["TacticsConfigFilename"] = tacticsConfigFilename
+	}
 	serverConfig["LogFilename"] = filepath.Join(testDataDirName, "psiphond.log")
 	serverConfig["LogLevel"] = "debug"
 
@@ -582,7 +585,11 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	}
 
 	if doTactics {
-		clientConfig.NetworkIDGetter = &testNetworkGetter{}
+		// Use a distinct prefix for network ID for each test run to
+		// ensure tactics from different runs don't apply; this is
+		// a workaround for the singleton datastore.
+		prefix := time.Now().String()
+		clientConfig.NetworkIDGetter = &testNetworkGetter{prefix: prefix}
 	}
 
 	if doTactics {
@@ -1181,16 +1188,24 @@ func paveOSLConfigFile(t *testing.T, oslConfigFilename string) string {
 func paveTacticsConfigFile(
 	t *testing.T, tacticsConfigFilename string,
 	tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey string,
+	tunnelProtocol string,
 	propagationChannelID string) {
 
+	// Setting LimitTunnelProtocols passively exercises the
+	// server-side LimitTunnelProtocols enforcement.
+
 	tacticsConfigJSONFormat := `
     {
       "RequestPublicKey" : "%s",
       "RequestPrivateKey" : "%s",
       "RequestObfuscatedKey" : "%s",
+      "EnforceServerSide" : true,
       "DefaultTactics" : {
         "TTL" : "60s",
-        "Probability" : 1.0
+        "Probability" : 1.0,
+        "Parameters" : {
+          "LimitTunnelProtocols" : ["%s"]
+        }
       },
       "FilteredTactics" : [
         {
@@ -1215,6 +1230,7 @@ func paveTacticsConfigFile(
 	tacticsConfigJSON := fmt.Sprintf(
 		tacticsConfigJSONFormat,
 		tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
+		tunnelProtocol,
 		propagationChannelID)
 
 	err := ioutil.WriteFile(tacticsConfigFilename, []byte(tacticsConfigJSON), 0600)
@@ -1245,8 +1261,9 @@ const dummyClientVerificationPayload = `
 }`
 
 type testNetworkGetter struct {
+	prefix string
 }
 
-func (testNetworkGetter) GetNetworkID() string {
-	return "NETWORK1"
+func (t *testNetworkGetter) GetNetworkID() string {
+	return t.prefix + "NETWORK1"
 }

+ 9 - 1
psiphon/server/tunnelServer.go

@@ -143,6 +143,14 @@ func (server *TunnelServer) Run() error {
 			return common.ContextError(err)
 		}
 
+		tacticsListener := tactics.NewListener(
+			listener,
+			support.TacticsServer,
+			tunnelProtocol,
+			func(IPAddress string) common.GeoIPData {
+				return common.GeoIPData(support.GeoIPService.Lookup(IPAddress))
+			})
+
 		log.WithContextFields(
 			LogFields{
 				"localAddress":   localAddress,
@@ -152,7 +160,7 @@ func (server *TunnelServer) Run() error {
 		listeners = append(
 			listeners,
 			&sshListener{
-				Listener:       listener,
+				Listener:       tacticsListener,
 				localAddress:   localAddress,
 				tunnelProtocol: tunnelProtocol,
 			})