sshService.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "encoding/json"
  22. "fmt"
  23. "io"
  24. "net"
  25. "sync"
  26. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
  27. "golang.org/x/crypto/ssh"
  28. )
  29. type sshServer struct {
  30. config *Config
  31. sshConfig *ssh.ServerConfig
  32. clientMutex sync.Mutex
  33. stoppingClients bool
  34. clients map[string]ssh.Conn
  35. }
  36. func RunSSHServer(config *Config, shutdownBroadcast <-chan struct{}) error {
  37. sshServer := &sshServer{
  38. config: config,
  39. clients: make(map[string]ssh.Conn),
  40. }
  41. sshServer.sshConfig = &ssh.ServerConfig{
  42. PasswordCallback: sshServer.passwordCallback,
  43. AuthLogCallback: sshServer.authLogCallback,
  44. ServerVersion: config.SSHServerVersion,
  45. }
  46. privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
  47. if err != nil {
  48. return psiphon.ContextError(err)
  49. }
  50. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  51. signer, err := ssh.NewSignerFromKey(privateKey)
  52. if err != nil {
  53. return psiphon.ContextError(err)
  54. }
  55. sshServer.sshConfig.AddHostKey(signer)
  56. listener, err := net.Listen(
  57. "tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, config.SSHPort))
  58. if err != nil {
  59. return psiphon.ContextError(err)
  60. }
  61. log.WithContext().Info("starting")
  62. err = nil
  63. errors := make(chan error)
  64. waitGroup := new(sync.WaitGroup)
  65. waitGroup.Add(1)
  66. go func() {
  67. defer waitGroup.Done()
  68. loop:
  69. for {
  70. conn, err := listener.Accept()
  71. select {
  72. case <-shutdownBroadcast:
  73. break loop
  74. default:
  75. }
  76. if err != nil {
  77. if e, ok := err.(net.Error); ok && e.Temporary() {
  78. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  79. // Temporary error, keep running
  80. continue
  81. }
  82. select {
  83. case errors <- psiphon.ContextError(err):
  84. default:
  85. }
  86. break loop
  87. }
  88. // process each client connection concurrently
  89. go sshServer.handleClient(conn)
  90. }
  91. sshServer.stopClients()
  92. log.WithContext().Info("stopped")
  93. }()
  94. select {
  95. case <-shutdownBroadcast:
  96. case err = <-errors:
  97. }
  98. listener.Close()
  99. waitGroup.Wait()
  100. log.WithContext().Info("exiting")
  101. return err
  102. }
  103. func (sshServer *sshServer) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  104. var sshPasswordPayload struct {
  105. SessionId string `json:"SessionId"`
  106. SshPassword string `json:"SshPassword"`
  107. }
  108. err := json.Unmarshal(password, &sshPasswordPayload)
  109. if err != nil {
  110. return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  111. }
  112. if conn.User() == sshServer.config.SSHUserName &&
  113. sshPasswordPayload.SshPassword == sshServer.config.SSHPassword {
  114. return nil, nil
  115. }
  116. return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  117. }
  118. func (sshServer *sshServer) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  119. if err != nil {
  120. log.WithContextFields(LogFields{"error": err, "method": method}).Warning("authentication failed")
  121. } else {
  122. log.WithContextFields(LogFields{"error": err, "method": method}).Info("authentication success")
  123. }
  124. }
  125. func (sshServer *sshServer) registerClient(sshConn ssh.Conn) bool {
  126. sshServer.clientMutex.Lock()
  127. defer sshServer.clientMutex.Unlock()
  128. if sshServer.stoppingClients {
  129. return false
  130. }
  131. existingSshConn := sshServer.clients[string(sshConn.SessionID())]
  132. if existingSshConn != nil {
  133. log.WithContext().Warning("unexpected existing connection")
  134. existingSshConn.Close()
  135. existingSshConn.Wait()
  136. }
  137. sshServer.clients[string(sshConn.SessionID())] = sshConn
  138. return true
  139. }
  140. func (sshServer *sshServer) unregisterClient(sshConn ssh.Conn) {
  141. sshServer.clientMutex.Lock()
  142. if sshServer.stoppingClients {
  143. return
  144. }
  145. delete(sshServer.clients, string(sshConn.SessionID()))
  146. sshServer.clientMutex.Unlock()
  147. sshConn.Close()
  148. }
  149. func (sshServer *sshServer) stopClients() {
  150. sshServer.clientMutex.Lock()
  151. sshServer.stoppingClients = true
  152. sshServer.clientMutex.Unlock()
  153. for _, sshConn := range sshServer.clients {
  154. sshConn.Close()
  155. sshConn.Wait()
  156. }
  157. }
  158. func (sshServer *sshServer) handleClient(conn net.Conn) {
  159. // TODO: does this block on SSH handshake (so should be in goroutine)?
  160. sshConn, channels, requests, err := ssh.NewServerConn(conn, sshServer.sshConfig)
  161. if err != nil {
  162. conn.Close()
  163. log.WithContextFields(LogFields{"error": err}).Warning("establish failed")
  164. return
  165. }
  166. if !sshServer.registerClient(sshConn) {
  167. sshConn.Close()
  168. log.WithContext().Warning("register failed")
  169. return
  170. }
  171. defer sshServer.unregisterClient(sshConn)
  172. // TODO: don't record IP; do GeoIP
  173. log.WithContextFields(LogFields{"remoteAddr": sshConn.RemoteAddr()}).Warning("connection accepted")
  174. go ssh.DiscardRequests(requests)
  175. for newChannel := range channels {
  176. if newChannel.ChannelType() != "direct-tcpip" {
  177. sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  178. return
  179. }
  180. // process each port forward concurrently
  181. go sshServer.handleNewDirectTcpipChannel(newChannel)
  182. }
  183. }
  184. func (sshServer *sshServer) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
  185. // TODO: log more details?
  186. log.WithContextFields(
  187. LogFields{
  188. "channelType": newChannel.ChannelType(),
  189. "rejectMessage": message,
  190. "rejectReason": reason,
  191. }).Warning("reject new channel")
  192. newChannel.Reject(reason, message)
  193. }
  194. func (sshServer *sshServer) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
  195. // http://tools.ietf.org/html/rfc4254#section-7.2
  196. var directTcpipExtraData struct {
  197. HostToConnect string
  198. PortToConnect uint32
  199. OriginatorIPAddress string
  200. OriginatorPort uint32
  201. }
  202. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  203. if err != nil {
  204. sshServer.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  205. return
  206. }
  207. targetAddr := fmt.Sprintf("%s:%d",
  208. directTcpipExtraData.HostToConnect,
  209. directTcpipExtraData.PortToConnect)
  210. log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
  211. // TODO: port forward dial timeout
  212. // TODO: report ssh.ResourceShortage when appropriate
  213. fwdConn, err := net.Dial("tcp", targetAddr)
  214. if err != nil {
  215. sshServer.rejectNewChannel(newChannel, ssh.ConnectionFailed, err.Error())
  216. return
  217. }
  218. defer fwdConn.Close()
  219. fwdChannel, requests, err := newChannel.Accept()
  220. if err != nil {
  221. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  222. return
  223. }
  224. log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
  225. go ssh.DiscardRequests(requests)
  226. defer fwdChannel.Close()
  227. // relay channel to forwarded connection
  228. // TODO: use a low-memory io.Copy?
  229. // TODO: relay errors to fwdChannel.Stderr()?
  230. relayWaitGroup := new(sync.WaitGroup)
  231. relayWaitGroup.Add(1)
  232. go func() {
  233. defer relayWaitGroup.Done()
  234. _, err := io.Copy(fwdConn, fwdChannel)
  235. if err != nil {
  236. log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
  237. }
  238. }()
  239. _, err = io.Copy(fwdChannel, fwdConn)
  240. if err != nil {
  241. log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
  242. }
  243. fwdChannel.CloseWrite()
  244. relayWaitGroup.Wait()
  245. log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
  246. }