Преглед изворни кода

Merge pull request #659 from mirokuratczyk/passthrough-address

Fix: log passthrough address when demux used
Rod Hynes пре 2 година
родитељ
комит
b2279f7bfa
2 измењених фајлова са 29 додато и 1 уклоњено
  1. 11 0
      psiphon/server/demux.go
  2. 18 1
      psiphon/server/server_test.go

+ 11 - 0
psiphon/server/demux.go

@@ -26,6 +26,7 @@ import (
 	"net"
 	"time"
 
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/sirupsen/logrus"
 )
@@ -319,3 +320,13 @@ func (conn *bufferedConn) Read(b []byte) (n int, err error) {
 
 	return conn.Conn.Read(b)
 }
+
+// GetMetrics implements the common.MetricsSource interface.
+func (conn *bufferedConn) GetMetrics() common.LogFields {
+	// Relay any metrics from the underlying conn.
+	m, ok := conn.Conn.(common.MetricsSource)
+	if ok {
+		return m.GetMetrics()
+	}
+	return nil
+}

+ 18 - 1
psiphon/server/server_test.go

@@ -201,6 +201,7 @@ func TestTLSOSSH(t *testing.T) {
 	runServer(t,
 		&runServerConfig{
 			tunnelProtocol:       "TLS-OSSH",
+			passthrough:          true,
 			enableSSHAPIRequests: true,
 			requireAuthorization: true,
 			doTunneledWebRequest: true,
@@ -681,11 +682,15 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	}
 
 	var tunnelProtocolPassthroughAddresses map[string]string
+	var passthroughAddress *string
 
 	if runConfig.passthrough {
+		passthroughAddress = new(string)
+		*passthroughAddress = "x.x.x.x:x"
+
 		tunnelProtocolPassthroughAddresses = map[string]string{
 			// Tests do not trigger passthrough so set invalid IP and port.
-			runConfig.tunnelProtocol: "x.x.x.x:x",
+			runConfig.tunnelProtocol: *passthroughAddress,
 		}
 	}
 
@@ -1467,6 +1472,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 			expectUDPDataTransfer,
 			expectQUICVersion,
 			expectDestinationBytesFields,
+			passthroughAddress,
 			logFields)
 		if err != nil {
 			t.Fatalf("invalid server tunnel log fields: %s", err)
@@ -1604,6 +1610,7 @@ func checkExpectedServerTunnelLogFields(
 	expectUDPDataTransfer bool,
 	expectQUICVersion string,
 	expectDestinationBytesFields bool,
+	expectPassthroughAddress *string,
 	fields map[string]interface{}) error {
 
 	// Limitations:
@@ -2063,6 +2070,16 @@ func checkExpectedServerTunnelLogFields(
 		}
 	}
 
+	if expectPassthroughAddress != nil {
+		name := "passthrough_address"
+		if fields[name] == nil {
+			return fmt.Errorf("missing expected field '%s'", name)
+		}
+		if fields[name] != *expectPassthroughAddress {
+			return fmt.Errorf("unexpected field value %s: %v != %v", name, fields[name], *expectPassthroughAddress)
+		}
+	}
+
 	if runConfig.doLogHostProvider {
 		name := "provider"
 		if fields[name] == nil {