sshService.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "crypto/subtle"
  22. "encoding/hex"
  23. "encoding/json"
  24. "errors"
  25. "fmt"
  26. "io"
  27. "net"
  28. "sync"
  29. "time"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
  31. "golang.org/x/crypto/ssh"
  32. )
  33. func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  34. return runSSHServer(config, false, shutdownBroadcast)
  35. }
  36. func RunObfuscatedSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  37. return runSSHServer(config, true, shutdownBroadcast)
  38. }
  39. type sshServer struct {
  40. config *Config
  41. useObfuscation bool
  42. shutdownBroadcast <-chan struct{}
  43. sshConfig *ssh.ServerConfig
  44. clientMutex sync.Mutex
  45. stoppingClients bool
  46. clients map[string]ssh.Conn
  47. }
  48. func runSSHServer(
  49. config *Config, useObfuscation bool, shutdownBroadcast <-chan struct{}) error {
  50. sshServer := &sshServer{
  51. config: config,
  52. useObfuscation: useObfuscation,
  53. shutdownBroadcast: shutdownBroadcast,
  54. clients: make(map[string]ssh.Conn),
  55. }
  56. sshServer.sshConfig = &ssh.ServerConfig{
  57. PasswordCallback: sshServer.passwordCallback,
  58. AuthLogCallback: sshServer.authLogCallback,
  59. ServerVersion: config.SSHServerVersion,
  60. }
  61. privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
  62. if err != nil {
  63. return psiphon.ContextError(err)
  64. }
  65. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  66. signer, err := ssh.NewSignerFromKey(privateKey)
  67. if err != nil {
  68. return psiphon.ContextError(err)
  69. }
  70. sshServer.sshConfig.AddHostKey(signer)
  71. var serverPort int
  72. if useObfuscation {
  73. serverPort = config.ObfuscatedSSHServerPort
  74. } else {
  75. serverPort = config.SSHServerPort
  76. }
  77. listener, err := net.Listen(
  78. "tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, serverPort))
  79. if err != nil {
  80. return psiphon.ContextError(err)
  81. }
  82. log.WithContextFields(
  83. LogFields{
  84. "useObfuscation": useObfuscation,
  85. "port": serverPort,
  86. }).Info("starting")
  87. err = nil
  88. errors := make(chan error)
  89. waitGroup := new(sync.WaitGroup)
  90. waitGroup.Add(1)
  91. go func() {
  92. defer waitGroup.Done()
  93. loop:
  94. for {
  95. conn, err := listener.Accept()
  96. select {
  97. case <-shutdownBroadcast:
  98. if err == nil {
  99. conn.Close()
  100. }
  101. break loop
  102. default:
  103. }
  104. if err != nil {
  105. if e, ok := err.(net.Error); ok && e.Temporary() {
  106. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  107. // Temporary error, keep running
  108. continue
  109. }
  110. select {
  111. case errors <- psiphon.ContextError(err):
  112. default:
  113. }
  114. break loop
  115. }
  116. // process each client connection concurrently
  117. go sshServer.handleClient(conn)
  118. }
  119. sshServer.stopClients()
  120. log.WithContextFields(
  121. LogFields{"useObfuscation": useObfuscation}).Info("stopped")
  122. }()
  123. select {
  124. case <-shutdownBroadcast:
  125. case err = <-errors:
  126. }
  127. listener.Close()
  128. waitGroup.Wait()
  129. log.WithContextFields(
  130. LogFields{"useObfuscation": useObfuscation}).Info("exiting")
  131. return err
  132. }
  133. func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  134. var sshPasswordPayload struct {
  135. SessionId string `json:"SessionId"`
  136. SshPassword string `json:"SshPassword"`
  137. }
  138. err := json.Unmarshal(password, &sshPasswordPayload)
  139. if err != nil {
  140. return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  141. }
  142. userOk := (subtle.ConstantTimeCompare(
  143. []byte(conn.User()), []byte(sshServer.config.SSHUserName)) == 1)
  144. passwordOk := (subtle.ConstantTimeCompare(
  145. []byte(sshPasswordPayload.SshPassword), []byte(sshServer.config.SSHPassword)) == 1)
  146. if !userOk || !passwordOk {
  147. return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  148. }
  149. geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(conn.RemoteAddr()))
  150. log.WithContextFields(
  151. LogFields{
  152. "sshSessionID": hex.EncodeToString(conn.SessionID()),
  153. "psiphonSessionID": sshPasswordPayload.SessionId,
  154. "country": geoIPData.Country,
  155. "city": geoIPData.City,
  156. "ISP": geoIPData.ISP,
  157. }).Info("tunnel started")
  158. return nil, nil
  159. }
  160. func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  161. if err != nil {
  162. log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
  163. } else {
  164. log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
  165. }
  166. }
  167. func (sshServer *sshServer) registerClient(sshConn ssh.Conn) bool {
  168. sshServer.clientMutex.Lock()
  169. defer sshServer.clientMutex.Unlock()
  170. if sshServer.stoppingClients {
  171. return false
  172. }
  173. existingSshConn := sshServer.clients[string(sshConn.SessionID())]
  174. if existingSshConn != nil {
  175. log.WithContext().Warning("unexpected existing connection")
  176. existingSshConn.Close()
  177. existingSshConn.Wait()
  178. }
  179. sshServer.clients[string(sshConn.SessionID())] = sshConn
  180. return true
  181. }
  182. func (sshServer *sshServer) unregisterClient(sshConn ssh.Conn) {
  183. sshServer.clientMutex.Lock()
  184. if sshServer.stoppingClients {
  185. return
  186. }
  187. delete(sshServer.clients, string(sshConn.SessionID()))
  188. sshServer.clientMutex.Unlock()
  189. sshConn.Close()
  190. }
  191. func (sshServer *sshServer) stopClients() {
  192. sshServer.clientMutex.Lock()
  193. sshServer.stoppingClients = true
  194. sshServer.clientMutex.Unlock()
  195. for _, sshConn := range sshServer.clients {
  196. sshConn.Close()
  197. sshConn.Wait()
  198. }
  199. }
  200. func (sshServer *sshServer) handleClient(conn net.Conn) {
  201. // Run the initial [obfuscated] SSH handshake in a goroutine
  202. // so we can both respect shutdownBroadcast and implement a
  203. // handshake timeout. The timeout is to reclaim network
  204. // resources in case the handshake takes too long.
  205. type sshNewServerConnResult struct {
  206. conn net.Conn
  207. sshConn *ssh.ServerConn
  208. channels <-chan ssh.NewChannel
  209. requests <-chan *ssh.Request
  210. err error
  211. }
  212. resultChannel := make(chan *sshNewServerConnResult, 2)
  213. if SSH_HANDSHAKE_TIMEOUT > 0 {
  214. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  215. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  216. })
  217. }
  218. go func() {
  219. result := &sshNewServerConnResult{}
  220. if sshServer.useObfuscation {
  221. result.conn, result.err = psiphon.NewObfuscatedSshConn(
  222. psiphon.OBFUSCATION_CONN_MODE_SERVER, conn, sshServer.config.ObfuscatedSSHKey)
  223. } else {
  224. result.conn = conn
  225. }
  226. if result.err == nil {
  227. result.sshConn, result.channels,
  228. result.requests, result.err = ssh.NewServerConn(result.conn, sshServer.sshConfig)
  229. }
  230. resultChannel <- result
  231. }()
  232. var result *sshNewServerConnResult
  233. select {
  234. case result = <-resultChannel:
  235. case <-sshServer.shutdownBroadcast:
  236. // Close() will interrupt an ongoing handshake
  237. // TODO: wait for goroutine to exit before returning?
  238. conn.Close()
  239. return
  240. }
  241. if result.err != nil {
  242. conn.Close()
  243. log.WithContextFields(LogFields{"error": result.err}).Warning("handshake failed")
  244. return
  245. }
  246. if !sshServer.registerClient(result.sshConn) {
  247. result.sshConn.Close()
  248. log.WithContext().Warning("register failed")
  249. return
  250. }
  251. defer sshServer.unregisterClient(result.sshConn)
  252. // TODO: don't record IP; do GeoIP
  253. log.WithContextFields(
  254. LogFields{"remoteAddr": result.sshConn.RemoteAddr()}).Warning("connection accepted")
  255. go ssh.DiscardRequests(result.requests)
  256. for newChannel := range result.channels {
  257. if newChannel.ChannelType() != "direct-tcpip" {
  258. sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  259. return
  260. }
  261. // process each port forward concurrently
  262. go sshServer.handleNewDirectTcpipChannel(newChannel)
  263. }
  264. }
  265. func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
  266. // TODO: log more details?
  267. log.WithContextFields(
  268. LogFields{
  269. "channelType": newChannel.ChannelType(),
  270. "rejectMessage": message,
  271. "rejectReason": reason,
  272. }).Warning("reject new channel")
  273. newChannel.Reject(reason, message)
  274. }
  275. func (sshServer *sshServer) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
  276. // http://tools.ietf.org/html/rfc4254#section-7.2
  277. var directTcpipExtraData struct {
  278. HostToConnect string
  279. PortToConnect uint32
  280. OriginatorIPAddress string
  281. OriginatorPort uint32
  282. }
  283. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  284. if err != nil {
  285. sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  286. return
  287. }
  288. targetAddr := fmt.Sprintf("%s:%d",
  289. directTcpipExtraData.HostToConnect,
  290. directTcpipExtraData.PortToConnect)
  291. log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
  292. // TODO: port forward dial timeout
  293. // TODO: report ssh.ResourceShortage when appropriate
  294. fwdConn, err := net.Dial("tcp", targetAddr)
  295. if err != nil {
  296. sshServer.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
  297. return
  298. }
  299. defer fwdConn.Close()
  300. fwdChannel, requests, err := newChannel.Accept()
  301. if err != nil {
  302. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  303. return
  304. }
  305. log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
  306. go ssh.DiscardRequests(requests)
  307. defer fwdChannel.Close()
  308. // relay channel to forwarded connection
  309. // TODO: use a low-memory io.Copy?
  310. // TODO: relay errors to fwdChannel.Stderr()?
  311. relayWaitGroup := new(sync.WaitGroup)
  312. relayWaitGroup.Add(1)
  313. go func() {
  314. defer relayWaitGroup.Done()
  315. _, err := io.Copy(fwdConn, fwdChannel)
  316. if err != nil {
  317. log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
  318. }
  319. }()
  320. _, err = io.Copy(fwdChannel, fwdConn)
  321. if err != nil {
  322. log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
  323. }
  324. fwdChannel.CloseWrite()
  325. relayWaitGroup.Wait()
  326. log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
  327. }