|
|
@@ -211,16 +211,16 @@ func (sshServer *sshServer) stopClient(client *sshClient) {
|
|
|
client.Lock()
|
|
|
log.WithContextFields(
|
|
|
LogFields{
|
|
|
- "startTime": client.startTime,
|
|
|
- "duration": time.Now().Sub(client.startTime),
|
|
|
- "psiphonSessionID": client.psiphonSessionID,
|
|
|
- "country": client.geoIPData.Country,
|
|
|
- "city": client.geoIPData.City,
|
|
|
- "ISP": client.geoIPData.ISP,
|
|
|
- "bytesUp": client.bytesUp,
|
|
|
- "bytesDown": client.bytesDown,
|
|
|
- "portForwardCount": client.portForwardCount,
|
|
|
- "maxConcurrentPortForwardCount": client.maxConcurrentPortForwardCount,
|
|
|
+ "startTime": client.startTime,
|
|
|
+ "duration": time.Now().Sub(client.startTime),
|
|
|
+ "psiphonSessionID": client.psiphonSessionID,
|
|
|
+ "country": client.geoIPData.Country,
|
|
|
+ "city": client.geoIPData.City,
|
|
|
+ "ISP": client.geoIPData.ISP,
|
|
|
+ "bytesUp": client.bytesUp,
|
|
|
+ "bytesDown": client.bytesDown,
|
|
|
+ "portForwardCount": client.portForwardCount,
|
|
|
+ "peakConcurrentPortForwardCount": client.peakConcurrentPortForwardCount,
|
|
|
}).Info("tunnel closed")
|
|
|
client.Unlock()
|
|
|
}
|
|
|
@@ -244,14 +244,15 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
startTime: time.Now(),
|
|
|
geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())),
|
|
|
}
|
|
|
+ sshClient.trafficRules = sshServer.config.GetTrafficRules(sshClient.geoIPData.Country)
|
|
|
|
|
|
- // Wrap the base TCP connection in a TimeoutTCPConn which will terminate
|
|
|
+ // Wrap the base TCP connection with an IdleTimeoutConn which will terminate
|
|
|
// the connection if it's idle for too long. This timeout is in effect for
|
|
|
// the entire duration of the SSH connection. Clients must actively use
|
|
|
// the connection or send SSH keep alive requests to keep the connection
|
|
|
// active.
|
|
|
|
|
|
- conn := psiphon.NewTimeoutTCPConn(tcpConn, SSH_CONNECTION_READ_DEADLINE)
|
|
|
+ conn := psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE)
|
|
|
|
|
|
// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
|
|
|
// respect shutdownBroadcast and implement a specific handshake timeout.
|
|
|
@@ -320,7 +321,7 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
|
|
|
clientID, ok := sshServer.registerClient(sshClient)
|
|
|
if !ok {
|
|
|
- tcpConn.Close()
|
|
|
+ conn.Close()
|
|
|
log.WithContext().Warning("register failed")
|
|
|
return
|
|
|
}
|
|
|
@@ -333,16 +334,17 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
|
|
|
type sshClient struct {
|
|
|
sync.Mutex
|
|
|
- sshServer *sshServer
|
|
|
- sshConn ssh.Conn
|
|
|
- startTime time.Time
|
|
|
- geoIPData GeoIPData
|
|
|
- psiphonSessionID string
|
|
|
- bytesUp int64
|
|
|
- bytesDown int64
|
|
|
- portForwardCount int64
|
|
|
- concurrentPortForwardCount int64
|
|
|
- maxConcurrentPortForwardCount int64
|
|
|
+ sshServer *sshServer
|
|
|
+ sshConn ssh.Conn
|
|
|
+ startTime time.Time
|
|
|
+ geoIPData GeoIPData
|
|
|
+ trafficRules TrafficRules
|
|
|
+ psiphonSessionID string
|
|
|
+ bytesUp int64
|
|
|
+ bytesDown int64
|
|
|
+ portForwardCount int64
|
|
|
+ concurrentPortForwardCount int64
|
|
|
+ peakConcurrentPortForwardCount int64
|
|
|
}
|
|
|
|
|
|
func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
|
|
|
@@ -353,6 +355,18 @@ func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ if sshClient.trafficRules.MaxClientPortForwardCount > 0 {
|
|
|
+ sshClient.Lock()
|
|
|
+ limitExceeded := sshClient.portForwardCount >= int64(sshClient.trafficRules.MaxClientPortForwardCount)
|
|
|
+ sshClient.Unlock()
|
|
|
+
|
|
|
+ if limitExceeded {
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
+ newChannel, ssh.ResourceShortage, "maximum port forward limit exceeded")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// process each port forward concurrently
|
|
|
go sshClient.handleNewDirectTcpipChannel(newChannel)
|
|
|
}
|
|
|
@@ -410,8 +424,8 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
sshClient.Lock()
|
|
|
sshClient.portForwardCount += 1
|
|
|
sshClient.concurrentPortForwardCount += 1
|
|
|
- if sshClient.concurrentPortForwardCount > sshClient.maxConcurrentPortForwardCount {
|
|
|
- sshClient.maxConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
|
|
|
+ if sshClient.concurrentPortForwardCount > sshClient.peakConcurrentPortForwardCount {
|
|
|
+ sshClient.peakConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
|
|
|
}
|
|
|
sshClient.Unlock()
|
|
|
|
|
|
@@ -421,9 +435,20 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
|
|
|
defer fwdChannel.Close()
|
|
|
|
|
|
- // relay channel to forwarded connection
|
|
|
+ // TODO: Fix -- fwdChannel is a ssh.Channel which is not a net.Conn, which
|
|
|
+ // NewJointIdleTimeoutConn expects.
|
|
|
+ /*
|
|
|
+ // When idle port forward traffic rules are in place, wrap each end of the
|
|
|
+ // port forward in peer JointIdleTimeoutConns which triggers an idle
|
|
|
+ // timeout only if both ends are (read) idle.
|
|
|
+ if sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds > 0 {
|
|
|
+ fwdConn, fwdChannel = psiphon.NewJointIdleTimeoutConn(
|
|
|
+ fwdConn, fwdChannel,
|
|
|
+ time.Duration(sshClient.trafficRules.IdlePortForwardTimeoutMilliseconds)*time.Millisecond)
|
|
|
+ }
|
|
|
+ */
|
|
|
|
|
|
- // TODO: use a low-memory io.Copy?
|
|
|
+ // relay channel to forwarded connection
|
|
|
// TODO: relay errors to fwdChannel.Stderr()?
|
|
|
|
|
|
var bytesUp, bytesDown int64
|
|
|
@@ -433,12 +458,14 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
go func() {
|
|
|
defer relayWaitGroup.Done()
|
|
|
var err error
|
|
|
- bytesUp, err = io.Copy(fwdConn, fwdChannel)
|
|
|
+ bytesUp, err = copyWithThrottle(
|
|
|
+ fwdConn, fwdChannel, sshClient.trafficRules.ThrottleUpstreamSleepMilliseconds)
|
|
|
if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
|
|
|
}
|
|
|
}()
|
|
|
- bytesDown, err = io.Copy(fwdChannel, fwdConn)
|
|
|
+ bytesDown, err = copyWithThrottle(
|
|
|
+ fwdChannel, fwdConn, sshClient.trafficRules.ThrottleDownstreamSleepMilliseconds)
|
|
|
if err != nil {
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
|
|
|
}
|
|
|
@@ -454,6 +481,28 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
|
|
|
}
|
|
|
|
|
|
+func copyWithThrottle(dst io.Writer, src io.Reader, throttleSleepMilliseconds int) (int64, error) {
|
|
|
+ // TODO: use a low-memory io.Copy?
|
|
|
+ if throttleSleepMilliseconds <= 0 {
|
|
|
+ // No throttle
|
|
|
+ return io.Copy(dst, src)
|
|
|
+ }
|
|
|
+ var totalBytes int64
|
|
|
+ for {
|
|
|
+ bytes, err := io.CopyN(dst, src, SSH_THROTTLED_PORT_FORWARD_MAX_COPY)
|
|
|
+ totalBytes += bytes
|
|
|
+ if err == io.EOF {
|
|
|
+ err = nil
|
|
|
+ break
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ return totalBytes, psiphon.ContextError(err)
|
|
|
+ }
|
|
|
+ time.Sleep(time.Duration(throttleSleepMilliseconds) * time.Millisecond)
|
|
|
+ }
|
|
|
+ return totalBytes, nil
|
|
|
+}
|
|
|
+
|
|
|
func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
var sshPasswordPayload struct {
|
|
|
SessionId string `json:"SessionId"`
|