throttled.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package common
  20. import (
  21. "net"
  22. "sync"
  23. "sync/atomic"
  24. "time"
  25. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  26. "golang.org/x/time/rate"
  27. )
  28. // RateLimits specify the rate limits for a ThrottledConn.
  29. type RateLimits struct {
  30. // ReadUnthrottledBytes specifies the number of bytes to
  31. // read, approximately, before starting rate limiting.
  32. ReadUnthrottledBytes int64
  33. // ReadBytesPerSecond specifies a rate limit for read
  34. // data transfer. The default, 0, is no limit.
  35. ReadBytesPerSecond int64
  36. // WriteUnthrottledBytes specifies the number of bytes to
  37. // write, approximately, before starting rate limiting.
  38. WriteUnthrottledBytes int64
  39. // WriteBytesPerSecond specifies a rate limit for write
  40. // data transfer. The default, 0, is no limit.
  41. WriteBytesPerSecond int64
  42. // CloseAfterExhausted indicates that the underlying
  43. // net.Conn should be closed once either the read or
  44. // write unthrottled bytes have been exhausted. In this
  45. // case, throttling is never applied.
  46. CloseAfterExhausted bool
  47. }
  48. // ThrottledConn wraps a net.Conn with read and write rate limiters.
  49. // Rates are specified as bytes per second. Optional unlimited byte
  50. // counts allow for a number of bytes to read or write before
  51. // applying rate limiting. Specify limit values of 0 to set no rate
  52. // limit (unlimited counts are ignored in this case).
  53. // The underlying rate limiter uses the token bucket algorithm to
  54. // calculate delay times for read and write operations.
  55. type ThrottledConn struct {
  56. net.Conn
  57. readUnthrottledBytes atomic.Int64
  58. readBytesPerSecond atomic.Int64
  59. writeUnthrottledBytes atomic.Int64
  60. writeBytesPerSecond atomic.Int64
  61. closeAfterExhausted int32
  62. readLock sync.Mutex
  63. readRateLimiter *rate.Limiter
  64. readDelayTimer *time.Timer
  65. writeLock sync.Mutex
  66. writeRateLimiter *rate.Limiter
  67. writeDelayTimer *time.Timer
  68. isClosed int32
  69. stopBroadcast chan struct{}
  70. isStream bool
  71. }
  72. // NewThrottledConn initializes a new ThrottledConn.
  73. //
  74. // Set isStreamConn to true when conn is stream-oriented, such as TCP, and
  75. // false when the conn is packet-oriented, such as UDP. When conn is a
  76. // stream, reads and writes may be split to accomodate rate limits.
  77. func NewThrottledConn(
  78. conn net.Conn, isStream bool, limits RateLimits) *ThrottledConn {
  79. throttledConn := &ThrottledConn{
  80. Conn: conn,
  81. isStream: isStream,
  82. stopBroadcast: make(chan struct{}),
  83. }
  84. throttledConn.SetLimits(limits)
  85. return throttledConn
  86. }
  87. // SetLimits modifies the rate limits of an existing
  88. // ThrottledConn. It is safe to call SetLimits while
  89. // other goroutines are calling Read/Write. This function
  90. // will not block, and the new rate limits will be
  91. // applied within Read/Write, but not necessarily until
  92. // some further I/O at previous rates.
  93. func (conn *ThrottledConn) SetLimits(limits RateLimits) {
  94. // Using atomic instead of mutex to avoid blocking
  95. // this function on throttled I/O in an ongoing
  96. // read or write. Precise synchronized application
  97. // of the rate limit values is not required.
  98. // Negative rates are invalid and -1 is a special
  99. // value to used to signal throttling initialized
  100. // state. Silently normalize negative values to 0.
  101. rate := limits.ReadBytesPerSecond
  102. if rate < 0 {
  103. rate = 0
  104. }
  105. conn.readBytesPerSecond.Store(rate)
  106. conn.readUnthrottledBytes.Store(limits.ReadUnthrottledBytes)
  107. rate = limits.WriteBytesPerSecond
  108. if rate < 0 {
  109. rate = 0
  110. }
  111. conn.writeBytesPerSecond.Store(rate)
  112. conn.writeUnthrottledBytes.Store(limits.WriteUnthrottledBytes)
  113. closeAfterExhausted := int32(0)
  114. if limits.CloseAfterExhausted {
  115. closeAfterExhausted = 1
  116. }
  117. atomic.StoreInt32(&conn.closeAfterExhausted, closeAfterExhausted)
  118. }
  119. func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
  120. // A mutex is used to ensure conformance with net.Conn concurrency semantics.
  121. // The atomic.SwapInt64 and subsequent assignment of readRateLimiter or
  122. // readDelayTimer could be a race condition with concurrent reads.
  123. conn.readLock.Lock()
  124. defer conn.readLock.Unlock()
  125. if atomic.LoadInt32(&conn.isClosed) == 1 {
  126. return 0, errors.TraceNew("throttled conn closed")
  127. }
  128. // Use the base conn until the unthrottled count is
  129. // exhausted. This is only an approximate enforcement
  130. // since this read, or concurrent reads, could exceed
  131. // the remaining count.
  132. if conn.readUnthrottledBytes.Load() > 0 {
  133. n, err := conn.Conn.Read(buffer)
  134. conn.readUnthrottledBytes.Add(-int64(n))
  135. return n, err
  136. }
  137. if atomic.LoadInt32(&conn.closeAfterExhausted) == 1 {
  138. conn.Conn.Close()
  139. return 0, errors.TraceNew("throttled conn exhausted")
  140. }
  141. readRate := conn.readBytesPerSecond.Swap(-1)
  142. if readRate != -1 {
  143. // SetLimits has been called and a new rate limiter
  144. // must be initialized. When no limit is specified,
  145. // the reader/writer is simply the base conn.
  146. // No state is retained from the previous rate limiter,
  147. // so a pending I/O throttle sleep may be skipped when
  148. // the old and new rate are similar.
  149. if readRate == 0 {
  150. conn.readRateLimiter = nil
  151. } else {
  152. conn.readRateLimiter =
  153. rate.NewLimiter(rate.Limit(readRate), int(readRate))
  154. }
  155. }
  156. // The number of bytes read cannot exceed the rate limiter burst size,
  157. // which is enforced by rate.Limiter.ReserveN. Reduce any read buffer
  158. // size to be at most the burst size.
  159. //
  160. // Read should still return as soon as read bytes are available; and the
  161. // number of bytes that will be received is unknown; so there is no loop
  162. // here to read more bytes. Reducing the read buffer size minimizes
  163. // latency for the up-to-burst-size bytes read, whereas allowing a full
  164. // read followed by multiple ReserveN calls and sleeps would increase
  165. // latency.
  166. //
  167. // In practise, with Psiphon tunnels, throttling is not applied until
  168. // after the Psiphon API handshake, so read buffer reductions won't
  169. // impact early obfuscation traffic shaping; and reads are on the order
  170. // of one SSH "packet", up to 32K, unlikely to be split for all but the
  171. // most restrictive of rate limits.
  172. if conn.readRateLimiter != nil {
  173. burst := conn.readRateLimiter.Burst()
  174. if len(buffer) > burst {
  175. if !conn.isStream {
  176. return 0, errors.TraceNew("non-stream read buffer exceeds burst")
  177. }
  178. buffer = buffer[:burst]
  179. }
  180. }
  181. n, err := conn.Conn.Read(buffer)
  182. if n > 0 && conn.readRateLimiter != nil {
  183. // While rate.Limiter.WaitN would be simpler to use, internally Wait
  184. // creates a new timer for every call which must sleep, which is
  185. // expected to be most calls. Instead, call ReserveN to get the sleep
  186. // time and reuse one timer without allocation.
  187. //
  188. // TODO: avoid allocation: ReserveN allocates a *Reservation; while
  189. // the internal reserveN returns a struct, not a pointer.
  190. reservation := conn.readRateLimiter.ReserveN(time.Now(), n)
  191. if !reservation.OK() {
  192. // This error is not expected, given the buffer size adjustment.
  193. return 0, errors.TraceNew("burst size exceeded")
  194. }
  195. sleepDuration := reservation.Delay()
  196. if sleepDuration > 0 {
  197. if conn.readDelayTimer == nil {
  198. conn.readDelayTimer = time.NewTimer(sleepDuration)
  199. } else {
  200. conn.readDelayTimer.Reset(sleepDuration)
  201. }
  202. select {
  203. case <-conn.readDelayTimer.C:
  204. case <-conn.stopBroadcast:
  205. if !conn.readDelayTimer.Stop() {
  206. <-conn.readDelayTimer.C
  207. }
  208. }
  209. }
  210. }
  211. // Don't wrap I/O errors
  212. return n, err
  213. }
  214. func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
  215. // See comments in Read.
  216. conn.writeLock.Lock()
  217. defer conn.writeLock.Unlock()
  218. if atomic.LoadInt32(&conn.isClosed) == 1 {
  219. return 0, errors.TraceNew("throttled conn closed")
  220. }
  221. if conn.writeUnthrottledBytes.Load() > 0 {
  222. n, err := conn.Conn.Write(buffer)
  223. conn.writeUnthrottledBytes.Add(-int64(n))
  224. return n, err
  225. }
  226. if atomic.LoadInt32(&conn.closeAfterExhausted) == 1 {
  227. conn.Conn.Close()
  228. return 0, errors.TraceNew("throttled conn exhausted")
  229. }
  230. writeRate := conn.writeBytesPerSecond.Swap(-1)
  231. if writeRate != -1 {
  232. if writeRate == 0 {
  233. conn.writeRateLimiter = nil
  234. } else {
  235. conn.writeRateLimiter =
  236. rate.NewLimiter(rate.Limit(writeRate), int(writeRate))
  237. }
  238. }
  239. if conn.writeRateLimiter == nil {
  240. n, err := conn.Conn.Write(buffer)
  241. // Don't wrap I/O errors
  242. return n, err
  243. }
  244. // The number of bytes written cannot exceed the rate limiter burst size,
  245. // which is enforced by rate.Limiter.ReserveN. Split writes to be at most
  246. // the burst size.
  247. //
  248. // Splitting writes may have some effect on the shape of TCP packets sent
  249. // on the network.
  250. //
  251. // In practise, with Psiphon tunnels, throttling is not applied until
  252. // after the Psiphon API handshake, so write splits won't impact early
  253. // obfuscation traffic shaping; and writes are on the order of one
  254. // SSH "packet", up to 32K, unlikely to be split for all but the most
  255. // restrictive of rate limits.
  256. burst := conn.writeRateLimiter.Burst()
  257. if !conn.isStream && len(buffer) > burst {
  258. return 0, errors.TraceNew("non-stream write exceeds burst")
  259. }
  260. totalWritten := 0
  261. for i := 0; i < len(buffer); i += burst {
  262. j := i + burst
  263. if j > len(buffer) {
  264. j = len(buffer)
  265. }
  266. b := buffer[i:j]
  267. // See comment in Read regarding rate.Limiter.ReserveN vs.
  268. // rate.Limiter.WaitN.
  269. reservation := conn.writeRateLimiter.ReserveN(time.Now(), len(b))
  270. if !reservation.OK() {
  271. // This error is not expected, given the write split adjustments.
  272. return 0, errors.TraceNew("burst size exceeded")
  273. }
  274. sleepDuration := reservation.Delay()
  275. if sleepDuration > 0 {
  276. if conn.writeDelayTimer == nil {
  277. conn.writeDelayTimer = time.NewTimer(sleepDuration)
  278. } else {
  279. conn.writeDelayTimer.Reset(sleepDuration)
  280. }
  281. select {
  282. case <-conn.writeDelayTimer.C:
  283. case <-conn.stopBroadcast:
  284. if !conn.writeDelayTimer.Stop() {
  285. <-conn.writeDelayTimer.C
  286. }
  287. }
  288. }
  289. n, err := conn.Conn.Write(b)
  290. totalWritten += n
  291. if err != nil {
  292. // Don't wrap I/O errors
  293. return totalWritten, err
  294. }
  295. }
  296. return totalWritten, nil
  297. }
  298. func (conn *ThrottledConn) Close() error {
  299. // Ensure close channel only called once.
  300. if !atomic.CompareAndSwapInt32(&conn.isClosed, 0, 1) {
  301. return nil
  302. }
  303. close(conn.stopBroadcast)
  304. return errors.Trace(conn.Conn.Close())
  305. }