Просмотр исходного кода

Initial integration of packetman and psiphond

Rod Hynes 5 лет назад
Родитель
Сommit
bbbc498bff

+ 53 - 29
psiphon/common/packetman/packetman_linux.go

@@ -80,47 +80,37 @@ const (
 // NFQUEUE with queue-bypass requires Linux kernel 2.6.39; 3.16 or later is
 // NFQUEUE with queue-bypass requires Linux kernel 2.6.39; 3.16 or later is
 // validated and recommended.
 // validated and recommended.
 type Manipulator struct {
 type Manipulator struct {
-	config           *Config
-	mutex            sync.Mutex
-	runContext       context.Context
-	stopRunning      context.CancelFunc
-	waitGroup        *sync.WaitGroup
-	injectIPv4FD     int
-	injectIPv6FD     int
-	nfqueue          *nfqueue.Nfqueue
-	compiledSpecs    map[string]*compiledSpec
-	appliedSpecCache *cache.Cache
+	config             *Config
+	mutex              sync.Mutex
+	runContext         context.Context
+	stopRunning        context.CancelFunc
+	waitGroup          *sync.WaitGroup
+	injectIPv4FD       int
+	injectIPv6FD       int
+	nfqueue            *nfqueue.Nfqueue
+	compiledSpecsMutex sync.Mutex
+	compiledSpecs      map[string]*compiledSpec
+	appliedSpecCache   *cache.Cache
 }
 }
 
 
 // NewManipulator creates a new Manipulator.
 // NewManipulator creates a new Manipulator.
 func NewManipulator(config *Config) (*Manipulator, error) {
 func NewManipulator(config *Config) (*Manipulator, error) {
 
 
-	compiledSpecs := make(map[string]*compiledSpec)
+	m := &Manipulator{
+		config: config,
+	}
 
 
-	for _, spec := range config.Specs {
-		if spec.Name == "" {
-			return nil, errors.TraceNew("invalid spec name")
-		}
-		if _, ok := compiledSpecs[spec.Name]; ok {
-			return nil, errors.TraceNew("duplicate spec name")
-		}
-		compiledSpec, err := compileSpec(spec)
-		if err != nil {
-			return nil, errors.Trace(err)
-		}
-		compiledSpecs[spec.Name] = compiledSpec
+	err := m.SetSpecs(config.Specs)
+	if err != nil {
+		return nil, errors.Trace(err)
 	}
 	}
 
 
 	// To avoid memory exhaustion, do not retain unconsumed appliedSpecCache
 	// To avoid memory exhaustion, do not retain unconsumed appliedSpecCache
 	// entries for a longer time than it may reasonably take to complete the TCP
 	// entries for a longer time than it may reasonably take to complete the TCP
 	// handshake.
 	// handshake.
-	appliedSpecCache := cache.New(appliedSpecCacheTTL, appliedSpecCacheTTL/2)
+	m.appliedSpecCache = cache.New(appliedSpecCacheTTL, appliedSpecCacheTTL/2)
 
 
-	return &Manipulator{
-		config:           config,
-		compiledSpecs:    compiledSpecs,
-		appliedSpecCache: appliedSpecCache,
-	}, nil
+	return m, nil
 }
 }
 
 
 // Start initializes NFQUEUEs and raw sockets for packet manipulation. Start
 // Start initializes NFQUEUEs and raw sockets for packet manipulation. Start
@@ -283,6 +273,33 @@ func (m *Manipulator) Stop() {
 	m.configureIPTables(false)
 	m.configureIPTables(false)
 }
 }
 
 
