Rate limiting refactor, race fixes, more tests
This commit is contained in:
parent
ccc2dd1128
commit
62140ec001
8 changed files with 241 additions and 117 deletions
|
@ -13,8 +13,17 @@ var ErrLimitReached = errors.New("limit reached")
|
|||
|
||||
// Limiter is an interface that implements a rate limiting mechanism, e.g. based on time or a fixed value
|
||||
type Limiter interface {
|
||||
// Allow adds n to the limiters internal value, or returns ErrLimitReached if the limit has been reached
|
||||
Allow(n int64) error
|
||||
// Allow adds one to the limiters value, or returns false if the limit has been reached
|
||||
Allow() bool
|
||||
|
||||
// AllowN adds n to the limiters value, or returns false if the limit has been reached
|
||||
AllowN(n int64) bool
|
||||
|
||||
// Value returns the current internal limiter value
|
||||
Value() int64
|
||||
|
||||
// Reset resets the state of the limiter
|
||||
Reset()
|
||||
}
|
||||
|
||||
// FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached
|
||||
|
@ -25,6 +34,8 @@ type FixedLimiter struct {
|
|||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var _ Limiter = (*FixedLimiter)(nil)
|
||||
|
||||
// NewFixedLimiter creates a new Limiter
|
||||
func NewFixedLimiter(limit int64) *FixedLimiter {
|
||||
return NewFixedLimiterWithValue(limit, 0)
|
||||
|
@ -38,16 +49,22 @@ func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
|
|||
}
|
||||
}
|
||||
|
||||
// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
|
||||
// exceeded after adding n, ErrLimitReached is returned.
|
||||
func (l *FixedLimiter) Allow(n int64) error {
|
||||
// Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was
|
||||
// exceeded, false is returned.
|
||||
func (l *FixedLimiter) Allow() bool {
|
||||
return l.AllowN(1)
|
||||
}
|
||||
|
||||
// AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
|
||||
// exceeded after adding n, false is returned.
|
||||
func (l *FixedLimiter) AllowN(n int64) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if l.value+n > l.limit {
|
||||
return ErrLimitReached
|
||||
return false
|
||||
}
|
||||
l.value += n
|
||||
return nil
|
||||
return true
|
||||
}
|
||||
|
||||
// Value returns the current limiter value
|
||||
|
@ -66,12 +83,29 @@ func (l *FixedLimiter) Reset() {
|
|||
|
||||
// RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit.
|
||||
type RateLimiter struct {
|
||||
r rate.Limit
|
||||
b int
|
||||
value int64
|
||||
limiter *rate.Limiter
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var _ Limiter = (*RateLimiter)(nil)
|
||||
|
||||
// NewRateLimiter creates a new RateLimiter
|
||||
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
|
||||
return NewRateLimiterWithValue(r, b, 0)
|
||||
}
|
||||
|
||||
// NewRateLimiterWithValue creates a new RateLimiter with the given starting value.
|
||||
//
|
||||
// Note that the starting value only has informational value. It does not impact the underlying
|
||||
// value of the rate.Limiter.
|
||||
func NewRateLimiterWithValue(r rate.Limit, b int, value int64) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
r: r,
|
||||
b: b,
|
||||
value: value,
|
||||
limiter: rate.NewLimiter(r, b),
|
||||
}
|
||||
}
|
||||
|
@ -82,16 +116,40 @@ func NewBytesLimiter(bytes int, interval time.Duration) *RateLimiter {
|
|||
return NewRateLimiter(rate.Limit(bytes)*rate.Every(interval), bytes)
|
||||
}
|
||||
|
||||
// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
|
||||
// exceeded after adding n, ErrLimitReached is returned.
|
||||
func (l *RateLimiter) Allow(n int64) error {
|
||||
// Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was
|
||||
// exceeded, false is returned.
|
||||
func (l *RateLimiter) Allow() bool {
|
||||
return l.AllowN(1)
|
||||
}
|
||||
|
||||
// AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
|
||||
// exceeded after adding n, false is returned.
|
||||
func (l *RateLimiter) AllowN(n int64) bool {
|
||||
if n <= 0 {
|
||||
return nil // No-op. Can't take back bytes you're written!
|
||||
return false // No-op. Can't take back bytes you're written!
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if !l.limiter.AllowN(time.Now(), int(n)) {
|
||||
return ErrLimitReached
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
l.value += n
|
||||
return true
|
||||
}
|
||||
|
||||
// Value returns the current limiter value
|
||||
func (l *RateLimiter) Value() int64 {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.value
|
||||
}
|
||||
|
||||
// Reset sets the limiter's value back to zero, and resets the underlying rate.Limiter
|
||||
func (l *RateLimiter) Reset() {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.limiter = rate.NewLimiter(l.r, l.b)
|
||||
l.value = 0
|
||||
}
|
||||
|
||||
// LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
|
||||
|
@ -117,9 +175,9 @@ func (w *LimitWriter) Write(p []byte) (n int, err error) {
|
|||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
for i := 0; i < len(w.limiters); i++ {
|
||||
if err := w.limiters[i].Allow(int64(len(p))); err != nil {
|
||||
if !w.limiters[i].AllowN(int64(len(p))) {
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
w.limiters[j].Allow(-int64(len(p))) // Revert limiters limits if allowed
|
||||
w.limiters[j].AllowN(-int64(len(p))) // Revert limiters limits if not allowed
|
||||
}
|
||||
return 0, ErrLimitReached
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue