tunnelServer.go 89 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "context"
  22. "crypto/subtle"
  23. "encoding/base64"
  24. "encoding/json"
  25. "errors"
  26. "fmt"
  27. "io"
  28. "net"
  29. "strconv"
  30. "sync"
  31. "sync/atomic"
  32. "syscall"
  33. "time"
  34. "github.com/Psiphon-Labs/goarista/monotime"
  35. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  36. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/accesscontrol"
  37. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
  38. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
  39. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
  40. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
  41. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
  42. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
  43. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
  44. "github.com/marusama/semaphore"
  45. cache "github.com/patrickmn/go-cache"
  46. )
  47. const (
  48. SSH_AUTH_LOG_PERIOD = 30 * time.Minute
  49. SSH_HANDSHAKE_TIMEOUT = 30 * time.Second
  50. SSH_BEGIN_HANDSHAKE_TIMEOUT = 1 * time.Second
  51. SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute
  52. SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
  53. SSH_TCP_PORT_FORWARD_QUEUE_SIZE = 1024
  54. SSH_KEEP_ALIVE_PAYLOAD_MIN_BYTES = 0
  55. SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES = 256
  56. SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second
  57. SSH_SEND_OSL_RETRY_FACTOR = 2
  58. OSL_SESSION_CACHE_TTL = 5 * time.Minute
  59. MAX_AUTHORIZATIONS = 16
  60. )
  61. // TunnelServer is the main server that accepts Psiphon client
  62. // connections, via various obfuscation protocols, and provides
  63. // port forwarding (TCP and UDP) services to the Psiphon client.
  64. // At its core, TunnelServer is an SSH server. SSH is the base
  65. // protocol that provides port forward multiplexing, and transport
  66. // security. Layered on top of SSH, optionally, is Obfuscated SSH
  67. // and meek protocols, which provide further circumvention
  68. // capabilities.
  69. type TunnelServer struct {
  70. runWaitGroup *sync.WaitGroup
  71. listenerError chan error
  72. shutdownBroadcast <-chan struct{}
  73. sshServer *sshServer
  74. }
  75. // NewTunnelServer initializes a new tunnel server.
  76. func NewTunnelServer(
  77. support *SupportServices,
  78. shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
  79. sshServer, err := newSSHServer(support, shutdownBroadcast)
  80. if err != nil {
  81. return nil, common.ContextError(err)
  82. }
  83. return &TunnelServer{
  84. runWaitGroup: new(sync.WaitGroup),
  85. listenerError: make(chan error),
  86. shutdownBroadcast: shutdownBroadcast,
  87. sshServer: sshServer,
  88. }, nil
  89. }
  90. // Run runs the tunnel server; this function blocks while running a selection of
  91. // listeners that handle connection using various obfuscation protocols.
  92. //
  93. // Run listens on each designated tunnel port and spawns new goroutines to handle
  94. // each client connection. It halts when shutdownBroadcast is signaled. A list of active
  95. // clients is maintained, and when halting all clients are cleanly shutdown.
  96. //
  97. // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
  98. // authentication, and then looping on client new channel requests. "direct-tcpip"
  99. // channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress
  100. // config parameter is configured, UDP port forwards over a TCP stream, following
  101. // the udpgw protocol, are handled.
  102. //
  103. // A new goroutine is spawned to handle each port forward for each client. Each port
  104. // forward tracks its bytes transferred. Overall per-client stats for connection duration,
  105. // GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
  106. // client shuts down.
  107. //
  108. // Note: client handler goroutines may still be shutting down after Run() returns. See
  109. // comment in sshClient.stop(). TODO: fully synchronized shutdown.
  110. func (server *TunnelServer) Run() error {
  111. type sshListener struct {
  112. net.Listener
  113. localAddress string
  114. tunnelProtocol string
  115. }
  116. // TODO: should TunnelServer hold its own support pointer?
  117. support := server.sshServer.support
  118. // First bind all listeners; once all are successful,
  119. // start accepting connections on each.
  120. var listeners []*sshListener
  121. for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts {
  122. localAddress := fmt.Sprintf(
  123. "%s:%d", support.Config.ServerIPAddress, listenPort)
  124. var listener net.Listener
  125. var err error
  126. if protocol.TunnelProtocolUsesQUIC(tunnelProtocol) {
  127. listener, err = quic.Listen(localAddress)
  128. } else {
  129. listener, err = net.Listen("tcp", localAddress)
  130. }
  131. if err != nil {
  132. for _, existingListener := range listeners {
  133. existingListener.Listener.Close()
  134. }
  135. return common.ContextError(err)
  136. }
  137. tacticsListener := tactics.NewListener(
  138. listener,
  139. support.TacticsServer,
  140. tunnelProtocol,
  141. func(IPAddress string) common.GeoIPData {
  142. return common.GeoIPData(support.GeoIPService.Lookup(IPAddress))
  143. })
  144. log.WithContextFields(
  145. LogFields{
  146. "localAddress": localAddress,
  147. "tunnelProtocol": tunnelProtocol,
  148. }).Info("listening")
  149. listeners = append(
  150. listeners,
  151. &sshListener{
  152. Listener: tacticsListener,
  153. localAddress: localAddress,
  154. tunnelProtocol: tunnelProtocol,
  155. })
  156. }
  157. for _, listener := range listeners {
  158. server.runWaitGroup.Add(1)
  159. go func(listener *sshListener) {
  160. defer server.runWaitGroup.Done()
  161. log.WithContextFields(
  162. LogFields{
  163. "localAddress": listener.localAddress,
  164. "tunnelProtocol": listener.tunnelProtocol,
  165. }).Info("running")
  166. server.sshServer.runListener(
  167. listener.Listener,
  168. server.listenerError,
  169. listener.tunnelProtocol)
  170. log.WithContextFields(
  171. LogFields{
  172. "localAddress": listener.localAddress,
  173. "tunnelProtocol": listener.tunnelProtocol,
  174. }).Info("stopped")
  175. }(listener)
  176. }
  177. var err error
  178. select {
  179. case <-server.shutdownBroadcast:
  180. case err = <-server.listenerError:
  181. }
  182. for _, listener := range listeners {
  183. listener.Close()
  184. }
  185. server.sshServer.stopClients()
  186. server.runWaitGroup.Wait()
  187. log.WithContext().Info("stopped")
  188. return err
  189. }
  190. // GetLoadStats returns load stats for the tunnel server. The stats are
  191. // broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
  192. // include current connected client count, total number of current port
  193. // forwards.
  194. func (server *TunnelServer) GetLoadStats() (ProtocolStats, RegionStats) {
  195. return server.sshServer.getLoadStats()
  196. }
  197. // ResetAllClientTrafficRules resets all established client traffic rules
  198. // to use the latest config and client properties. Any existing traffic
  199. // rule state is lost, including throttling state.
  200. func (server *TunnelServer) ResetAllClientTrafficRules() {
  201. server.sshServer.resetAllClientTrafficRules()
  202. }
  203. // ResetAllClientOSLConfigs resets all established client OSL state to use
  204. // the latest OSL config. Any existing OSL state is lost, including partial
  205. // progress towards SLOKs.
  206. func (server *TunnelServer) ResetAllClientOSLConfigs() {
  207. server.sshServer.resetAllClientOSLConfigs()
  208. }
  209. // SetClientHandshakeState sets the handshake state -- that it completed and
  210. // what parameters were passed -- in sshClient. This state is used for allowing
  211. // port forwards and for future traffic rule selection. SetClientHandshakeState
  212. // also triggers an immediate traffic rule re-selection, as the rules selected
  213. // upon tunnel establishment may no longer apply now that handshake values are
  214. // set.
  215. //
  216. // The authorizations received from the client handshake are verified and the
  217. // resulting list of authorized access types are applied to the client's tunnel
  218. // and traffic rules. A list of active authorization IDs and authorized access
  219. // types is returned for responding to the client and logging.
  220. func (server *TunnelServer) SetClientHandshakeState(
  221. sessionID string,
  222. state handshakeState,
  223. authorizations []string) ([]string, []string, error) {
  224. return server.sshServer.setClientHandshakeState(sessionID, state, authorizations)
  225. }
  226. // GetClientHandshaked indicates whether the client has completed a handshake
  227. // and whether its traffic rules are immediately exhausted.
  228. func (server *TunnelServer) GetClientHandshaked(
  229. sessionID string) (bool, bool, error) {
  230. return server.sshServer.getClientHandshaked(sessionID)
  231. }
  232. // ExpectClientDomainBytes indicates whether the client was configured to report
  233. // domain bytes in its handshake response.
  234. func (server *TunnelServer) ExpectClientDomainBytes(
  235. sessionID string) (bool, error) {
  236. return server.sshServer.expectClientDomainBytes(sessionID)
  237. }
  238. // SetEstablishTunnels sets whether new tunnels may be established or not.
  239. // When not establishing, incoming connections are immediately closed.
  240. func (server *TunnelServer) SetEstablishTunnels(establish bool) {
  241. server.sshServer.setEstablishTunnels(establish)
  242. }
  243. // GetEstablishTunnels returns whether new tunnels may be established or not.
  244. func (server *TunnelServer) GetEstablishTunnels() bool {
  245. return server.sshServer.getEstablishTunnels()
  246. }
  247. type sshServer struct {
  248. // Note: 64-bit ints used with atomic operations are placed
  249. // at the start of struct to ensure 64-bit alignment.
  250. // (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
  251. lastAuthLog int64
  252. authFailedCount int64
  253. support *SupportServices
  254. establishTunnels int32
  255. concurrentSSHHandshakes semaphore.Semaphore
  256. shutdownBroadcast <-chan struct{}
  257. sshHostKey ssh.Signer
  258. clientsMutex sync.Mutex
  259. stoppingClients bool
  260. acceptedClientCounts map[string]map[string]int64
  261. clients map[string]*sshClient
  262. oslSessionCacheMutex sync.Mutex
  263. oslSessionCache *cache.Cache
  264. authorizationSessionIDsMutex sync.Mutex
  265. authorizationSessionIDs map[string]string
  266. }
  267. func newSSHServer(
  268. support *SupportServices,
  269. shutdownBroadcast <-chan struct{}) (*sshServer, error) {
  270. privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey))
  271. if err != nil {
  272. return nil, common.ContextError(err)
  273. }
  274. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  275. signer, err := ssh.NewSignerFromKey(privateKey)
  276. if err != nil {
  277. return nil, common.ContextError(err)
  278. }
  279. var concurrentSSHHandshakes semaphore.Semaphore
  280. if support.Config.MaxConcurrentSSHHandshakes > 0 {
  281. concurrentSSHHandshakes = semaphore.New(support.Config.MaxConcurrentSSHHandshakes)
  282. }
  283. // The OSL session cache temporarily retains OSL seed state
  284. // progress for disconnected clients. This enables clients
  285. // that disconnect and immediately reconnect to the same
  286. // server to resume their OSL progress. Cached progress
  287. // is referenced by session ID and is retained for
  288. // OSL_SESSION_CACHE_TTL after disconnect.
  289. //
  290. // Note: session IDs are assumed to be unpredictable. If a
  291. // rogue client could guess the session ID of another client,
  292. // it could resume its OSL progress and, if the OSL config
  293. // were known, infer some activity.
  294. oslSessionCache := cache.New(OSL_SESSION_CACHE_TTL, 1*time.Minute)
  295. return &sshServer{
  296. support: support,
  297. establishTunnels: 1,
  298. concurrentSSHHandshakes: concurrentSSHHandshakes,
  299. shutdownBroadcast: shutdownBroadcast,
  300. sshHostKey: signer,
  301. acceptedClientCounts: make(map[string]map[string]int64),
  302. clients: make(map[string]*sshClient),
  303. oslSessionCache: oslSessionCache,
  304. authorizationSessionIDs: make(map[string]string),
  305. }, nil
  306. }
  307. func (sshServer *sshServer) setEstablishTunnels(establish bool) {
  308. // Do nothing when the setting is already correct. This avoids
  309. // spurious log messages when setEstablishTunnels is called
  310. // periodically with the same setting.
  311. if establish == sshServer.getEstablishTunnels() {
  312. return
  313. }
  314. establishFlag := int32(1)
  315. if !establish {
  316. establishFlag = 0
  317. }
  318. atomic.StoreInt32(&sshServer.establishTunnels, establishFlag)
  319. log.WithContextFields(
  320. LogFields{"establish": establish}).Info("establishing tunnels")
  321. }
  322. func (sshServer *sshServer) getEstablishTunnels() bool {
  323. return atomic.LoadInt32(&sshServer.establishTunnels) == 1
  324. }
  325. // runListener is intended to run an a goroutine; it blocks
  326. // running a particular listener. If an unrecoverable error
  327. // occurs, it will send the error to the listenerError channel.
  328. func (sshServer *sshServer) runListener(
  329. listener net.Listener,
  330. listenerError chan<- error,
  331. listenerTunnelProtocol string) {
  332. runningProtocols := make([]string, 0)
  333. for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
  334. runningProtocols = append(runningProtocols, tunnelProtocol)
  335. }
  336. handleClient := func(clientTunnelProtocol string, clientConn net.Conn) {
  337. // Note: establish tunnel limiter cannot simply stop TCP
  338. // listeners in all cases (e.g., meek) since SSH tunnel can
  339. // span multiple TCP connections.
  340. if !sshServer.getEstablishTunnels() {
  341. log.WithContext().Debug("not establishing tunnels")
  342. clientConn.Close()
  343. return
  344. }
  345. // The tunnelProtocol passed to handleClient is used for stats,
  346. // throttling, etc. When the tunnel protocol can be determined
  347. // unambiguously from the listening port, use that protocol and
  348. // don't use any client-declared value. Only use the client's
  349. // value, if present, in special cases where the listenting port
  350. // cannot distinguish the protocol.
  351. tunnelProtocol := listenerTunnelProtocol
  352. if clientTunnelProtocol != "" &&
  353. protocol.UseClientTunnelProtocol(
  354. clientTunnelProtocol, runningProtocols) {
  355. tunnelProtocol = clientTunnelProtocol
  356. }
  357. // process each client connection concurrently
  358. go sshServer.handleClient(tunnelProtocol, clientConn)
  359. }
  360. // Note: when exiting due to a unrecoverable error, be sure
  361. // to try to send the error to listenerError so that the outer
  362. // TunnelServer.Run will properly shut down instead of remaining
  363. // running.
  364. if protocol.TunnelProtocolUsesMeekHTTP(listenerTunnelProtocol) ||
  365. protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol) {
  366. meekServer, err := NewMeekServer(
  367. sshServer.support,
  368. listener,
  369. protocol.TunnelProtocolUsesMeekHTTPS(listenerTunnelProtocol),
  370. protocol.TunnelProtocolUsesObfuscatedSessionTickets(listenerTunnelProtocol),
  371. handleClient,
  372. sshServer.shutdownBroadcast)
  373. if err == nil {
  374. err = meekServer.Run()
  375. }
  376. if err != nil {
  377. select {
  378. case listenerError <- common.ContextError(err):
  379. default:
  380. }
  381. return
  382. }
  383. } else {
  384. for {
  385. conn, err := listener.Accept()
  386. select {
  387. case <-sshServer.shutdownBroadcast:
  388. if err == nil {
  389. conn.Close()
  390. }
  391. return
  392. default:
  393. }
  394. if err != nil {
  395. if e, ok := err.(net.Error); ok && e.Temporary() {
  396. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  397. // Temporary error, keep running
  398. continue
  399. }
  400. select {
  401. case listenerError <- common.ContextError(err):
  402. default:
  403. }
  404. return
  405. }
  406. handleClient("", conn)
  407. }
  408. }
  409. }
  410. // An accepted client has completed a direct TCP or meek connection and has a net.Conn. Registration
  411. // is for tracking the number of connections.
  412. func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol, region string) {
  413. sshServer.clientsMutex.Lock()
  414. defer sshServer.clientsMutex.Unlock()
  415. if sshServer.acceptedClientCounts[tunnelProtocol] == nil {
  416. sshServer.acceptedClientCounts[tunnelProtocol] = make(map[string]int64)
  417. }
  418. sshServer.acceptedClientCounts[tunnelProtocol][region] += 1
  419. }
  420. func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol, region string) {
  421. sshServer.clientsMutex.Lock()
  422. defer sshServer.clientsMutex.Unlock()
  423. sshServer.acceptedClientCounts[tunnelProtocol][region] -= 1
  424. }
  425. // An established client has completed its SSH handshake and has a ssh.Conn. Registration is
  426. // for tracking the number of fully established clients and for maintaining a list of running
  427. // clients (for stopping at shutdown time).
  428. func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
  429. sshServer.clientsMutex.Lock()
  430. if sshServer.stoppingClients {
  431. sshServer.clientsMutex.Unlock()
  432. return false
  433. }
  434. // In the case of a duplicate client sessionID, the previous client is closed.
  435. // - Well-behaved clients generate a random sessionID that should be unique (won't
  436. // accidentally conflict) and hard to guess (can't be targeted by a malicious
  437. // client).
  438. // - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected
  439. // and reestablished. In this case, when the same server is selected, this logic
  440. // will be hit; closing the old, dangling client is desirable.
  441. // - Multi-tunnel clients should not normally use one server for multiple tunnels.
  442. existingClient := sshServer.clients[client.sessionID]
  443. sshServer.clients[client.sessionID] = client
  444. sshServer.clientsMutex.Unlock()
  445. // Call stop() outside the mutex to avoid deadlock.
  446. if existingClient != nil {
  447. existingClient.stop()
  448. // Since existingClient.run() isn't guaranteed to have terminated at
  449. // this point, synchronously release authorizations for the previous
  450. // client here. This ensures that the authorization IDs are not in
  451. // use when the reconnecting client submits its authorizations.
  452. existingClient.cleanupAuthorizations()
  453. log.WithContext().Debug(
  454. "stopped existing client with duplicate session ID")
  455. }
  456. return true
  457. }
  458. func (sshServer *sshServer) unregisterEstablishedClient(client *sshClient) {
  459. sshServer.clientsMutex.Lock()
  460. registeredClient := sshServer.clients[client.sessionID]
  461. // registeredClient will differ from client when client
  462. // is the existingClient terminated in registerEstablishedClient.
  463. // In that case, registeredClient remains connected, and
  464. // the sshServer.clients entry should be retained.
  465. if registeredClient == client {
  466. delete(sshServer.clients, client.sessionID)
  467. }
  468. sshServer.clientsMutex.Unlock()
  469. // Call stop() outside the mutex to avoid deadlock.
  470. client.stop()
  471. }
  472. type ProtocolStats map[string]map[string]int64
  473. type RegionStats map[string]map[string]map[string]int64
  474. func (sshServer *sshServer) getLoadStats() (ProtocolStats, RegionStats) {
  475. sshServer.clientsMutex.Lock()
  476. defer sshServer.clientsMutex.Unlock()
  477. // Explicitly populate with zeros to ensure 0 counts in log messages
  478. zeroStats := func() map[string]int64 {
  479. stats := make(map[string]int64)
  480. stats["accepted_clients"] = 0
  481. stats["established_clients"] = 0
  482. stats["dialing_tcp_port_forwards"] = 0
  483. stats["tcp_port_forwards"] = 0
  484. stats["total_tcp_port_forwards"] = 0
  485. stats["udp_port_forwards"] = 0
  486. stats["total_udp_port_forwards"] = 0
  487. stats["tcp_port_forward_dialed_count"] = 0
  488. stats["tcp_port_forward_dialed_duration"] = 0
  489. stats["tcp_port_forward_failed_count"] = 0
  490. stats["tcp_port_forward_failed_duration"] = 0
  491. stats["tcp_port_forward_rejected_dialing_limit_count"] = 0
  492. return stats
  493. }
  494. zeroProtocolStats := func() map[string]map[string]int64 {
  495. stats := make(map[string]map[string]int64)
  496. stats["ALL"] = zeroStats()
  497. for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
  498. stats[tunnelProtocol] = zeroStats()
  499. }
  500. return stats
  501. }
  502. // [<protocol or ALL>][<stat name>] -> count
  503. protocolStats := zeroProtocolStats()
  504. // [<region][<protocol or ALL>][<stat name>] -> count
  505. regionStats := make(RegionStats)
  506. // Note: as currently tracked/counted, each established client is also an accepted client
  507. for tunnelProtocol, regionAcceptedClientCounts := range sshServer.acceptedClientCounts {
  508. for region, acceptedClientCount := range regionAcceptedClientCounts {
  509. if acceptedClientCount > 0 {
  510. if regionStats[region] == nil {
  511. regionStats[region] = zeroProtocolStats()
  512. }
  513. protocolStats["ALL"]["accepted_clients"] += acceptedClientCount
  514. protocolStats[tunnelProtocol]["accepted_clients"] += acceptedClientCount
  515. regionStats[region]["ALL"]["accepted_clients"] += acceptedClientCount
  516. regionStats[region][tunnelProtocol]["accepted_clients"] += acceptedClientCount
  517. }
  518. }
  519. }
  520. for _, client := range sshServer.clients {
  521. client.Lock()
  522. tunnelProtocol := client.tunnelProtocol
  523. region := client.geoIPData.Country
  524. if regionStats[region] == nil {
  525. regionStats[region] = zeroProtocolStats()
  526. }
  527. stats := []map[string]int64{
  528. protocolStats["ALL"],
  529. protocolStats[tunnelProtocol],
  530. regionStats[region]["ALL"],
  531. regionStats[region][tunnelProtocol]}
  532. for _, stat := range stats {
  533. stat["established_clients"] += 1
  534. // Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
  535. stat["dialing_tcp_port_forwards"] += client.tcpTrafficState.concurrentDialingPortForwardCount
  536. stat["tcp_port_forwards"] += client.tcpTrafficState.concurrentPortForwardCount
  537. stat["total_tcp_port_forwards"] += client.tcpTrafficState.totalPortForwardCount
  538. // client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
  539. stat["udp_port_forwards"] += client.udpTrafficState.concurrentPortForwardCount
  540. stat["total_udp_port_forwards"] += client.udpTrafficState.totalPortForwardCount
  541. stat["tcp_port_forward_dialed_count"] += client.qualityMetrics.tcpPortForwardDialedCount
  542. stat["tcp_port_forward_dialed_duration"] +=
  543. int64(client.qualityMetrics.tcpPortForwardDialedDuration / time.Millisecond)
  544. stat["tcp_port_forward_failed_count"] += client.qualityMetrics.tcpPortForwardFailedCount
  545. stat["tcp_port_forward_failed_duration"] +=
  546. int64(client.qualityMetrics.tcpPortForwardFailedDuration / time.Millisecond)
  547. stat["tcp_port_forward_rejected_dialing_limit_count"] +=
  548. client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount
  549. }
  550. client.qualityMetrics.tcpPortForwardDialedCount = 0
  551. client.qualityMetrics.tcpPortForwardDialedDuration = 0
  552. client.qualityMetrics.tcpPortForwardFailedCount = 0
  553. client.qualityMetrics.tcpPortForwardFailedDuration = 0
  554. client.qualityMetrics.tcpPortForwardRejectedDialingLimitCount = 0
  555. client.Unlock()
  556. }
  557. return protocolStats, regionStats
  558. }
  559. func (sshServer *sshServer) resetAllClientTrafficRules() {
  560. sshServer.clientsMutex.Lock()
  561. clients := make(map[string]*sshClient)
  562. for sessionID, client := range sshServer.clients {
  563. clients[sessionID] = client
  564. }
  565. sshServer.clientsMutex.Unlock()
  566. for _, client := range clients {
  567. client.setTrafficRules()
  568. }
  569. }
  570. func (sshServer *sshServer) resetAllClientOSLConfigs() {
  571. // Flush cached seed state. This has the same effect
  572. // and same limitations as calling setOSLConfig for
  573. // currently connected clients -- all progress is lost.
  574. sshServer.oslSessionCacheMutex.Lock()
  575. sshServer.oslSessionCache.Flush()
  576. sshServer.oslSessionCacheMutex.Unlock()
  577. sshServer.clientsMutex.Lock()
  578. clients := make(map[string]*sshClient)
  579. for sessionID, client := range sshServer.clients {
  580. clients[sessionID] = client
  581. }
  582. sshServer.clientsMutex.Unlock()
  583. for _, client := range clients {
  584. client.setOSLConfig()
  585. }
  586. }
  587. func (sshServer *sshServer) setClientHandshakeState(
  588. sessionID string,
  589. state handshakeState,
  590. authorizations []string) ([]string, []string, error) {
  591. sshServer.clientsMutex.Lock()
  592. client := sshServer.clients[sessionID]
  593. sshServer.clientsMutex.Unlock()
  594. if client == nil {
  595. return nil, nil, common.ContextError(errors.New("unknown session ID"))
  596. }
  597. activeAuthorizationIDs, authorizedAccessTypes, err := client.setHandshakeState(
  598. state, authorizations)
  599. if err != nil {
  600. return nil, nil, common.ContextError(err)
  601. }
  602. return activeAuthorizationIDs, authorizedAccessTypes, nil
  603. }
  604. func (sshServer *sshServer) getClientHandshaked(
  605. sessionID string) (bool, bool, error) {
  606. sshServer.clientsMutex.Lock()
  607. client := sshServer.clients[sessionID]
  608. sshServer.clientsMutex.Unlock()
  609. if client == nil {
  610. return false, false, common.ContextError(errors.New("unknown session ID"))
  611. }
  612. completed, exhausted := client.getHandshaked()
  613. return completed, exhausted, nil
  614. }
  615. func (sshServer *sshServer) revokeClientAuthorizations(sessionID string) {
  616. sshServer.clientsMutex.Lock()
  617. client := sshServer.clients[sessionID]
  618. sshServer.clientsMutex.Unlock()
  619. if client == nil {
  620. return
  621. }
  622. // sshClient.handshakeState.authorizedAccessTypes is not cleared. Clearing
  623. // authorizedAccessTypes may cause sshClient.logTunnel to fail to log
  624. // access types. As the revocation may be due to legitimate use of an
  625. // authorization in multiple sessions by a single client, useful metrics
  626. // would be lost.
  627. client.Lock()
  628. client.handshakeState.authorizationsRevoked = true
  629. client.Unlock()
  630. // Select and apply new traffic rules, as filtered by the client's new
  631. // authorization state.
  632. client.setTrafficRules()
  633. }
  634. func (sshServer *sshServer) expectClientDomainBytes(
  635. sessionID string) (bool, error) {
  636. sshServer.clientsMutex.Lock()
  637. client := sshServer.clients[sessionID]
  638. sshServer.clientsMutex.Unlock()
  639. if client == nil {
  640. return false, common.ContextError(errors.New("unknown session ID"))
  641. }
  642. return client.expectDomainBytes(), nil
  643. }
  644. func (sshServer *sshServer) stopClients() {
  645. sshServer.clientsMutex.Lock()
  646. sshServer.stoppingClients = true
  647. clients := sshServer.clients
  648. sshServer.clients = make(map[string]*sshClient)
  649. sshServer.clientsMutex.Unlock()
  650. for _, client := range clients {
  651. client.stop()
  652. }
  653. }
  654. func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
  655. geoIPData := sshServer.support.GeoIPService.Lookup(
  656. common.IPAddressFromAddr(clientConn.RemoteAddr()))
  657. sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
  658. defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
  659. // When configured, enforce a cap on the number of concurrent SSH
  660. // handshakes. This limits load spikes on busy servers when many clients
  661. // attempt to connect at once. Wait a short time, SSH_BEGIN_HANDSHAKE_TIMEOUT,
  662. // to acquire; waiting will avoid immediately creating more load on another
  663. // server in the network when the client tries a new candidate. Disconnect the
  664. // client when that wait time is exceeded.
  665. //
  666. // This mechanism limits memory allocations and CPU usage associated with the
  667. // SSH handshake. At this point, new direct TCP connections or new meek
  668. // connections, with associated resource usage, are already established. Those
  669. // connections are expected to be rate or load limited using other mechanisms.
  670. //
  671. // TODO:
  672. //
  673. // - deduct time spent acquiring the semaphore from SSH_HANDSHAKE_TIMEOUT in
  674. // sshClient.run, since the client is also applying an SSH handshake timeout
  675. // and won't exclude time spent waiting.
  676. // - each call to sshServer.handleClient (in sshServer.runListener) is invoked
  677. // in its own goroutine, but shutdown doesn't synchronously await these
  678. // goroutnes. Once this is synchronizes, the following context.WithTimeout
  679. // should use an sshServer parent context to ensure blocking acquires
  680. // interrupt immediately upon shutdown.
  681. var onSSHHandshakeFinished func()
  682. if sshServer.support.Config.MaxConcurrentSSHHandshakes > 0 {
  683. ctx, cancelFunc := context.WithTimeout(
  684. context.Background(), SSH_BEGIN_HANDSHAKE_TIMEOUT)
  685. defer cancelFunc()
  686. err := sshServer.concurrentSSHHandshakes.Acquire(ctx, 1)
  687. if err != nil {
  688. clientConn.Close()
  689. // This is a debug log as the only possible error is context timeout.
  690. log.WithContextFields(LogFields{"error": err}).Debug(
  691. "acquire SSH handshake semaphore failed")
  692. return
  693. }
  694. onSSHHandshakeFinished = func() {
  695. sshServer.concurrentSSHHandshakes.Release(1)
  696. }
  697. }
  698. sshClient := newSshClient(sshServer, tunnelProtocol, geoIPData)
  699. // sshClient.run _must_ call onSSHHandshakeFinished to release the semaphore:
  700. // in any error case; or, as soon as the SSH handshake phase has successfully
  701. // completed.
  702. sshClient.run(clientConn, onSSHHandshakeFinished)
  703. }
  704. func (sshServer *sshServer) monitorPortForwardDialError(err error) {
  705. // "err" is the error returned from a failed TCP or UDP port
  706. // forward dial. Certain system error codes indicate low resource
  707. // conditions: insufficient file descriptors, ephemeral ports, or
  708. // memory. For these cases, log an alert.
  709. // TODO: also temporarily suspend new clients
  710. // Note: don't log net.OpError.Error() as the full error string
  711. // may contain client destination addresses.
  712. opErr, ok := err.(*net.OpError)
  713. if ok {
  714. if opErr.Err == syscall.EADDRNOTAVAIL ||
  715. opErr.Err == syscall.EAGAIN ||
  716. opErr.Err == syscall.ENOMEM ||
  717. opErr.Err == syscall.EMFILE ||
  718. opErr.Err == syscall.ENFILE {
  719. log.WithContextFields(
  720. LogFields{"error": opErr.Err}).Error(
  721. "port forward dial failed due to unavailable resource")
  722. }
  723. }
  724. }
  725. type sshClient struct {
  726. sync.Mutex
  727. sshServer *sshServer
  728. tunnelProtocol string
  729. sshConn ssh.Conn
  730. activityConn *common.ActivityMonitoredConn
  731. throttledConn *common.ThrottledConn
  732. geoIPData GeoIPData
  733. sessionID string
  734. isFirstTunnelInSession bool
  735. supportsServerRequests bool
  736. handshakeState handshakeState
  737. udpChannel ssh.Channel
  738. packetTunnelChannel ssh.Channel
  739. trafficRules TrafficRules
  740. tcpTrafficState trafficState
  741. udpTrafficState trafficState
  742. qualityMetrics qualityMetrics
  743. tcpPortForwardLRU *common.LRUConns
  744. oslClientSeedState *osl.ClientSeedState
  745. signalIssueSLOKs chan struct{}
  746. runCtx context.Context
  747. stopRunning context.CancelFunc
  748. tcpPortForwardDialingAvailableSignal context.CancelFunc
  749. releaseAuthorizations func()
  750. stopTimer *time.Timer
  751. }
  752. type trafficState struct {
  753. bytesUp int64
  754. bytesDown int64
  755. concurrentDialingPortForwardCount int64
  756. peakConcurrentDialingPortForwardCount int64
  757. concurrentPortForwardCount int64
  758. peakConcurrentPortForwardCount int64
  759. totalPortForwardCount int64
  760. availablePortForwardCond *sync.Cond
  761. }
  762. // qualityMetrics records upstream TCP dial attempts and
  763. // elapsed time. Elapsed time includes the full TCP handshake
  764. // and, in aggregate, is a measure of the quality of the
  765. // upstream link. These stats are recorded by each sshClient
  766. // and then reported and reset in sshServer.getLoadStats().
  767. type qualityMetrics struct {
  768. tcpPortForwardDialedCount int64
  769. tcpPortForwardDialedDuration time.Duration
  770. tcpPortForwardFailedCount int64
  771. tcpPortForwardFailedDuration time.Duration
  772. tcpPortForwardRejectedDialingLimitCount int64
  773. }
  774. type handshakeState struct {
  775. completed bool
  776. apiProtocol string
  777. apiParams common.APIParameters
  778. authorizedAccessTypes []string
  779. authorizationsRevoked bool
  780. expectDomainBytes bool
  781. }
  782. func newSshClient(
  783. sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData) *sshClient {
  784. runCtx, stopRunning := context.WithCancel(context.Background())
  785. // isFirstTunnelInSession is defaulted to true so that the pre-handshake
  786. // traffic rules won't apply UnthrottleFirstTunnelOnly and negate any
  787. // unthrottled bytes during the initial protocol negotiation.
  788. client := &sshClient{
  789. sshServer: sshServer,
  790. tunnelProtocol: tunnelProtocol,
  791. geoIPData: geoIPData,
  792. isFirstTunnelInSession: true,
  793. tcpPortForwardLRU: common.NewLRUConns(),
  794. signalIssueSLOKs: make(chan struct{}, 1),
  795. runCtx: runCtx,
  796. stopRunning: stopRunning,
  797. }
  798. client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
  799. client.udpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
  800. return client
  801. }
  802. func (sshClient *sshClient) run(
  803. clientConn net.Conn, onSSHHandshakeFinished func()) {
  804. // onSSHHandshakeFinished must be called even if the SSH handshake is aborted.
  805. defer func() {
  806. if onSSHHandshakeFinished != nil {
  807. onSSHHandshakeFinished()
  808. }
  809. }()
  810. // Some conns report additional metrics
  811. metricsSource, isMetricsSource := clientConn.(MetricsSource)
  812. // Set initial traffic rules, pre-handshake, based on currently known info.
  813. sshClient.setTrafficRules()
  814. // Wrap the base client connection with an ActivityMonitoredConn which will
  815. // terminate the connection if no data is received before the deadline. This
  816. // timeout is in effect for the entire duration of the SSH connection. Clients
  817. // must actively use the connection or send SSH keep alive requests to keep
  818. // the connection active. Writes are not considered reliable activity indicators
  819. // due to buffering.
  820. activityConn, err := common.NewActivityMonitoredConn(
  821. clientConn,
  822. SSH_CONNECTION_READ_DEADLINE,
  823. false,
  824. nil,
  825. nil)
  826. if err != nil {
  827. clientConn.Close()
  828. log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
  829. return
  830. }
  831. clientConn = activityConn
  832. // Further wrap the connection in a rate limiting ThrottledConn.
  833. throttledConn := common.NewThrottledConn(clientConn, sshClient.rateLimits())
  834. clientConn = throttledConn
  835. // Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  836. // respect shutdownBroadcast and implement a specific handshake timeout.
  837. // The timeout is to reclaim network resources in case the handshake takes
  838. // too long.
  839. type sshNewServerConnResult struct {
  840. conn net.Conn
  841. sshConn *ssh.ServerConn
  842. channels <-chan ssh.NewChannel
  843. requests <-chan *ssh.Request
  844. err error
  845. }
  846. resultChannel := make(chan *sshNewServerConnResult, 2)
  847. var afterFunc *time.Timer
  848. if SSH_HANDSHAKE_TIMEOUT > 0 {
  849. afterFunc = time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  850. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  851. })
  852. }
  853. go func(conn net.Conn) {
  854. sshServerConfig := &ssh.ServerConfig{
  855. PasswordCallback: sshClient.passwordCallback,
  856. AuthLogCallback: sshClient.authLogCallback,
  857. ServerVersion: sshClient.sshServer.support.Config.SSHServerVersion,
  858. }
  859. sshServerConfig.AddHostKey(sshClient.sshServer.sshHostKey)
  860. if protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
  861. // This is the list of supported non-Encrypt-then-MAC algorithms from
  862. // https://github.com/Psiphon-Labs/psiphon-tunnel-core/blob/3ef11effe6acd92c3aefd140ee09c42a1f15630b/psiphon/common/crypto/ssh/common.go#L60
  863. //
  864. // With Encrypt-then-MAC algorithms, packet length is transmitted in
  865. // plaintext, which aids in traffic analysis; clients may still send
  866. // Encrypt-then-MAC algorithms in their KEX_INIT message, but do not
  867. // select these algorithms.
  868. //
  869. // The exception is TUNNEL_PROTOCOL_SSH, which is intended to appear
  870. // like SSH on the wire.
  871. sshServerConfig.MACs = []string{"hmac-sha2-256", "hmac-sha1", "hmac-sha1-96"}
  872. }
  873. result := &sshNewServerConnResult{}
  874. // Wrap the connection in an SSH deobfuscator when required.
  875. if protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
  876. // Note: NewObfuscatedSshConn blocks on network I/O
  877. // TODO: ensure this won't block shutdown
  878. conn, result.err = obfuscator.NewObfuscatedSshConn(
  879. obfuscator.OBFUSCATION_CONN_MODE_SERVER,
  880. conn,
  881. sshClient.sshServer.support.Config.ObfuscatedSSHKey,
  882. nil,
  883. nil)
  884. if result.err != nil {
  885. result.err = common.ContextError(result.err)
  886. }
  887. }
  888. if result.err == nil {
  889. result.sshConn, result.channels, result.requests, result.err =
  890. ssh.NewServerConn(conn, sshServerConfig)
  891. }
  892. resultChannel <- result
  893. }(clientConn)
  894. var result *sshNewServerConnResult
  895. select {
  896. case result = <-resultChannel:
  897. case <-sshClient.sshServer.shutdownBroadcast:
  898. // Close() will interrupt an ongoing handshake
  899. // TODO: wait for SSH handshake goroutines to exit before returning?
  900. clientConn.Close()
  901. return
  902. }
  903. if afterFunc != nil {
  904. afterFunc.Stop()
  905. }
  906. if result.err != nil {
  907. clientConn.Close()
  908. // This is a Debug log due to noise. The handshake often fails due to I/O
  909. // errors as clients frequently interrupt connections in progress when
  910. // client-side load balancing completes a connection to a different server.
  911. log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed")
  912. return
  913. }
  914. // The SSH handshake has finished successfully; notify now to allow other
  915. // blocked SSH handshakes to proceed.
  916. if onSSHHandshakeFinished != nil {
  917. onSSHHandshakeFinished()
  918. }
  919. onSSHHandshakeFinished = nil
  920. sshClient.Lock()
  921. sshClient.sshConn = result.sshConn
  922. sshClient.activityConn = activityConn
  923. sshClient.throttledConn = throttledConn
  924. sshClient.Unlock()
  925. if !sshClient.sshServer.registerEstablishedClient(sshClient) {
  926. clientConn.Close()
  927. log.WithContext().Warning("register failed")
  928. return
  929. }
  930. sshClient.runTunnel(result.channels, result.requests)
  931. // Note: sshServer.unregisterEstablishedClient calls sshClient.stop(),
  932. // which also closes underlying transport Conn.
  933. sshClient.sshServer.unregisterEstablishedClient(sshClient)
  934. var additionalMetrics LogFields
  935. if isMetricsSource {
  936. additionalMetrics = metricsSource.GetMetrics()
  937. }
  938. sshClient.logTunnel(additionalMetrics)
  939. // Transfer OSL seed state -- the OSL progress -- from the closing
  940. // client to the session cache so the client can resume its progress
  941. // if it reconnects to this same server.
  942. // Note: following setOSLConfig order of locking.
  943. sshClient.Lock()
  944. if sshClient.oslClientSeedState != nil {
  945. sshClient.sshServer.oslSessionCacheMutex.Lock()
  946. sshClient.oslClientSeedState.Hibernate()
  947. sshClient.sshServer.oslSessionCache.Set(
  948. sshClient.sessionID, sshClient.oslClientSeedState, cache.DefaultExpiration)
  949. sshClient.sshServer.oslSessionCacheMutex.Unlock()
  950. sshClient.oslClientSeedState = nil
  951. }
  952. sshClient.Unlock()
  953. // Initiate cleanup of the GeoIP session cache. To allow for post-tunnel
  954. // final status requests, the lifetime of cached GeoIP records exceeds the
  955. // lifetime of the sshClient.
  956. sshClient.sshServer.support.GeoIPService.MarkSessionCacheToExpire(sshClient.sessionID)
  957. }
  958. func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  959. expectedSessionIDLength := 2 * protocol.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
  960. expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
  961. var sshPasswordPayload protocol.SSHPasswordPayload
  962. err := json.Unmarshal(password, &sshPasswordPayload)
  963. if err != nil {
  964. // Backwards compatibility case: instead of a JSON payload, older clients
  965. // send the hex encoded session ID prepended to the SSH password.
  966. // Note: there's an even older case where clients don't send any session ID,
  967. // but that's no longer supported.
  968. if len(password) == expectedSessionIDLength+expectedSSHPasswordLength {
  969. sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength])
  970. sshPasswordPayload.SshPassword = string(password[expectedSSHPasswordLength:])
  971. } else {
  972. return nil, common.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  973. }
  974. }
  975. if !isHexDigits(sshClient.sshServer.support.Config, sshPasswordPayload.SessionId) ||
  976. len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
  977. return nil, common.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
  978. }
  979. userOk := (subtle.ConstantTimeCompare(
  980. []byte(conn.User()), []byte(sshClient.sshServer.support.Config.SSHUserName)) == 1)
  981. passwordOk := (subtle.ConstantTimeCompare(
  982. []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1)
  983. if !userOk || !passwordOk {
  984. return nil, common.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  985. }
  986. sessionID := sshPasswordPayload.SessionId
  987. // The GeoIP session cache will be populated if there was a previous tunnel
  988. // with this session ID. This will be true up to GEOIP_SESSION_CACHE_TTL, which
  989. // is currently much longer than the OSL session cache, another option to use if
  990. // the GeoIP session cache is retired (the GeoIP session cache currently only
  991. // supports legacy use cases).
  992. isFirstTunnelInSession := !sshClient.sshServer.support.GeoIPService.InSessionCache(sessionID)
  993. supportsServerRequests := common.Contains(
  994. sshPasswordPayload.ClientCapabilities, protocol.CLIENT_CAPABILITY_SERVER_REQUESTS)
  995. sshClient.Lock()
  996. // After this point, these values are read-only as they are read
  997. // without obtaining sshClient.Lock.
  998. sshClient.sessionID = sessionID
  999. sshClient.isFirstTunnelInSession = isFirstTunnelInSession
  1000. sshClient.supportsServerRequests = supportsServerRequests
  1001. geoIPData := sshClient.geoIPData
  1002. sshClient.Unlock()
  1003. // Store the GeoIP data associated with the session ID. This makes
  1004. // the GeoIP data available to the web server for web API requests.
  1005. // A cache that's distinct from the sshClient record is used to allow
  1006. // for or post-tunnel final status requests.
  1007. // If the client is reconnecting with the same session ID, this call
  1008. // will undo the expiry set by MarkSessionCacheToExpire.
  1009. sshClient.sshServer.support.GeoIPService.SetSessionCache(sessionID, geoIPData)
  1010. return nil, nil
  1011. }
  1012. func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  1013. if err != nil {
  1014. if method == "none" && err.Error() == "no auth passed yet" {
  1015. // In this case, the callback invocation is noise from auth negotiation
  1016. return
  1017. }
  1018. // Note: here we previously logged messages for fail2ban to act on. This is no longer
  1019. // done as the complexity outweighs the benefits.
  1020. //
  1021. // - The SSH credential is not secret -- it's in the server entry. Attackers targeting
  1022. // the server likely already have the credential. On the other hand, random scanning and
  1023. // brute forcing is mitigated with high entropy random passwords, rate limiting
  1024. // (implemented on the host via iptables), and limited capabilities (the SSH session can
  1025. // only port forward).
  1026. //
  1027. // - fail2ban coverage was inconsistent; in the case of an unfronted meek protocol through
  1028. // an upstream proxy, the remote address is the upstream proxy, which should not be blocked.
  1029. // The X-Forwarded-For header cant be used instead as it may be forged and used to get IPs
  1030. // deliberately blocked; and in any case fail2ban adds iptables rules which can only block
  1031. // by direct remote IP, not by original client IP. Fronted meek has the same iptables issue.
  1032. //
  1033. // Random scanning and brute forcing of port 22 will result in log noise. To mitigate this,
  1034. // not every authentication failure is logged. A summary log is emitted periodically to
  1035. // retain some record of this activity in case this is relevant to, e.g., a performance
  1036. // investigation.
  1037. atomic.AddInt64(&sshClient.sshServer.authFailedCount, 1)
  1038. lastAuthLog := monotime.Time(atomic.LoadInt64(&sshClient.sshServer.lastAuthLog))
  1039. if monotime.Since(lastAuthLog) > SSH_AUTH_LOG_PERIOD {
  1040. now := int64(monotime.Now())
  1041. if atomic.CompareAndSwapInt64(&sshClient.sshServer.lastAuthLog, int64(lastAuthLog), now) {
  1042. count := atomic.SwapInt64(&sshClient.sshServer.authFailedCount, 0)
  1043. log.WithContextFields(
  1044. LogFields{"lastError": err, "failedCount": count}).Warning("authentication failures")
  1045. }
  1046. }
  1047. log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
  1048. } else {
  1049. log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
  1050. }
  1051. }
  1052. // stop signals the ssh connection to shutdown. After sshConn() returns,
  1053. // the connection has terminated but sshClient.run() may still be
  1054. // running and in the process of exiting.
  1055. func (sshClient *sshClient) stop() {
  1056. sshClient.sshConn.Close()
  1057. sshClient.sshConn.Wait()
  1058. }
  1059. // runTunnel handles/dispatches new channels and new requests from the client.
  1060. // When the SSH client connection closes, both the channels and requests channels
  1061. // will close and runTunnel will exit.
  1062. func (sshClient *sshClient) runTunnel(
  1063. channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
  1064. waitGroup := new(sync.WaitGroup)
  1065. // Start client SSH API request handler
  1066. waitGroup.Add(1)
  1067. go func() {
  1068. defer waitGroup.Done()
  1069. for request := range requests {
  1070. // Requests are processed serially; API responses must be sent in request order.
  1071. var responsePayload []byte
  1072. var err error
  1073. if request.Type == "keepalive@openssh.com" {
  1074. // SSH keep alive round trips are used as speed test samples.
  1075. responsePayload, err = tactics.MakeSpeedTestResponse(
  1076. SSH_KEEP_ALIVE_PAYLOAD_MIN_BYTES, SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES)
  1077. } else {
  1078. // All other requests are assumed to be API requests.
  1079. sshClient.Lock()
  1080. authorizedAccessTypes := sshClient.handshakeState.authorizedAccessTypes
  1081. sshClient.Unlock()
  1082. // Note: unlock before use is only safe as long as referenced sshClient data,
  1083. // such as slices in handshakeState, is read-only after initially set.
  1084. responsePayload, err = sshAPIRequestHandler(
  1085. sshClient.sshServer.support,
  1086. sshClient.geoIPData,
  1087. authorizedAccessTypes,
  1088. request.Type,
  1089. request.Payload)
  1090. }
  1091. if err == nil {
  1092. err = request.Reply(true, responsePayload)
  1093. } else {
  1094. log.WithContextFields(LogFields{"error": err}).Warning("request failed")
  1095. err = request.Reply(false, nil)
  1096. }
  1097. if err != nil {
  1098. log.WithContextFields(LogFields{"error": err}).Warning("response failed")
  1099. }
  1100. }
  1101. }()
  1102. // Start OSL sender
  1103. if sshClient.supportsServerRequests {
  1104. waitGroup.Add(1)
  1105. go func() {
  1106. defer waitGroup.Done()
  1107. sshClient.runOSLSender()
  1108. }()
  1109. }
  1110. // Lifecycle of a TCP port forward:
  1111. //
  1112. // 1. A "direct-tcpip" SSH request is received from the client.
  1113. //
  1114. // A new TCP port forward request is enqueued. The queue delivers TCP port
  1115. // forward requests to the TCP port forward manager, which enforces the TCP
  1116. // port forward dial limit.
  1117. //
  1118. // Enqueuing new requests allows for reading further SSH requests from the
  1119. // client without blocking when the dial limit is hit; this is to permit new
  1120. // UDP/udpgw port forwards to be restablished without delay. The maximum size
  1121. // of the queue enforces a hard cap on resources consumed by a client in the
  1122. // pre-dial phase. When the queue is full, new TCP port forwards are
  1123. // immediately rejected.
  1124. //
  1125. // 2. The TCP port forward manager dequeues the request.
  1126. //
  1127. // The manager calls dialingTCPPortForward(), which increments
  1128. // concurrentDialingPortForwardCount, and calls
  1129. // isTCPDialingPortForwardLimitExceeded() to check the concurrent dialing
  1130. // count.
  1131. //
  1132. // The manager enforces the concurrent TCP dial limit: when at the limit, the
  1133. // manager blocks waiting for the number of dials to drop below the limit before
  1134. // dispatching the request to handleTCPPortForward(), which will run in its own
  1135. // goroutine and will dial and relay the port forward.
  1136. //
  1137. // The block delays the current request and also halts dequeuing of subsequent
  1138. // requests and could ultimately cause requests to be immediately rejected if
  1139. // the queue fills. These actions are intended to apply back pressure when
  1140. // upstream network resources are impaired.
  1141. //
  1142. // The time spent in the queue is deducted from the port forward's dial timeout.
  1143. // The time spent blocking while at the dial limit is similarly deducted from
  1144. // the dial timeout. If the dial timeout has expired before the dial begins, the
  1145. // port forward is rejected and a stat is recorded.
  1146. //
  1147. // 3. handleTCPPortForward() performs the port forward dial and relaying.
  1148. //
  1149. // a. Dial the target, using the dial timeout remaining after queue and blocking
  1150. // time is deducted.
  1151. //
  1152. // b. If the dial fails, call abortedTCPPortForward() to decrement
  1153. // concurrentDialingPortForwardCount, freeing up a dial slot.
  1154. //
  1155. // c. If the dial succeeds, call establishedPortForward(), which decrements
  1156. // concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
  1157. // the "established" port forward count.
  1158. //
  1159. // d. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
  1160. // concurrentPortForwardCount, the number of _established_ TCP port forwards.
  1161. // If the limit is exceeded, the LRU established TCP port forward is closed and
  1162. // the newly established TCP port forward proceeds. This LRU logic allows some
  1163. // dangling resource consumption (e.g., TIME_WAIT) while providing a better
  1164. // experience for clients.
  1165. //
  1166. // e. Relay data.
  1167. //
  1168. // f. Call closedPortForward() which decrements concurrentPortForwardCount and
  1169. // records bytes transferred.
  1170. // Start the TCP port forward manager
  1171. type newTCPPortForward struct {
  1172. enqueueTime monotime.Time
  1173. hostToConnect string
  1174. portToConnect int
  1175. newChannel ssh.NewChannel
  1176. }
  1177. // The queue size is set to the traffic rules (MaxTCPPortForwardCount +
  1178. // MaxTCPDialingPortForwardCount), which is a reasonable indication of resource
  1179. // limits per client; when that value is not set, a default is used.
  1180. // A limitation: this queue size is set once and doesn't change, for this client,
  1181. // when traffic rules are reloaded.
  1182. queueSize := sshClient.getTCPPortForwardQueueSize()
  1183. if queueSize == 0 {
  1184. queueSize = SSH_TCP_PORT_FORWARD_QUEUE_SIZE
  1185. }
  1186. newTCPPortForwards := make(chan *newTCPPortForward, queueSize)
  1187. waitGroup.Add(1)
  1188. go func() {
  1189. defer waitGroup.Done()
  1190. for newPortForward := range newTCPPortForwards {
  1191. remainingDialTimeout :=
  1192. time.Duration(sshClient.getDialTCPPortForwardTimeoutMilliseconds())*time.Millisecond -
  1193. monotime.Since(newPortForward.enqueueTime)
  1194. if remainingDialTimeout <= 0 {
  1195. sshClient.updateQualityMetricsWithRejectedDialingLimit()
  1196. sshClient.rejectNewChannel(
  1197. newPortForward.newChannel, "TCP port forward timed out in queue")
  1198. continue
  1199. }
  1200. // Reserve a TCP dialing slot.
  1201. //
  1202. // TOCTOU note: important to increment counts _before_ checking limits; otherwise,
  1203. // the client could potentially consume excess resources by initiating many port
  1204. // forwards concurrently.
  1205. sshClient.dialingTCPPortForward()
  1206. // When max dials are in progress, wait up to remainingDialTimeout for dialing
  1207. // to become available. This blocks all dequeing.
  1208. if sshClient.isTCPDialingPortForwardLimitExceeded() {
  1209. blockStartTime := monotime.Now()
  1210. ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
  1211. sshClient.setTCPPortForwardDialingAvailableSignal(cancelCtx)
  1212. <-ctx.Done()
  1213. sshClient.setTCPPortForwardDialingAvailableSignal(nil)
  1214. cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  1215. remainingDialTimeout -= monotime.Since(blockStartTime)
  1216. }
  1217. if remainingDialTimeout <= 0 {
  1218. // Release the dialing slot here since handleTCPChannel() won't be called.
  1219. sshClient.abortedTCPPortForward()
  1220. sshClient.updateQualityMetricsWithRejectedDialingLimit()
  1221. sshClient.rejectNewChannel(
  1222. newPortForward.newChannel, "TCP port forward timed out before dialing")
  1223. continue
  1224. }
  1225. // Dial and relay the TCP port forward. handleTCPChannel is run in its own worker goroutine.
  1226. // handleTCPChannel will release the dialing slot reserved by dialingTCPPortForward(); and
  1227. // will deal with remainingDialTimeout <= 0.
  1228. waitGroup.Add(1)
  1229. go func(remainingDialTimeout time.Duration, newPortForward *newTCPPortForward) {
  1230. defer waitGroup.Done()
  1231. sshClient.handleTCPChannel(
  1232. remainingDialTimeout,
  1233. newPortForward.hostToConnect,
  1234. newPortForward.portToConnect,
  1235. newPortForward.newChannel)
  1236. }(remainingDialTimeout, newPortForward)
  1237. }
  1238. }()
  1239. // Handle new channel (port forward) requests from the client.
  1240. //
  1241. // packet tunnel channels are handled by the packet tunnel server
  1242. // component. Each client may have at most one packet tunnel channel.
  1243. //
  1244. // udpgw client connections are dispatched immediately (clients use this for
  1245. // DNS, so it's essential to not block; and only one udpgw connection is
  1246. // retained at a time).
  1247. //
  1248. // All other TCP port forwards are dispatched via the TCP port forward
  1249. // manager queue.
  1250. for newChannel := range channels {
  1251. if newChannel.ChannelType() == protocol.PACKET_TUNNEL_CHANNEL_TYPE {
  1252. if !sshClient.sshServer.support.Config.RunPacketTunnel {
  1253. sshClient.rejectNewChannel(newChannel, "unsupported packet tunnel channel type")
  1254. continue
  1255. }
  1256. // Accept this channel immediately. This channel will replace any
  1257. // previously existing packet tunnel channel for this client.
  1258. packetTunnelChannel, requests, err := newChannel.Accept()
  1259. if err != nil {
  1260. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  1261. continue
  1262. }
  1263. go ssh.DiscardRequests(requests)
  1264. sshClient.setPacketTunnelChannel(packetTunnelChannel)
  1265. // PacketTunnelServer will run the client's packet tunnel. If necessary, ClientConnected
  1266. // will stop packet tunnel workers for any previous packet tunnel channel.
  1267. checkAllowedTCPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
  1268. return sshClient.isPortForwardPermitted(portForwardTypeTCP, false, upstreamIPAddress, port)
  1269. }
  1270. checkAllowedUDPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
  1271. return sshClient.isPortForwardPermitted(portForwardTypeUDP, false, upstreamIPAddress, port)
  1272. }
  1273. flowActivityUpdaterMaker := func(
  1274. upstreamHostname string, upstreamIPAddress net.IP) []tun.FlowActivityUpdater {
  1275. var updaters []tun.FlowActivityUpdater
  1276. oslUpdater := sshClient.newClientSeedPortForward(upstreamIPAddress)
  1277. if oslUpdater != nil {
  1278. updaters = append(updaters, oslUpdater)
  1279. }
  1280. return updaters
  1281. }
  1282. metricUpdater := func(
  1283. TCPApplicationBytesUp, TCPApplicationBytesDown,
  1284. UDPApplicationBytesUp, UDPApplicationBytesDown int64) {
  1285. sshClient.Lock()
  1286. sshClient.tcpTrafficState.bytesUp += TCPApplicationBytesUp
  1287. sshClient.tcpTrafficState.bytesDown += TCPApplicationBytesDown
  1288. sshClient.udpTrafficState.bytesUp += UDPApplicationBytesUp
  1289. sshClient.udpTrafficState.bytesDown += UDPApplicationBytesDown
  1290. sshClient.Unlock()
  1291. }
  1292. err = sshClient.sshServer.support.PacketTunnelServer.ClientConnected(
  1293. sshClient.sessionID,
  1294. packetTunnelChannel,
  1295. checkAllowedTCPPortFunc,
  1296. checkAllowedUDPPortFunc,
  1297. flowActivityUpdaterMaker,
  1298. metricUpdater)
  1299. if err != nil {
  1300. log.WithContextFields(LogFields{"error": err}).Warning("start packet tunnel client failed")
  1301. sshClient.setPacketTunnelChannel(nil)
  1302. }
  1303. continue
  1304. }
  1305. if newChannel.ChannelType() != "direct-tcpip" {
  1306. sshClient.rejectNewChannel(newChannel, "unknown or unsupported channel type")
  1307. continue
  1308. }
  1309. // http://tools.ietf.org/html/rfc4254#section-7.2
  1310. var directTcpipExtraData struct {
  1311. HostToConnect string
  1312. PortToConnect uint32
  1313. OriginatorIPAddress string
  1314. OriginatorPort uint32
  1315. }
  1316. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  1317. if err != nil {
  1318. sshClient.rejectNewChannel(newChannel, "invalid extra data")
  1319. continue
  1320. }
  1321. // Intercept TCP port forwards to a specified udpgw server and handle directly.
  1322. // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
  1323. isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
  1324. sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
  1325. net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
  1326. if isUDPChannel {
  1327. // Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
  1328. // own worker goroutine.
  1329. waitGroup.Add(1)
  1330. go func(channel ssh.NewChannel) {
  1331. defer waitGroup.Done()
  1332. sshClient.handleUDPChannel(channel)
  1333. }(newChannel)
  1334. } else {
  1335. // Dispatch via TCP port forward manager. When the queue is full, the channel
  1336. // is immediately rejected.
  1337. tcpPortForward := &newTCPPortForward{
  1338. enqueueTime: monotime.Now(),
  1339. hostToConnect: directTcpipExtraData.HostToConnect,
  1340. portToConnect: int(directTcpipExtraData.PortToConnect),
  1341. newChannel: newChannel,
  1342. }
  1343. select {
  1344. case newTCPPortForwards <- tcpPortForward:
  1345. default:
  1346. sshClient.updateQualityMetricsWithRejectedDialingLimit()
  1347. sshClient.rejectNewChannel(newChannel, "TCP port forward dial queue full")
  1348. }
  1349. }
  1350. }
  1351. // The channel loop is interrupted by a client
  1352. // disconnect or by calling sshClient.stop().
  1353. // Stop the TCP port forward manager
  1354. close(newTCPPortForwards)
  1355. // Stop all other worker goroutines
  1356. sshClient.stopRunning()
  1357. if sshClient.sshServer.support.Config.RunPacketTunnel {
  1358. // PacketTunnelServer.ClientDisconnected stops packet tunnel workers.
  1359. sshClient.sshServer.support.PacketTunnelServer.ClientDisconnected(
  1360. sshClient.sessionID)
  1361. }
  1362. waitGroup.Wait()
  1363. sshClient.cleanupAuthorizations()
  1364. }
  1365. func (sshClient *sshClient) cleanupAuthorizations() {
  1366. sshClient.Lock()
  1367. if sshClient.releaseAuthorizations != nil {
  1368. sshClient.releaseAuthorizations()
  1369. }
  1370. if sshClient.stopTimer != nil {
  1371. sshClient.stopTimer.Stop()
  1372. }
  1373. sshClient.Unlock()
  1374. }
  1375. // setPacketTunnelChannel sets the single packet tunnel channel
  1376. // for this sshClient. Any existing packet tunnel channel is
  1377. // closed.
  1378. func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
  1379. sshClient.Lock()
  1380. if sshClient.packetTunnelChannel != nil {
  1381. sshClient.packetTunnelChannel.Close()
  1382. }
  1383. sshClient.packetTunnelChannel = channel
  1384. sshClient.Unlock()
  1385. }
  1386. // setUDPChannel sets the single UDP channel for this sshClient.
  1387. // Each sshClient may have only one concurrent UDP channel. Each
  1388. // UDP channel multiplexes many UDP port forwards via the udpgw
  1389. // protocol. Any existing UDP channel is closed.
  1390. func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
  1391. sshClient.Lock()
  1392. if sshClient.udpChannel != nil {
  1393. sshClient.udpChannel.Close()
  1394. }
  1395. sshClient.udpChannel = channel
  1396. sshClient.Unlock()
  1397. }
  1398. func (sshClient *sshClient) logTunnel(additionalMetrics LogFields) {
  1399. // Note: reporting duration based on last confirmed data transfer, which
  1400. // is reads for sshClient.activityConn.GetActiveDuration(), and not
  1401. // connection closing is important for protocols such as meek. For
  1402. // meek, the connection remains open until the HTTP session expires,
  1403. // which may be some time after the tunnel has closed. (The meek
  1404. // protocol has no allowance for signalling payload EOF, and even if
  1405. // it did the client may not have the opportunity to send a final
  1406. // request with an EOF flag set.)
  1407. sshClient.Lock()
  1408. logFields := getRequestLogFields(
  1409. "server_tunnel",
  1410. sshClient.geoIPData,
  1411. sshClient.handshakeState.authorizedAccessTypes,
  1412. sshClient.handshakeState.apiParams,
  1413. baseRequestParams)
  1414. logFields["session_id"] = sshClient.sessionID
  1415. logFields["handshake_completed"] = sshClient.handshakeState.completed
  1416. logFields["start_time"] = sshClient.activityConn.GetStartTime()
  1417. logFields["duration"] = sshClient.activityConn.GetActiveDuration() / time.Millisecond
  1418. logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
  1419. logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
  1420. logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
  1421. logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
  1422. logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
  1423. logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
  1424. logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
  1425. // sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
  1426. logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
  1427. logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
  1428. // Pre-calculate a total-tunneled-bytes field. This total is used
  1429. // extensively in analytics and is more performant when pre-calculated.
  1430. logFields["bytes"] = sshClient.tcpTrafficState.bytesUp +
  1431. sshClient.tcpTrafficState.bytesDown +
  1432. sshClient.udpTrafficState.bytesUp +
  1433. sshClient.udpTrafficState.bytesDown
  1434. // Merge in additional metrics from the optional metrics source
  1435. if additionalMetrics != nil {
  1436. for name, value := range additionalMetrics {
  1437. // Don't overwrite any basic fields
  1438. if logFields[name] == nil {
  1439. logFields[name] = value
  1440. }
  1441. }
  1442. }
  1443. sshClient.Unlock()
  1444. // Note: unlock before use is only safe as long as referenced sshClient data,
  1445. // such as slices in handshakeState, is read-only after initially set.
  1446. log.LogRawFieldsWithTimestamp(logFields)
  1447. }
  1448. func (sshClient *sshClient) runOSLSender() {
  1449. for {
  1450. // Await a signal that there are SLOKs to send
  1451. // TODO: use reflect.SelectCase, and optionally await timer here?
  1452. select {
  1453. case <-sshClient.signalIssueSLOKs:
  1454. case <-sshClient.runCtx.Done():
  1455. return
  1456. }
  1457. retryDelay := SSH_SEND_OSL_INITIAL_RETRY_DELAY
  1458. for {
  1459. err := sshClient.sendOSLRequest()
  1460. if err == nil {
  1461. break
  1462. }
  1463. log.WithContextFields(LogFields{"error": err}).Warning("sendOSLRequest failed")
  1464. // If the request failed, retry after a delay (with exponential backoff)
  1465. // or when signaled that there are additional SLOKs to send
  1466. retryTimer := time.NewTimer(retryDelay)
  1467. select {
  1468. case <-retryTimer.C:
  1469. case <-sshClient.signalIssueSLOKs:
  1470. case <-sshClient.runCtx.Done():
  1471. retryTimer.Stop()
  1472. return
  1473. }
  1474. retryTimer.Stop()
  1475. retryDelay *= SSH_SEND_OSL_RETRY_FACTOR
  1476. }
  1477. }
  1478. }
  1479. // sendOSLRequest will invoke osl.GetSeedPayload to issue SLOKs and
  1480. // generate a payload, and send an OSL request to the client when
  1481. // there are new SLOKs in the payload.
  1482. func (sshClient *sshClient) sendOSLRequest() error {
  1483. seedPayload := sshClient.getOSLSeedPayload()
  1484. // Don't send when no SLOKs. This will happen when signalIssueSLOKs
  1485. // is received but no new SLOKs are issued.
  1486. if len(seedPayload.SLOKs) == 0 {
  1487. return nil
  1488. }
  1489. oslRequest := protocol.OSLRequest{
  1490. SeedPayload: seedPayload,
  1491. }
  1492. requestPayload, err := json.Marshal(oslRequest)
  1493. if err != nil {
  1494. return common.ContextError(err)
  1495. }
  1496. ok, _, err := sshClient.sshConn.SendRequest(
  1497. protocol.PSIPHON_API_OSL_REQUEST_NAME,
  1498. true,
  1499. requestPayload)
  1500. if err != nil {
  1501. return common.ContextError(err)
  1502. }
  1503. if !ok {
  1504. return common.ContextError(errors.New("client rejected request"))
  1505. }
  1506. sshClient.clearOSLSeedPayload()
  1507. return nil
  1508. }
  1509. func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, logMessage string) {
  1510. // We always return the reject reason "Prohibited":
  1511. // - Traffic rules and connection limits may prohibit the connection.
  1512. // - External firewall rules may prohibit the connection, and this is not currently
  1513. // distinguishable from other failure modes.
  1514. // - We limit the failure information revealed to the client.
  1515. reason := ssh.Prohibited
  1516. // Note: Debug level, as logMessage may contain user traffic destination address information
  1517. log.WithContextFields(
  1518. LogFields{
  1519. "channelType": newChannel.ChannelType(),
  1520. "logMessage": logMessage,
  1521. "rejectReason": reason.String(),
  1522. }).Debug("reject new channel")
  1523. // Note: logMessage is internal, for logging only; just the reject reason is sent to the client.
  1524. newChannel.Reject(reason, reason.String())
  1525. }
  1526. // setHandshakeState records that a client has completed a handshake API request.
  1527. // Some parameters from the handshake request may be used in future traffic rule
  1528. // selection. Port forwards are disallowed until a handshake is complete. The
  1529. // handshake parameters are included in the session summary log recorded in
  1530. // sshClient.stop().
  1531. func (sshClient *sshClient) setHandshakeState(
  1532. state handshakeState,
  1533. authorizations []string) ([]string, []string, error) {
  1534. sshClient.Lock()
  1535. completed := sshClient.handshakeState.completed
  1536. if !completed {
  1537. sshClient.handshakeState = state
  1538. }
  1539. sshClient.Unlock()
  1540. // Client must only perform one handshake
  1541. if completed {
  1542. return nil, nil, common.ContextError(errors.New("handshake already completed"))
  1543. }
  1544. // Verify the authorizations submitted by the client. Verified, active
  1545. // (non-expired) access types will be available for traffic rules
  1546. // filtering.
  1547. //
  1548. // When an authorization is active but expires while the client is
  1549. // connected, the client is disconnected to ensure the access is reset.
  1550. // This is implemented by setting a timer to perform the disconnect at the
  1551. // expiry time of the soonest expiring authorization.
  1552. //
  1553. // sshServer.authorizationSessionIDs tracks the unique mapping of active
  1554. // authorization IDs to client session IDs and is used to detect and
  1555. // prevent multiple malicious clients from reusing a single authorization
  1556. // (within the scope of this server).
  1557. // authorizationIDs and authorizedAccessTypes are returned to the client
  1558. // and logged, respectively; initialize to empty lists so the
  1559. // protocol/logs don't need to handle 'null' values.
  1560. authorizationIDs := make([]string, 0)
  1561. authorizedAccessTypes := make([]string, 0)
  1562. var stopTime time.Time
  1563. for i, authorization := range authorizations {
  1564. // This sanity check mitigates malicious clients causing excess CPU use.
  1565. if i >= MAX_AUTHORIZATIONS {
  1566. log.WithContext().Warning("too many authorizations")
  1567. break
  1568. }
  1569. verifiedAuthorization, err := accesscontrol.VerifyAuthorization(
  1570. &sshClient.sshServer.support.Config.AccessControlVerificationKeyRing,
  1571. authorization)
  1572. if err != nil {
  1573. log.WithContextFields(
  1574. LogFields{"error": err}).Warning("verify authorization failed")
  1575. continue
  1576. }
  1577. authorizationID := base64.StdEncoding.EncodeToString(verifiedAuthorization.ID)
  1578. if common.Contains(authorizedAccessTypes, verifiedAuthorization.AccessType) {
  1579. log.WithContextFields(
  1580. LogFields{"accessType": verifiedAuthorization.AccessType}).Warning("duplicate authorization access type")
  1581. continue
  1582. }
  1583. authorizationIDs = append(authorizationIDs, authorizationID)
  1584. authorizedAccessTypes = append(authorizedAccessTypes, verifiedAuthorization.AccessType)
  1585. if stopTime.IsZero() || stopTime.After(verifiedAuthorization.Expires) {
  1586. stopTime = verifiedAuthorization.Expires
  1587. }
  1588. }
  1589. // Associate all verified authorizationIDs with this client's session ID.
  1590. // Handle cases where previous associations exist:
  1591. //
  1592. // - Multiple malicious clients reusing a single authorization. In this
  1593. // case, authorizations are revoked from the previous client.
  1594. //
  1595. // - The client reconnected with a new session ID due to user toggling.
  1596. // This case is expected due to server affinity. This cannot be
  1597. // distinguished from the previous case and the same action is taken;
  1598. // this will have no impact on a legitimate client as the previous
  1599. // session is dangling.
  1600. //
  1601. // - The client automatically reconnected with the same session ID. This
  1602. // case is not expected as sshServer.registerEstablishedClient
  1603. // synchronously calls sshClient.releaseAuthorizations; as a safe guard,
  1604. // this case is distinguished and no revocation action is taken.
  1605. sshClient.sshServer.authorizationSessionIDsMutex.Lock()
  1606. for _, authorizationID := range authorizationIDs {
  1607. sessionID, ok := sshClient.sshServer.authorizationSessionIDs[authorizationID]
  1608. if ok && sessionID != sshClient.sessionID {
  1609. log.WithContextFields(
  1610. LogFields{"authorizationID": authorizationID}).Warning("duplicate active authorization")
  1611. // Invoke asynchronously to avoid deadlocks.
  1612. // TODO: invoke only once for each distinct sessionID?
  1613. go sshClient.sshServer.revokeClientAuthorizations(sessionID)
  1614. }
  1615. sshClient.sshServer.authorizationSessionIDs[authorizationID] = sshClient.sessionID
  1616. }
  1617. sshClient.sshServer.authorizationSessionIDsMutex.Unlock()
  1618. if len(authorizationIDs) > 0 {
  1619. sshClient.Lock()
  1620. // Make the authorizedAccessTypes available for traffic rules filtering.
  1621. sshClient.handshakeState.authorizedAccessTypes = authorizedAccessTypes
  1622. // On exit, sshClient.runTunnel will call releaseAuthorizations, which
  1623. // will release the authorization IDs so the client can reconnect and
  1624. // present the same authorizations again. sshClient.runTunnel will
  1625. // also cancel the stopTimer in case it has not yet fired.
  1626. // Note: termination of the stopTimer goroutine is not synchronized.
  1627. sshClient.releaseAuthorizations = func() {
  1628. sshClient.sshServer.authorizationSessionIDsMutex.Lock()
  1629. for _, authorizationID := range authorizationIDs {
  1630. sessionID, ok := sshClient.sshServer.authorizationSessionIDs[authorizationID]
  1631. if ok && sessionID == sshClient.sessionID {
  1632. delete(sshClient.sshServer.authorizationSessionIDs, authorizationID)
  1633. }
  1634. }
  1635. sshClient.sshServer.authorizationSessionIDsMutex.Unlock()
  1636. }
  1637. sshClient.stopTimer = time.AfterFunc(
  1638. stopTime.Sub(time.Now()),
  1639. func() {
  1640. sshClient.stop()
  1641. })
  1642. sshClient.Unlock()
  1643. }
  1644. sshClient.setTrafficRules()
  1645. sshClient.setOSLConfig()
  1646. return authorizationIDs, authorizedAccessTypes, nil
  1647. }
  1648. // getHandshaked returns whether the client has completed a handshake API
  1649. // request and whether the traffic rules that were selected after the
  1650. // handshake immediately exhaust the client.
  1651. //
  1652. // When the client is immediately exhausted it will be closed; but this
  1653. // takes effect asynchronously. The "exhausted" return value is used to
  1654. // prevent API requests by clients that will close.
  1655. func (sshClient *sshClient) getHandshaked() (bool, bool) {
  1656. sshClient.Lock()
  1657. defer sshClient.Unlock()
  1658. completed := sshClient.handshakeState.completed
  1659. exhausted := false
  1660. // Notes:
  1661. // - "Immediately exhausted" is when CloseAfterExhausted is set and
  1662. // either ReadUnthrottledBytes or WriteUnthrottledBytes starts from
  1663. // 0, so no bytes would be read or written. This check does not
  1664. // examine whether 0 bytes _remain_ in the ThrottledConn.
  1665. // - This check is made against the current traffic rules, which
  1666. // could have changed in a hot reload since the handshake.
  1667. if completed &&
  1668. *sshClient.trafficRules.RateLimits.CloseAfterExhausted == true &&
  1669. (*sshClient.trafficRules.RateLimits.ReadUnthrottledBytes == 0 ||
  1670. *sshClient.trafficRules.RateLimits.WriteUnthrottledBytes == 0) {
  1671. exhausted = true
  1672. }
  1673. return completed, exhausted
  1674. }
  1675. func (sshClient *sshClient) expectDomainBytes() bool {
  1676. sshClient.Lock()
  1677. defer sshClient.Unlock()
  1678. return sshClient.handshakeState.expectDomainBytes
  1679. }
  1680. // setTrafficRules resets the client's traffic rules based on the latest server config
  1681. // and client properties. As sshClient.trafficRules may be reset by a concurrent
  1682. // goroutine, trafficRules must only be accessed within the sshClient mutex.
  1683. func (sshClient *sshClient) setTrafficRules() {
  1684. sshClient.Lock()
  1685. defer sshClient.Unlock()
  1686. sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
  1687. sshClient.isFirstTunnelInSession,
  1688. sshClient.tunnelProtocol,
  1689. sshClient.geoIPData,
  1690. sshClient.handshakeState)
  1691. if sshClient.throttledConn != nil {
  1692. // Any existing throttling state is reset.
  1693. sshClient.throttledConn.SetLimits(
  1694. sshClient.trafficRules.RateLimits.CommonRateLimits())
  1695. }
  1696. }
  1697. // setOSLConfig resets the client's OSL seed state based on the latest OSL config
  1698. // As sshClient.oslClientSeedState may be reset by a concurrent goroutine,
  1699. // oslClientSeedState must only be accessed within the sshClient mutex.
  1700. func (sshClient *sshClient) setOSLConfig() {
  1701. sshClient.Lock()
  1702. defer sshClient.Unlock()
  1703. propagationChannelID, err := getStringRequestParam(
  1704. sshClient.handshakeState.apiParams, "propagation_channel_id")
  1705. if err != nil {
  1706. // This should not fail as long as client has sent valid handshake
  1707. return
  1708. }
  1709. // Use a cached seed state if one is found for the client's
  1710. // session ID. This enables resuming progress made in a previous
  1711. // tunnel.
  1712. // Note: go-cache is already concurency safe; the additional mutex
  1713. // is necessary to guarantee that Get/Delete is atomic; although in
  1714. // practice no two concurrent clients should ever supply the same
  1715. // session ID.
  1716. sshClient.sshServer.oslSessionCacheMutex.Lock()
  1717. oslClientSeedState, found := sshClient.sshServer.oslSessionCache.Get(sshClient.sessionID)
  1718. if found {
  1719. sshClient.sshServer.oslSessionCache.Delete(sshClient.sessionID)
  1720. sshClient.sshServer.oslSessionCacheMutex.Unlock()
  1721. sshClient.oslClientSeedState = oslClientSeedState.(*osl.ClientSeedState)
  1722. sshClient.oslClientSeedState.Resume(sshClient.signalIssueSLOKs)
  1723. return
  1724. }
  1725. sshClient.sshServer.oslSessionCacheMutex.Unlock()
  1726. // Two limitations when setOSLConfig() is invoked due to an
  1727. // OSL config hot reload:
  1728. //
  1729. // 1. any partial progress towards SLOKs is lost.
  1730. //
  1731. // 2. all existing osl.ClientSeedPortForwards for existing
  1732. // port forwards will not send progress to the new client
  1733. // seed state.
  1734. sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
  1735. sshClient.geoIPData.Country,
  1736. propagationChannelID,
  1737. sshClient.signalIssueSLOKs)
  1738. }
  1739. // newClientSeedPortForward will return nil when no seeding is
  1740. // associated with the specified ipAddress.
  1741. func (sshClient *sshClient) newClientSeedPortForward(ipAddress net.IP) *osl.ClientSeedPortForward {
  1742. sshClient.Lock()
  1743. defer sshClient.Unlock()
  1744. // Will not be initialized before handshake.
  1745. if sshClient.oslClientSeedState == nil {
  1746. return nil
  1747. }
  1748. return sshClient.oslClientSeedState.NewClientSeedPortForward(ipAddress)
  1749. }
  1750. // getOSLSeedPayload returns a payload containing all seeded SLOKs for
  1751. // this client's session.
  1752. func (sshClient *sshClient) getOSLSeedPayload() *osl.SeedPayload {
  1753. sshClient.Lock()
  1754. defer sshClient.Unlock()
  1755. // Will not be initialized before handshake.
  1756. if sshClient.oslClientSeedState == nil {
  1757. return &osl.SeedPayload{SLOKs: make([]*osl.SLOK, 0)}
  1758. }
  1759. return sshClient.oslClientSeedState.GetSeedPayload()
  1760. }
  1761. func (sshClient *sshClient) clearOSLSeedPayload() {
  1762. sshClient.Lock()
  1763. defer sshClient.Unlock()
  1764. sshClient.oslClientSeedState.ClearSeedPayload()
  1765. }
  1766. func (sshClient *sshClient) rateLimits() common.RateLimits {
  1767. sshClient.Lock()
  1768. defer sshClient.Unlock()
  1769. return sshClient.trafficRules.RateLimits.CommonRateLimits()
  1770. }
  1771. func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
  1772. sshClient.Lock()
  1773. defer sshClient.Unlock()
  1774. return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
  1775. }
  1776. func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
  1777. sshClient.Lock()
  1778. defer sshClient.Unlock()
  1779. return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
  1780. }
  1781. func (sshClient *sshClient) setTCPPortForwardDialingAvailableSignal(signal context.CancelFunc) {
  1782. sshClient.Lock()
  1783. defer sshClient.Unlock()
  1784. sshClient.tcpPortForwardDialingAvailableSignal = signal
  1785. }
  1786. const (
  1787. portForwardTypeTCP = iota
  1788. portForwardTypeUDP
  1789. )
  1790. func (sshClient *sshClient) isPortForwardPermitted(
  1791. portForwardType int,
  1792. isTransparentDNSForwarding bool,
  1793. remoteIP net.IP,
  1794. port int) bool {
  1795. sshClient.Lock()
  1796. defer sshClient.Unlock()
  1797. if !sshClient.handshakeState.completed {
  1798. return false
  1799. }
  1800. // Disallow connection to loopback. This is a failsafe. The server
  1801. // should be run on a host with correctly configured firewall rules.
  1802. // An exception is made in the case of tranparent DNS forwarding,
  1803. // where the remoteIP has been rewritten.
  1804. if !isTransparentDNSForwarding && remoteIP.IsLoopback() {
  1805. return false
  1806. }
  1807. var allowPorts []int
  1808. if portForwardType == portForwardTypeTCP {
  1809. allowPorts = sshClient.trafficRules.AllowTCPPorts
  1810. } else {
  1811. allowPorts = sshClient.trafficRules.AllowUDPPorts
  1812. }
  1813. if len(allowPorts) == 0 {
  1814. return true
  1815. }
  1816. // TODO: faster lookup?
  1817. if len(allowPorts) > 0 {
  1818. for _, allowPort := range allowPorts {
  1819. if port == allowPort {
  1820. return true
  1821. }
  1822. }
  1823. }
  1824. for _, subnet := range sshClient.trafficRules.AllowSubnets {
  1825. // Note: ignoring error as config has been validated
  1826. _, network, _ := net.ParseCIDR(subnet)
  1827. if network.Contains(remoteIP) {
  1828. return true
  1829. }
  1830. }
  1831. log.WithContextFields(
  1832. LogFields{
  1833. "type": portForwardType,
  1834. "port": port,
  1835. }).Debug("port forward denied by traffic rules")
  1836. return false
  1837. }
  1838. func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
  1839. sshClient.Lock()
  1840. defer sshClient.Unlock()
  1841. state := &sshClient.tcpTrafficState
  1842. max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  1843. if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
  1844. return true
  1845. }
  1846. return false
  1847. }
  1848. func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
  1849. sshClient.Lock()
  1850. defer sshClient.Unlock()
  1851. return *sshClient.trafficRules.MaxTCPPortForwardCount +
  1852. *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  1853. }
  1854. func (sshClient *sshClient) getDialTCPPortForwardTimeoutMilliseconds() int {
  1855. sshClient.Lock()
  1856. defer sshClient.Unlock()
  1857. return *sshClient.trafficRules.DialTCPPortForwardTimeoutMilliseconds
  1858. }
  1859. func (sshClient *sshClient) dialingTCPPortForward() {
  1860. sshClient.Lock()
  1861. defer sshClient.Unlock()
  1862. state := &sshClient.tcpTrafficState
  1863. state.concurrentDialingPortForwardCount += 1
  1864. if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
  1865. state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
  1866. }
  1867. }
  1868. func (sshClient *sshClient) abortedTCPPortForward() {
  1869. sshClient.Lock()
  1870. defer sshClient.Unlock()
  1871. sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
  1872. }
  1873. func (sshClient *sshClient) allocatePortForward(portForwardType int) bool {
  1874. sshClient.Lock()
  1875. defer sshClient.Unlock()
  1876. // Check if at port forward limit. The subsequent counter
  1877. // changes must be atomic with the limit check to ensure
  1878. // the counter never exceeds the limit in the case of
  1879. // concurrent allocations.
  1880. var max int
  1881. var state *trafficState
  1882. if portForwardType == portForwardTypeTCP {
  1883. max = *sshClient.trafficRules.MaxTCPPortForwardCount
  1884. state = &sshClient.tcpTrafficState
  1885. } else {
  1886. max = *sshClient.trafficRules.MaxUDPPortForwardCount
  1887. state = &sshClient.udpTrafficState
  1888. }
  1889. if max > 0 && state.concurrentPortForwardCount >= int64(max) {
  1890. return false
  1891. }
  1892. // Update port forward counters.
  1893. if portForwardType == portForwardTypeTCP {
  1894. // Assumes TCP port forwards called dialingTCPPortForward
  1895. state.concurrentDialingPortForwardCount -= 1
  1896. if sshClient.tcpPortForwardDialingAvailableSignal != nil {
  1897. max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  1898. if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
  1899. sshClient.tcpPortForwardDialingAvailableSignal()
  1900. }
  1901. }
  1902. }
  1903. state.concurrentPortForwardCount += 1
  1904. if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
  1905. state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
  1906. }
  1907. state.totalPortForwardCount += 1
  1908. return true
  1909. }
  1910. // establishedPortForward increments the concurrent port
  1911. // forward counter. closedPortForward decrements it, so it
  1912. // must always be called for each establishedPortForward
  1913. // call.
  1914. //
  1915. // When at the limit of established port forwards, the LRU
  1916. // existing port forward is closed to make way for the newly
  1917. // established one. There can be a minor delay as, in addition
  1918. // to calling Close() on the port forward net.Conn,
  1919. // establishedPortForward waits for the LRU's closedPortForward()
  1920. // call which will decrement the concurrent counter. This
  1921. // ensures all resources associated with the LRU (socket,
  1922. // goroutine) are released or will very soon be released before
  1923. // proceeding.
  1924. func (sshClient *sshClient) establishedPortForward(
  1925. portForwardType int, portForwardLRU *common.LRUConns) {
  1926. // Do not lock sshClient here.
  1927. var state *trafficState
  1928. if portForwardType == portForwardTypeTCP {
  1929. state = &sshClient.tcpTrafficState
  1930. } else {
  1931. state = &sshClient.udpTrafficState
  1932. }
  1933. // When the maximum number of port forwards is already
  1934. // established, close the LRU. CloseOldest will call
  1935. // Close on the port forward net.Conn. Both TCP and
  1936. // UDP port forwards have handler goroutines that may
  1937. // be blocked calling Read on the net.Conn. Close will
  1938. // eventually interrupt the Read and cause the handlers
  1939. // to exit, but not immediately. So the following logic
  1940. // waits for a LRU handler to be interrupted and signal
  1941. // availability.
  1942. //
  1943. // Notes:
  1944. //
  1945. // - the port forward limit can change via a traffic
  1946. // rules hot reload; the condition variable handles
  1947. // this case whereas a channel-based semaphore would
  1948. // not.
  1949. //
  1950. // - if a number of goroutines exceeding the total limit
  1951. // arrive here all concurrently, some CloseOldest() calls
  1952. // will have no effect as there can be less existing port
  1953. // forwards than new ones. In this case, the new port
  1954. // forward will be delayed. This is highly unlikely in
  1955. // practise since UDP calls to establishedPortForward are
  1956. // serialized and TCP calls are limited by the dial
  1957. // queue/count.
  1958. if !sshClient.allocatePortForward(portForwardType) {
  1959. portForwardLRU.CloseOldest()
  1960. log.WithContext().Debug("closed LRU port forward")
  1961. state.availablePortForwardCond.L.Lock()
  1962. for !sshClient.allocatePortForward(portForwardType) {
  1963. state.availablePortForwardCond.Wait()
  1964. }
  1965. state.availablePortForwardCond.L.Unlock()
  1966. }
  1967. }
  1968. func (sshClient *sshClient) closedPortForward(
  1969. portForwardType int, bytesUp, bytesDown int64) {
  1970. sshClient.Lock()
  1971. var state *trafficState
  1972. if portForwardType == portForwardTypeTCP {
  1973. state = &sshClient.tcpTrafficState
  1974. } else {
  1975. state = &sshClient.udpTrafficState
  1976. }
  1977. state.concurrentPortForwardCount -= 1
  1978. state.bytesUp += bytesUp
  1979. state.bytesDown += bytesDown
  1980. sshClient.Unlock()
  1981. // Signal any goroutine waiting in establishedPortForward
  1982. // that an established port forward slot is available.
  1983. state.availablePortForwardCond.Signal()
  1984. }
  1985. func (sshClient *sshClient) updateQualityMetricsWithDialResult(
  1986. tcpPortForwardDialSuccess bool, dialDuration time.Duration) {
  1987. sshClient.Lock()
  1988. defer sshClient.Unlock()
  1989. if tcpPortForwardDialSuccess {
  1990. sshClient.qualityMetrics.tcpPortForwardDialedCount += 1
  1991. sshClient.qualityMetrics.tcpPortForwardDialedDuration += dialDuration
  1992. } else {
  1993. sshClient.qualityMetrics.tcpPortForwardFailedCount += 1
  1994. sshClient.qualityMetrics.tcpPortForwardFailedDuration += dialDuration
  1995. }
  1996. }
  1997. func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
  1998. sshClient.Lock()
  1999. defer sshClient.Unlock()
  2000. sshClient.qualityMetrics.tcpPortForwardRejectedDialingLimitCount += 1
  2001. }
  2002. func (sshClient *sshClient) handleTCPChannel(
  2003. remainingDialTimeout time.Duration,
  2004. hostToConnect string,
  2005. portToConnect int,
  2006. newChannel ssh.NewChannel) {
  2007. // Assumptions:
  2008. // - sshClient.dialingTCPPortForward() has been called
  2009. // - remainingDialTimeout > 0
  2010. established := false
  2011. defer func() {
  2012. if !established {
  2013. sshClient.abortedTCPPortForward()
  2014. }
  2015. }()
  2016. // Transparently redirect web API request connections.
  2017. isWebServerPortForward := false
  2018. config := sshClient.sshServer.support.Config
  2019. if config.WebServerPortForwardAddress != "" {
  2020. destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
  2021. if destination == config.WebServerPortForwardAddress {
  2022. isWebServerPortForward = true
  2023. if config.WebServerPortForwardRedirectAddress != "" {
  2024. // Note: redirect format is validated when config is loaded
  2025. host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
  2026. port, _ := strconv.Atoi(portStr)
  2027. hostToConnect = host
  2028. portToConnect = port
  2029. }
  2030. }
  2031. }
  2032. // Dial the remote address.
  2033. //
  2034. // Hostname resolution is performed explicitly, as a separate step, as the target IP
  2035. // address is used for traffic rules (AllowSubnets) and OSL seed progress.
  2036. //
  2037. // Contexts are used for cancellation (via sshClient.runCtx, which is cancelled
  2038. // when the client is stopping) and timeouts.
  2039. dialStartTime := monotime.Now()
  2040. log.WithContextFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
  2041. ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
  2042. IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
  2043. cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  2044. // TODO: shuffle list to try other IPs?
  2045. // TODO: IPv6 support
  2046. var IP net.IP
  2047. for _, ip := range IPs {
  2048. if ip.IP.To4() != nil {
  2049. IP = ip.IP
  2050. break
  2051. }
  2052. }
  2053. if err == nil && IP == nil {
  2054. err = errors.New("no IP address")
  2055. }
  2056. resolveElapsedTime := monotime.Since(dialStartTime)
  2057. if err != nil {
  2058. // Record a port forward failure
  2059. sshClient.updateQualityMetricsWithDialResult(true, resolveElapsedTime)
  2060. sshClient.rejectNewChannel(newChannel, fmt.Sprintf("LookupIP failed: %s", err))
  2061. return
  2062. }
  2063. remainingDialTimeout -= resolveElapsedTime
  2064. if remainingDialTimeout <= 0 {
  2065. sshClient.rejectNewChannel(newChannel, "TCP port forward timed out resolving")
  2066. return
  2067. }
  2068. // Enforce traffic rules, using the resolved IP address.
  2069. if !isWebServerPortForward &&
  2070. !sshClient.isPortForwardPermitted(
  2071. portForwardTypeTCP,
  2072. false,
  2073. IP,
  2074. portToConnect) {
  2075. // Note: not recording a port forward failure in this case
  2076. sshClient.rejectNewChannel(newChannel, "port forward not permitted")
  2077. return
  2078. }
  2079. // TCP dial.
  2080. remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
  2081. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
  2082. ctx, cancelCtx = context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
  2083. fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
  2084. cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  2085. // Record port forward success or failure
  2086. sshClient.updateQualityMetricsWithDialResult(err == nil, monotime.Since(dialStartTime))
  2087. if err != nil {
  2088. // Monitor for low resource error conditions
  2089. sshClient.sshServer.monitorPortForwardDialError(err)
  2090. sshClient.rejectNewChannel(newChannel, fmt.Sprintf("DialTimeout failed: %s", err))
  2091. return
  2092. }
  2093. // The upstream TCP port forward connection has been established. Schedule
  2094. // some cleanup and notify the SSH client that the channel is accepted.
  2095. defer fwdConn.Close()
  2096. fwdChannel, requests, err := newChannel.Accept()
  2097. if err != nil {
  2098. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  2099. return
  2100. }
  2101. go ssh.DiscardRequests(requests)
  2102. defer fwdChannel.Close()
  2103. // Release the dialing slot and acquire an established slot.
  2104. //
  2105. // establishedPortForward increments the concurrent TCP port
  2106. // forward counter and closes the LRU existing TCP port forward
  2107. // when already at the limit.
  2108. //
  2109. // Known limitations:
  2110. //
  2111. // - Closed LRU TCP sockets will enter the TIME_WAIT state,
  2112. // continuing to consume some resources.
  2113. sshClient.establishedPortForward(portForwardTypeTCP, sshClient.tcpPortForwardLRU)
  2114. // "established = true" cancels the deferred abortedTCPPortForward()
  2115. established = true
  2116. // TODO: 64-bit alignment? https://golang.org/pkg/sync/atomic/#pkg-note-BUG
  2117. var bytesUp, bytesDown int64
  2118. defer func() {
  2119. sshClient.closedPortForward(
  2120. portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
  2121. }()
  2122. lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
  2123. defer lruEntry.Remove()
  2124. // ActivityMonitoredConn monitors the TCP port forward I/O and updates
  2125. // its LRU status. ActivityMonitoredConn also times out I/O on the port
  2126. // forward if both reads and writes have been idle for the specified
  2127. // duration.
  2128. // Ensure nil interface if newClientSeedPortForward returns nil
  2129. var updater common.ActivityUpdater
  2130. seedUpdater := sshClient.newClientSeedPortForward(IP)
  2131. if seedUpdater != nil {
  2132. updater = seedUpdater
  2133. }
  2134. fwdConn, err = common.NewActivityMonitoredConn(
  2135. fwdConn,
  2136. sshClient.idleTCPPortForwardTimeout(),
  2137. true,
  2138. updater,
  2139. lruEntry)
  2140. if err != nil {
  2141. log.WithContextFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
  2142. return
  2143. }
  2144. // Relay channel to forwarded connection.
  2145. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
  2146. // TODO: relay errors to fwdChannel.Stderr()?
  2147. relayWaitGroup := new(sync.WaitGroup)
  2148. relayWaitGroup.Add(1)
  2149. go func() {
  2150. defer relayWaitGroup.Done()
  2151. // io.Copy allocates a 32K temporary buffer, and each port forward relay uses
  2152. // two of these buffers; using io.CopyBuffer with a smaller buffer reduces the
  2153. // overall memory footprint.
  2154. bytes, err := io.CopyBuffer(
  2155. fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  2156. atomic.AddInt64(&bytesDown, bytes)
  2157. if err != nil && err != io.EOF {
  2158. // Debug since errors such as "connection reset by peer" occur during normal operation
  2159. log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
  2160. }
  2161. // Interrupt upstream io.Copy when downstream is shutting down.
  2162. // TODO: this is done to quickly cleanup the port forward when
  2163. // fwdConn has a read timeout, but is it clean -- upstream may still
  2164. // be flowing?
  2165. fwdChannel.Close()
  2166. }()
  2167. bytes, err := io.CopyBuffer(
  2168. fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  2169. atomic.AddInt64(&bytesUp, bytes)
  2170. if err != nil && err != io.EOF {
  2171. log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
  2172. }
  2173. // Shutdown special case: fwdChannel will be closed and return EOF when
  2174. // the SSH connection is closed, but we need to explicitly close fwdConn
  2175. // to interrupt the downstream io.Copy, which may be blocked on a
  2176. // fwdConn.Read().
  2177. fwdConn.Close()
  2178. relayWaitGroup.Wait()
  2179. log.WithContextFields(
  2180. LogFields{
  2181. "remoteAddr": remoteAddr,
  2182. "bytesUp": atomic.LoadInt64(&bytesUp),
  2183. "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
  2184. }