|
|
@@ -210,8 +210,15 @@ func (server *TunnelServer) SetClientHandshakeState(
|
|
|
return server.sshServer.setClientHandshakeState(sessionID, state)
|
|
|
}
|
|
|
|
|
|
+// SetEstablishTunnels sets whether new tunnels may be established or not.
|
|
|
+// When not establishing, incoming connections are immediately closed.
|
|
|
+func (server *TunnelServer) SetEstablishTunnels(establish bool) {
|
|
|
+ server.sshServer.setEstablishTunnels(establish)
|
|
|
+}
|
|
|
+
|
|
|
type sshServer struct {
|
|
|
support *SupportServices
|
|
|
+ establishTunnels int32
|
|
|
shutdownBroadcast <-chan struct{}
|
|
|
sshHostKey ssh.Signer
|
|
|
clientsMutex sync.Mutex
|
|
|
@@ -237,6 +244,7 @@ func newSSHServer(
|
|
|
|
|
|
return &sshServer{
|
|
|
support: support,
|
|
|
+ establishTunnels: 1,
|
|
|
shutdownBroadcast: shutdownBroadcast,
|
|
|
sshHostKey: signer,
|
|
|
acceptedClientCounts: make(map[string]int64),
|
|
|
@@ -244,6 +252,17 @@ func newSSHServer(
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
+func (sshServer *sshServer) setEstablishTunnels(establish bool) {
|
|
|
+ establishFlag := int32(1)
|
|
|
+ if !establish {
|
|
|
+ establishFlag = 0
|
|
|
+ }
|
|
|
+ atomic.StoreInt32(&sshServer.establishTunnels, establishFlag)
|
|
|
+
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{"establish": establish}).Info("establishing tunnels")
|
|
|
+}
|
|
|
+
|
|
|
// runListener is intended to run an a goroutine; it blocks
|
|
|
// running a particular listener. If an unrecoverable error
|
|
|
// occurs, it will send the error to the listenerError channel.
|
|
|
@@ -253,6 +272,17 @@ func (sshServer *sshServer) runListener(
|
|
|
tunnelProtocol string) {
|
|
|
|
|
|
handleClient := func(clientConn net.Conn) {
|
|
|
+
|
|
|
+ // Note: establish tunnel limiter cannot simply stop TCP
|
|
|
+ // listeners in all cases (e.g., meek) since SSH tunnel can
|
|
|
+ // span multiple TCP connections.
|
|
|
+
|
|
|
+ if atomic.LoadInt32(&sshServer.establishTunnels) != 1 {
|
|
|
+ log.WithContext().Debug("not establishing tunnels")
|
|
|
+ clientConn.Close()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
// process each client connection concurrently
|
|
|
go sshServer.handleClient(tunnelProtocol, clientConn)
|
|
|
}
|