Rate limits make sense now!

This commit is contained in:
binwiederhier 2023-01-26 22:57:18 -05:00
parent a036814d98
commit c874a641df
17 changed files with 365 additions and 205 deletions

View file

@ -77,6 +77,7 @@ var flagsServe = append(
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"visitor_request_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"visitor_request_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
@ -150,6 +151,7 @@ func execServe(c *cli.Context) error {
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst") visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish") visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",") visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",")
visitorMessageDailyLimit := c.Int("visitor-message-daily-limit")
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
behindProxy := c.Bool("behind-proxy") behindProxy := c.Bool("behind-proxy")
@ -289,6 +291,7 @@ func execServe(c *cli.Context) error {
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs
conf.VisitorMessageDailyLimit = visitorMessageDailyLimit
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
conf.BehindProxy = behindProxy conf.BehindProxy = behindProxy

View file

@ -44,6 +44,7 @@ const (
DefaultVisitorSubscriptionLimit = 30 DefaultVisitorSubscriptionLimit = 30
DefaultVisitorRequestLimitBurst = 60 DefaultVisitorRequestLimitBurst = 60
DefaultVisitorRequestLimitReplenish = 5 * time.Second DefaultVisitorRequestLimitReplenish = 5 * time.Second
DefaultVisitorMessageDailyLimit = 0
DefaultVisitorEmailLimitBurst = 16 DefaultVisitorEmailLimitBurst = 16
DefaultVisitorEmailLimitReplenish = time.Hour DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorAccountCreationLimitBurst = 3 DefaultVisitorAccountCreationLimitBurst = 3
@ -105,6 +106,7 @@ type Config struct {
VisitorRequestLimitBurst int VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration VisitorRequestLimitReplenish time.Duration
VisitorRequestExemptIPAddrs []netip.Prefix VisitorRequestExemptIPAddrs []netip.Prefix
VisitorMessageDailyLimit int
VisitorEmailLimitBurst int VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration VisitorEmailLimitReplenish time.Duration
VisitorAccountCreationLimitBurst int VisitorAccountCreationLimitBurst int
@ -171,6 +173,7 @@ func NewConfig() *Config {
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0), VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
VisitorMessageDailyLimit: DefaultVisitorMessageDailyLimit,
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst, VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,

View file

@ -75,10 +75,10 @@ var (
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""} errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: too many messages", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}

View file

@ -38,10 +38,9 @@ import (
TODO TODO
-- --
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)? - HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
- MEDIUM Rate limiting: Test daily message quota read from database initially
- MEDIUM: Races with v.user (see publishSyncEventAsync test) - MEDIUM: Races with v.user (see publishSyncEventAsync test)
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade) - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@ -57,7 +56,6 @@ Make sure account endpoints make sense for admins
Tests: Tests:
- Payment endpoints (make mocks) - Payment endpoints (make mocks)
- test that the visitor is based on the IP address when a user has no tier
*/ */
// Server is the main server, providing the UI and API for ntfy // Server is the main server, providing the UI and API for ntfy
@ -308,7 +306,7 @@ func (s *Server) Stop() {
} }
func (s *Server) handle(w http.ResponseWriter, r *http.Request) { func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
v, err := s.visitor(r) // Note: Always returns v, even when error is returned v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
if err == nil { if err == nil {
log.Debug("%s Dispatching request", logHTTPPrefix(v, r)) log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
if log.IsTrace() { if log.IsTrace() {
@ -563,7 +561,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
if v.user != nil { if v.user != nil {
m.User = v.user.ID m.User = v.user.ID
} }
m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix() m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
return nil, err return nil, err
} }
@ -601,7 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
} }
v.IncrementMessages() v.IncrementMessages()
if s.userManager != nil && v.user != nil { if s.userManager != nil && v.user != nil {
s.userManager.EnqueueStats(v.user) s.userManager.EnqueueStats(v.user) // FIXME this makes no sense for tier-less users
} }
s.mu.Lock() s.mu.Lock()
s.messages++ s.messages++
@ -1382,8 +1380,10 @@ func (s *Server) runStatsResetter() {
log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt) log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
select { select {
case <-timer.C: case <-timer.C:
log.Debug("Stats resetter: Running")
s.resetStats() s.resetStats()
case <-s.closeChan: case <-s.closeChan:
log.Debug("Stats resetter: Stopping timer")
timer.Stop() timer.Stop()
return return
} }
@ -1440,17 +1440,15 @@ func (s *Server) sendDelayedMessages() error {
return err return err
} }
for _, m := range messages { for _, m := range messages {
var v *visitor var u *user.User
if s.userManager != nil && m.User != "" { if s.userManager != nil && m.User != "" {
u, err := s.userManager.User(m.User) u, err = s.userManager.User(m.User)
if err != nil { if err != nil {
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error()) log.Warn("Error sending delayed message %s: %s", m.ID, err.Error())
continue continue
} }
v = s.visitorFromUser(u, m.Sender)
} else {
v = s.visitorFromIP(m.Sender)
} }
v := s.visitor(m.Sender, u)
if err := s.sendDelayedMessage(v, m); err != nil { if err := s.sendDelayedMessage(v, m); err != nil {
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error()) log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
} }
@ -1588,20 +1586,16 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
} }
} }
// visitor creates or retrieves a rate.Limiter for the given visitor. // maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
// Note that this function will always return a visitor, even if an error occurs. // Note that this function will always return a visitor, even if an error occurs.
func (s *Server) visitor(r *http.Request) (v *visitor, err error) { func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
ip := extractIPAddress(r, s.config.BehindProxy) ip := extractIPAddress(r, s.config.BehindProxy)
var u *user.User // may stay nil if no auth header! var u *user.User // may stay nil if no auth header!
if u, err = s.authenticate(r); err != nil { if u, err = s.authenticate(r); err != nil {
log.Debug("authentication failed: %s", err.Error()) log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs! err = errHTTPUnauthorized // Always return visitor, even when error occurs!
} }
if u != nil { v = s.visitor(ip, u)
v = s.visitorFromUser(u, ip)
} else {
v = s.visitorFromIP(ip)
}
v.SetUser(u) // Update visitor user with latest from database! v.SetUser(u) // Update visitor user with latest from database!
return v, err // Always return visitor, even when error occurs! return v, err // Always return visitor, even when error occurs!
} }
@ -1645,26 +1639,19 @@ func (s *Server) authenticateBearerAuth(value string) (user *user.User, err erro
return s.userManager.AuthenticateToken(token) return s.userManager.AuthenticateToken(token)
} }
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor { func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
v, exists := s.visitors[visitorID] id := visitorID(ip, user)
v, exists := s.visitors[id]
if !exists { if !exists {
s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user) s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
return s.visitors[visitorID] return s.visitors[id]
} }
v.Keepalive() v.Keepalive()
return v return v
} }
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
}
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
}
func (s *Server) writeJSON(w http.ResponseWriter, v any) error { func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests

View file

@ -200,6 +200,12 @@
# visitor-request-limit-replenish: "5s" # visitor-request-limit-replenish: "5s"
# visitor-request-limit-exempt-hosts: "" # visitor-request-limit-exempt-hosts: ""
# Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset
# every day at midnight UTC. If the limit is not set (or set to zero), the request
# limit (see above) governs the upper limit.
#
# visitor-message-daily-limit: 0
# Rate limiting: Allowed emails per visitor: # Rate limiting: Allowed emails per visitor:
# - visitor-email-limit-burst is the initial bucket of emails each visitor has # - visitor-email-limit-burst is the initial bucket of emails each visitor has
# - visitor-email-limit-replenish is the rate at which the bucket is refilled # - visitor-email-limit-replenish is the rate at which the bucket is refilled

View file

@ -23,6 +23,9 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
} else if v.user != nil { } else if v.user != nil {
return errHTTPUnauthorized // Cannot create account from user context return errHTTPUnauthorized // Cannot create account from user context
} }
if err := v.AccountCreationAllowed(); err != nil {
return errHTTPTooManyRequestsLimitAccountCreation
}
} }
newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit) newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit)
if err != nil { if err != nil {
@ -31,9 +34,6 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil { if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
return errHTTPConflictUserExists return errHTTPConflictUserExists
} }
if err := v.AccountCreationAllowed(); err != nil {
return errHTTPTooManyRequestsLimitAccountCreation
}
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
return err return err
} }
@ -49,9 +49,9 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
response := &apiAccountResponse{ response := &apiAccountResponse{
Limits: &apiAccountLimits{ Limits: &apiAccountLimits{
Basis: string(limits.Basis), Basis: string(limits.Basis),
Messages: limits.MessagesLimit, Messages: limits.MessageLimit,
MessagesExpiryDuration: int64(limits.MessagesExpiryDuration.Seconds()), MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()),
Emails: limits.EmailsLimit, Emails: limits.EmailLimit,
Reservations: limits.ReservationsLimit, Reservations: limits.ReservationsLimit,
AttachmentTotalSize: limits.AttachmentTotalSizeLimit, AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
AttachmentFileSize: limits.AttachmentFileSizeLimit, AttachmentFileSize: limits.AttachmentFileSizeLimit,
@ -344,7 +344,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
reservations, err := s.userManager.ReservationsCount(v.user.Name) reservations, err := s.userManager.ReservationsCount(v.user.Name)
if err != nil { if err != nil {
return err return err
} else if reservations >= v.user.Tier.ReservationsLimit { } else if reservations >= v.user.Tier.ReservationLimit {
return errHTTPTooManyRequestsLimitReservations return errHTTPTooManyRequestsLimitReservations
} }
} }

View file

@ -410,10 +410,10 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
// Create a tier // Create a tier
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro", Code: "pro",
MessagesLimit: 123, MessageLimit: 123,
MessagesExpiryDuration: 86400 * time.Second, MessageExpiryDuration: 86400 * time.Second,
EmailsLimit: 32, EmailLimit: 32,
ReservationsLimit: 2, ReservationLimit: 2,
AttachmentFileSizeLimit: 1231231, AttachmentFileSizeLimit: 1231231,
AttachmentTotalSizeLimit: 123123, AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second, AttachmentExpiryDuration: 10800 * time.Second,
@ -491,9 +491,9 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro", Code: "pro",
MessagesLimit: 20, MessageLimit: 20,
ReservationsLimit: 2, ReservationLimit: 2,
})) }))
require.Nil(t, s.userManager.ChangeTier("phil", "pro")) require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -525,9 +525,9 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro", Code: "pro",
MessagesLimit: 20, MessageLimit: 20,
ReservationsLimit: 2, ReservationLimit: 2,
})) }))
require.Nil(t, s.userManager.ChangeTier("phil", "pro")) require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -591,10 +591,10 @@ func TestAccount_Tier_Create(t *testing.T) {
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro", Code: "pro",
Name: "Pro", Name: "Pro",
MessagesLimit: 123, MessageLimit: 123,
MessagesExpiryDuration: 86400 * time.Second, MessageExpiryDuration: 86400 * time.Second,
EmailsLimit: 32, EmailLimit: 32,
ReservationsLimit: 2, ReservationLimit: 2,
AttachmentFileSizeLimit: 1231231, AttachmentFileSizeLimit: 1231231,
AttachmentTotalSizeLimit: 123123, AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second, AttachmentExpiryDuration: 10800 * time.Second,
@ -616,10 +616,10 @@ func TestAccount_Tier_Create(t *testing.T) {
require.True(t, strings.HasPrefix(ti.ID, "ti_")) require.True(t, strings.HasPrefix(ti.ID, "ti_"))
require.Equal(t, "pro", ti.Code) require.Equal(t, "pro", ti.Code)
require.Equal(t, "Pro", ti.Name) require.Equal(t, "Pro", ti.Name)
require.Equal(t, int64(123), ti.MessagesLimit) require.Equal(t, int64(123), ti.MessageLimit)
require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration) require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration)
require.Equal(t, int64(32), ti.EmailsLimit) require.Equal(t, int64(32), ti.EmailLimit)
require.Equal(t, int64(2), ti.ReservationsLimit) require.Equal(t, int64(2), ti.ReservationLimit)
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit) require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)

