sshService.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "crypto/subtle"
  22. "encoding/json"
  23. "errors"
  24. "fmt"
  25. "io"
  26. "net"
  27. "sync"
  28. "time"
  29. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
  30. "golang.org/x/crypto/ssh"
  31. )
  32. func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  33. return runSSHServer(config, false, shutdownBroadcast)
  34. }
  35. func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  36. return runSSHServer(config, true, shutdownBroadcast)
  37. }
  38. type sshServer struct {
  39. config *Config
  40. useObfuscation bool
  41. shutdownBroadcast <-chan struct{}
  42. sshConfig *ssh.ServerConfig
  43. clientsMutex sync.Mutex
  44. stoppingClients bool
  45. clients map[string]*sshClient
  46. }
  47. type sshClient struct {
  48. sync.Mutex
  49. sshConn ssh.Conn
  50. startTime time.Time
  51. geoIPData GeoIPData
  52. psiphonSessionID string
  53. bytesUp int64
  54. bytesDown int64
  55. portForwardCount int64
  56. concurrentPortForwardCount int64
  57. maxConcurrentPortForwardCount int64
  58. }
  59. func runSSHServer(
  60. config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
  61. sshServer := &sshServer{
  62. config: config,
  63. useObfuscation: useObfuscation,
  64. shutdownBroadcast: shutdownBroadcast,
  65. clients: make(map[string]*sshClient),
  66. }
  67. sshServer.sshConfig = &ssh.ServerConfig{
  68. PasswordCallback: sshServer.passwordCallback,
  69. AuthLogCallback: sshServer.authLogCallback,
  70. ServerVersion: config.SSHServerVersion,
  71. }
  72. privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
  73. if err != nil {
  74. return psiphon.ContextError(err)
  75. }
  76. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  77. signer, err := ssh.NewSignerFromKey(privateKey)
  78. if err != nil {
  79. return psiphon.ContextError(err)
  80. }
  81. sshServer.sshConfig.AddHostKey(signer)
  82. var serverPort int
  83. if useObfuscation {
  84. serverPort = config.ObfuscatedSSHServerPort
  85. } else {
  86. serverPort = config.SSHServerPort
  87. }
  88. listener, err := net.Listen(
  89. "tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, serverPort))
  90. if err != nil {
  91. return psiphon.ContextError(err)
  92. }
  93. log.WithContextFields(
  94. LogFields{
  95. "useObfuscation": useObfuscation,
  96. "port": serverPort,
  97. }).Info("starting")
  98. err = nil
  99. errors := make(chan error)
  100. waitGroup := new(sync.WaitGroup)
  101. waitGroup.Add(1)
  102. go func() {
  103. defer waitGroup.Done()
  104. loop:
  105. for {
  106. conn, err := listener.Accept()
  107. select {
  108. case <-shutdownBroadcast:
  109. if err == nil {
  110. conn.Close()
  111. }
  112. break loop
  113. default:
  114. }
  115. if err != nil {
  116. if e, ok := err.(net.Error); ok && e.Temporary() {
  117. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  118. // Temporary error, keep running
  119. continue
  120. }
  121. select {
  122. case errors <- psiphon.ContextError(err):
  123. default:
  124. }
  125. break loop
  126. }
  127. // process each client connection concurrently
  128. go sshServer.handleClient(conn)
  129. }
  130. sshServer.stopClients()
  131. log.WithContextFields(
  132. LogFields{"useObfuscation": useObfuscation}).Info("stopped")
  133. }()
  134. select {
  135. case <-shutdownBroadcast:
  136. case err = <-errors:
  137. }
  138. listener.Close()
  139. waitGroup.Wait()
  140. log.WithContextFields(
  141. LogFields{"useObfuscation": useObfuscation}).Info("exiting")
  142. return err
  143. }
  144. func (sshServer *sshServer) registerClient(client *sshClient) (string, bool) {
  145. sshServer.clientsMutex.Lock()
  146. defer sshServer.clientsMutex.Unlock()
  147. if sshServer.stoppingClients {
  148. return "", false
  149. }
  150. key := string(client.sshConn.SessionID())
  151. existingClient := sshServer.clients[key]
  152. if existingClient != nil {
  153. log.WithContext().Warning("unexpected existing connection")
  154. client.sshConn.Close()
  155. client.sshConn.Wait()
  156. }
  157. sshServer.clients[key] = client
  158. return key, true
  159. }
  160. func (sshServer *sshServer) updateClient(
  161. clientKey string, updater func(*sshClient)) {
  162. sshServer.clientsMutex.Lock()
  163. sshClient, ok := sshServer.clients[clientKey]
  164. sshServer.clientsMutex.Unlock()
  165. if ok {
  166. sshClient.Lock()
  167. updater(sshClient)
  168. sshClient.Unlock()
  169. }
  170. }
  171. func (sshServer *sshServer) unregisterClient(clientKey string) {
  172. sshServer.clientsMutex.Lock()
  173. if sshServer.stoppingClients {
  174. return
  175. }
  176. client := sshServer.clients[clientKey]
  177. delete(sshServer.clients, clientKey)
  178. sshServer.clientsMutex.Unlock()
  179. if client == nil {
  180. return
  181. }
  182. client.sshConn.Close()
  183. client.Lock()
  184. log.WithContextFields(
  185. LogFields{
  186. "startTime": client.startTime,
  187. "duration": time.Now().Sub(client.startTime),
  188. "psiphonSessionID": client.psiphonSessionID,
  189. "country": client.geoIPData.Country,
  190. "city": client.geoIPData.City,
  191. "ISP": client.geoIPData.ISP,
  192. "bytesUp": client.bytesUp,
  193. "bytesDown": client.bytesDown,
  194. "portForwardCount": client.portForwardCount,
  195. "maxConcurrentPortForwardCount": client.maxConcurrentPortForwardCount,
  196. }).Info("tunnel closed")
  197. client.Unlock()
  198. }
  199. func (sshServer *sshServer) stopClients() {
  200. sshServer.clientsMutex.Lock()
  201. sshServer.stoppingClients = true
  202. sshServer.clientsMutex.Unlock()
  203. for _, client := range sshServer.clients {
  204. client.sshConn.Close()
  205. client.sshConn.Wait()
  206. }
  207. sshServer.clients = make(map[string]*sshClient)
  208. }
  209. func (sshServer *sshServer) handleClient(conn net.Conn) {
  210. startTime := time.Now()
  211. geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(conn.RemoteAddr()))
  212. // Run the initial [obfuscated] SSH handshake in a goroutine
  213. // so we can both respect shutdownBroadcast and implement a
  214. // handshake timeout. The timeout is to reclaim network
  215. // resources in case the handshake takes too long.
  216. type sshNewServerConnResult struct {
  217. conn net.Conn
  218. sshConn *ssh.ServerConn
  219. channels <-chan ssh.NewChannel
  220. requests <-chan *ssh.Request
  221. err error
  222. }
  223. resultChannel := make(chan *sshNewServerConnResult, 2)
  224. if SSH_HANDSHAKE_TIMEOUT > 0 {
  225. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  226. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  227. })
  228. }
  229. go func() {
  230. result := &sshNewServerConnResult{}
  231. if sshServer.useObfuscation {
  232. result.conn, result.err = psiphon.NewObfuscatedSshConn(
  233. psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey)
  234. } else {
  235. result.conn = conn
  236. }
  237. if result.err == nil {
  238. result.sshConn, result.channels,
  239. result.requests, result.err = ssh.NewServerConn(result.conn, sshServer.sshConfig)
  240. }
  241. resultChannel <- result
  242. }()
  243. var result *sshNewServerConnResult
  244. select {
  245. case result = <-resultChannel:
  246. case <-sshServer.shutdownBroadcast:
  247. // Close() will interrupt an ongoing handshake
  248. // TODO: wait for goroutine to exit before returning?
  249. conn.Close()
  250. return
  251. }
  252. if result.err != nil {
  253. conn.Close()
  254. log.WithContextFields(LogFields{"error": result.err}).Warning("handshake failed")
  255. return
  256. }
  257. clientKey, ok := sshServer.registerClient(
  258. &sshClient{
  259. sshConn: result.sshConn,
  260. startTime: startTime,
  261. geoIPData: geoIPData,
  262. })
  263. if !ok {
  264. result.sshConn.Close()
  265. log.WithContext().Warning("register failed")
  266. return
  267. }
  268. defer sshServer.unregisterClient(clientKey)
  269. go ssh.DiscardRequests(result.requests)
  270. for newChannel := range result.channels {
  271. if newChannel.ChannelType() != "direct-tcpip" {
  272. sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  273. return
  274. }
  275. // process each port forward concurrently
  276. go sshServer.handleNewDirectTcpipChannel(clientKey, newChannel)
  277. }
  278. }
  279. func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  280. var sshPasswordPayload struct {
  281. SessionId string `json:"SessionId"`
  282. SshPassword string `json:"SshPassword"`
  283. }
  284. err := json.Unmarshal(password, &sshPasswordPayload)
  285. if err != nil {
  286. return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  287. }
  288. userOk := (subtle.ConstantTimeCompare(
  289. []byte(conn.User()), []byte(sshServer.config.SSHUserName)) == 1)
  290. passwordOk := (subtle.ConstantTimeCompare(
  291. []byte(sshPasswordPayload.SshPassword), []byte(sshServer.config.SSHPassword)) == 1)
  292. if !userOk || !passwordOk {
  293. return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  294. }
  295. clientKey := string(conn.SessionID())
  296. sshServer.updateClient(clientKey, func(client *sshClient) {
  297. client.psiphonSessionID = sshPasswordPayload.SessionId
  298. })
  299. return nil, nil
  300. }
  301. func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  302. if err != nil {
  303. log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
  304. } else {
  305. log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
  306. }
  307. }
  308. func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
  309. // TODO: log more details?
  310. log.WithContextFields(
  311. LogFields{
  312. "channelType": newChannel.ChannelType(),
  313. "rejectMessage": message,
  314. "rejectReason": reason,
  315. }).Warning("reject new channel")
  316. newChannel.Reject(reason, message)
  317. }
  318. func (sshServer *sshServer) handleNewDirectTcpipChannel(clientKey string, newChannel ssh.NewChannel) {
  319. // http://tools.ietf.org/html/rfc4254#section-7.2
  320. var directTcpipExtraData struct {
  321. HostToConnect string
  322. PortToConnect uint32
  323. OriginatorIPAddress string
  324. OriginatorPort uint32
  325. }
  326. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  327. if err != nil {
  328. sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  329. return
  330. }
  331. targetAddr := fmt.Sprintf("%s:%d",
  332. directTcpipExtraData.HostToConnect,
  333. directTcpipExtraData.PortToConnect)
  334. log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
  335. // TODO: port forward dial timeout
  336. // TODO: report ssh.ResourceShortage when appropriate
  337. fwdConn, err := net.Dial("tcp", targetAddr)
  338. if err != nil {
  339. sshServer.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
  340. return
  341. }
  342. defer fwdConn.Close()
  343. fwdChannel, requests, err := newChannel.Accept()
  344. if err != nil {
  345. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  346. return
  347. }
  348. sshServer.updateClient(clientKey, func(client *sshClient) {
  349. client.portForwardCount += 1
  350. client.concurrentPortForwardCount += 1
  351. if client.concurrentPortForwardCount > client.maxConcurrentPortForwardCount {
  352. client.maxConcurrentPortForwardCount = client.concurrentPortForwardCount
  353. }
  354. })
  355. log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
  356. go ssh.DiscardRequests(requests)
  357. defer fwdChannel.Close()
  358. // relay channel to forwarded connection
  359. // TODO: use a low-memory io.Copy?
  360. // TODO: relay errors to fwdChannel.Stderr()?
  361. var bytesUp, bytesDown int64
  362. relayWaitGroup := new(sync.WaitGroup)
  363. relayWaitGroup.Add(1)
  364. go func() {
  365. defer relayWaitGroup.Done()
  366. var err error
  367. bytesUp, err = io.Copy(fwdConn, fwdChannel)
  368. if err != nil {
  369. log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
  370. }
  371. }()
  372. bytesDown, err = io.Copy(fwdChannel, fwdConn)
  373. if err != nil {
  374. log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
  375. }
  376. fwdChannel.CloseWrite()
  377. relayWaitGroup.Wait()
  378. sshServer.updateClient(clientKey, func(client *sshClient) {
  379. client.concurrentPortForwardCount -= 1
  380. client.bytesUp += bytesUp
  381. client.bytesDown += bytesDown
  382. })
  383. log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
  384. }