Procházet zdrojové kódy

Add tests checking for expected logs, fields, and field values

Rod Hynes před 7 roky
rodič
revize
6eddfab663
2 změnil soubory, kde provedl 289 přidání a 4 odebrání
  1. 21 0
      psiphon/server/log.go
  2. 268 4
      psiphon/server/server_test.go

+ 21 - 0
psiphon/server/log.go

@@ -27,6 +27,7 @@ import (
 	go_log "log"
 	"os"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/Psiphon-Inc/rotate-safe-writer"
@@ -149,6 +150,22 @@ func NewLogWriter() *io.PipeWriter {
 type CustomJSONFormatter struct {
 }
 
+var (
+	useLogCallback int32
+	logCallback    atomic.Value
+)
+
+// setLogCallback sets a callback that is invoked with each JSON log message.
+// This facility is intended for use in testing only.
+func setLogCallback(callback func([]byte)) {
+	if callback == nil {
+		atomic.StoreInt32(&useLogCallback, 0)
+		return
+	}
+	atomic.StoreInt32(&useLogCallback, 1)
+	logCallback.Store(callback)
+}
+
 const customJSONFormatterLogRawFieldsWithTimestamp = "CustomJSONFormatter.LogRawFieldsWithTimestamp"
 
 // Format implements logrus.Formatter. This is a customized version
@@ -197,6 +214,10 @@ func (f *CustomJSONFormatter) Format(entry *logrus.Entry) ([]byte, error) {
 		return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err)
 	}
 
+	if atomic.LoadInt32(&useLogCallback) == 1 {
+		logCallback.Load().(func([]byte))(serialized)
+	}
+
 	return append(serialized, '\n'), nil
 }
 

+ 268 - 4
psiphon/server/server_test.go

@@ -480,6 +480,11 @@ type runServerConfig struct {
 	forceLivenessTest    bool
 }
 
+var (
+	testSSHClientVersions = []string{"SSH-2.0-A", "SSH-2.0-B", "SSH-2.0-C"}
+	testUserAgents        = []string{"ua1", "ua2", "ua3"}
+)
+
 func runServer(t *testing.T, runConfig *runServerConfig) {
 
 	// configure authorized access
@@ -618,6 +623,36 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 	serverConfigJSON, _ = json.Marshal(serverConfig)
 
+	serverConnectedLog := make(chan map[string]interface{}, 1)
+	serverTunnelLog := make(chan map[string]interface{}, 1)
+
+	setLogCallback(func(log []byte) {
+
+		logFields := make(map[string]interface{})
+
+		err := json.Unmarshal(log, &logFields)
+		if err != nil {
+			return
+		}
+
+		if logFields["event_name"] == nil {
+			return
+		}
+
+		switch logFields["event_name"].(string) {
+		case "connected":
+			select {
+			case serverConnectedLog <- logFields:
+			default:
+			}
+		case "server_tunnel":
+			select {
+			case serverTunnelLog <- logFields:
+			default:
+			}
+		}
+	})
+
 	// run server
 
 	serverWaitGroup := new(sync.WaitGroup)
@@ -630,7 +665,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			t.Fatalf("error running server: %s", err)
 		}
 	}()
-	defer func() {
+
+	stopServer := func() {
 
 		// Test: orderly server shutdown
 
@@ -650,6 +686,13 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		case <-shutdownTimeout.C:
 			t.Fatalf("server shutdown timeout exceeded")
 		}
+	}
+
+	// Stop server on early exits due to failure.
+	defer func() {
+		if stopServer != nil {
+			stopServer()
+		}
 	}()
 
 	// TODO: monitor logs for more robust wait-until-loaded. For example,
@@ -687,7 +730,15 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	p, _ := os.FindProcess(os.Getpid())
 	p.Signal(syscall.SIGUSR2)
 
-	// connect to server with client
+	// configure client
+
+	psiphon.RegisterSSHClientVersionPicker(func() string {
+		return testSSHClientVersions[prng.Intn(len(testSSHClientVersions))]
+	})
+
+	psiphon.RegisterUserAgentPicker(func() string {
+		return testUserAgents[prng.Intn(len(testUserAgents))]
+	})
 
 	// TODO: currently, TargetServerEntry only works with one tunnel
 	numTunnels := 1
