sshService.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "crypto/subtle"
  22. "encoding/json"
  23. "errors"
  24. "fmt"
  25. "io"
  26. "net"
  27. "sync"
  28. "time"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
  30. "golang.org/x/crypto/ssh"
  31. )
  32. func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  33. return runSSHServer(config, false, shutdownBroadcast)
  34. }
  35. func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  36. return runSSHServer(config, true, shutdownBroadcast)
  37. }
  38. func runSSHServer(
  39. config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
  40. privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
  41. if err != nil {
  42. return psiphon.ContextError(err)
  43. }
  44. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  45. signer, err := ssh.NewSignerFromKey(privateKey)
  46. if err != nil {
  47. return psiphon.ContextError(err)
  48. }
  49. sshServer := &sshServer{
  50. config: config,
  51. useObfuscation: useObfuscation,
  52. shutdownBroadcast: shutdownBroadcast,
  53. sshHostKey: signer,
  54. nextClientID: 1,
  55. clients: make(map[sshClientID]*sshClient),
  56. }
  57. var serverPort int
  58. if useObfuscation {
  59. serverPort = config.ObfuscatedSSHServerPort
  60. } else {
  61. serverPort = config.SSHServerPort
  62. }
  63. listener, err := net.Listen(
  64. "tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, serverPort))
  65. if err != nil {
  66. return psiphon.ContextError(err)
  67. }
  68. log.WithContextFields(
  69. LogFields{
  70. "useObfuscation": useObfuscation,
  71. "port": serverPort,
  72. }).Info("starting")
  73. err = nil
  74. errors := make(chan error)
  75. waitGroup := new(sync.WaitGroup)
  76. waitGroup.Add(1)
  77. go func() {
  78. defer waitGroup.Done()
  79. loop:
  80. for {
  81. conn, err := listener.Accept()
  82. select {
  83. case <-shutdownBroadcast:
  84. if err == nil {
  85. conn.Close()
  86. }
  87. break loop
  88. default:
  89. }
  90. if err != nil {
  91. if e, ok := err.(net.Error); ok && e.Temporary() {
  92. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  93. // Temporary error, keep running
  94. continue
  95. }
  96. select {
  97. case errors <- psiphon.ContextError(err):
  98. default:
  99. }
  100. break loop
  101. }
  102. // process each client connection concurrently
  103. go sshServer.handleClient(conn.(*net.TCPConn))
  104. }
  105. sshServer.stopClients()
  106. log.WithContextFields(
  107. LogFields{"useObfuscation": useObfuscation}).Info("stopped")
  108. }()
  109. select {
  110. case <-shutdownBroadcast:
  111. case err = <-errors:
  112. }
  113. listener.Close()
  114. waitGroup.Wait()
  115. log.WithContextFields(
  116. LogFields{"useObfuscation": useObfuscation}).Info("exiting")
  117. return err
  118. }
  119. type sshClientID uint64
  120. type sshServer struct {
  121. config *Config
  122. useObfuscation bool
  123. shutdownBroadcast <-chan struct{}
  124. sshHostKey ssh.Signer
  125. nextClientID sshClientID
  126. clientsMutex sync.Mutex
  127. stoppingClients bool
  128. clients map[sshClientID]*sshClient
  129. }
  130. func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
  131. sshServer.clientsMutex.Lock()
  132. defer sshServer.clientsMutex.Unlock()
  133. if sshServer.stoppingClients {
  134. return 0, false
  135. }
  136. clientID := sshServer.nextClientID
  137. sshServer.nextClientID += 1
  138. sshServer.clients[clientID] = client
  139. return clientID, true
  140. }
  141. func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
  142. sshServer.clientsMutex.Lock()
  143. client := sshServer.clients[clientID]
  144. delete(sshServer.clients, clientID)
  145. sshServer.clientsMutex.Unlock()
  146. if client != nil {
  147. sshServer.stopClient(client)
  148. }
  149. }
  150. func (sshServer *sshServer) stopClient(client *sshClient) {
  151. client.sshConn.Close()
  152. client.sshConn.Wait()
  153. client.Lock()
  154. log.WithContextFields(
  155. LogFields{
  156. "startTime": client.startTime,
  157. "duration": time.Now().Sub(client.startTime),
  158. "psiphonSessionID": client.psiphonSessionID,
  159. "country": client.geoIPData.Country,
  160. "city": client.geoIPData.City,
  161. "ISP": client.geoIPData.ISP,
  162. "bytesUp": client.bytesUp,
  163. "bytesDown": client.bytesDown,
  164. "portForwardCount": client.portForwardCount,
  165. "maxConcurrentPortForwardCount": client.maxConcurrentPortForwardCount,
  166. }).Info("tunnel closed")
  167. client.Unlock()
  168. }
  169. func (sshServer *sshServer) stopClients() {
  170. sshServer.clientsMutex.Lock()
  171. sshServer.stoppingClients = true
  172. sshServer.clients = make(map[sshClientID]*sshClient)
  173. sshServer.clientsMutex.Unlock()
  174. for _, client := range sshServer.clients {
  175. sshServer.stopClient(client)
  176. }
  177. }
  178. func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
  179. sshClient := &sshClient{
  180. sshServer: sshServer,
  181. startTime: time.Now(),
  182. geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())),
  183. }
  184. // Wrap the base TCP connection in a TimeoutTCPConn which will terminate
  185. // the connection if it's idle for too long. This timeout is in effect for
  186. // the entire duration of the SSH connection. Clients must actively use
  187. // the connection or send SSH keep alive requests to keep the connection
  188. // active.
  189. conn := psiphon.NewTimeoutTCPConn(tcpConn, SSH_CONNECTION_READ_DEADLINE)
  190. // Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  191. // respect shutdownBroadcast and implement a specific handshake timeout.
  192. // The timeout is to reclaim network resources in case the handshake takes
  193. // too long.
  194. type sshNewServerConnResult struct {
  195. conn net.Conn
  196. sshConn *ssh.ServerConn
  197. channels <-chan ssh.NewChannel
  198. requests <-chan *ssh.Request
  199. err error
  200. }
  201. resultChannel := make(chan *sshNewServerConnResult, 2)
  202. if SSH_HANDSHAKE_TIMEOUT > 0 {
  203. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  204. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  205. })
  206. }
  207. go func() {
  208. result := &sshNewServerConnResult{}
  209. if sshServer.useObfuscation {
  210. result.conn, result.err = psiphon.NewObfuscatedSshConn(
  211. psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey)
  212. } else {
  213. result.conn = conn
  214. }
  215. if result.err == nil {
  216. sshServerConfig := &ssh.ServerConfig{
  217. PasswordCallback: sshClient.passwordCallback,
  218. AuthLogCallback: sshClient.authLogCallback,
  219. ServerVersion: sshServer.config.SSHServerVersion,
  220. }
  221. sshServerConfig.AddHostKey(sshServer.sshHostKey)
  222. result.sshConn, result.channels, result.requests, result.err =
  223. ssh.NewServerConn(result.conn, sshServerConfig)
  224. }
  225. resultChannel <- result
  226. }()
  227. var result *sshNewServerConnResult
  228. select {
  229. case result = <-resultChannel:
  230. case <-sshServer.shutdownBroadcast:
  231. // Close() will interrupt an ongoing handshake
  232. // TODO: wait for goroutine to exit before returning?
  233. conn.Close()
  234. return
  235. }
  236. if result.err != nil {
  237. conn.Close()
  238. log.WithContextFields(LogFields{"error": result.err}).Warning("handshake failed")
  239. return
  240. }
  241. sshClient.Lock()
  242. sshClient.sshConn = result.sshConn
  243. sshClient.Unlock()
  244. clientID, ok := sshServer.registerClient(sshClient)
  245. if !ok {
  246. tcpConn.Close()
  247. log.WithContext().Warning("register failed")
  248. return
  249. }
  250. defer sshServer.unregisterClient(clientID)
  251. go ssh.DiscardRequests(result.requests)
  252. sshClient.handleChannels(result.channels)
  253. }
  254. type sshClient struct {
  255. sync.Mutex
  256. sshServer *sshServer
  257. sshConn ssh.Conn
  258. startTime time.Time
  259. geoIPData GeoIPData
  260. psiphonSessionID string
  261. bytesUp int64
  262. bytesDown int64
  263. portForwardCount int64
  264. concurrentPortForwardCount int64
  265. maxConcurrentPortForwardCount int64
  266. }
  267. func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
  268. for newChannel := range channels {
  269. if newChannel.ChannelType() != "direct-tcpip" {
  270. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  271. return
  272. }
  273. // process each port forward concurrently
  274. go sshClient.handleNewDirectTcpipChannel(newChannel)
  275. }
  276. }
  277. func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
  278. // TODO: log more details?
  279. log.WithContextFields(
  280. LogFields{
  281. "channelType": newChannel.ChannelType(),
  282. "rejectMessage": message,
  283. "rejectReason": reason,
  284. }).Warning("reject new channel")
  285. newChannel.Reject(reason, message)
  286. }
  287. func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
  288. // http://tools.ietf.org/html/rfc4254#section-7.2
  289. var directTcpipExtraData struct {
  290. HostToConnect string
  291. PortToConnect uint32
  292. OriginatorIPAddress string
  293. OriginatorPort uint32
  294. }
  295. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  296. if err != nil {
  297. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  298. return
  299. }
  300. targetAddr := fmt.Sprintf("%s:%d",
  301. directTcpipExtraData.HostToConnect,
  302. directTcpipExtraData.PortToConnect)
  303. log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
  304. // TODO: port forward dial timeout
  305. // TODO: report ssh.ResourceShortage when appropriate
  306. // TODO: IPv6 support
  307. fwdConn, err := net.Dial("tcp4", targetAddr)
  308. if err != nil {
  309. sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
  310. return
  311. }
  312. defer fwdConn.Close()
  313. fwdChannel, requests, err := newChannel.Accept()
  314. if err != nil {
  315. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  316. return
  317. }
  318. sshClient.Lock()
  319. sshClient.portForwardCount += 1
  320. sshClient.concurrentPortForwardCount += 1
  321. if sshClient.concurrentPortForwardCount > sshClient.maxConcurrentPortForwardCount {
  322. sshClient.maxConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
  323. }
  324. sshClient.Unlock()
  325. log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
  326. go ssh.DiscardRequests(requests)
  327. defer fwdChannel.Close()
  328. // relay channel to forwarded connection
  329. // TODO: use a low-memory io.Copy?
  330. // TODO: relay errors to fwdChannel.Stderr()?
  331. var bytesUp, bytesDown int64
  332. relayWaitGroup := new(sync.WaitGroup)
  333. relayWaitGroup.Add(1)
  334. go func() {
  335. defer relayWaitGroup.Done()
  336. var err error
  337. bytesUp, err = io.Copy(fwdConn, fwdChannel)
  338. if err != nil {
  339. log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
  340. }
  341. }()
  342. bytesDown, err = io.Copy(fwdChannel, fwdConn)
  343. if err != nil {
  344. log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
  345. }
  346. fwdChannel.CloseWrite()
  347. relayWaitGroup.Wait()
  348. sshClient.Lock()
  349. sshClient.concurrentPortForwardCount -= 1
  350. sshClient.bytesUp += bytesUp
  351. sshClient.bytesDown += bytesDown
  352. sshClient.Unlock()
  353. log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
  354. }
  355. func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  356. var sshPasswordPayload struct {
  357. SessionId string `json:"SessionId"`
  358. SshPassword string `json:"SshPassword"`
  359. }
  360. err := json.Unmarshal(password, &sshPasswordPayload)
  361. if err != nil {
  362. return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  363. }
  364. userOk := (subtle.ConstantTimeCompare(
  365. []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
  366. passwordOk := (subtle.ConstantTimeCompare(
  367. []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
  368. if !userOk || !passwordOk {
  369. return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  370. }
  371. psiphonSessionID := sshPasswordPayload.SessionId
  372. sshClient.Lock()
  373. sshClient.psiphonSessionID = psiphonSessionID
  374. geoIPData := sshClient.geoIPData
  375. sshClient.Unlock()
  376. if sshClient.sshServer.config.UseRedis() {
  377. err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
  378. if err != nil {
  379. log.WithContextFields(LogFields{
  380. "psiphonSessionID": psiphonSessionID,
  381. "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
  382. // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
  383. }
  384. }
  385. return nil, nil
  386. }
  387. func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  388. if err != nil {
  389. log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
  390. } else {
  391. log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
  392. }
  393. }