tunnelServer.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package server
  20. import (
  21. "crypto/subtle"
  22. "encoding/json"
  23. "errors"
  24. "fmt"
  25. "io"
  26. "net"
  27. "sync"
  28. "sync/atomic"
  29. "time"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
  31. "golang.org/x/crypto/ssh"
  32. )
  33. // TunnelServer is the main server that accepts Psiphon client
  34. // connections, via various obfuscation protocols, and provides
  35. // port forwarding (TCP and UDP) services to the Psiphon client.
  36. // At its core, TunnelServer is an SSH server. SSH is the base
  37. // protocol that provides port forward multiplexing, and transport
  38. // security. Layered on top of SSH, optionally, is Obfuscated SSH
  39. // and meek protocols, which provide further circumvention
  40. // capabilities.
  41. type TunnelServer struct {
  42. config *Config
  43. runWaitGroup *sync.WaitGroup
  44. listenerError chan error
  45. shutdownBroadcast <-chan struct{}
  46. sshServer *sshServer
  47. }
  48. // NewTunnelServer initializes a new tunnel server.
  49. func NewTunnelServer(
  50. config *Config,
  51. psinetDatabase *PsinetDatabase,
  52. shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
  53. sshServer, err := newSSHServer(
  54. config, psinetDatabase, shutdownBroadcast)
  55. if err != nil {
  56. return nil, psiphon.ContextError(err)
  57. }
  58. return &TunnelServer{
  59. config: config,
  60. runWaitGroup: new(sync.WaitGroup),
  61. listenerError: make(chan error),
  62. shutdownBroadcast: shutdownBroadcast,
  63. sshServer: sshServer,
  64. }, nil
  65. }
  66. // GetLoadStats returns load stats for the tunnel server. The stats are
  67. // broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
  68. // include current connected client count, total number of current port
  69. // forwards.
  70. func (server *TunnelServer) GetLoadStats() map[string]map[string]int64 {
  71. return server.sshServer.getLoadStats()
  72. }
  73. // Run runs the tunnel server; this function blocks while running a selection of
  74. // listeners that handle connection using various obfuscation protocols.
  75. //
  76. // Run listens on each designated tunnel port and spawns new goroutines to handle
  77. // each client connection. It halts when shutdownBroadcast is signaled. A list of active
  78. // clients is maintained, and when halting all clients are cleanly shutdown.
  79. //
  80. // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
  81. // authentication, and then looping on client new channel requests. "direct-tcpip"
  82. // channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress
  83. // config parameter is configured, UDP port forwards over a TCP stream, following
  84. // the udpgw protocol, are handled.
  85. //
  86. // A new goroutine is spawned to handle each port forward for each client. Each port
  87. // forward tracks its bytes transferred. Overall per-client stats for connection duration,
  88. // GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
  89. // client shuts down.
  90. func (server *TunnelServer) Run() error {
  91. type sshListener struct {
  92. net.Listener
  93. localAddress string
  94. tunnelProtocol string
  95. }
  96. // First bind all listeners; once all are successful,
  97. // start accepting connections on each.
  98. var listeners []*sshListener
  99. for tunnelProtocol, listenPort := range server.config.TunnelProtocolPorts {
  100. localAddress := fmt.Sprintf(
  101. "%s:%d", server.config.ServerIPAddress, listenPort)
  102. listener, err := net.Listen("tcp", localAddress)
  103. if err != nil {
  104. for _, existingListener := range listeners {
  105. existingListener.Listener.Close()
  106. }
  107. return psiphon.ContextError(err)
  108. }
  109. log.WithContextFields(
  110. LogFields{
  111. "localAddress": localAddress,
  112. "tunnelProtocol": tunnelProtocol,
  113. }).Info("listening")
  114. listeners = append(
  115. listeners,
  116. &sshListener{
  117. Listener: listener,
  118. localAddress: localAddress,
  119. tunnelProtocol: tunnelProtocol,
  120. })
  121. }
  122. for _, listener := range listeners {
  123. server.runWaitGroup.Add(1)
  124. go func(listener *sshListener) {
  125. defer server.runWaitGroup.Done()
  126. log.WithContextFields(
  127. LogFields{
  128. "localAddress": listener.localAddress,
  129. "tunnelProtocol": listener.tunnelProtocol,
  130. }).Info("running")
  131. server.sshServer.runListener(
  132. listener.Listener,
  133. server.listenerError,
  134. listener.tunnelProtocol)
  135. log.WithContextFields(
  136. LogFields{
  137. "localAddress": listener.localAddress,
  138. "tunnelProtocol": listener.tunnelProtocol,
  139. }).Info("stopped")
  140. }(listener)
  141. }
  142. var err error
  143. select {
  144. case <-server.shutdownBroadcast:
  145. case err = <-server.listenerError:
  146. }
  147. for _, listener := range listeners {
  148. listener.Close()
  149. }
  150. server.sshServer.stopClients()
  151. server.runWaitGroup.Wait()
  152. log.WithContext().Info("stopped")
  153. return err
  154. }
  155. type sshClientID uint64
  156. type sshServer struct {
  157. config *Config
  158. psinetDatabase *PsinetDatabase
  159. shutdownBroadcast <-chan struct{}
  160. sshHostKey ssh.Signer
  161. nextClientID sshClientID
  162. clientsMutex sync.Mutex
  163. stoppingClients bool
  164. clients map[sshClientID]*sshClient
  165. }
  166. func newSSHServer(
  167. config *Config,
  168. psinetDatabase *PsinetDatabase,
  169. shutdownBroadcast <-chan struct{}) (*sshServer, error) {
  170. privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
  171. if err != nil {
  172. return nil, psiphon.ContextError(err)
  173. }
  174. // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
  175. signer, err := ssh.NewSignerFromKey(privateKey)
  176. if err != nil {
  177. return nil, psiphon.ContextError(err)
  178. }
  179. return &sshServer{
  180. config: config,
  181. psinetDatabase: psinetDatabase,
  182. shutdownBroadcast: shutdownBroadcast,
  183. sshHostKey: signer,
  184. nextClientID: 1,
  185. clients: make(map[sshClientID]*sshClient),
  186. }, nil
  187. }
  188. // runListener is intended to run an a goroutine; it blocks
  189. // running a particular listener. If an unrecoverable error
  190. // occurs, it will send the error to the listenerError channel.
  191. func (sshServer *sshServer) runListener(
  192. listener net.Listener,
  193. listenerError chan<- error,
  194. tunnelProtocol string) {
  195. handleClient := func(clientConn net.Conn) {
  196. // process each client connection concurrently
  197. go sshServer.handleClient(tunnelProtocol, clientConn)
  198. }
  199. // Note: when exiting due to a unrecoverable error, be sure
  200. // to try to send the error to listenerError so that the outer
  201. // TunnelServer.Run will properly shut down instead of remaining
  202. // running.
  203. if psiphon.TunnelProtocolUsesMeekHTTP(tunnelProtocol) ||
  204. psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
  205. meekServer, err := NewMeekServer(
  206. sshServer.config,
  207. listener,
  208. psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol),
  209. handleClient,
  210. sshServer.shutdownBroadcast)
  211. if err != nil {
  212. select {
  213. case listenerError <- psiphon.ContextError(err):
  214. default:
  215. }
  216. return
  217. }
  218. meekServer.Run()
  219. } else {
  220. for {
  221. conn, err := listener.Accept()
  222. select {
  223. case <-sshServer.shutdownBroadcast:
  224. if err == nil {
  225. conn.Close()
  226. }
  227. return
  228. default:
  229. }
  230. if err != nil {
  231. if e, ok := err.(net.Error); ok && e.Temporary() {
  232. log.WithContextFields(LogFields{"error": err}).Error("accept failed")
  233. // Temporary error, keep running
  234. continue
  235. }
  236. select {
  237. case listenerError <- psiphon.ContextError(err):
  238. default:
  239. }
  240. return
  241. }
  242. handleClient(conn)
  243. }
  244. }
  245. }
  246. func (sshServer *sshServer) registerClient(client *sshClient) (sshClientID, bool) {
  247. sshServer.clientsMutex.Lock()
  248. defer sshServer.clientsMutex.Unlock()
  249. if sshServer.stoppingClients {
  250. return 0, false
  251. }
  252. clientID := sshServer.nextClientID
  253. sshServer.nextClientID += 1
  254. sshServer.clients[clientID] = client
  255. return clientID, true
  256. }
  257. func (sshServer *sshServer) unregisterClient(clientID sshClientID) {
  258. sshServer.clientsMutex.Lock()
  259. client := sshServer.clients[clientID]
  260. delete(sshServer.clients, clientID)
  261. sshServer.clientsMutex.Unlock()
  262. if client != nil {
  263. client.stop()
  264. }
  265. }
  266. func (sshServer *sshServer) getLoadStats() map[string]map[string]int64 {
  267. sshServer.clientsMutex.Lock()
  268. defer sshServer.clientsMutex.Unlock()
  269. loadStats := make(map[string]map[string]int64)
  270. for _, client := range sshServer.clients {
  271. if loadStats[client.tunnelProtocol] == nil {
  272. loadStats[client.tunnelProtocol] = make(map[string]int64)
  273. }
  274. // Note: can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
  275. loadStats[client.tunnelProtocol]["CurrentClients"] += 1
  276. client.Lock()
  277. loadStats[client.tunnelProtocol]["CurrentTCPPortForwards"] += client.tcpTrafficState.concurrentPortForwardCount
  278. loadStats[client.tunnelProtocol]["TotalTCPPortForwards"] += client.tcpTrafficState.totalPortForwardCount
  279. loadStats[client.tunnelProtocol]["CurrentUDPPortForwards"] += client.udpTrafficState.concurrentPortForwardCount
  280. loadStats[client.tunnelProtocol]["TotalUDPPortForwards"] += client.udpTrafficState.totalPortForwardCount
  281. client.Unlock()
  282. }
  283. return loadStats
  284. }
  285. func (sshServer *sshServer) stopClients() {
  286. sshServer.clientsMutex.Lock()
  287. sshServer.stoppingClients = true
  288. clients := sshServer.clients
  289. sshServer.clients = make(map[sshClientID]*sshClient)
  290. sshServer.clientsMutex.Unlock()
  291. for _, client := range clients {
  292. client.stop()
  293. }
  294. }
  295. func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
  296. geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(clientConn.RemoteAddr()))
  297. sshClient := newSshClient(
  298. sshServer,
  299. tunnelProtocol,
  300. geoIPData,
  301. sshServer.config.GetTrafficRules(geoIPData.Country))
  302. // Wrap the base client connection with an ActivityMonitoredConn which will
  303. // terminate the connection if no data is received before the deadline. This
  304. // timeout is in effect for the entire duration of the SSH connection. Clients
  305. // must actively use the connection or send SSH keep alive requests to keep
  306. // the connection active.
  307. activityConn := psiphon.NewActivityMonitoredConn(
  308. clientConn,
  309. SSH_CONNECTION_READ_DEADLINE,
  310. false,
  311. nil)
  312. clientConn = activityConn
  313. // Further wrap the connection in a rate limiting ThrottledConn.
  314. rateLimits := sshClient.trafficRules.GetRateLimits(tunnelProtocol)
  315. clientConn = psiphon.NewThrottledConn(
  316. clientConn,
  317. rateLimits.DownstreamUnlimitedBytes,
  318. int64(rateLimits.DownstreamBytesPerSecond),
  319. rateLimits.UpstreamUnlimitedBytes,
  320. int64(rateLimits.UpstreamBytesPerSecond))
  321. // Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  322. // respect shutdownBroadcast and implement a specific handshake timeout.
  323. // The timeout is to reclaim network resources in case the handshake takes
  324. // too long.
  325. type sshNewServerConnResult struct {
  326. conn net.Conn
  327. sshConn *ssh.ServerConn
  328. channels <-chan ssh.NewChannel
  329. requests <-chan *ssh.Request
  330. err error
  331. }
  332. resultChannel := make(chan *sshNewServerConnResult, 2)
  333. if SSH_HANDSHAKE_TIMEOUT > 0 {
  334. time.AfterFunc(time.Duration(SSH_HANDSHAKE_TIMEOUT), func() {
  335. resultChannel <- &sshNewServerConnResult{err: errors.New("ssh handshake timeout")}
  336. })
  337. }
  338. go func(conn net.Conn) {
  339. sshServerConfig := &ssh.ServerConfig{
  340. PasswordCallback: sshClient.passwordCallback,
  341. AuthLogCallback: sshClient.authLogCallback,
  342. ServerVersion: sshServer.config.SSHServerVersion,
  343. }
  344. sshServerConfig.AddHostKey(sshServer.sshHostKey)
  345. result := &sshNewServerConnResult{}
  346. // Wrap the connection in an SSH deobfuscator when required.
  347. if psiphon.TunnelProtocolUsesObfuscatedSSH(tunnelProtocol) {
  348. // Note: NewObfuscatedSshConn blocks on network I/O
  349. // TODO: ensure this won't block shutdown
  350. conn, result.err = psiphon.NewObfuscatedSshConn(
  351. psiphon.OBFUSCATION_CONN_MODE_SERVER,
  352. clientConn,
  353. sshServer.config.ObfuscatedSSHKey)
  354. if result.err != nil {
  355. result.err = psiphon.ContextError(result.err)
  356. }
  357. }
  358. if result.err == nil {
  359. result.sshConn, result.channels, result.requests, result.err =
  360. ssh.NewServerConn(conn, sshServerConfig)
  361. }
  362. resultChannel <- result
  363. }(clientConn)
  364. var result *sshNewServerConnResult
  365. select {
  366. case result = <-resultChannel:
  367. case <-sshServer.shutdownBroadcast:
  368. // Close() will interrupt an ongoing handshake
  369. // TODO: wait for goroutine to exit before returning?
  370. clientConn.Close()
  371. return
  372. }
  373. if result.err != nil {
  374. clientConn.Close()
  375. // This is a Debug log due to noise. The handshake often fails due to I/O
  376. // errors as clients frequently interrupt connections in progress when
  377. // client-side load balancing completes a connection to a different server.
  378. log.WithContextFields(LogFields{"error": result.err}).Debug("handshake failed")
  379. return
  380. }
  381. sshClient.Lock()
  382. sshClient.sshConn = result.sshConn
  383. sshClient.activityConn = activityConn
  384. sshClient.Unlock()
  385. clientID, ok := sshServer.registerClient(sshClient)
  386. if !ok {
  387. clientConn.Close()
  388. log.WithContext().Warning("register failed")
  389. return
  390. }
  391. defer sshServer.unregisterClient(clientID)
  392. sshClient.runClient(result.channels, result.requests)
  393. // TODO: clientConn.Close()?
  394. }
  395. type sshClient struct {
  396. sync.Mutex
  397. sshServer *sshServer
  398. tunnelProtocol string
  399. sshConn ssh.Conn
  400. activityConn *psiphon.ActivityMonitoredConn
  401. geoIPData GeoIPData
  402. psiphonSessionID string
  403. udpChannel ssh.Channel
  404. trafficRules TrafficRules
  405. tcpTrafficState *trafficState
  406. udpTrafficState *trafficState
  407. channelHandlerWaitGroup *sync.WaitGroup
  408. tcpPortForwardLRU *psiphon.LRUConns
  409. stopBroadcast chan struct{}
  410. }
  411. type trafficState struct {
  412. bytesUp int64
  413. bytesDown int64
  414. concurrentPortForwardCount int64
  415. peakConcurrentPortForwardCount int64
  416. totalPortForwardCount int64
  417. }
  418. func newSshClient(
  419. sshServer *sshServer, tunnelProtocol string, geoIPData GeoIPData, trafficRules TrafficRules) *sshClient {
  420. return &sshClient{
  421. sshServer: sshServer,
  422. tunnelProtocol: tunnelProtocol,
  423. geoIPData: geoIPData,
  424. trafficRules: trafficRules,
  425. tcpTrafficState: &trafficState{},
  426. udpTrafficState: &trafficState{},
  427. channelHandlerWaitGroup: new(sync.WaitGroup),
  428. tcpPortForwardLRU: psiphon.NewLRUConns(),
  429. stopBroadcast: make(chan struct{}),
  430. }
  431. }
  432. func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  433. var sshPasswordPayload struct {
  434. SessionId string `json:"SessionId"`
  435. SshPassword string `json:"SshPassword"`
  436. }
  437. err := json.Unmarshal(password, &sshPasswordPayload)
  438. if err != nil {
  439. return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
  440. }
  441. userOk := (subtle.ConstantTimeCompare(
  442. []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
  443. passwordOk := (subtle.ConstantTimeCompare(
  444. []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
  445. if !userOk || !passwordOk {
  446. return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
  447. }
  448. psiphonSessionID := sshPasswordPayload.SessionId
  449. sshClient.Lock()
  450. sshClient.psiphonSessionID = psiphonSessionID
  451. geoIPData := sshClient.geoIPData
  452. sshClient.Unlock()
  453. if sshClient.sshServer.config.UseRedis() {
  454. err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
  455. if err != nil {
  456. log.WithContextFields(LogFields{
  457. "psiphonSessionID": psiphonSessionID,
  458. "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
  459. // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
  460. }
  461. }
  462. return nil, nil
  463. }
  464. func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  465. if err != nil {
  466. if sshClient.sshServer.config.UseFail2Ban() {
  467. clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
  468. if clientIPAddress != "" {
  469. LogFail2Ban(clientIPAddress)
  470. }
  471. }
  472. log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
  473. } else {
  474. log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
  475. }
  476. }
  477. func (sshClient *sshClient) stop() {
  478. sshClient.sshConn.Close()
  479. sshClient.sshConn.Wait()
  480. close(sshClient.stopBroadcast)
  481. sshClient.channelHandlerWaitGroup.Wait()
  482. // Note: reporting duration based on last confirmed data transfer, which
  483. // is reads for sshClient.activityConn.GetActiveDuration(), and not
  484. // connection closing is important for protocols such as meek. For
  485. // meek, the connection remains open until the HTTP session expires,
  486. // which may be some time after the tunnel has closed. (The meek
  487. // protocol has no allowance for signalling payload EOF, and even if
  488. // it did the client may not have the opportunity to send a final
  489. // request with an EOF flag set.)
  490. sshClient.Lock()
  491. log.WithContextFields(
  492. LogFields{
  493. "startTime": sshClient.activityConn.GetStartTime(),
  494. "duration": sshClient.activityConn.GetActiveDuration(),
  495. "psiphonSessionID": sshClient.psiphonSessionID,
  496. "country": sshClient.geoIPData.Country,
  497. "city": sshClient.geoIPData.City,
  498. "ISP": sshClient.geoIPData.ISP,
  499. "bytesUpTCP": sshClient.tcpTrafficState.bytesUp,
  500. "bytesDownTCP": sshClient.tcpTrafficState.bytesDown,
  501. "peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
  502. "totalPortForwardCountTCP": sshClient.tcpTrafficState.totalPortForwardCount,
  503. "bytesUpUDP": sshClient.udpTrafficState.bytesUp,
  504. "bytesDownUDP": sshClient.udpTrafficState.bytesDown,
  505. "peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
  506. "totalPortForwardCountUDP": sshClient.udpTrafficState.totalPortForwardCount,
  507. }).Info("tunnel closed")
  508. sshClient.Unlock()
  509. }
  510. // runClient handles/dispatches new channel and new requests from the client.
  511. // When the SSH client connection closes, both the channels and requests channels
  512. // will close and runClient will exit.
  513. func (sshClient *sshClient) runClient(
  514. channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
  515. requestsWaitGroup := new(sync.WaitGroup)
  516. requestsWaitGroup.Add(1)
  517. go func() {
  518. defer requestsWaitGroup.Done()
  519. for request := range requests {
  520. // requests are processed serially; responses must be sent in request order.
  521. responsePayload, err := sshAPIRequestHandler(
  522. sshClient.sshServer.config,
  523. sshClient.sshServer.psinetDatabase,
  524. sshClient.geoIPData,
  525. request.Type,
  526. request.Payload)
  527. if err == nil {
  528. err = request.Reply(true, responsePayload)
  529. } else {
  530. log.WithContextFields(LogFields{"error": err}).Warning("request failed")
  531. err = request.Reply(false, nil)
  532. }
  533. if err != nil {
  534. log.WithContextFields(LogFields{"error": err}).Warning("response failed")
  535. }
  536. }
  537. }()
  538. for newChannel := range channels {
  539. if newChannel.ChannelType() != "direct-tcpip" {
  540. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "unknown or unsupported channel type")
  541. continue
  542. }
  543. // process each port forward concurrently
  544. sshClient.channelHandlerWaitGroup.Add(1)
  545. go sshClient.handleNewPortForwardChannel(newChannel)
  546. }
  547. requestsWaitGroup.Wait()
  548. }
  549. func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason ssh.RejectionReason, message string) {
  550. // TODO: log more details?
  551. log.WithContextFields(
  552. LogFields{
  553. "channelType": newChannel.ChannelType(),
  554. "rejectMessage": message,
  555. "rejectReason": reason,
  556. }).Warning("reject new channel")
  557. newChannel.Reject(reason, message)
  558. }
  559. func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
  560. defer sshClient.channelHandlerWaitGroup.Done()
  561. // http://tools.ietf.org/html/rfc4254#section-7.2
  562. var directTcpipExtraData struct {
  563. HostToConnect string
  564. PortToConnect uint32
  565. OriginatorIPAddress string
  566. OriginatorPort uint32
  567. }
  568. err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  569. if err != nil {
  570. sshClient.rejectNewChannel(newChannel, ssh.Prohibited, "invalid extra data")
  571. return
  572. }
  573. // Intercept TCP port forwards to a specified udpgw server and handle directly.
  574. // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
  575. isUDPChannel := sshClient.sshServer.config.UDPInterceptUdpgwServerAddress != "" &&
  576. sshClient.sshServer.config.UDPInterceptUdpgwServerAddress ==
  577. fmt.Sprintf("%s:%d",
  578. directTcpipExtraData.HostToConnect,
  579. directTcpipExtraData.PortToConnect)
  580. if isUDPChannel {
  581. sshClient.handleUDPChannel(newChannel)
  582. } else {
  583. sshClient.handleTCPChannel(
  584. directTcpipExtraData.HostToConnect, int(directTcpipExtraData.PortToConnect), newChannel)
  585. }
  586. }
  587. func (sshClient *sshClient) isPortForwardPermitted(
  588. port int, allowPorts []int, denyPorts []int) bool {
  589. // TODO: faster lookup?
  590. if len(allowPorts) > 0 {
  591. for _, allowPort := range allowPorts {
  592. if port == allowPort {
  593. return true
  594. }
  595. }
  596. return false
  597. }
  598. if len(denyPorts) > 0 {
  599. for _, denyPort := range denyPorts {
  600. if port == denyPort {
  601. return false
  602. }
  603. }
  604. }
  605. return true
  606. }
  607. func (sshClient *sshClient) isPortForwardLimitExceeded(
  608. state *trafficState, maxPortForwardCount int) bool {
  609. limitExceeded := false
  610. if maxPortForwardCount > 0 {
  611. sshClient.Lock()
  612. limitExceeded = state.concurrentPortForwardCount >= int64(maxPortForwardCount)
  613. sshClient.Unlock()
  614. }
  615. return limitExceeded
  616. }
  617. func (sshClient *sshClient) openedPortForward(
  618. state *trafficState) {
  619. sshClient.Lock()
  620. state.concurrentPortForwardCount += 1
  621. if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
  622. state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
  623. }
  624. state.totalPortForwardCount += 1
  625. sshClient.Unlock()
  626. }
  627. func (sshClient *sshClient) closedPortForward(
  628. state *trafficState, bytesUp, bytesDown int64) {
  629. sshClient.Lock()
  630. state.concurrentPortForwardCount -= 1
  631. state.bytesUp += bytesUp
  632. state.bytesDown += bytesDown
  633. sshClient.Unlock()
  634. }
  635. func (sshClient *sshClient) handleTCPChannel(
  636. hostToConnect string,
  637. portToConnect int,
  638. newChannel ssh.NewChannel) {
  639. if !sshClient.isPortForwardPermitted(
  640. portToConnect,
  641. sshClient.trafficRules.AllowTCPPorts,
  642. sshClient.trafficRules.DenyTCPPorts) {
  643. sshClient.rejectNewChannel(
  644. newChannel, ssh.Prohibited, "port forward not permitted")
  645. return
  646. }
  647. var bytesUp, bytesDown int64
  648. sshClient.openedPortForward(sshClient.tcpTrafficState)
  649. defer func() {
  650. sshClient.closedPortForward(
  651. sshClient.tcpTrafficState,
  652. atomic.LoadInt64(&bytesUp),
  653. atomic.LoadInt64(&bytesDown))
  654. }()
  655. // TOCTOU note: important to increment the port forward count (via
  656. // openPortForward) _before_ checking isPortForwardLimitExceeded
  657. // otherwise, the client could potentially consume excess resources
  658. // by initiating many port forwards concurrently.
  659. // TODO: close LRU connection (after successful Dial) instead of
  660. // rejecting new connection?
  661. if sshClient.isPortForwardLimitExceeded(
  662. sshClient.tcpTrafficState,
  663. sshClient.trafficRules.MaxTCPPortForwardCount) {
  664. // Close the oldest TCP port forward. CloseOldest() closes
  665. // the conn and the port forward's goroutine will complete
  666. // the cleanup asynchronously.
  667. //
  668. // Some known limitations:
  669. //
  670. // - Since CloseOldest() closes the upstream socket but does not
  671. // clean up all resources associated with the port forward. These
  672. // include the goroutine(s) relaying traffic as well as the SSH
  673. // channel. Closing the socket will interrupt the goroutines which
  674. // will then complete the cleanup. But, since the full cleanup is
  675. // asynchronous, there exists a possibility that a client can consume
  676. // more than max port forward resources -- just not upstream sockets.
  677. //
  678. // - An LRU list entry for this port forward is not added until
  679. // after the dial completes, but the port forward is counted
  680. // towards max limits. This means many dials in progress will
  681. // put established connections in jeopardy.
  682. //
  683. // - We're closing the oldest open connection _before_ successfully
  684. // dialing the new port forward. This means we are potentially
  685. // discarding a good connection to make way for a failed connection.
  686. // We cannot simply dial first and still maintain a limit on
  687. // resources used, so to address this we'd need to add some
  688. // accounting for connections still establishing.
  689. sshClient.tcpPortForwardLRU.CloseOldest()
  690. log.WithContextFields(
  691. LogFields{
  692. "maxCount": sshClient.trafficRules.MaxTCPPortForwardCount,
  693. }).Debug("closed LRU TCP port forward")
  694. }
  695. // Dial the target remote address. This is done in a goroutine to
  696. // ensure the shutdown signal is handled immediately.
  697. remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
  698. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
  699. type dialTcpResult struct {
  700. conn net.Conn
  701. err error
  702. }
  703. resultChannel := make(chan *dialTcpResult, 1)
  704. go func() {
  705. // TODO: on EADDRNOTAVAIL, temporarily suspend new clients
  706. // TODO: IPv6 support
  707. conn, err := net.DialTimeout(
  708. "tcp4", remoteAddr, SSH_TCP_PORT_FORWARD_DIAL_TIMEOUT)
  709. resultChannel <- &dialTcpResult{conn, err}
  710. }()
  711. var result *dialTcpResult
  712. select {
  713. case result = <-resultChannel:
  714. case <-sshClient.stopBroadcast:
  715. // Note: may leave dial in progress
  716. return
  717. }
  718. if result.err != nil {
  719. sshClient.rejectNewChannel(newChannel, ssh.ConnectionFailed, result.err.Error())
  720. return
  721. }
  722. // The upstream TCP port forward connection has been established. Schedule
  723. // some cleanup and notify the SSH client that the channel is accepted.
  724. fwdConn := result.conn
  725. defer fwdConn.Close()
  726. lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
  727. defer lruEntry.Remove()
  728. // ActivityMonitoredConn monitors the TCP port forward I/O and updates
  729. // its LRU status. ActivityMonitoredConn also times out read on the port
  730. // forward if both reads and writes have been idle for the specified
  731. // duration.
  732. fwdConn = psiphon.NewActivityMonitoredConn(
  733. fwdConn,
  734. time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond,
  735. true,
  736. lruEntry)
  737. fwdChannel, requests, err := newChannel.Accept()
  738. if err != nil {
  739. log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
  740. return
  741. }
  742. go ssh.DiscardRequests(requests)
  743. defer fwdChannel.Close()
  744. log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
  745. // Relay channel to forwarded connection.
  746. // TODO: relay errors to fwdChannel.Stderr()?
  747. relayWaitGroup := new(sync.WaitGroup)
  748. relayWaitGroup.Add(1)
  749. go func() {
  750. defer relayWaitGroup.Done()
  751. // io.Copy allocates a 32K temporary buffer, and each port forward relay uses
  752. // two of these buffers; using io.CopyBuffer with a smaller buffer reduces the
  753. // overall memory footprint.
  754. bytes, err := io.CopyBuffer(
  755. fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  756. atomic.AddInt64(&bytesDown, bytes)
  757. if err != nil && err != io.EOF {
  758. // Debug since errors such as "connection reset by peer" occur during normal operation
  759. log.WithContextFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
  760. }
  761. // Interrupt upstream io.Copy when downstream is shutting down.
  762. // TODO: this is done to quickly cleanup the port forward when
  763. // fwdConn has a read timeout, but is it clean -- upstream may still
  764. // be flowing?
  765. fwdChannel.Close()
  766. }()
  767. bytes, err := io.CopyBuffer(
  768. fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  769. atomic.AddInt64(&bytesUp, bytes)
  770. if err != nil && err != io.EOF {
  771. log.WithContextFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
  772. }
  773. // Shutdown special case: fwdChannel will be closed and return EOF when
  774. // the SSH connection is closed, but we need to explicitly close fwdConn
  775. // to interrupt the downstream io.Copy, which may be blocked on a
  776. // fwdConn.Read().
  777. fwdConn.Close()
  778. relayWaitGroup.Wait()
  779. log.WithContextFields(
  780. LogFields{
  781. "remoteAddr": remoteAddr,
  782. "bytesUp": atomic.LoadInt64(&bytesUp),
  783. "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
  784. }