|
|
@@ -355,13 +355,17 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
|
|
|
geoIPData,
|
|
|
sshServer.config.GetTrafficRules(geoIPData.Country))
|
|
|
|
|
|
- // Wrap the base client connection with an IdleTimeoutConn which will terminate
|
|
|
- // the connection if no data is received before the deadline. This timeout is
|
|
|
- // in effect for the entire duration of the SSH connection. Clients must actively
|
|
|
- // use the connection or send SSH keep alive requests to keep the connection
|
|
|
- // active.
|
|
|
+ // Wrap the base client connection with an ActivityMonitoredConn which will
|
|
|
+ // terminate the connection if no data is received before the deadline. This
|
|
|
+ // timeout is in effect for the entire duration of the SSH connection. Clients
|
|
|
+ // must actively use the connection or send SSH keep alive requests to keep
|
|
|
+ // the connection active.
|
|
|
|
|
|
- clientConn = psiphon.NewIdleTimeoutConn(clientConn, SSH_CONNECTION_READ_DEADLINE, false)
|
|
|
+ clientConn = psiphon.NewActivityMonitoredConn(
|
|
|
+ clientConn,
|
|
|
+ SSH_CONNECTION_READ_DEADLINE,
|
|
|
+ false,
|
|
|
+ nil)
|
|
|
|
|
|
// Further wrap the connection in a rate limiting ThrottledConn.
|
|
|
|
|
|
@@ -478,6 +482,7 @@ type sshClient struct {
|
|
|
tcpTrafficState *trafficState
|
|
|
udpTrafficState *trafficState
|
|
|
channelHandlerWaitGroup *sync.WaitGroup
|
|
|
+ tcpPortForwardLRU *psiphon.LRUConns
|
|
|
stopBroadcast chan struct{}
|
|
|
}
|
|
|
|
|
|
@@ -500,10 +505,94 @@ func newSshClient(
|
|
|
tcpTrafficState: &trafficState{},
|
|
|
udpTrafficState: &trafficState{},
|
|
|
channelHandlerWaitGroup: new(sync.WaitGroup),
|
|
|
+ tcpPortForwardLRU: psiphon.NewLRUConns(),
|
|
|
stopBroadcast: make(chan struct{}),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
+ var sshPasswordPayload struct {
|
|
|
+ SessionId string `json:"SessionId"`
|
|
|
+ SshPassword string `json:"SshPassword"`
|
|
|
+ }
|
|
|
+ err := json.Unmarshal(password, &sshPasswordPayload)
|
|
|
+ if err != nil {
|
|
|
+ return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
|
|
|
+ }
|
|
|
+
|
|
|
+ userOk := (subtle.ConstantTimeCompare(
|
|
|
+ []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
|
|
|
+
|
|
|
+ passwordOk := (subtle.ConstantTimeCompare(
|
|
|
+ []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
|
|
|
+
|
|
|
+ if !userOk || !passwordOk {
|
|
|
+ return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
|
|
|
+ }
|
|
|
+
|
|
|
+ psiphonSessionID := sshPasswordPayload.SessionId
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ sshClient.psiphonSessionID = psiphonSessionID
|
|
|
+ geoIPData := sshClient.geoIPData
|
|
|
+ sshClient.Unlock()
|
|
|
+
|
|
|
+ if sshClient.sshServer.config.UseRedis() {
|
|
|
+ err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
|
|
|
+ if err != nil {
|
|
|
+ log.WithContextFields(LogFields{
|
|
|
+ "psiphonSessionID": psiphonSessionID,
|
|
|
+ "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
|
|
|
+ // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
|
|
+ if err != nil {
|
|
|
+ if sshClient.sshServer.config.UseFail2Ban() {
|
|
|
+ clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
|
|
|
+ if clientIPAddress != "" {
|
|
|
+ LogFail2Ban(clientIPAddress)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
|
|
|
+ } else {
|
|
|
+ log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (sshClient *sshClient) stop() {
|
|
|
+
|
|
|
+ sshClient.sshConn.Close()
|
|
|
+ sshClient.sshConn.Wait()
|
|
|
+
|
|
|
+ close(sshClient.stopBroadcast)
|
|
|
+ sshClient.channelHandlerWaitGroup.Wait()
|
|
|
+
|
|
|
+ sshClient.Lock()
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{
|
|
|
+ "startTime": sshClient.startTime,
|
|
|
+ "duration": time.Now().Sub(sshClient.startTime),
|
|
|
+ "psiphonSessionID": sshClient.psiphonSessionID,
|
|
|
+ "country": sshClient.geoIPData.Country,
|
|
|
+ "city": sshClient.geoIPData.City,
|
|
|
+ "ISP": sshClient.geoIPData.ISP,
|
|
|
+ "bytesUpTCP": sshClient.tcpTrafficState.bytesUp,
|
|
|
+ "bytesDownTCP": sshClient.tcpTrafficState.bytesDown,
|
|
|
+ "peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
|
|
|
+ "totalPortForwardCountTCP": sshClient.tcpTrafficState.totalPortForwardCount,
|
|
|
+ "bytesUpUDP": sshClient.udpTrafficState.bytesUp,
|
|
|
+ "bytesDownUDP": sshClient.udpTrafficState.bytesDown,
|
|
|
+ "peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
|
|
|
+ "totalPortForwardCountUDP": sshClient.udpTrafficState.totalPortForwardCount,
|
|
|
+ }).Info("tunnel closed")
|
|
|
+ sshClient.Unlock()
|
|
|
+}
|
|
|
+
|
|
|
func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
|
|
|
for newChannel := range channels {
|
|
|
|
|
|
@@ -652,11 +741,43 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
sshClient.tcpTrafficState,
|
|
|
sshClient.trafficRules.MaxTCPPortForwardCount) {
|
|
|
|
|
|
- sshClient.rejectNewChannel(
|
|
|
- newChannel, ssh.Prohibited, "maximum port forward limit exceeded")
|
|
|
- return
|
|
|
+ // Close the oldest TCP port forward. CloseOldest() closes
|
|
|
+ // the conn and the port forward's goroutine will complete
|
|
|
+ // the cleanup asynchronously.
|
|
|
+ //
|
|
|
+ // Some known limitations:
|
|
|
+ //
|
|
|
+ // - Since CloseOldest() closes the upstream socket but does not
|
|
|
+ // clean up all resources associated with the port forward. These
|
|
|
+ // include the goroutine(s) relaying traffic as well as the SSH
|
|
|
+ // channel. Closing the socket will interrupt the goroutines which
|
|
|
+ // will then complete the cleanup. But, since the full cleanup is
|
|
|
+ // asynchronous, there exists a possibility that a client can consume
|
|
|
+ // more than max port forward resources -- just not upstream sockets.
|
|
|
+ //
|
|
|
+ // - An LRU list entry for this port forward is not added until
|
|
|
+ // after the dial completes, but the port forward is counted
|
|
|
+ // towards max limits. This means many dials in progress will
|
|
|
+ // put established connections in jeopardy.
|
|
|
+ //
|
|
|
+ // - We're closing the oldest open connection _before_ successfully
|
|
|
+ // dialing the new port forward. This means we are potentially
|
|
|
+ // discarding a good connection to make way for a failed connection.
|
|
|
+ // We cannot simply dial first and still maintain a limit on
|
|
|
+ // resources used, so to address this we'd need to add some
|
|
|
+ // accounting for connections still establishing.
|
|
|
+
|
|
|
+ sshClient.tcpPortForwardLRU.CloseOldest()
|
|
|
+
|
|
|
+ log.WithContextFields(
|
|
|
+ LogFields{
|
|
|
+ "maxCount": sshClient.trafficRules.MaxTCPPortForwardCount,
|
|
|
+ }).Debug("closed LRU TCP port forward")
|
|
|
}
|
|
|
|
|
|
+ // Dial the target remote address. This is done in a goroutine to
|
|
|
+ // ensure the shutdown signal is handled immediately.
|
|
|
+
|
|
|
remoteAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
|
|
|
|
|
|
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
|
|
|
@@ -689,9 +810,25 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ // The upstream TCP port forward connection has been established. Schedule
|
|
|
+ // some cleanup and notify the SSH client that the channel is accepted.
|
|
|
+
|
|
|
fwdConn := result.conn
|
|
|
defer fwdConn.Close()
|
|
|
|
|
|
+ lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
|
|
|
+ defer lruEntry.Remove()
|
|
|
+
|
|
|
+ // ActivityMonitoredConn monitors the TCP port forward I/O and updates
|
|
|
+ // its LRU status. ActivityMonitoredConn also times out read on the port
|
|
|
+ // forward if both reads and writes have been idle for the specified
|
|
|
+ // duration.
|
|
|
+ fwdConn = psiphon.NewActivityMonitoredConn(
|
|
|
+ fwdConn,
|
|
|
+ time.Duration(sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds)*time.Millisecond,
|
|
|
+ true,
|
|
|
+ lruEntry)
|
|
|
+
|
|
|
fwdChannel, requests, err := newChannel.Accept()
|
|
|
if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
|
@@ -702,21 +839,9 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
|
|
|
log.WithContextFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
|
|
|
|
|
|
- // When idle port forward traffic rules are in place, wrap fwdConn
|
|
|
- // in an IdleTimeoutConn configured to reset idle on writes as well
|
|
|
- // as read. This ensures the port forward idle timeout only happens
|
|
|
- // when both upstream and downstream directions are are idle.
|
|
|
-
|
|
|
- if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
|
|
|
- fwdConn = psiphon.NewIdleTimeoutConn(
|
|
|
- fwdConn,
|
|
|
- time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond,
|
|
|
- true)
|
|
|
- }
|
|
|
-
|
|
|
// Relay channel to forwarded connection.
|
|
|
- // TODO: relay errors to fwdChannel.Stderr()?
|
|
|
|
|
|
+ // TODO: relay errors to fwdChannel.Stderr()?
|
|
|
relayWaitGroup := new(sync.WaitGroup)
|
|
|
relayWaitGroup.Add(1)
|
|
|
go func() {
|
|
|
@@ -757,86 +882,3 @@ func (sshClient *sshClient) handleTCPChannel(
|
|
|
"bytesUp": atomic.LoadInt64(&bytesUp),
|
|
|
"bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting")
|
|
|
}
|
|
|
-
|
|
|
-func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
- var sshPasswordPayload struct {
|
|
|
- SessionId string `json:"SessionId"`
|
|
|
- SshPassword string `json:"SshPassword"`
|
|
|
- }
|
|
|
- err := json.Unmarshal(password, &sshPasswordPayload)
|
|
|
- if err != nil {
|
|
|
- return nil, psiphon.ContextError(fmt.Errorf("invalid password payload for %q", conn.User()))
|
|
|
- }
|
|
|
-
|
|
|
- userOk := (subtle.ConstantTimeCompare(
|
|
|
- []byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
|
|
|
-
|
|
|
- passwordOk := (subtle.ConstantTimeCompare(
|
|
|
- []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
|
|
|
-
|
|
|
- if !userOk || !passwordOk {
|
|
|
- return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
|
|
|
- }
|
|
|
-
|
|
|
- psiphonSessionID := sshPasswordPayload.SessionId
|
|
|
-
|
|
|
- sshClient.Lock()
|
|
|
- sshClient.psiphonSessionID = psiphonSessionID
|
|
|
- geoIPData := sshClient.geoIPData
|
|
|
- sshClient.Unlock()
|
|
|
-
|
|
|
- if sshClient.sshServer.config.UseRedis() {
|
|
|
- err = UpdateRedisForLegacyPsiWeb(psiphonSessionID, geoIPData)
|
|
|
- if err != nil {
|
|
|
- log.WithContextFields(LogFields{
|
|
|
- "psiphonSessionID": psiphonSessionID,
|
|
|
- "error": err}).Warning("UpdateRedisForLegacyPsiWeb failed")
|
|
|
- // Allow the connection to proceed; legacy psi_web will not get accurate GeoIP values.
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return nil, nil
|
|
|
-}
|
|
|
-
|
|
|
-func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
|
|
|
- if err != nil {
|
|
|
- if sshClient.sshServer.config.UseFail2Ban() {
|
|
|
- clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
|
|
|
- if clientIPAddress != "" {
|
|
|
- LogFail2Ban(clientIPAddress)
|
|
|
- }
|
|
|
- }
|
|
|
- log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
|
|
|
- } else {
|
|
|
- log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-func (sshClient *sshClient) stop() {
|
|
|
-
|
|
|
- sshClient.sshConn.Close()
|
|
|
- sshClient.sshConn.Wait()
|
|
|
-
|
|
|
- close(sshClient.stopBroadcast)
|
|
|
- sshClient.channelHandlerWaitGroup.Wait()
|
|
|
-
|
|
|
- sshClient.Lock()
|
|
|
- log.WithContextFields(
|
|
|
- LogFields{
|
|
|
- "startTime": sshClient.startTime,
|
|
|
- "duration": time.Now().Sub(sshClient.startTime),
|
|
|
- "psiphonSessionID": sshClient.psiphonSessionID,
|
|
|
- "country": sshClient.geoIPData.Country,
|
|
|
- "city": sshClient.geoIPData.City,
|
|
|
- "ISP": sshClient.geoIPData.ISP,
|
|
|
- "bytesUpTCP": sshClient.tcpTrafficState.bytesUp,
|
|
|
- "bytesDownTCP": sshClient.tcpTrafficState.bytesDown,
|
|
|
- "peakConcurrentPortForwardCountTCP": sshClient.tcpTrafficState.peakConcurrentPortForwardCount,
|
|
|
- "totalPortForwardCountTCP": sshClient.tcpTrafficState.totalPortForwardCount,
|
|
|
- "bytesUpUDP": sshClient.udpTrafficState.bytesUp,
|
|
|
- "bytesDownUDP": sshClient.udpTrafficState.bytesDown,
|
|
|
- "peakConcurrentPortForwardCountUDP": sshClient.udpTrafficState.peakConcurrentPortForwardCount,
|
|
|
- "totalPortForwardCountUDP": sshClient.udpTrafficState.totalPortForwardCount,
|
|
|
- }).Info("tunnel closed")
|
|
|
- sshClient.Unlock()
|
|
|
-}
|