+// SetSpecs installs a new set of packet transformation Spec values, replacing
+// the initial specs from Config.Specs, or any previous SetSpecs call. When
+// SetSpecs returns an error, the previous set of specs is retained.
+func (m *Manipulator) SetSpecs(specs []*Spec) error {
+
+	compiledSpecs := make(map[string]*compiledSpec)
+	for _, spec := range config.Specs {
+		if spec.Name == "" {
+			return errors.TraceNew("invalid spec name")
+		}
+		if _, ok := compiledSpecs[spec.Name]; ok {
+			return errors.TraceNew("duplicate spec name")
+		}
+		compiledSpec, err := compileSpec(spec)
+		if err != nil {
+			return errors.Trace(err)
+		}
+		compiledSpecs[spec.Name] = compiledSpec
+	}
+
+	m.compiledSpecsMutex.Lock()
+	m.compiledSpecs = compiledSpecs
+	m.compiledSpecsMutex.Unlock()
+
+	return nil
+}
+
 func makeConnectionID(
 func makeConnectionID(
 	srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) string {
 	srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) string {
 
 
@@ -515,7 +532,14 @@ func (m *Manipulator) getCompiledSpec(interceptedPacket gopacket.Packet) (*compi
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
+	// Concurrency note: m.compiledSpecs may be replaced by SetSpecs, but any
+	// reference to an individual compiledSpec remains valid; each compiledSpec
+	// is read-only.
+
+	m.compiledSpecsMutex.Lock()
 	spec, ok := m.compiledSpecs[specName]
 	spec, ok := m.compiledSpecs[specName]
+	m.compiledSpecsMutex.Unlock()
+
 	if !ok {
 	if !ok {
 		return nil, errors.Tracef("invalid spec name: %s", specName)
 		return nil, errors.Tracef("invalid spec name: %s", specName)
 	}
 	}

+ 4 - 0
psiphon/common/packetman/packetman_unsupported.go

@@ -48,6 +48,10 @@ func (m *Manipulator) Start() error {
 func (m *Manipulator) Stop() {
 func (m *Manipulator) Stop() {
 }
 }
 
 
+func (m *Manipulator) SetSpecs(_ []*Spec) error {
+	return errors.Trace(errUnsupported)
+}
+
 func (m *Manipulator) GetAppliedSpecName(_, _ *net.TCPAddr) (string, error) {
 func (m *Manipulator) GetAppliedSpecName(_, _ *net.TCPAddr) (string, error) {
 	return "", errors.Trace(errUnsupported)
 	return "", errors.Trace(errUnsupported)
 }
 }

+ 22 - 1
psiphon/common/parameters/clientParameters.go

@@ -239,6 +239,9 @@ const (
 	BPFServerTCPProbability                          = "BPFServerTCPProbability"
 	BPFServerTCPProbability                          = "BPFServerTCPProbability"
 	BPFClientTCPProgram                              = "BPFClientTCPProgram"
 	BPFClientTCPProgram                              = "BPFClientTCPProgram"
 	BPFClientTCPProbability                          = "BPFClientTCPProbability"
 	BPFClientTCPProbability                          = "BPFClientTCPProbability"
+	ServerPacketManipulationSpecs                    = "ServerPacketManipulationSpecs"
+	ServerProtocolPacketManipulations                = "ServerProtocolPacketManipulations"
+	ServerPacketManipulationProbability              = "ServerPacketManipulationProbability"
 )
 )
 
 
 const (
 const (
@@ -495,6 +498,10 @@ var defaultClientParameters = map[string]struct {
 	BPFServerTCPProbability: {value: 0.5, minimum: 0.0, flags: serverSideOnly},
 	BPFServerTCPProbability: {value: 0.5, minimum: 0.0, flags: serverSideOnly},
 	BPFClientTCPProgram:     {value: (*BPFProgramSpec)(nil)},
 	BPFClientTCPProgram:     {value: (*BPFProgramSpec)(nil)},
 	BPFClientTCPProbability: {value: 0.5, minimum: 0.0},
 	BPFClientTCPProbability: {value: 0.5, minimum: 0.0},
+
+	ServerPacketManipulationSpecs:       {value: PacketManipulationSpecs{}, flags: serverSideOnly},
+	ServerProtocolPacketManipulations:   {value: make(ProtocolPacketManipulations), flags: serverSideOnly},
+	ServerPacketManipulationProbability: {value: 0.5, minimum: 0.0, flags: serverSideOnly},
 }
 }
 
 
 // IsServerSideOnly indicates if the parameter specified by name is used
 // IsServerSideOnly indicates if the parameter specified by name is used
@@ -1078,7 +1085,7 @@ func (p ClientParametersAccessor) QUICVersions(name string) protocol.QUICVersion
 // corresponding to the specified labeled set and label value. The return
 // corresponding to the specified labeled set and label value. The return
 // value is nil when no set is found.
 // value is nil when no set is found.
 func (p ClientParametersAccessor) LabeledQUICVersions(name, label string) protocol.QUICVersions {
 func (p ClientParametersAccessor) LabeledQUICVersions(name, label string) protocol.QUICVersions {
-	var value protocol.LabeledQUICVersions
+	value := protocol.LabeledQUICVersions{}
 	p.snapshot.getValue(name, &value)
 	p.snapshot.getValue(name, &value)
 	return value[label]
 	return value[label]
 }
 }
@@ -1153,3 +1160,17 @@ func (p ClientParametersAccessor) BPFProgram(name string) (bool, string, []bpf.R
 	rawInstructions, _ := value.Assemble()
 	rawInstructions, _ := value.Assemble()
 	return true, value.Name, rawInstructions
 	return true, value.Name, rawInstructions
 }
 }
+
+// PacketManipulationSpecs returns a PacketManipulationSpecs parameter value.
+func (p ClientParametersAccessor) PacketManipulationSpecs(name string) PacketManipulationSpecs {
+	value := PacketManipulationSpecs{}
+	p.snapshot.getValue(name, &value)
+	return value
+}
+
+// ProtocolPacketManipulations returns a ProtocolPacketManipulations parameter value.
+func (p ClientParametersAccessor) ProtocolPacketManipulations(name string) ProtocolPacketManipulations {
+	value := make(ProtocolPacketManipulations)
+	p.snapshot.getValue(name, &value)
+	return value
+}

+ 10 - 0
psiphon/common/parameters/clientParameters_test.go

@@ -129,6 +129,16 @@ func TestGetDefaultParameters(t *testing.T) {
 					"BPFProgramSpec returned %+v %+v %+v expected %+v",
 					"BPFProgramSpec returned %+v %+v %+v expected %+v",
 					ok, name, rawInstructions, v)
 					ok, name, rawInstructions, v)
 			}
 			}
