tunnelServer.go 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "crypto/subtle"
  22. "encoding/json"
  23. "errors"
  24. "fmt"
  25. "io"
  26. "net"
  27. "sync"
  28. "sync/atomic"
  29. "time"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
  31. "golang.org/x/crypto/ssh"
  32. )
  33. // TunnelServer is the main server that accepts Psiphon client
  34. // connections, via various obfuscation protocols, and provides
  35. // port forwarding (TCP and UDP) services to the Psiphon client.
  36. // At its core, TunnelServer is an SSH server. SSH is the base
  37. // protocol that provides port forward multiplexing, and transport
  38. // security. Layered on top of SSH, optionally, is Obfuscated SSH
  39. // and meek protocols, which provide further circumvention
  40. // capabilities.
  41. type TunnelServer struct {
  42. runWaitGroup *sync.WaitGroup
  43. listenerError chan error
  44. shutdownBroadcast <-chan struct{}
  45. sshServer *sshServer
  46. }
  47. // NewTunnelServer initializes a new tunnel server.
  48. func NewTunnelServer(
  49. support *SupportServices,
  50. shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
  51. sshServer, err := newSSHServer(support, shutdownBroadcast)
  52. if err != nil {
  53. return nil, psiphon.ContextError(err)
  54. }
  55. return &TunnelServer{
  56. runWaitGroup: new(sync.WaitGroup),
  57. listenerError: make(chan error),
  58. shutdownBroadcast: shutdownBroadcast,
  59. sshServer: sshServer,
  60. }, nil
  61. }
  62. // GetLoadStats returns load stats for the tunnel server. The stats are
  63. // broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
  64. // include current connected client count, total number of current port
  65. // forwards.
  66. func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
  67. return server.sshServer.getLoadStats()
  68. }
  69. // Run runs the tunnel server; this function blocks while running a selection of
  70. // listeners that handle connection using various obfuscation protocols.
  71. //
  72. // Run listens on each designated tunnel port and spawns new goroutines to handle
  73. // each client connection. It halts when shutdownBroadcast is signaled. A list of active
  74. // clients is maintained, and when halting all clients are cleanly shutdown.
  75. //
  76. // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
  77. // authentication, and then looping on client new channel requests. "direct-tcpip"
  78. // channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress
  79. // config parameter is configured, UDP port forwards over a TCP stream, following
  80. // the udpgw protocol, are handled.
  81. //
  82. // A new goroutine is spawned to handle each port forward for each client. Each port
  83. // forward tracks its bytes transferred. Overall per-client stats for connection duration,
  84. // GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
  85. // client shuts down.
  86. func (server *TunnelServer) Run() error {
  87. type sshListener struct {
  88. net.Listener
  89. localAddress string
  90. tunnelProtocol string
  91. }
  92. // TODO: should TunnelServer hold its own support pointer?
  93. support := server.sshServer.support
  94. // First bind all listeners; once all are successful,
  95. // start accepting connections on each.
  96. var listeners []*sshListener
  97. for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts {
  98. localAddress := fmt.Sprintf(
  99. "%s:%d", support.Config.ServerIPAddress, listenPort)
  100. listener, err := net.Listen("tcp", localAddress)
  101. if err != nil {
  102. for _, existingListener := range listeners {
  103. existingListener.Listener.Close()
  104. }
  105. return psiphon.ContextError(err)
  106. }
  107. log.WithContextFields(
  108. LogFields{
  109. "localAddress": localAddress,
  110. "tunnelProtocol": tunnelProtocol,
  111. }).Info("listening")
  112. listeners = append(
  113. listeners,
  114. &sshListener{
  115. Listener: listener,
  116. localAddress: localAddress,
  117. tunnelProtocol: tunnelProtocol,
  118. })
  119. }
  120. for _, listener := range listeners {
  121. server.runWaitGroup.Add(1)
  122. go func(listener *sshListener) {
  123. defer server.runWaitGroup.Done()
  124. log.WithContextFields(
  125. LogFields{
  126. "localAddress": listener.localAddress,
  127. "tunnelProtocol": listener.tunnelProtocol,
  128. }).Info("running")
  129. server.sshServer.runListener(
  130. listener.Listener,
  131. server.listenerError,
  132. listener.tunnelProtocol)
  133. log.WithContextFields(
  134. LogFields{
  135. "localAddress": listener.localAddress,
  136. "tunnelProtocol": listener.tunnelProtocol,
  137. }).Info("stopped")
  138. }(listener)
  139. }
  140. var err error
  141. select {
  142. case <-server.shutdownBroadcast:
  143. case err = <-server.listenerError:
  144. }
  145. for _, listener := range listeners {
  146. listener.Close()
  147. }
  148. server.sshServer.stopClients()
  149. server.runWaitGroup.Wait()
  150. log.WithContext().Info("stopped")
  151. return err
  152. }
  153. type sshClientID uint64
  154. type sshServer struct {
  155. support *SupportServices
  156. shutdownBroadcast <-chan struct{}
  157. sshHostKey ssh.Signer
  158. nextClientID sshClientID
  159. clientsMutex sync.Mutex
  160. stoppingClients bool
  161. acceptedClientCounts map[string]int64
  162. clients map[sshClientID]*sshClient
  163. }
  164. func newSSHServer(
  165. support *SupportServices,
  166. shutdownBroadcast <-chan struct{}) (*sshServer, error) {
  167. privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey))
  168. if err != nil {
  169. return nil, psiphon.ContextError(err)
  170. }
  171. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  172. signer, err := ssh.NewSignerFromKey(privateKey)
  173. if err != nil {
  174. return nil, psiphon.ContextError(err)
  175. }
  176. return &sshServer{
  177. support: support,
  178. shutdownBroadcast: shutdownBroadcast,
  179. sshHostKey: signer,
  180. nextClientID: 1,
  181. acceptedClientCounts: make(map[string]int64),
  182. clients: make(map[sshClientID]*sshClient),
  183. }, nil
  184. }
  185. // runListener is intended to run an a goroutine; it blocks
  186. // running a particular listener. If an unrecoverable error
  187. // occurs, it will send the error to the listenerError channel.
  188. func (sshServer *sshServer) runListener(
  189. listener net.Listener,
  190. listenerError chan<- error,
  191. tunnelProtocol string) {
  192. handleClient := func(clientConn net.Conn) {
  193. // process each client connection concurrently
  194. go sshServer.handleClient(tunnelProtocol, clientConn)
  195. }
  196. // Note: when exiting due to a unrecoverable error, be sure
  197. // to try to send the error to listenerError so that the outer
  198. // TunnelServer.Run will properly shut down instead of remaining
  199. // running.
  200. if psiphon.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
  201. psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
  202. meekServer, err := NewMeekServer(
  203. sshServer.support,
  204. listener,
  205. psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol),
  206. handleClient,
  207. sshServer.shutdownBroadcast)
  208. if err != nil {
  209. select {
  210. case listenerError <- psiphon.ContextError(err):
  211. default:
  212. }
  213. return
  214. }
  215. meekServer.Run()
  216. } else {
  217. for {
  218. conn, err := listener.Accept()
  219. select {
  220. case <-sshServer.shutdownBroadcast:
  221. if err == nil {
  222. conn.Close()
  223. }
  224. return
  225. default:
  226. }
  227. if err != nil {
  228. if e, ok := err.(net.Error); ok && e.Temporary() {
  229. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  230. // Temporary error, keep running
  231. continue
  232. }
  233. select {
  234. case listenerError <- psiphon.ContextError(err):
  235. default:
  236. }
  237. return
  238. }
  239. handleClient(conn)
  240. }
  241. }
  242. }
  243. // An accepted client has completed a direct TCP or meek connection and has a net.Conn. Registration
  244. // is for tracking the number of connections.
  245. func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol string) {
  246. sshServer.clientsMutex.Lock()
  247. defer sshServer.clientsMutex.Unlock()
  248. sshServer.acceptedClientCounts[tunnelProtocol] += 1
  249. }
  250. func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol string) {
  251. sshServer.clientsMutex.Lock()
  252. defer sshServer.clientsMutex.Unlock()
  253. sshServer.acceptedClientCounts[tunnelProtocol] -= 1
  254. }
  255. // An established client has completed its SSH handshake and has a ssh.Conn. Registration is
  256. // for tracking the number of fully established clients and for maintaining a list of running
  257. // clients (for stopping at shutdown time).
  258. func (sshServer *sshServer) registerEstablishedClient(client *sshClient) (sshClientID, bool) {
  259. sshServer.clientsMutex.Lock()
  260. defer sshServer.clientsMutex.Unlock()
  261. if sshServer.stoppingClients {
  262. return 0, false
  263. }
  264. clientID := sshServer.nextClientID
  265. sshServer.nextClientID += 1
  266. sshServer.clients[clientID] = client
  267. return clientID, true
  268. }
  269. func (sshServer *sshServer) unregisterEstablishedClient(clientID sshClientID) {
  270. sshServer.clientsMutex.Lock()
  271. client := sshServer.clients[clientID]
  272. delete(sshServer.clients, clientID)
  273. sshServer.clientsMutex.Unlock()
  274. if client != nil {
  275. client.stop()
  276. }
  277. }
  278. func (sshServer *sshServer) getLoadStats() map[string]map[string]int64 {
  279. sshServer.clientsMutex.Lock()
  280. defer sshServer.clientsMutex.Unlock()
  281. loadStats := make(map[string]map[string]int64)
  282. // Explicitly populate with zeros to get 0 counts in log messages derived from getLoadStats()
  283. for tunnelProtocol, _ := range sshServer.support.Config.TunnelProtocolPorts {
  284. loadStats[tunnelProtocol] = make(map[string]int64)
  285. loadStats[tunnelProtocol]["AcceptedClients"] = 0
  286. loadStats[tunnelProtocol]["EstablishedClients"] = 0
  287. loadStats[tunnelProtocol]["TCPPortForwards"] = 0
  288. loadStats[tunnelProtocol]["TotalTCPPortForwards"] = 0
  289. loadStats[tunnelProtocol]["UDPPortForwards"] = 0
  290. loadStats[tunnelProtocol]["TotalUDPPortForwards"] = 0
  291. }
  292. // Note: as currently tracked/counted, each established client is also an accepted client
  293. for tunnelProtocol, acceptedClientCount := range sshServer.acceptedClientCounts {
  294. loadStats[tunnelProtocol]["AcceptedClients"] = acceptedClientCount
  295. }
  296. for _, client := range sshServer.clients {
  297. // Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
  298. loadStats[client.tunnelProtocol]["EstablishedClients"] += 1
  299. client.Lock()
  300. loadStats[client.tunnelProtocol]["TCPPortForwards"] += client.tcpTrafficState.concurrentPortForwardCount
  301. loadStats[client.tunnelProtocol]["TotalTCPPortForwards"] += client.tcpTrafficState.totalPortForwardCount
  302. loadStats[client.tunnelProtocol]["UDPPortForwards"] += client.udpTrafficState.concurrentPortForwardCount
  303. loadStats[client.tunnelProtocol]["TotalUDPPortForwards"] += client.udpTrafficState.totalPortForwardCount
  304. client.Unlock()
  305. }
  306. return loadStats
  307. }
  308. func (sshServer *sshServer) stopClients() {
  309. sshServer.clientsMutex.Lock()
  310. sshServer.stoppingClients = true
  311. clients := sshServer.clients
  312. sshServer.clients = make(map[sshClientID]*sshClient)
  313. sshServer.clientsMutex.Unlock()
  314. for _, client := range clients {
  315. client.stop()
  316. }
  317. }
  318. func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
  319. sshServer.registerAcceptedClient(tunnelProtocol)
  320. defer sshServer.unregisterAcceptedClient(tunnelProtocol)
  321. geoIPData := sshServer.support.GeoIPService.Lookup(
  322. psiphon.IPAddressFromAddr(clientConn.RemoteAddr()))
  323. // TODO: apply reload of TrafficRulesSet to existing clients
  324. sshClient := newSshClient(
  325. sshServer,
  326. tunnelProtocol,
  327. geoIPData,
  328. sshServer.support.TrafficRulesSet.GetTrafficRules(geoIPData.Country))
  329. // Wrap the base client connection with an ActivityMonitoredConn which will
  330. // terminate the connection if no data is received before the deadline. This
  331. // timeout is in effect for the entire duration of the SSH connection. Clients
  332. // must actively use the connection or send SSH keep alive requests to keep
  333. // the connection active.
  334. activityConn := psiphon.NewActivityMonitoredConn(
  335. clientConn,
  336. SSH_CONNECTION_READ_DEADLINE,
  337. false,
  338. nil)
  339. clientConn = activityConn
  340. // Further wrap the connection in a rate limiting ThrottledConn.
  341. rateLimits := sshClient.trafficRules.GetRateLimits(tunnelProtocol)
  342. clientConn = psiphon.NewThrottledConn(
  343. clientConn,
  344. rateLimits.DownstreamUnlimitedBytes,
  345. int64(rateLimits.DownstreamBytesPerSecond),
  346. rateLimits.UpstreamUnlimitedBytes,
  347. int64(rateLimits.UpstreamBytesPerSecond))
  348. // Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  349. // respect shutdownBroadcast and implement a specific handshake timeout.
  350. // The timeout is to reclaim network resources in case the handshake takes
  351. // too long.
  352. type sshNewServerConnResult struct {
  353. conn net.Conn
  354. sshConn *ssh.ServerConn
  355. channels <-chan ssh.NewChannel
  356. requests <-chan *ssh.Request
  357. err error
  358. }
  359. resultChannel := make(chan *sshNewServerConnResult, 2)
  360. if SSH_HANDSHAKE_TIMEOUT > 0 {
  361. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  362. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  363. })
  364. }
  365. go func(conn net.Conn) {
  366. sshServerConfig := &ssh.ServerConfig{
  367. PasswordCallback: sshClient.passwordCallback,
  368. AuthLogCallback: sshClient.authLogCallback,
  369. ServerVersion: sshServer.support.Config.SSHServerVersion,
  370. }
  371. sshServerConfig.AddHostKey(sshServer.sshHostKey)
  372. result := &sshNewServerConnResult{}
  373. // Wrap the connection in an SSH deobfuscator when required.
  374. if psiphon.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
  375. // Note: NewObfuscatedSshConn blocks on network I/O
  376. // TODO: ensure this won't block shutdown
  377. conn, result.err = psiphon.NewObfuscatedSshConn(
  378. psiphon.OBFUSCATION_CONN_MODE_SERVER,
  379. clientConn,
  380. sshServer.support.Config.ObfuscatedSSHKey)
  381. if result.err != nil {
  382. result.err = psiphon.ContextError(result.err)
  383. }
  384. }
  385. if result.err == nil {
  386. result.sshConn, result.channels, result.requests, result.err =
  387. ssh.NewServerConn(conn, sshServerConfig)
  388. }
  389. resultChannel <- result
  390. }(clientConn)
  391. var result *sshNewServerConnResult
  392. select {
  393. case result = <-resultChannel:
  394. case <-sshServer.shutdownBroadcast:
  395. // Close() will interrupt an ongoing handshake
  396. // TODO: wait for goroutine to exit before returning?
  397. clientConn.Close()
  398. return
  399. }
  400. if result.err != nil {
  401. clientConn.Close()
  402. // This is a Debug log due to noise. The handshake often fails due to I/O
  403. // errors as clients frequently interrupt connections in progress when
  404. // client-side load balancing completes a connection to a different server.
  405. log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed")
  406. return
  407. }
  408. sshClient.Lock()
  409. sshClient.sshConn = result.sshConn
  410. sshClient.activityConn = activityConn
  411. sshClient.Unlock()
  412. clientID, ok := sshServer.registerEstablishedClient(sshClient)
  413. if !ok {
  414. clientConn.Close()
  415. log.WithContext().Warning("register failed")
  416. return
  417. }
  418. defer sshServer.unregisterEstablishedClient(clientID)
  419. sshClient.runClient(result.channels, result.requests)
  420. // Note: sshServer.unregisterClient calls sshClient.Close(),
  421. // which also closes underlying transport Conn.
  422. }
  423. type sshClient struct {
  424. sync.Mutex
  425. sshServer *sshServer
  426. tunnelProtocol string
  427. sshConn ssh.Conn
  428. activityConn *psiphon.ActivityMonitoredConn
  429. geoIPData GeoIPData
  430. psiphonSessionID string
  431. udpChannel ssh.Channel
  432. trafficRules TrafficRules
  433. tcpTrafficState *trafficState
  434. udpTrafficState *trafficState
  435. channelHandlerWaitGroup *sync.WaitGroup
  436. tcpPortForwardLRU *psiphon.LRUConns
  437. stopBroadcast chan struct{}
  438. }
  439. type trafficState struct {
  440. bytesUp int64
  441. bytesDown int64
  442. concurrentPortForwardCount int64
  443. peakConcurrentPortForwardCount int64
  444. totalPortForwardCount int64
  445. }
  446. func newSshClient(
  447. sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
  448. return &sshClient{
  449. sshServer: sshServer,
  450. tunnelProtocol: tunnelProtocol,
  451. geoIPData: geoIPData,
  452. trafficRules: trafficRules,
  453. tcpTrafficState: &trafficState{},
  454. udpTrafficState: &trafficState{},
  455. channelHandlerWaitGroup: new(sync.WaitGroup),
  456. tcpPortForwardLRU: psiphon.NewLRUConns(),
  457. stopBroadcast: make(chan struct{}),
  458. }
  459. }
  460. func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  461. var sshPasswordPayload struct {
  462. SessionId string `json:"SessionId"`
  463. SshPassword string `json:"SshPassword"`
  464. }
  465. err := json.Unmarshal(password, &sshPasswordPayload)
  466. if err != nil {
  467. // Backwards compatibility case: instead of a JSON payload, older clients
  468. // send the hex encoded session ID prepended to the SSH password.
  469. // Note: there's an even older case where clients don't send any session ID,
  470. // but that's no longer supported.
  471. if len(password) == 2*psiphon.PSIPHON_API_CLIENT_SESSION_ID_LENGTH+2*SSH_PASSWORD_BYTE_LENGTH {
  472. sshPasswordPayload.SessionId = string(password[0 : 2*psiphon.PSIPHON_API_CLIENT_SESSION_ID_LENGTH])
  473. sshPasswordPayload.SshPassword = string(password[2*psiphon.PSIPHON_API_CLIENT_SESSION_ID_LENGTH : len(password)])
  474. } else {
  475. return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  476. }
  477. }
  478. if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) {
  479. return nil, psiphon.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
  480. }
  481. userOk := (subtle.ConstantTimeCompare(
  482. []byte(conn.User()), []byte(sshClient.sshServer.support.Config.SSHUserName)) == 1)
  483. passwordOk := (subtle.ConstantTimeCompare(
  484. []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1)
  485. if !userOk || !passwordOk {
  486. return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  487. }
  488. psiphonSessionID := sshPasswordPayload.SessionId
  489. sshClient.Lock()
  490. sshClient.psiphonSessionID = psiphonSessionID
  491. geoIPData := sshClient.geoIPData
  492. sshClient.Unlock()
  493. // Store the GeoIP data associated with the session ID. This makes the GeoIP data
  494. // available to the web server for web transport Psiphon API requests.
  495. sshClient.sshServer.support.GeoIPService.SetSessionCache(
  496. psiphonSessionID, geoIPData)
  497. return nil, nil
  498. }
  499. func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  500. if err != nil {
  501. if method == "none" && err.Error() == "no auth passed yet" {
  502. // In this case, the callback invocation is noise from auth negotiation
  503. return
  504. }
  505. logFields := LogFields{"error": err, "method": method}
  506. if sshClient.sshServer.support.Config.UseFail2Ban() {
  507. clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
  508. if clientIPAddress != "" {
  509. logFields["fail2ban"] = fmt.Sprintf(
  510. sshClient.sshServer.support.Config.Fail2BanFormat, clientIPAddress)
  511. }
  512. }
  513. log.WithContextFields(logFields).Error("authentication failed")
  514. } else {
  515. log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
  516. }
  517. }
  518. func (sshClient *sshClient) stop() {
  519. sshClient.sshConn.Close()
  520. sshClient.sshConn.Wait()
  521. close(sshClient.stopBroadcast)
  522. sshClient.channelHandlerWaitGroup.Wait()
  523. // Note: reporting duration based on last confirmed data transfer, which
  524. // is reads for sshClient.activityConn.GetActiveDuration(), and not
  525. // connection closing is important for protocols such as meek. For
  526. // meek, the connection remains open until the HTTP session expires,
  527. // which may be some time after the tunnel has closed. (The meek
  528. // protocol has no allowance for signalling payload EOF, and even if
  529. // it did the client may not have the opportunity to send a final
  530. // request with an EOF flag set.)
  531. sshClient.Lock()
  532. log.WithContextFields(
  533. LogFields{
  534. "startTime": sshClient.activityConn.GetStartTime(),
  535. "duration": sshClient.activityConn.GetActiveDuration(),
  536. "psiphonSessionID": sshClient.psiphonSessionID,
  537. "country": sshClient.geoIPData.Country,
  538. "city": sshClient.geoIPData.City,
  539. "ISP": sshClient.geoIPData.ISP,
  540. "bytesUpTCP": sshClient.tcpTrafficState.bytesUp,
  541. "bytesDownTCP": sshClient.tcpTrafficState.bytesDown,
  542. "peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
  543. "totalPortForwardCountTCP": sshClient.tcpTrafficState.totalPortForwardCount,
  544. "bytesUpUDP": sshClient.udpTrafficState.bytesUp,
  545. "bytesDownUDP": sshClient.udpTrafficState.bytesDown,
  546. "peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
  547. "totalPortForwardCountUDP": sshClient.udpTrafficState.totalPortForwardCount,
  548. }).Info("tunnel closed")
  549. sshClient.Unlock()
  550. }
  551. // runClient handles/dispatches new channel and new requests from the client.
  552. // When the SSH client connection closes, both the channels and requests channels
  553. // will close and runClient will exit.
  554. func (sshClient *sshClient) runClient(
  555. channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
  556. requestsWaitGroup := new(sync.WaitGroup)
  557. requestsWaitGroup.Add(1)
  558. go func() {
  559. defer requestsWaitGroup.Done()
  560. for request := range requests {
  561. // Requests are processed serially; API responses must be sent in request order.
  562. var responsePayload []byte
  563. var err error
  564. if request.Type == "keepalive@openssh.com" {
  565. // Keepalive requests have an empty response.
  566. } else {
  567. // All other requests are assumed to be API requests.
  568. responsePayload, err = sshAPIRequestHandler(
  569. sshClient.sshServer.support,
  570. sshClient.geoIPData,
  571. request.Type,
  572. request.Payload)
  573. }
  574. if err == nil {
  575. err = request.Reply(true, responsePayload)
  576. } else {
  577. log.WithContextFields(LogFields{"error": err}).Warning("request failed")
  578. err = request.Reply(false, nil)
  579. }
  580. if err != nil {
  581. log.WithContextFields(LogFields{"error": err}).Warning("response failed")
  582. }
  583. }
  584. }()
  585. for newChannel := range channels {
  586. if newChannel.ChannelType() != "direct-tcpip" {
  587. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  588. continue
  589. }
  590. // process each port forward concurrently
  591. sshClient.channelHandlerWaitGroup.Add(1)
  592. go sshClient.handleNewPortForwardChannel(newChannel)
  593. }
  594. requestsWaitGroup.Wait()
  595. }
  596. func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
  597. // TODO: log more details?
  598. log.WithContextFields(
  599. LogFields{
  600. "channelType": newChannel.ChannelType(),
  601. "rejectMessage": message,
  602. "rejectReason": reason,
  603. }).Warning("reject new channel")
  604. newChannel.Reject(reason, message)
  605. }
  606. func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
  607. defer sshClient.channelHandlerWaitGroup.Done()
  608. // http://tools.ietf.org/html/rfc4254#section-7.2
  609. var directTcpipExtraData struct {
  610. HostToConnect string
  611. PortToConnect uint32
  612. OriginatorIPAddress string
  613. OriginatorPort uint32
  614. }
  615. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  616. if err != nil {
  617. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  618. return
  619. }
  620. // Intercept TCP port forwards to a specified udpgw server and handle directly.
  621. // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
  622. isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
  623. sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
  624. fmt.Sprintf("%s:%d",
  625. directTcpipExtraData.HostToConnect,
  626. directTcpipExtraData.PortToConnect)
  627. if isUDPChannel {
  628. sshClient.handleUDPChannel(newChannel)
  629. } else {
  630. sshClient.handleTCPChannel(
  631. directTcpipExtraData.HostToConnect, int(directTcpipExtraData.PortToConnect), newChannel)
  632. }
  633. }
  634. func (sshClient *sshClient) isPortForwardPermitted(
  635. port int, allowPorts []int, denyPorts []int) bool {
  636. // TODO: faster lookup?
  637. if len(allowPorts) > 0 {
  638. for _, allowPort := range allowPorts {
  639. if port == allowPort {
  640. return true
  641. }
  642. }
  643. return false
  644. }
  645. if len(denyPorts) > 0 {
  646. for _, denyPort := range denyPorts {
  647. if port == denyPort {
  648. return false
  649. }
  650. }
  651. }
  652. return true
  653. }
  654. func (sshClient *sshClient) isPortForwardLimitExceeded(
  655. state *trafficState, maxPortForwardCount int) bool {
  656. limitExceeded := false
  657. if maxPortForwardCount > 0 {
  658. sshClient.Lock()
  659. limitExceeded = state.concurrentPortForwardCount >= int64(maxPortForwardCount)
  660. sshClient.Unlock()
  661. }
  662. return limitExceeded
  663. }
  664. func (sshClient *sshClient) openedPortForward(
  665. state *trafficState) {
  666. sshClient.Lock()
  667. state.concurrentPortForwardCount += 1
  668. if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
  669. state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
  670. }
  671. state.totalPortForwardCount += 1
  672. sshClient.Unlock()
  673. }
  674. func (sshClient *sshClient) closedPortForward(
  675. state *trafficState, bytesUp, bytesDown int64) {
  676. sshClient.Lock()
  677. state.concurrentPortForwardCount -= 1
  678. state.bytesUp += bytesUp
  679. state.bytesDown += bytesDown
  680. sshClient.Unlock()
  681. }
  682. func (sshClient *sshClient) handleTCPChannel(
  683. hostToConnect string,
  684. portToConnect int,
  685. newChannel ssh.NewChannel) {
  686. if !sshClient.isPortForwardPermitted(
  687. portToConnect,
  688. sshClient.trafficRules.AllowTCPPorts,
  689. sshClient.trafficRules.DenyTCPPorts) {
  690. sshClient.rejectNewChannel(
  691. newChannel, ssh.Prohibited, "port forward not permitted")
  692. return
  693. }
  694. var bytesUp, bytesDown int64
  695. sshClient.openedPortForward(sshClient.tcpTrafficState)
  696. defer func() {
  697. sshClient.closedPortForward(
  698. sshClient.tcpTrafficState,
  699. atomic.LoadInt64(&bytesUp),
  700. atomic.LoadInt64(&bytesDown))
  701. }()
  702. // TOCTOU note: important to increment the port forward count (via
  703. // openPortForward) _before_ checking isPortForwardLimitExceeded
  704. // otherwise, the client could potentially consume excess resources
  705. // by initiating many port forwards concurrently.
  706. // TODO: close LRU connection (after successful Dial) instead of
  707. // rejecting new connection?
  708. if sshClient.isPortForwardLimitExceeded(
  709. sshClient.tcpTrafficState,
  710. sshClient.trafficRules.MaxTCPPortForwardCount) {
  711. // Close the oldest TCP port forward. CloseOldest() closes
  712. // the conn and the port forward's goroutine will complete
  713. // the cleanup asynchronously.
  714. //
  715. // Some known limitations:
  716. //
  717. // - Since CloseOldest() closes the upstream socket but does not
  718. // clean up all resources associated with the port forward. These
  719. // include the goroutine(s) relaying traffic as well as the SSH
  720. // channel. Closing the socket will interrupt the goroutines which
  721. // will then complete the cleanup. But, since the full cleanup is
  722. // asynchronous, there exists a possibility that a client can consume
  723. // more than max port forward resources -- just not upstream sockets.
  724. //
  725. // - An LRU list entry for this port forward is not added until
  726. // after the dial completes, but the port forward is counted
  727. // towards max limits. This means many dials in progress will
  728. // put established connections in jeopardy.
  729. //
  730. // - We're closing the oldest open connection _before_ successfully
  731. // dialing the new port forward. This means we are potentially
  732. // discarding a good connection to make way for a failed connection.
  733. // We cannot simply dial first and still maintain a limit on
  734. // resources used, so to address this we'd need to add some
  735. // accounting for connections still establishing.
  736. sshClient.tcpPortForwardLRU.CloseOldest()
  737. log.WithContextFields(
  738. LogFields{
  739. "maxCount": sshClient.trafficRules.MaxTCPPortForwardCount,
  740. }).Debug("closed LRU TCP port forward")
  741. }
  742. // Dial the target remote address. This is done in a goroutine to
  743. // ensure the shutdown signal is handled immediately.
  744. remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
  745. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
  746. type dialTcpResult struct {
  747. conn net.Conn
  748. err error
  749. }
  750. resultChannel := make(chan *dialTcpResult, 1)
  751. go func() {
  752. // TODO: on EADDRNOTAVAIL, temporarily suspend new clients
  753. // TODO: IPv6 support
  754. conn, err := net.DialTimeout(
  755. "tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
  756. resultChannel <- &dialTcpResult{conn, err}
  757. }()
  758. var result *dialTcpResult
  759. select {
  760. case result = <-resultChannel:
  761. case <-sshClient.stopBroadcast:
  762. // Note: may leave dial in progress
  763. return
  764. }
  765. if result.err != nil {
  766. sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error())
  767. return
  768. }
  769. // The upstream TCP port forward connection has been established. Schedule
  770. // some cleanup and notify the SSH client that the channel is accepted.
  771. fwdConn := result.conn
  772. defer fwdConn.Close()
  773. lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
  774. defer lruEntry.Remove()
  775. // ActivityMonitoredConn monitors the TCP port forward I/O and updates
  776. // its LRU status. ActivityMonitoredConn also times out read on the port
  777. // forward if both reads and writes have been idle for the specified
  778. // duration.
  779. fwdConn = psiphon.NewActivityMonitoredConn(
  780. fwdConn,
  781. time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond,
  782. true,
  783. lruEntry)
  784. fwdChannel, requests, err := newChannel.Accept()
  785. if err != nil {
  786. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  787. return
  788. }
  789. go ssh.DiscardRequests(requests)
  790. defer fwdChannel.Close()
  791. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
  792. // Relay channel to forwarded connection.
  793. // TODO: relay errors to fwdChannel.Stderr()?
  794. relayWaitGroup := new(sync.WaitGroup)
  795. relayWaitGroup.Add(1)
  796. go func() {
  797. defer relayWaitGroup.Done()
  798. // io.Copy allocates a 32K temporary buffer, and each port forward relay uses
  799. // two of these buffers; using io.CopyBuffer with a smaller buffer reduces the
  800. // overall memory footprint.
  801. bytes, err := io.CopyBuffer(
  802. fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  803. atomic.AddInt64(&bytesDown, bytes)
  804. if err != nil && err != io.EOF {
  805. // Debug since errors such as "connection reset by peer" occur during normal operation
  806. log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
  807. }
  808. // Interrupt upstream io.Copy when downstream is shutting down.
  809. // TODO: this is done to quickly cleanup the port forward when
  810. // fwdConn has a read timeout, but is it clean -- upstream may still
  811. // be flowing?
  812. fwdChannel.Close()
  813. }()
  814. bytes, err := io.CopyBuffer(
  815. fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  816. atomic.AddInt64(&bytesUp, bytes)
  817. if err != nil && err != io.EOF {
  818. log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
  819. }
  820. // Shutdown special case: fwdChannel will be closed and return EOF when
  821. // the SSH connection is closed, but we need to explicitly close fwdConn
  822. // to interrupt the downstream io.Copy, which may be blocked on a
  823. // fwdConn.Read().
  824. fwdConn.Close()
  825. relayWaitGroup.Wait()
  826. log.WithContextFields(
  827. LogFields{
  828. "remoteAddr": remoteAddr,
  829. "bytesUp": atomic.LoadInt64(&bytesUp),
  830. "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
  831. }