sshService.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "crypto/subtle"
  22. "encoding/json"
  23. "errors"
  24. "fmt"
  25. "io"
  26. "net"
  27. "runtime"
  28. "sync"
  29. "sync/atomic"
  30. "time"
  31. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
  32. "golang.org/x/crypto/ssh"
  33. )
  34. // RunSSHServer runs an SSH server, the core tunneling component of the Psiphon
  35. // server. The SSH server runs a selection of listeners that handle connections
  36. // using various, optional obfuscation protocols layered on top of SSH.
  37. // (Currently, just Obfuscated SSH).
  38. //
  39. // RunSSHServer listens on the designated port(s) and spawns new goroutines to handle
  40. // each client connection. It halts when shutdownBroadcast is signaled. A list of active
  41. // clients is maintained, and when halting all clients are first shutdown.
  42. //
  43. // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
  44. // authentication, and then looping on client new channel requests. At this time, only
  45. // "direct-tcpip" channels, dynamic port fowards, are expected and supported.
  46. //
  47. // A new goroutine is spawned to handle each port forward for each client. Each port
  48. // forward tracks its bytes transferred. Overall per-client stats for connection duration,
  49. // GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
  50. // client shuts down.
  51. func RunSSHServer(
  52. config *Config, shutdownBroadcast <-chan struct{}) error {
  53. privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
  54. if err != nil {
  55. return psiphon.ContextError(err)
  56. }
  57. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  58. signer, err := ssh.NewSignerFromKey(privateKey)
  59. if err != nil {
  60. return psiphon.ContextError(err)
  61. }
  62. sshServer := &sshServer{
  63. config: config,
  64. runWaitGroup: new(sync.WaitGroup),
  65. listenerError: make(chan error),
  66. shutdownBroadcast: shutdownBroadcast,
  67. sshHostKey: signer,
  68. nextClientID: 1,
  69. clients: make(map[sshClientID]*sshClient),
  70. }
  71. type sshListener struct {
  72. net.Listener
  73. localAddress string
  74. tunnelProtocol string
  75. }
  76. var listeners []*sshListener
  77. if config.RunSSHServer() {
  78. listeners = append(listeners, &sshListener{
  79. localAddress: fmt.Sprintf(
  80. "%s:%d", config.ServerIPAddress, config.SSHServerPort),
  81. tunnelProtocol: psiphon.TUNNEL_PROTOCOL_SSH,
  82. })
  83. }
  84. if config.RunObfuscatedSSHServer() {
  85. listeners = append(listeners, &sshListener{
  86. localAddress: fmt.Sprintf(
  87. "%s:%d", config.ServerIPAddress, config.ObfuscatedSSHServerPort),
  88. tunnelProtocol: psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH,
  89. })
  90. }
  91. // TODO: add additional protocol listeners here (e.g, meek)
  92. for i, listener := range listeners {
  93. var err error
  94. listener.Listener, err = net.Listen("tcp", listener.localAddress)
  95. if err != nil {
  96. for j := 0; j < i; j++ {
  97. listener.Listener.Close()
  98. }
  99. return psiphon.ContextError(err)
  100. }
  101. log.WithContextFields(
  102. LogFields{
  103. "localAddress": listener.localAddress,
  104. "tunnelProtocol": listener.tunnelProtocol,
  105. }).Info("listening")
  106. }
  107. for _, listener := range listeners {
  108. sshServer.runWaitGroup.Add(1)
  109. go func(listener *sshListener) {
  110. defer sshServer.runWaitGroup.Done()
  111. sshServer.runListener(
  112. listener.Listener, listener.tunnelProtocol)
  113. log.WithContextFields(
  114. LogFields{
  115. "localAddress": listener.localAddress,
  116. "tunnelProtocol": listener.tunnelProtocol,
  117. }).Info("stopping")
  118. }(listener)
  119. }
  120. if config.RunLoadMonitor() {
  121. sshServer.runWaitGroup.Add(1)
  122. go func() {
  123. defer sshServer.runWaitGroup.Done()
  124. sshServer.runLoadMonitor()
  125. }()
  126. }
  127. err = nil
  128. select {
  129. case <-sshServer.shutdownBroadcast:
  130. case err = <-sshServer.listenerError:
  131. }
  132. for _, listener := range listeners {
  133. listener.Close()
  134. }
  135. sshServer.stopClients()
  136. sshServer.runWaitGroup.Wait()
  137. log.WithContext().Info("stopped")
  138. return err
  139. }
  140. type sshClientID uint64
  141. type sshServer struct {
  142. config *Config
  143. runWaitGroup *sync.WaitGroup
  144. listenerError chan error
  145. shutdownBroadcast <-chan struct{}
  146. sshHostKey ssh.Signer
  147. nextClientID sshClientID
  148. clientsMutex sync.Mutex
  149. stoppingClients bool
  150. clients map[sshClientID]*sshClient
  151. }
  152. func (sshServer *sshServer) runListener(
  153. listener net.Listener, tunnelProtocol string) {
  154. for {
  155. conn, err := listener.Accept()
  156. if err == nil && tunnelProtocol == psiphon.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
  157. conn, err = psiphon.NewObfuscatedSshConn(
  158. psiphon.OBFUSCATION_CONN_MODE_SERVER,
  159. conn,
  160. sshServer.config.ObfuscatedSSHKey)
  161. }
  162. select {
  163. case <-sshServer.shutdownBroadcast:
  164. if err == nil {
  165. conn.Close()
  166. }
  167. return
  168. default:
  169. }
  170. if err != nil {
  171. if e, ok := err.(net.Error); ok && e.Temporary() {
  172. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  173. // Temporary error, keep running
  174. continue
  175. }
  176. select {
  177. case sshServer.listenerError <- psiphon.ContextError(err):
  178. default:
  179. }
  180. return
  181. }
  182. // process each client connection concurrently
  183. go sshServer.handleClient(tunnelProtocol, conn)
  184. }
  185. }
  186. func (sshServer *sshServer) runLoadMonitor() {
  187. ticker := time.NewTicker(
  188. time.Duration(sshServer.config.LoadMonitorPeriodSeconds) * time.Second)
  189. defer ticker.Stop()
  190. for {
  191. select {
  192. case <-sshServer.shutdownBroadcast:
  193. return
  194. case <-ticker.C:
  195. var memStats runtime.MemStats
  196. runtime.ReadMemStats(&memStats)
  197. fields := LogFields{
  198. "goroutines": runtime.NumGoroutine(),
  199. "memAlloc": memStats.Alloc,
  200. "memTotalAlloc": memStats.TotalAlloc,
  201. "memSysAlloc": memStats.Sys,
  202. }
  203. for tunnelProtocol, count := range sshServer.countClients() {
  204. fields[tunnelProtocol] = count
  205. }
  206. log.WithContextFields(fields).Info("load")
  207. }
  208. }
  209. }
  210. func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
  211. sshServer.clientsMutex.Lock()
  212. defer sshServer.clientsMutex.Unlock()
  213. if sshServer.stoppingClients {
  214. return 0, false
  215. }
  216. clientID := sshServer.nextClientID
  217. sshServer.nextClientID += 1
  218. sshServer.clients[clientID] = client
  219. return clientID, true
  220. }
  221. func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
  222. sshServer.clientsMutex.Lock()
  223. client := sshServer.clients[clientID]
  224. delete(sshServer.clients, clientID)
  225. sshServer.clientsMutex.Unlock()
  226. if client != nil {
  227. client.stop()
  228. }
  229. }
  230. func (sshServer *sshServer) countClients() map[string]int {
  231. sshServer.clientsMutex.Lock()
  232. defer sshServer.clientsMutex.Unlock()
  233. counts := make(map[string]int)
  234. for _, client := range sshServer.clients {
  235. counts[client.tunnelProtocol] += 1
  236. }
  237. return counts
  238. }
  239. func (sshServer *sshServer) stopClients() {
  240. sshServer.clientsMutex.Lock()
  241. sshServer.stoppingClients = true
  242. sshServer.clients = make(map[sshClientID]*sshClient)
  243. sshServer.clientsMutex.Unlock()
  244. for _, client := range sshServer.clients {
  245. client.stop()
  246. }
  247. }
  248. func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
  249. geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(clientConn.RemoteAddr()))
  250. sshClient := newSshClient(
  251. sshServer,
  252. tunnelProtocol,
  253. geoIPData,
  254. sshServer.config.GetTrafficRules(geoIPData.Country))
  255. // Wrap the base client connection with an IdleTimeoutConn which will terminate
  256. // the connection if no data is received before the deadline. This timeout is
  257. // in effect for the entire duration of the SSH connection. Clients must actively
  258. // use the connection or send SSH keep alive requests to keep the connection
  259. // active.
  260. var conn net.Conn
  261. conn = psiphon.NewIdleTimeoutConn(clientConn, SSH_CONNECTION_READ_DEADLINE, false)
  262. // Further wrap the connection in a rate limiting ThrottledConn.
  263. conn = psiphon.NewThrottledConn(
  264. conn,
  265. int64(sshClient.trafficRules.LimitDownstreamBytesPerSecond),
  266. int64(sshClient.trafficRules.LimitUpstreamBytesPerSecond))
  267. // Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  268. // respect shutdownBroadcast and implement a specific handshake timeout.
  269. // The timeout is to reclaim network resources in case the handshake takes
  270. // too long.
  271. type sshNewServerConnResult struct {
  272. conn net.Conn
  273. sshConn *ssh.ServerConn
  274. channels <-chan ssh.NewChannel
  275. requests <-chan *ssh.Request
  276. err error
  277. }
  278. resultChannel := make(chan *sshNewServerConnResult, 2)
  279. if SSH_HANDSHAKE_TIMEOUT > 0 {
  280. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  281. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  282. })
  283. }
  284. go func(conn net.Conn) {
  285. sshServerConfig := &ssh.ServerConfig{
  286. PasswordCallback: sshClient.passwordCallback,
  287. AuthLogCallback: sshClient.authLogCallback,
  288. ServerVersion: sshServer.config.SSHServerVersion,
  289. }
  290. sshServerConfig.AddHostKey(sshServer.sshHostKey)
  291. sshConn, channels, requests, err :=
  292. ssh.NewServerConn(conn, sshServerConfig)
  293. resultChannel <- &sshNewServerConnResult{
  294. conn: conn,
  295. sshConn: sshConn,
  296. channels: channels,
  297. requests: requests,
  298. err: err,
  299. }
  300. }(conn)
  301. var result *sshNewServerConnResult
  302. select {
  303. case result = <-resultChannel:
  304. case <-sshServer.shutdownBroadcast:
  305. // Close() will interrupt an ongoing handshake
  306. // TODO: wait for goroutine to exit before returning?
  307. conn.Close()
  308. return
  309. }
  310. if result.err != nil {
  311. conn.Close()
  312. log.WithContextFields(LogFields{"error": result.err}).Warning("handshake failed")
  313. return
  314. }
  315. sshClient.Lock()
  316. sshClient.sshConn = result.sshConn
  317. sshClient.Unlock()
  318. clientID, ok := sshServer.registerClient(sshClient)
  319. if !ok {
  320. conn.Close()
  321. log.WithContext().Warning("register failed")
  322. return
  323. }
  324. defer sshServer.unregisterClient(clientID)
  325. go ssh.DiscardRequests(result.requests)
  326. sshClient.handleChannels(result.channels)
  327. }
  328. type sshClient struct {
  329. sync.Mutex
  330. sshServer *sshServer
  331. tunnelProtocol string
  332. sshConn ssh.Conn
  333. startTime time.Time
  334. geoIPData GeoIPData
  335. psiphonSessionID string
  336. udpChannel ssh.Channel
  337. trafficRules TrafficRules
  338. tcpTrafficState *trafficState
  339. udpTrafficState *trafficState
  340. channelHandlerWaitGroup *sync.WaitGroup
  341. stopBroadcast chan struct{}
  342. }
  343. type trafficState struct {
  344. bytesUp int64
  345. bytesDown int64
  346. portForwardCount int64
  347. concurrentPortForwardCount int64
  348. peakConcurrentPortForwardCount int64
  349. }
  350. func newSshClient(
  351. sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
  352. return &sshClient{
  353. sshServer: sshServer,
  354. tunnelProtocol: tunnelProtocol,
  355. startTime: time.Now(),
  356. geoIPData: geoIPData,
  357. trafficRules: trafficRules,
  358. tcpTrafficState: &trafficState{},
  359. udpTrafficState: &trafficState{},
  360. channelHandlerWaitGroup: new(sync.WaitGroup),
  361. stopBroadcast: make(chan struct{}),
  362. }
  363. }
  364. func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
  365. for newChannel := range channels {
  366. if newChannel.ChannelType() != "direct-tcpip" {
  367. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  368. continue
  369. }
  370. // process each port forward concurrently
  371. sshClient.channelHandlerWaitGroup.Add(1)
  372. go sshClient.handleNewPortForwardChannel(newChannel)
  373. }
  374. }
  375. func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
  376. // TODO: log more details?
  377. log.WithContextFields(
  378. LogFields{
  379. "channelType": newChannel.ChannelType(),
  380. "rejectMessage": message,
  381. "rejectReason": reason,
  382. }).Warning("reject new channel")
  383. newChannel.Reject(reason, message)
  384. }
  385. func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
  386. defer sshClient.channelHandlerWaitGroup.Done()
  387. // http://tools.ietf.org/html/rfc4254#section-7.2
  388. var directTcpipExtraData struct {
  389. HostToConnect string
  390. PortToConnect uint32
  391. OriginatorIPAddress string
  392. OriginatorPort uint32
  393. }
  394. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  395. if err != nil {
  396. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  397. return
  398. }
  399. // Intercept TCP port forwards to a specified udpgw server and handle directly.
  400. // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
  401. isUDPChannel := sshClient.sshServer.config.UdpgwServerAddress != "" &&
  402. sshClient.sshServer.config.UdpgwServerAddress ==
  403. fmt.Sprintf("%s:%d",
  404. directTcpipExtraData.HostToConnect,
  405. directTcpipExtraData.PortToConnect)
  406. if isUDPChannel {
  407. sshClient.handleUDPChannel(newChannel)
  408. } else {
  409. sshClient.handleTCPChannel(
  410. directTcpipExtraData.HostToConnect, int(directTcpipExtraData.PortToConnect), newChannel)
  411. }
  412. }
  413. func (sshClient *sshClient) isPortForwardPermitted(
  414. port int, allowPorts []int, denyPorts []int) bool {
  415. // TODO: faster lookup?
  416. if allowPorts != nil {
  417. for _, allowPort := range allowPorts {
  418. if port == allowPort {
  419. return true
  420. }
  421. }
  422. return false
  423. }
  424. if denyPorts != nil {
  425. for _, denyPort := range denyPorts {
  426. if port == denyPort {
  427. return false
  428. }
  429. }
  430. }
  431. return true
  432. }
  433. func (sshClient *sshClient) isPortForwardLimitExceeded(
  434. state *trafficState, maxPortForwardCount int) bool {
  435. limitExceeded := false
  436. if maxPortForwardCount > 0 {
  437. sshClient.Lock()
  438. limitExceeded = state.portForwardCount >= int64(maxPortForwardCount)
  439. sshClient.Unlock()
  440. }
  441. return limitExceeded
  442. }
  443. func (sshClient *sshClient) openedPortForward(
  444. state *trafficState) {
  445. sshClient.Lock()
  446. state.portForwardCount += 1
  447. state.concurrentPortForwardCount += 1
  448. if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
  449. state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
  450. }
  451. sshClient.Unlock()
  452. }
  453. func (sshClient *sshClient) closedPortForward(
  454. state *trafficState, bytesUp, bytesDown int64) {
  455. sshClient.Lock()
  456. state.concurrentPortForwardCount -= 1
  457. state.bytesUp += bytesUp
  458. state.bytesDown += bytesDown
  459. sshClient.Unlock()
  460. }
  461. func (sshClient *sshClient) handleTCPChannel(
  462. hostToConnect string,
  463. portToConnect int,
  464. newChannel ssh.NewChannel) {
  465. if !sshClient.isPortForwardPermitted(
  466. portToConnect,
  467. sshClient.trafficRules.AllowTCPPorts,
  468. sshClient.trafficRules.DenyTCPPorts) {
  469. sshClient.rejectNewChannel(
  470. newChannel, ssh.Prohibited, "port forward not permitted")
  471. return
  472. }
  473. var bytesUp, bytesDown int64
  474. sshClient.openedPortForward(sshClient.tcpTrafficState)
  475. defer sshClient.closedPortForward(
  476. sshClient.tcpTrafficState, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
  477. // TOCTOU note: important to increment the port forward count (via
  478. // openPortForward) _before_ checking isPortForwardLimitExceeded
  479. // otherwise, the client could potentially consume excess resources
  480. // by initiating many port forwards concurrently.
  481. // TODO: close LRU connection (after successful Dial) instead of
  482. // rejecting new connection?
  483. if sshClient.isPortForwardLimitExceeded(
  484. sshClient.tcpTrafficState,
  485. sshClient.trafficRules.MaxTCPPortForwardCount) {
  486. sshClient.rejectNewChannel(
  487. newChannel, ssh.Prohibited, "maximum port forward limit exceeded")
  488. return
  489. }
  490. remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
  491. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
  492. type dialTcpResult struct {
  493. conn net.Conn
  494. err error
  495. }
  496. resultChannel := make(chan *dialTcpResult, 1)
  497. go func() {
  498. // TODO: on EADDRNOTAVAIL, temporarily suspend new clients
  499. // TODO: IPv6 support
  500. conn, err := net.DialTimeout(
  501. "tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
  502. resultChannel <- &dialTcpResult{conn, err}
  503. }()
  504. var result *dialTcpResult
  505. select {
  506. case result = <-resultChannel:
  507. case <-sshClient.stopBroadcast:
  508. // Note: may leave dial in progress
  509. return
  510. }
  511. if result.err != nil {
  512. sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error())
  513. return
  514. }
  515. fwdConn := result.conn
  516. defer fwdConn.Close()
  517. fwdChannel, requests, err := newChannel.Accept()
  518. if err != nil {
  519. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  520. return
  521. }
  522. go ssh.DiscardRequests(requests)
  523. defer fwdChannel.Close()
  524. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
  525. // When idle port forward traffic rules are in place, wrap fwdConn
  526. // in an IdleTimeoutConn configured to reset idle on writes as well
  527. // as read. This ensures the port forward idle timeout only happens
  528. // when both upstream and downstream directions are are idle.
  529. if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
  530. fwdConn = psiphon.NewIdleTimeoutConn(
  531. fwdConn,
  532. time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond,
  533. true)
  534. }
  535. // relay channel to forwarded connection
  536. // TODO: relay errors to fwdChannel.Stderr()?
  537. // TODO: use a low-memory io.Copy?
  538. relayWaitGroup := new(sync.WaitGroup)
  539. relayWaitGroup.Add(1)
  540. go func() {
  541. defer relayWaitGroup.Done()
  542. bytes, err := io.Copy(fwdChannel, fwdConn)
  543. atomic.AddInt64(&bytesDown, bytes)
  544. if err != nil && err != io.EOF {
  545. // Debug since errors such as "connection reset by peer" occur during normal operation
  546. log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
  547. }
  548. }()
  549. bytes, err := io.Copy(fwdConn, fwdChannel)
  550. atomic.AddInt64(&bytesUp, bytes)
  551. if err != nil && err != io.EOF {
  552. log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
  553. }
  554. // Shutdown special case: fwdChannel will be closed and return EOF when
  555. // the SSH connection is closed, but we need to explicitly close fwdConn
  556. // to interrupt the downstream io.Copy, which may be blocked on a
  557. // fwdConn.Read().
  558. fwdConn.Close()
  559. relayWaitGroup.Wait()
  560. log.WithContextFields(
  561. LogFields{
  562. "remoteAddr": remoteAddr,
  563. "bytesUp": atomic.LoadInt64(&bytesUp),
  564. "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
  565. }
  566. func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  567. var sshPasswordPayload struct {
  568. SessionId string `json:"SessionId"`
  569. SshPassword string `json:"SshPassword"`
  570. }
  571. err := json.Unmarshal(password, &sshPasswordPayload)
  572. if err != nil {
  573. return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  574. }
  575. userOk := (subtle.ConstantTimeCompare(
  576. []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
  577. passwordOk := (subtle.ConstantTimeCompare(
  578. []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
  579. if !userOk || !passwordOk {
  580. return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  581. }
  582. psiphonSessionID := sshPasswordPayload.SessionId
  583. sshClient.Lock()
  584. sshClient.psiphonSessionID = psiphonSessionID
  585. geoIPData := sshClient.geoIPData
  586. sshClient.Unlock()
  587. if sshClient.sshServer.config.UseRedis() {
  588. err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
  589. if err != nil {
  590. log.WithContextFields(LogFields{
  591. "psiphonSessionID": psiphonSessionID,
  592. "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
  593. // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
  594. }
  595. }
  596. return nil, nil
  597. }
  598. func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  599. if err != nil {
  600. if sshClient.sshServer.config.UseFail2Ban() {
  601. clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
  602. if clientIPAddress != "" {
  603. LogFail2Ban(clientIPAddress)
  604. }
  605. }
  606. log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
  607. } else {
  608. log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
  609. }
  610. }
  611. func (sshClient *sshClient) stop() {
  612. sshClient.sshConn.Close()
  613. sshClient.sshConn.Wait()
  614. close(sshClient.stopBroadcast)
  615. sshClient.channelHandlerWaitGroup.Wait()
  616. sshClient.Lock()
  617. log.WithContextFields(
  618. LogFields{
  619. "startTime": sshClient.startTime,
  620. "duration": time.Now().Sub(sshClient.startTime),
  621. "psiphonSessionID": sshClient.psiphonSessionID,
  622. "country": sshClient.geoIPData.Country,
  623. "city": sshClient.geoIPData.City,
  624. "ISP": sshClient.geoIPData.ISP,
  625. "bytesUpTCP": sshClient.tcpTrafficState.bytesUp,
  626. "bytesDownTCP": sshClient.tcpTrafficState.bytesDown,
  627. "portForwardCountTCP": sshClient.tcpTrafficState.portForwardCount,
  628. "peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
  629. "bytesUpUDP": sshClient.udpTrafficState.bytesUp,
  630. "bytesDownUDP": sshClient.udpTrafficState.bytesDown,
  631. "portForwardCountUDP": sshClient.udpTrafficState.portForwardCount,
  632. "peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
  633. }).Info("tunnel closed")
  634. sshClient.Unlock()
  635. }