package util

import (
	"errors"
	"golang.org/x/time/rate"
	"io"
	"sync"
	"time"
)

// ErrLimitReached is the error returned by the Limiter and LimitWriter when the predefined limit has been reached
var ErrLimitReached = errors.New("limit reached")

// Limiter is an interface that implements a rate limiting mechanism, e.g. based on time or a fixed value
type Limiter interface {
	// Allow adds one to the limiters value, or returns false if the limit has been reached
	Allow() bool

	// AllowN adds n to the limiters value, or returns false if the limit has been reached
	AllowN(n int64) bool

	// Value returns the current internal limiter value
	Value() int64

	// Reset resets the state of the limiter
	Reset()
}

// FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached
// ErrLimitReached will be returned. FixedLimiter may be used by multiple goroutines.
type FixedLimiter struct {
	value int64
	limit int64
	mu    sync.Mutex
}

var _ Limiter = (*FixedLimiter)(nil)

// NewFixedLimiter creates a new Limiter
func NewFixedLimiter(limit int64) *FixedLimiter {
	return NewFixedLimiterWithValue(limit, 0)
}

// NewFixedLimiterWithValue creates a new Limiter and sets the initial value
func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
	return &FixedLimiter{
		limit: limit,
		value: value,
	}
}

// Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was
// exceeded, false is returned.
func (l *FixedLimiter) Allow() bool {
	return l.AllowN(1)
}

// AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
// exceeded after adding n, false is returned.
func (l *FixedLimiter) AllowN(n int64) bool {
	l.mu.Lock()
	defer l.mu.Unlock()
	if l.value+n > l.limit {
		return false
	}
	l.value += n
	return true
}

// Value returns the current limiter value
func (l *FixedLimiter) Value() int64 {
	l.mu.Lock()
	defer l.mu.Unlock()
	return l.value
}

// Reset sets the limiter's value back to zero
func (l *FixedLimiter) Reset() {
	l.mu.Lock()
	defer l.mu.Unlock()
	l.value = 0
}

// RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit.
type RateLimiter struct {
	r       rate.Limit
	b       int
	value   int64
	limiter *rate.Limiter
	mu      sync.Mutex
}

var _ Limiter = (*RateLimiter)(nil)

// NewRateLimiter creates a new RateLimiter
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
	return NewRateLimiterWithValue(r, b, 0)
}

// NewRateLimiterWithValue creates a new RateLimiter with the given starting value.
//
// Note that the starting value only has informational value. It does not impact the underlying
// value of the rate.Limiter.
func NewRateLimiterWithValue(r rate.Limit, b int, value int64) *RateLimiter {
	return &RateLimiter{
		r:       r,
		b:       b,
		value:   value,
		limiter: rate.NewLimiter(r, b),
	}
}

// NewBytesLimiter creates a RateLimiter that is meant to be used for a bytes-per-interval limit,
// e.g. 250 MB per day. And example of the underlying idea can be found here: https://go.dev/play/p/0ljgzIZQ6dJ
func NewBytesLimiter(bytes int, interval time.Duration) *RateLimiter {
	return NewRateLimiter(rate.Limit(bytes)*rate.Every(interval), bytes)
}

// Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was
// exceeded, false is returned.
func (l *RateLimiter) Allow() bool {
	return l.AllowN(1)
}

// AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
// exceeded after adding n, false is returned.
func (l *RateLimiter) AllowN(n int64) bool {
	if n <= 0 {
		return false // No-op. Can't take back bytes you're written!
	}
	l.mu.Lock()
	defer l.mu.Unlock()
	if !l.limiter.AllowN(time.Now(), int(n)) {
		return false
	}
	l.value += n
	return true
}

// Value returns the current limiter value
func (l *RateLimiter) Value() int64 {
	l.mu.Lock()
	defer l.mu.Unlock()
	return l.value
}

// Reset sets the limiter's value back to zero, and resets the underlying rate.Limiter
func (l *RateLimiter) Reset() {
	l.mu.Lock()
	defer l.mu.Unlock()
	l.limiter = rate.NewLimiter(l.r, l.b)
	l.value = 0
}

// LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
// writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
// Each limiter's value is increased with every write.
type LimitWriter struct {
	w        io.Writer
	written  int64
	limiters []Limiter
	mu       sync.Mutex
}

// NewLimitWriter creates a new LimitWriter
func NewLimitWriter(w io.Writer, limiters ...Limiter) *LimitWriter {
	return &LimitWriter{
		w:        w,
		limiters: limiters,
	}
}

// Write passes through all writes to the underlying writer until any of the given limiter's limit is reached
func (w *LimitWriter) Write(p []byte) (n int, err error) {
	w.mu.Lock()
	defer w.mu.Unlock()
	for i := 0; i < len(w.limiters); i++ {
		if !w.limiters[i].AllowN(int64(len(p))) {
			for j := i - 1; j >= 0; j-- {
				w.limiters[j].AllowN(-int64(len(p))) // Revert limiters limits if not allowed
			}
			return 0, ErrLimitReached
		}
	}
	n, err = w.w.Write(p)
	w.written += int64(n)
	return
}