View file

@ -60,15 +60,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
if err != nil { if err != nil {
return err return err
} }
freeTier := defaultVisitorLimits(s.config) freeTier := configBasedVisitorLimits(s.config)
response := []*apiAccountBillingTier{ response := []*apiAccountBillingTier{
{ {
// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price. // This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
Limits: &apiAccountLimits{ Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisIP), Basis: string(visitorLimitBasisIP),
Messages: freeTier.MessagesLimit, Messages: freeTier.MessageLimit,
MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()), MessagesExpiryDuration: int64(freeTier.MessageExpiryDuration.Seconds()),
Emails: freeTier.EmailsLimit, Emails: freeTier.EmailLimit,
Reservations: freeTier.ReservationsLimit, Reservations: freeTier.ReservationsLimit,
AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit, AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
AttachmentFileSize: freeTier.AttachmentFileSizeLimit, AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
@ -91,10 +91,10 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
Price: priceStr, Price: priceStr,
Limits: &apiAccountLimits{ Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisTier), Basis: string(visitorLimitBasisTier),
Messages: tier.MessagesLimit, Messages: tier.MessageLimit,
MessagesExpiryDuration: int64(tier.MessagesExpiryDuration.Seconds()), MessagesExpiryDuration: int64(tier.MessageExpiryDuration.Seconds()),
Emails: tier.EmailsLimit, Emails: tier.EmailLimit,
Reservations: tier.ReservationsLimit, Reservations: tier.ReservationLimit,
AttachmentTotalSize: tier.AttachmentTotalSizeLimit, AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
AttachmentFileSize: tier.AttachmentFileSizeLimit, AttachmentFileSize: tier.AttachmentFileSizeLimit,
AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()), AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()),
@ -336,7 +336,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil { if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err return err
} }
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil return nil
} }
@ -355,14 +355,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil { if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
return err return err
} }
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil return nil
} }
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit reservationsLimit := visitorDefaultReservationsLimit
if tier != nil { if tier != nil {
reservationsLimit = tier.ReservationsLimit reservationsLimit = tier.ReservationLimit
} }
if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil { if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
return err return err

View file

@ -5,11 +5,14 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stripe/stripe-go/v74" "github.com/stripe/stripe-go/v74"
"golang.org/x/time/rate"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"io" "io"
"net/netip"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
) )
@ -48,10 +51,10 @@ func TestPayments_Tiers(t *testing.T) {
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
Name: "Pro", Name: "Pro",
MessagesLimit: 1000, MessageLimit: 1000,
MessagesExpiryDuration: time.Hour, MessageExpiryDuration: time.Hour,
EmailsLimit: 123, EmailLimit: 123,
ReservationsLimit: 777, ReservationLimit: 777,
AttachmentFileSizeLimit: 999, AttachmentFileSizeLimit: 999,
AttachmentTotalSizeLimit: 888, AttachmentTotalSizeLimit: 888,
AttachmentExpiryDuration: time.Minute, AttachmentExpiryDuration: time.Minute,
@ -61,10 +64,10 @@ func TestPayments_Tiers(t *testing.T) {
ID: "ti_444", ID: "ti_444",
Code: "business", Code: "business",
Name: "Business", Name: "Business",
MessagesLimit: 2000, MessageLimit: 2000,
MessagesExpiryDuration: 10 * time.Hour, MessageExpiryDuration: 10 * time.Hour,
EmailsLimit: 123123, EmailLimit: 123123,
ReservationsLimit: 777333, ReservationLimit: 777333,
AttachmentFileSizeLimit: 999111, AttachmentFileSizeLimit: 999111,
AttachmentTotalSizeLimit: 888111, AttachmentTotalSizeLimit: 888111,
AttachmentExpiryDuration: time.Hour, AttachmentExpiryDuration: time.Hour,
@ -238,9 +241,14 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
require.Equal(t, 401, rr.Code) require.Equal(t, 401, rr.Code)
} }
func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) { func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
// This tests a successful checkout flow (not a paying customer -> paying customer), // This test is too overloaded, but it's also a great end-to-end a test.
// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user. //
// It tests:
// - A successful checkout flow (not a paying customer -> paying customer)
// - Tier-changes reset the rate limits for the user
// - The request limits for tier-less user and a tier-user
// - The message limits for a tier-user
stripeMock := &testStripeAPI{} stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t) defer stripeMock.AssertExpectations(t)
@ -248,19 +256,26 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
c := newTestConfigWithAuthFile(t) c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key" c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key" c.StripeWebhookKey = "webhook key"
c.VisitorRequestLimitBurst = 10 c.VisitorRequestLimitBurst = 5
c.VisitorRequestLimitReplenish = time.Hour c.VisitorRequestLimitReplenish = time.Hour
c.CacheStartupQueries = `
pragma journal_mode = WAL;
pragma synchronous = normal;
pragma temp_store = memory;
`
c.CacheBatchSize = 500
c.CacheBatchTimeout = time.Second
s := newTestServer(t, c) s := newTestServer(t, c)
s.stripe = stripeMock s.stripe = stripeMock
// Create a user with a Stripe subscription and 3 reservations // Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "starter", Code: "starter",
StripePriceID: "price_1234", StripePriceID: "price_1234",
ReservationsLimit: 1, ReservationLimit: 1,
MessagesLimit: 100, MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
MessagesExpiryDuration: time.Hour, MessageExpiryDuration: time.Hour,
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
u, err := s.userManager.User("phil") u, err := s.userManager.User("phil")
@ -298,7 +313,7 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
Return(&stripe.Customer{}, nil) Return(&stripe.Customer{}, nil)
// Send messages until rate limit of free tier is hit // Send messages until rate limit of free tier is hit
for i := 0; i < 10; i++ { for i := 0; i < 5; i++ {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
@ -323,10 +338,9 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix()) require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
// FIXME FIXME This test is broken, because the rate limit logic is unclear!
// Now for the fun part: Verify that new rate limits are immediately applied // Now for the fun part: Verify that new rate limits are immediately applied
for i := 0; i < 100; i++ { // This only tests the request limiter, which kicks in before the message limiter.
for i := 0; i < 11; i++ {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
@ -336,6 +350,37 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
require.Equal(t, 429, rr.Code) require.Equal(t, 429, rr.Code)
// Now let's test the message limiter by faking a ridiculously generous rate limiter
v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
var wg sync.WaitGroup
for i := 0; i < 209; i++ {
wg.Add(1)
go func() {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
wg.Done()
}()
}
wg.Wait()
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code)
// And now let's cross-check that the stats are correct too
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
require.Equal(t, int64(220), account.Limits.Messages)
require.Equal(t, int64(220), account.Stats.Messages)
require.Equal(t, int64(0), account.Stats.MessagesRemaining)
} }
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) { func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
@ -363,9 +408,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
ID: "ti_1", ID: "ti_1",
Code: "starter", Code: "starter",
StripePriceID: "price_1234", // ! StripePriceID: "price_1234", // !
ReservationsLimit: 1, // ! ReservationLimit: 1, // !
MessagesLimit: 100, MessageLimit: 100,
MessagesExpiryDuration: time.Hour, MessageExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour, AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000, AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000,
@ -375,9 +420,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
ID: "ti_2", ID: "ti_2",
Code: "pro", Code: "pro",
StripePriceID: "price_1111", // ! StripePriceID: "price_1111", // !
ReservationsLimit: 3, // ! ReservationLimit: 3, // !
MessagesLimit: 200, MessageLimit: 200,
MessagesExpiryDuration: time.Hour, MessageExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour, AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000, AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000,

View file

@ -8,7 +8,6 @@ import (
"fmt" "fmt"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"io" "io"
"log"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -22,9 +21,14 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
) )
func init() {
// log.SetLevel(log.DebugLevel)
}
func TestServer_PublishAndPoll(t *testing.T) { func TestServer_PublishAndPoll(t *testing.T) {
s := newTestServer(t, newTestConfig(t)) s := newTestServer(t, newTestConfig(t))
@ -742,16 +746,31 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
require.Equal(t, 401, response.Code) require.Equal(t, 401, response.Code)
} }
func TestServer_StatsResetter(t *testing.T) { func TestServer_StatsResetter_User_Without_Tier(t *testing.T) {
// This tests the stats resetter for
// - an anonymous user
// - a user without a tier (treated like the same as the anonymous user)
// - a user with a tier
c := newTestConfigWithAuthFile(t) c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
c.VisitorStatsResetTime = time.Now().Add(2 * time.Second) c.VisitorStatsResetTime = time.Now().Add(2 * time.Second)
s := newTestServer(t, c) s := newTestServer(t, c)
go s.runStatsResetter() go s.runStatsResetter()
// Create user with tier (tieruser) and user without tier (phil)
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test",
MessageLimit: 5,
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite)) require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("tieruser", "test"))
// Send an anonymous message
response := request(t, s, "PUT", "/mytopic", "test", nil)
// Send messages from user without tier (phil)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{ response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
@ -759,30 +778,66 @@ func TestServer_StatsResetter(t *testing.T) {
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
} }
response := request(t, s, "GET", "/v1/account", "", map[string]string{ // Send messages from user with tier
"Authorization": util.BasicAuth("phil", "phil"), for i := 0; i < 2; i++ {
}) response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
require.Equal(t, 200, response.Code) "Authorization": util.BasicAuth("tieruser", "tieruser"),
})
require.Equal(t, 200, response.Code)
}
// User stats show 10 messages // User stats show 6 messages (for user without tier)
response = request(t, s, "GET", "/v1/account", "", map[string]string{ response = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body)) account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(5), account.Stats.Messages) require.Equal(t, int64(6), account.Stats.Messages)
// User stats show 6 messages (for anonymous visitor)
response = request(t, s, "GET", "/v1/account", "", nil)
require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(6), account.Stats.Messages)
// User stats show 2 messages (for user with tier)
response = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("tieruser", "tieruser"),
})
require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(2), account.Stats.Messages)
// Wait for stats resetter to run // Wait for stats resetter to run
time.Sleep(2200 * time.Millisecond) time.Sleep(2200 * time.Millisecond)
// User stats show 0 messages now! // User stats show 0 messages now!
response = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(0), account.Stats.Messages)
// Since this is a user without a tier, the anonymous user should have the same stats
response = request(t, s, "GET", "/v1/account", "", nil) response = request(t, s, "GET", "/v1/account", "", nil)
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body)) account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), account.Stats.Messages) require.Equal(t, int64(0), account.Stats.Messages)
// User stats show 0 messages (for user with tier)
response = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("tieruser", "tieruser"),
})
require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(0), account.Stats.Messages)
} }
type testMailer struct { type testMailer struct {
@ -1133,9 +1188,9 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
// Create tier with certain limits // Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test", Code: "test",
MessagesLimit: 5, MessageLimit: 5,
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack! MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test")) require.Nil(t, s.userManager.ChangeTier("phil", "test"))
@ -1363,8 +1418,8 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
sevenDays := time.Duration(604800) * time.Second sevenDays := time.Duration(604800) * time.Second
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test", Code: "test",
MessagesLimit: 10, MessageLimit: 10,
MessagesExpiryDuration: sevenDays, MessageExpiryDuration: sevenDays,
AttachmentFileSizeLimit: 50_000, AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000, AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: sevenDays, // 7 days AttachmentExpiryDuration: sevenDays, // 7 days
@ -1407,8 +1462,8 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
// Create tier with certain limits // Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test", Code: "test",
MessagesLimit: 10, MessageLimit: 10,
MessagesExpiryDuration: time.Hour, MessageExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 50_000, AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000, AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: time.Hour, AttachmentExpiryDuration: time.Hour,
@ -1450,7 +1505,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
// Create tier with certain limits // Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test", Code: "test",
MessagesLimit: 100, MessageLimit: 100,
AttachmentFileSizeLimit: 50_000, AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000, AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: 30 * time.Second, AttachmentExpiryDuration: 30 * time.Second,
@ -1574,7 +1629,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
r, _ := http.NewRequest("GET", "/bla", nil) r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11" r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty! r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
v, err := s.visitor(r) v, err := s.maybeAuthenticate(r)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "8.9.10.11", v.ip.String()) require.Equal(t, "8.9.10.11", v.ip.String())
} }
@ -1586,7 +1641,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
r, _ := http.NewRequest("GET", "/bla", nil) r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11" r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", "1.1.1.1") r.Header.Set("X-Forwarded-For", "1.1.1.1")
v, err := s.visitor(r) v, err := s.maybeAuthenticate(r)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "1.1.1.1", v.ip.String()) require.Equal(t, "1.1.1.1", v.ip.String())
} }
@ -1598,7 +1653,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
r, _ := http.NewRequest("GET", "/bla", nil) r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11" r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ") r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
v, err := s.visitor(r) v, err := s.maybeAuthenticate(r)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "234.5.2.1", v.ip.String()) require.Equal(t, "234.5.2.1", v.ip.String())
} }
@ -1611,7 +1666,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
// Add lots of messages // Add lots of messages
log.Printf("Adding %d messages", count) log.Info("Adding %d messages", count)
start := time.Now() start := time.Now()
messages := make([]*message, 0) messages := make([]*message, 0)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
@ -1621,31 +1676,31 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
messages = append(messages, newDefaultMessage(topicID, "some message")) messages = append(messages, newDefaultMessage(topicID, "some message"))
} }
require.Nil(t, s.messageCache.addMessages(messages)) require.Nil(t, s.messageCache.addMessages(messages))
log.Printf("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond)) log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
// Update stats // Update stats
statsChan := make(chan bool) statsChan := make(chan bool)
go func() { go func() {
log.Printf("Updating stats") log.Info("Updating stats")
start := time.Now() start := time.Now()
s.execManager() s.execManager()
log.Printf("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond)) log.Info("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
statsChan <- true statsChan <- true
}() }()
time.Sleep(50 * time.Millisecond) // Make sure it starts first time.Sleep(50 * time.Millisecond) // Make sure it starts first
// Publish message (during stats update) // Publish message (during stats update)
log.Printf("Publishing message") log.Info("Publishing message")
start = time.Now() start = time.Now()
response := request(t, s, "PUT", "/mytopic", "some body", nil) response := request(t, s, "PUT", "/mytopic", "some body", nil)
m := toMessage(t, response.Body.String()) m := toMessage(t, response.Body.String())
assert.Equal(t, "some body", m.Message) assert.Equal(t, "some body", m.Message)
assert.True(t, time.Since(start) < 100*time.Millisecond) assert.True(t, time.Since(start) < 100*time.Millisecond)
log.Printf("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond)) log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
// Wait for all goroutines // Wait for all goroutines
<-statsChan <-statsChan
log.Printf("Done: Waiting for all locks") log.Info("Done: Waiting for all locks")
} }
func newTestConfig(t *testing.T) *Config { func newTestConfig(t *testing.T) *Config {

View file

@ -14,16 +14,39 @@ import (
) )
const ( const (
// oneDay is an approximation of a day as a time.Duration
oneDay = 24 * time.Hour
// visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number // visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number
// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since // has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
// they are replenished faster (typically). // they are replenished faster (typically).
visitorExpungeAfter = 24 * time.Hour visitorExpungeAfter = oneDay
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve. // visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise // This number is zero, and changing it may have unintended consequences in the web app, or otherwise
visitorDefaultReservationsLimit = int64(0) visitorDefaultReservationsLimit = int64(0)
) )
// Constants used to convert a tier-user's MessageLimit (see user.Tier) into adequate request limiter
// values (token bucket).
//
// Example: Assuming a user.Tier's MessageLimit is 10,000:
// - the allowed burst is 500 (= 10,000 * 5%), which is < 1000 (the max)
// - the replenish rate is 2 * 10,000 / 24 hours
const (
visitorMessageToRequestLimitBurstRate = 0.05
visitorMessageToRequestLimitBurstMax = 1000
visitorMessageToRequestLimitReplenishFactor = 2
)
// Constants used to convert a tier-user's EmailLimit (see user.Tier) into adequate email limiter
// values (token bucket). Example: Assuming a user.Tier's EmailLimit is 200, the allowed burst is
// 40 (= 200 * 20%), which is <150 (the max).
const (
visitorEmailLimitBurstRate = 0.2
visitorEmailLimitBurstMax = 150
)
var ( var (
errVisitorLimitReached = errors.New("limit reached") errVisitorLimitReached = errors.New("limit reached")
) )
@ -55,9 +78,13 @@ type visitorInfo struct {
type visitorLimits struct { type visitorLimits struct {
Basis visitorLimitBasis Basis visitorLimitBasis
MessagesLimit int64 RequestLimitBurst int
MessagesExpiryDuration time.Duration RequestLimitReplenish rate.Limit
EmailsLimit int64 MessageLimit int64
MessageExpiryDuration time.Duration
EmailLimit int64
EmailLimitBurst int
EmailLimitReplenish rate.Limit
ReservationsLimit int64 ReservationsLimit int64
AttachmentTotalSizeLimit int64 AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64 AttachmentFileSizeLimit int64
@ -173,7 +200,7 @@ func (v *visitor) SubscriptionAllowed() error {
} }
func (v *visitor) AccountCreationAllowed() error { func (v *visitor) AccountCreationAllowed() error {
if v.accountLimiter != nil && !v.accountLimiter.Allow() { if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) {
return errVisitorLimitReached return errVisitorLimitReached
} }
return nil return nil
@ -242,31 +269,6 @@ func (v *visitor) SetUser(u *user.User) {
} }
} }
func (v *visitor) resetLimiters() {
log.Info("%s Resetting limiters for visitor", v.stringNoLock())
var messagesLimiter, bandwidthLimiter util.Limiter
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
if v.user != nil && v.user.Tier != nil {
requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
} else {
requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
messagesLimiter = nil // Message limit is governed by the requestLimiter
emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
}
if v.user == nil {
accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
}
v.requestLimiter = requestLimiter
v.messagesLimiter = messagesLimiter
v.emailsLimiter = emailsLimiter
v.bandwidthLimiter = bandwidthLimiter
v.accountLimiter = accountLimiter
}
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor, // MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
// an empty string is returned. // an empty string is returned.
func (v *visitor) MaybeUserID() string { func (v *visitor) MaybeUserID() string {
@ -278,22 +280,71 @@ func (v *visitor) MaybeUserID() string {
return "" return ""
} }
func (v *visitor) resetLimiters() {
log.Debug("%s Resetting limiters for visitor", v.stringNoLock())
limits := v.limitsNoLock()
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, v.messages)
v.emailsLimiter = rate.NewLimiter(limits.EmailLimitReplenish, limits.EmailLimitBurst)
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
if v.user == nil {
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
} else {
v.accountLimiter = nil // Users cannot create accounts when logged in
}
}
func (v *visitor) Limits() *visitorLimits { func (v *visitor) Limits() *visitorLimits {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()
limits := defaultVisitorLimits(v.config) return v.limitsNoLock()
}
func (v *visitor) limitsNoLock() *visitorLimits {
if v.user != nil && v.user.Tier != nil { if v.user != nil && v.user.Tier != nil {
limits.Basis = visitorLimitBasisTier return tierBasedVisitorLimits(v.config, v.user.Tier)
limits.MessagesLimit = v.user.Tier.MessagesLimit }
limits.MessagesExpiryDuration = v.user.Tier.MessagesExpiryDuration return configBasedVisitorLimits(v.config)
limits.EmailsLimit = v.user.Tier.EmailsLimit }
limits.ReservationsLimit = v.user.Tier.ReservationsLimit
limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits {
limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit return &visitorLimits{
limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration Basis: visitorLimitBasisTier,
limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit RequestLimitBurst: util.MinMax(int(float64(tier.MessageLimit)*visitorMessageToRequestLimitBurstRate), conf.VisitorRequestLimitBurst, visitorMessageToRequestLimitBurstMax),
RequestLimitReplenish: dailyLimitToRate(tier.MessageLimit * visitorMessageToRequestLimitReplenishFactor),
MessageLimit: tier.MessageLimit,
MessageExpiryDuration: tier.MessageExpiryDuration,
EmailLimit: tier.EmailLimit,
EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax),
EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit),
ReservationsLimit: tier.ReservationLimit,
AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit,
AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit,
AttachmentExpiryDuration: tier.AttachmentExpiryDuration,
AttachmentBandwidthLimit: tier.AttachmentBandwidthLimit,
}
}
func configBasedVisitorLimits(conf *Config) *visitorLimits {
messagesLimit := replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish) // Approximation!
if conf.VisitorMessageDailyLimit > 0 {
messagesLimit = int64(conf.VisitorMessageDailyLimit)
}
return &visitorLimits{
Basis: visitorLimitBasisIP,
RequestLimitBurst: conf.VisitorRequestLimitBurst,
RequestLimitReplenish: rate.Every(conf.VisitorRequestLimitReplenish),
MessageLimit: messagesLimit,
MessageExpiryDuration: conf.CacheDuration,
EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation!
EmailLimitBurst: conf.VisitorEmailLimitBurst,
EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish),
ReservationsLimit: visitorDefaultReservationsLimit,
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
} }
return limits
} }
func (v *visitor) Info() (*visitorInfo, error) { func (v *visitor) Info() (*visitorInfo, error) {
@ -321,9 +372,9 @@ func (v *visitor) Info() (*visitorInfo, error) {
limits := v.Limits() limits := v.Limits()
stats := &visitorStats{ stats := &visitorStats{
Messages: messages, Messages: messages,
MessagesRemaining: zeroIfNegative(limits.MessagesLimit - messages), MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages),
Emails: emails, Emails: emails,
EmailsRemaining: zeroIfNegative(limits.EmailsLimit - emails), EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails),
Reservations: reservations, Reservations: reservations,
ReservationsRemaining: zeroIfNegative(limits.ReservationsLimit - reservations), ReservationsRemaining: zeroIfNegative(limits.ReservationsLimit - reservations),
AttachmentTotalSize: attachmentsBytesUsed, AttachmentTotalSize: attachmentsBytesUsed,
@ -343,23 +394,16 @@ func zeroIfNegative(value int64) int64 {
} }
func replenishDurationToDailyLimit(duration time.Duration) int64 { func replenishDurationToDailyLimit(duration time.Duration) int64 {
return int64(24 * time.Hour / duration) return int64(oneDay / duration)
} }
func dailyLimitToRate(limit int64) rate.Limit { func dailyLimitToRate(limit int64) rate.Limit {
return rate.Limit(limit) * rate.Every(24*time.Hour) return rate.Limit(limit) * rate.Every(oneDay)
} }
func defaultVisitorLimits(conf *Config) *visitorLimits { func visitorID(ip netip.Addr, u *user.User) string {
return &visitorLimits{ if u != nil && u.Tier != nil {
Basis: visitorLimitBasisIP, return fmt.Sprintf("user:%s", u.ID)
MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish),
MessagesExpiryDuration: conf.CacheDuration,
EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish),
ReservationsLimit: visitorDefaultReservationsLimit,
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
} }
return fmt.Sprintf("ip:%s", ip.String())
} }

View file

@ -709,10 +709,10 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
ID: tierID.String, ID: tierID.String,
Code: tierCode.String, Code: tierCode.String,
Name: tierName.String, Name: tierName.String,
MessagesLimit: messagesLimit.Int64, MessageLimit: messagesLimit.Int64,
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailsLimit: emailsLimit.Int64, EmailLimit: emailsLimit.Int64,
ReservationsLimit: reservationsLimit.Int64, ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
@ -845,7 +845,7 @@ func (a *Manager) ChangeTier(username, tier string) error {
t, err := a.Tier(tier) t, err := a.Tier(tier)
if err != nil { if err != nil {
return err return err
} else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil { } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
return err return err
} }
if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil { if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
@ -870,7 +870,7 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
if err != nil { if err != nil {
return err return err
} }
if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit { if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
reservations, err := a.Reservations(username) reservations, err := a.Reservations(username)
if err != nil { if err != nil {
return err return err
@ -999,7 +999,7 @@ func (a *Manager) CreateTier(tier *Tier) error {
if tier.ID == "" { if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
} }
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil { if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
return err return err
} }
return nil return nil
@ -1070,10 +1070,10 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
ID: id, ID: id,
Code: code, Code: code,
Name: name, Name: name,
MessagesLimit: messagesLimit.Int64, MessageLimit: messagesLimit.Int64,
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailsLimit: emailsLimit.Int64, EmailLimit: emailsLimit.Int64,
ReservationsLimit: reservationsLimit.Int64, ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,

View file

@ -335,10 +335,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
Code: "pro", Code: "pro",
Name: "ntfy Pro", Name: "ntfy Pro",
StripePriceID: "price123", StripePriceID: "price123",
MessagesLimit: 5_000, MessageLimit: 5_000,
MessagesExpiryDuration: 3 * 24 * time.Hour, MessageExpiryDuration: 3 * 24 * time.Hour,
EmailsLimit: 50, EmailLimit: 50,
ReservationsLimit: 5, ReservationLimit: 5,
AttachmentFileSizeLimit: 52428800, AttachmentFileSizeLimit: 52428800,
AttachmentTotalSizeLimit: 524288000, AttachmentTotalSizeLimit: 524288000,
AttachmentExpiryDuration: 24 * time.Hour, AttachmentExpiryDuration: 24 * time.Hour,
@ -351,10 +351,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, RoleUser, ben.Role) require.Equal(t, RoleUser, ben.Role)
require.Equal(t, "pro", ben.Tier.Code) require.Equal(t, "pro", ben.Tier.Code)
require.Equal(t, int64(5000), ben.Tier.MessagesLimit) require.Equal(t, int64(5000), ben.Tier.MessageLimit)
require.Equal(t, 3*24*time.Hour, ben.Tier.MessagesExpiryDuration) require.Equal(t, 3*24*time.Hour, ben.Tier.MessageExpiryDuration)
require.Equal(t, int64(50), ben.Tier.EmailsLimit) require.Equal(t, int64(50), ben.Tier.EmailLimit)
require.Equal(t, int64(5), ben.Tier.ReservationsLimit) require.Equal(t, int64(5), ben.Tier.ReservationLimit)
require.Equal(t, int64(52428800), ben.Tier.AttachmentFileSizeLimit) require.Equal(t, int64(52428800), ben.Tier.AttachmentFileSizeLimit)
require.Equal(t, int64(524288000), ben.Tier.AttachmentTotalSizeLimit) require.Equal(t, int64(524288000), ben.Tier.AttachmentTotalSizeLimit)
require.Equal(t, 24*time.Hour, ben.Tier.AttachmentExpiryDuration) require.Equal(t, 24*time.Hour, ben.Tier.AttachmentExpiryDuration)

View file

@ -62,10 +62,10 @@ type Tier struct {
ID string // Tier identifier (ti_...) ID string // Tier identifier (ti_...)
Code string // Code of the tier Code string // Code of the tier
Name string // Name of the tier Name string // Name of the tier
MessagesLimit int64 // Daily message limit MessageLimit int64 // Daily message limit
MessagesExpiryDuration time.Duration // Cache duration for messages MessageExpiryDuration time.Duration // Cache duration for messages
EmailsLimit int64 // Daily email limit EmailLimit int64 // Daily email limit
ReservationsLimit int64 // Number of topic reservations allowed by user ReservationLimit int64 // Number of topic reservations allowed by user
AttachmentFileSizeLimit int64 // Max file size per file (bytes) AttachmentFileSizeLimit int64 // Max file size per file (bytes)
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes) AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted

View file

@ -27,8 +27,14 @@ type FixedLimiter struct {
// NewFixedLimiter creates a new Limiter // NewFixedLimiter creates a new Limiter
func NewFixedLimiter(limit int64) *FixedLimiter { func NewFixedLimiter(limit int64) *FixedLimiter {
return NewFixedLimiterWithValue(limit, 0)
}
// NewFixedLimiterWithValue creates a new Limiter and sets the initial value
func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
return &FixedLimiter{ return &FixedLimiter{
limit: limit, limit: limit,
value: value,
} }
} }

View file

@ -17,7 +17,7 @@ var (
// NextOccurrenceUTC takes a time of day (e.g. 9:00am), and returns the next occurrence // NextOccurrenceUTC takes a time of day (e.g. 9:00am), and returns the next occurrence
// of that time from the current time (in UTC). // of that time from the current time (in UTC).
func NextOccurrenceUTC(timeOfDay, base time.Time) time.Time { func NextOccurrenceUTC(timeOfDay, base time.Time) time.Time {
hour, minute, seconds := timeOfDay.Clock() hour, minute, seconds := timeOfDay.UTC().Clock()
now := base.UTC() now := base.UTC()
next := time.Date(now.Year(), now.Month(), now.Day(), hour, minute, seconds, 0, time.UTC) next := time.Date(now.Year(), now.Month(), now.Day(), hour, minute, seconds, 0, time.UTC)
if next.Before(now) { if next.Before(now) {

View file

@ -337,6 +337,17 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error
return nil, err return nil, err
} }
// MinMax returns value if it is between min and max, or either
// min or max if it is out of range
func MinMax[T int | int64](value, min, max T) T {
if value < min {
return min
} else if value > max {
return max
}
return value
}
// String turns a string into a pointer of a string // String turns a string into a pointer of a string
func String(v string) *string { func String(v string) *string {
return &v return &v