|
@@ -211,16 +211,20 @@ func (sshServer *sshServer) stopClient(client *sshClient) {
|
|
|
client.Lock()
|
|
client.Lock()
|
|
|
log.WithContextFields(
|
|
log.WithContextFields(
|
|
|
LogFields{
|
|
LogFields{
|
|
|
- "startTime": client.startTime,
|
|
|
|
|
- "duration": time.Now().Sub(client.startTime),
|
|
|
|
|
- "psiphonSessionID": client.psiphonSessionID,
|
|
|
|
|
- "country": client.geoIPData.Country,
|
|
|
|
|
- "city": client.geoIPData.City,
|
|
|
|
|
- "ISP": client.geoIPData.ISP,
|
|
|
|
|
- "bytesUp": client.bytesUp,
|
|
|
|
|
- "bytesDown": client.bytesDown,
|
|
|
|
|
- "portForwardCount": client.portForwardCount,
|
|
|
|
|
- "peakConcurrentPortForwardCount": client.peakConcurrentPortForwardCount,
|
|
|
|
|
|
|
+ "startTime": client.startTime,
|
|
|
|
|
+ "duration": time.Now().Sub(client.startTime),
|
|
|
|
|
+ "psiphonSessionID": client.psiphonSessionID,
|
|
|
|
|
+ "country": client.geoIPData.Country,
|
|
|
|
|
+ "city": client.geoIPData.City,
|
|
|
|
|
+ "ISP": client.geoIPData.ISP,
|
|
|
|
|
+ "bytesUpTCP": client.tcpTrafficState.bytesUp,
|
|
|
|
|
+ "bytesDownTCP": client.tcpTrafficState.bytesDown,
|
|
|
|
|
+ "portForwardCountTCP": client.tcpTrafficState.portForwardCount,
|
|
|
|
|
+ "peakConcurrentPortForwardCountTCP": client.tcpTrafficState.peakConcurrentPortForwardCount,
|
|
|
|
|
+ "bytesUpUDP": client.udpTrafficState.bytesUp,
|
|
|
|
|
+ "bytesDownUDP": client.udpTrafficState.bytesDown,
|
|
|
|
|
+ "portForwardCountUDP": client.udpTrafficState.portForwardCount,
|
|
|
|
|
+ "peakConcurrentPortForwardCountUDP": client.udpTrafficState.peakConcurrentPortForwardCount,
|
|
|
}).Info("tunnel closed")
|
|
}).Info("tunnel closed")
|
|
|
client.Unlock()
|
|
client.Unlock()
|
|
|
}
|
|
}
|
|
@@ -239,12 +243,16 @@ func (sshServer *sshServer) stopClients() {
|
|
|
|
|
|
|
|
func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
|
|
|
|
|
|
|
+ geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr()))
|
|
|
|
|
+
|
|
|
sshClient := &sshClient{
|
|
sshClient := &sshClient{
|
|
|
- sshServer: sshServer,
|
|
|
|
|
- startTime: time.Now(),
|
|
|
|
|
- geoIPData: GeoIPLookup(psiphon.IPAddressFromAddr(tcpConn.RemoteAddr())),
|
|
|
|
|
|
|
+ sshServer: sshServer,
|
|
|
|
|
+ startTime: time.Now(),
|
|
|
|
|
+ geoIPData: geoIPData,
|
|
|
|
|
+ trafficRules: sshServer.config.GetTrafficRules(geoIPData.Country),
|
|
|
|
|
+ tcpTrafficState: &trafficState{},
|
|
|
|
|
+ udpTrafficState: &trafficState{},
|
|
|
}
|
|
}
|
|
|
- sshClient.trafficRules = sshServer.config.GetTrafficRules(sshClient.geoIPData.Country)
|
|
|
|
|
|
|
|
|
|
// Wrap the base TCP connection with an IdleTimeoutConn which will terminate
|
|
// Wrap the base TCP connection with an IdleTimeoutConn which will terminate
|
|
|
// the connection if no data is received before the deadline. This timeout is
|
|
// the connection if no data is received before the deadline. This timeout is
|
|
@@ -252,7 +260,16 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
// use the connection or send SSH keep alive requests to keep the connection
|
|
// use the connection or send SSH keep alive requests to keep the connection
|
|
|
// active.
|
|
// active.
|
|
|
|
|
|
|
|
- conn := psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
|
|
|
|
|
|
|
+ var conn net.Conn
|
|
|
|
|
+
|
|
|
|
|
+ conn = psiphon.NewIdleTimeoutConn(tcpConn, SSH_CONNECTION_READ_DEADLINE, false)
|
|
|
|
|
+
|
|
|
|
|
+ // Further wrap the connection in a rate limiting ThrottledConn.
|
|
|
|
|
+
|
|
|
|
|
+ conn = psiphon.NewThrottledConn(
|
|
|
|
|
+ conn,
|
|
|
|
|
+ int64(sshClient.trafficRules.LimitDownstreamBytesPerSecond),
|
|
|
|
|
+ int64(sshClient.trafficRules.LimitUpstreamBytesPerSecond))
|
|
|
|
|
|
|
|
// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
|
|
// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
|
|
|
// respect shutdownBroadcast and implement a specific handshake timeout.
|
|
// respect shutdownBroadcast and implement a specific handshake timeout.
|
|
@@ -334,12 +351,18 @@ func (sshServer *sshServer) handleClient(tcpConn *net.TCPConn) {
|
|
|
|
|
|
|
|
type sshClient struct {
|
|
type sshClient struct {
|
|
|
sync.Mutex
|
|
sync.Mutex
|
|
|
- sshServer *sshServer
|
|
|
|
|
- sshConn ssh.Conn
|
|
|
|
|
- startTime time.Time
|
|
|
|
|
- geoIPData GeoIPData
|
|
|
|
|
- trafficRules TrafficRules
|
|
|
|
|
- psiphonSessionID string
|
|
|
|
|
|
|
+ sshServer *sshServer
|
|
|
|
|
+ sshConn ssh.Conn
|
|
|
|
|
+ startTime time.Time
|
|
|
|
|
+ geoIPData GeoIPData
|
|
|
|
|
+ psiphonSessionID string
|
|
|
|
|
+ udpChannel ssh.Channel
|
|
|
|
|
+ trafficRules TrafficRules
|
|
|
|
|
+ tcpTrafficState *trafficState
|
|
|
|
|
+ udpTrafficState *trafficState
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+type trafficState struct {
|
|
|
bytesUp int64
|
|
bytesUp int64
|
|
|
bytesDown int64
|
|
bytesDown int64
|
|
|
portForwardCount int64
|
|
portForwardCount int64
|
|
@@ -355,20 +378,8 @@ func (sshClient *sshClient) handleChannels(channels <-chan ssh.NewChannel) {
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if sshClient.trafficRules.MaxClientPortForwardCount > 0 {
|
|
|
|
|
- sshClient.Lock()
|
|
|
|
|
- limitExceeded := sshClient.portForwardCount >= int64(sshClient.trafficRules.MaxClientPortForwardCount)
|
|
|
|
|
- sshClient.Unlock()
|
|
|
|
|
-
|
|
|
|
|
- if limitExceeded {
|
|
|
|
|
- sshClient.rejectNewChannel(
|
|
|
|
|
- newChannel, ssh.ResourceShortage, "maximum port forward limit exceeded")
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
// process each port forward concurrently
|
|
// process each port forward concurrently
|
|
|
- go sshClient.handleNewDirectTcpipChannel(newChannel)
|
|
|
|
|
|
|
+ go sshClient.handleNewPortForwardChannel(newChannel)
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -383,7 +394,7 @@ func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, reason s
|
|
|
newChannel.Reject(reason, message)
|
|
newChannel.Reject(reason, message)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChannel) {
|
|
|
|
|
|
|
+func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChannel) {
|
|
|
|
|
|
|
|
// http://tools.ietf.org/html/rfc4254#section-7.2
|
|
// http://tools.ietf.org/html/rfc4254#section-7.2
|
|
|
var directTcpipExtraData struct {
|
|
var directTcpipExtraData struct {
|
|
@@ -399,14 +410,109 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- targetAddr := fmt.Sprintf("%s:%d",
|
|
|
|
|
- directTcpipExtraData.HostToConnect,
|
|
|
|
|
- directTcpipExtraData.PortToConnect)
|
|
|
|
|
|
|
+ // Intercept TCP port forwards to a specified udpgw server and handle directly.
|
|
|
|
|
+ // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
|
|
|
|
|
+ isUDPChannel := sshClient.sshServer.config.UdpgwServerAddress != "" &&
|
|
|
|
|
+ sshClient.sshServer.config.UdpgwServerAddress ==
|
|
|
|
|
+ fmt.Sprintf("%s:%d",
|
|
|
|
|
+ directTcpipExtraData.HostToConnect,
|
|
|
|
|
+ directTcpipExtraData.PortToConnect)
|
|
|
|
|
+
|
|
|
|
|
+ if isUDPChannel {
|
|
|
|
|
+ sshClient.handleUDPChannel(newChannel)
|
|
|
|
|
+ } else {
|
|
|
|
|
+ sshClient.handleTCPChannel(
|
|
|
|
|
+ directTcpipExtraData.HostToConnect, int(directTcpipExtraData.PortToConnect), newChannel)
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (sshClient *sshClient) isPortForwardPermitted(
|
|
|
|
|
+ port int, allowPorts []int, denyPorts []int) bool {
|
|
|
|
|
+
|
|
|
|
|
+ // TODO: faster lookup?
|
|
|
|
|
+ if allowPorts != nil {
|
|
|
|
|
+ for _, allowPort := range allowPorts {
|
|
|
|
|
+ if port == allowPort {
|
|
|
|
|
+ return true
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return false
|
|
|
|
|
+ }
|
|
|
|
|
+ if denyPorts != nil {
|
|
|
|
|
+ for _, denyPort := range denyPorts {
|
|
|
|
|
+ if port == denyPort {
|
|
|
|
|
+ return false
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return true
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (sshClient *sshClient) isPortForwardLimitExceeded(
|
|
|
|
|
+ state *trafficState, maxPortForwardCount int) bool {
|
|
|
|
|
+
|
|
|
|
|
+ limitExceeded := false
|
|
|
|
|
+ if maxPortForwardCount > 0 {
|
|
|
|
|
+ sshClient.Lock()
|
|
|
|
|
+ limitExceeded = state.portForwardCount >= int64(maxPortForwardCount)
|
|
|
|
|
+ sshClient.Unlock()
|
|
|
|
|
+ }
|
|
|
|
|
+ return limitExceeded
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (sshClient *sshClient) establishedPortForward(
|
|
|
|
|
+ state *trafficState) {
|
|
|
|
|
+
|
|
|
|
|
+ sshClient.Lock()
|
|
|
|
|
+ state.portForwardCount += 1
|
|
|
|
|
+ state.concurrentPortForwardCount += 1
|
|
|
|
|
+ if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
|
|
|
|
|
+ state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
|
|
|
|
|
+ }
|
|
|
|
|
+ sshClient.Unlock()
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (sshClient *sshClient) closedPortForward(
|
|
|
|
|
+ state *trafficState, bytesUp, bytesDown int64) {
|
|
|
|
|
+
|
|
|
|
|
+ sshClient.Lock()
|
|
|
|
|
+ state.concurrentPortForwardCount -= 1
|
|
|
|
|
+ state.bytesUp += bytesUp
|
|
|
|
|
+ state.bytesDown += bytesDown
|
|
|
|
|
+ sshClient.Unlock()
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (sshClient *sshClient) handleTCPChannel(
|
|
|
|
|
+ hostToConnect string,
|
|
|
|
|
+ portToConnect int,
|
|
|
|
|
+ newChannel ssh.NewChannel) {
|
|
|
|
|
+
|
|
|
|
|
+ if !sshClient.isPortForwardPermitted(
|
|
|
|
|
+ portToConnect,
|
|
|
|
|
+ sshClient.trafficRules.AllowTCPPorts,
|
|
|
|
|
+ sshClient.trafficRules.DenyTCPPorts) {
|
|
|
|
|
+
|
|
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
|
|
+ newChannel, ssh.Prohibited, "port forward not permitted")
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // TODO: close LRU connection (after successful Dial) instead of rejecting new connection?
|
|
|
|
|
+ if sshClient.isPortForwardLimitExceeded(
|
|
|
|
|
+ sshClient.tcpTrafficState,
|
|
|
|
|
+ sshClient.trafficRules.MaxTCPPortForwardCount) {
|
|
|
|
|
+
|
|
|
|
|
+ sshClient.rejectNewChannel(
|
|
|
|
|
+ newChannel, ssh.Prohibited, "maximum port forward limit exceeded")
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ targetAddr := fmt.Sprintf("%s:%d", hostToConnect, portToConnect)
|
|
|
|
|
|
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("dialing")
|
|
|
|
|
|
|
|
|
|
+ // TODO: on EADDRNOTAVAIL, temporarily suspend new clients
|
|
|
// TODO: port forward dial timeout
|
|
// TODO: port forward dial timeout
|
|
|
- // TODO: report ssh.ResourceShortage when appropriate
|
|
|
|
|
// TODO: IPv6 support
|
|
// TODO: IPv6 support
|
|
|
fwdConn, err := net.Dial("tcp4", targetAddr)
|
|
fwdConn, err := net.Dial("tcp4", targetAddr)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -420,21 +526,13 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
log.WithContextFields(LogFields{"error": err}).Warning("accept new channel failed")
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
+ go ssh.DiscardRequests(requests)
|
|
|
|
|
+ defer fwdChannel.Close()
|
|
|
|
|
|
|
|
- sshClient.Lock()
|
|
|
|
|
- sshClient.portForwardCount += 1
|
|
|
|
|
- sshClient.concurrentPortForwardCount += 1
|
|
|
|
|
- if sshClient.concurrentPortForwardCount > sshClient.peakConcurrentPortForwardCount {
|
|
|
|
|
- sshClient.peakConcurrentPortForwardCount = sshClient.concurrentPortForwardCount
|
|
|
|
|
- }
|
|
|
|
|
- sshClient.Unlock()
|
|
|
|
|
|
|
+ sshClient.establishedPortForward(sshClient.tcpTrafficState)
|
|
|
|
|
|
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("relaying")
|
|
|
|
|
|
|
|
- go ssh.DiscardRequests(requests)
|
|
|
|
|
-
|
|
|
|
|
- defer fwdChannel.Close()
|
|
|
|
|
-
|
|
|
|
|
// When idle port forward traffic rules are in place, wrap fwdConn
|
|
// When idle port forward traffic rules are in place, wrap fwdConn
|
|
|
// in an IdleTimeoutConn configured to reset idle on writes as well
|
|
// in an IdleTimeoutConn configured to reset idle on writes as well
|
|
|
// as read. This ensures the port forward idle timeout only happens
|
|
// as read. This ensures the port forward idle timeout only happens
|
|
@@ -449,6 +547,7 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
|
|
|
|
|
// relay channel to forwarded connection
|
|
// relay channel to forwarded connection
|
|
|
// TODO: relay errors to fwdChannel.Stderr()?
|
|
// TODO: relay errors to fwdChannel.Stderr()?
|
|
|
|
|
+ // TODO: use a low-memory io.Copy?
|
|
|
|
|
|
|
|
var bytesUp, bytesDown int64
|
|
var bytesUp, bytesDown int64
|
|
|
|
|
|
|
@@ -457,51 +556,23 @@ func (sshClient *sshClient) handleNewDirectTcpipChannel(newChannel ssh.NewChanne
|
|
|
go func() {
|
|
go func() {
|
|
|
defer relayWaitGroup.Done()
|
|
defer relayWaitGroup.Done()
|
|
|
var err error
|
|
var err error
|
|
|
- bytesUp, err = copyWithThrottle(
|
|
|
|
|
- fwdConn, fwdChannel, sshClient.trafficRules.ThrottleUpstreamSleepMilliseconds)
|
|
|
|
|
|
|
+ bytesUp, err = io.Copy(fwdConn, fwdChannel)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- log.WithContextFields(LogFields{"error": err}).Warning("upstream relay failed")
|
|
|
|
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Warning("upstream TCP relay failed")
|
|
|
}
|
|
}
|
|
|
}()
|
|
}()
|
|
|
- bytesDown, err = copyWithThrottle(
|
|
|
|
|
- fwdChannel, fwdConn, sshClient.trafficRules.ThrottleDownstreamSleepMilliseconds)
|
|
|
|
|
|
|
+ bytesDown, err = io.Copy(fwdChannel, fwdConn)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- log.WithContextFields(LogFields{"error": err}).Warning("downstream relay failed")
|
|
|
|
|
|
|
+ log.WithContextFields(LogFields{"error": err}).Warning("downstream TCP relay failed")
|
|
|
}
|
|
}
|
|
|
fwdChannel.CloseWrite()
|
|
fwdChannel.CloseWrite()
|
|
|
relayWaitGroup.Wait()
|
|
relayWaitGroup.Wait()
|
|
|
|
|
|
|
|
- sshClient.Lock()
|
|
|
|
|
- sshClient.concurrentPortForwardCount -= 1
|
|
|
|
|
- sshClient.bytesUp += bytesUp
|
|
|
|
|
- sshClient.bytesDown += bytesDown
|
|
|
|
|
- sshClient.Unlock()
|
|
|
|
|
|
|
+ sshClient.closedPortForward(sshClient.tcpTrafficState, bytesUp, bytesDown)
|
|
|
|
|
|
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
|
|
log.WithContextFields(LogFields{"target": targetAddr}).Debug("exiting")
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func copyWithThrottle(dst io.Writer, src io.Reader, throttleSleepMilliseconds int) (int64, error) {
|
|
|
|
|
- // TODO: use a low-memory io.Copy?
|
|
|
|
|
- if throttleSleepMilliseconds <= 0 {
|
|
|
|
|
- // No throttle
|
|
|
|
|
- return io.Copy(dst, src)
|
|
|
|
|
- }
|
|
|
|
|
- var totalBytes int64
|
|
|
|
|
- for {
|
|
|
|
|
- bytes, err := io.CopyN(dst, src, SSH_THROTTLED_PORT_FORWARD_MAX_COPY)
|
|
|
|
|
- totalBytes += bytes
|
|
|
|
|
- if err == io.EOF {
|
|
|
|
|
- err = nil
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return totalBytes, psiphon.ContextError(err)
|
|
|
|
|
- }
|
|
|
|
|
- time.Sleep(time.Duration(throttleSleepMilliseconds) * time.Millisecond)
|
|
|
|
|
- }
|
|
|
|
|
- return totalBytes, nil
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
|
var sshPasswordPayload struct {
|
|
var sshPasswordPayload struct {
|
|
|
SessionId string `json:"SessionId"`
|
|
SessionId string `json:"SessionId"`
|