sshService.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "crypto/subtle"
  22. "encoding/json"
  23. "errors"
  24. "fmt"
  25. "io"
  26. "net"
  27. "sync"
  28. "time"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
  30. "golang.org/x/crypto/ssh"
  31. )
  32. // RunSSHServer runs an ssh server with plain SSH protocol.
  33. func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  34. return runSSHServer(config, false, shutdownBroadcast)
  35. }
  36. // RunSSHServer runs an ssh server with Obfuscated SSH protocol.
  37. func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  38. return runSSHServer(config, true, shutdownBroadcast)
  39. }
  40. // runSSHServer runs an SSH or Obfuscated SSH server. In the Obfuscated SSH case, an
  41. // ObfuscatedSSHConn is layered in front of the client TCP connection; otherwise, both
  42. // modes are identical.
  43. //
  44. // runSSHServer listens on the designated port and spawns new goroutines to handle
  45. // each client connection. It halts when shutdownBroadcast is signaled. A list of active
  46. // clients is maintained, and when halting all clients are first shutdown.
  47. //
  48. // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
  49. // authentication, and then looping on client new channel requests. At this time, only
  50. // "direct-tcpip" channels, dynamic port fowards, are expected and supported.
  51. //
  52. // A new goroutine is spawned to handle each port forward. Each port forward tracks its
  53. // bytes transferred. Overall per-client stats for connection duration, GeoIP, number of
  54. // port forwards, and bytes transferred are tracked and logged when the client shuts down.
  55. func runSSHServer(
  56. config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
  57. privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
  58. if err != nil {
  59. return psiphon.ContextError(err)
  60. }
  61. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  62. signer, err := ssh.NewSignerFromKey(privateKey)
  63. if err != nil {
  64. return psiphon.ContextError(err)
  65. }
  66. sshServer := &sshServer{
  67. config: config,
  68. useObfuscation: useObfuscation,
  69. shutdownBroadcast: shutdownBroadcast,
  70. sshHostKey: signer,
  71. nextClientID: 1,
  72. clients: make(map[sshClientID]*sshClient),
  73. }
  74. var serverPort int
  75. if useObfuscation {
  76. serverPort = config.ObfuscatedSSHServerPort
  77. } else {
  78. serverPort = config.SSHServerPort
  79. }
  80. listener, err := net.Listen(
  81. "tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, serverPort))
  82. if err != nil {
  83. return psiphon.ContextError(err)
  84. }
  85. log.WithContextFields(
  86. LogFields{
  87. "useObfuscation": useObfuscation,
  88. "port": serverPort,
  89. }).Info("starting")
  90. err = nil
  91. errors := make(chan error)
  92. waitGroup := new(sync.WaitGroup)
  93. waitGroup.Add(1)
  94. go func() {
  95. defer waitGroup.Done()
  96. loop:
  97. for {
  98. conn, err := listener.Accept()
  99. select {
  100. case <-shutdownBroadcast:
  101. if err == nil {
  102. conn.Close()
  103. }
  104. break loop
  105. default:
  106. }
  107. if err != nil {
  108. if e, ok := err.(net.Error); ok && e.Temporary() {
  109. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  110. // Temporary error, keep running
  111. continue
  112. }
  113. select {
  114. case errors <- psiphon.ContextError(err):
  115. default:
  116. }
  117. break loop
  118. }
  119. // process each client connection concurrently
  120. go sshServer.handleClient(conn.(*net.TCPConn))
  121. }
  122. sshServer.stopClients()
  123. log.WithContextFields(
  124. LogFields{"useObfuscation": useObfuscation}).Info("stopped")
  125. }()
  126. select {
  127. case <-shutdownBroadcast:
  128. case err = <-errors:
  129. }
  130. listener.Close()
  131. waitGroup.Wait()
  132. log.WithContextFields(
  133. LogFields{"useObfuscation": useObfuscation}).Info("exiting")
  134. return err
  135. }
  136. type sshClientID uint64
  137. type sshServer struct {
  138. config *Config
  139. useObfuscation bool
  140. shutdownBroadcast <-chan struct{}
  141. sshHostKey ssh.Signer
  142. nextClientID sshClientID
  143. clientsMutex sync.Mutex
  144. stoppingClients bool
  145. clients map[sshClientID]*sshClient
  146. }
  147. func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
  148. sshServer.clientsMutex.Lock()
  149. defer sshServer.clientsMutex.Unlock()
  150. if sshServer.stoppingClients {
  151. return 0, false
  152. }
  153. clientID := sshServer.nextClientID
  154. sshServer.nextClientID += 1
  155. sshServer.clients[clientID] = client
  156. return clientID, true
  157. }
  158. func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
  159. sshServer.clientsMutex.Lock()
  160. client := sshServer.clients[clientID]
  161. delete(sshServer.clients, clientID)
  162. sshServer.clientsMutex.Unlock()
  163. if client != nil {
  164. sshServer.stopClient(client)
  165. }
  166. }
  167. func (sshServer *sshServer) stopClient(client *sshClient) {
  168. client.sshConn.Close()
  169. client.sshConn.Wait()
  170. client.Lock()
  171. log.WithContextFields(
  172. LogFields{
  173. "startTime": client.startTime,
  174. "duration": time.Now().Sub(client.startTime),
  175. "psiphonSessionID": client.psiphonSessionID,
  176. "country": client.geoIPData.Country,
  177. "city": client.geoIPData.City,
  178. "ISP": client.geoIPData.ISP,
  179. "bytesUp": client.bytesUp,
  180. "bytesDown": client.bytesDown,
  181. "portForwardCount": client.portForwardCount,
  182. "peakConcurrentPortForwardCount": client.peakConcurrentPortForwardCount,
  183. }).Info("tunnel closed")
  184. client.Unlock()
  185. }
  186. func (sshServer *sshServer) stopClients() {
  187. sshServer.clientsMutex.Lock()
  188. sshServer.stoppingClients = true
  189. sshServer.clients = make(map[sshClientID]*sshClient)
  190. sshServer.clientsMutex.Unlock()
  191. for _, client := range sshServer.clients {
  192. sshServer.stopClient(client)
  193. }
  194. }
  195. func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
  196. sshClient := &sshClient{
  197. sshServer: sshServer,
  198. startTime: time.Now(),
  199. geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())),
  200. }
  201. sshClient.trafficRules = sshServer.config.GetTrafficRules(sshClient.geoIPData.Country)
  202. // Wrap the base TCP connection with an IdleTimeoutConn which will terminate
  203. // the connection if no data is received before the deadline. This timeout is
  204. // in effect for the entire duration of the SSH connection. Clients must actively
  205. // use the connection or send SSH keep alive requests to keep the connection
  206. // active.
  207. conn := psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
  208. // Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  209. // respect shutdownBroadcast and implement a specific handshake timeout.
  210. // The timeout is to reclaim network resources in case the handshake takes
  211. // too long.
  212. type sshNewServerConnResult struct {
  213. conn net.Conn
  214. sshConn *ssh.ServerConn
  215. channels <-chan ssh.NewChannel
  216. requests <-chan *ssh.Request
  217. err error
  218. }
  219. resultChannel := make(chan *sshNewServerConnResult, 2)
  220. if SSH_HANDSHAKE_TIMEOUT > 0 {
  221. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  222. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  223. })
  224. }
  225. go func() {
  226. result := &sshNewServerConnResult{}
  227. if sshServer.useObfuscation {
  228. result.conn, result.err = psiphon.NewObfuscatedSshConn(
  229. psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey)
  230. } else {
  231. result.conn = conn
  232. }
  233. if result.err == nil {
  234. sshServerConfig := &ssh.ServerConfig{
  235. PasswordCallback: sshClient.passwordCallback,
  236. AuthLogCallback: sshClient.authLogCallback,
  237. ServerVersion: sshServer.config.SSHServerVersion,
  238. }
  239. sshServerConfig.AddHostKey(sshServer.sshHostKey)
  240. result.sshConn, result.channels, result.requests, result.err =
  241. ssh.NewServerConn(result.conn, sshServerConfig)
  242. }
  243. resultChannel <- result
  244. }()
  245. var result *sshNewServerConnResult
  246. select {
  247. case result = <-resultChannel:
  248. case <-sshServer.shutdownBroadcast:
  249. // Close() will interrupt an ongoing handshake
  250. // TODO: wait for goroutine to exit before returning?
  251. conn.Close()
  252. return
  253. }
  254. if result.err != nil {
  255. conn.Close()
  256. log.WithContextFields(LogFields{"error": result.err}).Warning("handshake failed")
  257. return
  258. }
  259. sshClient.Lock()
  260. sshClient.sshConn = result.sshConn
  261. sshClient.Unlock()
  262. clientID, ok := sshServer.registerClient(sshClient)
  263. if !ok {
  264. conn.Close()
  265. log.WithContext().Warning("register failed")
  266. return
  267. }
  268. defer sshServer.unregisterClient(clientID)
  269. go ssh.DiscardRequests(result.requests)
  270. sshClient.handleChannels(result.channels)
  271. }
  272. type sshClient struct {
  273. sync.Mutex
  274. sshServer *sshServer
  275. sshConn ssh.Conn
  276. startTime time.Time
  277. geoIPData GeoIPData
  278. trafficRules TrafficRules
  279. psiphonSessionID string
  280. bytesUp int64
  281. bytesDown int64
  282. portForwardCount int64
  283. concurrentPortForwardCount int64
  284. peakConcurrentPortForwardCount int64
  285. }
  286. func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
  287. for newChannel := range channels {
  288. if newChannel.ChannelType() != "direct-tcpip" {
  289. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  290. return
  291. }
  292. if sshClient.trafficRules.MaxClientPortForwardCount > 0 {
  293. sshClient.Lock()
  294. limitExceeded := sshClient.portForwardCount >= int64(sshClient.trafficRules.MaxClientPortForwardCount)
  295. sshClient.Unlock()
  296. if limitExceeded {
  297. sshClient.rejectNewChannel(
  298. newChannel, ssh.ResourceShortage, "maximum port forward limit exceeded")
  299. return
  300. }
  301. }
  302. // process each port forward concurrently
  303. go sshClient.handleNewDirectTcpipChannel(newChannel)
  304. }
  305. }
  306. func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
  307. // TODO: log more details?
  308. log.WithContextFields(
  309. LogFields{
  310. "channelType": newChannel.ChannelType(),
  311. "rejectMessage": message,
  312. "rejectReason": reason,
  313. }).Warning("reject new channel")
  314. newChannel.Reject(reason, message)
  315. }
  316. func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
  317. // http://tools.ietf.org/html/rfc4254#section-7.2
  318. var directTcpipExtraData struct {
  319. HostToConnect string
  320. PortToConnect uint32
  321. OriginatorIPAddress string
  322. OriginatorPort uint32
  323. }
  324. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  325. if err != nil {
  326. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  327. return
  328. }
  329. targetAddr := fmt.Sprintf("%s:%d",
  330. directTcpipExtraData.HostToConnect,
  331. directTcpipExtraData.PortToConnect)
  332. log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
  333. // TODO: port forward dial timeout
  334. // TODO: report ssh.ResourceShortage when appropriate
  335. // TODO: IPv6 support
  336. fwdConn, err := net.Dial("tcp4", targetAddr)
  337. if err != nil {
  338. sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
  339. return
  340. }
  341. defer fwdConn.Close()
  342. fwdChannel, requests, err := newChannel.Accept()
  343. if err != nil {
  344. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  345. return
  346. }
  347. sshClient.Lock()
  348. sshClient.portForwardCount += 1
  349. sshClient.concurrentPortForwardCount += 1
  350. if sshClient.concurrentPortForwardCount > sshClient.peakConcurrentPortForwardCount {
  351. sshClient.peakConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
  352. }
  353. sshClient.Unlock()
  354. log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
  355. go ssh.DiscardRequests(requests)
  356. defer fwdChannel.Close()
  357. // When idle port forward traffic rules are in place, wrap fwdConn
  358. // in an IdleTimeoutConn configured to reset idle on writes as well
  359. // as read. This ensures the port forward idle timeout only happens
  360. // when both upstream and downstream directions are are idle.
  361. if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
  362. fwdConn = psiphon.NewIdleTimeoutConn(
  363. fwdConn,
  364. time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond,
  365. true)
  366. }
  367. // relay channel to forwarded connection
  368. // TODO: relay errors to fwdChannel.Stderr()?
  369. var bytesUp, bytesDown int64
  370. relayWaitGroup := new(sync.WaitGroup)
  371. relayWaitGroup.Add(1)
  372. go func() {
  373. defer relayWaitGroup.Done()
  374. var err error
  375. bytesUp, err = copyWithThrottle(
  376. fwdConn, fwdChannel, sshClient.trafficRules.ThrottleUpstreamSleepMilliseconds)
  377. if err != nil {
  378. log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
  379. }
  380. }()
  381. bytesDown, err = copyWithThrottle(
  382. fwdChannel, fwdConn, sshClient.trafficRules.ThrottleDownstreamSleepMilliseconds)
  383. if err != nil {
  384. log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
  385. }
  386. fwdChannel.CloseWrite()
  387. relayWaitGroup.Wait()
  388. sshClient.Lock()
  389. sshClient.concurrentPortForwardCount -= 1
  390. sshClient.bytesUp += bytesUp
  391. sshClient.bytesDown += bytesDown
  392. sshClient.Unlock()
  393. log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
  394. }
  395. func copyWithThrottle(dst io.Writer, src io.Reader, throttleSleepMilliseconds int) (int64, error) {
  396. // TODO: use a low-memory io.Copy?
  397. if throttleSleepMilliseconds <= 0 {
  398. // No throttle
  399. return io.Copy(dst, src)
  400. }
  401. var totalBytes int64
  402. for {
  403. bytes, err := io.CopyN(dst, src, SSH_THROTTLED_PORT_FORWARD_MAX_COPY)
  404. totalBytes += bytes
  405. if err == io.EOF {
  406. err = nil
  407. break
  408. }
  409. if err != nil {
  410. return totalBytes, psiphon.ContextError(err)
  411. }
  412. time.Sleep(time.Duration(throttleSleepMilliseconds) * time.Millisecond)
  413. }
  414. return totalBytes, nil
  415. }
  416. func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  417. var sshPasswordPayload struct {
  418. SessionId string `json:"SessionId"`
  419. SshPassword string `json:"SshPassword"`
  420. }
  421. err := json.Unmarshal(password, &sshPasswordPayload)
  422. if err != nil {
  423. return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  424. }
  425. userOk := (subtle.ConstantTimeCompare(
  426. []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
  427. passwordOk := (subtle.ConstantTimeCompare(
  428. []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
  429. if !userOk || !passwordOk {
  430. return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  431. }
  432. psiphonSessionID := sshPasswordPayload.SessionId
  433. sshClient.Lock()
  434. sshClient.psiphonSessionID = psiphonSessionID
  435. geoIPData := sshClient.geoIPData
  436. sshClient.Unlock()
  437. if sshClient.sshServer.config.UseRedis() {
  438. err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
  439. if err != nil {
  440. log.WithContextFields(LogFields{
  441. "psiphonSessionID": psiphonSessionID,
  442. "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
  443. // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
  444. }
  445. }
  446. return nil, nil
  447. }
  448. func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  449. if err != nil {
  450. log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
  451. } else {
  452. log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
  453. }
  454. }