throttled.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package common
  20. import (
  21. "net"
  22. "sync"
  23. "sync/atomic"
  24. "time"
  25. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  26. "github.com/juju/ratelimit"
  27. )
  28. // RateLimits specify the rate limits for a ThrottledConn.
  29. type RateLimits struct {
  30. // ReadUnthrottledBytes specifies the number of bytes to
  31. // read, approximately, before starting rate limiting.
  32. ReadUnthrottledBytes int64
  33. // ReadBytesPerSecond specifies a rate limit for read
  34. // data transfer. The default, 0, is no limit.
  35. ReadBytesPerSecond int64
  36. // WriteUnthrottledBytes specifies the number of bytes to
  37. // write, approximately, before starting rate limiting.
  38. WriteUnthrottledBytes int64
  39. // WriteBytesPerSecond specifies a rate limit for write
  40. // data transfer. The default, 0, is no limit.
  41. WriteBytesPerSecond int64
  42. // CloseAfterExhausted indicates that the underlying
  43. // net.Conn should be closed once either the read or
  44. // write unthrottled bytes have been exhausted. In this
  45. // case, throttling is never applied.
  46. CloseAfterExhausted bool
  47. }
  48. // ThrottledConn wraps a net.Conn with read and write rate limiters.
  49. // Rates are specified as bytes per second. Optional unlimited byte
  50. // counts allow for a number of bytes to read or write before
  51. // applying rate limiting. Specify limit values of 0 to set no rate
  52. // limit (unlimited counts are ignored in this case).
  53. // The underlying rate limiter uses the token bucket algorithm to
  54. // calculate delay times for read and write operations.
  55. type ThrottledConn struct {
  56. // Note: 64-bit ints used with atomic operations are placed
  57. // at the start of struct to ensure 64-bit alignment.
  58. // (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
  59. readUnthrottledBytes int64
  60. readBytesPerSecond int64
  61. writeUnthrottledBytes int64
  62. writeBytesPerSecond int64
  63. closeAfterExhausted int32
  64. readLock sync.Mutex
  65. readRateLimiter *ratelimit.Bucket
  66. readDelayTimer *time.Timer
  67. writeLock sync.Mutex
  68. writeRateLimiter *ratelimit.Bucket
  69. writeDelayTimer *time.Timer
  70. isClosed int32
  71. stopBroadcast chan struct{}
  72. net.Conn
  73. }
  74. // NewThrottledConn initializes a new ThrottledConn.
  75. func NewThrottledConn(conn net.Conn, limits RateLimits) *ThrottledConn {
  76. throttledConn := &ThrottledConn{
  77. Conn: conn,
  78. stopBroadcast: make(chan struct{}),
  79. }
  80. throttledConn.SetLimits(limits)
  81. return throttledConn
  82. }
  83. // SetLimits modifies the rate limits of an existing
  84. // ThrottledConn. It is safe to call SetLimits while
  85. // other goroutines are calling Read/Write. This function
  86. // will not block, and the new rate limits will be
  87. // applied within Read/Write, but not necessarily until
  88. // some further I/O at previous rates.
  89. func (conn *ThrottledConn) SetLimits(limits RateLimits) {
  90. // Using atomic instead of mutex to avoid blocking
  91. // this function on throttled I/O in an ongoing
  92. // read or write. Precise synchronized application
  93. // of the rate limit values is not required.
  94. // Negative rates are invalid and -1 is a special
  95. // value to used to signal throttling initialized
  96. // state. Silently normalize negative values to 0.
  97. rate := limits.ReadBytesPerSecond
  98. if rate < 0 {
  99. rate = 0
  100. }
  101. atomic.StoreInt64(&conn.readBytesPerSecond, rate)
  102. atomic.StoreInt64(&conn.readUnthrottledBytes, limits.ReadUnthrottledBytes)
  103. rate = limits.WriteBytesPerSecond
  104. if rate < 0 {
  105. rate = 0
  106. }
  107. atomic.StoreInt64(&conn.writeBytesPerSecond, rate)
  108. atomic.StoreInt64(&conn.writeUnthrottledBytes, limits.WriteUnthrottledBytes)
  109. closeAfterExhausted := int32(0)
  110. if limits.CloseAfterExhausted {
  111. closeAfterExhausted = 1
  112. }
  113. atomic.StoreInt32(&conn.closeAfterExhausted, closeAfterExhausted)
  114. }
  115. func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
  116. // A mutex is used to ensure conformance with net.Conn concurrency semantics.
  117. // The atomic.SwapInt64 and subsequent assignment of readRateLimiter or
  118. // readDelayTimer could be a race condition with concurrent reads.
  119. conn.readLock.Lock()
  120. defer conn.readLock.Unlock()
  121. select {
  122. case <-conn.stopBroadcast:
  123. return 0, errors.TraceNew("throttled conn closed")
  124. default:
  125. }
  126. // Use the base conn until the unthrottled count is
  127. // exhausted. This is only an approximate enforcement
  128. // since this read, or concurrent reads, could exceed
  129. // the remaining count.
  130. if atomic.LoadInt64(&conn.readUnthrottledBytes) > 0 {
  131. n, err := conn.Conn.Read(buffer)
  132. atomic.AddInt64(&conn.readUnthrottledBytes, -int64(n))
  133. return n, err
  134. }
  135. if atomic.LoadInt32(&conn.closeAfterExhausted) == 1 {
  136. conn.Conn.Close()
  137. return 0, errors.TraceNew("throttled conn exhausted")
  138. }
  139. rate := atomic.SwapInt64(&conn.readBytesPerSecond, -1)
  140. if rate != -1 {
  141. // SetLimits has been called and a new rate limiter
  142. // must be initialized. When no limit is specified,
  143. // the reader/writer is simply the base conn.
  144. // No state is retained from the previous rate limiter,
  145. // so a pending I/O throttle sleep may be skipped when
  146. // the old and new rate are similar.
  147. if rate == 0 {
  148. conn.readRateLimiter = nil
  149. } else {
  150. conn.readRateLimiter =
  151. ratelimit.NewBucketWithRate(float64(rate), rate)
  152. }
  153. }
  154. n, err := conn.Conn.Read(buffer)
  155. // Sleep to enforce the rate limit. This is the same logic as implemented in
  156. // ratelimit.Reader, but using a timer and a close signal instead of an
  157. // uninterruptible time.Sleep.
  158. //
  159. // The readDelayTimer is always expired/stopped and drained after this code
  160. // block and is ready to be Reset on the next call.
  161. if n >= 0 && conn.readRateLimiter != nil {
  162. sleepDuration := conn.readRateLimiter.Take(int64(n))
  163. if sleepDuration > 0 {
  164. if conn.readDelayTimer == nil {
  165. conn.readDelayTimer = time.NewTimer(sleepDuration)
  166. } else {
  167. conn.readDelayTimer.Reset(sleepDuration)
  168. }
  169. select {
  170. case <-conn.readDelayTimer.C:
  171. case <-conn.stopBroadcast:
  172. if !conn.readDelayTimer.Stop() {
  173. <-conn.readDelayTimer.C
  174. }
  175. }
  176. }
  177. }
  178. return n, errors.Trace(err)
  179. }
  180. func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
  181. // See comments in Read.
  182. conn.writeLock.Lock()
  183. defer conn.writeLock.Unlock()
  184. select {
  185. case <-conn.stopBroadcast:
  186. return 0, errors.TraceNew("throttled conn closed")
  187. default:
  188. }
  189. if atomic.LoadInt64(&conn.writeUnthrottledBytes) > 0 {
  190. n, err := conn.Conn.Write(buffer)
  191. atomic.AddInt64(&conn.writeUnthrottledBytes, -int64(n))
  192. return n, err
  193. }
  194. if atomic.LoadInt32(&conn.closeAfterExhausted) == 1 {
  195. conn.Conn.Close()
  196. return 0, errors.TraceNew("throttled conn exhausted")
  197. }
  198. rate := atomic.SwapInt64(&conn.writeBytesPerSecond, -1)
  199. if rate != -1 {
  200. if rate == 0 {
  201. conn.writeRateLimiter = nil
  202. } else {
  203. conn.writeRateLimiter =
  204. ratelimit.NewBucketWithRate(float64(rate), rate)
  205. }
  206. }
  207. if len(buffer) >= 0 && conn.writeRateLimiter != nil {
  208. sleepDuration := conn.writeRateLimiter.Take(int64(len(buffer)))
  209. if sleepDuration > 0 {
  210. if conn.writeDelayTimer == nil {
  211. conn.writeDelayTimer = time.NewTimer(sleepDuration)
  212. } else {
  213. conn.writeDelayTimer.Reset(sleepDuration)
  214. }
  215. select {
  216. case <-conn.writeDelayTimer.C:
  217. case <-conn.stopBroadcast:
  218. if !conn.writeDelayTimer.Stop() {
  219. <-conn.writeDelayTimer.C
  220. }
  221. }
  222. }
  223. }
  224. n, err := conn.Conn.Write(buffer)
  225. return n, errors.Trace(err)
  226. }
  227. func (conn *ThrottledConn) Close() error {
  228. // Ensure close channel only called once.
  229. if !atomic.CompareAndSwapInt32(&conn.isClosed, 0, 1) {
  230. return nil
  231. }
  232. close(conn.stopBroadcast)
  233. return errors.Trace(conn.Conn.Close())
  234. }