Просмотр исходного кода

Integrate packet tunnel server into psiphond

Rod Hynes 8 лет назад
Родитель
Сommit
2a76851430

+ 2 - 0
.travis.yml

@@ -12,6 +12,7 @@ script:
 - go test -race -v ./common
 - go test -race -v ./common/osl
 - go test -race -v ./common/protocol
+- go test -race -v ./common/tun
 - go test -race -v ./transferstats
 - go test -race -v ./server
 - go test -race -v ./server/psinet
@@ -19,6 +20,7 @@ script:
 - go test -v -covermode=count -coverprofile=common.coverprofile ./common
 - go test -v -covermode=count -coverprofile=osl.coverprofile ./common/osl
 - go test -v -covermode=count -coverprofile=protocol.coverprofile ./common/protocol
+- go test -v -covermode=count -coverprofile=tun.coverprofile ./common/tun
 - go test -v -covermode=count -coverprofile=transferstats.coverprofile ./transferstats
 - go test -v -covermode=count -coverprofile=server.coverprofile ./server
 - go test -v -covermode=count -coverprofile=psinet.coverprofile ./server/psinet

+ 6 - 2
Server/main.go

@@ -118,12 +118,16 @@ func main() {
 		serverIPaddress := generateServerIPaddress
 
 		if generateServerNetworkInterface != "" {
-			var err error
-			serverIPaddress, err = common.GetInterfaceIPAddress(generateServerNetworkInterface)
+			// TODO: IPv6 support
+			serverIPv4Address, _, err := common.GetInterfaceIPAddresses(generateServerNetworkInterface)
+			if err == nil && serverIPv4Address == nil {
+				err = fmt.Errorf("no IPv4 address for interface %s", generateServerNetworkInterface)
+			}
 			if err != nil {
 				fmt.Printf("generate failed: %s\n", err)
 				os.Exit(1)
 			}
+			serverIPaddress = serverIPv4Address.String()
 		}
 
 		tunnelProtocolPorts := make(map[string]int)

+ 4 - 4
psiphon/common/logger.go

@@ -33,10 +33,10 @@ type Logger interface {
 // LogContext is interface-compatible with the return values from
 // psiphon/server.ContextLogger.WithContext/WithContextFields.
 type LogContext interface {
-	Debug(message string)
-	Info(message string)
-	Warning(message string)
-	Error(message string)
+	Debug(args ...interface{})
+	Info(args ...interface{})
+	Warning(args ...interface{})
+	Error(args ...interface{})
 }
 
 // LogFields is type-compatible with psiphon/server.LogFields

+ 2 - 0
psiphon/common/protocol/protocol.go

@@ -54,6 +54,8 @@ const (
 
 	PSIPHON_SSH_API_PROTOCOL = "ssh"
 	PSIPHON_WEB_API_PROTOCOL = "web"
+
+	PACKET_TUNNEL_CHANNEL_TYPE = "tun@psiphon.ca"
 )
 
 var SupportedTunnelProtocols = []string{

+ 50 - 47
psiphon/common/tun/tun.go

@@ -141,7 +141,6 @@ import (
 )
 
 const (
-	CHANNEL_NAME                         = "tun@psiphon.ca"
 	DEFAULT_MTU                          = 1500
 	DEFAULT_DOWNSTREAM_PACKET_QUEUE_SIZE = 64
 	DEFAULT_IDLE_SESSION_EXPIRY_SECONDS  = 300
@@ -300,12 +299,14 @@ func (server *Server) Stop() {
 	server.config.Logger.WithContext().Info("stopped")
 }
 
+type AllowedPortChecker func(upstreamIPAddress net.IP, port int) bool
+
 // ClientConnected handles new client connections, creating or resuming
 // a session and returns with client packet handlers running.
 //
 // sessionID is used to identify sessions for resumption.
 //
-// transportConn provides the channel for relaying packets to and from
+// transport provides the channel for relaying packets to and from
 // the client.
 //
 // checkAllowedTCPPortFunc/checkAllowedUDPPortFunc are callbacks used
@@ -324,8 +325,8 @@ func (server *Server) Stop() {
 // a new SSH client connection.)
 func (server *Server) ClientConnected(
 	sessionID string,
-	transportConn net.Conn,
-	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc func(port int) bool) error {
+	transport io.ReadWriteCloser,
+	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker) error {
 
 	// It's unusual to call both sync.WaitGroup.Add() _and_ Done() in the same
 	// goroutine. There's no other place to call Add() since ClientConnected is
@@ -408,7 +409,7 @@ func (server *Server) ClientConnected(
 		}
 	}
 
-	server.resumeSession(clientSession, NewChannel(transportConn, MTU))
+	server.resumeSession(clientSession, NewChannel(transport, MTU))
 
 	return nil
 }
@@ -921,8 +922,8 @@ type session struct {
 	assignedIPv6Address      net.IP
 	setOriginalIPv6Address   int32
 	originalIPv6Address      net.IP
-	checkAllowedTCPPortFunc  func(port int) bool
-	checkAllowedUDPPortFunc  func(port int) bool
+	checkAllowedTCPPortFunc  AllowedPortChecker
+	checkAllowedUDPPortFunc  AllowedPortChecker
 	downstreamPackets        chan []byte
 	freePackets              chan []byte
 	workers                  *sync.WaitGroup
@@ -1135,10 +1136,10 @@ type ClientConfig struct {
 	// should be obtained from the packet tunnel server.
 	MTU int
 
-	// TransportConn is an established transport channel
-	// that will be used to relay packets to and from a
-	// packet tunnel server.
-	TransportConn net.Conn
+	// Transport is an established transport channel that
+	// will be used to relay packets to and from a packet
+	// tunnel server.
+	Transport io.ReadWriteCloser
 
 	// TunFD specifies a file descriptor to use to read
 	// and write packets to be relayed to the client. When
@@ -1204,7 +1205,7 @@ func NewClient(config *ClientConfig) (*Client, error) {
 	return &Client{
 		config:      config,
 		device:      device,
-		channel:     NewChannel(config.TransportConn, config.MTU),
+		channel:     NewChannel(config.Transport, config.MTU),
 		metrics:     new(packetMetrics),
 		runContext:  runContext,
 		stopRunning: stopRunning,
@@ -1646,6 +1647,16 @@ func processPacket(
 		}
 	}
 
+	var upstreamIPAddress net.IP
+	if direction == packetDirectionServerUpstream {
+
+		upstreamIPAddress = destinationIPAddress
+
+	} else if direction == packetDirectionServerDownstream {
+
+		upstreamIPAddress = sourceIPAddress
+	}
+
 	// Enforce traffic rules (allowed TCP/UDP ports).
 
 	if direction == packetDirectionServerUpstream ||
@@ -1653,7 +1664,8 @@ func processPacket(
 
 		if protocol == internetProtocolTCP {
 
-			if !session.checkAllowedTCPPortFunc(int(destinationPort)) {
+			if !session.checkAllowedTCPPortFunc(
+				upstreamIPAddress, int(destinationPort)) {
 
 				metrics.rejectedPacket(direction, packetRejectTCPPort)
 				return false
@@ -1661,7 +1673,8 @@ func processPacket(
 
 		} else if protocol == internetProtocolUDP {
 
-			if !session.checkAllowedUDPPortFunc(int(destinationPort)) {
+			if !session.checkAllowedUDPPortFunc(
+				upstreamIPAddress, int(destinationPort)) {
 
 				metrics.rejectedPacket(direction, packetRejectUDPPort)
 				return false
@@ -1683,7 +1696,6 @@ func processPacket(
 
 	// Configure rewriting.
 
-	var upstreamIPAddress net.IP
 	var checksumAccumulator int32
 	var rewriteSourceIPAddress, rewriteDestinationIPAddress net.IP
 
@@ -1712,8 +1724,6 @@ func processPacket(
 			}
 		}
 
-		upstreamIPAddress = destinationIPAddress
-
 	} else if direction == packetDirectionServerDownstream {
 
 		// Destination address will be original source address.
@@ -1755,8 +1765,6 @@ func processPacket(
 				}
 			}
 		}
-
-		upstreamIPAddress = sourceIPAddress
 	}
 
 	// Apply rewrites. IP (v4 only) and TCP/UDP all have packet
@@ -1990,13 +1998,13 @@ func (device *Device) Close() error {
 	return device.deviceIO.Close()
 }
 
-// Channel manages packet transport over a communications
-// channel. Any net.Conn can provide transport. In psiphond,
-// the net.Conn will be an SSH channel. Channel I/O frames
+// Channel manages packet transport over a communications channel.
+// Any io.ReadWriteCloser can provide transport. In psiphond, the
+// io.ReadWriteCloser will be an SSH channel. Channel I/O frames
 // packets with a length header and uses static, preallocated
 // buffers to avoid GC churn.
 type Channel struct {
-	conn           net.Conn
+	transport      io.ReadWriteCloser
 	inboundBuffer  []byte
 	outboundBuffer []byte
 }
@@ -2008,9 +2016,9 @@ const (
 )
 
 // NewChannel initializes a new Channel.
-func NewChannel(conn net.Conn, MTU int) *Channel {
+func NewChannel(transport io.ReadWriteCloser, MTU int) *Channel {
 	return &Channel{
-		conn:           conn,
+		transport:      transport,
 		inboundBuffer:  make([]byte, channelHeaderSize+MTU),
 		outboundBuffer: make([]byte, channelHeaderSize+MTU),
 	}
@@ -2023,7 +2031,7 @@ func NewChannel(conn net.Conn, MTU int) *Channel {
 func (channel *Channel) ReadPacket() ([]byte, error) {
 
 	header := channel.inboundBuffer[0:channelHeaderSize]
-	_, err := io.ReadFull(channel.conn, header)
+	_, err := io.ReadFull(channel.transport, header)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
@@ -2034,7 +2042,7 @@ func (channel *Channel) ReadPacket() ([]byte, error) {
 	}
 
 	packet := channel.inboundBuffer[channelHeaderSize : channelHeaderSize+size]
-	_, err = io.ReadFull(channel.conn, packet)
+	_, err = io.ReadFull(channel.transport, packet)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
@@ -2046,36 +2054,31 @@ func (channel *Channel) ReadPacket() ([]byte, error) {
 // Concurrent calls to WritePacket are not supported.
 func (channel *Channel) WritePacket(packet []byte) error {
 
-	// Flow control assumed to be provided by the
-	// transport conn. In the case of SSH, the channel
-	// window size will determine whether the packet
-	// data is transmitted immediately or whether the
-	// conn.Write will block. When the channel window
-	// is full and conn.Write blocks, the sender's tun
-	// device will not be read (client case) or the send
-	// queue will fill (server case) and packets will
-	// be dropped. In this way, the channel window size
-	// will influence the TCP window size for tunneled
-	// traffic.
-
-	// Writes are not batched up but dispatched immediately.
-	// When the transport is an SSH channel, the overhead
-	// per tunneled packet includes:
+	// Flow control assumed to be provided by the transport. In the case
+	// of SSH, the channel window size will determine whether the packet
+	// data is transmitted immediately or whether the transport.Write will
+	// block. When the channel window is full and transport.Write blocks,
+	// the sender's tun device will not be read (client case) or the send
+	// queue will fill (server case) and packets will be dropped. In this
+	// way, the channel window size will influence the TCP window size for
+	// tunneled traffic.
+
+	// Writes are not batched up but dispatched immediately. When the
+	// transport is an SSH channel, the overhead per tunneled packet includes:
 	//
 	// - SSH_MSG_CHANNEL_DATA: 5 bytes (https://tools.ietf.org/html/rfc4254#section-5.2)
 	// - SSH packet: ~28 bytes (https://tools.ietf.org/html/rfc4253#section-5.3), with MAC
 	// - TCP/IP transport for SSH: 40 bytes for IPv4
 	//
-	// Also, when the transport in an SSH channel, batching
-	// of packets will naturally occur when the SSH channel
-	// window is full.
+	// Also, when the transport in an SSH channel, batching of packets will
+	// naturally occur when the SSH channel window is full.
 
 	// Assumes MTU <= 64K and len(packet) <= MTU
 
 	size := len(packet)
 	binary.BigEndian.PutUint16(channel.outboundBuffer, uint16(size))
 	copy(channel.outboundBuffer[channelHeaderSize:], packet)
-	_, err := channel.conn.Write(channel.outboundBuffer[0 : channelHeaderSize+size])
+	_, err := channel.transport.Write(channel.outboundBuffer[0 : channelHeaderSize+size])
 	if err != nil {
 		return common.ContextError(err)
 	}
@@ -2086,5 +2089,5 @@ func (channel *Channel) WritePacket(packet []byte) error {
 // Close interrupts any blocking Read/Write calls and
 // closes the channel transport.
 func (channel *Channel) Close() error {
-	return channel.conn.Close()
+	return channel.transport.Close()
 }

+ 10 - 10
psiphon/common/tun/tun_test.go

@@ -308,7 +308,7 @@ func (server *testServer) run() {
 				return
 			}
 
-			checkAllowedPortFunc := func(int) bool { return true }
+			checkAllowedPortFunc := func(net.IP, int) bool { return true }
 
 			server.tunServer.ClientConnected(
 				sessionID,
@@ -396,7 +396,7 @@ func startTestClient(
 		IPv6AddressCIDR:   "fd26:b6a6:4454:310a:0000:0000:0000:0001/64",
 		RouteDestinations: routeDestinations,
 		TunFD:             -1,
-		TransportConn:     unixConn,
+		Transport:         unixConn,
 		MTU:               MTU,
 	}
 
@@ -679,18 +679,18 @@ func (context *testLoggerContext) log(priority, message string) {
 	}
 }
 
-func (context *testLoggerContext) Debug(message string) {
-	context.log("DEBUG", message)
+func (context *testLoggerContext) Debug(args ...interface{}) {
+	context.log("DEBUG", fmt.Sprint(args...))
 }
 
-func (context *testLoggerContext) Info(message string) {
-	context.log("INFO", message)
+func (context *testLoggerContext) Info(args ...interface{}) {
+	context.log("INFO", fmt.Sprint(args...))
 }
 
-func (context *testLoggerContext) Warning(message string) {
-	context.log("WARNING", message)
+func (context *testLoggerContext) Warning(args ...interface{}) {
+	context.log("WARNING", fmt.Sprint(args...))
 }
 
-func (context *testLoggerContext) Error(message string) {
-	context.log("ERROR", message)
+func (context *testLoggerContext) Error(args ...interface{}) {
+	context.log("ERROR", fmt.Sprint(args...))
 }

+ 14 - 0
psiphon/server/config.go

@@ -256,6 +256,20 @@ type Config struct {
 	// OSLConfigFilename is the path of a file containing a JSON-encoded
 	// OSL Config, the OSL schemes to apply to Psiphon client tunnels.
 	OSLConfigFilename string
+
+	// RunPacketTunnel specifies whether to run a packet tunnel.
+	RunPacketTunnel bool
+
+	// PacketTunnelEgressInterface specifies tun.ServerConfig.EgressInterface.
+	PacketTunnelEgressInterface string
+
+	// PacketTunnelDownStreamPacketQueueSize specifies
+	// tun.ServerConfig.DownStreamPacketQueueSize.
+	PacketTunnelDownStreamPacketQueueSize int
+
+	// PacketTunnelSessionIdleExpirySeconds specifies
+	// tun.ServerConfig.SessionIdleExpirySeconds
+	PacketTunnelSessionIdleExpirySeconds int
 }
 
 // RunWebServer indicates whether to run a web server component.

+ 36 - 1
psiphon/server/dns.go

@@ -126,6 +126,16 @@ func NewDNSResolver(defaultResolver string) (*DNSResolver, error) {
 // the resolvers becomes unavailable.
 func (dns *DNSResolver) Get() net.IP {
 
+	dns.reloadWhenStale()
+
+	dns.ReloadableFile.RLock()
+	defer dns.ReloadableFile.RUnlock()
+
+	return dns.resolvers[rand.Intn(len(dns.resolvers))]
+}
+
+func (dns *DNSResolver) reloadWhenStale() {
+
 	// Every UDP DNS port forward frequently calls Get(), so this code
 	// is intended to minimize blocking. Most callers will hit just the
 	// atomic.LoadInt64 reload time check and the RLock (an atomic.AddInt32
@@ -159,11 +169,36 @@ func (dns *DNSResolver) Get() net.IP {
 			atomic.StoreInt32(&dns.isReloading, 0)
 		}
 	}
+}
+
+// GetAllIPv4 returns a list of all IPv4 DNS resolver addresses.
+// Cached values are updated if they're stale. If reloading fails,
+// the previous values are used.
+func (dns *DNSResolver) GetAllIPv4() []net.IP {
+	return dns.getAll(false)
+}
+
+// GetAllIPv6 returns a list of all IPv6 DNS resolver addresses.
+// Cached values are updated if they're stale. If reloading fails,
+// the previous values are used.
+func (dns *DNSResolver) GetAllIPv6() []net.IP {
+	return dns.getAll(true)
+}
+
+func (dns *DNSResolver) getAll(wantIPv6 bool) []net.IP {
+
+	dns.reloadWhenStale()
 
 	dns.ReloadableFile.RLock()
 	defer dns.ReloadableFile.RUnlock()
 
-	return dns.resolvers[rand.Intn(len(dns.resolvers))]
+	resolvers := make([]net.IP, 0)
+	for _, resolver := range dns.resolvers {
+		if (resolver.To4() == nil) == wantIPv6 {
+			resolvers = append(resolvers, resolver)
+		}
+	}
+	return resolvers
 }
 
 func parseResolveConf(fileContent []byte) ([]net.IP, error) {

+ 28 - 0
psiphon/server/log.go

@@ -118,6 +118,34 @@ func (logger *ContextLogger) LogPanicRecover(recoverValue interface{}, stack []b
 		})
 }
 
+type commonLogger struct {
+	contextLogger *ContextLogger
+}
+
+func (logger *commonLogger) WithContext() common.LogContext {
+	// Patch context to be correct parent
+	return logger.contextLogger.WithContext().WithField("context", common.GetParentContext())
+}
+
+func (logger *commonLogger) WithContextFields(fields common.LogFields) common.LogContext {
+	// Patch context to be correct parent
+	return logger.contextLogger.WithContextFields(LogFields(fields)).WithField("context", common.GetParentContext())
+}
+
+func (logger *commonLogger) LogMetric(metric string, fields common.LogFields) {
+	fields["event_name"] = metric
+	logger.contextLogger.LogRawFieldsWithTimestamp(LogFields(fields))
+}
+
+// CommonLogger wraps a ContextLogger instance with an interface
+// that conforms to common.Logger. This is used to make the ContextLogger
+// available to other packages that don't import the "server" package.
+func CommonLogger(contextLogger *ContextLogger) *commonLogger {
+	return &commonLogger{
+		contextLogger: contextLogger,
+	}
+}
+
 // NewLogWriter returns an io.PipeWriter that can be used to write
 // to the global logger. Caller must Close() the writer.
 func NewLogWriter() *io.PipeWriter {

+ 41 - 7
psiphon/server/services.go

@@ -36,6 +36,7 @@ import (
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/psinet"
 )
 
@@ -78,6 +79,38 @@ func RunServices(configJSON []byte) error {
 
 	supportServices.TunnelServer = tunnelServer
 
+	if config.RunPacketTunnel {
+
+		packetTunnelServer, err := tun.NewServer(&tun.ServerConfig{
+			Logger: CommonLogger(log),
+			GetDNSResolverIPv4Addresses: supportServices.DNSResolver.GetAllIPv4,
+			GetDNSResolverIPv6Addresses: supportServices.DNSResolver.GetAllIPv6,
+			EgressInterface:             config.PacketTunnelEgressInterface,
+			DownStreamPacketQueueSize:   config.PacketTunnelDownStreamPacketQueueSize,
+			SessionIdleExpirySeconds:    config.PacketTunnelSessionIdleExpirySeconds,
+		})
+		if err != nil {
+			log.WithContextFields(LogFields{"error": err}).Error("init packet tunnel failed")
+			return common.ContextError(err)
+		}
+
+		supportServices.PacketTunnelServer = packetTunnelServer
+	}
+
+	// After this point, errors should be delivered to the "errors" channel and
+	// orderly shutdown should flow through to the end of the function to ensure
+	// all workers are synchronously stopped.
+
+	if config.RunPacketTunnel {
+		supportServices.PacketTunnelServer.Start()
+		waitGroup.Add(1)
+		go func() {
+			defer waitGroup.Done()
+			<-shutdownBroadcast
+			supportServices.PacketTunnelServer.Stop()
+		}()
+	}
+
 	if config.RunLoadMonitor() {
 		waitGroup.Add(1)
 		go func() {
@@ -330,13 +363,14 @@ func logServerLoad(server *TunnelServer) {
 // components, which allows these data components to be refreshed
 // without restarting the server process.
 type SupportServices struct {
-	Config          *Config
-	TrafficRulesSet *TrafficRulesSet
-	OSLConfig       *osl.Config
-	PsinetDatabase  *psinet.Database
-	GeoIPService    *GeoIPService
-	DNSResolver     *DNSResolver
-	TunnelServer    *TunnelServer
+	Config             *Config
+	TrafficRulesSet    *TrafficRulesSet
+	OSLConfig          *osl.Config
+	PsinetDatabase     *psinet.Database
+	GeoIPService       *GeoIPService
+	DNSResolver        *DNSResolver
+	TunnelServer       *TunnelServer
+	PacketTunnelServer *tun.Server
 }
 
 // NewSupportServices initializes a new SupportServices.

+ 69 - 1
psiphon/server/tunnelServer.go

@@ -749,6 +749,7 @@ type sshClient struct {
 	supportsServerRequests               bool
 	handshakeState                       handshakeState
 	udpChannel                           ssh.Channel
+	packetTunnelChannel                  ssh.Channel
 	trafficRules                         TrafficRules
 	tcpTrafficState                      trafficState
 	udpTrafficState                      trafficState
@@ -1283,6 +1284,9 @@ func (sshClient *sshClient) runTunnel(
 
 	// Handle new channel (port forward) requests from the client.
 	//
+	// packet tunnel channels are handled by the packet tunnel server
+	// component. Each client may have at most one packet tunnel channel.
+	//
 	// udpgw client connections are dispatched immediately (clients use this for
 	// DNS, so it's essential to not block; and only one udpgw connection is
 	// retained at a time).
@@ -1292,6 +1296,39 @@ func (sshClient *sshClient) runTunnel(
 
 	for newChannel := range channels {
 
+		if newChannel.ChannelType() == protocol.PACKET_TUNNEL_CHANNEL_TYPE {
+
+			// Accept this channel immediately. This channel will replace any
+			// previously existing packet tunnel channel for this client.
+
+			packetTunnelChannel, requests, err := newChannel.Accept()
+			if err != nil {
+				log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
+				continue
+			}
+			go ssh.DiscardRequests(requests)
+
+			sshClient.setPacketTunnelChannel(packetTunnelChannel)
+
+			// PacketTunnelServer will run the client's packet tunnel. ClientDisconnected will
+			// be called by setPacketTunnelChannel: either if the client starts a new packet
+			// tunnel channel, or on exit of this function.
+
+			checkAllowedTCPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
+				return sshClient.isPortForwardPermitted(portForwardTypeTCP, false, upstreamIPAddress, port)
+			}
+
+			checkAllowedUDPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
+				return sshClient.isPortForwardPermitted(portForwardTypeUDP, false, upstreamIPAddress, port)
+			}
+
+			sshClient.sshServer.support.PacketTunnelServer.ClientConnected(
+				sshClient.sessionID,
+				packetTunnelChannel,
+				checkAllowedTCPPortFunc,
+				checkAllowedUDPPortFunc)
+		}
+
 		if newChannel.ChannelType() != "direct-tcpip" {
 			sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
 			continue
@@ -1358,9 +1395,40 @@ func (sshClient *sshClient) runTunnel(
 	// Stop all other worker goroutines
 	sshClient.stopRunning()
 
+	// This calls PacketTunnelServer.ClientDisconnected,
+	// which stops packet tunnel workers.
+	sshClient.setPacketTunnelChannel(nil)
+
 	waitGroup.Wait()
 }
 
+// setPacketTunnelChannel sets the single packet tunnel channel
+// for this sshClient. Any existing packet tunnel channel is
+// closed and its underlying session idled.
+func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
+	sshClient.Lock()
+	if sshClient.packetTunnelChannel != nil {
+		sshClient.packetTunnelChannel.Close()
+		sshClient.sshServer.support.PacketTunnelServer.ClientDisconnected(
+			sshClient.sessionID)
+	}
+	sshClient.packetTunnelChannel = channel
+	sshClient.Unlock()
+}
+
+// setUDPChannel sets the single UDP channel for this sshClient.
+// Each sshClient may have only one concurrent UDP channel. Each
+// UDP channel multiplexes many UDP port forwards via the udpgw
+// protocol. Any existing UDP channel is closed.
+func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
+	sshClient.Lock()
+	if sshClient.udpChannel != nil {
+		sshClient.udpChannel.Close()
+	}
+	sshClient.udpChannel = channel
+	sshClient.Unlock()
+}
+
 func (sshClient *sshClient) logTunnel(additionalMetrics LogFields) {
 
 	// Note: reporting duration based on last confirmed data transfer, which
@@ -1704,7 +1772,7 @@ func (sshClient *sshClient) isPortForwardPermitted(
 
 	// Disallow connection to loopback. This is a failsafe. The server
 	// should be run on a host with correctly configured firewall rules.
-	// And exception is made in the case of tranparent DNS forwarding,
+	// An exception is made in the case of tranparent DNS forwarding,
 	// where the remoteIP has been rewritten.
 	if !isTransparentDNSForwarding && remoteIP.IsLoopback() {
 		return false

+ 0 - 13
psiphon/server/udp.go

@@ -34,19 +34,6 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 )
 
-// setUDPChannel sets the single UDP channel for this sshClient.
-// Each sshClient may have only one concurrent UDP channel. Each
-// UDP channel multiplexes many UDP port forwards via the udpgw
-// protocol. Any existing UDP channel is closed.
-func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
-	sshClient.Lock()
-	if sshClient.udpChannel != nil {
-		sshClient.udpChannel.Close()
-	}
-	sshClient.udpChannel = channel
-	sshClient.Unlock()
-}
-
 // handleUDPChannel implements UDP port forwarding. A single UDP
 // SSH channel follows the udpgw protocol, which multiplexes many
 // UDP port forwards.