Rate limits make sense now!
This commit is contained in:
parent
a036814d98
commit
c874a641df
17 changed files with 365 additions and 205 deletions
|
@ -77,6 +77,7 @@ var flagsServe = append(
|
||||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
|
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
|
||||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"visitor_request_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
|
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"visitor_request_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}),
|
||||||
|
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}),
|
||||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
|
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
|
||||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
|
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
|
||||||
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
|
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
|
||||||
|
@ -150,6 +151,7 @@ func execServe(c *cli.Context) error {
|
||||||
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
|
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
|
||||||
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
|
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
|
||||||
visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",")
|
visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",")
|
||||||
|
visitorMessageDailyLimit := c.Int("visitor-message-daily-limit")
|
||||||
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
|
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
|
||||||
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
|
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
|
||||||
behindProxy := c.Bool("behind-proxy")
|
behindProxy := c.Bool("behind-proxy")
|
||||||
|
@ -289,6 +291,7 @@ func execServe(c *cli.Context) error {
|
||||||
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
|
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
|
||||||
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
|
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
|
||||||
conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs
|
conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs
|
||||||
|
conf.VisitorMessageDailyLimit = visitorMessageDailyLimit
|
||||||
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
|
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
|
||||||
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
|
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
|
||||||
conf.BehindProxy = behindProxy
|
conf.BehindProxy = behindProxy
|
||||||
|
|
|
@ -44,6 +44,7 @@ const (
|
||||||
DefaultVisitorSubscriptionLimit = 30
|
DefaultVisitorSubscriptionLimit = 30
|
||||||
DefaultVisitorRequestLimitBurst = 60
|
DefaultVisitorRequestLimitBurst = 60
|
||||||
DefaultVisitorRequestLimitReplenish = 5 * time.Second
|
DefaultVisitorRequestLimitReplenish = 5 * time.Second
|
||||||
|
DefaultVisitorMessageDailyLimit = 0
|
||||||
DefaultVisitorEmailLimitBurst = 16
|
DefaultVisitorEmailLimitBurst = 16
|
||||||
DefaultVisitorEmailLimitReplenish = time.Hour
|
DefaultVisitorEmailLimitReplenish = time.Hour
|
||||||
DefaultVisitorAccountCreationLimitBurst = 3
|
DefaultVisitorAccountCreationLimitBurst = 3
|
||||||
|
@ -105,6 +106,7 @@ type Config struct {
|
||||||
VisitorRequestLimitBurst int
|
VisitorRequestLimitBurst int
|
||||||
VisitorRequestLimitReplenish time.Duration
|
VisitorRequestLimitReplenish time.Duration
|
||||||
VisitorRequestExemptIPAddrs []netip.Prefix
|
VisitorRequestExemptIPAddrs []netip.Prefix
|
||||||
|
VisitorMessageDailyLimit int
|
||||||
VisitorEmailLimitBurst int
|
VisitorEmailLimitBurst int
|
||||||
VisitorEmailLimitReplenish time.Duration
|
VisitorEmailLimitReplenish time.Duration
|
||||||
VisitorAccountCreationLimitBurst int
|
VisitorAccountCreationLimitBurst int
|
||||||
|
@ -171,6 +173,7 @@ func NewConfig() *Config {
|
||||||
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
|
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
|
||||||
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
|
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
|
||||||
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
|
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
|
||||||
|
VisitorMessageDailyLimit: DefaultVisitorMessageDailyLimit,
|
||||||
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
|
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
|
||||||
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
||||||
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
|
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
|
||||||
|
|
|
@ -75,10 +75,10 @@ var (
|
||||||
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth reached", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
|
errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
|
||||||
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
|
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
|
||||||
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: too many messages", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
||||||
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
|
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
|
||||||
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
|
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
|
||||||
|
|
|
@ -38,10 +38,9 @@ import (
|
||||||
TODO
|
TODO
|
||||||
--
|
--
|
||||||
|
|
||||||
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
|
|
||||||
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||||
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
|
|
||||||
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
|
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
|
||||||
|
- MEDIUM Rate limiting: Test daily message quota read from database initially
|
||||||
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
|
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
|
||||||
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||||
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||||
|
@ -57,7 +56,6 @@ Make sure account endpoints make sense for admins
|
||||||
|
|
||||||
Tests:
|
Tests:
|
||||||
- Payment endpoints (make mocks)
|
- Payment endpoints (make mocks)
|
||||||
- test that the visitor is based on the IP address when a user has no tier
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Server is the main server, providing the UI and API for ntfy
|
// Server is the main server, providing the UI and API for ntfy
|
||||||
|
@ -308,7 +306,7 @@ func (s *Server) Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||||
v, err := s.visitor(r) // Note: Always returns v, even when error is returned
|
v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
|
||||||
if err == nil {
|
if err == nil {
|
||||||
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
|
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
|
||||||
if log.IsTrace() {
|
if log.IsTrace() {
|
||||||
|
@ -563,7 +561,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
if v.user != nil {
|
if v.user != nil {
|
||||||
m.User = v.user.ID
|
m.User = v.user.ID
|
||||||
}
|
}
|
||||||
m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
|
m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
|
||||||
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
|
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -601,7 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
}
|
}
|
||||||
v.IncrementMessages()
|
v.IncrementMessages()
|
||||||
if s.userManager != nil && v.user != nil {
|
if s.userManager != nil && v.user != nil {
|
||||||
s.userManager.EnqueueStats(v.user)
|
s.userManager.EnqueueStats(v.user) // FIXME this makes no sense for tier-less users
|
||||||
}
|
}
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.messages++
|
s.messages++
|
||||||
|
@ -1382,8 +1380,10 @@ func (s *Server) runStatsResetter() {
|
||||||
log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
|
log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
|
||||||
select {
|
select {
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
|
log.Debug("Stats resetter: Running")
|
||||||
s.resetStats()
|
s.resetStats()
|
||||||
case <-s.closeChan:
|
case <-s.closeChan:
|
||||||
|
log.Debug("Stats resetter: Stopping timer")
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1440,17 +1440,15 @@ func (s *Server) sendDelayedMessages() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
var v *visitor
|
var u *user.User
|
||||||
if s.userManager != nil && m.User != "" {
|
if s.userManager != nil && m.User != "" {
|
||||||
u, err := s.userManager.User(m.User)
|
u, err = s.userManager.User(m.User)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
log.Warn("Error sending delayed message %s: %s", m.ID, err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
v = s.visitorFromUser(u, m.Sender)
|
|
||||||
} else {
|
|
||||||
v = s.visitorFromIP(m.Sender)
|
|
||||||
}
|
}
|
||||||
|
v := s.visitor(m.Sender, u)
|
||||||
if err := s.sendDelayedMessage(v, m); err != nil {
|
if err := s.sendDelayedMessage(v, m); err != nil {
|
||||||
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
||||||
}
|
}
|
||||||
|
@ -1588,20 +1586,16 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
|
||||||
// Note that this function will always return a visitor, even if an error occurs.
|
// Note that this function will always return a visitor, even if an error occurs.
|
||||||
func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
|
||||||
ip := extractIPAddress(r, s.config.BehindProxy)
|
ip := extractIPAddress(r, s.config.BehindProxy)
|
||||||
var u *user.User // may stay nil if no auth header!
|
var u *user.User // may stay nil if no auth header!
|
||||||
if u, err = s.authenticate(r); err != nil {
|
if u, err = s.authenticate(r); err != nil {
|
||||||
log.Debug("authentication failed: %s", err.Error())
|
log.Debug("authentication failed: %s", err.Error())
|
||||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||||
}
|
}
|
||||||
if u != nil {
|
v = s.visitor(ip, u)
|
||||||
v = s.visitorFromUser(u, ip)
|
|
||||||
} else {
|
|
||||||
v = s.visitorFromIP(ip)
|
|
||||||
}
|
|
||||||
v.SetUser(u) // Update visitor user with latest from database!
|
v.SetUser(u) // Update visitor user with latest from database!
|
||||||
return v, err // Always return visitor, even when error occurs!
|
return v, err // Always return visitor, even when error occurs!
|
||||||
}
|
}
|
||||||
|
@ -1645,26 +1639,19 @@ func (s *Server) authenticateBearerAuth(value string) (user *user.User, err erro
|
||||||
return s.userManager.AuthenticateToken(token)
|
return s.userManager.AuthenticateToken(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
|
func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
v, exists := s.visitors[visitorID]
|
id := visitorID(ip, user)
|
||||||
|
v, exists := s.visitors[id]
|
||||||
if !exists {
|
if !exists {
|
||||||
s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
|
s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
|
||||||
return s.visitors[visitorID]
|
return s.visitors[id]
|
||||||
}
|
}
|
||||||
v.Keepalive()
|
v.Keepalive()
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
|
||||||
return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
|
|
||||||
return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
|
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
|
|
|
@ -200,6 +200,12 @@
|
||||||
# visitor-request-limit-replenish: "5s"
|
# visitor-request-limit-replenish: "5s"
|
||||||
# visitor-request-limit-exempt-hosts: ""
|
# visitor-request-limit-exempt-hosts: ""
|
||||||
|
|
||||||
|
# Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset
|
||||||
|
# every day at midnight UTC. If the limit is not set (or set to zero), the request
|
||||||
|
# limit (see above) governs the upper limit.
|
||||||
|
#
|
||||||
|
# visitor-message-daily-limit: 0
|
||||||
|
|
||||||
# Rate limiting: Allowed emails per visitor:
|
# Rate limiting: Allowed emails per visitor:
|
||||||
# - visitor-email-limit-burst is the initial bucket of emails each visitor has
|
# - visitor-email-limit-burst is the initial bucket of emails each visitor has
|
||||||
# - visitor-email-limit-replenish is the rate at which the bucket is refilled
|
# - visitor-email-limit-replenish is the rate at which the bucket is refilled
|
||||||
|
|
|
@ -23,6 +23,9 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
||||||
} else if v.user != nil {
|
} else if v.user != nil {
|
||||||
return errHTTPUnauthorized // Cannot create account from user context
|
return errHTTPUnauthorized // Cannot create account from user context
|
||||||
}
|
}
|
||||||
|
if err := v.AccountCreationAllowed(); err != nil {
|
||||||
|
return errHTTPTooManyRequestsLimitAccountCreation
|
||||||
|
}
|
||||||
}
|
}
|
||||||
newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit)
|
newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -31,9 +34,6 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
||||||
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
|
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
|
||||||
return errHTTPConflictUserExists
|
return errHTTPConflictUserExists
|
||||||
}
|
}
|
||||||
if err := v.AccountCreationAllowed(); err != nil {
|
|
||||||
return errHTTPTooManyRequestsLimitAccountCreation
|
|
||||||
}
|
|
||||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
|
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -49,9 +49,9 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
||||||
response := &apiAccountResponse{
|
response := &apiAccountResponse{
|
||||||
Limits: &apiAccountLimits{
|
Limits: &apiAccountLimits{
|
||||||
Basis: string(limits.Basis),
|
Basis: string(limits.Basis),
|
||||||
Messages: limits.MessagesLimit,
|
Messages: limits.MessageLimit,
|
||||||
MessagesExpiryDuration: int64(limits.MessagesExpiryDuration.Seconds()),
|
MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()),
|
||||||
Emails: limits.EmailsLimit,
|
Emails: limits.EmailLimit,
|
||||||
Reservations: limits.ReservationsLimit,
|
Reservations: limits.ReservationsLimit,
|
||||||
AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
|
AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
|
||||||
AttachmentFileSize: limits.AttachmentFileSizeLimit,
|
AttachmentFileSize: limits.AttachmentFileSizeLimit,
|
||||||
|
@ -344,7 +344,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||||
reservations, err := s.userManager.ReservationsCount(v.user.Name)
|
reservations, err := s.userManager.ReservationsCount(v.user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if reservations >= v.user.Tier.ReservationsLimit {
|
} else if reservations >= v.user.Tier.ReservationLimit {
|
||||||
return errHTTPTooManyRequestsLimitReservations
|
return errHTTPTooManyRequestsLimitReservations
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -410,10 +410,10 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
||||||
// Create a tier
|
// Create a tier
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
MessagesLimit: 123,
|
MessageLimit: 123,
|
||||||
MessagesExpiryDuration: 86400 * time.Second,
|
MessageExpiryDuration: 86400 * time.Second,
|
||||||
EmailsLimit: 32,
|
EmailLimit: 32,
|
||||||
ReservationsLimit: 2,
|
ReservationLimit: 2,
|
||||||
AttachmentFileSizeLimit: 1231231,
|
AttachmentFileSizeLimit: 1231231,
|
||||||
AttachmentTotalSizeLimit: 123123,
|
AttachmentTotalSizeLimit: 123123,
|
||||||
AttachmentExpiryDuration: 10800 * time.Second,
|
AttachmentExpiryDuration: 10800 * time.Second,
|
||||||
|
@ -491,9 +491,9 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
MessagesLimit: 20,
|
MessageLimit: 20,
|
||||||
ReservationsLimit: 2,
|
ReservationLimit: 2,
|
||||||
}))
|
}))
|
||||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||||
|
|
||||||
|
@ -525,9 +525,9 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
MessagesLimit: 20,
|
MessageLimit: 20,
|
||||||
ReservationsLimit: 2,
|
ReservationLimit: 2,
|
||||||
}))
|
}))
|
||||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||||
|
|
||||||
|
@ -591,10 +591,10 @@ func TestAccount_Tier_Create(t *testing.T) {
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
Name: "Pro",
|
Name: "Pro",
|
||||||
MessagesLimit: 123,
|
MessageLimit: 123,
|
||||||
MessagesExpiryDuration: 86400 * time.Second,
|
MessageExpiryDuration: 86400 * time.Second,
|
||||||
EmailsLimit: 32,
|
EmailLimit: 32,
|
||||||
ReservationsLimit: 2,
|
ReservationLimit: 2,
|
||||||
AttachmentFileSizeLimit: 1231231,
|
AttachmentFileSizeLimit: 1231231,
|
||||||
AttachmentTotalSizeLimit: 123123,
|
AttachmentTotalSizeLimit: 123123,
|
||||||
AttachmentExpiryDuration: 10800 * time.Second,
|
AttachmentExpiryDuration: 10800 * time.Second,
|
||||||
|
@ -616,10 +616,10 @@ func TestAccount_Tier_Create(t *testing.T) {
|
||||||
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
|
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
|
||||||
require.Equal(t, "pro", ti.Code)
|
require.Equal(t, "pro", ti.Code)
|
||||||
require.Equal(t, "Pro", ti.Name)
|
require.Equal(t, "Pro", ti.Name)
|
||||||
require.Equal(t, int64(123), ti.MessagesLimit)
|
require.Equal(t, int64(123), ti.MessageLimit)
|
||||||
require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
|
require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration)
|
||||||
require.Equal(t, int64(32), ti.EmailsLimit)
|
require.Equal(t, int64(32), ti.EmailLimit)
|
||||||
require.Equal(t, int64(2), ti.ReservationsLimit)
|
require.Equal(t, int64(2), ti.ReservationLimit)
|
||||||
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
|
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
|
||||||
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
|
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
|
||||||
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
|
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
|
||||||
|
|
|
@ -60,15 +60,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
freeTier := defaultVisitorLimits(s.config)
|
freeTier := configBasedVisitorLimits(s.config)
|
||||||
response := []*apiAccountBillingTier{
|
response := []*apiAccountBillingTier{
|
||||||
{
|
{
|
||||||
// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
|
// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
|
||||||
Limits: &apiAccountLimits{
|
Limits: &apiAccountLimits{
|
||||||
Basis: string(visitorLimitBasisIP),
|
Basis: string(visitorLimitBasisIP),
|
||||||
Messages: freeTier.MessagesLimit,
|
Messages: freeTier.MessageLimit,
|
||||||
MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
|
MessagesExpiryDuration: int64(freeTier.MessageExpiryDuration.Seconds()),
|
||||||
Emails: freeTier.EmailsLimit,
|
Emails: freeTier.EmailLimit,
|
||||||
Reservations: freeTier.ReservationsLimit,
|
Reservations: freeTier.ReservationsLimit,
|
||||||
AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
|
AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
|
||||||
AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
|
AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
|
||||||
|
@ -91,10 +91,10 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
||||||
Price: priceStr,
|
Price: priceStr,
|
||||||
Limits: &apiAccountLimits{
|
Limits: &apiAccountLimits{
|
||||||
Basis: string(visitorLimitBasisTier),
|
Basis: string(visitorLimitBasisTier),
|
||||||
Messages: tier.MessagesLimit,
|
Messages: tier.MessageLimit,
|
||||||
MessagesExpiryDuration: int64(tier.MessagesExpiryDuration.Seconds()),
|
MessagesExpiryDuration: int64(tier.MessageExpiryDuration.Seconds()),
|
||||||
Emails: tier.EmailsLimit,
|
Emails: tier.EmailLimit,
|
||||||
Reservations: tier.ReservationsLimit,
|
Reservations: tier.ReservationLimit,
|
||||||
AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
|
AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
|
||||||
AttachmentFileSize: tier.AttachmentFileSizeLimit,
|
AttachmentFileSize: tier.AttachmentFileSizeLimit,
|
||||||
AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()),
|
AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()),
|
||||||
|
@ -336,7 +336,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
||||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -355,14 +355,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
|
||||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||||
reservationsLimit := visitorDefaultReservationsLimit
|
reservationsLimit := visitorDefaultReservationsLimit
|
||||||
if tier != nil {
|
if tier != nil {
|
||||||
reservationsLimit = tier.ReservationsLimit
|
reservationsLimit = tier.ReservationLimit
|
||||||
}
|
}
|
||||||
if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
|
if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -5,11 +5,14 @@ import (
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stripe/stripe-go/v74"
|
"github.com/stripe/stripe-go/v74"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
"heckel.io/ntfy/user"
|
"heckel.io/ntfy/user"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"io"
|
"io"
|
||||||
|
"net/netip"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -48,10 +51,10 @@ func TestPayments_Tiers(t *testing.T) {
|
||||||
ID: "ti_123",
|
ID: "ti_123",
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
Name: "Pro",
|
Name: "Pro",
|
||||||
MessagesLimit: 1000,
|
MessageLimit: 1000,
|
||||||
MessagesExpiryDuration: time.Hour,
|
MessageExpiryDuration: time.Hour,
|
||||||
EmailsLimit: 123,
|
EmailLimit: 123,
|
||||||
ReservationsLimit: 777,
|
ReservationLimit: 777,
|
||||||
AttachmentFileSizeLimit: 999,
|
AttachmentFileSizeLimit: 999,
|
||||||
AttachmentTotalSizeLimit: 888,
|
AttachmentTotalSizeLimit: 888,
|
||||||
AttachmentExpiryDuration: time.Minute,
|
AttachmentExpiryDuration: time.Minute,
|
||||||
|
@ -61,10 +64,10 @@ func TestPayments_Tiers(t *testing.T) {
|
||||||
ID: "ti_444",
|
ID: "ti_444",
|
||||||
Code: "business",
|
Code: "business",
|
||||||
Name: "Business",
|
Name: "Business",
|
||||||
MessagesLimit: 2000,
|
MessageLimit: 2000,
|
||||||
MessagesExpiryDuration: 10 * time.Hour,
|
MessageExpiryDuration: 10 * time.Hour,
|
||||||
EmailsLimit: 123123,
|
EmailLimit: 123123,
|
||||||
ReservationsLimit: 777333,
|
ReservationLimit: 777333,
|
||||||
AttachmentFileSizeLimit: 999111,
|
AttachmentFileSizeLimit: 999111,
|
||||||
AttachmentTotalSizeLimit: 888111,
|
AttachmentTotalSizeLimit: 888111,
|
||||||
AttachmentExpiryDuration: time.Hour,
|
AttachmentExpiryDuration: time.Hour,
|
||||||
|
@ -238,9 +241,14 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
||||||
require.Equal(t, 401, rr.Code)
|
require.Equal(t, 401, rr.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
|
func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
|
||||||
// This tests a successful checkout flow (not a paying customer -> paying customer),
|
// This test is too overloaded, but it's also a great end-to-end a test.
|
||||||
// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
|
//
|
||||||
|
// It tests:
|
||||||
|
// - A successful checkout flow (not a paying customer -> paying customer)
|
||||||
|
// - Tier-changes reset the rate limits for the user
|
||||||
|
// - The request limits for tier-less user and a tier-user
|
||||||
|
// - The message limits for a tier-user
|
||||||
|
|
||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
@ -248,19 +256,26 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
c.VisitorRequestLimitBurst = 10
|
c.VisitorRequestLimitBurst = 5
|
||||||
c.VisitorRequestLimitReplenish = time.Hour
|
c.VisitorRequestLimitReplenish = time.Hour
|
||||||
|
c.CacheStartupQueries = `
|
||||||
|
pragma journal_mode = WAL;
|
||||||
|
pragma synchronous = normal;
|
||||||
|
pragma temp_store = memory;
|
||||||
|
`
|
||||||
|
c.CacheBatchSize = 500
|
||||||
|
c.CacheBatchTimeout = time.Second
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
s.stripe = stripeMock
|
s.stripe = stripeMock
|
||||||
|
|
||||||
// Create a user with a Stripe subscription and 3 reservations
|
// Create a user with a Stripe subscription and 3 reservations
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
ID: "ti_123",
|
ID: "ti_123",
|
||||||
Code: "starter",
|
Code: "starter",
|
||||||
StripePriceID: "price_1234",
|
StripePriceID: "price_1234",
|
||||||
ReservationsLimit: 1,
|
ReservationLimit: 1,
|
||||||
MessagesLimit: 100,
|
MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
|
||||||
MessagesExpiryDuration: time.Hour,
|
MessageExpiryDuration: time.Hour,
|
||||||
}))
|
}))
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
|
||||||
u, err := s.userManager.User("phil")
|
u, err := s.userManager.User("phil")
|
||||||
|
@ -298,7 +313,7 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
|
||||||
Return(&stripe.Customer{}, nil)
|
Return(&stripe.Customer{}, nil)
|
||||||
|
|
||||||
// Send messages until rate limit of free tier is hit
|
// Send messages until rate limit of free tier is hit
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
})
|
})
|
||||||
|
@ -323,10 +338,9 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
|
||||||
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||||
|
|
||||||
// FIXME FIXME This test is broken, because the rate limit logic is unclear!
|
|
||||||
|
|
||||||
// Now for the fun part: Verify that new rate limits are immediately applied
|
// Now for the fun part: Verify that new rate limits are immediately applied
|
||||||
for i := 0; i < 100; i++ {
|
// This only tests the request limiter, which kicks in before the message limiter.
|
||||||
|
for i := 0; i < 11; i++ {
|
||||||
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
})
|
})
|
||||||
|
@ -336,6 +350,37 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
})
|
})
|
||||||
require.Equal(t, 429, rr.Code)
|
require.Equal(t, 429, rr.Code)
|
||||||
|
|
||||||
|
// Now let's test the message limiter by faking a ridiculously generous rate limiter
|
||||||
|
v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
|
||||||
|
v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 209; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 429, rr.Code)
|
||||||
|
|
||||||
|
// And now let's cross-check that the stats are correct too
|
||||||
|
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
|
||||||
|
require.Equal(t, int64(220), account.Limits.Messages)
|
||||||
|
require.Equal(t, int64(220), account.Stats.Messages)
|
||||||
|
require.Equal(t, int64(0), account.Stats.MessagesRemaining)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
|
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
|
||||||
|
@ -363,9 +408,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
||||||
ID: "ti_1",
|
ID: "ti_1",
|
||||||
Code: "starter",
|
Code: "starter",
|
||||||
StripePriceID: "price_1234", // !
|
StripePriceID: "price_1234", // !
|
||||||
ReservationsLimit: 1, // !
|
ReservationLimit: 1, // !
|
||||||
MessagesLimit: 100,
|
MessageLimit: 100,
|
||||||
MessagesExpiryDuration: time.Hour,
|
MessageExpiryDuration: time.Hour,
|
||||||
AttachmentExpiryDuration: time.Hour,
|
AttachmentExpiryDuration: time.Hour,
|
||||||
AttachmentFileSizeLimit: 1000000,
|
AttachmentFileSizeLimit: 1000000,
|
||||||
AttachmentTotalSizeLimit: 1000000,
|
AttachmentTotalSizeLimit: 1000000,
|
||||||
|
@ -375,9 +420,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
||||||
ID: "ti_2",
|
ID: "ti_2",
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
StripePriceID: "price_1111", // !
|
StripePriceID: "price_1111", // !
|
||||||
ReservationsLimit: 3, // !
|
ReservationLimit: 3, // !
|
||||||
MessagesLimit: 200,
|
MessageLimit: 200,
|
||||||
MessagesExpiryDuration: time.Hour,
|
MessageExpiryDuration: time.Hour,
|
||||||
AttachmentExpiryDuration: time.Hour,
|
AttachmentExpiryDuration: time.Hour,
|
||||||
AttachmentFileSizeLimit: 1000000,
|
AttachmentFileSizeLimit: 1000000,
|
||||||
AttachmentTotalSizeLimit: 1000000,
|
AttachmentTotalSizeLimit: 1000000,
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"heckel.io/ntfy/user"
|
"heckel.io/ntfy/user"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -22,9 +21,14 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// log.SetLevel(log.DebugLevel)
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_PublishAndPoll(t *testing.T) {
|
func TestServer_PublishAndPoll(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfig(t))
|
s := newTestServer(t, newTestConfig(t))
|
||||||
|
|
||||||
|
@ -742,16 +746,31 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
|
||||||
require.Equal(t, 401, response.Code)
|
require.Equal(t, 401, response.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_StatsResetter(t *testing.T) {
|
func TestServer_StatsResetter_User_Without_Tier(t *testing.T) {
|
||||||
|
// This tests the stats resetter for
|
||||||
|
// - an anonymous user
|
||||||
|
// - a user without a tier (treated like the same as the anonymous user)
|
||||||
|
// - a user with a tier
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t)
|
||||||
c.AuthDefault = user.PermissionDenyAll
|
|
||||||
c.VisitorStatsResetTime = time.Now().Add(2 * time.Second)
|
c.VisitorStatsResetTime = time.Now().Add(2 * time.Second)
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
go s.runStatsResetter()
|
go s.runStatsResetter()
|
||||||
|
|
||||||
|
// Create user with tier (tieruser) and user without tier (phil)
|
||||||
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
Code: "test",
|
||||||
|
MessageLimit: 5,
|
||||||
|
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||||
|
}))
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
|
require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser))
|
||||||
|
require.Nil(t, s.userManager.ChangeTier("tieruser", "test"))
|
||||||
|
|
||||||
|
// Send an anonymous message
|
||||||
|
response := request(t, s, "PUT", "/mytopic", "test", nil)
|
||||||
|
|
||||||
|
// Send messages from user without tier (phil)
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
@ -759,30 +778,66 @@ func TestServer_StatsResetter(t *testing.T) {
|
||||||
require.Equal(t, 200, response.Code)
|
require.Equal(t, 200, response.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
response := request(t, s, "GET", "/v1/account", "", map[string]string{
|
// Send messages from user with tier
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
for i := 0; i < 2; i++ {
|
||||||
})
|
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
||||||
require.Equal(t, 200, response.Code)
|
"Authorization": util.BasicAuth("tieruser", "tieruser"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
}
|
||||||
|
|
||||||
// User stats show 10 messages
|
// User stats show 6 messages (for user without tier)
|
||||||
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, response.Code)
|
require.Equal(t, 200, response.Code)
|
||||||
account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, int64(5), account.Stats.Messages)
|
require.Equal(t, int64(6), account.Stats.Messages)
|
||||||
|
|
||||||
|
// User stats show 6 messages (for anonymous visitor)
|
||||||
|
response = request(t, s, "GET", "/v1/account", "", nil)
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(6), account.Stats.Messages)
|
||||||
|
|
||||||
|
// User stats show 2 messages (for user with tier)
|
||||||
|
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("tieruser", "tieruser"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(2), account.Stats.Messages)
|
||||||
|
|
||||||
// Wait for stats resetter to run
|
// Wait for stats resetter to run
|
||||||
time.Sleep(2200 * time.Millisecond)
|
time.Sleep(2200 * time.Millisecond)
|
||||||
|
|
||||||
// User stats show 0 messages now!
|
// User stats show 0 messages now!
|
||||||
|
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(0), account.Stats.Messages)
|
||||||
|
|
||||||
|
// Since this is a user without a tier, the anonymous user should have the same stats
|
||||||
response = request(t, s, "GET", "/v1/account", "", nil)
|
response = request(t, s, "GET", "/v1/account", "", nil)
|
||||||
require.Equal(t, 200, response.Code)
|
require.Equal(t, 200, response.Code)
|
||||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, int64(0), account.Stats.Messages)
|
require.Equal(t, int64(0), account.Stats.Messages)
|
||||||
|
|
||||||
|
// User stats show 0 messages (for user with tier)
|
||||||
|
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("tieruser", "tieruser"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(0), account.Stats.Messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testMailer struct {
|
type testMailer struct {
|
||||||
|
@ -1133,9 +1188,9 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
|
||||||
|
|
||||||
// Create tier with certain limits
|
// Create tier with certain limits
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
Code: "test",
|
Code: "test",
|
||||||
MessagesLimit: 5,
|
MessageLimit: 5,
|
||||||
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
|
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||||
}))
|
}))
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||||
|
@ -1363,8 +1418,8 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
|
||||||
sevenDays := time.Duration(604800) * time.Second
|
sevenDays := time.Duration(604800) * time.Second
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
Code: "test",
|
Code: "test",
|
||||||
MessagesLimit: 10,
|
MessageLimit: 10,
|
||||||
MessagesExpiryDuration: sevenDays,
|
MessageExpiryDuration: sevenDays,
|
||||||
AttachmentFileSizeLimit: 50_000,
|
AttachmentFileSizeLimit: 50_000,
|
||||||
AttachmentTotalSizeLimit: 200_000,
|
AttachmentTotalSizeLimit: 200_000,
|
||||||
AttachmentExpiryDuration: sevenDays, // 7 days
|
AttachmentExpiryDuration: sevenDays, // 7 days
|
||||||
|
@ -1407,8 +1462,8 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
|
||||||
// Create tier with certain limits
|
// Create tier with certain limits
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
Code: "test",
|
Code: "test",
|
||||||
MessagesLimit: 10,
|
MessageLimit: 10,
|
||||||
MessagesExpiryDuration: time.Hour,
|
MessageExpiryDuration: time.Hour,
|
||||||
AttachmentFileSizeLimit: 50_000,
|
AttachmentFileSizeLimit: 50_000,
|
||||||
AttachmentTotalSizeLimit: 200_000,
|
AttachmentTotalSizeLimit: 200_000,
|
||||||
AttachmentExpiryDuration: time.Hour,
|
AttachmentExpiryDuration: time.Hour,
|
||||||
|
@ -1450,7 +1505,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
|
||||||
// Create tier with certain limits
|
// Create tier with certain limits
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
Code: "test",
|
Code: "test",
|
||||||
MessagesLimit: 100,
|
MessageLimit: 100,
|
||||||
AttachmentFileSizeLimit: 50_000,
|
AttachmentFileSizeLimit: 50_000,
|
||||||
AttachmentTotalSizeLimit: 200_000,
|
AttachmentTotalSizeLimit: 200_000,
|
||||||
AttachmentExpiryDuration: 30 * time.Second,
|
AttachmentExpiryDuration: 30 * time.Second,
|
||||||
|
@ -1574,7 +1629,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
|
||||||
r, _ := http.NewRequest("GET", "/bla", nil)
|
r, _ := http.NewRequest("GET", "/bla", nil)
|
||||||
r.RemoteAddr = "8.9.10.11"
|
r.RemoteAddr = "8.9.10.11"
|
||||||
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
|
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
|
||||||
v, err := s.visitor(r)
|
v, err := s.maybeAuthenticate(r)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "8.9.10.11", v.ip.String())
|
require.Equal(t, "8.9.10.11", v.ip.String())
|
||||||
}
|
}
|
||||||
|
@ -1586,7 +1641,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
|
||||||
r, _ := http.NewRequest("GET", "/bla", nil)
|
r, _ := http.NewRequest("GET", "/bla", nil)
|
||||||
r.RemoteAddr = "8.9.10.11"
|
r.RemoteAddr = "8.9.10.11"
|
||||||
r.Header.Set("X-Forwarded-For", "1.1.1.1")
|
r.Header.Set("X-Forwarded-For", "1.1.1.1")
|
||||||
v, err := s.visitor(r)
|
v, err := s.maybeAuthenticate(r)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "1.1.1.1", v.ip.String())
|
require.Equal(t, "1.1.1.1", v.ip.String())
|
||||||
}
|
}
|
||||||
|
@ -1598,7 +1653,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
|
||||||
r, _ := http.NewRequest("GET", "/bla", nil)
|
r, _ := http.NewRequest("GET", "/bla", nil)
|
||||||
r.RemoteAddr = "8.9.10.11"
|
r.RemoteAddr = "8.9.10.11"
|
||||||
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
|
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
|
||||||
v, err := s.visitor(r)
|
v, err := s.maybeAuthenticate(r)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "234.5.2.1", v.ip.String())
|
require.Equal(t, "234.5.2.1", v.ip.String())
|
||||||
}
|
}
|
||||||
|
@ -1611,7 +1666,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
// Add lots of messages
|
// Add lots of messages
|
||||||
log.Printf("Adding %d messages", count)
|
log.Info("Adding %d messages", count)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
messages := make([]*message, 0)
|
messages := make([]*message, 0)
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
|
@ -1621,31 +1676,31 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
||||||
messages = append(messages, newDefaultMessage(topicID, "some message"))
|
messages = append(messages, newDefaultMessage(topicID, "some message"))
|
||||||
}
|
}
|
||||||
require.Nil(t, s.messageCache.addMessages(messages))
|
require.Nil(t, s.messageCache.addMessages(messages))
|
||||||
log.Printf("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
|
log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
|
||||||
|
|
||||||
// Update stats
|
// Update stats
|
||||||
statsChan := make(chan bool)
|
statsChan := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
log.Printf("Updating stats")
|
log.Info("Updating stats")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
s.execManager()
|
s.execManager()
|
||||||
log.Printf("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
|
log.Info("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
|
||||||
statsChan <- true
|
statsChan <- true
|
||||||
}()
|
}()
|
||||||
time.Sleep(50 * time.Millisecond) // Make sure it starts first
|
time.Sleep(50 * time.Millisecond) // Make sure it starts first
|
||||||
|
|
||||||
// Publish message (during stats update)
|
// Publish message (during stats update)
|
||||||
log.Printf("Publishing message")
|
log.Info("Publishing message")
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
response := request(t, s, "PUT", "/mytopic", "some body", nil)
|
response := request(t, s, "PUT", "/mytopic", "some body", nil)
|
||||||
m := toMessage(t, response.Body.String())
|
m := toMessage(t, response.Body.String())
|
||||||
assert.Equal(t, "some body", m.Message)
|
assert.Equal(t, "some body", m.Message)
|
||||||
assert.True(t, time.Since(start) < 100*time.Millisecond)
|
assert.True(t, time.Since(start) < 100*time.Millisecond)
|
||||||
log.Printf("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
|
log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
|
||||||
|
|
||||||
// Wait for all goroutines
|
// Wait for all goroutines
|
||||||
<-statsChan
|
<-statsChan
|
||||||
log.Printf("Done: Waiting for all locks")
|
log.Info("Done: Waiting for all locks")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestConfig(t *testing.T) *Config {
|
func newTestConfig(t *testing.T) *Config {
|
||||||
|
|
|
@ -14,16 +14,39 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// oneDay is an approximation of a day as a time.Duration
|
||||||
|
oneDay = 24 * time.Hour
|
||||||
|
|
||||||
// visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number
|
// visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number
|
||||||
// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
|
// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
|
||||||
// they are replenished faster (typically).
|
// they are replenished faster (typically).
|
||||||
visitorExpungeAfter = 24 * time.Hour
|
visitorExpungeAfter = oneDay
|
||||||
|
|
||||||
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
|
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
|
||||||
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
|
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
|
||||||
visitorDefaultReservationsLimit = int64(0)
|
visitorDefaultReservationsLimit = int64(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Constants used to convert a tier-user's MessageLimit (see user.Tier) into adequate request limiter
|
||||||
|
// values (token bucket).
|
||||||
|
//
|
||||||
|
// Example: Assuming a user.Tier's MessageLimit is 10,000:
|
||||||
|
// - the allowed burst is 500 (= 10,000 * 5%), which is < 1000 (the max)
|
||||||
|
// - the replenish rate is 2 * 10,000 / 24 hours
|
||||||
|
const (
|
||||||
|
visitorMessageToRequestLimitBurstRate = 0.05
|
||||||
|
visitorMessageToRequestLimitBurstMax = 1000
|
||||||
|
visitorMessageToRequestLimitReplenishFactor = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// Constants used to convert a tier-user's EmailLimit (see user.Tier) into adequate email limiter
|
||||||
|
// values (token bucket). Example: Assuming a user.Tier's EmailLimit is 200, the allowed burst is
|
||||||
|
// 40 (= 200 * 20%), which is <150 (the max).
|
||||||
|
const (
|
||||||
|
visitorEmailLimitBurstRate = 0.2
|
||||||
|
visitorEmailLimitBurstMax = 150
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errVisitorLimitReached = errors.New("limit reached")
|
errVisitorLimitReached = errors.New("limit reached")
|
||||||
)
|
)
|
||||||
|
@ -55,9 +78,13 @@ type visitorInfo struct {
|
||||||
|
|
||||||
type visitorLimits struct {
|
type visitorLimits struct {
|
||||||
Basis visitorLimitBasis
|
Basis visitorLimitBasis
|
||||||
MessagesLimit int64
|
RequestLimitBurst int
|
||||||
MessagesExpiryDuration time.Duration
|
RequestLimitReplenish rate.Limit
|
||||||
EmailsLimit int64
|
MessageLimit int64
|
||||||
|
MessageExpiryDuration time.Duration
|
||||||
|
EmailLimit int64
|
||||||
|
EmailLimitBurst int
|
||||||
|
EmailLimitReplenish rate.Limit
|
||||||
ReservationsLimit int64
|
ReservationsLimit int64
|
||||||
AttachmentTotalSizeLimit int64
|
AttachmentTotalSizeLimit int64
|
||||||
AttachmentFileSizeLimit int64
|
AttachmentFileSizeLimit int64
|
||||||
|
@ -173,7 +200,7 @@ func (v *visitor) SubscriptionAllowed() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) AccountCreationAllowed() error {
|
func (v *visitor) AccountCreationAllowed() error {
|
||||||
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
|
if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) {
|
||||||
return errVisitorLimitReached
|
return errVisitorLimitReached
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -242,31 +269,6 @@ func (v *visitor) SetUser(u *user.User) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) resetLimiters() {
|
|
||||||
log.Info("%s Resetting limiters for visitor", v.stringNoLock())
|
|
||||||
var messagesLimiter, bandwidthLimiter util.Limiter
|
|
||||||
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
|
|
||||||
if v.user != nil && v.user.Tier != nil {
|
|
||||||
requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
|
|
||||||
messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
|
|
||||||
emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
|
|
||||||
bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
|
|
||||||
} else {
|
|
||||||
requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
|
|
||||||
messagesLimiter = nil // Message limit is governed by the requestLimiter
|
|
||||||
emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
|
|
||||||
bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
|
|
||||||
}
|
|
||||||
if v.user == nil {
|
|
||||||
accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
|
|
||||||
}
|
|
||||||
v.requestLimiter = requestLimiter
|
|
||||||
v.messagesLimiter = messagesLimiter
|
|
||||||
v.emailsLimiter = emailsLimiter
|
|
||||||
v.bandwidthLimiter = bandwidthLimiter
|
|
||||||
v.accountLimiter = accountLimiter
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
|
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
|
||||||
// an empty string is returned.
|
// an empty string is returned.
|
||||||
func (v *visitor) MaybeUserID() string {
|
func (v *visitor) MaybeUserID() string {
|
||||||
|
@ -278,22 +280,71 @@ func (v *visitor) MaybeUserID() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *visitor) resetLimiters() {
|
||||||
|
log.Debug("%s Resetting limiters for visitor", v.stringNoLock())
|
||||||
|
limits := v.limitsNoLock()
|
||||||
|
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
|
||||||
|
v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, v.messages)
|
||||||
|
v.emailsLimiter = rate.NewLimiter(limits.EmailLimitReplenish, limits.EmailLimitBurst)
|
||||||
|
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
|
||||||
|
if v.user == nil {
|
||||||
|
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
|
||||||
|
} else {
|
||||||
|
v.accountLimiter = nil // Users cannot create accounts when logged in
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (v *visitor) Limits() *visitorLimits {
|
func (v *visitor) Limits() *visitorLimits {
|
||||||
v.mu.Lock()
|
v.mu.Lock()
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
limits := defaultVisitorLimits(v.config)
|
return v.limitsNoLock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *visitor) limitsNoLock() *visitorLimits {
|
||||||
if v.user != nil && v.user.Tier != nil {
|
if v.user != nil && v.user.Tier != nil {
|
||||||
limits.Basis = visitorLimitBasisTier
|
return tierBasedVisitorLimits(v.config, v.user.Tier)
|
||||||
limits.MessagesLimit = v.user.Tier.MessagesLimit
|
}
|
||||||
limits.MessagesExpiryDuration = v.user.Tier.MessagesExpiryDuration
|
return configBasedVisitorLimits(v.config)
|
||||||
limits.EmailsLimit = v.user.Tier.EmailsLimit
|
}
|
||||||
limits.ReservationsLimit = v.user.Tier.ReservationsLimit
|
|
||||||
limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit
|
func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits {
|
||||||
limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit
|
return &visitorLimits{
|
||||||
limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration
|
Basis: visitorLimitBasisTier,
|
||||||
limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit
|
RequestLimitBurst: util.MinMax(int(float64(tier.MessageLimit)*visitorMessageToRequestLimitBurstRate), conf.VisitorRequestLimitBurst, visitorMessageToRequestLimitBurstMax),
|
||||||
|
RequestLimitReplenish: dailyLimitToRate(tier.MessageLimit * visitorMessageToRequestLimitReplenishFactor),
|
||||||
|
MessageLimit: tier.MessageLimit,
|
||||||
|
MessageExpiryDuration: tier.MessageExpiryDuration,
|
||||||
|
EmailLimit: tier.EmailLimit,
|
||||||
|
EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax),
|
||||||
|
EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit),
|
||||||
|
ReservationsLimit: tier.ReservationLimit,
|
||||||
|
AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit,
|
||||||
|
AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit,
|
||||||
|
AttachmentExpiryDuration: tier.AttachmentExpiryDuration,
|
||||||
|
AttachmentBandwidthLimit: tier.AttachmentBandwidthLimit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configBasedVisitorLimits(conf *Config) *visitorLimits {
|
||||||
|
messagesLimit := replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish) // Approximation!
|
||||||
|
if conf.VisitorMessageDailyLimit > 0 {
|
||||||
|
messagesLimit = int64(conf.VisitorMessageDailyLimit)
|
||||||
|
}
|
||||||
|
return &visitorLimits{
|
||||||
|
Basis: visitorLimitBasisIP,
|
||||||
|
RequestLimitBurst: conf.VisitorRequestLimitBurst,
|
||||||
|
RequestLimitReplenish: rate.Every(conf.VisitorRequestLimitReplenish),
|
||||||
|
MessageLimit: messagesLimit,
|
||||||
|
MessageExpiryDuration: conf.CacheDuration,
|
||||||
|
EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation!
|
||||||
|
EmailLimitBurst: conf.VisitorEmailLimitBurst,
|
||||||
|
EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish),
|
||||||
|
ReservationsLimit: visitorDefaultReservationsLimit,
|
||||||
|
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
|
||||||
|
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
|
||||||
|
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
|
||||||
|
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
|
||||||
}
|
}
|
||||||
return limits
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) Info() (*visitorInfo, error) {
|
func (v *visitor) Info() (*visitorInfo, error) {
|
||||||
|
@ -321,9 +372,9 @@ func (v *visitor) Info() (*visitorInfo, error) {
|
||||||
limits := v.Limits()
|
limits := v.Limits()
|
||||||
stats := &visitorStats{
|
stats := &visitorStats{
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
MessagesRemaining: zeroIfNegative(limits.MessagesLimit - messages),
|
MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages),
|
||||||
Emails: emails,
|
Emails: emails,
|
||||||
EmailsRemaining: zeroIfNegative(limits.EmailsLimit - emails),
|
EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails),
|
||||||
Reservations: reservations,
|
Reservations: reservations,
|
||||||
ReservationsRemaining: zeroIfNegative(limits.ReservationsLimit - reservations),
|
ReservationsRemaining: zeroIfNegative(limits.ReservationsLimit - reservations),
|
||||||
AttachmentTotalSize: attachmentsBytesUsed,
|
AttachmentTotalSize: attachmentsBytesUsed,
|
||||||
|
@ -343,23 +394,16 @@ func zeroIfNegative(value int64) int64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
func replenishDurationToDailyLimit(duration time.Duration) int64 {
|
func replenishDurationToDailyLimit(duration time.Duration) int64 {
|
||||||
return int64(24 * time.Hour / duration)
|
return int64(oneDay / duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dailyLimitToRate(limit int64) rate.Limit {
|
func dailyLimitToRate(limit int64) rate.Limit {
|
||||||
return rate.Limit(limit) * rate.Every(24*time.Hour)
|
return rate.Limit(limit) * rate.Every(oneDay)
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultVisitorLimits(conf *Config) *visitorLimits {
|
func visitorID(ip netip.Addr, u *user.User) string {
|
||||||
return &visitorLimits{
|
if u != nil && u.Tier != nil {
|
||||||
Basis: visitorLimitBasisIP,
|
return fmt.Sprintf("user:%s", u.ID)
|
||||||
MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish),
|
|
||||||
MessagesExpiryDuration: conf.CacheDuration,
|
|
||||||
EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish),
|
|
||||||
ReservationsLimit: visitorDefaultReservationsLimit,
|
|
||||||
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
|
|
||||||
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
|
|
||||||
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
|
|
||||||
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
|
|
||||||
}
|
}
|
||||||
|
return fmt.Sprintf("ip:%s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
|
@ -709,10 +709,10 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
||||||
ID: tierID.String,
|
ID: tierID.String,
|
||||||
Code: tierCode.String,
|
Code: tierCode.String,
|
||||||
Name: tierName.String,
|
Name: tierName.String,
|
||||||
MessagesLimit: messagesLimit.Int64,
|
MessageLimit: messagesLimit.Int64,
|
||||||
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||||
EmailsLimit: emailsLimit.Int64,
|
EmailLimit: emailsLimit.Int64,
|
||||||
ReservationsLimit: reservationsLimit.Int64,
|
ReservationLimit: reservationsLimit.Int64,
|
||||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||||
|
@ -845,7 +845,7 @@ func (a *Manager) ChangeTier(username, tier string) error {
|
||||||
t, err := a.Tier(tier)
|
t, err := a.Tier(tier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
|
} else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
|
if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
|
||||||
|
@ -870,7 +870,7 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
|
if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
|
||||||
reservations, err := a.Reservations(username)
|
reservations, err := a.Reservations(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -999,7 +999,7 @@ func (a *Manager) CreateTier(tier *Tier) error {
|
||||||
if tier.ID == "" {
|
if tier.ID == "" {
|
||||||
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||||
}
|
}
|
||||||
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
|
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -1070,10 +1070,10 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
|
||||||
ID: id,
|
ID: id,
|
||||||
Code: code,
|
Code: code,
|
||||||
Name: name,
|
Name: name,
|
||||||
MessagesLimit: messagesLimit.Int64,
|
MessageLimit: messagesLimit.Int64,
|
||||||
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||||
EmailsLimit: emailsLimit.Int64,
|
EmailLimit: emailsLimit.Int64,
|
||||||
ReservationsLimit: reservationsLimit.Int64,
|
ReservationLimit: reservationsLimit.Int64,
|
||||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||||
|
|
|
@ -335,10 +335,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
Name: "ntfy Pro",
|
Name: "ntfy Pro",
|
||||||
StripePriceID: "price123",
|
StripePriceID: "price123",
|
||||||
MessagesLimit: 5_000,
|
MessageLimit: 5_000,
|
||||||
MessagesExpiryDuration: 3 * 24 * time.Hour,
|
MessageExpiryDuration: 3 * 24 * time.Hour,
|
||||||
EmailsLimit: 50,
|
EmailLimit: 50,
|
||||||
ReservationsLimit: 5,
|
ReservationLimit: 5,
|
||||||
AttachmentFileSizeLimit: 52428800,
|
AttachmentFileSizeLimit: 52428800,
|
||||||
AttachmentTotalSizeLimit: 524288000,
|
AttachmentTotalSizeLimit: 524288000,
|
||||||
AttachmentExpiryDuration: 24 * time.Hour,
|
AttachmentExpiryDuration: 24 * time.Hour,
|
||||||
|
@ -351,10 +351,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, RoleUser, ben.Role)
|
require.Equal(t, RoleUser, ben.Role)
|
||||||
require.Equal(t, "pro", ben.Tier.Code)
|
require.Equal(t, "pro", ben.Tier.Code)
|
||||||
require.Equal(t, int64(5000), ben.Tier.MessagesLimit)
|
require.Equal(t, int64(5000), ben.Tier.MessageLimit)
|
||||||
require.Equal(t, 3*24*time.Hour, ben.Tier.MessagesExpiryDuration)
|
require.Equal(t, 3*24*time.Hour, ben.Tier.MessageExpiryDuration)
|
||||||
require.Equal(t, int64(50), ben.Tier.EmailsLimit)
|
require.Equal(t, int64(50), ben.Tier.EmailLimit)
|
||||||
require.Equal(t, int64(5), ben.Tier.ReservationsLimit)
|
require.Equal(t, int64(5), ben.Tier.ReservationLimit)
|
||||||
require.Equal(t, int64(52428800), ben.Tier.AttachmentFileSizeLimit)
|
require.Equal(t, int64(52428800), ben.Tier.AttachmentFileSizeLimit)
|
||||||
require.Equal(t, int64(524288000), ben.Tier.AttachmentTotalSizeLimit)
|
require.Equal(t, int64(524288000), ben.Tier.AttachmentTotalSizeLimit)
|
||||||
require.Equal(t, 24*time.Hour, ben.Tier.AttachmentExpiryDuration)
|
require.Equal(t, 24*time.Hour, ben.Tier.AttachmentExpiryDuration)
|
||||||
|
|
|
@ -62,10 +62,10 @@ type Tier struct {
|
||||||
ID string // Tier identifier (ti_...)
|
ID string // Tier identifier (ti_...)
|
||||||
Code string // Code of the tier
|
Code string // Code of the tier
|
||||||
Name string // Name of the tier
|
Name string // Name of the tier
|
||||||
MessagesLimit int64 // Daily message limit
|
MessageLimit int64 // Daily message limit
|
||||||
MessagesExpiryDuration time.Duration // Cache duration for messages
|
MessageExpiryDuration time.Duration // Cache duration for messages
|
||||||
EmailsLimit int64 // Daily email limit
|
EmailLimit int64 // Daily email limit
|
||||||
ReservationsLimit int64 // Number of topic reservations allowed by user
|
ReservationLimit int64 // Number of topic reservations allowed by user
|
||||||
AttachmentFileSizeLimit int64 // Max file size per file (bytes)
|
AttachmentFileSizeLimit int64 // Max file size per file (bytes)
|
||||||
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
|
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
|
||||||
AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted
|
AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted
|
||||||
|
|
|
@ -27,8 +27,14 @@ type FixedLimiter struct {
|
||||||
|
|
||||||
// NewFixedLimiter creates a new Limiter
|
// NewFixedLimiter creates a new Limiter
|
||||||
func NewFixedLimiter(limit int64) *FixedLimiter {
|
func NewFixedLimiter(limit int64) *FixedLimiter {
|
||||||
|
return NewFixedLimiterWithValue(limit, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFixedLimiterWithValue creates a new Limiter and sets the initial value
|
||||||
|
func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
|
||||||
return &FixedLimiter{
|
return &FixedLimiter{
|
||||||
limit: limit,
|
limit: limit,
|
||||||
|
value: value,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ var (
|
||||||
// NextOccurrenceUTC takes a time of day (e.g. 9:00am), and returns the next occurrence
|
// NextOccurrenceUTC takes a time of day (e.g. 9:00am), and returns the next occurrence
|
||||||
// of that time from the current time (in UTC).
|
// of that time from the current time (in UTC).
|
||||||
func NextOccurrenceUTC(timeOfDay, base time.Time) time.Time {
|
func NextOccurrenceUTC(timeOfDay, base time.Time) time.Time {
|
||||||
hour, minute, seconds := timeOfDay.Clock()
|
hour, minute, seconds := timeOfDay.UTC().Clock()
|
||||||
now := base.UTC()
|
now := base.UTC()
|
||||||
next := time.Date(now.Year(), now.Month(), now.Day(), hour, minute, seconds, 0, time.UTC)
|
next := time.Date(now.Year(), now.Month(), now.Day(), hour, minute, seconds, 0, time.UTC)
|
||||||
if next.Before(now) {
|
if next.Before(now) {
|
||||||
|
|
11
util/util.go
11
util/util.go
|
@ -337,6 +337,17 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MinMax returns value if it is between min and max, or either
|
||||||
|
// min or max if it is out of range
|
||||||
|
func MinMax[T int | int64](value, min, max T) T {
|
||||||
|
if value < min {
|
||||||
|
return min
|
||||||
|
} else if value > max {
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
// String turns a string into a pointer of a string
|
// String turns a string into a pointer of a string
|
||||||
func String(v string) *string {
|
func String(v string) *string {
|
||||||
return &v
|
return &v
|
||||||
|
|
Loading…
Reference in a new issue