tunnelServer.go 73 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "context"
  22. "crypto/subtle"
  23. "encoding/json"
  24. "errors"
  25. "fmt"
  26. "io"
  27. "net"
  28. "strconv"
  29. "sync"
  30. "sync/atomic"
  31. "syscall"
  32. "time"
  33. cache "github.com/Psiphon-Inc/go-cache"
  34. "github.com/Psiphon-Inc/goarista/monotime"
  35. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  36. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
  37. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
  38. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  39. )
  40. const (
  41. SSH_AUTH_LOG_PERIOD = 30 * time.Minute
  42. SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
  43. SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
  44. SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
  45. SSH_TCP_PORT_FORWARD_QUEUE_SIZE = 1024
  46. SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second
  47. SSH_SEND_OSL_RETRY_FACTOR = 2
  48. OSL_SESSION_CACHE_TTL = 5 * time.Minute
  49. )
  50. // TunnelServer is the main server that accepts Psiphon client
  51. // connections, via various obfuscation protocols, and provides
  52. // port forwarding (TCP and UDP) services to the Psiphon client.
  53. // At its core, TunnelServer is an SSH server. SSH is the base
  54. // protocol that provides port forward multiplexing, and transport
  55. // security. Layered on top of SSH, optionally, is Obfuscated SSH
  56. // and meek protocols, which provide further circumvention
  57. // capabilities.
  58. type TunnelServer struct {
  59. runWaitGroup *sync.WaitGroup
  60. listenerError chan error
  61. shutdownBroadcast <-chan struct{}
  62. sshServer *sshServer
  63. }
  64. // NewTunnelServer initializes a new tunnel server.
  65. func NewTunnelServer(
  66. support *SupportServices,
  67. shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
  68. sshServer, err := newSSHServer(support, shutdownBroadcast)
  69. if err != nil {
  70. return nil, common.ContextError(err)
  71. }
  72. return &TunnelServer{
  73. runWaitGroup: new(sync.WaitGroup),
  74. listenerError: make(chan error),
  75. shutdownBroadcast: shutdownBroadcast,
  76. sshServer: sshServer,
  77. }, nil
  78. }
  79. // Run runs the tunnel server; this function blocks while running a selection of
  80. // listeners that handle connection using various obfuscation protocols.
  81. //
  82. // Run listens on each designated tunnel port and spawns new goroutines to handle
  83. // each client connection. It halts when shutdownBroadcast is signaled. A list of active
  84. // clients is maintained, and when halting all clients are cleanly shutdown.
  85. //
  86. // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
  87. // authentication, and then looping on client new channel requests. "direct-tcpip"
  88. // channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress
  89. // config parameter is configured, UDP port forwards over a TCP stream, following
  90. // the udpgw protocol, are handled.
  91. //
  92. // A new goroutine is spawned to handle each port forward for each client. Each port
  93. // forward tracks its bytes transferred. Overall per-client stats for connection duration,
  94. // GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
  95. // client shuts down.
  96. //
  97. // Note: client handler goroutines may still be shutting down after Run() returns. See
  98. // comment in sshClient.stop(). TODO: fully synchronized shutdown.
  99. func (server *TunnelServer) Run() error {
  100. type sshListener struct {
  101. net.Listener
  102. localAddress string
  103. tunnelProtocol string
  104. }
  105. // TODO: should TunnelServer hold its own support pointer?
  106. support := server.sshServer.support
  107. // First bind all listeners; once all are successful,
  108. // start accepting connections on each.
  109. var listeners []*sshListener
  110. for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts {
  111. localAddress := fmt.Sprintf(
  112. "%s:%d", support.Config.ServerIPAddress, listenPort)
  113. listener, err := net.Listen("tcp", localAddress)
  114. if err != nil {
  115. for _, existingListener := range listeners {
  116. existingListener.Listener.Close()
  117. }
  118. return common.ContextError(err)
  119. }
  120. log.WithContextFields(
  121. LogFields{
  122. "localAddress": localAddress,
  123. "tunnelProtocol": tunnelProtocol,
  124. }).Info("listening")
  125. listeners = append(
  126. listeners,
  127. &sshListener{
  128. Listener: listener,
  129. localAddress: localAddress,
  130. tunnelProtocol: tunnelProtocol,
  131. })
  132. }
  133. for _, listener := range listeners {
  134. server.runWaitGroup.Add(1)
  135. go func(listener *sshListener) {
  136. defer server.runWaitGroup.Done()
  137. log.WithContextFields(
  138. LogFields{
  139. "localAddress": listener.localAddress,
  140. "tunnelProtocol": listener.tunnelProtocol,
  141. }).Info("running")
  142. server.sshServer.runListener(
  143. listener.Listener,
  144. server.listenerError,
  145. listener.tunnelProtocol)
  146. log.WithContextFields(
  147. LogFields{
  148. "localAddress": listener.localAddress,
  149. "tunnelProtocol": listener.tunnelProtocol,
  150. }).Info("stopped")
  151. }(listener)
  152. }
  153. var err error
  154. select {
  155. case <-server.shutdownBroadcast:
  156. case err = <-server.listenerError:
  157. }
  158. for _, listener := range listeners {
  159. listener.Close()
  160. }
  161. server.sshServer.stopClients()
  162. server.runWaitGroup.Wait()
  163. log.WithContext().Info("stopped")
  164. return err
  165. }
  166. // GetLoadStats returns load stats for the tunnel server. The stats are
  167. // broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
  168. // include current connected client count, total number of current port
  169. // forwards.
  170. func (server *TunnelServer) GetLoadStats() (ProtocolStats, RegionStats) {
  171. return server.sshServer.getLoadStats()
  172. }
  173. // ResetAllClientTrafficRules resets all established client traffic rules
  174. // to use the latest config and client properties. Any existing traffic
  175. // rule state is lost, including throttling state.
  176. func (server *TunnelServer) ResetAllClientTrafficRules() {
  177. server.sshServer.resetAllClientTrafficRules()
  178. }
  179. // ResetAllClientOSLConfigs resets all established client OSL state to use
  180. // the latest OSL config. Any existing OSL state is lost, including partial
  181. // progress towards SLOKs.
  182. func (server *TunnelServer) ResetAllClientOSLConfigs() {
  183. server.sshServer.resetAllClientOSLConfigs()
  184. }
  185. // SetClientHandshakeState sets the handshake state -- that it completed and
  186. // what parameters were passed -- in sshClient. This state is used for allowing
  187. // port forwards and for future traffic rule selection. SetClientHandshakeState
  188. // also triggers an immediate traffic rule re-selection, as the rules selected
  189. // upon tunnel establishment may no longer apply now that handshake values are
  190. // set.
  191. func (server *TunnelServer) SetClientHandshakeState(
  192. sessionID string, state handshakeState) error {
  193. return server.sshServer.setClientHandshakeState(sessionID, state)
  194. }
  195. // GetClientHandshaked indicates whether the client has completed a handshake
  196. // and whether its traffic rules are immediately exhausted.
  197. func (server *TunnelServer) GetClientHandshaked(
  198. sessionID string) (bool, bool, error) {
  199. return server.sshServer.getClientHandshaked(sessionID)
  200. }
  201. // SetEstablishTunnels sets whether new tunnels may be established or not.
  202. // When not establishing, incoming connections are immediately closed.
  203. func (server *TunnelServer) SetEstablishTunnels(establish bool) {
  204. server.sshServer.setEstablishTunnels(establish)
  205. }
  206. // GetEstablishTunnels returns whether new tunnels may be established or not.
  207. func (server *TunnelServer) GetEstablishTunnels() bool {
  208. return server.sshServer.getEstablishTunnels()
  209. }
  210. type sshServer struct {
  211. // Note: 64-bit ints used with atomic operations are placed
  212. // at the start of struct to ensure 64-bit alignment.
  213. // (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
  214. lastAuthLog int64
  215. authFailedCount int64
  216. support *SupportServices
  217. establishTunnels int32
  218. shutdownBroadcast <-chan struct{}
  219. sshHostKey ssh.Signer
  220. clientsMutex sync.Mutex
  221. stoppingClients bool
  222. acceptedClientCounts map[string]map[string]int64
  223. clients map[string]*sshClient
  224. oslSessionCacheMutex sync.Mutex
  225. oslSessionCache *cache.Cache
  226. }
  227. func newSSHServer(
  228. support *SupportServices,
  229. shutdownBroadcast <-chan struct{}) (*sshServer, error) {
  230. privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey))
  231. if err != nil {
  232. return nil, common.ContextError(err)
  233. }
  234. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  235. signer, err := ssh.NewSignerFromKey(privateKey)
  236. if err != nil {
  237. return nil, common.ContextError(err)
  238. }
  239. // The OSL session cache temporarily retains OSL seed state
  240. // progress for disconnected clients. This enables clients
  241. // that disconnect and immediately reconnect to the same
  242. // server to resume their OSL progress. Cached progress
  243. // is referenced by session ID and is retained for
  244. // OSL_SESSION_CACHE_TTL after disconnect.
  245. //
  246. // Note: session IDs are assumed to be unpredictable. If a
  247. // rogue client could guess the session ID of another client,
  248. // it could resume its OSL progress and, if the OSL config
  249. // were known, infer some activity.
  250. oslSessionCache := cache.New(OSL_SESSION_CACHE_TTL, 1*time.Minute)
  251. return &sshServer{
  252. support: support,
  253. establishTunnels: 1,
  254. shutdownBroadcast: shutdownBroadcast,
  255. sshHostKey: signer,
  256. acceptedClientCounts: make(map[string]map[string]int64),
  257. clients: make(map[string]*sshClient),
  258. oslSessionCache: oslSessionCache,
  259. }, nil
  260. }
  261. func (sshServer *sshServer) setEstablishTunnels(establish bool) {
  262. // Do nothing when the setting is already correct. This avoids
  263. // spurious log messages when setEstablishTunnels is called
  264. // periodically with the same setting.
  265. if establish == sshServer.getEstablishTunnels() {
  266. return
  267. }
  268. establishFlag := int32(1)
  269. if !establish {
  270. establishFlag = 0
  271. }
  272. atomic.StoreInt32(&sshServer.establishTunnels, establishFlag)
  273. log.WithContextFields(
  274. LogFields{"establish": establish}).Info("establishing tunnels")
  275. }
  276. func (sshServer *sshServer) getEstablishTunnels() bool {
  277. return atomic.LoadInt32(&sshServer.establishTunnels) == 1
  278. }
  279. // runListener is intended to run an a goroutine; it blocks
  280. // running a particular listener. If an unrecoverable error
  281. // occurs, it will send the error to the listenerError channel.
  282. func (sshServer *sshServer) runListener(
  283. listener net.Listener,
  284. listenerError chan<- error,
  285. listenerTunnelProtocol string) {
  286. runningProtocols := make([]string, 0)
  287. for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
  288. runningProtocols = append(runningProtocols, tunnelProtocol)
  289. }
  290. handleClient := func(clientTunnelProtocol string, clientConn net.Conn) {
  291. // Note: establish tunnel limiter cannot simply stop TCP
  292. // listeners in all cases (e.g., meek) since SSH tunnel can
  293. // span multiple TCP connections.
  294. if !sshServer.getEstablishTunnels() {
  295. log.WithContext().Debug("not establishing tunnels")
  296. clientConn.Close()
  297. return
  298. }
  299. // The tunnelProtocol passed to handleClient is used for stats,
  300. // throttling, etc. When the tunnel protocol can be determined
  301. // unambiguously from the listening port, use that protocol and
  302. // don't use any client-declared value. Only use the client's
  303. // value, if present, in special cases where the listenting port
  304. // cannot distinguish the protocol.
  305. tunnelProtocol := listenerTunnelProtocol
  306. if clientTunnelProtocol != "" &&
  307. protocol.UseClientTunnelProtocol(
  308. clientTunnelProtocol, runningProtocols) {
  309. tunnelProtocol = clientTunnelProtocol
  310. }
  311. // process each client connection concurrently
  312. go sshServer.handleClient(tunnelProtocol, clientConn)
  313. }
  314. // Note: when exiting due to a unrecoverable error, be sure
  315. // to try to send the error to listenerError so that the outer
  316. // TunnelServer.Run will properly shut down instead of remaining
  317. // running.
  318. if protocol.TunnelProtocolUsesMeekHTTP(listenerTunnelProtocol) ||
  319. protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol) {
  320. meekServer, err := NewMeekServer(
  321. sshServer.support,
  322. listener,
  323. protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol),
  324. protocol.TunnelProtocolUsesObfuscatedSessionTickets(listenerTunnelProtocol),
  325. handleClient,
  326. sshServer.shutdownBroadcast)
  327. if err != nil {
  328. select {
  329. case listenerError <- common.ContextError(err):
  330. default:
  331. }
  332. return
  333. }
  334. meekServer.Run()
  335. } else {
  336. for {
  337. conn, err := listener.Accept()
  338. select {
  339. case <-sshServer.shutdownBroadcast:
  340. if err == nil {
  341. conn.Close()
  342. }
  343. return
  344. default:
  345. }
  346. if err != nil {
  347. if e, ok := err.(net.Error); ok && e.Temporary() {
  348. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  349. // Temporary error, keep running
  350. continue
  351. }
  352. select {
  353. case listenerError <- common.ContextError(err):
  354. default:
  355. }
  356. return
  357. }
  358. handleClient("", conn)
  359. }
  360. }
  361. }
  362. // An accepted client has completed a direct TCP or meek connection and has a net.Conn. Registration
  363. // is for tracking the number of connections.
  364. func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol, region string) {
  365. sshServer.clientsMutex.Lock()
  366. defer sshServer.clientsMutex.Unlock()
  367. if sshServer.acceptedClientCounts[tunnelProtocol] == nil {
  368. sshServer.acceptedClientCounts[tunnelProtocol] = make(map[string]int64)
  369. }
  370. sshServer.acceptedClientCounts[tunnelProtocol][region] += 1
  371. }
  372. func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol, region string) {
  373. sshServer.clientsMutex.Lock()
  374. defer sshServer.clientsMutex.Unlock()
  375. sshServer.acceptedClientCounts[tunnelProtocol][region] -= 1
  376. }
  377. // An established client has completed its SSH handshake and has a ssh.Conn. Registration is
  378. // for tracking the number of fully established clients and for maintaining a list of running
  379. // clients (for stopping at shutdown time).
  380. func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
  381. sshServer.clientsMutex.Lock()
  382. if sshServer.stoppingClients {
  383. sshServer.clientsMutex.Unlock()
  384. return false
  385. }
  386. // In the case of a duplicate client sessionID, the previous client is closed.
  387. // - Well-behaved clients generate pick a random sessionID that should be
  388. // unique (won't accidentally conflict) and hard to guess (can't be targeted
  389. // by a malicious client).
  390. // - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected
  391. // and resestablished. In this case, when the same server is selected, this logic
  392. // will be hit; closing the old, dangling client is desirable.
  393. // - Multi-tunnel clients should not normally use one server for multiple tunnels.
  394. existingClient := sshServer.clients[client.sessionID]
  395. sshServer.clients[client.sessionID] = client
  396. sshServer.clientsMutex.Unlock()
  397. // Call stop() outside the mutex to avoid deadlock.
  398. if existingClient != nil {
  399. existingClient.stop()
  400. log.WithContext().Info(
  401. "stopped existing client with duplicate session ID")
  402. }
  403. return true
  404. }
  405. func (sshServer *sshServer) unregisterEstablishedClient(client *sshClient) {
  406. sshServer.clientsMutex.Lock()
  407. registeredClient := sshServer.clients[client.sessionID]
  408. // registeredClient will differ from client when client
  409. // is the existingClient terminated in registerEstablishedClient.
  410. // In that case, registeredClient remains connected, and
  411. // the sshServer.clients entry should be retained.
  412. if registeredClient == client {
  413. delete(sshServer.clients, client.sessionID)
  414. }
  415. sshServer.clientsMutex.Unlock()
  416. // Call stop() outside the mutex to avoid deadlock.
  417. client.stop()
  418. }
  419. type ProtocolStats map[string]map[string]int64
  420. type RegionStats map[string]map[string]map[string]int64
  421. func (sshServer *sshServer) getLoadStats() (ProtocolStats, RegionStats) {
  422. sshServer.clientsMutex.Lock()
  423. defer sshServer.clientsMutex.Unlock()
  424. // Explicitly populate with zeros to ensure 0 counts in log messages
  425. zeroStats := func() map[string]int64 {
  426. stats := make(map[string]int64)
  427. stats["accepted_clients"] = 0
  428. stats["established_clients"] = 0
  429. stats["dialing_tcp_port_forwards"] = 0
  430. stats["tcp_port_forwards"] = 0
  431. stats["total_tcp_port_forwards"] = 0
  432. stats["udp_port_forwards"] = 0
  433. stats["total_udp_port_forwards"] = 0
  434. stats["tcp_port_forward_dialed_count"] = 0
  435. stats["tcp_port_forward_dialed_duration"] = 0
  436. stats["tcp_port_forward_failed_count"] = 0
  437. stats["tcp_port_forward_failed_duration"] = 0
  438. stats["tcp_port_forward_rejected_dialing_limit_count"] = 0
  439. return stats
  440. }
  441. zeroProtocolStats := func() map[string]map[string]int64 {
  442. stats := make(map[string]map[string]int64)
  443. stats["ALL"] = zeroStats()
  444. for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
  445. stats[tunnelProtocol] = zeroStats()
  446. }
  447. return stats
  448. }
  449. // [<protocol or ALL>][<stat name>] -> count
  450. protocolStats := zeroProtocolStats()
  451. // [<region][<protocol or ALL>][<stat name>] -> count
  452. regionStats := make(RegionStats)
  453. // Note: as currently tracked/counted, each established client is also an accepted client
  454. for tunnelProtocol, regionAcceptedClientCounts := range sshServer.acceptedClientCounts {
  455. for region, acceptedClientCount := range regionAcceptedClientCounts {
  456. if acceptedClientCount > 0 {
  457. if regionStats[region] == nil {
  458. regionStats[region] = zeroProtocolStats()
  459. }
  460. protocolStats["ALL"]["accepted_clients"] += acceptedClientCount
  461. protocolStats[tunnelProtocol]["accepted_clients"] += acceptedClientCount
  462. regionStats[region]["ALL"]["accepted_clients"] += acceptedClientCount
  463. regionStats[region][tunnelProtocol]["accepted_clients"] += acceptedClientCount
  464. }
  465. }
  466. }
  467. for _, client := range sshServer.clients {
  468. client.Lock()
  469. tunnelProtocol := client.tunnelProtocol
  470. region := client.geoIPData.Country
  471. if regionStats[region] == nil {
  472. regionStats[region] = zeroProtocolStats()
  473. }
  474. stats := []map[string]int64{
  475. protocolStats["ALL"],
  476. protocolStats[tunnelProtocol],
  477. regionStats[region]["ALL"],
  478. regionStats[region][tunnelProtocol]}
  479. for _, stat := range stats {
  480. stat["established_clients"] += 1
  481. // Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
  482. stat["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
  483. stat["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
  484. stat["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
  485. // client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
  486. stat["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
  487. stat["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
  488. stat["tcp_port_forward_dialed_count"] += client.qualityMetrics.tcpPortForwardDialedCount
  489. stat["tcp_port_forward_dialed_duration"] +=
  490. int64(client.qualityMetrics.tcpPortForwardDialedDuration / time.Millisecond)
  491. stat["tcp_port_forward_failed_count"] += client.qualityMetrics.tcpPortForwardFailedCount
  492. stat["tcp_port_forward_failed_duration"] +=
  493. int64(client.qualityMetrics.tcpPortForwardFailedDuration / time.Millisecond)
  494. stat["tcp_port_forward_rejected_dialing_limit_count"] +=
  495. client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount
  496. }
  497. client.qualityMetrics.tcpPortForwardDialedCount = 0
  498. client.qualityMetrics.tcpPortForwardDialedDuration = 0
  499. client.qualityMetrics.tcpPortForwardFailedCount = 0
  500. client.qualityMetrics.tcpPortForwardFailedDuration = 0
  501. client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount = 0
  502. client.Unlock()
  503. }
  504. return protocolStats, regionStats
  505. }
  506. func (sshServer *sshServer) resetAllClientTrafficRules() {
  507. sshServer.clientsMutex.Lock()
  508. clients := make(map[string]*sshClient)
  509. for sessionID, client := range sshServer.clients {
  510. clients[sessionID] = client
  511. }
  512. sshServer.clientsMutex.Unlock()
  513. for _, client := range clients {
  514. client.setTrafficRules()
  515. }
  516. }
  517. func (sshServer *sshServer) resetAllClientOSLConfigs() {
  518. // Flush cached seed state. This has the same effect
  519. // and same limitations as calling setOSLConfig for
  520. // currently connected clients -- all progress is lost.
  521. sshServer.oslSessionCacheMutex.Lock()
  522. sshServer.oslSessionCache.Flush()
  523. sshServer.oslSessionCacheMutex.Unlock()
  524. sshServer.clientsMutex.Lock()
  525. clients := make(map[string]*sshClient)
  526. for sessionID, client := range sshServer.clients {
  527. clients[sessionID] = client
  528. }
  529. sshServer.clientsMutex.Unlock()
  530. for _, client := range clients {
  531. client.setOSLConfig()
  532. }
  533. }
  534. func (sshServer *sshServer) setClientHandshakeState(
  535. sessionID string, state handshakeState) error {
  536. sshServer.clientsMutex.Lock()
  537. client := sshServer.clients[sessionID]
  538. sshServer.clientsMutex.Unlock()
  539. if client == nil {
  540. return common.ContextError(errors.New("unknown session ID"))
  541. }
  542. err := client.setHandshakeState(state)
  543. if err != nil {
  544. return common.ContextError(err)
  545. }
  546. return nil
  547. }
  548. func (sshServer *sshServer) getClientHandshaked(
  549. sessionID string) (bool, bool, error) {
  550. sshServer.clientsMutex.Lock()
  551. client := sshServer.clients[sessionID]
  552. sshServer.clientsMutex.Unlock()
  553. if client == nil {
  554. return false, false, common.ContextError(errors.New("unknown session ID"))
  555. }
  556. completed, exhausted := client.getHandshaked()
  557. return completed, exhausted, nil
  558. }
  559. func (sshServer *sshServer) stopClients() {
  560. sshServer.clientsMutex.Lock()
  561. sshServer.stoppingClients = true
  562. clients := sshServer.clients
  563. sshServer.clients = make(map[string]*sshClient)
  564. sshServer.clientsMutex.Unlock()
  565. for _, client := range clients {
  566. client.stop()
  567. }
  568. }
  569. func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
  570. geoIPData := sshServer.support.GeoIPService.Lookup(
  571. common.IPAddressFromAddr(clientConn.RemoteAddr()))
  572. sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
  573. defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
  574. sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
  575. sshClient.run(clientConn)
  576. }
  577. func (sshServer *sshServer) monitorPortForwardDialError(err error) {
  578. // "err" is the error returned from a failed TCP or UDP port
  579. // forward dial. Certain system error codes indicate low resource
  580. // conditions: insufficient file descriptors, ephemeral ports, or
  581. // memory. For these cases, log an alert.
  582. // TODO: also temporarily suspend new clients
  583. // Note: don't log net.OpError.Error() as the full error string
  584. // may contain client destination addresses.
  585. opErr, ok := err.(*net.OpError)
  586. if ok {
  587. if opErr.Err == syscall.EADDRNOTAVAIL ||
  588. opErr.Err == syscall.EAGAIN ||
  589. opErr.Err == syscall.ENOMEM ||
  590. opErr.Err == syscall.EMFILE ||
  591. opErr.Err == syscall.ENFILE {
  592. log.WithContextFields(
  593. LogFields{"error": opErr.Err}).Error(
  594. "port forward dial failed due to unavailable resource")
  595. }
  596. }
  597. }
  598. type sshClient struct {
  599. sync.Mutex
  600. sshServer *sshServer
  601. tunnelProtocol string
  602. sshConn ssh.Conn
  603. activityConn *common.ActivityMonitoredConn
  604. throttledConn *common.ThrottledConn
  605. geoIPData GeoIPData
  606. sessionID string
  607. supportsServerRequests bool
  608. handshakeState handshakeState
  609. udpChannel ssh.Channel
  610. packetTunnelChannel ssh.Channel
  611. trafficRules TrafficRules
  612. tcpTrafficState trafficState
  613. udpTrafficState trafficState
  614. qualityMetrics qualityMetrics
  615. tcpPortForwardLRU *common.LRUConns
  616. oslClientSeedState *osl.ClientSeedState
  617. signalIssueSLOKs chan struct{}
  618. runContext context.Context
  619. stopRunning context.CancelFunc
  620. tcpPortForwardDialingAvailableSignal context.CancelFunc
  621. }
  622. type trafficState struct {
  623. bytesUp int64
  624. bytesDown int64
  625. concurrentDialingPortForwardCount int64
  626. peakConcurrentDialingPortForwardCount int64
  627. concurrentPortForwardCount int64
  628. peakConcurrentPortForwardCount int64
  629. totalPortForwardCount int64
  630. availablePortForwardCond *sync.Cond
  631. }
  632. // qualityMetrics records upstream TCP dial attempts and
  633. // elapsed time. Elapsed time includes the full TCP handshake
  634. // and, in aggregate, is a measure of the quality of the
  635. // upstream link. These stats are recorded by each sshClient
  636. // and then reported and reset in sshServer.getLoadStats().
  637. type qualityMetrics struct {
  638. tcpPortForwardDialedCount int64
  639. tcpPortForwardDialedDuration time.Duration
  640. tcpPortForwardFailedCount int64
  641. tcpPortForwardFailedDuration time.Duration
  642. tcpPortForwardRejectedDialingLimitCount int64
  643. }
  644. type handshakeState struct {
  645. completed bool
  646. apiProtocol string
  647. apiParams requestJSONObject
  648. }
  649. func newSshClient(
  650. sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
  651. runContext, stopRunning := context.WithCancel(context.Background())
  652. client := &sshClient{
  653. sshServer: sshServer,
  654. tunnelProtocol: tunnelProtocol,
  655. geoIPData: geoIPData,
  656. tcpPortForwardLRU: common.NewLRUConns(),
  657. signalIssueSLOKs: make(chan struct{}, 1),
  658. runContext: runContext,
  659. stopRunning: stopRunning,
  660. }
  661. client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
  662. client.udpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
  663. return client
  664. }
  665. func (sshClient *sshClient) run(clientConn net.Conn) {
  666. // Some conns report additional metrics
  667. metricsSource, isMetricsSource := clientConn.(MetricsSource)
  668. // Set initial traffic rules, pre-handshake, based on currently known info.
  669. sshClient.setTrafficRules()
  670. // Wrap the base client connection with an ActivityMonitoredConn which will
  671. // terminate the connection if no data is received before the deadline. This
  672. // timeout is in effect for the entire duration of the SSH connection. Clients
  673. // must actively use the connection or send SSH keep alive requests to keep
  674. // the connection active. Writes are not considered reliable activity indicators
  675. // due to buffering.
  676. activityConn, err := common.NewActivityMonitoredConn(
  677. clientConn,
  678. SSH_CONNECTION_READ_DEADLINE,
  679. false,
  680. nil,
  681. nil)
  682. if err != nil {
  683. clientConn.Close()
  684. log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
  685. return
  686. }
  687. clientConn = activityConn
  688. // Further wrap the connection in a rate limiting ThrottledConn.
  689. throttledConn := common.NewThrottledConn(clientConn, sshClient.rateLimits())
  690. clientConn = throttledConn
  691. // Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  692. // respect shutdownBroadcast and implement a specific handshake timeout.
  693. // The timeout is to reclaim network resources in case the handshake takes
  694. // too long.
  695. type sshNewServerConnResult struct {
  696. conn net.Conn
  697. sshConn *ssh.ServerConn
  698. channels <-chan ssh.NewChannel
  699. requests <-chan *ssh.Request
  700. err error
  701. }
  702. resultChannel := make(chan *sshNewServerConnResult, 2)
  703. if SSH_HANDSHAKE_TIMEOUT > 0 {
  704. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  705. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  706. })
  707. }
  708. go func(conn net.Conn) {
  709. sshServerConfig := &ssh.ServerConfig{
  710. PasswordCallback: sshClient.passwordCallback,
  711. AuthLogCallback: sshClient.authLogCallback,
  712. ServerVersion: sshClient.sshServer.support.Config.SSHServerVersion,
  713. }
  714. sshServerConfig.AddHostKey(sshClient.sshServer.sshHostKey)
  715. result := &sshNewServerConnResult{}
  716. // Wrap the connection in an SSH deobfuscator when required.
  717. if protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
  718. // Note: NewObfuscatedSshConn blocks on network I/O
  719. // TODO: ensure this won't block shutdown
  720. conn, result.err = common.NewObfuscatedSshConn(
  721. common.OBFUSCATION_CONN_MODE_SERVER,
  722. conn,
  723. sshClient.sshServer.support.Config.ObfuscatedSSHKey)
  724. if result.err != nil {
  725. result.err = common.ContextError(result.err)
  726. }
  727. }
  728. if result.err == nil {
  729. result.sshConn, result.channels, result.requests, result.err =
  730. ssh.NewServerConn(conn, sshServerConfig)
  731. }
  732. resultChannel <- result
  733. }(clientConn)
  734. var result *sshNewServerConnResult
  735. select {
  736. case result = <-resultChannel:
  737. case <-sshClient.sshServer.shutdownBroadcast:
  738. // Close() will interrupt an ongoing handshake
  739. // TODO: wait for goroutine to exit before returning?
  740. clientConn.Close()
  741. return
  742. }
  743. if result.err != nil {
  744. clientConn.Close()
  745. // This is a Debug log due to noise. The handshake often fails due to I/O
  746. // errors as clients frequently interrupt connections in progress when
  747. // client-side load balancing completes a connection to a different server.
  748. log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed")
  749. return
  750. }
  751. sshClient.Lock()
  752. sshClient.sshConn = result.sshConn
  753. sshClient.activityConn = activityConn
  754. sshClient.throttledConn = throttledConn
  755. sshClient.Unlock()
  756. if !sshClient.sshServer.registerEstablishedClient(sshClient) {
  757. clientConn.Close()
  758. log.WithContext().Warning("register failed")
  759. return
  760. }
  761. sshClient.runTunnel(result.channels, result.requests)
  762. // Note: sshServer.unregisterEstablishedClient calls sshClient.stop(),
  763. // which also closes underlying transport Conn.
  764. sshClient.sshServer.unregisterEstablishedClient(sshClient)
  765. var additionalMetrics LogFields
  766. if isMetricsSource {
  767. additionalMetrics = metricsSource.GetMetrics()
  768. }
  769. sshClient.logTunnel(additionalMetrics)
  770. // Transfer OSL seed state -- the OSL progress -- from the closing
  771. // client to the session cache so the client can resume its progress
  772. // if it reconnects to this same server.
  773. // Note: following setOSLConfig order of locking.
  774. sshClient.Lock()
  775. if sshClient.oslClientSeedState != nil {
  776. sshClient.sshServer.oslSessionCacheMutex.Lock()
  777. sshClient.oslClientSeedState.Hibernate()
  778. sshClient.sshServer.oslSessionCache.Set(
  779. sshClient.sessionID, sshClient.oslClientSeedState, cache.DefaultExpiration)
  780. sshClient.sshServer.oslSessionCacheMutex.Unlock()
  781. sshClient.oslClientSeedState = nil
  782. }
  783. sshClient.Unlock()
  784. // Initiate cleanup of the GeoIP session cache. To allow for post-tunnel
  785. // final status requests, the lifetime of cached GeoIP records exceeds the
  786. // lifetime of the sshClient.
  787. sshClient.sshServer.support.GeoIPService.MarkSessionCacheToExpire(sshClient.sessionID)
  788. }
  789. func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  790. expectedSessionIDLength := 2 * protocol.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
  791. expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
  792. var sshPasswordPayload protocol.SSHPasswordPayload
  793. err := json.Unmarshal(password, &sshPasswordPayload)
  794. if err != nil {
  795. // Backwards compatibility case: instead of a JSON payload, older clients
  796. // send the hex encoded session ID prepended to the SSH password.
  797. // Note: there's an even older case where clients don't send any session ID,
  798. // but that's no longer supported.
  799. if len(password) == expectedSessionIDLength+expectedSSHPasswordLength {
  800. sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength])
  801. sshPasswordPayload.SshPassword = string(password[expectedSSHPasswordLength:])
  802. } else {
  803. return nil, common.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  804. }
  805. }
  806. if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) ||
  807. len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
  808. return nil, common.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
  809. }
  810. userOk := (subtle.ConstantTimeCompare(
  811. []byte(conn.User()), []byte(sshClient.sshServer.support.Config.SSHUserName)) == 1)
  812. passwordOk := (subtle.ConstantTimeCompare(
  813. []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1)
  814. if !userOk || !passwordOk {
  815. return nil, common.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  816. }
  817. sessionID := sshPasswordPayload.SessionId
  818. supportsServerRequests := common.Contains(
  819. sshPasswordPayload.ClientCapabilities, protocol.CLIENT_CAPABILITY_SERVER_REQUESTS)
  820. sshClient.Lock()
  821. sshClient.sessionID = sessionID
  822. sshClient.supportsServerRequests = supportsServerRequests
  823. geoIPData := sshClient.geoIPData
  824. sshClient.Unlock()
  825. // Store the GeoIP data associated with the session ID. This makes
  826. // the GeoIP data available to the web server for web API requests.
  827. // A cache that's distinct from the sshClient record is used to allow
  828. // for or post-tunnel final status requests.
  829. // If the client is reconnecting with the same session ID, this call
  830. // will undo the expiry set by MarkSessionCacheToExpire.
  831. sshClient.sshServer.support.GeoIPService.SetSessionCache(sessionID, geoIPData)
  832. return nil, nil
  833. }
  834. func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  835. if err != nil {
  836. if method == "none" && err.Error() == "no auth passed yet" {
  837. // In this case, the callback invocation is noise from auth negotiation
  838. return
  839. }
  840. // Note: here we previously logged messages for fail2ban to act on. This is no longer
  841. // done as the complexity outweighs the benefits.
  842. //
  843. // - The SSH credential is not secret -- it's in the server entry. Attackers targeting
  844. // the server likely already have the credential. On the other hand, random scanning and
  845. // brute forcing is mitigated with high entropy random passwords, rate limiting
  846. // (implemented on the host via iptables), and limited capabilities (the SSH session can
  847. // only port forward).
  848. //
  849. // - fail2ban coverage was inconsistent; in the case of an unfronted meek protocol through
  850. // an upstream proxy, the remote address is the upstream proxy, which should not be blocked.
  851. // The X-Forwarded-For header cant be used instead as it may be forged and used to get IPs
  852. // deliberately blocked; and in any case fail2ban adds iptables rules which can only block
  853. // by direct remote IP, not by original client IP. Fronted meek has the same iptables issue.
  854. //
  855. // Random scanning and brute forcing of port 22 will result in log noise. To mitigate this,
  856. // not every authentication failure is logged. A summary log is emitted periodically to
  857. // retain some record of this activity in case this is relevant to, e.g., a performance
  858. // investigation.
  859. atomic.AddInt64(&sshClient.sshServer.authFailedCount, 1)
  860. lastAuthLog := monotime.Time(atomic.LoadInt64(&sshClient.sshServer.lastAuthLog))
  861. if monotime.Since(lastAuthLog) > SSH_AUTH_LOG_PERIOD {
  862. now := int64(monotime.Now())
  863. if atomic.CompareAndSwapInt64(&sshClient.sshServer.lastAuthLog, int64(lastAuthLog), now) {
  864. count := atomic.SwapInt64(&sshClient.sshServer.authFailedCount, 0)
  865. log.WithContextFields(
  866. LogFields{"lastError": err, "failedCount": count}).Warning("authentication failures")
  867. }
  868. }
  869. log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
  870. } else {
  871. log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
  872. }
  873. }
  874. // stop signals the ssh connection to shutdown. After sshConn() returns,
  875. // the connection has terminated but sshClient.run() may still be
  876. // running and in the process of exiting.
  877. func (sshClient *sshClient) stop() {
  878. sshClient.sshConn.Close()
  879. sshClient.sshConn.Wait()
  880. }
  881. // runTunnel handles/dispatches new channels and new requests from the client.
  882. // When the SSH client connection closes, both the channels and requests channels
  883. // will close and runTunnel will exit.
  884. func (sshClient *sshClient) runTunnel(
  885. channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
  886. waitGroup := new(sync.WaitGroup)
  887. // Start client SSH API request handler
  888. waitGroup.Add(1)
  889. go func() {
  890. defer waitGroup.Done()
  891. for request := range requests {
  892. // Requests are processed serially; API responses must be sent in request order.
  893. var responsePayload []byte
  894. var err error
  895. if request.Type == "keepalive@openssh.com" {
  896. // Keepalive requests have an empty response.
  897. } else {
  898. // All other requests are assumed to be API requests.
  899. responsePayload, err = sshAPIRequestHandler(
  900. sshClient.sshServer.support,
  901. sshClient.geoIPData,
  902. request.Type,
  903. request.Payload)
  904. }
  905. if err == nil {
  906. err = request.Reply(true, responsePayload)
  907. } else {
  908. log.WithContextFields(LogFields{"error": err}).Warning("request failed")
  909. err = request.Reply(false, nil)
  910. }
  911. if err != nil {
  912. log.WithContextFields(LogFields{"error": err}).Warning("response failed")
  913. }
  914. }
  915. }()
  916. // Start OSL sender
  917. if sshClient.supportsServerRequests {
  918. waitGroup.Add(1)
  919. go func() {
  920. defer waitGroup.Done()
  921. sshClient.runOSLSender()
  922. }()
  923. }
  924. // Lifecycle of a TCP port forward:
  925. //
  926. // 1. A "direct-tcpip" SSH request is received from the client.
  927. //
  928. // A new TCP port forward request is enqueued. The queue delivers TCP port
  929. // forward requests to the TCP port forward manager, which enforces the TCP
  930. // port forward dial limit.
  931. //
  932. // Enqueuing new requests allows for reading further SSH requests from the
  933. // client without blocking when the dial limit is hit; this is to permit new
  934. // UDP/udpgw port forwards to be restablished without delay. The maximum size
  935. // of the queue enforces a hard cap on resources consumed by a client in the
  936. // pre-dial phase. When the queue is full, new TCP port forwards are
  937. // immediately rejected.
  938. //
  939. // 2. The TCP port forward manager dequeues the request.
  940. //
  941. // The manager calls dialingTCPPortForward(), which increments
  942. // concurrentDialingPortForwardCount, and calls
  943. // isTCPDialingPortForwardLimitExceeded() to check the concurrent dialing
  944. // count.
  945. //
  946. // The manager enforces the concurrent TCP dial limit: when at the limit, the
  947. // manager blocks waiting for the number of dials to drop below the limit before
  948. // dispatching the request to handleTCPPortForward(), which will run in its own
  949. // goroutine and will dial and relay the port forward.
  950. //
  951. // The block delays the current request and also halts dequeuing of subsequent
  952. // requests and could ultimately cause requests to be immediately rejected if
  953. // the queue fills. These actions are intended to apply back pressure when
  954. // upstream network resources are impaired.
  955. //
  956. // The time spent in the queue is deducted from the port forward's dial timeout.
  957. // The time spent blocking while at the dial limit is similarly deducted from
  958. // the dial timeout. If the dial timeout has expired before the dial begins, the
  959. // port forward is rejected and a stat is recorded.
  960. //
  961. // 3. handleTCPPortForward() performs the port forward dial and relaying.
  962. //
  963. // a. Dial the target, using the dial timeout remaining after queue and blocking
  964. // time is deducted.
  965. //
  966. // b. If the dial fails, call abortedTCPPortForward() to decrement
  967. // concurrentDialingPortForwardCount, freeing up a dial slot.
  968. //
  969. // c. If the dial succeeds, call establishedPortForward(), which decrements
  970. // concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
  971. // the "established" port forward count.
  972. //
  973. // d. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
  974. // concurrentPortForwardCount, the number of _established_ TCP port forwards.
  975. // If the limit is exceeded, the LRU established TCP port forward is closed and
  976. // the newly established TCP port forward proceeds. This LRU logic allows some
  977. // dangling resource consumption (e.g., TIME_WAIT) while providing a better
  978. // experience for clients.
  979. //
  980. // e. Relay data.
  981. //
  982. // f. Call closedPortForward() which decrements concurrentPortForwardCount and
  983. // records bytes transferred.
  984. // Start the TCP port forward manager
  985. type newTCPPortForward struct {
  986. enqueueTime monotime.Time
  987. hostToConnect string
  988. portToConnect int
  989. newChannel ssh.NewChannel
  990. }
  991. // The queue size is set to the traffic rules (MaxTCPPortForwardCount +
  992. // MaxTCPDialingPortForwardCount), which is a reasonable indication of resource
  993. // limits per client; when that value is not set, a default is used.
  994. // A limitation: this queue size is set once and doesn't change, for this client,
  995. // when traffic rules are reloaded.
  996. queueSize := sshClient.getTCPPortForwardQueueSize()
  997. if queueSize == 0 {
  998. queueSize = SSH_TCP_PORT_FORWARD_QUEUE_SIZE
  999. }
  1000. newTCPPortForwards := make(chan *newTCPPortForward, queueSize)
  1001. waitGroup.Add(1)
  1002. go func() {
  1003. defer waitGroup.Done()
  1004. for newPortForward := range newTCPPortForwards {
  1005. remainingDialTimeout :=
  1006. time.Duration(sshClient.getDialTCPPortForwardTimeoutMilliseconds())*time.Millisecond -
  1007. monotime.Since(newPortForward.enqueueTime)
  1008. if remainingDialTimeout <= 0 {
  1009. sshClient.updateQualityMetricsWithRejectedDialingLimit()
  1010. sshClient.rejectNewChannel(
  1011. newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out in queue")
  1012. continue
  1013. }
  1014. // Reserve a TCP dialing slot.
  1015. //
  1016. // TOCTOU note: important to increment counts _before_ checking limits; otherwise,
  1017. // the client could potentially consume excess resources by initiating many port
  1018. // forwards concurrently.
  1019. sshClient.dialingTCPPortForward()
  1020. // When max dials are in progress, wait up to remainingDialTimeout for dialing
  1021. // to become available. This blocks all dequeing.
  1022. if sshClient.isTCPDialingPortForwardLimitExceeded() {
  1023. blockStartTime := monotime.Now()
  1024. ctx, cancelCtx := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
  1025. sshClient.setTCPPortForwardDialingAvailableSignal(cancelCtx)
  1026. <-ctx.Done()
  1027. sshClient.setTCPPortForwardDialingAvailableSignal(nil)
  1028. cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  1029. remainingDialTimeout -= monotime.Since(blockStartTime)
  1030. }
  1031. if remainingDialTimeout <= 0 {
  1032. // Release the dialing slot here since handleTCPChannel() won't be called.
  1033. sshClient.abortedTCPPortForward()
  1034. sshClient.updateQualityMetricsWithRejectedDialingLimit()
  1035. sshClient.rejectNewChannel(
  1036. newPortForward.newChannel, ssh.Prohibited, "TCP port forward timed out before dialing")
  1037. continue
  1038. }
  1039. // Dial and relay the TCP port forward. handleTCPChannel is run in its own worker goroutine.
  1040. // handleTCPChannel will release the dialing slot reserved by dialingTCPPortForward(); and
  1041. // will deal with remainingDialTimeout <= 0.
  1042. waitGroup.Add(1)
  1043. go func(remainingDialTimeout time.Duration, newPortForward *newTCPPortForward) {
  1044. defer waitGroup.Done()
  1045. sshClient.handleTCPChannel(
  1046. remainingDialTimeout,
  1047. newPortForward.hostToConnect,
  1048. newPortForward.portToConnect,
  1049. newPortForward.newChannel)
  1050. }(remainingDialTimeout, newPortForward)
  1051. }
  1052. }()
  1053. // Handle new channel (port forward) requests from the client.
  1054. //
  1055. // packet tunnel channels are handled by the packet tunnel server
  1056. // component. Each client may have at most one packet tunnel channel.
  1057. //
  1058. // udpgw client connections are dispatched immediately (clients use this for
  1059. // DNS, so it's essential to not block; and only one udpgw connection is
  1060. // retained at a time).
  1061. //
  1062. // All other TCP port forwards are dispatched via the TCP port forward
  1063. // manager queue.
  1064. for newChannel := range channels {
  1065. if newChannel.ChannelType() == protocol.PACKET_TUNNEL_CHANNEL_TYPE {
  1066. // Accept this channel immediately. This channel will replace any
  1067. // previously existing packet tunnel channel for this client.
  1068. packetTunnelChannel, requests, err := newChannel.Accept()
  1069. if err != nil {
  1070. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  1071. continue
  1072. }
  1073. go ssh.DiscardRequests(requests)
  1074. sshClient.setPacketTunnelChannel(packetTunnelChannel)
  1075. // PacketTunnelServer will run the client's packet tunnel. ClientDisconnected will
  1076. // be called by setPacketTunnelChannel: either if the client starts a new packet
  1077. // tunnel channel, or on exit of this function.
  1078. checkAllowedTCPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
  1079. return sshClient.isPortForwardPermitted(portForwardTypeTCP, false, upstreamIPAddress, port)
  1080. }
  1081. checkAllowedUDPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
  1082. return sshClient.isPortForwardPermitted(portForwardTypeUDP, false, upstreamIPAddress, port)
  1083. }
  1084. sshClient.sshServer.support.PacketTunnelServer.ClientConnected(
  1085. sshClient.sessionID,
  1086. packetTunnelChannel,
  1087. checkAllowedTCPPortFunc,
  1088. checkAllowedUDPPortFunc)
  1089. }
  1090. if newChannel.ChannelType() != "direct-tcpip" {
  1091. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  1092. continue
  1093. }
  1094. // http://tools.ietf.org/html/rfc4254#section-7.2
  1095. var directTcpipExtraData struct {
  1096. HostToConnect string
  1097. PortToConnect uint32
  1098. OriginatorIPAddress string
  1099. OriginatorPort uint32
  1100. }
  1101. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  1102. if err != nil {
  1103. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  1104. continue
  1105. }
  1106. // Intercept TCP port forwards to a specified udpgw server and handle directly.
  1107. // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
  1108. isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
  1109. sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
  1110. net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
  1111. if isUDPChannel {
  1112. // Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
  1113. // own worker goroutine.
  1114. waitGroup.Add(1)
  1115. go func(channel ssh.NewChannel) {
  1116. defer waitGroup.Done()
  1117. sshClient.handleUDPChannel(channel)
  1118. }(newChannel)
  1119. } else {
  1120. // Dispatch via TCP port forward manager. When the queue is full, the channel
  1121. // is immediately rejected.
  1122. tcpPortForward := &newTCPPortForward{
  1123. enqueueTime: monotime.Now(),
  1124. hostToConnect: directTcpipExtraData.HostToConnect,
  1125. portToConnect: int(directTcpipExtraData.PortToConnect),
  1126. newChannel: newChannel,
  1127. }
  1128. select {
  1129. case newTCPPortForwards <- tcpPortForward:
  1130. default:
  1131. sshClient.updateQualityMetricsWithRejectedDialingLimit()
  1132. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "TCP port forward dial queue full")
  1133. }
  1134. }
  1135. }
  1136. // The channel loop is interrupted by a client
  1137. // disconnect or by calling sshClient.stop().
  1138. // Stop the TCP port forward manager
  1139. close(newTCPPortForwards)
  1140. // Stop all other worker goroutines
  1141. sshClient.stopRunning()
  1142. // This calls PacketTunnelServer.ClientDisconnected,
  1143. // which stops packet tunnel workers.
  1144. sshClient.setPacketTunnelChannel(nil)
  1145. waitGroup.Wait()
  1146. }
  1147. // setPacketTunnelChannel sets the single packet tunnel channel
  1148. // for this sshClient. Any existing packet tunnel channel is
  1149. // closed and its underlying session idled.
  1150. func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
  1151. sshClient.Lock()
  1152. if sshClient.packetTunnelChannel != nil {
  1153. sshClient.packetTunnelChannel.Close()
  1154. sshClient.sshServer.support.PacketTunnelServer.ClientDisconnected(
  1155. sshClient.sessionID)
  1156. }
  1157. sshClient.packetTunnelChannel = channel
  1158. sshClient.Unlock()
  1159. }
  1160. // setUDPChannel sets the single UDP channel for this sshClient.
  1161. // Each sshClient may have only one concurrent UDP channel. Each
  1162. // UDP channel multiplexes many UDP port forwards via the udpgw
  1163. // protocol. Any existing UDP channel is closed.
  1164. func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
  1165. sshClient.Lock()
  1166. if sshClient.udpChannel != nil {
  1167. sshClient.udpChannel.Close()
  1168. }
  1169. sshClient.udpChannel = channel
  1170. sshClient.Unlock()
  1171. }
  1172. func (sshClient *sshClient) logTunnel(additionalMetrics LogFields) {
  1173. // Note: reporting duration based on last confirmed data transfer, which
  1174. // is reads for sshClient.activityConn.GetActiveDuration(), and not
  1175. // connection closing is important for protocols such as meek. For
  1176. // meek, the connection remains open until the HTTP session expires,
  1177. // which may be some time after the tunnel has closed. (The meek
  1178. // protocol has no allowance for signalling payload EOF, and even if
  1179. // it did the client may not have the opportunity to send a final
  1180. // request with an EOF flag set.)
  1181. sshClient.Lock()
  1182. logFields := getRequestLogFields(
  1183. sshClient.sshServer.support,
  1184. "server_tunnel",
  1185. sshClient.geoIPData,
  1186. sshClient.handshakeState.apiParams,
  1187. baseRequestParams)
  1188. logFields["handshake_completed"] = sshClient.handshakeState.completed
  1189. logFields["start_time"] = sshClient.activityConn.GetStartTime()
  1190. logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
  1191. logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
  1192. logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
  1193. logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
  1194. logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
  1195. logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
  1196. logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
  1197. logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
  1198. // sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
  1199. logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
  1200. logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
  1201. // Merge in additional metrics from the optional metrics source
  1202. if additionalMetrics != nil {
  1203. for name, value := range additionalMetrics {
  1204. // Don't overwrite any basic fields
  1205. if logFields[name] == nil {
  1206. logFields[name] = value
  1207. }
  1208. }
  1209. }
  1210. sshClient.Unlock()
  1211. log.LogRawFieldsWithTimestamp(logFields)
  1212. }
  1213. func (sshClient *sshClient) runOSLSender() {
  1214. for {
  1215. // Await a signal that there are SLOKs to send
  1216. // TODO: use reflect.SelectCase, and optionally await timer here?
  1217. select {
  1218. case <-sshClient.signalIssueSLOKs:
  1219. case <-sshClient.runContext.Done():
  1220. return
  1221. }
  1222. retryDelay := SSH_SEND_OSL_INITIAL_RETRY_DELAY
  1223. for {
  1224. err := sshClient.sendOSLRequest()
  1225. if err == nil {
  1226. break
  1227. }
  1228. log.WithContextFields(LogFields{"error": err}).Warning("sendOSLRequest failed")
  1229. // If the request failed, retry after a delay (with exponential backoff)
  1230. // or when signaled that there are additional SLOKs to send
  1231. retryTimer := time.NewTimer(retryDelay)
  1232. select {
  1233. case <-retryTimer.C:
  1234. case <-sshClient.signalIssueSLOKs:
  1235. case <-sshClient.runContext.Done():
  1236. retryTimer.Stop()
  1237. return
  1238. }
  1239. retryTimer.Stop()
  1240. retryDelay *= SSH_SEND_OSL_RETRY_FACTOR
  1241. }
  1242. }
  1243. }
  1244. // sendOSLRequest will invoke osl.GetSeedPayload to issue SLOKs and
  1245. // generate a payload, and send an OSL request to the client when
  1246. // there are new SLOKs in the payload.
  1247. func (sshClient *sshClient) sendOSLRequest() error {
  1248. seedPayload := sshClient.getOSLSeedPayload()
  1249. // Don't send when no SLOKs. This will happen when signalIssueSLOKs
  1250. // is received but no new SLOKs are issued.
  1251. if len(seedPayload.SLOKs) == 0 {
  1252. return nil
  1253. }
  1254. oslRequest := protocol.OSLRequest{
  1255. SeedPayload: seedPayload,
  1256. }
  1257. requestPayload, err := json.Marshal(oslRequest)
  1258. if err != nil {
  1259. return common.ContextError(err)
  1260. }
  1261. ok, _, err := sshClient.sshConn.SendRequest(
  1262. protocol.PSIPHON_API_OSL_REQUEST_NAME,
  1263. true,
  1264. requestPayload)
  1265. if err != nil {
  1266. return common.ContextError(err)
  1267. }
  1268. if !ok {
  1269. return common.ContextError(errors.New("client rejected request"))
  1270. }
  1271. sshClient.clearOSLSeedPayload()
  1272. return nil
  1273. }
  1274. func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, logMessage string) {
  1275. // Note: Debug level, as logMessage may contain user traffic destination address information
  1276. log.WithContextFields(
  1277. LogFields{
  1278. "channelType": newChannel.ChannelType(),
  1279. "logMessage": logMessage,
  1280. "rejectReason": reason.String(),
  1281. }).Debug("reject new channel")
  1282. // Note: logMessage is internal, for logging only; just the RejectionReason is sent to the client
  1283. newChannel.Reject(reason, reason.String())
  1284. }
  1285. // setHandshakeState records that a client has completed a handshake API request.
  1286. // Some parameters from the handshake request may be used in future traffic rule
  1287. // selection. Port forwards are disallowed until a handshake is complete. The
  1288. // handshake parameters are included in the session summary log recorded in
  1289. // sshClient.stop().
  1290. func (sshClient *sshClient) setHandshakeState(state handshakeState) error {
  1291. sshClient.Lock()
  1292. completed := sshClient.handshakeState.completed
  1293. if !completed {
  1294. sshClient.handshakeState = state
  1295. }
  1296. sshClient.Unlock()
  1297. // Client must only perform one handshake
  1298. if completed {
  1299. return common.ContextError(errors.New("handshake already completed"))
  1300. }
  1301. sshClient.setTrafficRules()
  1302. sshClient.setOSLConfig()
  1303. return nil
  1304. }
  1305. // getHandshaked returns whether the client has completed a handshake API
  1306. // request and whether the traffic rules that were selected after the
  1307. // handshake immediately exhaust the client.
  1308. //
  1309. // When the client is immediately exhausted it will be closed; but this
  1310. // takes effect asynchronously. The "exhausted" return value is used to
  1311. // prevent API requests by clients that will close.
  1312. func (sshClient *sshClient) getHandshaked() (bool, bool) {
  1313. sshClient.Lock()
  1314. defer sshClient.Unlock()
  1315. completed := sshClient.handshakeState.completed
  1316. exhausted := false
  1317. // Notes:
  1318. // - "Immediately exhausted" is when CloseAfterExhausted is set and
  1319. // either ReadUnthrottledBytes or WriteUnthrottledBytes starts from
  1320. // 0, so no bytes would be read or written. This check does not
  1321. // examine whether 0 bytes _remain_ in the ThrottledConn.
  1322. // - This check is made against the current traffic rules, which
  1323. // could have changed in a hot reload since the handshake.
  1324. if completed &&
  1325. *sshClient.trafficRules.RateLimits.CloseAfterExhausted == true &&
  1326. (*sshClient.trafficRules.RateLimits.ReadUnthrottledBytes == 0 ||
  1327. *sshClient.trafficRules.RateLimits.WriteUnthrottledBytes == 0) {
  1328. exhausted = true
  1329. }
  1330. return completed, exhausted
  1331. }
  1332. // setTrafficRules resets the client's traffic rules based on the latest server config
  1333. // and client properties. As sshClient.trafficRules may be reset by a concurrent
  1334. // goroutine, trafficRules must only be accessed within the sshClient mutex.
  1335. func (sshClient *sshClient) setTrafficRules() {
  1336. sshClient.Lock()
  1337. defer sshClient.Unlock()
  1338. sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
  1339. sshClient.tunnelProtocol, sshClient.geoIPData, sshClient.handshakeState)
  1340. if sshClient.throttledConn != nil {
  1341. // Any existing throttling state is reset.
  1342. sshClient.throttledConn.SetLimits(
  1343. sshClient.trafficRules.RateLimits.CommonRateLimits())
  1344. }
  1345. }
  1346. // setOSLConfig resets the client's OSL seed state based on the latest OSL config
  1347. // As sshClient.oslClientSeedState may be reset by a concurrent goroutine,
  1348. // oslClientSeedState must only be accessed within the sshClient mutex.
  1349. func (sshClient *sshClient) setOSLConfig() {
  1350. sshClient.Lock()
  1351. defer sshClient.Unlock()
  1352. propagationChannelID, err := getStringRequestParam(
  1353. sshClient.handshakeState.apiParams, "propagation_channel_id")
  1354. if err != nil {
  1355. // This should not fail as long as client has sent valid handshake
  1356. return
  1357. }
  1358. // Use a cached seed state if one is found for the client's
  1359. // session ID. This enables resuming progress made in a previous
  1360. // tunnel.
  1361. // Note: go-cache is already concurency safe; the additional mutex
  1362. // is necessary to guarantee that Get/Delete is atomic; although in
  1363. // practice no two concurrent clients should ever supply the same
  1364. // session ID.
  1365. sshClient.sshServer.oslSessionCacheMutex.Lock()
  1366. oslClientSeedState, found := sshClient.sshServer.oslSessionCache.Get(sshClient.sessionID)
  1367. if found {
  1368. sshClient.sshServer.oslSessionCache.Delete(sshClient.sessionID)
  1369. sshClient.sshServer.oslSessionCacheMutex.Unlock()
  1370. sshClient.oslClientSeedState = oslClientSeedState.(*osl.ClientSeedState)
  1371. sshClient.oslClientSeedState.Resume(sshClient.signalIssueSLOKs)
  1372. return
  1373. }
  1374. sshClient.sshServer.oslSessionCacheMutex.Unlock()
  1375. // Two limitations when setOSLConfig() is invoked due to an
  1376. // OSL config hot reload:
  1377. //
  1378. // 1. any partial progress towards SLOKs is lost.
  1379. //
  1380. // 2. all existing osl.ClientSeedPortForwards for existing
  1381. // port forwards will not send progress to the new client
  1382. // seed state.
  1383. sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
  1384. sshClient.geoIPData.Country,
  1385. propagationChannelID,
  1386. sshClient.signalIssueSLOKs)
  1387. }
  1388. // newClientSeedPortForward will return nil when no seeding is
  1389. // associated with the specified ipAddress.
  1390. func (sshClient *sshClient) newClientSeedPortForward(ipAddress net.IP) *osl.ClientSeedPortForward {
  1391. sshClient.Lock()
  1392. defer sshClient.Unlock()
  1393. // Will not be initialized before handshake.
  1394. if sshClient.oslClientSeedState == nil {
  1395. return nil
  1396. }
  1397. return sshClient.oslClientSeedState.NewClientSeedPortForward(ipAddress)
  1398. }
  1399. // getOSLSeedPayload returns a payload containing all seeded SLOKs for
  1400. // this client's session.
  1401. func (sshClient *sshClient) getOSLSeedPayload() *osl.SeedPayload {
  1402. sshClient.Lock()
  1403. defer sshClient.Unlock()
  1404. // Will not be initialized before handshake.
  1405. if sshClient.oslClientSeedState == nil {
  1406. return &osl.SeedPayload{SLOKs: make([]*osl.SLOK, 0)}
  1407. }
  1408. return sshClient.oslClientSeedState.GetSeedPayload()
  1409. }
  1410. func (sshClient *sshClient) clearOSLSeedPayload() {
  1411. sshClient.Lock()
  1412. defer sshClient.Unlock()
  1413. sshClient.oslClientSeedState.ClearSeedPayload()
  1414. }
  1415. func (sshClient *sshClient) rateLimits() common.RateLimits {
  1416. sshClient.Lock()
  1417. defer sshClient.Unlock()
  1418. return sshClient.trafficRules.RateLimits.CommonRateLimits()
  1419. }
  1420. func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
  1421. sshClient.Lock()
  1422. defer sshClient.Unlock()
  1423. return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
  1424. }
  1425. func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
  1426. sshClient.Lock()
  1427. defer sshClient.Unlock()
  1428. return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
  1429. }
  1430. func (sshClient *sshClient) setTCPPortForwardDialingAvailableSignal(signal context.CancelFunc) {
  1431. sshClient.Lock()
  1432. defer sshClient.Unlock()
  1433. sshClient.tcpPortForwardDialingAvailableSignal = signal
  1434. }
  1435. const (
  1436. portForwardTypeTCP = iota
  1437. portForwardTypeUDP
  1438. portForwardTypeTransparentDNS
  1439. )
  1440. func (sshClient *sshClient) isPortForwardPermitted(
  1441. portForwardType int,
  1442. isTransparentDNSForwarding bool,
  1443. remoteIP net.IP,
  1444. port int) bool {
  1445. sshClient.Lock()
  1446. defer sshClient.Unlock()
  1447. if !sshClient.handshakeState.completed {
  1448. return false
  1449. }
  1450. // Disallow connection to loopback. This is a failsafe. The server
  1451. // should be run on a host with correctly configured firewall rules.
  1452. // An exception is made in the case of tranparent DNS forwarding,
  1453. // where the remoteIP has been rewritten.
  1454. if !isTransparentDNSForwarding && remoteIP.IsLoopback() {
  1455. return false
  1456. }
  1457. var allowPorts []int
  1458. if portForwardType == portForwardTypeTCP {
  1459. allowPorts = sshClient.trafficRules.AllowTCPPorts
  1460. } else {
  1461. allowPorts = sshClient.trafficRules.AllowUDPPorts
  1462. }
  1463. if len(allowPorts) == 0 {
  1464. return true
  1465. }
  1466. // TODO: faster lookup?
  1467. if len(allowPorts) > 0 {
  1468. for _, allowPort := range allowPorts {
  1469. if port == allowPort {
  1470. return true
  1471. }
  1472. }
  1473. }
  1474. for _, subnet := range sshClient.trafficRules.AllowSubnets {
  1475. // Note: ignoring error as config has been validated
  1476. _, network, _ := net.ParseCIDR(subnet)
  1477. if network.Contains(remoteIP) {
  1478. return true
  1479. }
  1480. }
  1481. return false
  1482. }
  1483. func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
  1484. sshClient.Lock()
  1485. defer sshClient.Unlock()
  1486. state := &sshClient.tcpTrafficState
  1487. max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  1488. if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
  1489. return true
  1490. }
  1491. return false
  1492. }
  1493. func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
  1494. sshClient.Lock()
  1495. defer sshClient.Unlock()
  1496. return *sshClient.trafficRules.MaxTCPPortForwardCount +
  1497. *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  1498. }
  1499. func (sshClient *sshClient) getDialTCPPortForwardTimeoutMilliseconds() int {
  1500. sshClient.Lock()
  1501. defer sshClient.Unlock()
  1502. return *sshClient.trafficRules.DialTCPPortForwardTimeoutMilliseconds
  1503. }
  1504. func (sshClient *sshClient) dialingTCPPortForward() {
  1505. sshClient.Lock()
  1506. defer sshClient.Unlock()
  1507. state := &sshClient.tcpTrafficState
  1508. state.concurrentDialingPortForwardCount += 1
  1509. if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
  1510. state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
  1511. }
  1512. }
  1513. func (sshClient *sshClient) abortedTCPPortForward() {
  1514. sshClient.Lock()
  1515. defer sshClient.Unlock()
  1516. sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
  1517. }
  1518. func (sshClient *sshClient) allocatePortForward(portForwardType int) bool {
  1519. sshClient.Lock()
  1520. defer sshClient.Unlock()
  1521. // Check if at port forward limit. The subsequent counter
  1522. // changes must be atomic with the limit check to ensure
  1523. // the counter never exceeds the limit in the case of
  1524. // concurrent allocations.
  1525. var max int
  1526. var state *trafficState
  1527. if portForwardType == portForwardTypeTCP {
  1528. max = *sshClient.trafficRules.MaxTCPPortForwardCount
  1529. state = &sshClient.tcpTrafficState
  1530. } else {
  1531. max = *sshClient.trafficRules.MaxUDPPortForwardCount
  1532. state = &sshClient.udpTrafficState
  1533. }
  1534. if max > 0 && state.concurrentPortForwardCount >= int64(max) {
  1535. return false
  1536. }
  1537. // Update port forward counters.
  1538. if portForwardType == portForwardTypeTCP {
  1539. // Assumes TCP port forwards called dialingTCPPortForward
  1540. state.concurrentDialingPortForwardCount -= 1
  1541. if sshClient.tcpPortForwardDialingAvailableSignal != nil {
  1542. max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  1543. if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
  1544. sshClient.tcpPortForwardDialingAvailableSignal()
  1545. }
  1546. }
  1547. }
  1548. state.concurrentPortForwardCount += 1
  1549. if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
  1550. state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
  1551. }
  1552. state.totalPortForwardCount += 1
  1553. return true
  1554. }
  1555. // establishedPortForward increments the concurrent port
  1556. // forward counter. closedPortForward decrements it, so it
  1557. // must always be called for each establishedPortForward
  1558. // call.
  1559. //
  1560. // When at the limit of established port forwards, the LRU
  1561. // existing port forward is closed to make way for the newly
  1562. // established one. There can be a minor delay as, in addition
  1563. // to calling Close() on the port forward net.Conn,
  1564. // establishedPortForward waits for the LRU's closedPortForward()
  1565. // call which will decrement the concurrent counter. This
  1566. // ensures all resources associated with the LRU (socket,
  1567. // goroutine) are released or will very soon be released before
  1568. // proceeding.
  1569. func (sshClient *sshClient) establishedPortForward(
  1570. portForwardType int, portForwardLRU *common.LRUConns) {
  1571. // Do not lock sshClient here.
  1572. var state *trafficState
  1573. if portForwardType == portForwardTypeTCP {
  1574. state = &sshClient.tcpTrafficState
  1575. } else {
  1576. state = &sshClient.udpTrafficState
  1577. }
  1578. // When the maximum number of port forwards is already
  1579. // established, close the LRU. CloseOldest will call
  1580. // Close on the port forward net.Conn. Both TCP and
  1581. // UDP port forwards have handler goroutines that may
  1582. // be blocked calling Read on the net.Conn. Close will
  1583. // eventually interrupt the Read and cause the handlers
  1584. // to exit, but not immediately. So the following logic
  1585. // waits for a LRU handler to be interrupted and signal
  1586. // availability.
  1587. //
  1588. // Notes:
  1589. //
  1590. // - the port forward limit can change via a traffic
  1591. // rules hot reload; the condition variable handles
  1592. // this case whereas a channel-based semaphore would
  1593. // not.
  1594. //
  1595. // - if a number of goroutines exceeding the total limit
  1596. // arrive here all concurrently, some CloseOldest() calls
  1597. // will have no effect as there can be less existing port
  1598. // forwards than new ones. In this case, the new port
  1599. // forward will be delayed. This is highly unlikely in
  1600. // practise since UDP calls to establishedPortForward are
  1601. // serialized and TCP calls are limited by the dial
  1602. // queue/count.
  1603. if !sshClient.allocatePortForward(portForwardType) {
  1604. portForwardLRU.CloseOldest()
  1605. log.WithContext().Debug("closed LRU port forward")
  1606. state.availablePortForwardCond.L.Lock()
  1607. for !sshClient.allocatePortForward(portForwardType) {
  1608. state.availablePortForwardCond.Wait()
  1609. }
  1610. state.availablePortForwardCond.L.Unlock()
  1611. }
  1612. }
  1613. func (sshClient *sshClient) closedPortForward(
  1614. portForwardType int, bytesUp, bytesDown int64) {
  1615. sshClient.Lock()
  1616. var state *trafficState
  1617. if portForwardType == portForwardTypeTCP {
  1618. state = &sshClient.tcpTrafficState
  1619. } else {
  1620. state = &sshClient.udpTrafficState
  1621. }
  1622. state.concurrentPortForwardCount -= 1
  1623. state.bytesUp += bytesUp
  1624. state.bytesDown += bytesDown
  1625. sshClient.Unlock()
  1626. // Signal any goroutine waiting in establishedPortForward
  1627. // that an established port forward slot is available.
  1628. state.availablePortForwardCond.Signal()
  1629. }
  1630. func (sshClient *sshClient) updateQualityMetricsWithDialResult(
  1631. tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
  1632. sshClient.Lock()
  1633. defer sshClient.Unlock()
  1634. if tcpPortForwardDialSuccess {
  1635. sshClient.qualityMetrics.tcpPortForwardDialedCount += 1
  1636. sshClient.qualityMetrics.tcpPortForwardDialedDuration += dialDuration
  1637. } else {
  1638. sshClient.qualityMetrics.tcpPortForwardFailedCount += 1
  1639. sshClient.qualityMetrics.tcpPortForwardFailedDuration += dialDuration
  1640. }
  1641. }
  1642. func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
  1643. sshClient.Lock()
  1644. defer sshClient.Unlock()
  1645. sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
  1646. }
  1647. func (sshClient *sshClient) handleTCPChannel(
  1648. remainingDialTimeout time.Duration,
  1649. hostToConnect string,
  1650. portToConnect int,
  1651. newChannel ssh.NewChannel) {
  1652. // Assumptions:
  1653. // - sshClient.dialingTCPPortForward() has been called
  1654. // - remainingDialTimeout > 0
  1655. established := false
  1656. defer func() {
  1657. if !established {
  1658. sshClient.abortedTCPPortForward()
  1659. }
  1660. }()
  1661. // Transparently redirect web API request connections.
  1662. isWebServerPortForward := false
  1663. config := sshClient.sshServer.support.Config
  1664. if config.WebServerPortForwardAddress != "" {
  1665. destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
  1666. if destination == config.WebServerPortForwardAddress {
  1667. isWebServerPortForward = true
  1668. if config.WebServerPortForwardRedirectAddress != "" {
  1669. // Note: redirect format is validated when config is loaded
  1670. host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
  1671. port, _ := strconv.Atoi(portStr)
  1672. hostToConnect = host
  1673. portToConnect = port
  1674. }
  1675. }
  1676. }
  1677. // Dial the remote address.
  1678. //
  1679. // Hostname resolution is performed explicitly, as a separate step, as the target IP
  1680. // address is used for traffic rules (AllowSubnets) and OSL seed progress.
  1681. //
  1682. // Contexts are used for cancellation (via sshClient.runContext, which is cancelled
  1683. // when the client is stopping) and timeouts.
  1684. dialStartTime := monotime.Now()
  1685. log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
  1686. ctx, cancelCtx := context.WithTimeout(sshClient.runContext, remainingDialTimeout)
  1687. IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
  1688. cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  1689. // TODO: shuffle list to try other IPs?
  1690. // TODO: IPv6 support
  1691. var IP net.IP
  1692. for _, ip := range IPs {
  1693. if ip.IP.To4() != nil {
  1694. IP = ip.IP
  1695. break
  1696. }
  1697. }
  1698. if err == nil && IP == nil {
  1699. err = errors.New("no IP address")
  1700. }
  1701. resolveElapsedTime := monotime.Since(dialStartTime)
  1702. if err != nil {
  1703. // Record a port forward failure
  1704. sshClient.updateQualityMetricsWithDialResult(true, resolveElapsedTime)
  1705. sshClient.rejectNewChannel(
  1706. newChannel, ssh.ConnectionFailed, fmt.Sprintf("LookupIP failed: %s", err))
  1707. return
  1708. }
  1709. remainingDialTimeout -= resolveElapsedTime
  1710. if remainingDialTimeout <= 0 {
  1711. sshClient.rejectNewChannel(
  1712. newChannel, ssh.Prohibited, "TCP port forward timed out resolving")
  1713. return
  1714. }
  1715. // Enforce traffic rules, using the resolved IP address.
  1716. if !isWebServerPortForward &&
  1717. !sshClient.isPortForwardPermitted(
  1718. portForwardTypeTCP,
  1719. false,
  1720. IP,
  1721. portToConnect) {
  1722. // Note: not recording a port forward failure in this case
  1723. sshClient.rejectNewChannel(
  1724. newChannel, ssh.Prohibited, "port forward not permitted")
  1725. return
  1726. }
  1727. // TCP dial.
  1728. remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
  1729. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
  1730. ctx, cancelCtx = context.WithTimeout(sshClient.runContext, remainingDialTimeout)
  1731. fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
  1732. cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  1733. // Record port forward success or failure
  1734. sshClient.updateQualityMetricsWithDialResult(err == nil, monotime.Since(dialStartTime))
  1735. if err != nil {
  1736. // Monitor for low resource error conditions
  1737. sshClient.sshServer.monitorPortForwardDialError(err)
  1738. sshClient.rejectNewChannel(
  1739. newChannel, ssh.ConnectionFailed, fmt.Sprintf("DialTimeout failed: %s", err))
  1740. return
  1741. }
  1742. // The upstream TCP port forward connection has been established. Schedule
  1743. // some cleanup and notify the SSH client that the channel is accepted.
  1744. defer fwdConn.Close()
  1745. fwdChannel, requests, err := newChannel.Accept()
  1746. if err != nil {
  1747. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  1748. return
  1749. }
  1750. go ssh.DiscardRequests(requests)
  1751. defer fwdChannel.Close()
  1752. // Release the dialing slot and acquire an established slot.
  1753. //
  1754. // establishedPortForward increments the concurrent TCP port
  1755. // forward counter and closes the LRU existing TCP port forward
  1756. // when already at the limit.
  1757. //
  1758. // Known limitations:
  1759. //
  1760. // - Closed LRU TCP sockets will enter the TIME_WAIT state,
  1761. // continuing to consume some resources.
  1762. sshClient.establishedPortForward(portForwardTypeTCP, sshClient.tcpPortForwardLRU)
  1763. // "established = true" cancels the deferred abortedTCPPortForward()
  1764. established = true
  1765. // TODO: 64-bit alignment? https://golang.org/pkg/sync/atomic/#pkg-note-BUG
  1766. var bytesUp, bytesDown int64
  1767. defer func() {
  1768. sshClient.closedPortForward(
  1769. portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
  1770. }()
  1771. lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
  1772. defer lruEntry.Remove()
  1773. // ActivityMonitoredConn monitors the TCP port forward I/O and updates
  1774. // its LRU status. ActivityMonitoredConn also times out I/O on the port
  1775. // forward if both reads and writes have been idle for the specified
  1776. // duration.
  1777. // Ensure nil interface if newClientSeedPortForward returns nil
  1778. var updater common.ActivityUpdater
  1779. seedUpdater := sshClient.newClientSeedPortForward(IP)
  1780. if seedUpdater != nil {
  1781. updater = seedUpdater
  1782. }
  1783. fwdConn, err = common.NewActivityMonitoredConn(
  1784. fwdConn,
  1785. sshClient.idleTCPPortForwardTimeout(),
  1786. true,
  1787. updater,
  1788. lruEntry)
  1789. if err != nil {
  1790. log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
  1791. return
  1792. }
  1793. // Relay channel to forwarded connection.
  1794. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
  1795. // TODO: relay errors to fwdChannel.Stderr()?
  1796. relayWaitGroup := new(sync.WaitGroup)
  1797. relayWaitGroup.Add(1)
  1798. go func() {
  1799. defer relayWaitGroup.Done()
  1800. // io.Copy allocates a 32K temporary buffer, and each port forward relay uses
  1801. // two of these buffers; using io.CopyBuffer with a smaller buffer reduces the
  1802. // overall memory footprint.
  1803. bytes, err := io.CopyBuffer(
  1804. fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  1805. atomic.AddInt64(&bytesDown, bytes)
  1806. if err != nil && err != io.EOF {
  1807. // Debug since errors such as "connection reset by peer" occur during normal operation
  1808. log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
  1809. }
  1810. // Interrupt upstream io.Copy when downstream is shutting down.
  1811. // TODO: this is done to quickly cleanup the port forward when
  1812. // fwdConn has a read timeout, but is it clean -- upstream may still
  1813. // be flowing?
  1814. fwdChannel.Close()
  1815. }()
  1816. bytes, err := io.CopyBuffer(
  1817. fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  1818. atomic.AddInt64(&bytesUp, bytes)
  1819. if err != nil && err != io.EOF {
  1820. log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
  1821. }
  1822. // Shutdown special case: fwdChannel will be closed and return EOF when
  1823. // the SSH connection is closed, but we need to explicitly close fwdConn
  1824. // to interrupt the downstream io.Copy, which may be blocked on a
  1825. // fwdConn.Read().
  1826. fwdConn.Close()
  1827. relayWaitGroup.Wait()
  1828. log.WithContextFields(
  1829. LogFields{
  1830. "remoteAddr": remoteAddr,
  1831. "bytesUp": atomic.LoadInt64(&bytesUp),
  1832. "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
  1833. }