@@ -710,6 +761,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
         "ClientVersion" : "0",
         "SponsorId" : "0",
         "PropagationChannelId" : "0",
+        "TunnelWholeDevice" : 1,
+        "DeviceRegion" : "US",
         "DisableRemoteServerListFetcher" : true,
         "EstablishTunnelPausePeriodSeconds" : 1,
         "ConnectionWorkerPoolSize" : %d,
@@ -787,6 +840,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		}
 	}
 
+	// connect to server with client
+
 	err = psiphon.OpenDataStore(clientConfig)
 	if err != nil {
 		t.Fatalf("error initializing client datastore: %s", err)
@@ -803,6 +858,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	tunnelsEstablished := make(chan struct{}, 1)
 	homepageReceived := make(chan struct{}, 1)
 	slokSeeded := make(chan struct{}, 1)
+	clientConnectedNotice := make(chan map[string]interface{}, 1)
 
 	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
@@ -815,11 +871,13 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			}
 
 			switch noticeType {
+
 			case "Tunnels":
 				count := int(payload["count"].(float64))
 				if count >= numTunnels {
 					sendNotificationReceived(tunnelsEstablished)
 				}
+
 			case "Homepage":
 				homepageURL := payload["url"].(string)
 				if homepageURL != expectedHomepageURL {
@@ -827,8 +885,15 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 					t.Fatalf("unexpected homepage: %s", homepageURL)
 				}
 				sendNotificationReceived(homepageReceived)
+
 			case "SLOKSeeded":
 				sendNotificationReceived(slokSeeded)
+
+			case "ConnectedServer":
+				select {
+				case clientConnectedNotice <- payload:
+				default:
+				}
 			}
 		}))
 
@@ -842,7 +907,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		controller.Run(ctx)
 	}()
 
-	defer func() {
+	stopClient := func() {
 		cancelFunc()
 
 		shutdownTimeout := time.NewTimer(20 * time.Second)
@@ -858,6 +923,13 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		case <-shutdownTimeout.C:
 			t.Fatalf("controller shutdown timeout exceeded")
 		}
+	}
+
+	// Stop client on early exits due to failure.
+	defer func() {
+		if stopClient != nil {
+			stopClient()
+		}
 	}()
 
 	// Test: tunnels must be established, and correct homepage
@@ -924,6 +996,196 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			t.Fatalf("unexpected number of SLOKs: %d", numSLOKs)
 		}
 	}
