Przeglądaj źródła

Add server-side tactics enforcement

- A listener wrapper immediately closes new connections
  when their GeoIP attributes map to tactics that prohibit
  the listener's associated tunnel protocol.

  This implements the LimitTunnelProtocol tactics server-side
  for clients that don't yet support tactics. These clients
  will still attempt these connections, but they will not
  establish using the tunnel protocol.

- This change also removes TCP keepalive which was applied
  to HTTPS meek but not HTTP meek TCP connections. Naive TCP
  keepalive adds fingerprintable network traffic and is not
  necessary as other connection activity monitors exist.
Rod Hynes 7 lat temu
rodzic
commit
bb66648f53

+ 1 - 1
psiphon/common/parameters/clientParameters.go

@@ -418,7 +418,7 @@ func (p *ClientParameters) Set(
 
 			// A JSON remarshal resolves cases where applyParameters is a
 			// result of unmarshal-into-interface, in which case non-scalar
-			// values will not have the expecte types; see:
+			// values will not have the expected types; see:
 			// https://golang.org/pkg/encoding/json/#Unmarshal. This remarshal
 			// also results in a deep copy.
 

+ 139 - 39
psiphon/common/tactics/tactics.go

@@ -162,6 +162,7 @@ import (
 	"errors"
 	"fmt"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"sort"
 	"time"
@@ -608,6 +609,60 @@ func (server *Server) GetTacticsPayload(
 	geoIPData common.GeoIPData,
 	apiParams common.APIParameters) (*Payload, error) {
 
+	tactics, err := server.getTactics(geoIPData, apiParams)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	if tactics == nil {
+		return nil, nil
+	}
+
+	marshaledTactics, err := json.Marshal(tactics)
+	if err != nil {
+		return nil, common.ContextError(err)
+	}
+
+	// MD5 hash is used solely as a data checksum and not for any security purpose.
+	digest := md5.Sum(marshaledTactics)
+	tag := hex.EncodeToString(digest[:])
+
+	payload := &Payload{
+		Tag: tag,
+	}
+
+	// New clients should always send STORED_TACTICS_TAG_PARAMETER_NAME. When they have no
+	// stored tactics, the stored tag will be "" and not match payload.Tag and payload.Tactics
+	// will be sent.
+	//
+	// When new clients send a stored tag that matches payload.Tag, the client already has
+	// the correct data and payload.Tactics is not sent.
+	//
+	// Old clients will not send STORED_TACTICS_TAG_PARAMETER_NAME. In this case, do not
+	// send payload.Tactics as the client will not use it, will not store it, will not send
+	// back the new tag and so the handshake response will always contain wasteful tactics
+	// data.
+
+	sendPayloadTactics := true
+
+	clientStoredTag, err := getStringRequestParam(apiParams, STORED_TACTICS_TAG_PARAMETER_NAME)
+
+	// Old client or new client with same tag.
+	if err != nil || payload.Tag == clientStoredTag {
+		sendPayloadTactics = false
+	}
+
+	if sendPayloadTactics {
+		payload.Tactics = marshaledTactics
+	}
+
+	return payload, nil
+}
+
+func (server *Server) getTactics(
+	geoIPData common.GeoIPData,
+	apiParams common.APIParameters) (*Tactics, error) {
+
 	server.ReloadableFile.RLock()
 	defer server.ReloadableFile.RUnlock()
 
@@ -703,45 +758,7 @@ func (server *Server) GetTacticsPayload(
 		// Continue to apply more matches. Last matching tactics has priority for any field.
 	}
 
-	marshaledTactics, err := json.Marshal(tactics)
-	if err != nil {
-		return nil, common.ContextError(err)
-	}
-
-	// MD5 hash is used solely as a data checksum and not for any security purpose.
-	digest := md5.Sum(marshaledTactics)
-	tag := hex.EncodeToString(digest[:])
-
-	payload := &Payload{
-		Tag: tag,
-	}
-
-	// New clients should always send STORED_TACTICS_TAG_PARAMETER_NAME. When they have no
-	// stored tactics, the stored tag will be "" and not match payload.Tag and payload.Tactics
-	// will be sent.
-	//
-	// When new clients send a stored tag that matches payload.Tag, the client already has
-	// the correct data and payload.Tactics is not sent.
-	//
-	// Old clients will not send STORED_TACTICS_TAG_PARAMETER_NAME. In this case, do not
-	// send payload.Tactics as the client will not use it, will not store it, will not send
-	// back the new tag and so the handshake response will always contain wasteful tactics
-	// data.
-
-	sendPayloadTactics := true
-
-	clientStoredTag, err := getStringRequestParam(apiParams, STORED_TACTICS_TAG_PARAMETER_NAME)
-
-	// Old client or new client with same tag.
-	if err != nil || payload.Tag == clientStoredTag {
-		sendPayloadTactics = false
-	}
-
-	if sendPayloadTactics {
-		payload.Tactics = marshaledTactics
-	}
-
-	return payload, nil
+	return tactics, nil
 }
 
 // TODO: refactor this copy of psiphon/server.getStringRequestParam into common?
@@ -1036,6 +1053,89 @@ func (server *Server) handleTacticsRequest(
 	server.logger.LogMetric(TACTICS_METRIC_EVENT_NAME, logFields)
 }
 
+// Listener wraps a net.Listener and applies server-side enforcement of
+// certain tactics parameters to accepted connections. Tactics filtering is
+// limited to GeoIP attributes as the client has not yet sent API paramaters.
+type Listener struct {
+	net.Listener
+	server         *Server
+	tunnelProtocol string
+	geoIPLookup    func(IPaddress string) common.GeoIPData
+}
+
+// NewListener creates a new Listener.
+func NewListener(
+	listener net.Listener,
+	server *Server,
+	tunnelProtocol string,
+	geoIPLookup func(IPaddress string) common.GeoIPData) *Listener {
+
+	return &Listener{
+		Listener:       listener,
+		server:         server,
+		tunnelProtocol: tunnelProtocol,
+		geoIPLookup:    geoIPLookup,
+	}
+}
+
+// Close calls the underlying listener's Accept, and then
+// checks if tactics for the connection set LimitTunnelProtocols.
+// If LimitTunnelProtocols is set and does not include the
+// tunnel protocol the listener is running, the accepted
+// connection is immediately closed and the underlying
+// Accept is called again.
+func (listener *Listener) Accept() (net.Conn, error) {
+	for {
+
+		conn, err := listener.Listener.Accept()
+		if err != nil {
+			// Don't modify error from net.Listener
+			return nil, err
+		}
+
+		geoIPData := listener.geoIPLookup(common.IPAddressFromAddr(conn.RemoteAddr()))
+
+		tactics, err := listener.server.getTactics(geoIPData, make(common.APIParameters))
+		if err != nil {
+			listener.server.logger.WithContextFields(
+				common.LogFields{"error": err}).Warning("failed to get tactics for connection")
+			// If tactics is somehow misconfigured, keep handling connections.
+			// Other error cases that follow below take the same approach.
+			return conn, nil
+		}
+
+		if tactics == nil {
+			// This server isn't configured with tactics.
+			return conn, nil
+		}
+
+		limitTunnelProtocolsParameter, ok := tactics.Parameters[parameters.LimitTunnelProtocols]
+		if !ok {
+			// The tactics for the connection don't set LimitTunnelProtocols.
+			return conn, nil
+		}
+
+		if !common.FlipWeightedCoin(tactics.Probability) {
+			// Skip tactics with the configured probability.
+			return conn, nil
+		}
+
+		limitTunnelProtocols, ok := common.GetStringSlice(limitTunnelProtocolsParameter)
+		if !ok ||
+			len(limitTunnelProtocols) == 0 ||
+			common.Contains(limitTunnelProtocols, listener.tunnelProtocol) {
+
+			// The parameter is invalid; or no limit is set; or the
+			// listener protocol is not prohibited.
+			return conn, nil
+		}
+
+		// Don't accept this connection as its tactics prohibits the
+		// listener's tunnel protocol.
+		conn.Close()
+	}
+}
+
 // RoundTripper performs a round trip to the specified endpoint, sending the
 // request body and returning the response body. The context may be used to
 // set a timeout or cancel the rount trip.

+ 105 - 1
psiphon/common/tactics/tactics_test.go

@@ -97,6 +97,16 @@ func TestTactics(t *testing.T) {
               "ConnectionWorkerPoolSize" : %d
             }
           }
+        },
+        {
+          "Filter" : {
+            "Regions": ["R7"]
+          },
+          "Tactics" : {
+            "Parameters" : {
+              "LimitTunnelProtocols" : ["SSH"]
+            }
+          }
         }
       ]
     }
