tunnelServer.go 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "context"
  22. "crypto/subtle"
  23. "encoding/json"
  24. "errors"
  25. "fmt"
  26. "io"
  27. "net"
  28. "strconv"
  29. "sync"
  30. "sync/atomic"
  31. "syscall"
  32. "time"
  33. "github.com/Psiphon-Inc/crypto/ssh"
  34. cache "github.com/Psiphon-Inc/go-cache"
  35. "github.com/Psiphon-Inc/goarista/monotime"
  36. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  37. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
  38. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  39. )
  40. const (
  41. SSH_AUTH_LOG_PERIOD = 30 * time.Minute
  42. SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
  43. SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
  44. SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
  45. SSH_TCP_PORT_FORWARD_QUEUE_SIZE = 1024
  46. SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second
  47. SSH_SEND_OSL_RETRY_FACTOR = 2
  48. OSL_SESSION_CACHE_TTL = 5 * time.Minute
  49. )
  50. // TunnelServer is the main server that accepts Psiphon client
  51. // connections, via various obfuscation protocols, and provides
  52. // port forwarding (TCP and UDP) services to the Psiphon client.
  53. // At its core, TunnelServer is an SSH server. SSH is the base
  54. // protocol that provides port forward multiplexing, and transport
  55. // security. Layered on top of SSH, optionally, is Obfuscated SSH
  56. // and meek protocols, which provide further circumvention
  57. // capabilities.
  58. type TunnelServer struct {
  59. runWaitGroup *sync.WaitGroup
  60. listenerError chan error
  61. shutdownBroadcast <-chan struct{}
  62. sshServer *sshServer
  63. }
  64. // NewTunnelServer initializes a new tunnel server.
  65. func NewTunnelServer(
  66. support *SupportServices,
  67. shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
  68. sshServer, err := newSSHServer(support, shutdownBroadcast)
  69. if err != nil {
  70. return nil, common.ContextError(err)
  71. }
  72. return &TunnelServer{
  73. runWaitGroup: new(sync.WaitGroup),
  74. listenerError: make(chan error),
  75. shutdownBroadcast: shutdownBroadcast,
  76. sshServer: sshServer,
  77. }, nil
  78. }
  79. // Run runs the tunnel server; this function blocks while running a selection of
  80. // listeners that handle connection using various obfuscation protocols.
  81. //
  82. // Run listens on each designated tunnel port and spawns new goroutines to handle
  83. // each client connection. It halts when shutdownBroadcast is signaled. A list of active
  84. // clients is maintained, and when halting all clients are cleanly shutdown.
  85. //
  86. // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
  87. // authentication, and then looping on client new channel requests. "direct-tcpip"
  88. // channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress
  89. // config parameter is configured, UDP port forwards over a TCP stream, following
  90. // the udpgw protocol, are handled.
  91. //
  92. // A new goroutine is spawned to handle each port forward for each client. Each port
  93. // forward tracks its bytes transferred. Overall per-client stats for connection duration,
  94. // GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
  95. // client shuts down.
  96. //
  97. // Note: client handler goroutines may still be shutting down after Run() returns. See
  98. // comment in sshClient.stop(). TODO: fully synchronized shutdown.
  99. func (server *TunnelServer) Run() error {
  100. type sshListener struct {
  101. net.Listener
  102. localAddress string
  103. tunnelProtocol string
  104. }
  105. // TODO: should TunnelServer hold its own support pointer?
  106. support := server.sshServer.support
  107. // First bind all listeners; once all are successful,
  108. // start accepting connections on each.
  109. var listeners []*sshListener
  110. for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts {
  111. localAddress := fmt.Sprintf(
  112. "%s:%d", support.Config.ServerIPAddress, listenPort)
  113. listener, err := net.Listen("tcp", localAddress)
  114. if err != nil {
  115. for _, existingListener := range listeners {
  116. existingListener.Listener.Close()
  117. }
  118. return common.ContextError(err)
  119. }
  120. log.WithContextFields(
  121. LogFields{
  122. "localAddress": localAddress,
  123. "tunnelProtocol": tunnelProtocol,
  124. }).Info("listening")
  125. listeners = append(
  126. listeners,
  127. &sshListener{
  128. Listener: listener,
  129. localAddress: localAddress,
  130. tunnelProtocol: tunnelProtocol,
  131. })
  132. }
  133. for _, listener := range listeners {
  134. server.runWaitGroup.Add(1)
  135. go func(listener *sshListener) {
  136. defer server.runWaitGroup.Done()
  137. log.WithContextFields(
  138. LogFields{
  139. "localAddress": listener.localAddress,
  140. "tunnelProtocol": listener.tunnelProtocol,
  141. }).Info("running")
  142. server.sshServer.runListener(
  143. listener.Listener,
  144. server.listenerError,
  145. listener.tunnelProtocol)
  146. log.WithContextFields(
  147. LogFields{
  148. "localAddress": listener.localAddress,
  149. "tunnelProtocol": listener.tunnelProtocol,
  150. }).Info("stopped")
  151. }(listener)
  152. }
  153. var err error
  154. select {
  155. case <-server.shutdownBroadcast:
  156. case err = <-server.listenerError:
  157. }
  158. for _, listener := range listeners {
  159. listener.Close()
  160. }
  161. server.sshServer.stopClients()
  162. server.runWaitGroup.Wait()
  163. log.WithContext().Info("stopped")
  164. return err
  165. }
  166. // GetLoadStats returns load stats for the tunnel server. The stats are
  167. // broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
  168. // include current connected client count, total number of current port
  169. // forwards.
  170. func (server *TunnelServer) GetLoadStats() (ProtocolStats, RegionStats) {
  171. return server.sshServer.getLoadStats()
  172. }
  173. // ResetAllClientTrafficRules resets all established client traffic rules
  174. // to use the latest config and client properties. Any existing traffic
  175. // rule state is lost, including throttling state.
  176. func (server *TunnelServer) ResetAllClientTrafficRules() {
  177. server.sshServer.resetAllClientTrafficRules()
  178. }
  179. // ResetAllClientOSLConfigs resets all established client OSL state to use
  180. // the latest OSL config. Any existing OSL state is lost, including partial
  181. // progress towards SLOKs.
  182. func (server *TunnelServer) ResetAllClientOSLConfigs() {
  183. server.sshServer.resetAllClientOSLConfigs()
  184. }
  185. // SetClientHandshakeState sets the handshake state -- that it completed and
  186. // what paramaters were passed -- in sshClient. This state is used for allowing
  187. // port forwards and for future traffic rule selection. SetClientHandshakeState
  188. // also triggers an immediate traffic rule re-selection, as the rules selected
  189. // upon tunnel establishment may no longer apply now that handshake values are
  190. // set.
  191. func (server *TunnelServer) SetClientHandshakeState(
  192. sessionID string, state handshakeState) error {
  193. return server.sshServer.setClientHandshakeState(sessionID, state)
  194. }
  195. // SetEstablishTunnels sets whether new tunnels may be established or not.
  196. // When not establishing, incoming connections are immediately closed.
  197. func (server *TunnelServer) SetEstablishTunnels(establish bool) {
  198. server.sshServer.setEstablishTunnels(establish)
  199. }
  200. // GetEstablishTunnels returns whether new tunnels may be established or not.
  201. func (server *TunnelServer) GetEstablishTunnels() bool {
  202. return server.sshServer.getEstablishTunnels()
  203. }
  204. type sshServer struct {
  205. // Note: 64-bit ints used with atomic operations are at placed
  206. // at the start of struct to ensure 64-bit alignment.
  207. // (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
  208. lastAuthLog int64
  209. authFailedCount int64
  210. support *SupportServices
  211. establishTunnels int32
  212. shutdownBroadcast <-chan struct{}
  213. sshHostKey ssh.Signer
  214. clientsMutex sync.Mutex
  215. stoppingClients bool
  216. acceptedClientCounts map[string]map[string]int64
  217. clients map[string]*sshClient
  218. oslSessionCacheMutex sync.Mutex
  219. oslSessionCache *cache.Cache
  220. }
  221. func newSSHServer(
  222. support *SupportServices,
  223. shutdownBroadcast <-chan struct{}) (*sshServer, error) {
  224. privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey))
  225. if err != nil {
  226. return nil, common.ContextError(err)
  227. }
  228. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  229. signer, err := ssh.NewSignerFromKey(privateKey)
  230. if err != nil {
  231. return nil, common.ContextError(err)
  232. }
  233. // The OSL session cache temporarily retains OSL seed state
  234. // progress for disconnected clients. This enables clients
  235. // that disconnect and immediately reconnect to the same
  236. // server to resume their OSL progress. Cached progress
  237. // is referenced by session ID and is retained for
  238. // OSL_SESSION_CACHE_TTL after disconnect.
  239. //
  240. // Note: session IDs are assumed to be unpredictable. If a
  241. // rogue client could guess the session ID of another client,
  242. // it could resume its OSL progress and, if the OSL config
  243. // were known, infer some activity.
  244. oslSessionCache := cache.New(OSL_SESSION_CACHE_TTL, 1*time.Minute)
  245. return &sshServer{
  246. support: support,
  247. establishTunnels: 1,
  248. shutdownBroadcast: shutdownBroadcast,
  249. sshHostKey: signer,
  250. acceptedClientCounts: make(map[string]map[string]int64),
  251. clients: make(map[string]*sshClient),
  252. oslSessionCache: oslSessionCache,
  253. }, nil
  254. }
  255. func (sshServer *sshServer) setEstablishTunnels(establish bool) {
  256. // Do nothing when the setting is already correct. This avoids
  257. // spurious log messages when setEstablishTunnels is called
  258. // periodically with the same setting.
  259. if establish == sshServer.getEstablishTunnels() {
  260. return
  261. }
  262. establishFlag := int32(1)
  263. if !establish {
  264. establishFlag = 0
  265. }
  266. atomic.StoreInt32(&sshServer.establishTunnels, establishFlag)
  267. log.WithContextFields(
  268. LogFields{"establish": establish}).Info("establishing tunnels")
  269. }
  270. func (sshServer *sshServer) getEstablishTunnels() bool {
  271. return atomic.LoadInt32(&sshServer.establishTunnels) == 1
  272. }
  273. // runListener is intended to run an a goroutine; it blocks
  274. // running a particular listener. If an unrecoverable error
  275. // occurs, it will send the error to the listenerError channel.
  276. func (sshServer *sshServer) runListener(
  277. listener net.Listener,
  278. listenerError chan<- error,
  279. listenerTunnelProtocol string) {
  280. runningProtocols := make([]string, 0)
  281. for tunnelProtocol, _ := range sshServer.support.Config.TunnelProtocolPorts {
  282. runningProtocols = append(runningProtocols, tunnelProtocol)
  283. }
  284. handleClient := func(clientTunnelProtocol string, clientConn net.Conn) {
  285. // Note: establish tunnel limiter cannot simply stop TCP
  286. // listeners in all cases (e.g., meek) since SSH tunnel can
  287. // span multiple TCP connections.
  288. if !sshServer.getEstablishTunnels() {
  289. log.WithContext().Debug("not establishing tunnels")
  290. clientConn.Close()
  291. return
  292. }
  293. // The tunnelProtocol passed to handleClient is used for stats,
  294. // throttling, etc. When the tunnel protocol can be determined
  295. // unambiguously from the listening port, use that protocol and
  296. // don't use any client-declared value. Only use the client's
  297. // value, if present, in special cases where the listenting port
  298. // cannot distinguish the protocol.
  299. tunnelProtocol := listenerTunnelProtocol
  300. if clientTunnelProtocol != "" &&
  301. protocol.UseClientTunnelProtocol(
  302. clientTunnelProtocol, runningProtocols) {
  303. tunnelProtocol = clientTunnelProtocol
  304. }
  305. // process each client connection concurrently
  306. go sshServer.handleClient(tunnelProtocol, clientConn)
  307. }
  308. // Note: when exiting due to a unrecoverable error, be sure
  309. // to try to send the error to listenerError so that the outer
  310. // TunnelServer.Run will properly shut down instead of remaining
  311. // running.
  312. if protocol.TunnelProtocolUsesMeekHTTP(listenerTunnelProtocol) ||
  313. protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol) {
  314. meekServer, err := NewMeekServer(
  315. sshServer.support,
  316. listener,
  317. protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol),
  318. protocol.TunnelProtocolUsesObfuscatedSessionTickets(listenerTunnelProtocol),
  319. handleClient,
  320. sshServer.shutdownBroadcast)
  321. if err != nil {
  322. select {
  323. case listenerError <- common.ContextError(err):
  324. default:
  325. }
  326. return
  327. }
  328. meekServer.Run()
  329. } else {
  330. for {
  331. conn, err := listener.Accept()
  332. select {
  333. case <-sshServer.shutdownBroadcast:
  334. if err == nil {
  335. conn.Close()
  336. }
  337. return
  338. default:
  339. }
  340. if err != nil {
  341. if e, ok := err.(net.Error); ok && e.Temporary() {
  342. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  343. // Temporary error, keep running
  344. continue
  345. }
  346. select {
  347. case listenerError <- common.ContextError(err):
  348. default:
  349. }
  350. return
  351. }
  352. handleClient("", conn)
  353. }
  354. }
  355. }
  356. // An accepted client has completed a direct TCP or meek connection and has a net.Conn. Registration
  357. // is for tracking the number of connections.
  358. func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol, region string) {
  359. sshServer.clientsMutex.Lock()
  360. defer sshServer.clientsMutex.Unlock()
  361. if sshServer.acceptedClientCounts[tunnelProtocol] == nil {
  362. sshServer.acceptedClientCounts[tunnelProtocol] = make(map[string]int64)
  363. }
  364. sshServer.acceptedClientCounts[tunnelProtocol][region] += 1
  365. }
  366. func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol, region string) {
  367. sshServer.clientsMutex.Lock()
  368. defer sshServer.clientsMutex.Unlock()
  369. sshServer.acceptedClientCounts[tunnelProtocol][region] -= 1
  370. }
  371. // An established client has completed its SSH handshake and has a ssh.Conn. Registration is
  372. // for tracking the number of fully established clients and for maintaining a list of running
  373. // clients (for stopping at shutdown time).
  374. func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
  375. sshServer.clientsMutex.Lock()
  376. if sshServer.stoppingClients {
  377. sshServer.clientsMutex.Unlock()
  378. return false
  379. }
  380. // In the case of a duplicate client sessionID, the previous client is closed.
  381. // - Well-behaved clients generate pick a random sessionID that should be
  382. // unique (won't accidentally conflict) and hard to guess (can't be targetted
  383. // by a malicious client).
  384. // - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected
  385. // and resestablished. In this case, when the same server is selected, this logic
  386. // will be hit; closing the old, dangling client is desirable.
  387. // - Multi-tunnel clients should not normally use one server for multiple tunnels.
  388. existingClient := sshServer.clients[client.sessionID]
  389. sshServer.clients[client.sessionID] = client
  390. sshServer.clientsMutex.Unlock()
  391. // Call stop() outside the mutex to avoid deadlock.
  392. if existingClient != nil {
  393. existingClient.stop()
  394. log.WithContext().Info(
  395. "stopped existing client with duplicate session ID")
  396. }
  397. return true
  398. }
  399. func (sshServer *sshServer) unregisterEstablishedClient(client *sshClient) {
  400. sshServer.clientsMutex.Lock()
  401. registeredClient := sshServer.clients[client.sessionID]
  402. // registeredClient will differ from client when client
  403. // is the existingClient terminated in registerEstablishedClient.
  404. // In that case, registeredClient remains connected, and
  405. // the sshServer.clients entry should be retained.
  406. if registeredClient == client {
  407. delete(sshServer.clients, client.sessionID)
  408. }
  409. sshServer.clientsMutex.Unlock()
  410. // Call stop() outside the mutex to avoid deadlock.
  411. client.stop()
  412. }
  413. type ProtocolStats map[string]map[string]int64
  414. type RegionStats map[string]map[string]map[string]int64
  415. func (sshServer *sshServer) getLoadStats() (ProtocolStats, RegionStats) {
  416. sshServer.clientsMutex.Lock()
  417. defer sshServer.clientsMutex.Unlock()
  418. // Explicitly populate with zeros to ensure 0 counts in log messages
  419. zeroStats := func() map[string]int64 {
  420. stats := make(map[string]int64)
  421. stats["accepted_clients"] = 0
  422. stats["established_clients"] = 0
  423. stats["dialing_tcp_port_forwards"] = 0
  424. stats["tcp_port_forwards"] = 0
  425. stats["total_tcp_port_forwards"] = 0
  426. stats["udp_port_forwards"] = 0
  427. stats["total_udp_port_forwards"] = 0
  428. stats["tcp_port_forward_dialed_count"] = 0
  429. stats["tcp_port_forward_dialed_duration"] = 0
  430. stats["tcp_port_forward_failed_count"] = 0
  431. stats["tcp_port_forward_failed_duration"] = 0
  432. stats["tcp_port_forward_rejected_dialing_limit_count"] = 0
  433. return stats
  434. }
  435. zeroProtocolStats := func() map[string]map[string]int64 {
  436. stats := make(map[string]map[string]int64)
  437. stats["ALL"] = zeroStats()
  438. for tunnelProtocol, _ := range sshServer.support.Config.TunnelProtocolPorts {
  439. stats[tunnelProtocol] = zeroStats()
  440. }
  441. return stats
  442. }
  443. // [<protocol or ALL>][<stat name>] -> count
  444. protocolStats := zeroProtocolStats()
  445. // [<region][<protocol or ALL>][<stat name>] -> count
  446. regionStats := make(RegionStats)
  447. // Note: as currently tracked/counted, each established client is also an accepted client
  448. for tunnelProtocol, regionAcceptedClientCounts := range sshServer.acceptedClientCounts {
  449. for region, acceptedClientCount := range regionAcceptedClientCounts {
  450. if acceptedClientCount > 0 {
  451. if regionStats[region] == nil {
  452. regionStats[region] = zeroProtocolStats()
  453. }
  454. protocolStats["ALL"]["accepted_clients"] += acceptedClientCount
  455. protocolStats[tunnelProtocol]["accepted_clients"] += acceptedClientCount
  456. regionStats[region]["ALL"]["accepted_clients"] += acceptedClientCount
  457. regionStats[region][tunnelProtocol]["accepted_clients"] += acceptedClientCount
  458. }
  459. }
  460. }
  461. for _, client := range sshServer.clients {
  462. client.Lock()
  463. tunnelProtocol := client.tunnelProtocol
  464. region := client.geoIPData.Country
  465. if regionStats[region] == nil {
  466. regionStats[region] = zeroProtocolStats()
  467. }
  468. stats := []map[string]int64{
  469. protocolStats["ALL"],
  470. protocolStats[tunnelProtocol],
  471. regionStats[region]["ALL"],
  472. regionStats[region][tunnelProtocol]}
  473. for _, stat := range stats {
  474. stat["established_clients"] += 1
  475. // Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
  476. stat["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
  477. stat["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
  478. stat["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
  479. // client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
  480. stat["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
  481. stat["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
  482. stat["tcp_port_forward_dialed_count"] += client.qualityMetrics.tcpPortForwardDialedCount
  483. stat["tcp_port_forward_dialed_duration"] +=
  484. int64(client.qualityMetrics.tcpPortForwardDialedDuration / time.Millisecond)
  485. stat["tcp_port_forward_failed_count"] += client.qualityMetrics.tcpPortForwardFailedCount
  486. stat["tcp_port_forward_failed_duration"] +=
  487. int64(client.qualityMetrics.tcpPortForwardFailedDuration / time.Millisecond)
  488. stat["tcp_port_forward_rejected_dialing_limit_count"] +=
  489. client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount
  490. }
  491. client.qualityMetrics.tcpPortForwardDialedCount = 0
  492. client.qualityMetrics.tcpPortForwardDialedDuration = 0
  493. client.qualityMetrics.tcpPortForwardFailedCount = 0
  494. client.qualityMetrics.tcpPortForwardFailedDuration = 0
  495. client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount = 0
  496. client.Unlock()
  497. }
  498. return protocolStats, regionStats
  499. }
  500. func (sshServer *sshServer) resetAllClientTrafficRules() {
  501. sshServer.clientsMutex.Lock()
  502. clients := make(map[string]*sshClient)
  503. for sessionID, client := range sshServer.clients {
  504. clients[sessionID] = client
  505. }
  506. sshServer.clientsMutex.Unlock()
  507. for _, client := range clients {
  508. client.setTrafficRules()
  509. }
  510. }
  511. func (sshServer *sshServer) resetAllClientOSLConfigs() {
  512. // Flush cached seed state. This has the same effect
  513. // and same limitations as calling setOSLConfig for
  514. // currently connected clients -- all progress is lost.
  515. sshServer.oslSessionCacheMutex.Lock()
  516. sshServer.oslSessionCache.Flush()
  517. sshServer.oslSessionCacheMutex.Unlock()
  518. sshServer.clientsMutex.Lock()
  519. clients := make(map[string]*sshClient)
  520. for sessionID, client := range sshServer.clients {
  521. clients[sessionID] = client
  522. }
  523. sshServer.clientsMutex.Unlock()
  524. for _, client := range clients {
  525. client.setOSLConfig()
  526. }
  527. }
  528. func (sshServer *sshServer) setClientHandshakeState(
  529. sessionID string, state handshakeState) error {
  530. sshServer.clientsMutex.Lock()
  531. client := sshServer.clients[sessionID]
  532. sshServer.clientsMutex.Unlock()
  533. if client == nil {
  534. return common.ContextError(errors.New("unknown session ID"))
  535. }
  536. err := client.setHandshakeState(state)
  537. if err != nil {
  538. return common.ContextError(err)
  539. }
  540. return nil
  541. }
  542. func (sshServer *sshServer) stopClients() {
  543. sshServer.clientsMutex.Lock()
  544. sshServer.stoppingClients = true
  545. clients := sshServer.clients
  546. sshServer.clients = make(map[string]*sshClient)
  547. sshServer.clientsMutex.Unlock()
  548. for _, client := range clients {
  549. client.stop()
  550. }
  551. }
  552. func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
  553. geoIPData := sshServer.support.GeoIPService.Lookup(
  554. common.IPAddressFromAddr(clientConn.RemoteAddr()))
  555. sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
  556. defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
  557. sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
  558. sshClient.run(clientConn)
  559. }
  560. func (sshServer *sshServer) monitorPortForwardDialError(err error) {
  561. // "err" is the error returned from a failed TCP or UDP port
  562. // forward dial. Certain system error codes indicate low resource
  563. // conditions: insufficient file descriptors, ephemeral ports, or
  564. // memory. For these cases, log an alert.
  565. // TODO: also temporarily suspend new clients
  566. // Note: don't log net.OpError.Error() as the full error string
  567. // may contain client destination addresses.
  568. opErr, ok := err.(*net.OpError)
  569. if ok {
  570. if opErr.Err == syscall.EADDRNOTAVAIL ||
  571. opErr.Err == syscall.EAGAIN ||
  572. opErr.Err == syscall.ENOMEM ||
  573. opErr.Err == syscall.EMFILE ||
  574. opErr.Err == syscall.ENFILE {
  575. log.WithContextFields(
  576. LogFields{"error": opErr.Err}).Error(
  577. "port forward dial failed due to unavailable resource")
  578. }
  579. }
  580. }
  581. type sshClient struct {
  582. sync.Mutex
  583. sshServer *sshServer
  584. tunnelProtocol string
  585. sshConn ssh.Conn
  586. activityConn *common.ActivityMonitoredConn
  587. throttledConn *common.ThrottledConn
  588. geoIPData GeoIPData
  589. sessionID string
  590. supportsServerRequests bool
  591. handshakeState handshakeState
  592. udpChannel ssh.Channel
  593. trafficRules TrafficRules
  594. tcpTrafficState trafficState
  595. udpTrafficState trafficState
  596. qualityMetrics qualityMetrics
  597. tcpPortForwardLRU *common.LRUConns
  598. oslClientSeedState *osl.ClientSeedState
  599. signalIssueSLOKs chan struct{}
  600. runContext context.Context
  601. stopRunning context.CancelFunc
  602. tcpPortForwardDialingAvailableSignal context.CancelFunc
  603. }
  604. type trafficState struct {
  605. bytesUp int64
  606. bytesDown int64
  607. concurrentDialingPortForwardCount int64
  608. peakConcurrentDialingPortForwardCount int64
  609. concurrentPortForwardCount int64
  610. peakConcurrentPortForwardCount int64
  611. totalPortForwardCount int64
  612. availablePortForwardCond *sync.Cond
  613. }
  614. // qualityMetrics records upstream TCP dial attempts and
  615. // elapsed time. Elapsed time includes the full TCP handshake
  616. // and, in aggregate, is a measure of the quality of the
  617. // upstream link. These stats are recorded by each sshClient
  618. // and then reported and reset in sshServer.getLoadStats().
  619. type qualityMetrics struct {
  620. tcpPortForwardDialedCount int64
  621. tcpPortForwardDialedDuration time.Duration
  622. tcpPortForwardFailedCount int64
  623. tcpPortForwardFailedDuration time.Duration
  624. tcpPortForwardRejectedDialingLimitCount int64
  625. }
  626. type handshakeState struct {
  627. completed bool
  628. apiProtocol string
  629. apiParams requestJSONObject
  630. }
  631. func newSshClient(
  632. sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
  633. runContext, stopRunning := context.WithCancel(context.Background())
  634. client := &sshClient{
  635. sshServer: sshServer,
  636. tunnelProtocol: tunnelProtocol,
  637. geoIPData: geoIPData,
  638. tcpPortForwardLRU: common.NewLRUConns(),
  639. signalIssueSLOKs: make(chan struct{}, 1),
  640. runContext: runContext,
  641. stopRunning: stopRunning,
  642. }
  643. client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
  644. client.udpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
  645. return client
  646. }
  647. func (sshClient *sshClient) run(clientConn net.Conn) {
  648. // Some conns report additional metrics
  649. metricsSource, isMetricsSource := clientConn.(MetricsSource)
  650. // Set initial traffic rules, pre-handshake, based on currently known info.
  651. sshClient.setTrafficRules()
  652. // Wrap the base client connection with an ActivityMonitoredConn which will
  653. // terminate the connection if no data is received before the deadline. This
  654. // timeout is in effect for the entire duration of the SSH connection. Clients
  655. // must actively use the connection or send SSH keep alive requests to keep
  656. // the connection active. Writes are not considered reliable activity indicators
  657. // due to buffering.
  658. activityConn, err := common.NewActivityMonitoredConn(
  659. clientConn,
  660. SSH_CONNECTION_READ_DEADLINE,
  661. false,
  662. nil,
  663. nil)
  664. if err != nil {
  665. clientConn.Close()
  666. log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
  667. return
  668. }
  669. clientConn = activityConn
  670. // Further wrap the connection in a rate limiting ThrottledConn.
  671. throttledConn := common.NewThrottledConn(clientConn, sshClient.rateLimits())
  672. clientConn = throttledConn
  673. // Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  674. // respect shutdownBroadcast and implement a specific handshake timeout.
  675. // The timeout is to reclaim network resources in case the handshake takes
  676. // too long.
  677. type sshNewServerConnResult struct {
  678. conn net.Conn
  679. sshConn *ssh.ServerConn
  680. channels <-chan ssh.NewChannel
  681. requests <-chan *ssh.Request
  682. err error
  683. }
  684. resultChannel := make(chan *sshNewServerConnResult, 2)
  685. if SSH_HANDSHAKE_TIMEOUT > 0 {
  686. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  687. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  688. })
  689. }
  690. go func(conn net.Conn) {
  691. sshServerConfig := &ssh.ServerConfig{
  692. PasswordCallback: sshClient.passwordCallback,
  693. AuthLogCallback: sshClient.authLogCallback,
  694. ServerVersion: sshClient.sshServer.support.Config.SSHServerVersion,
  695. }
  696. sshServerConfig.AddHostKey(sshClient.sshServer.sshHostKey)
  697. result := &sshNewServerConnResult{}
  698. // Wrap the connection in an SSH deobfuscator when required.
  699. if protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
  700. // Note: NewObfuscatedSshConn blocks on network I/O
  701. // TODO: ensure this won't block shutdown
  702. conn, result.err = common.NewObfuscatedSshConn(
  703. common.OBFUSCATION_CONN_MODE_SERVER,
  704. conn,
  705. sshClient.sshServer.support.Config.ObfuscatedSSHKey)
  706. if result.err != nil {
  707. result.err = common.ContextError(result.err)
  708. }
  709. }
  710. if result.err == nil {
  711. result.sshConn, result.channels, result.requests, result.err =
  712. ssh.NewServerConn(conn, sshServerConfig)
  713. }
  714. resultChannel <- result
  715. }(clientConn)
  716. var result *sshNewServerConnResult
  717. select {
  718. case result = <-resultChannel:
  719. case <-sshClient.sshServer.shutdownBroadcast:
  720. // Close() will interrupt an ongoing handshake
  721. // TODO: wait for goroutine to exit before returning?
  722. clientConn.Close()
  723. return
  724. }
  725. if result.err != nil {
  726. clientConn.Close()
  727. // This is a Debug log due to noise. The handshake often fails due to I/O
  728. // errors as clients frequently interrupt connections in progress when
  729. // client-side load balancing completes a connection to a different server.
  730. log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed")
  731. return
  732. }
  733. sshClient.Lock()
  734. sshClient.sshConn = result.sshConn
  735. sshClient.activityConn = activityConn
  736. sshClient.throttledConn = throttledConn
  737. sshClient.Unlock()
  738. if !sshClient.sshServer.registerEstablishedClient(sshClient) {
  739. clientConn.Close()
  740. log.WithContext().Warning("register failed")
  741. return
  742. }
  743. sshClient.runTunnel(result.channels, result.requests)
  744. // Note: sshServer.unregisterEstablishedClient calls sshClient.stop(),
  745. // which also closes underlying transport Conn.
  746. sshClient.sshServer.unregisterEstablishedClient(sshClient)
  747. var additionalMetrics LogFields
  748. if isMetricsSource {
  749. additionalMetrics = metricsSource.GetMetrics()
  750. }
  751. sshClient.logTunnel(additionalMetrics)
  752. // Transfer OSL seed state -- the OSL progress -- from the closing
  753. // client to the session cache so the client can resume its progress
  754. // if it reconnects to this same server.
  755. // Note: following setOSLConfig order of locking.
  756. sshClient.Lock()
  757. if sshClient.oslClientSeedState != nil {
  758. sshClient.sshServer.oslSessionCacheMutex.Lock()
  759. sshClient.oslClientSeedState.Hibernate()
  760. sshClient.sshServer.oslSessionCache.Set(
  761. sshClient.sessionID, sshClient.oslClientSeedState, cache.DefaultExpiration)
  762. sshClient.sshServer.oslSessionCacheMutex.Unlock()
  763. sshClient.oslClientSeedState = nil
  764. }
  765. sshClient.Unlock()
  766. // Initiate cleanup of the GeoIP session cache. To allow for post-tunnel
  767. // final status requests, the lifetime of cached GeoIP records exceeds the
  768. // lifetime of the sshClient.
  769. sshClient.sshServer.support.GeoIPService.MarkSessionCacheToExpire(sshClient.sessionID)
  770. }
  771. func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  772. expectedSessionIDLength := 2 * protocol.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
  773. expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
  774. var sshPasswordPayload protocol.SSHPasswordPayload
  775. err := json.Unmarshal(password, &sshPasswordPayload)
  776. if err != nil {
  777. // Backwards compatibility case: instead of a JSON payload, older clients
  778. // send the hex encoded session ID prepended to the SSH password.
  779. // Note: there's an even older case where clients don't send any session ID,
  780. // but that's no longer supported.
  781. if len(password) == expectedSessionIDLength+expectedSSHPasswordLength {
  782. sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength])
  783. sshPasswordPayload.SshPassword = string(password[expectedSSHPasswordLength:len(password)])
  784. } else {
  785. return nil, common.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  786. }
  787. }
  788. if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) ||
  789. len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
  790. return nil, common.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
  791. }
  792. userOk := (subtle.ConstantTimeCompare(
  793. []byte(conn.User()), []byte(sshClient.sshServer.support.Config.SSHUserName)) == 1)
  794. passwordOk := (subtle.ConstantTimeCompare(
  795. []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1)
  796. if !userOk || !passwordOk {
  797. return nil, common.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  798. }
  799. sessionID := sshPasswordPayload.SessionId
  800. supportsServerRequests := common.Contains(
  801. sshPasswordPayload.ClientCapabilities, protocol.CLIENT_CAPABILITY_SERVER_REQUESTS)
  802. sshClient.Lock()
  803. sshClient.sessionID = sessionID
  804. sshClient.supportsServerRequests = supportsServerRequests
  805. geoIPData := sshClient.geoIPData
  806. sshClient.Unlock()
  807. // Store the GeoIP data associated with the session ID. This makes
  808. // the GeoIP data available to the web server for web API requests.
  809. // A cache that's distinct from the sshClient record is used to allow
  810. // for or post-tunnel final status requests.
  811. // If the client is reconnecting with the same session ID, this call
  812. // will undo the expiry set by MarkSessionCacheToExpire.
  813. sshClient.sshServer.support.GeoIPService.SetSessionCache(sessionID, geoIPData)
  814. return nil, nil
  815. }
  816. func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  817. if err != nil {
  818. if method == "none" && err.Error() == "no auth passed yet" {
  819. // In this case, the callback invocation is noise from auth negotiation
  820. return
  821. }
  822. // Note: here we previously logged messages for fail2ban to act on. This is no longer
  823. // done as the complexity outweighs the benefits.
  824. //
  825. // - The SSH credential is not secret -- it's in the server entry. Attackers targetting
  826. // the server likely already have the credential. On the other hand, random scanning and
  827. // brute forcing is mitigated with high entropy random passwords, rate limiting
  828. // (implemented on the host via iptables), and limited capabilities (the SSH session can
  829. // only port forward).
  830. //
  831. // - fail2ban coverage was inconsistent; in the case of an unfronted meek protocol through
  832. // an upstream proxy, the remote address is the upstream proxy, which should not be blocked.
  833. // The X-Forwarded-For header cant be used instead as it may be forged and used to get IPs
  834. // deliberately blocked; and in any case fail2ban adds iptables rules which can only block
  835. // by direct remote IP, not by original client IP. Fronted meek has the same iptables issue.
  836. //
  837. // Random scanning and brute forcing of port 22 will result in log noise. To mitigate this,
  838. // not every authentication failure is logged. A summary log is emitted periodically to
  839. // retain some record of this activity in case this is relevent to, e.g., a performance
  840. // investigation.
  841. atomic.AddInt64(&sshClient.sshServer.authFailedCount, 1)
  842. lastAuthLog := monotime.Time(atomic.LoadInt64(&sshClient.sshServer.lastAuthLog))
  843. if monotime.Since(lastAuthLog) > SSH_AUTH_LOG_PERIOD {
  844. now := int64(monotime.Now())
  845. if atomic.CompareAndSwapInt64(&sshClient.sshServer.lastAuthLog, int64(lastAuthLog), now) {
  846. count := atomic.SwapInt64(&sshClient.sshServer.authFailedCount, 0)
  847. log.WithContextFields(
  848. LogFields{"lastError": err, "failedCount": count}).Warning("authentication failures")
  849. }
  850. }
  851. log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
  852. } else {
  853. log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
  854. }
  855. }
  856. // stop signals the ssh connection to shutdown. After sshConn() returns,
  857. // the connection has terminated but sshClient.run() may still be
  858. // running and in the process of exiting.
  859. func (sshClient *sshClient) stop() {
  860. sshClient.sshConn.Close()
  861. sshClient.sshConn.Wait()
  862. }
  863. // runTunnel handles/dispatches new channels and new requests from the client.
  864. // When the SSH client connection closes, both the channels and requests channels
  865. // will close and runTunnel will exit.
  866. func (sshClient *sshClient) runTunnel(
  867. channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
  868. waitGroup := new(sync.WaitGroup)
  869. // Start client SSH API request handler
  870. waitGroup.Add(1)
  871. go func() {
  872. defer waitGroup.Done()
  873. for request := range requests {
  874. // Requests are processed serially; API responses must be sent in request order.
  875. var responsePayload []byte
  876. var err error
  877. if request.Type == "keepalive@openssh.com" {
  878. // Keepalive requests have an empty response.
  879. } else {
  880. // All other requests are assumed to be API requests.
  881. responsePayload, err = sshAPIRequestHandler(
  882. sshClient.sshServer.support,
  883. sshClient.geoIPData,
  884. request.Type,
  885. request.Payload)
  886. }
  887. if err == nil {
  888. err = request.Reply(true, responsePayload)
  889. } else {
  890. log.WithContextFields(LogFields{"error": err}).Warning("request failed")
  891. err = request.Reply(false, nil)
  892. }
  893. if err != nil {
  894. log.WithContextFields(LogFields{"error": err}).Warning("response failed")
  895. }
  896. }
  897. }()
  898. // Start OSL sender
  899. if sshClient.supportsServerRequests {
  900. waitGroup.Add(1)
  901. go func() {
  902. defer waitGroup.Done()
  903. sshClient.runOSLSender()
  904. }()
  905. }
  906. // Lifecycle of a TCP port forward:
  907. //
  908. // 1. A "direct-tcpip" SSH request is received from the client.
  909. //
  910. // A new TCP port forward request is enqueued. The queue delivers TCP port
  911. // forward requests to the TCP port forward manager, which enforces the TCP
  912. // port forward dial limit.
  913. //
  914. // Enqueuing new requests allows for reading further SSH requests from the
  915. // client without blocking when the dial limit is hit; this is to permit new
  916. // UDP/udpgw port forwards to be restablished without delay. The maximum size
  917. // of the queue enforces a hard cap on resources consumed by a client in the
  918. // pre-dial phase. When the queue is full, new TCP port forwards are
  919. // immediately rejected.
  920. //
  921. // 2. The TCP port forward manager dequeues the request.
  922. //
  923. // The manager calls dialingTCPPortForward(), which increments
  924. // concurrentDialingPortForwardCount, and calls
  925. // isTCPDialingPortForwardLimitExceeded() to check the concurrent dialing
  926. // count.
  927. //
  928. // The manager enforces the concurrent TCP dial limit: when at the limit, the
  929. // manager blocks waiting for the number of dials to drop below the limit before
  930. // dispatching the request to handleTCPPortForward(), which will run in its own
  931. // goroutine and will dial and relay the port forward.
  932. //
  933. // The block delays the current request and also halts dequeuing of subsequent
  934. // requests and could ultimately cause requests to be immediately rejected if
  935. // the queue fills. These actions are intended to apply back pressure when
  936. // upstream network resources are impaired.
  937. //
  938. // The time spent in the queue is deducted from the port forward's dial timeout.
  939. // The time spent blocking while at the dial limit is similarly deducted from
  940. // the dial timeout. If the dial timeout has expired before the dial begins, the
  941. // port forward is rejected and a stat is recorded.
  942. //
  943. // 3. handleTCPPortForward() performs the port forward dial and relaying.
  944. //
  945. // a. Dial the target, using the dial timeout remaining after queue and blocking
  946. // time is deducted.
  947. //
  948. // b. If the dial fails, call abortedTCPPortForward() to decrement
  949. // concurrentDialingPortForwardCount, freeing up a dial slot.
  950. //
  951. // c. If the dial succeeds, call establishedPortForward(), which decrements
  952. // concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
  953. // the "established" port forward count.
  954. //
  955. // d. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
  956. // concurrentPortForwardCount, the number of _established_ TCP port forwards.
  957. // If the limit is exceeded, the LRU established TCP port forward is closed and
  958. // the newly established TCP port forward proceeds. This LRU logic allows some
  959. // dangling resource consumption (e.g., TIME_WAIT) while providing a better
  960. // experience for clients.
  961. //
  962. // e. Relay data.
  963. //
  964. // f. Call closedPortForward() which decrements concurrentPortForwardCount and
  965. // records bytes transferred.
  966. // Start the TCP port forward manager
  967. type newTCPPortForward struct {
  968. enqueueTime monotime.Time
  969. hostToConnect string
  970. portToConnect int
  971. newChannel ssh.NewChannel
  972. }
  973. // The queue size is set to the traffic rules (MaxTCPPortForwardCount +
  974. // MaxTCPDialingPortForwardCount), which is a reasonable indication of resource
  975. // limits per client; when that value is not set, a default is used.
  976. // A limitation: this queue size is set once and doesn't change, for this client,
  977. // when traffic rules are reloaded.
  978. queueSize := sshClient.getTCPPortForwardQueueSize()
  979. if queueSize == 0 {
  980. queueSize = SSH_TCP_PORT_FORWARD_QUEUE_SIZE
  981. }
  982. newTCPPortForwards := make(chan *newTCPPortForward, queueSize)
  983. waitGroup.Add(1)
  984. go func() {
  985. defer waitGroup.Done()
  986. for newPortForward := range newTCPPortForwards {
  987. remainingDialTimeout :=
  988. time.Duration(sshClient.getDialTCPPortForwardTimeoutMilliseconds())*time.Millisecond -
  989. monotime.Since(newPortForward.enqueueTime)
  990. if remainingDialTimeout <= 0 {
  991. sshClient.updateQualityMetricsWithRejectedDialingLimit()
  992. sshClient.rejectNewChannel(
  993. newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out in queue")
  994. continue
  995. }
  996. // Reserve a TCP dialing slot.
  997. //
  998. // TOCTOU note: important to increment counts _before_ checking limits; otherwise,
  999. // the client could potentially consume excess resources by initiating many port
  1000. // forwards concurrently.
  1001. sshClient.dialingTCPPortForward()
  1002. // When max dials are in progress, wait up to remainingDialTimeout for dialing
  1003. // to become available. This blocks all dequeing.
  1004. if sshClient.isTCPDialingPortForwardLimitExceeded() {
  1005. blockStartTime := monotime.Now()
  1006. ctx, cancelCtx := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
  1007. sshClient.setTCPPortForwardDialingAvailableSignal(cancelCtx)
  1008. <-ctx.Done()
  1009. sshClient.setTCPPortForwardDialingAvailableSignal(nil)
  1010. cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  1011. remainingDialTimeout -= monotime.Since(blockStartTime)
  1012. }
  1013. if remainingDialTimeout <= 0 {
  1014. // Release the dialing slot here since handleTCPChannel() won't be called.
  1015. sshClient.abortedTCPPortForward()
  1016. sshClient.updateQualityMetricsWithRejectedDialingLimit()
  1017. sshClient.rejectNewChannel(
  1018. newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out before dialing")
  1019. continue
  1020. }
  1021. // Dial and relay the TCP port forward. handleTCPChannel is run in its own worker goroutine.
  1022. // handleTCPChannel will release the dialing slot reserved by dialingTCPPortForward(); and
  1023. // will deal with remainingDialTimeout <= 0.
  1024. waitGroup.Add(1)
  1025. go func(remainingDialTimeout time.Duration, newPortForward *newTCPPortForward) {
  1026. defer waitGroup.Done()
  1027. sshClient.handleTCPChannel(
  1028. remainingDialTimeout,
  1029. newPortForward.hostToConnect,
  1030. newPortForward.portToConnect,
  1031. newPortForward.newChannel)
  1032. }(remainingDialTimeout, newPortForward)
  1033. }
  1034. }()
  1035. // Handle new channel (port forward) requests from the client.
  1036. //
  1037. // udpgw client connections are dispatched immediately (clients use this for
  1038. // DNS, so it's essential to not block; and only one udpgw connection is
  1039. // retained at a time).
  1040. //
  1041. // All other TCP port forwards are dispatched via the TCP port forward
  1042. // manager queue.
  1043. for newChannel := range channels {
  1044. if newChannel.ChannelType() != "direct-tcpip" {
  1045. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  1046. continue
  1047. }
  1048. // http://tools.ietf.org/html/rfc4254#section-7.2
  1049. var directTcpipExtraData struct {
  1050. HostToConnect string
  1051. PortToConnect uint32
  1052. OriginatorIPAddress string
  1053. OriginatorPort uint32
  1054. }
  1055. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  1056. if err != nil {
  1057. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  1058. continue
  1059. }
  1060. // Intercept TCP port forwards to a specified udpgw server and handle directly.
  1061. // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
  1062. isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
  1063. sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
  1064. net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
  1065. if isUDPChannel {
  1066. // Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
  1067. // own worker goroutine.
  1068. waitGroup.Add(1)
  1069. go func(channel ssh.NewChannel) {
  1070. defer waitGroup.Done()
  1071. sshClient.handleUDPChannel(channel)
  1072. }(newChannel)
  1073. } else {
  1074. // Dispatch via TCP port forward manager. When the queue is full, the channel
  1075. // is immediately rejected.
  1076. tcpPortForward := &newTCPPortForward{
  1077. enqueueTime: monotime.Now(),
  1078. hostToConnect: directTcpipExtraData.HostToConnect,
  1079. portToConnect: int(directTcpipExtraData.PortToConnect),
  1080. newChannel: newChannel,
  1081. }
  1082. select {
  1083. case newTCPPortForwards <- tcpPortForward:
  1084. default:
  1085. sshClient.updateQualityMetricsWithRejectedDialingLimit()
  1086. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "TCP port forward dial queue full")
  1087. }
  1088. }
  1089. }
  1090. // The channel loop is interrupted by a client
  1091. // disconnect or by calling sshClient.stop().
  1092. // Stop the TCP port forward manager
  1093. close(newTCPPortForwards)
  1094. // Stop all other worker goroutines
  1095. sshClient.stopRunning()
  1096. waitGroup.Wait()
  1097. }
  1098. func (sshClient *sshClient) logTunnel(additionalMetrics LogFields) {
  1099. // Note: reporting duration based on last confirmed data transfer, which
  1100. // is reads for sshClient.activityConn.GetActiveDuration(), and not
  1101. // connection closing is important for protocols such as meek. For
  1102. // meek, the connection remains open until the HTTP session expires,
  1103. // which may be some time after the tunnel has closed. (The meek
  1104. // protocol has no allowance for signalling payload EOF, and even if
  1105. // it did the client may not have the opportunity to send a final
  1106. // request with an EOF flag set.)
  1107. sshClient.Lock()
  1108. logFields := getRequestLogFields(
  1109. sshClient.sshServer.support,
  1110. "server_tunnel",
  1111. sshClient.geoIPData,
  1112. sshClient.handshakeState.apiParams,
  1113. baseRequestParams)
  1114. logFields["handshake_completed"] = sshClient.handshakeState.completed
  1115. logFields["start_time"] = sshClient.activityConn.GetStartTime()
  1116. logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
  1117. logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
  1118. logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
  1119. logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
  1120. logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
  1121. logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
  1122. logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
  1123. logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
  1124. // sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
  1125. logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
  1126. logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
  1127. // Merge in additional metrics from the optional metrics source
  1128. if additionalMetrics != nil {
  1129. for name, value := range additionalMetrics {
  1130. // Don't overwrite any basic fields
  1131. if logFields[name] == nil {
  1132. logFields[name] = value
  1133. }
  1134. }
  1135. }
  1136. sshClient.Unlock()
  1137. log.LogRawFieldsWithTimestamp(logFields)
  1138. }
  1139. func (sshClient *sshClient) runOSLSender() {
  1140. for {
  1141. // Await a signal that there are SLOKs to send
  1142. // TODO: use reflect.SelectCase, and optionally await timer here?
  1143. select {
  1144. case <-sshClient.signalIssueSLOKs:
  1145. case <-sshClient.runContext.Done():
  1146. return
  1147. }
  1148. retryDelay := SSH_SEND_OSL_INITIAL_RETRY_DELAY
  1149. for {
  1150. err := sshClient.sendOSLRequest()
  1151. if err == nil {
  1152. break
  1153. }
  1154. log.WithContextFields(LogFields{"error": err}).Warning("sendOSLRequest failed")
  1155. // If the request failed, retry after a delay (with exponential backoff)
  1156. // or when signaled that there are additional SLOKs to send
  1157. retryTimer := time.NewTimer(retryDelay)
  1158. select {
  1159. case <-retryTimer.C:
  1160. case <-sshClient.signalIssueSLOKs:
  1161. case <-sshClient.runContext.Done():
  1162. retryTimer.Stop()
  1163. return
  1164. }
  1165. retryTimer.Stop()
  1166. retryDelay *= SSH_SEND_OSL_RETRY_FACTOR
  1167. }
  1168. }
  1169. }
  1170. // sendOSLRequest will invoke osl.GetSeedPayload to issue SLOKs and
  1171. // generate a payload, and send an OSL request to the client when
  1172. // there are new SLOKs in the payload.
  1173. func (sshClient *sshClient) sendOSLRequest() error {
  1174. seedPayload := sshClient.getOSLSeedPayload()
  1175. // Don't send when no SLOKs. This will happen when signalIssueSLOKs
  1176. // is received but no new SLOKs are issued.
  1177. if len(seedPayload.SLOKs) == 0 {
  1178. return nil
  1179. }
  1180. oslRequest := protocol.OSLRequest{
  1181. SeedPayload: seedPayload,
  1182. }
  1183. requestPayload, err := json.Marshal(oslRequest)
  1184. if err != nil {
  1185. return common.ContextError(err)
  1186. }
  1187. ok, _, err := sshClient.sshConn.SendRequest(
  1188. protocol.PSIPHON_API_OSL_REQUEST_NAME,
  1189. true,
  1190. requestPayload)
  1191. if err != nil {
  1192. return common.ContextError(err)
  1193. }
  1194. if !ok {
  1195. return common.ContextError(errors.New("client rejected request"))
  1196. }
  1197. sshClient.clearOSLSeedPayload()
  1198. return nil
  1199. }
  1200. func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, logMessage string) {
  1201. // Note: Debug level, as logMessage may contain user traffic destination address information
  1202. log.WithContextFields(
  1203. LogFields{
  1204. "channelType": newChannel.ChannelType(),
  1205. "logMessage": logMessage,
  1206. "rejectReason": reason.String(),
  1207. }).Debug("reject new channel")
  1208. // Note: logMessage is internal, for logging only; just the RejectionReason is sent to the client
  1209. newChannel.Reject(reason, reason.String())
  1210. }
  1211. // setHandshakeState records that a client has completed a handshake API request.
  1212. // Some parameters from the handshake request may be used in future traffic rule
  1213. // selection. Port forwards are disallowed until a handshake is complete. The
  1214. // handshake parameters are included in the session summary log recorded in
  1215. // sshClient.stop().
  1216. func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
  1217. sshClient.Lock()
  1218. completed := sshClient.handshakeState.completed
  1219. if !completed {
  1220. sshClient.handshakeState = state
  1221. }
  1222. sshClient.Unlock()
  1223. // Client must only perform one handshake
  1224. if completed {
  1225. return common.ContextError(errors.New("handshake already completed"))
  1226. }
  1227. sshClient.setTrafficRules()
  1228. sshClient.setOSLConfig()
  1229. return nil
  1230. }
  1231. // setTrafficRules resets the client's traffic rules based on the latest server config
  1232. // and client properties. As sshClient.trafficRules may be reset by a concurrent
  1233. // goroutine, trafficRules must only be accessed within the sshClient mutex.
  1234. func (sshClient *sshClient) setTrafficRules() {
  1235. sshClient.Lock()
  1236. defer sshClient.Unlock()
  1237. sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
  1238. sshClient.tunnelProtocol, sshClient.geoIPData, sshClient.handshakeState)
  1239. if sshClient.throttledConn != nil {
  1240. // Any existing throttling state is reset.
  1241. sshClient.throttledConn.SetLimits(
  1242. sshClient.trafficRules.RateLimits.CommonRateLimits())
  1243. }
  1244. }
  1245. // setOSLConfig resets the client's OSL seed state based on the latest OSL config
  1246. // As sshClient.oslClientSeedState may be reset by a concurrent goroutine,
  1247. // oslClientSeedState must only be accessed within the sshClient mutex.
  1248. func (sshClient *sshClient) setOSLConfig() {
  1249. sshClient.Lock()
  1250. defer sshClient.Unlock()
  1251. propagationChannelID, err := getStringRequestParam(
  1252. sshClient.handshakeState.apiParams, "propagation_channel_id")
  1253. if err != nil {
  1254. // This should not fail as long as client has sent valid handshake
  1255. return
  1256. }
  1257. // Use a cached seed state if one is found for the client's
  1258. // session ID. This enables resuming progress made in a previous
  1259. // tunnel.
  1260. // Note: go-cache is already concurency safe; the additional mutex
  1261. // is necessary to guarantee that Get/Delete is atomic; although in
  1262. // practice no two concurrent clients should ever supply the same
  1263. // session ID.
  1264. sshClient.sshServer.oslSessionCacheMutex.Lock()
  1265. oslClientSeedState, found := sshClient.sshServer.oslSessionCache.Get(sshClient.sessionID)
  1266. if found {
  1267. sshClient.sshServer.oslSessionCache.Delete(sshClient.sessionID)
  1268. sshClient.sshServer.oslSessionCacheMutex.Unlock()
  1269. sshClient.oslClientSeedState = oslClientSeedState.(*osl.ClientSeedState)
  1270. sshClient.oslClientSeedState.Resume(sshClient.signalIssueSLOKs)
  1271. return
  1272. }
  1273. sshClient.sshServer.oslSessionCacheMutex.Unlock()
  1274. // Two limitations when setOSLConfig() is invoked due to an
  1275. // OSL config hot reload:
  1276. //
  1277. // 1. any partial progress towards SLOKs is lost.
  1278. //
  1279. // 2. all existing osl.ClientSeedPortForwards for existing
  1280. // port forwards will not send progress to the new client
  1281. // seed state.
  1282. sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
  1283. sshClient.geoIPData.Country,
  1284. propagationChannelID,
  1285. sshClient.signalIssueSLOKs)
  1286. }
  1287. // newClientSeedPortForward will return nil when no seeding is
  1288. // associated with the specified ipAddress.
  1289. func (sshClient *sshClient) newClientSeedPortForward(ipAddress net.IP) *osl.ClientSeedPortForward {
  1290. sshClient.Lock()
  1291. defer sshClient.Unlock()
  1292. // Will not be initialized before handshake.
  1293. if sshClient.oslClientSeedState == nil {
  1294. return nil
  1295. }
  1296. return sshClient.oslClientSeedState.NewClientSeedPortForward(ipAddress)
  1297. }
  1298. // getOSLSeedPayload returns a payload containing all seeded SLOKs for
  1299. // this client's session.
  1300. func (sshClient *sshClient) getOSLSeedPayload() *osl.SeedPayload {
  1301. sshClient.Lock()
  1302. defer sshClient.Unlock()
  1303. // Will not be initialized before handshake.
  1304. if sshClient.oslClientSeedState == nil {
  1305. return &osl.SeedPayload{SLOKs: make([]*osl.SLOK, 0)}
  1306. }
  1307. return sshClient.oslClientSeedState.GetSeedPayload()
  1308. }
  1309. func (sshClient *sshClient) clearOSLSeedPayload() {
  1310. sshClient.Lock()
  1311. defer sshClient.Unlock()
  1312. sshClient.oslClientSeedState.ClearSeedPayload()
  1313. }
  1314. func (sshClient *sshClient) rateLimits() common.RateLimits {
  1315. sshClient.Lock()
  1316. defer sshClient.Unlock()
  1317. return sshClient.trafficRules.RateLimits.CommonRateLimits()
  1318. }
  1319. func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
  1320. sshClient.Lock()
  1321. defer sshClient.Unlock()
  1322. return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
  1323. }
  1324. func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
  1325. sshClient.Lock()
  1326. defer sshClient.Unlock()
  1327. return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
  1328. }
  1329. func (sshClient *sshClient) setTCPPortForwardDialingAvailableSignal(signal context.CancelFunc) {
  1330. sshClient.Lock()
  1331. defer sshClient.Unlock()
  1332. sshClient.tcpPortForwardDialingAvailableSignal = signal
  1333. }
  1334. const (
  1335. portForwardTypeTCP = iota
  1336. portForwardTypeUDP
  1337. portForwardTypeTransparentDNS
  1338. )
  1339. func (sshClient *sshClient) isPortForwardPermitted(
  1340. portForwardType int,
  1341. isTransparentDNSForwarding bool,
  1342. remoteIP net.IP,
  1343. port int) bool {
  1344. sshClient.Lock()
  1345. defer sshClient.Unlock()
  1346. if !sshClient.handshakeState.completed {
  1347. return false
  1348. }
  1349. // Disallow connection to loopback. This is a failsafe. The server
  1350. // should be run on a host with correctly configured firewall rules.
  1351. // And exception is made in the case of tranparent DNS forwarding,
  1352. // where the remoteIP has been rewritten.
  1353. if !isTransparentDNSForwarding && remoteIP.IsLoopback() {
  1354. return false
  1355. }
  1356. var allowPorts []int
  1357. if portForwardType == portForwardTypeTCP {
  1358. allowPorts = sshClient.trafficRules.AllowTCPPorts
  1359. } else {
  1360. allowPorts = sshClient.trafficRules.AllowUDPPorts
  1361. }
  1362. if len(allowPorts) == 0 {
  1363. return true
  1364. }
  1365. // TODO: faster lookup?
  1366. if len(allowPorts) > 0 {
  1367. for _, allowPort := range allowPorts {
  1368. if port == allowPort {
  1369. return true
  1370. }
  1371. }
  1372. }
  1373. for _, subnet := range sshClient.trafficRules.AllowSubnets {
  1374. // Note: ignoring error as config has been validated
  1375. _, network, _ := net.ParseCIDR(subnet)
  1376. if network.Contains(remoteIP) {
  1377. return true
  1378. }
  1379. }
  1380. return false
  1381. }
  1382. func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
  1383. sshClient.Lock()
  1384. defer sshClient.Unlock()
  1385. state := &sshClient.tcpTrafficState
  1386. max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  1387. if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
  1388. return true
  1389. }
  1390. return false
  1391. }
  1392. func (sshClient *sshClient) isAtPortForwardLimit(
  1393. portForwardType int) bool {
  1394. sshClient.Lock()
  1395. defer sshClient.Unlock()
  1396. var max int
  1397. var state *trafficState
  1398. if portForwardType == portForwardTypeTCP {
  1399. max = *sshClient.trafficRules.MaxTCPPortForwardCount
  1400. state = &sshClient.tcpTrafficState
  1401. } else {
  1402. max = *sshClient.trafficRules.MaxUDPPortForwardCount
  1403. state = &sshClient.udpTrafficState
  1404. }
  1405. if max > 0 && state.concurrentPortForwardCount >= int64(max) {
  1406. return true
  1407. }
  1408. return false
  1409. }
  1410. func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
  1411. sshClient.Lock()
  1412. defer sshClient.Unlock()
  1413. return *sshClient.trafficRules.MaxTCPPortForwardCount +
  1414. *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  1415. }
  1416. func (sshClient *sshClient) getDialTCPPortForwardTimeoutMilliseconds() int {
  1417. sshClient.Lock()
  1418. defer sshClient.Unlock()
  1419. return *sshClient.trafficRules.DialTCPPortForwardTimeoutMilliseconds
  1420. }
  1421. func (sshClient *sshClient) dialingTCPPortForward() {
  1422. sshClient.Lock()
  1423. defer sshClient.Unlock()
  1424. state := &sshClient.tcpTrafficState
  1425. state.concurrentDialingPortForwardCount += 1
  1426. if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
  1427. state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
  1428. }
  1429. }
  1430. func (sshClient *sshClient) abortedTCPPortForward() {
  1431. sshClient.Lock()
  1432. defer sshClient.Unlock()
  1433. sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
  1434. }
  1435. // establishedPortForward increments the concurrent port
  1436. // forward counter. closedPortForward decrements it, so it
  1437. // must always be called for each establishedPortForward
  1438. // call.
  1439. //
  1440. // When at the limit of established port forwards, the LRU
  1441. // existing port forward is closed to make way for the newly
  1442. // established one. There can be a minor delay as, in addition
  1443. // to calling Close() on the port forward net.Conn,
  1444. // establishedPortForward waits for the LRU's closedPortForward()
  1445. // call which will decrement the concurrent counter. This
  1446. // ensures all resources associated with the LRU (socket,
  1447. // goroutine) are released or will very soon be released before
  1448. // proceeding.
  1449. func (sshClient *sshClient) establishedPortForward(
  1450. portForwardType int, portForwardLRU *common.LRUConns) {
  1451. var state *trafficState
  1452. if portForwardType == portForwardTypeTCP {
  1453. state = &sshClient.tcpTrafficState
  1454. } else {
  1455. state = &sshClient.udpTrafficState
  1456. }
  1457. // When the maximum number of port forwards is already
  1458. // established, close the LRU. CloseOldest will call
  1459. // Close on the port forward net.Conn. Both TCP and
  1460. // UDP port forwards have handler goroutines that may
  1461. // be blocked calling Read on the net.Conn. Close will
  1462. // eventually interrupt the Read and cause the handlers
  1463. // to exit, but not immediately. So the following logic
  1464. // waits for a LRU handler to be interrupted and signal
  1465. // availability.
  1466. //
  1467. // Note: the port forward limit can change via a traffic
  1468. // rules hot reload; the condition variable handles this
  1469. // case whereas a channel-based semaphore would not.
  1470. if sshClient.isAtPortForwardLimit(portForwardType) {
  1471. portForwardLRU.CloseOldest()
  1472. log.WithContext().Debug("closed LRU port forward")
  1473. state.availablePortForwardCond.L.Lock()
  1474. for sshClient.isAtPortForwardLimit(portForwardType) {
  1475. state.availablePortForwardCond.Wait()
  1476. }
  1477. state.availablePortForwardCond.L.Unlock()
  1478. }
  1479. sshClient.Lock()
  1480. if portForwardType == portForwardTypeTCP {
  1481. // Assumes TCP port forwards called dialingTCPPortForward
  1482. state.concurrentDialingPortForwardCount -= 1
  1483. if sshClient.tcpPortForwardDialingAvailableSignal != nil {
  1484. max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  1485. if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
  1486. sshClient.tcpPortForwardDialingAvailableSignal()
  1487. }
  1488. }
  1489. }
  1490. state.concurrentPortForwardCount += 1
  1491. if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
  1492. state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
  1493. }
  1494. state.totalPortForwardCount += 1
  1495. sshClient.Unlock()
  1496. }
  1497. func (sshClient *sshClient) closedPortForward(
  1498. portForwardType int, bytesUp, bytesDown int64) {
  1499. sshClient.Lock()
  1500. var state *trafficState
  1501. if portForwardType == portForwardTypeTCP {
  1502. state = &sshClient.tcpTrafficState
  1503. } else {
  1504. state = &sshClient.udpTrafficState
  1505. }
  1506. state.concurrentPortForwardCount -= 1
  1507. state.bytesUp += bytesUp
  1508. state.bytesDown += bytesDown
  1509. sshClient.Unlock()
  1510. // Signal any goroutine waiting in establishedPortForward
  1511. // that an established port forward slot is available.
  1512. state.availablePortForwardCond.Signal()
  1513. }
  1514. func (sshClient *sshClient) updateQualityMetricsWithDialResult(
  1515. tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
  1516. sshClient.Lock()
  1517. defer sshClient.Unlock()
  1518. if tcpPortForwardDialSuccess {
  1519. sshClient.qualityMetrics.tcpPortForwardDialedCount += 1
  1520. sshClient.qualityMetrics.tcpPortForwardDialedDuration += dialDuration
  1521. } else {
  1522. sshClient.qualityMetrics.tcpPortForwardFailedCount += 1
  1523. sshClient.qualityMetrics.tcpPortForwardFailedDuration += dialDuration
  1524. }
  1525. }
  1526. func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
  1527. sshClient.Lock()
  1528. defer sshClient.Unlock()
  1529. sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
  1530. }
  1531. func (sshClient *sshClient) handleTCPChannel(
  1532. remainingDialTimeout time.Duration,
  1533. hostToConnect string,
  1534. portToConnect int,
  1535. newChannel ssh.NewChannel) {
  1536. // Assumptions:
  1537. // - sshClient.dialingTCPPortForward() has been called
  1538. // - remainingDialTimeout > 0
  1539. established := false
  1540. defer func() {
  1541. if !established {
  1542. sshClient.abortedTCPPortForward()
  1543. }
  1544. }()
  1545. // Transparently redirect web API request connections.
  1546. isWebServerPortForward := false
  1547. config := sshClient.sshServer.support.Config
  1548. if config.WebServerPortForwardAddress != "" {
  1549. destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
  1550. if destination == config.WebServerPortForwardAddress {
  1551. isWebServerPortForward = true
  1552. if config.WebServerPortForwardRedirectAddress != "" {
  1553. // Note: redirect format is validated when config is loaded
  1554. host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
  1555. port, _ := strconv.Atoi(portStr)
  1556. hostToConnect = host
  1557. portToConnect = port
  1558. }
  1559. }
  1560. }
  1561. // Dial the remote address.
  1562. //
  1563. // Hostname resolution is performed explicitly, as a seperate step, as the target IP
  1564. // address is used for traffic rules (AllowSubnets) and OSL seed progress.
  1565. //
  1566. // Contexts are used for cancellation (via sshClient.runContext, which is cancelled
  1567. // when the client is stopping) and timeouts.
  1568. dialStartTime := monotime.Now()
  1569. log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
  1570. ctx, cancelCtx := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
  1571. IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
  1572. cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  1573. // TODO: shuffle list to try other IPs?
  1574. // TODO: IPv6 support
  1575. var IP net.IP
  1576. for _, ip := range IPs {
  1577. if ip.IP.To4() != nil {
  1578. IP = ip.IP
  1579. break
  1580. }
  1581. }
  1582. if err == nil && IP == nil {
  1583. err = errors.New("no IP address")
  1584. }
  1585. resolveElapsedTime := monotime.Since(dialStartTime)
  1586. if err != nil {
  1587. // Record a port forward failure
  1588. sshClient.updateQualityMetricsWithDialResult(true, resolveElapsedTime)
  1589. sshClient.rejectNewChannel(
  1590. newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
  1591. return
  1592. }
  1593. remainingDialTimeout -= resolveElapsedTime
  1594. if remainingDialTimeout <= 0 {
  1595. sshClient.rejectNewChannel(
  1596. newChannel, ssh.Prohibited, "TCP port forward timed out resolving")
  1597. return
  1598. }
  1599. // Enforce traffic rules, using the resolved IP address.
  1600. if !isWebServerPortForward &&
  1601. !sshClient.isPortForwardPermitted(
  1602. portForwardTypeTCP,
  1603. false,
  1604. IP,
  1605. portToConnect) {
  1606. // Note: not recording a port forward failure in this case
  1607. sshClient.rejectNewChannel(
  1608. newChannel, ssh.Prohibited, "port forward not permitted")
  1609. return
  1610. }
  1611. // TCP dial.
  1612. remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
  1613. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
  1614. ctx, cancelCtx = context.WithTimeout(sshClient.runContext, remainingDialTimeout)
  1615. fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
  1616. cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  1617. // Record port forward success or failure
  1618. sshClient.updateQualityMetricsWithDialResult(err == nil, monotime.Since(dialStartTime))
  1619. if err != nil {
  1620. // Monitor for low resource error conditions
  1621. sshClient.sshServer.monitorPortForwardDialError(err)
  1622. sshClient.rejectNewChannel(
  1623. newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", err))
  1624. return
  1625. }
  1626. // The upstream TCP port forward connection has been established. Schedule
  1627. // some cleanup and notify the SSH client that the channel is accepted.
  1628. defer fwdConn.Close()
  1629. fwdChannel, requests, err := newChannel.Accept()
  1630. if err != nil {
  1631. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  1632. return
  1633. }
  1634. go ssh.DiscardRequests(requests)
  1635. defer fwdChannel.Close()
  1636. // Release the dialing slot and acquire an established slot.
  1637. //
  1638. // establishedPortForward increments the concurrent TCP port
  1639. // forward counter and closes the LRU existing TCP port forward
  1640. // when already at the limit.
  1641. //
  1642. // Known limitations:
  1643. //
  1644. // - Closed LRU TCP sockets will enter the TIME_WAIT state,
  1645. // continuing to consume some resources.
  1646. sshClient.establishedPortForward(portForwardTypeTCP, sshClient.tcpPortForwardLRU)
  1647. // "established = true" cancels the deferred abortedTCPPortForward()
  1648. established = true
  1649. // TODO: 64-bit alignment? https://golang.org/pkg/sync/atomic/#pkg-note-BUG
  1650. var bytesUp, bytesDown int64
  1651. defer func() {
  1652. sshClient.closedPortForward(
  1653. portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
  1654. }()
  1655. lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
  1656. defer lruEntry.Remove()
  1657. // ActivityMonitoredConn monitors the TCP port forward I/O and updates
  1658. // its LRU status. ActivityMonitoredConn also times out I/O on the port
  1659. // forward if both reads and writes have been idle for the specified
  1660. // duration.
  1661. // Ensure nil interface if newClientSeedPortForward returns nil
  1662. var updater common.ActivityUpdater
  1663. seedUpdater := sshClient.newClientSeedPortForward(IP)
  1664. if seedUpdater != nil {
  1665. updater = seedUpdater
  1666. }
  1667. fwdConn, err = common.NewActivityMonitoredConn(
  1668. fwdConn,
  1669. sshClient.idleTCPPortForwardTimeout(),
  1670. true,
  1671. updater,
  1672. lruEntry)
  1673. if err != nil {
  1674. log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
  1675. return
  1676. }
  1677. // Relay channel to forwarded connection.
  1678. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
  1679. // TODO: relay errors to fwdChannel.Stderr()?
  1680. relayWaitGroup := new(sync.WaitGroup)
  1681. relayWaitGroup.Add(1)
  1682. go func() {
  1683. defer relayWaitGroup.Done()
  1684. // io.Copy allocates a 32K temporary buffer, and each port forward relay uses
  1685. // two of these buffers; using io.CopyBuffer with a smaller buffer reduces the
  1686. // overall memory footprint.
  1687. bytes, err := io.CopyBuffer(
  1688. fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  1689. atomic.AddInt64(&bytesDown, bytes)
  1690. if err != nil && err != io.EOF {
  1691. // Debug since errors such as "connection reset by peer" occur during normal operation
  1692. log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
  1693. }
  1694. // Interrupt upstream io.Copy when downstream is shutting down.
  1695. // TODO: this is done to quickly cleanup the port forward when
  1696. // fwdConn has a read timeout, but is it clean -- upstream may still
  1697. // be flowing?
  1698. fwdChannel.Close()
  1699. }()
  1700. bytes, err := io.CopyBuffer(
  1701. fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  1702. atomic.AddInt64(&bytesUp, bytes)
  1703. if err != nil && err != io.EOF {
  1704. log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
  1705. }
  1706. // Shutdown special case: fwdChannel will be closed and return EOF when
  1707. // the SSH connection is closed, but we need to explicitly close fwdConn
  1708. // to interrupt the downstream io.Copy, which may be blocked on a
  1709. // fwdConn.Read().
  1710. fwdConn.Close()
  1711. relayWaitGroup.Wait()
  1712. log.WithContextFields(
  1713. LogFields{
  1714. "remoteAddr": remoteAddr,
  1715. "bytesUp": atomic.LoadInt64(&bytesUp),
  1716. "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
  1717. }