sshService.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "crypto/subtle"
  22. "encoding/json"
  23. "errors"
  24. "fmt"
  25. "io"
  26. "net"
  27. "sync"
  28. "time"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
  30. "golang.org/x/crypto/ssh"
  31. )
  32. // RunSSHServer runs an ssh server with plain SSH protocol.
  33. func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  34. return runSSHServer(config, false, shutdownBroadcast)
  35. }
  36. // RunSSHServer runs an ssh server with Obfuscated SSH protocol.
  37. func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  38. return runSSHServer(config, true, shutdownBroadcast)
  39. }
  40. // runSSHServer runs an SSH or Obfuscated SSH server. In the Obfuscated SSH case, an
  41. // ObfuscatedSSHConn is layered in front of the client TCP connection; otherwise, both
  42. // modes are identical.
  43. //
  44. // runSSHServer listens on the designated port and spawns new goroutines to handle
  45. // each client connection. It halts when shutdownBroadcast is signaled. A list of active
  46. // clients is maintained, and when halting all clients are first shutdown.
  47. //
  48. // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
  49. // authentication, and then looping on client new channel requests. At this time, only
  50. // "direct-tcpip" channels, dynamic port fowards, are expected and supported.
  51. //
  52. // A new goroutine is spawned to handle each port forward. Each port forward tracks its
  53. // bytes transferred. Overall per-client stats for connection duration, GeoIP, number of
  54. // port forwards, and bytes transferred are tracked and logged when the client shuts down.
  55. func runSSHServer(
  56. config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
  57. privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
  58. if err != nil {
  59. return psiphon.ContextError(err)
  60. }
  61. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  62. signer, err := ssh.NewSignerFromKey(privateKey)
  63. if err != nil {
  64. return psiphon.ContextError(err)
  65. }
  66. sshServer := &sshServer{
  67. config: config,
  68. useObfuscation: useObfuscation,
  69. shutdownBroadcast: shutdownBroadcast,
  70. sshHostKey: signer,
  71. nextClientID: 1,
  72. clients: make(map[sshClientID]*sshClient),
  73. }
  74. var serverPort int
  75. if useObfuscation {
  76. serverPort = config.ObfuscatedSSHServerPort
  77. } else {
  78. serverPort = config.SSHServerPort
  79. }
  80. listener, err := net.Listen(
  81. "tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, serverPort))
  82. if err != nil {
  83. return psiphon.ContextError(err)
  84. }
  85. log.WithContextFields(
  86. LogFields{
  87. "useObfuscation": useObfuscation,
  88. "port": serverPort,
  89. }).Info("starting")
  90. err = nil
  91. errors := make(chan error)
  92. waitGroup := new(sync.WaitGroup)
  93. waitGroup.Add(1)
  94. go func() {
  95. defer waitGroup.Done()
  96. loop:
  97. for {
  98. conn, err := listener.Accept()
  99. select {
  100. case <-shutdownBroadcast:
  101. if err == nil {
  102. conn.Close()
  103. }
  104. break loop
  105. default:
  106. }
  107. if err != nil {
  108. if e, ok := err.(net.Error); ok && e.Temporary() {
  109. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  110. // Temporary error, keep running
  111. continue
  112. }
  113. select {
  114. case errors <- psiphon.ContextError(err):
  115. default:
  116. }
  117. break loop
  118. }
  119. // process each client connection concurrently
  120. go sshServer.handleClient(conn.(*net.TCPConn))
  121. }
  122. sshServer.stopClients()
  123. log.WithContextFields(
  124. LogFields{"useObfuscation": useObfuscation}).Info("stopped")
  125. }()
  126. select {
  127. case <-shutdownBroadcast:
  128. case err = <-errors:
  129. }
  130. listener.Close()
  131. waitGroup.Wait()
  132. log.WithContextFields(
  133. LogFields{"useObfuscation": useObfuscation}).Info("exiting")
  134. return err
  135. }
  136. type sshClientID uint64
  137. type sshServer struct {
  138. config *Config
  139. useObfuscation bool
  140. shutdownBroadcast <-chan struct{}
  141. sshHostKey ssh.Signer
  142. nextClientID sshClientID
  143. clientsMutex sync.Mutex
  144. stoppingClients bool
  145. clients map[sshClientID]*sshClient
  146. }
  147. func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
  148. sshServer.clientsMutex.Lock()
  149. defer sshServer.clientsMutex.Unlock()
  150. if sshServer.stoppingClients {
  151. return 0, false
  152. }
  153. clientID := sshServer.nextClientID
  154. sshServer.nextClientID += 1
  155. sshServer.clients[clientID] = client
  156. return clientID, true
  157. }
  158. func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
  159. sshServer.clientsMutex.Lock()
  160. client := sshServer.clients[clientID]
  161. delete(sshServer.clients, clientID)
  162. sshServer.clientsMutex.Unlock()
  163. if client != nil {
  164. sshServer.stopClient(client)
  165. }
  166. }
  167. func (sshServer *sshServer) stopClient(client *sshClient) {
  168. client.sshConn.Close()
  169. client.sshConn.Wait()
  170. client.Lock()
  171. log.WithContextFields(
  172. LogFields{
  173. "startTime": client.startTime,
  174. "duration": time.Now().Sub(client.startTime),
  175. "psiphonSessionID": client.psiphonSessionID,
  176. "country": client.geoIPData.Country,
  177. "city": client.geoIPData.City,
  178. "ISP": client.geoIPData.ISP,
  179. "bytesUpTCP": client.tcpTrafficState.bytesUp,
  180. "bytesDownTCP": client.tcpTrafficState.bytesDown,
  181. "portForwardCountTCP": client.tcpTrafficState.portForwardCount,
  182. "peakConcurrentPortForwardCountTCP": client.tcpTrafficState.peakConcurrentPortForwardCount,
  183. "bytesUpUDP": client.udpTrafficState.bytesUp,
  184. "bytesDownUDP": client.udpTrafficState.bytesDown,
  185. "portForwardCountUDP": client.udpTrafficState.portForwardCount,
  186. "peakConcurrentPortForwardCountUDP": client.udpTrafficState.peakConcurrentPortForwardCount,
  187. }).Info("tunnel closed")
  188. client.Unlock()
  189. }
  190. func (sshServer *sshServer) stopClients() {
  191. sshServer.clientsMutex.Lock()
  192. sshServer.stoppingClients = true
  193. sshServer.clients = make(map[sshClientID]*sshClient)
  194. sshServer.clientsMutex.Unlock()
  195. for _, client := range sshServer.clients {
  196. sshServer.stopClient(client)
  197. }
  198. }
  199. func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
  200. geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
  201. sshClient := &sshClient{
  202. sshServer: sshServer,
  203. startTime: time.Now(),
  204. geoIPData: geoIPData,
  205. trafficRules: sshServer.config.GetTrafficRules(geoIPData.Country),
  206. tcpTrafficState: &trafficState{},
  207. udpTrafficState: &trafficState{},
  208. }
  209. // Wrap the base TCP connection with an IdleTimeoutConn which will terminate
  210. // the connection if no data is received before the deadline. This timeout is
  211. // in effect for the entire duration of the SSH connection. Clients must actively
  212. // use the connection or send SSH keep alive requests to keep the connection
  213. // active.
  214. var conn net.Conn
  215. conn = psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
  216. // Further wrap the connection in a rate limiting ThrottledConn.
  217. conn = psiphon.NewThrottledConn(
  218. conn,
  219. int64(sshClient.trafficRules.LimitDownstreamBytesPerSecond),
  220. int64(sshClient.trafficRules.LimitUpstreamBytesPerSecond))
  221. // Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  222. // respect shutdownBroadcast and implement a specific handshake timeout.
  223. // The timeout is to reclaim network resources in case the handshake takes
  224. // too long.
  225. type sshNewServerConnResult struct {
  226. conn net.Conn
  227. sshConn *ssh.ServerConn
  228. channels <-chan ssh.NewChannel
  229. requests <-chan *ssh.Request
  230. err error
  231. }
  232. resultChannel := make(chan *sshNewServerConnResult, 2)
  233. if SSH_HANDSHAKE_TIMEOUT > 0 {
  234. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  235. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  236. })
  237. }
  238. go func() {
  239. result := &sshNewServerConnResult{}
  240. if sshServer.useObfuscation {
  241. result.conn, result.err = psiphon.NewObfuscatedSshConn(
  242. psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey)
  243. } else {
  244. result.conn = conn
  245. }
  246. if result.err == nil {
  247. sshServerConfig := &ssh.ServerConfig{
  248. PasswordCallback: sshClient.passwordCallback,
  249. AuthLogCallback: sshClient.authLogCallback,
  250. ServerVersion: sshServer.config.SSHServerVersion,
  251. }
  252. sshServerConfig.AddHostKey(sshServer.sshHostKey)
  253. result.sshConn, result.channels, result.requests, result.err =
  254. ssh.NewServerConn(result.conn, sshServerConfig)
  255. }
  256. resultChannel <- result
  257. }()
  258. var result *sshNewServerConnResult
  259. select {
  260. case result = <-resultChannel:
  261. case <-sshServer.shutdownBroadcast:
  262. // Close() will interrupt an ongoing handshake
  263. // TODO: wait for goroutine to exit before returning?
  264. conn.Close()
  265. return
  266. }
  267. if result.err != nil {
  268. conn.Close()
  269. log.WithContextFields(LogFields{"error": result.err}).Warning("handshake failed")
  270. return
  271. }
  272. sshClient.Lock()
  273. sshClient.sshConn = result.sshConn
  274. sshClient.Unlock()
  275. clientID, ok := sshServer.registerClient(sshClient)
  276. if !ok {
  277. conn.Close()
  278. log.WithContext().Warning("register failed")
  279. return
  280. }
  281. defer sshServer.unregisterClient(clientID)
  282. go ssh.DiscardRequests(result.requests)
  283. sshClient.handleChannels(result.channels)
  284. }
  285. type sshClient struct {
  286. sync.Mutex
  287. sshServer *sshServer
  288. sshConn ssh.Conn
  289. startTime time.Time
  290. geoIPData GeoIPData
  291. psiphonSessionID string
  292. udpChannel ssh.Channel
  293. trafficRules TrafficRules
  294. tcpTrafficState *trafficState
  295. udpTrafficState *trafficState
  296. }
  297. type trafficState struct {
  298. bytesUp int64
  299. bytesDown int64
  300. portForwardCount int64
  301. concurrentPortForwardCount int64
  302. peakConcurrentPortForwardCount int64
  303. }
  304. func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
  305. for newChannel := range channels {
  306. if newChannel.ChannelType() != "direct-tcpip" {
  307. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  308. return
  309. }
  310. // process each port forward concurrently
  311. go sshClient.handleNewPortForwardChannel(newChannel)
  312. }
  313. }
  314. func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
  315. // TODO: log more details?
  316. log.WithContextFields(
  317. LogFields{
  318. "channelType": newChannel.ChannelType(),
  319. "rejectMessage": message,
  320. "rejectReason": reason,
  321. }).Warning("reject new channel")
  322. newChannel.Reject(reason, message)
  323. }
  324. func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
  325. // http://tools.ietf.org/html/rfc4254#section-7.2
  326. var directTcpipExtraData struct {
  327. HostToConnect string
  328. PortToConnect uint32
  329. OriginatorIPAddress string
  330. OriginatorPort uint32
  331. }
  332. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  333. if err != nil {
  334. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  335. return
  336. }
  337. // Intercept TCP port forwards to a specified udpgw server and handle directly.
  338. // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
  339. isUDPChannel := sshClient.sshServer.config.UdpgwServerAddress != "" &&
  340. sshClient.sshServer.config.UdpgwServerAddress ==
  341. fmt.Sprintf("%s:%d",
  342. directTcpipExtraData.HostToConnect,
  343. directTcpipExtraData.PortToConnect)
  344. if isUDPChannel {
  345. sshClient.handleUDPChannel(newChannel)
  346. } else {
  347. sshClient.handleTCPChannel(
  348. directTcpipExtraData.HostToConnect, int(directTcpipExtraData.PortToConnect), newChannel)
  349. }
  350. }
  351. func (sshClient *sshClient) isPortForwardPermitted(
  352. port int, allowPorts []int, denyPorts []int) bool {
  353. // TODO: faster lookup?
  354. if allowPorts != nil {
  355. for _, allowPort := range allowPorts {
  356. if port == allowPort {
  357. return true
  358. }
  359. }
  360. return false
  361. }
  362. if denyPorts != nil {
  363. for _, denyPort := range denyPorts {
  364. if port == denyPort {
  365. return false
  366. }
  367. }
  368. }
  369. return true
  370. }
  371. func (sshClient *sshClient) isPortForwardLimitExceeded(
  372. state *trafficState, maxPortForwardCount int) bool {
  373. limitExceeded := false
  374. if maxPortForwardCount > 0 {
  375. sshClient.Lock()
  376. limitExceeded = state.portForwardCount >= int64(maxPortForwardCount)
  377. sshClient.Unlock()
  378. }
  379. return limitExceeded
  380. }
  381. func (sshClient *sshClient) establishedPortForward(
  382. state *trafficState) {
  383. sshClient.Lock()
  384. state.portForwardCount += 1
  385. state.concurrentPortForwardCount += 1
  386. if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
  387. state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
  388. }
  389. sshClient.Unlock()
  390. }
  391. func (sshClient *sshClient) closedPortForward(
  392. state *trafficState, bytesUp, bytesDown int64) {
  393. sshClient.Lock()
  394. state.concurrentPortForwardCount -= 1
  395. state.bytesUp += bytesUp
  396. state.bytesDown += bytesDown
  397. sshClient.Unlock()
  398. }
  399. func (sshClient *sshClient) handleTCPChannel(
  400. hostToConnect string,
  401. portToConnect int,
  402. newChannel ssh.NewChannel) {
  403. if !sshClient.isPortForwardPermitted(
  404. portToConnect,
  405. sshClient.trafficRules.AllowTCPPorts,
  406. sshClient.trafficRules.DenyTCPPorts) {
  407. sshClient.rejectNewChannel(
  408. newChannel, ssh.Prohibited, "port forward not permitted")
  409. return
  410. }
  411. // TODO: close LRU connection (after successful Dial) instead of rejecting new connection?
  412. if sshClient.isPortForwardLimitExceeded(
  413. sshClient.tcpTrafficState,
  414. sshClient.trafficRules.MaxTCPPortForwardCount) {
  415. sshClient.rejectNewChannel(
  416. newChannel, ssh.Prohibited, "maximum port forward limit exceeded")
  417. return
  418. }
  419. targetAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
  420. log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
  421. // TODO: on EADDRNOTAVAIL, temporarily suspend new clients
  422. // TODO: port forward dial timeout
  423. // TODO: IPv6 support
  424. fwdConn, err := net.Dial("tcp4", targetAddr)
  425. if err != nil {
  426. sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
  427. return
  428. }
  429. defer fwdConn.Close()
  430. fwdChannel, requests, err := newChannel.Accept()
  431. if err != nil {
  432. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  433. return
  434. }
  435. go ssh.DiscardRequests(requests)
  436. defer fwdChannel.Close()
  437. sshClient.establishedPortForward(sshClient.tcpTrafficState)
  438. log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
  439. // When idle port forward traffic rules are in place, wrap fwdConn
  440. // in an IdleTimeoutConn configured to reset idle on writes as well
  441. // as read. This ensures the port forward idle timeout only happens
  442. // when both upstream and downstream directions are are idle.
  443. if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
  444. fwdConn = psiphon.NewIdleTimeoutConn(
  445. fwdConn,
  446. time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond,
  447. true)
  448. }
  449. // relay channel to forwarded connection
  450. // TODO: relay errors to fwdChannel.Stderr()?
  451. // TODO: use a low-memory io.Copy?
  452. var bytesUp, bytesDown int64
  453. relayWaitGroup := new(sync.WaitGroup)
  454. relayWaitGroup.Add(1)
  455. go func() {
  456. defer relayWaitGroup.Done()
  457. var err error
  458. bytesUp, err = io.Copy(fwdConn, fwdChannel)
  459. if err != nil {
  460. log.WithContextFields(LogFields{"error": err}).Warning("upstream TCP relay failed")
  461. }
  462. }()
  463. bytesDown, err = io.Copy(fwdChannel, fwdConn)
  464. if err != nil {
  465. log.WithContextFields(LogFields{"error": err}).Warning("downstream TCP relay failed")
  466. }
  467. fwdChannel.CloseWrite()
  468. relayWaitGroup.Wait()
  469. sshClient.closedPortForward(sshClient.tcpTrafficState, bytesUp, bytesDown)
  470. log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
  471. }
  472. func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  473. var sshPasswordPayload struct {
  474. SessionId string `json:"SessionId"`
  475. SshPassword string `json:"SshPassword"`
  476. }
  477. err := json.Unmarshal(password, &sshPasswordPayload)
  478. if err != nil {
  479. return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  480. }
  481. userOk := (subtle.ConstantTimeCompare(
  482. []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
  483. passwordOk := (subtle.ConstantTimeCompare(
  484. []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
  485. if !userOk || !passwordOk {
  486. return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  487. }
  488. psiphonSessionID := sshPasswordPayload.SessionId
  489. sshClient.Lock()
  490. sshClient.psiphonSessionID = psiphonSessionID
  491. geoIPData := sshClient.geoIPData
  492. sshClient.Unlock()
  493. if sshClient.sshServer.config.UseRedis() {
  494. err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
  495. if err != nil {
  496. log.WithContextFields(LogFields{
  497. "psiphonSessionID": psiphonSessionID,
  498. "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
  499. // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
  500. }
  501. }
  502. return nil, nil
  503. }
  504. func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  505. if err != nil {
  506. if sshClient.sshServer.config.UseFail2Ban() {
  507. clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
  508. if clientIPAddress != "" {
  509. LogFail2Ban(clientIPAddress)
  510. }
  511. }
  512. log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
  513. } else {
  514. log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
  515. }
  516. }