@@ -116,6 +126,10 @@ func TestTactics(t *testing.T) {
 	tacticsLimitTunnelProtocols := protocol.TunnelProtocols{"OSSH", "SSH"}
 	jsonTacticsLimitTunnelProtocols, _ := json.Marshal(tacticsLimitTunnelProtocols)
 
+	listenerProtocol := "OSSH"
+	listenerProhibitedGeoIP := func(string) common.GeoIPData { return common.GeoIPData{Country: "R7"} }
+	listenerAllowedGeoIP := func(string) common.GeoIPData { return common.GeoIPData{Country: "R8"} }
+
 	tacticsConfig := fmt.Sprintf(
 		tacticsConfigTemplate,
 		encodedRequestPublicKey,
@@ -680,8 +694,98 @@ func TestTactics(t *testing.T) {
 		t.Fatalf("HandleEndPoint unexpectedly handled request")
 	}
 
-	// TODO: test replay attack defence
+	// Test Listener
 
+	tacticsProbability = 1.0
+
+	tacticsConfig = fmt.Sprintf(
+		tacticsConfigTemplate,
+		"",
+		"",
+		"",
+		tacticsProbability,
+		tacticsNetworkLatencyMultiplier,
+		tacticsConnectionWorkerPoolSize,
+		jsonTacticsLimitTunnelProtocols,
+		tacticsConnectionWorkerPoolSize+1)
+
+	err = ioutil.WriteFile(configFileName, []byte(tacticsConfig), 0600)
+	if err != nil {
+		t.Fatalf("WriteFile failed: %s", err)
+	}
+
+	reloaded, err = server.Reload()
+	if err != nil {
+		t.Fatalf("Reload failed: %s", err)
+	}
+
+	listenerTestCases := []struct {
+		description      string
+		geoIPLookup      func(string) common.GeoIPData
+		expectConnection bool
+	}{
+		{
+			"connection prohibited",
+			listenerProhibitedGeoIP,
+			false,
+		},
+		{
+			"connection allowed",
+			listenerAllowedGeoIP,
+			true,
+		},
+	}
+
+	for _, testCase := range listenerTestCases {
+		t.Run(testCase.description, func(t *testing.T) {
+
+			tcpListener, err := net.Listen("tcp", ":0")
+			if err != nil {
+				t.Fatalf(" net.Listen failed: %s", err)
+			}
+
+			tacticsListener := NewListener(
+				tcpListener,
+				server,
+				listenerProtocol,
+				testCase.geoIPLookup)
+
+			clientConn, err := net.Dial("tcp", tacticsListener.Addr().String())
+			if err != nil {
+				t.Fatalf(" net.Dial failed: %s", err)
+				return
+			}
+
+			result := make(chan struct{}, 1)
+
+			go func() {
+				serverConn, err := tacticsListener.Accept()
+				if err == nil {
+					result <- *new(struct{})
+					serverConn.Close()
+				}
+			}()
+
+			timer := time.NewTimer(3 * time.Second)
+			defer timer.Stop()
+
+			select {
+			case <-result:
+				if !testCase.expectConnection {
+					t.Fatalf("unexpected accepted connection")
+				}
+			case <-timer.C:
+				if testCase.expectConnection {
+					t.Fatalf("timeout before expected accepted connection")
+				}
+			}
+
+			clientConn.Close()
+			tacticsListener.Close()
+		})
+	}
+
+	// TODO: test replay attack defence
 	// TODO: test Server.Validate with invalid tactics configurations
 }
 

+ 19 - 0
psiphon/common/utils.go

@@ -70,6 +70,25 @@ func ContainsInt(list []int, target int) bool {
 	return false
 }
 
+// GetStringSlice converts an interface{} which is
+// of type []interace{}, and with the type of each
+// element a string, to []string.
+func GetStringSlice(value interface{}) ([]string, bool) {
+	slice, ok := value.([]interface{})
+	if !ok {
+		return nil, false
+	}
+	strSlice := make([]string, len(slice))
+	for index, element := range slice {
+		str, ok := element.(string)
+		if !ok {
+			return nil, false
+		}
+		strSlice[index] = str
+	}
+	return strSlice, true
+}
+
 // FlipCoin is a helper function that randomly
 // returns true or false.
 //

+ 28 - 0
psiphon/common/utils_test.go

@@ -21,12 +21,40 @@ package common
 
 import (
 	"bytes"
+	"encoding/json"
 	"fmt"
 	"math"
+	"reflect"
 	"testing"
 	"time"
 )
 
+func TestGetStringSlice(t *testing.T) {
+
+	originalSlice := []string{"a", "b", "c"}
+
+	j, err := json.Marshal(originalSlice)
+	if err != nil {
+		t.Errorf("json.Marshal failed: %s", err)
+	}
+
+	var value interface{}
+
+	err = json.Unmarshal(&value)
+	if err != nil {
+		t.Errorf("json.Unmarshal failed: %s", err)
+	}
+
+	newSlice, ok := GetStringSlice(value)
+	if !ok {
+		t.Errorf("GetStringSlice failed")
+	}
+
+	if !reflect.DeepEqual(originalSlice, newSlice) {
+		t.Errorf("unexpected GetStringSlice output")
+	}
+}
+
 func TestMakeRandomPeriod(t *testing.T) {
 	min := 1 * time.Nanosecond
 	max := 10000 * time.Nanosecond

+ 1 - 19
psiphon/server/net.go

@@ -53,7 +53,6 @@ package server
 import (
 	"net"
 	"net/http"
-	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tls"
 )
@@ -74,24 +73,7 @@ type HTTPSServer struct {
 //
 // Note that the http.Server.TLSConfig field is ignored and the
 // psiphon/common/tls.Config parameter is used intead.
-//
-// tcpKeepAliveListener is used in http.ListenAndServeTLS but not exported,
-// so we use a copy from https://golang.org/src/net/http/server.go.
 func (server *HTTPSServer) ServeTLS(listener net.Listener, config *tls.Config) error {
-	tlsListener := tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
+	tlsListener := tls.NewListener(listener, config)
 	return server.Serve(tlsListener)
 }
-
-type tcpKeepAliveListener struct {
-	*net.TCPListener
-}
-
-func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
-	tc, err := ln.AcceptTCP()
-	if err != nil {
-		return
-	}
-	tc.SetKeepAlive(true)
-	tc.SetKeepAlivePeriod(3 * time.Minute)
-	return tc, nil
-}

+ 21 - 5
psiphon/server/server_test.go

@@ -444,6 +444,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		paveTacticsConfigFile(
 			t, tacticsConfigFilename,
 			tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
+			runConfig.tunnelProtocol,
 			propagationChannelID)
 	}
 
@@ -453,7 +454,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	serverConfig["PsinetDatabaseFilename"] = psinetFilename
 	serverConfig["TrafficRulesFilename"] = trafficRulesFilename
 	serverConfig["OSLConfigFilename"] = oslConfigFilename
-	serverConfig["TacticsConfigFilename"] = tacticsConfigFilename
+	if doTactics {
+		serverConfig["TacticsConfigFilename"] = tacticsConfigFilename
+	}
 	serverConfig["LogFilename"] = filepath.Join(testDataDirName, "psiphond.log")
 	serverConfig["LogLevel"] = "debug"
 
@@ -582,7 +585,11 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	}
 
 	if doTactics {
-		clientConfig.NetworkIDGetter = &testNetworkGetter{}
+		// Use a distinct prefix for network ID for each test run to
+		// ensure tactics from different runs don't apply; this is
+		// a workaround for the singleton datastore.
+		prefix := time.Now().String()
+		clientConfig.NetworkIDGetter = &testNetworkGetter{prefix: prefix}
 	}
 
 	if doTactics {
@@ -1181,8 +1188,12 @@ func paveOSLConfigFile(t *testing.T, oslConfigFilename string) string {
 func paveTacticsConfigFile(
 	t *testing.T, tacticsConfigFilename string,
 	tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey string,
+	tunnelProtocol string,
 	propagationChannelID string) {
 
+	// Setting LimitTunnelProtocols passively exercises the
+	// server-side LimitTunnelProtocols enforcement.
+
 	tacticsConfigJSONFormat := `
     {
       "RequestPublicKey" : "%s",
@@ -1190,7 +1201,10 @@ func paveTacticsConfigFile(
       "RequestObfuscatedKey" : "%s",
       "DefaultTactics" : {
         "TTL" : "60s",
-        "Probability" : 1.0
+        "Probability" : 1.0,
+        "Parameters" : {
+          "LimitTunnelProtocols" : ["%s"]
+        }
       },
       "FilteredTactics" : [
         {
@@ -1215,6 +1229,7 @@ func paveTacticsConfigFile(
 	tacticsConfigJSON := fmt.Sprintf(
 		tacticsConfigJSONFormat,
 		tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
+		tunnelProtocol,
 		propagationChannelID)
 
 	err := ioutil.WriteFile(tacticsConfigFilename, []byte(tacticsConfigJSON), 0600)
@@ -1245,8 +1260,9 @@ const dummyClientVerificationPayload = `
 }`
 
 type testNetworkGetter struct {
+	prefix string
 }
 
-func (testNetworkGetter) GetNetworkID() string {
-	return "NETWORK1"
+func (t *testNetworkGetter) GetNetworkID() string {
+	return t.prefix + "NETWORK1"
 }

+ 9 - 1
psiphon/server/tunnelServer.go

@@ -143,6 +143,14 @@ func (server *TunnelServer) Run() error {
 			return common.ContextError(err)
 		}
 
+		tacticsListener := tactics.NewListener(
+			listener,
+			support.TacticsServer,
+			tunnelProtocol,
+			func(IPAddress string) common.GeoIPData {
+				return common.GeoIPData(support.GeoIPService.Lookup(IPAddress))
+			})
+
 		log.WithContextFields(
 			LogFields{
 				"localAddress":   localAddress,
@@ -152,7 +160,7 @@ func (server *TunnelServer) Run() error {
 		listeners = append(
 			listeners,
 			&sshListener{
-				Listener:       listener,
+				Listener:       tacticsListener,
 				localAddress:   localAddress,
 				tunnelProtocol: tunnelProtocol,
 			})