+		case PacketManipulationSpecs:
+			g := p.Get().PacketManipulationSpecs(name)
+			if !reflect.DeepEqual(v, g) {
+				t.Fatalf("PacketManipulationSpecs returned %+v expected %+v", g, v)
+			}
+		case ProtocolPacketManipulations:
+			g := p.Get().ProtocolPacketManipulations(name)
+			if !reflect.DeepEqual(v, g) {
+				t.Fatalf("ProtocolPacketManipulations returned %+v expected %+v", g, v)
+			}
 		default:
 		default:
 			t.Fatalf("Unhandled default type: %s", name)
 			t.Fatalf("Unhandled default type: %s", name)
 		}
 		}

+ 75 - 0
psiphon/common/parameters/packetman.go

@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2020, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package parameters
+
+import (
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/packetman"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+)
+
+// PacketManipulationSpecs is a list of packet manipulation specs.
+type PacketManipulationSpecs []*packetman.Spec
+
+// Validate checks that each spec name is unique and that each spec compiles.
+func (specs PacketManipulationSpecs) Validate() error {
+	specNames := make(map[string]bool)
+	for _, spec := range specs {
+		if spec.Name == "" {
+			return errors.TraceNew("missing spec name")
+		}
+		if ok, _ := specNames[spec.Name]; ok {
+			return errors.TraceNew("duplicate spec name")
+		}
+		specNames[spec.Name] = true
+		err := spec.Validate()
+		if err != nil {
+			return errors.Trace(err)
+		}
+	}
+	return nil
+}
+
+// ProtocolPacketManipulations is a map from tunnel protocol names (or "All")
+// to a list of packet manipulation spec names.
+type ProtocolPacketManipulations map[string][]string
+
+// Validate checks that tunnel protocol and spec names are valid. Duplicate
+// spec names are allowed in each entry, enabling weighted selection.
+func (manipulations ProtocolPacketManipulations) Validate(specs PacketManipulationSpecs) error {
+	validSpecNames := make(map[string]bool)
+	for _, spec := range specs {
+		validSpecNames[spec.Name] = true
+	}
+	for tunnelProtocol, specNames := range manipulations {
+		if tunnelProtocol != protocol.TUNNEL_PROTOCOLS_ALL {
+			if !protocol.TunnelProtocolMayUseServerPacketManipulation(tunnelProtocol) {
+				return errors.TraceNew("invalid tunnel protocol for packet manipulation")
+			}
+		}
+
+		for _, specName := range specNames {
+			if ok, _ := validSpecNames[specName]; !ok {
+				return errors.TraceNew("invalid spec name")
+			}
+		}
+	}
+	return nil
+}

+ 10 - 0
psiphon/common/protocol/protocol.go

@@ -43,6 +43,8 @@ const (
 	TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH          = "TAPDANCE-OSSH"
 	TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH          = "TAPDANCE-OSSH"
 	TUNNEL_PROTOCOL_CONJOUR_OBFUSCATED_SSH           = "CONJOUR-OSSH"
 	TUNNEL_PROTOCOL_CONJOUR_OBFUSCATED_SSH           = "CONJOUR-OSSH"
 
 
+	TUNNEL_PROTOCOLS_ALL = "All"
+
 	SERVER_ENTRY_SOURCE_EMBEDDED   = "EMBEDDED"
 	SERVER_ENTRY_SOURCE_EMBEDDED   = "EMBEDDED"
 	SERVER_ENTRY_SOURCE_REMOTE     = "REMOTE"
 	SERVER_ENTRY_SOURCE_REMOTE     = "REMOTE"
 	SERVER_ENTRY_SOURCE_DISCOVERY  = "DISCOVERY"
 	SERVER_ENTRY_SOURCE_DISCOVERY  = "DISCOVERY"
@@ -228,6 +230,14 @@ func TunnelProtocolSupportsUpstreamProxy(protocol string) bool {
 	return !TunnelProtocolUsesQUIC(protocol)
 	return !TunnelProtocolUsesQUIC(protocol)
 }
 }
 
 
+func TunnelProtocolMayUseServerPacketManipulation(protocol string) bool {
+	return protocol == TUNNEL_PROTOCOL_SSH ||
+		protocol == TUNNEL_PROTOCOL_OBFUSCATED_SSH ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS ||
+		protocol == TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET
+}
+
 func UseClientTunnelProtocol(
 func UseClientTunnelProtocol(
 	clientProtocol string,
 	clientProtocol string,
 	serverProtocols TunnelProtocols) bool {
 	serverProtocols TunnelProtocols) bool {

+ 3 - 0
psiphon/server/config.go

@@ -310,6 +310,9 @@ type Config struct {
 	// tun.ServerConfig.SudoNetworkConfigCommands.
 	// tun.ServerConfig.SudoNetworkConfigCommands.
 	PacketTunnelSudoNetworkConfigCommands bool
 	PacketTunnelSudoNetworkConfigCommands bool
 
 
+	// RunPacketManipulator specifies whether to run a packet manipulator.
+	RunPacketManipulator bool
+
 	// MaxConcurrentSSHHandshakes specifies a limit on the number of concurrent
 	// MaxConcurrentSSHHandshakes specifies a limit on the number of concurrent
 	// SSH handshake negotiations. This is set to mitigate spikes in memory
 	// SSH handshake negotiations. This is set to mitigate spikes in memory
 	// allocations and CPU usage associated with SSH handshakes when many clients
 	// allocations and CPU usage associated with SSH handshakes when many clients

+ 47 - 14
psiphon/server/meek.go

@@ -21,6 +21,7 @@ package server
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"context"
 	"crypto/rand"
 	"crypto/rand"
 	"crypto/tls"
 	"crypto/tls"
 	"encoding/base64"
 	"encoding/base64"
@@ -172,6 +173,12 @@ func NewMeekServer(
 	return meekServer, nil
 	return meekServer, nil
 }
 }
 
 
+type meekContextKey struct {
+	key string
+}
+
+var meekNetConnContextKey = &meekContextKey{"net.Conn"}
+
 // Run runs the meek server; this function blocks while serving HTTP or
 // Run runs the meek server; this function blocks while serving HTTP or
 // HTTPS connections on the specified listener. This function also runs
 // HTTPS connections on the specified listener. This function also runs
 // a goroutine which cleans up expired meek client sessions.
 // a goroutine which cleans up expired meek client sessions.
@@ -217,6 +224,9 @@ func (server *MeekServer) Run() error {
 		WriteTimeout: MEEK_HTTP_CLIENT_IO_TIMEOUT,
 		WriteTimeout: MEEK_HTTP_CLIENT_IO_TIMEOUT,
 		Handler:      server,
 		Handler:      server,
 		ConnState:    server.httpConnStateCallback,
 		ConnState:    server.httpConnStateCallback,
+		ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
+			return context.WithValue(ctx, meekNetConnContextKey, conn)
+		},
 
 
 		// Disable auto HTTP/2 (https://golang.org/doc/go1.6)
 		// Disable auto HTTP/2 (https://golang.org/doc/go1.6)
 		TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
 		TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
@@ -621,6 +631,8 @@ func (server *MeekServer) getSessionOrEndpoint(
 
 
 	session.touch()
 	session.touch()
 
 
+	underlyingConn := request.Context().Value(meekNetConnContextKey).(net.Conn)
+
 	// Create a new meek conn that will relay the payload
 	// Create a new meek conn that will relay the payload
 	// between meek request/responses and the tunnel server client
 	// between meek request/responses and the tunnel server client
 	// handler. The client IP is also used to initialize the
 	// handler. The client IP is also used to initialize the
@@ -632,6 +644,8 @@ func (server *MeekServer) getSessionOrEndpoint(
 	clientConn := newMeekConn(
 	clientConn := newMeekConn(
 		server,
 		server,
 		session,
 		session,
+		underlyingConn.LocalAddr(),
+		underlyingConn.RemoteAddr(),
 		&net.TCPAddr{
 		&net.TCPAddr{
 			IP:   net.ParseIP(clientIP),
 			IP:   net.ParseIP(clientIP),
 			Port: 0,
 			Port: 0,
@@ -1195,26 +1209,33 @@ func makeMeekSessionID() (string, error) {
 // connection by the tunnel server (being passed to sshServer.handleClient).
 // connection by the tunnel server (being passed to sshServer.handleClient).
 // meekConn bridges net/http request/response payload readers and writers
 // meekConn bridges net/http request/response payload readers and writers
 // and goroutines calling Read()s and Write()s.
 // and goroutines calling Read()s and Write()s.
+//
+// meekConn implements the UnderlyingTCPAddrSource, returning the TCP
+// addresses for the _first_ underlying TCP connection in the meek tunnel.
 type meekConn struct {
 type meekConn struct {
-	meekServer        *MeekServer
-	meekSession       *meekSession
-	remoteAddr        net.Addr
-	protocolVersion   int
-	closeBroadcast    chan struct{}
-	closed            int32
-	lastReadChecksum  *uint64
-	readLock          sync.Mutex
-	emptyReadBuffer   chan *bytes.Buffer
-	partialReadBuffer chan *bytes.Buffer
-	fullReadBuffer    chan *bytes.Buffer
-	writeLock         sync.Mutex
-	nextWriteBuffer   chan []byte
-	writeResult       chan error
+	meekServer           *MeekServer
+	meekSession          *meekSession
+	remoteAddr           net.Addr
+	underlyingLocalAddr  net.Addr
+	underlyingRemoteAddr net.Addr
+	protocolVersion      int
+	closeBroadcast       chan struct{}
+	closed               int32
+	lastReadChecksum     *uint64
+	readLock             sync.Mutex
+	emptyReadBuffer      chan *bytes.Buffer
+	partialReadBuffer    chan *bytes.Buffer
+	fullReadBuffer       chan *bytes.Buffer
+	writeLock            sync.Mutex
+	nextWriteBuffer      chan []byte
+	writeResult          chan error
 }
 }
 
 
 func newMeekConn(
 func newMeekConn(
 	meekServer *MeekServer,
 	meekServer *MeekServer,
 	meekSession *meekSession,
 	meekSession *meekSession,
+	underlyingLocalAddr net.Addr,
+	underlyingRemoteAddr net.Addr,
 	remoteAddr net.Addr,
 	remoteAddr net.Addr,
 	protocolVersion int) *meekConn {
 	protocolVersion int) *meekConn {
 
 
@@ -1238,6 +1259,18 @@ func newMeekConn(
 	return conn
 	return conn
 }
 }
 
 
+func (conn *meekConn) GetUnderlyingTCPAddrs() (*net.TCPAddr, *net.TCPAddr, bool) {
+	localAddr, ok := conn.underlyingLocalAddr.(*net.TCPAddr)
+	if !ok {
+		return nil, nil, false
+	}
+	remoteAddr, ok := conn.underlyingRemoteAddr.(*net.TCPAddr)
+	if !ok {
+		return nil, nil, false
+	}
+	return localAddr, remoteAddr, true
+}
+
 // pumpReads causes goroutines blocking on meekConn.Read() to read
 // pumpReads causes goroutines blocking on meekConn.Read() to read
 // from the specified reader. This function blocks until the reader
 // from the specified reader. This function blocks until the reader
 // is fully consumed or the meekConn is closed. A read buffer allows
 // is fully consumed or the meekConn is closed. A read buffer allows

+ 4 - 0
psiphon/server/net.go

@@ -77,3 +77,7 @@ func (server *HTTPSServer) ServeTLS(listener net.Listener, config *tris.Config)
 	tlsListener := tris.NewListener(listener, config)
 	tlsListener := tris.NewListener(listener, config)
 	return server.Serve(tlsListener)
 	return server.Serve(tlsListener)
 }
 }
+
+type UnderlyingTCPAddrSource interface {
+	GetUnderlyingTCPAddrs() (*net.TCPAddr, *net.TCPAddr, bool)
+}

+ 197 - 0
psiphon/server/packetman.go

@@ -0,0 +1,197 @@
+/*
+ * Copyright (c) 2020, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package server
+
+import (
+	"net"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/packetman"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+)
+
+func makePacketManipulatorConfig(
+	support *SupportServices) (*packetman.Config, error) {
+
+	// Packet interception is configured for any tunnel protocol port that _may_
+	// use packet manipulation. A future hot reload of tactics may apply specs to
+	// any of these protocols.
+
+	var ports []int
+	for tunnelProtocol, port := range support.Config.TunnelProtocolPorts {
+		if protocol.TunnelProtocolMayUseServerPacketManipulation(tunnelProtocol) {
+			ports = append(ports, port)
+		}
+	}
+
+	getSpecName := func(protocolPort int, clientIP net.IP) string {
+
+		specName, err := selectPacketManipulationSpec(support, protocolPort, clientIP)
+		if err != nil {
+			log.WithTraceFields(
+				LogFields{"error": err}).Warning(
+				"failed to get tactics for packet manipulation")
+			return ""
+		}
+
+		return specName
+	}
+
+	specs, err := getPacketManipulationSpecs(support)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	return &packetman.Config{
+		Logger:                    CommonLogger(log),
+		SudoNetworkConfigCommands: support.Config.PacketTunnelSudoNetworkConfigCommands,
+		QueueNumber:               1,
+		ProtocolPorts:             ports,
+		Specs:                     specs,
+		GetSpecName:               getSpecName,
+	}, nil
+}
+
+func getPacketManipulationSpecs(support *SupportServices) ([]*packetman.Spec, error) {
+
+	// By convention, parameters.ServerPacketManipulationSpecs should be in
+	// DefaultTactics, not FilteredTactics; and Tactics.Probability is ignored.
+
+	tactics, err := support.TacticsServer.GetTactics(
+		true, common.GeoIPData(NewGeoIPData()), make(common.APIParameters))
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+
+	if tactics == nil {
+		// This server isn't configured with tactics.
+		return []*packetman.Spec{}, nil
+	}
+
+	clientParameters, err := parameters.NewClientParameters(nil)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	_, err = clientParameters.Set("", false, tactics.Parameters)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	p := clientParameters.Get()
+
+	specs := p.PacketManipulationSpecs(parameters.ServerPacketManipulationSpecs)
+
+	return specs, nil
+}
+
+func reloadPacketManipulationSpecs(support *SupportServices) error {
+
+	specs, err := getPacketManipulationSpecs(support)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	err = support.PacketManipulator.SetSpecs(specs)
+	if err != nil {
+		return errors.Trace(err)
+	}
+
+	return nil
+}
+
+func selectPacketManipulationSpec(
+	support *SupportServices, protocolPort int, clientIP net.IP) (string, error) {
+
+	geoIPData := support.GeoIPService.Lookup(clientIP.String())
+
+	tactics, err := support.TacticsServer.GetTactics(
+		true, common.GeoIPData(geoIPData), make(common.APIParameters))
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+
+	if tactics == nil {
+		// This server isn't configured with tactics.
+		return "", nil
+	}
+
+	if !prng.FlipWeightedCoin(tactics.Probability) {
+		// Skip tactics with the configured probability.
+		return "", nil
+	}
+
+	clientParameters, err := parameters.NewClientParameters(nil)
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+	_, err = clientParameters.Set("", false, tactics.Parameters)
+	if err != nil {
+		return "", errors.Trace(err)
+	}
+	p := clientParameters.Get()
+
+	// GeoIP tactics filtering is applied before getting
+	// ServerPacketManipulationProbability and ServerProtocolPacketManipulations.
+	//
+	// The intercepted packet source/protocol port is used to determine the
+	// tunnel protocol name, which is used to lookup enabled packet manipulation
+	// specs in ServerProtocolPacketManipulations.
+	//
+	// When there are multiple enabled specs, one is selected at random.
+	//
+	// Specs under the key "All" apply to all protocols. Duplicate specs per
+	// entry are allowed, enabling weighted selection. If a spec appears in both
+	// "All" and a specific protocol, the duplicate(s) are retained.
+
+	if !p.WeightedCoinFlip(parameters.ServerPacketManipulationProbability) {
+		return "", nil
+	}
+
+	targetTunnelProtocol := ""
+	for tunnelProtocol, port := range support.Config.TunnelProtocolPorts {
+		if port == protocolPort {
+			targetTunnelProtocol = tunnelProtocol
+			break
+		}
+	}
+	if targetTunnelProtocol == "" {
+		return "", errors.Tracef(
+			"packet manipulation protocol port not found: %d", protocolPort)
+	}
+
+	protocolSpecs := p.ProtocolPacketManipulations(
+		parameters.ServerProtocolPacketManipulations)
+
+	// TODO: cache merged per-protocol + "All" lists?
+
+	specNames, ok := protocolSpecs[targetTunnelProtocol]
+	if !ok {
+		specNames = []string{}
+	}
+
+	allProtocolsSpecNames, ok := protocolSpecs[protocol.TUNNEL_PROTOCOLS_ALL]
+	if ok {
+		specNames = append(specNames, allProtocolsSpecNames...)
+	}
+
+	return specNames[prng.Range(0, len(specNames)-1)], nil
+}

+ 55 - 7
psiphon/server/services.go

@@ -38,6 +38,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/buildinfo"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/buildinfo"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/packetman"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/psinet"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/psinet"
@@ -46,25 +47,32 @@ import (
 // RunServices initializes support functions including logging and GeoIP services;
 // RunServices initializes support functions including logging and GeoIP services;
 // and then starts the server components and runs them until os.Interrupt or
 // and then starts the server components and runs them until os.Interrupt or
 // os.Kill signals are received. The config determines which components are run.
 // os.Kill signals are received. The config determines which components are run.
-func RunServices(configJSON []byte) error {
+func RunServices(configJSON []byte) (retErr error) {
+
+	loggingInitialized := false
+
+	defer func() {
+		if retErr != nil && loggingInitialized {
+			log.WithTraceFields(LogFields{"error": retErr}).Error("RunServices failed")
+		}
+	}()
 
 
 	rand.Seed(int64(time.Now().Nanosecond()))
 	rand.Seed(int64(time.Now().Nanosecond()))
 
 
 	config, err := LoadConfig(configJSON)
 	config, err := LoadConfig(configJSON)
 	if err != nil {
 	if err != nil {
-		log.WithTraceFields(LogFields{"error": err}).Error("load config failed")
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
 	err = InitLogging(config)
 	err = InitLogging(config)
 	if err != nil {
 	if err != nil {
-		log.WithTraceFields(LogFields{"error": err}).Error("init logging failed")
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
+	loggingInitialized = true
+
 	supportServices, err := NewSupportServices(config)
 	supportServices, err := NewSupportServices(config)
 	if err != nil {
 	if err != nil {
-		log.WithTraceFields(LogFields{"error": err}).Error("init support services failed")
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
@@ -78,7 +86,6 @@ func RunServices(configJSON []byte) error {
 
 
 	tunnelServer, err := NewTunnelServer(supportServices, shutdownBroadcast)
 	tunnelServer, err := NewTunnelServer(supportServices, shutdownBroadcast)
 	if err != nil {
 	if err != nil {
-		log.WithTraceFields(LogFields{"error": err}).Error("init tunnel server failed")
 		return errors.Trace(err)
 		return errors.Trace(err)
 	}
 	}
 
 
@@ -97,13 +104,27 @@ func RunServices(configJSON []byte) error {
 			AllowBogons:                 config.AllowBogons,
 			AllowBogons:                 config.AllowBogons,
 		})
 		})
 		if err != nil {
 		if err != nil {
-			log.WithTraceFields(LogFields{"error": err}).Error("init packet tunnel failed")
 			return errors.Trace(err)
 			return errors.Trace(err)
 		}
 		}
 
 
 		supportServices.PacketTunnelServer = packetTunnelServer
 		supportServices.PacketTunnelServer = packetTunnelServer
 	}
 	}
 
 
+	if config.RunPacketManipulator {
+
+		packetManipulatorConfig, err := makePacketManipulatorConfig(supportServices)
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		packetManipulator, err := packetman.NewManipulator(packetManipulatorConfig)
+		if err != nil {
+			return errors.Trace(err)
+		}
+
+		supportServices.PacketManipulator = packetManipulator
+	}
+
 	// After this point, errors should be delivered to the errors channel and
 	// After this point, errors should be delivered to the errors channel and
 	// orderly shutdown should flow through to the end of the function to ensure
 	// orderly shutdown should flow through to the end of the function to ensure
 	// all workers are synchronously stopped.
 	// all workers are synchronously stopped.
@@ -118,6 +139,23 @@ func RunServices(configJSON []byte) error {
 		}()
 		}()
 	}
 	}
 
 
+	if config.RunPacketManipulator {
+		err := supportServices.PacketManipulator.Start()
+		if err != nil {
+			select {
+			case errorChannel <- err:
+			default:
+			}
+		} else {
+			waitGroup.Add(1)
+			go func() {
+				defer waitGroup.Done()
+				<-shutdownBroadcast
+				supportServices.PacketManipulator.Stop()
+			}()
+		}
+	}
+
 	if config.RunLoadMonitor() {
 	if config.RunLoadMonitor() {
 		waitGroup.Add(1)
 		waitGroup.Add(1)
 		go func() {
 		go func() {
@@ -256,7 +294,6 @@ loop:
 			break loop
 			break loop
 
 
 		case err = <-errorChannel:
 		case err = <-errorChannel:
-			log.WithTraceFields(LogFields{"error": err}).Error("service failed")
 			break loop
 			break loop
 		}
 		}
 	}
 	}
@@ -397,6 +434,7 @@ type SupportServices struct {
 	PacketTunnelServer *tun.Server
 	PacketTunnelServer *tun.Server
 	TacticsServer      *tactics.Server
 	TacticsServer      *tactics.Server
 	Blocklist          *Blocklist
 	Blocklist          *Blocklist
+	PacketManipulator  *packetman.Manipulator
 }
 }
 
 
 // NewSupportServices initializes a new SupportServices.
 // NewSupportServices initializes a new SupportServices.
@@ -472,12 +510,22 @@ func (support *SupportServices) Reload() {
 	// reload; new tactics will be obtained on the next client handshake or
 	// reload; new tactics will be obtained on the next client handshake or
 	// tactics request.
 	// tactics request.
 
 
+	reloadSpecs := func() {
+		err := reloadPacketManipulationSpecs(support)
+		if err != nil {
+			log.WithTraceFields(
+				LogFields{"error": errors.Trace(err)}).Warning(
+				"failed to reload packet manipulation specs")
+		}
+	}
+
 	// Take these actions only after the corresponding Reloader has reloaded.
 	// Take these actions only after the corresponding Reloader has reloaded.
 	// In both the traffic rules and OSL cases, there is some impact from state
 	// In both the traffic rules and OSL cases, there is some impact from state
 	// reset, so the reset should be avoided where possible.
 	// reset, so the reset should be avoided where possible.
 	reloadPostActions := map[common.Reloader]func(){
 	reloadPostActions := map[common.Reloader]func(){
 		support.TrafficRulesSet: func() { support.TunnelServer.ResetAllClientTrafficRules() },
 		support.TrafficRulesSet: func() { support.TunnelServer.ResetAllClientTrafficRules() },
 		support.OSLConfig:       func() { support.TunnelServer.ResetAllClientOSLConfigs() },
 		support.OSLConfig:       func() { support.TunnelServer.ResetAllClientOSLConfigs() },
+		support.TacticsServer:   reloadSpecs,
 	}
 	}
 
 
 	for _, reloader := range reloaders {
 	for _, reloader := range reloaders {

+ 52 - 12
psiphon/server/tunnelServer.go

@@ -1057,6 +1057,39 @@ func (sshServer *sshServer) handleClient(
 		}
 		}
 	}
 	}
 
 
+	serverPacketManipulation := ""
+	if protocol.TunnelProtocolMayUseServerPacketManipulation(tunnelProtocol) {
+
+		// A meekConn has synthetic address values, including the original client
+		// address in cases where the client uses an upstream proxy to connect to
+		// Psiphon. For meekConn, and any other conn implementing
+		// UnderlyingTCPAddrSource, get the underlying TCP connection addresses.
+		//
+		// Limitation: a meek tunnel may consist of several TCP connections. The
+		// server_packet_manipulation metric will reflect the packet manipulation
+		// applied to the _first_ TCP connection only.
+
+		var localAddr, remoteAddr *net.TCPAddr
+		var ok bool
+		underlying, ok := clientConn.(UnderlyingTCPAddrSource)
+		if ok {
+			localAddr, remoteAddr, ok = underlying.GetUnderlyingTCPAddrs()
+		} else {
+			localAddr, ok = clientConn.LocalAddr().(*net.TCPAddr)
+			if ok {
+				remoteAddr, ok = clientConn.RemoteAddr().(*net.TCPAddr)
+			}
+		}
+
+		if ok {
+			specName, err := sshServer.support.PacketManipulator.
+				GetAppliedSpecName(localAddr, remoteAddr)
+			if err == nil {
+				serverPacketManipulation = specName
+			}
+		}
+	}
+
 	geoIPData := sshServer.support.GeoIPService.Lookup(
 	geoIPData := sshServer.support.GeoIPService.Lookup(
 		common.IPAddressFromAddr(clientAddr))
 		common.IPAddressFromAddr(clientAddr))
 
 
@@ -1112,6 +1145,7 @@ func (sshServer *sshServer) handleClient(
 		sshServer,
 		sshServer,
 		sshListener,
 		sshListener,
 		tunnelProtocol,
 		tunnelProtocol,
+		serverPacketManipulation,
 		geoIPData)
 		geoIPData)
 
 
 	// sshClient.run _must_ call onSSHHandshakeFinished to release the semaphore:
 	// sshClient.run _must_ call onSSHHandshakeFinished to release the semaphore:
@@ -1156,6 +1190,7 @@ type sshClient struct {
 	sshConn                              ssh.Conn
 	sshConn                              ssh.Conn
 	activityConn                         *common.ActivityMonitoredConn
 	activityConn                         *common.ActivityMonitoredConn
 	throttledConn                        *common.ThrottledConn
 	throttledConn                        *common.ThrottledConn
+	serverPacketManipulation             string
 	geoIPData                            GeoIPData
 	geoIPData                            GeoIPData
 	sessionID                            string
 	sessionID                            string
 	isFirstTunnelInSession               bool
 	isFirstTunnelInSession               bool
@@ -1246,6 +1281,7 @@ func newSshClient(
 	sshServer *sshServer,
 	sshServer *sshServer,
 	sshListener *sshListener,
 	sshListener *sshListener,
 	tunnelProtocol string,
 	tunnelProtocol string,
+	serverPacketManipulation string,
 	geoIPData GeoIPData) *sshClient {
 	geoIPData GeoIPData) *sshClient {
 
 
 	runCtx, stopRunning := context.WithCancel(context.Background())
 	runCtx, stopRunning := context.WithCancel(context.Background())
@@ -1255,18 +1291,19 @@ func newSshClient(
 	// unthrottled bytes during the initial protocol negotiation.
 	// unthrottled bytes during the initial protocol negotiation.
 
 
 	client := &sshClient{
 	client := &sshClient{
-		sshServer:              sshServer,
-		sshListener:            sshListener,
-		tunnelProtocol:         tunnelProtocol,
-		geoIPData:              geoIPData,
-		isFirstTunnelInSession: true,
-		tcpPortForwardLRU:      common.NewLRUConns(),
-		signalIssueSLOKs:       make(chan struct{}, 1),
-		runCtx:                 runCtx,
-		stopRunning:            stopRunning,
-		stopped:                make(chan struct{}),
-		sendAlertRequests:      make(chan protocol.AlertRequest, ALERT_REQUEST_QUEUE_BUFFER_SIZE),
-		sentAlertRequests:      make(map[protocol.AlertRequest]bool),
+		sshServer:                sshServer,
+		sshListener:              sshListener,
+		tunnelProtocol:           tunnelProtocol,
+		serverPacketManipulation: serverPacketManipulation,
+		geoIPData:                geoIPData,
+		isFirstTunnelInSession:   true,
+		tcpPortForwardLRU:        common.NewLRUConns(),
+		signalIssueSLOKs:         make(chan struct{}, 1),
+		runCtx:                   runCtx,
+		stopRunning:              stopRunning,
+		stopped:                  make(chan struct{}),
+		sendAlertRequests:        make(chan protocol.AlertRequest, ALERT_REQUEST_QUEUE_BUFFER_SIZE),
+		sentAlertRequests:        make(map[protocol.AlertRequest]bool),
 	}
 	}
 
 
 	client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
 	client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
@@ -2261,6 +2298,9 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
 	// unconditionally, overwriting any value from handshake.
 	// unconditionally, overwriting any value from handshake.
 	logFields["relay_protocol"] = sshClient.tunnelProtocol
 	logFields["relay_protocol"] = sshClient.tunnelProtocol
 
 
+	if sshClient.serverPacketManipulation != "" {
+		logFields["server_packet_manipulation"] = sshClient.serverPacketManipulation
+	}
 	if sshClient.sshListener.BPFProgramName != "" {
 	if sshClient.sshListener.BPFProgramName != "" {
 		logFields["server_bpf"] = sshClient.sshListener.BPFProgramName
 		logFields["server_bpf"] = sshClient.sshListener.BPFProgramName
 	}
 	}