+
+	// Shutdown to ensure logs/notices are flushed
+
+	stopClient()
+	stopClient = nil
+	stopServer()
+	stopServer = nil
+
+	// TODO: stops should be fully synchronous, but, intermittently,
+	// server_tunnel fails to appear ("missing server tunnel log")
+	// without this delay.
+	time.Sleep(100 * time.Millisecond)
+
+	// Test: all expected logs/notices were emitted
+
+	select {
+	case <-clientConnectedNotice:
+	default:
+		t.Fatalf("missing client connected notice")
+	}
+
+	select {
+	case logFields := <-serverConnectedLog:
+		err := checkExpectedLogFields(runConfig, logFields)
+		if err != nil {
+			t.Fatalf("invalid server connected log fields: %s", err)
+		}
+	default:
+		t.Fatalf("missing server connected log")
+	}
+
+	select {
+	case logFields := <-serverTunnelLog:
+		err := checkExpectedLogFields(runConfig, logFields)
+		if err != nil {
+			t.Fatalf("invalid server tunnel log fields: %s", err)
+		}
+	default:
+		t.Fatalf("missing server tunnel log")
+	}
+}
+
+func checkExpectedLogFields(runConfig *runServerConfig, fields map[string]interface{}) error {
+
+	// Limitations:
+	//
+	// - client_build_rev not set in test build (see common/buildinfo.go)
+	// - egress_region, upstream_proxy_type, upstream_proxy_custom_header_names not exercised in test
+	// - meek_dial_ip_address/meek_resolved_ip_address only logged for FRONTED meek protocols
+
+	for _, name := range []string{
+		"session_id",
+		"last_connected",
+		"establishment_duration",
+		"propagation_channel_id",
+		"sponsor_id",
+		"client_platform",
+		"relay_protocol",
+		"tunnel_whole_device",
+		"device_region",
+		"ssh_client_version",
+		"server_entry_region",
+		"server_entry_source",
+		"server_entry_timestamp",
+		"dial_port_number",
+		"is_replay",
+		"dial_duration",
+		"candidate_number",
+	} {
+		if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+	}
+
+	if fields["relay_protocol"] != runConfig.tunnelProtocol {
+		return fmt.Errorf("unexpected relay_protocol '%s'", fields["relay_protocol"])
+	}
+
+	if !common.Contains(testSSHClientVersions, fields["ssh_client_version"].(string)) {
+		return fmt.Errorf("unexpected relay_protocol '%s'", fields["ssh_client_version"])
+	}
+
+	if protocol.TunnelProtocolUsesObfuscatedSSH(runConfig.tunnelProtocol) {
+
+		for _, name := range []string{
+			"padding",
+			"pad_response",
+		} {
+			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
+				return fmt.Errorf("missing expected field '%s'", name)
+			}
+		}
+	}
+
+	if protocol.TunnelProtocolUsesMeek(runConfig.tunnelProtocol) {
+
+		for _, name := range []string{
+			"user_agent",
+			"meek_transformed_host_name",
+			tactics.APPLIED_TACTICS_TAG_PARAMETER_NAME,
+		} {
+			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
+				return fmt.Errorf("missing expected field '%s'", name)
+			}
+		}
+
+		if !common.Contains(testUserAgents, fields["user_agent"].(string)) {
+			return fmt.Errorf("unexpected user_agent '%s'", fields["user_agent"])
+		}
+	}
+
+	if protocol.TunnelProtocolUsesMeekHTTP(runConfig.tunnelProtocol) {
+
+		for _, name := range []string{
+			"meek_host_header",
+		} {
+			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
+				return fmt.Errorf("missing expected field '%s'", name)
+			}
+		}
+
+		for _, name := range []string{
+			"meek_dial_ip_address",
+			"meek_resolved_ip_address",
+		} {
+			if fields[name] != nil {
+				return fmt.Errorf("unexpected field '%s'", name)
+			}
+		}
+	}
+
+	if protocol.TunnelProtocolUsesMeekHTTPS(runConfig.tunnelProtocol) {
+
+		for _, name := range []string{
+			"tls_profile",
+			"meek_sni_server_name",
+		} {
+			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
+				return fmt.Errorf("missing expected field '%s'", name)
+			}
+		}
+
+		for _, name := range []string{
+			"meek_dial_ip_address",
+			"meek_resolved_ip_address",
+			"meek_host_header",
+		} {
+			if fields[name] != nil {
+				return fmt.Errorf("unexpected field '%s'", name)
+			}
+		}
+
+		if !common.Contains(protocol.SupportedTLSProfiles, fields["tls_profile"].(string)) {
+			return fmt.Errorf("unexpected tls_profile '%s'", fields["tls_profile"])
+		}
+
+	}
+
+	if protocol.TunnelProtocolUsesQUIC(runConfig.tunnelProtocol) {
+
+		for _, name := range []string{
+			"quic_version",
+			"quic_dial_sni_address",
+		} {
+			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
+				return fmt.Errorf("missing expected field '%s'", name)
+			}
+		}
+
+		if !common.Contains(protocol.SupportedQUICVersions, fields["quic_version"].(string)) {
+			return fmt.Errorf("unexpected quic_version '%s'", fields["quic_version"])
+		}
+	}
+
+	if runConfig.forceFragmenting {
+
+		for _, name := range []string{
+			"upstream_bytes_fragmented",
+			"upstream_min_bytes_written",
+			"upstream_max_bytes_written",
+			"upstream_min_delayed",
+			"upstream_max_delayed",
+		} {
+			if fields[name] == nil || fmt.Sprintf("%s", fields[name]) == "" {
+				return fmt.Errorf("missing expected field '%s'", name)
+			}
+		}
+	}
+
+	return nil
 }
 
 func makeTunneledWebRequest(
@@ -1424,7 +1686,9 @@ func paveTacticsConfigFile(
           "Tactics" : {
             "Parameters" : {
               "TunnelConnectTimeout" : "20s",
-              "TunnelRateLimits" : {"WriteBytesPerSecond": 1000000}
+              "TunnelRateLimits" : {"WriteBytesPerSecond": 1000000},
+              "TransformHostNameProbability" : 1.0,
+              "PickUserAgentProbability" : 1.0
             }
           }
         }