From c874a641df84c30c5d303a96f7f7240bdbcf257e Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 26 Jan 2023 22:57:18 -0500 Subject: [PATCH] Rate limits make sense now! --- cmd/serve.go | 3 + server/config.go | 3 + server/errors.go | 4 +- server/server.go | 49 ++++------- server/server.yml | 6 ++ server/server_account.go | 14 +-- server/server_account_test.go | 36 ++++---- server/server_payments.go | 22 ++--- server/server_payments_test.go | 101 +++++++++++++++------ server/server_test.go | 111 +++++++++++++++++------ server/visitor.go | 156 +++++++++++++++++++++------------ user/manager.go | 22 ++--- user/manager_test.go | 16 ++-- user/types.go | 8 +- util/limit.go | 6 ++ util/time.go | 2 +- util/util.go | 11 +++ 17 files changed, 365 insertions(+), 205 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 729b159..974ef4b 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -77,6 +77,7 @@ var flagsServe = append( altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"visitor_request_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), @@ -150,6 +151,7 @@ func execServe(c *cli.Context) error { visitorRequestLimitBurst := c.Int("visitor-request-limit-burst") visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish") visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",") + visitorMessageDailyLimit := c.Int("visitor-message-daily-limit") visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") behindProxy := c.Bool("behind-proxy") @@ -289,6 +291,7 @@ func execServe(c *cli.Context) error { conf.VisitorRequestLimitBurst = visitorRequestLimitBurst conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs + conf.VisitorMessageDailyLimit = visitorMessageDailyLimit conf.VisitorEmailLimitBurst = visitorEmailLimitBurst conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish conf.BehindProxy = behindProxy diff --git a/server/config.go b/server/config.go index f67fe8f..fb85670 100644 --- a/server/config.go +++ b/server/config.go @@ -44,6 +44,7 @@ const ( DefaultVisitorSubscriptionLimit = 30 DefaultVisitorRequestLimitBurst = 60 DefaultVisitorRequestLimitReplenish = 5 * time.Second + DefaultVisitorMessageDailyLimit = 0 DefaultVisitorEmailLimitBurst = 16 DefaultVisitorEmailLimitReplenish = time.Hour DefaultVisitorAccountCreationLimitBurst = 3 @@ -105,6 +106,7 @@ type Config struct { VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration VisitorRequestExemptIPAddrs []netip.Prefix + VisitorMessageDailyLimit int VisitorEmailLimitBurst int VisitorEmailLimitReplenish time.Duration VisitorAccountCreationLimitBurst int @@ -171,6 +173,7 @@ func NewConfig() *Config { VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0), + VisitorMessageDailyLimit: DefaultVisitorMessageDailyLimit, VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst, diff --git a/server/errors.go b/server/errors.go index 359cedc..81fb9ea 100644 --- a/server/errors.go +++ b/server/errors.go @@ -75,10 +75,10 @@ var ( errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"} - errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth reached", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""} - errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: too many messages", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"} diff --git a/server/server.go b/server/server.go index d79c967..158c4d4 100644 --- a/server/server.go +++ b/server/server.go @@ -38,10 +38,9 @@ import ( TODO -- -- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) -- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters - HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)? +- MEDIUM Rate limiting: Test daily message quota read from database initially - MEDIUM: Races with v.user (see publishSyncEventAsync test) - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade) @@ -57,7 +56,6 @@ Make sure account endpoints make sense for admins Tests: - Payment endpoints (make mocks) -- test that the visitor is based on the IP address when a user has no tier */ // Server is the main server, providing the UI and API for ntfy @@ -308,7 +306,7 @@ func (s *Server) Stop() { } func (s *Server) handle(w http.ResponseWriter, r *http.Request) { - v, err := s.visitor(r) // Note: Always returns v, even when error is returned + v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned if err == nil { log.Debug("%s Dispatching request", logHTTPPrefix(v, r)) if log.IsTrace() { @@ -563,7 +561,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if v.user != nil { m.User = v.user.ID } - m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix() + m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix() if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { return nil, err } @@ -601,7 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes } v.IncrementMessages() if s.userManager != nil && v.user != nil { - s.userManager.EnqueueStats(v.user) + s.userManager.EnqueueStats(v.user) // FIXME this makes no sense for tier-less users } s.mu.Lock() s.messages++ @@ -1382,8 +1380,10 @@ func (s *Server) runStatsResetter() { log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt) select { case <-timer.C: + log.Debug("Stats resetter: Running") s.resetStats() case <-s.closeChan: + log.Debug("Stats resetter: Stopping timer") timer.Stop() return } @@ -1440,17 +1440,15 @@ func (s *Server) sendDelayedMessages() error { return err } for _, m := range messages { - var v *visitor + var u *user.User if s.userManager != nil && m.User != "" { - u, err := s.userManager.User(m.User) + u, err = s.userManager.User(m.User) if err != nil { - log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error()) + log.Warn("Error sending delayed message %s: %s", m.ID, err.Error()) continue } - v = s.visitorFromUser(u, m.Sender) - } else { - v = s.visitorFromIP(m.Sender) } + v := s.visitor(m.Sender, u) if err := s.sendDelayedMessage(v, m); err != nil { log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error()) } @@ -1588,20 +1586,16 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc } } -// visitor creates or retrieves a rate.Limiter for the given visitor. +// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor. // Note that this function will always return a visitor, even if an error occurs. -func (s *Server) visitor(r *http.Request) (v *visitor, err error) { +func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) { ip := extractIPAddress(r, s.config.BehindProxy) var u *user.User // may stay nil if no auth header! if u, err = s.authenticate(r); err != nil { log.Debug("authentication failed: %s", err.Error()) err = errHTTPUnauthorized // Always return visitor, even when error occurs! } - if u != nil { - v = s.visitorFromUser(u, ip) - } else { - v = s.visitorFromIP(ip) - } + v = s.visitor(ip, u) v.SetUser(u) // Update visitor user with latest from database! return v, err // Always return visitor, even when error occurs! } @@ -1645,26 +1639,19 @@ func (s *Server) authenticateBearerAuth(value string) (user *user.User, err erro return s.userManager.AuthenticateToken(token) } -func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor { +func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor { s.mu.Lock() defer s.mu.Unlock() - v, exists := s.visitors[visitorID] + id := visitorID(ip, user) + v, exists := s.visitors[id] if !exists { - s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user) - return s.visitors[visitorID] + s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user) + return s.visitors[id] } v.Keepalive() return v } -func (s *Server) visitorFromIP(ip netip.Addr) *visitor { - return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil) -} - -func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor { - return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user) -} - func (s *Server) writeJSON(w http.ResponseWriter, v any) error { w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests diff --git a/server/server.yml b/server/server.yml index 51e29c4..76198f9 100644 --- a/server/server.yml +++ b/server/server.yml @@ -200,6 +200,12 @@ # visitor-request-limit-replenish: "5s" # visitor-request-limit-exempt-hosts: "" +# Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset +# every day at midnight UTC. If the limit is not set (or set to zero), the request +# limit (see above) governs the upper limit. +# +# visitor-message-daily-limit: 0 + # Rate limiting: Allowed emails per visitor: # - visitor-email-limit-burst is the initial bucket of emails each visitor has # - visitor-email-limit-replenish is the rate at which the bucket is refilled diff --git a/server/server_account.go b/server/server_account.go index da02cab..718f022 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -23,6 +23,9 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v * } else if v.user != nil { return errHTTPUnauthorized // Cannot create account from user context } + if err := v.AccountCreationAllowed(); err != nil { + return errHTTPTooManyRequestsLimitAccountCreation + } } newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit) if err != nil { @@ -31,9 +34,6 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v * if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil { return errHTTPConflictUserExists } - if err := v.AccountCreationAllowed(); err != nil { - return errHTTPTooManyRequestsLimitAccountCreation - } if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User return err } @@ -49,9 +49,9 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis response := &apiAccountResponse{ Limits: &apiAccountLimits{ Basis: string(limits.Basis), - Messages: limits.MessagesLimit, - MessagesExpiryDuration: int64(limits.MessagesExpiryDuration.Seconds()), - Emails: limits.EmailsLimit, + Messages: limits.MessageLimit, + MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()), + Emails: limits.EmailLimit, Reservations: limits.ReservationsLimit, AttachmentTotalSize: limits.AttachmentTotalSizeLimit, AttachmentFileSize: limits.AttachmentFileSizeLimit, @@ -344,7 +344,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ reservations, err := s.userManager.ReservationsCount(v.user.Name) if err != nil { return err - } else if reservations >= v.user.Tier.ReservationsLimit { + } else if reservations >= v.user.Tier.ReservationLimit { return errHTTPTooManyRequestsLimitReservations } } diff --git a/server/server_account_test.go b/server/server_account_test.go index 79504b6..e3bbf11 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -410,10 +410,10 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) { // Create a tier require.Nil(t, s.userManager.CreateTier(&user.Tier{ Code: "pro", - MessagesLimit: 123, - MessagesExpiryDuration: 86400 * time.Second, - EmailsLimit: 32, - ReservationsLimit: 2, + MessageLimit: 123, + MessageExpiryDuration: 86400 * time.Second, + EmailLimit: 32, + ReservationLimit: 2, AttachmentFileSizeLimit: 1231231, AttachmentTotalSizeLimit: 123123, AttachmentExpiryDuration: 10800 * time.Second, @@ -491,9 +491,9 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) { require.Equal(t, 200, rr.Code) require.Nil(t, s.userManager.CreateTier(&user.Tier{ - Code: "pro", - MessagesLimit: 20, - ReservationsLimit: 2, + Code: "pro", + MessageLimit: 20, + ReservationLimit: 2, })) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) @@ -525,9 +525,9 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { require.Equal(t, 200, rr.Code) require.Nil(t, s.userManager.CreateTier(&user.Tier{ - Code: "pro", - MessagesLimit: 20, - ReservationsLimit: 2, + Code: "pro", + MessageLimit: 20, + ReservationLimit: 2, })) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) @@ -591,10 +591,10 @@ func TestAccount_Tier_Create(t *testing.T) { require.Nil(t, s.userManager.CreateTier(&user.Tier{ Code: "pro", Name: "Pro", - MessagesLimit: 123, - MessagesExpiryDuration: 86400 * time.Second, - EmailsLimit: 32, - ReservationsLimit: 2, + MessageLimit: 123, + MessageExpiryDuration: 86400 * time.Second, + EmailLimit: 32, + ReservationLimit: 2, AttachmentFileSizeLimit: 1231231, AttachmentTotalSizeLimit: 123123, AttachmentExpiryDuration: 10800 * time.Second, @@ -616,10 +616,10 @@ func TestAccount_Tier_Create(t *testing.T) { require.True(t, strings.HasPrefix(ti.ID, "ti_")) require.Equal(t, "pro", ti.Code) require.Equal(t, "Pro", ti.Name) - require.Equal(t, int64(123), ti.MessagesLimit) - require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration) - require.Equal(t, int64(32), ti.EmailsLimit) - require.Equal(t, int64(2), ti.ReservationsLimit) + require.Equal(t, int64(123), ti.MessageLimit) + require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration) + require.Equal(t, int64(32), ti.EmailLimit) + require.Equal(t, int64(2), ti.ReservationLimit) require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit) require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) diff --git a/server/server_payments.go b/server/server_payments.go index 2178d48..4e92757 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -60,15 +60,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ if err != nil { return err } - freeTier := defaultVisitorLimits(s.config) + freeTier := configBasedVisitorLimits(s.config) response := []*apiAccountBillingTier{ { // This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price. Limits: &apiAccountLimits{ Basis: string(visitorLimitBasisIP), - Messages: freeTier.MessagesLimit, - MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()), - Emails: freeTier.EmailsLimit, + Messages: freeTier.MessageLimit, + MessagesExpiryDuration: int64(freeTier.MessageExpiryDuration.Seconds()), + Emails: freeTier.EmailLimit, Reservations: freeTier.ReservationsLimit, AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit, AttachmentFileSize: freeTier.AttachmentFileSizeLimit, @@ -91,10 +91,10 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ Price: priceStr, Limits: &apiAccountLimits{ Basis: string(visitorLimitBasisTier), - Messages: tier.MessagesLimit, - MessagesExpiryDuration: int64(tier.MessagesExpiryDuration.Seconds()), - Emails: tier.EmailsLimit, - Reservations: tier.ReservationsLimit, + Messages: tier.MessageLimit, + MessagesExpiryDuration: int64(tier.MessageExpiryDuration.Seconds()), + Emails: tier.EmailLimit, + Reservations: tier.ReservationLimit, AttachmentTotalSize: tier.AttachmentTotalSizeLimit, AttachmentFileSize: tier.AttachmentFileSizeLimit, AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()), @@ -336,7 +336,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil { return err } - s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) + s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u)) return nil } @@ -355,14 +355,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil { return err } - s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) + s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u)) return nil } func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { reservationsLimit := visitorDefaultReservationsLimit if tier != nil { - reservationsLimit = tier.ReservationsLimit + reservationsLimit = tier.ReservationLimit } if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil { return err diff --git a/server/server_payments_test.go b/server/server_payments_test.go index d7bc660..1576ab0 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -5,11 +5,14 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stripe/stripe-go/v74" + "golang.org/x/time/rate" "heckel.io/ntfy/user" "heckel.io/ntfy/util" "io" + "net/netip" "path/filepath" "strings" + "sync" "testing" "time" ) @@ -48,10 +51,10 @@ func TestPayments_Tiers(t *testing.T) { ID: "ti_123", Code: "pro", Name: "Pro", - MessagesLimit: 1000, - MessagesExpiryDuration: time.Hour, - EmailsLimit: 123, - ReservationsLimit: 777, + MessageLimit: 1000, + MessageExpiryDuration: time.Hour, + EmailLimit: 123, + ReservationLimit: 777, AttachmentFileSizeLimit: 999, AttachmentTotalSizeLimit: 888, AttachmentExpiryDuration: time.Minute, @@ -61,10 +64,10 @@ func TestPayments_Tiers(t *testing.T) { ID: "ti_444", Code: "business", Name: "Business", - MessagesLimit: 2000, - MessagesExpiryDuration: 10 * time.Hour, - EmailsLimit: 123123, - ReservationsLimit: 777333, + MessageLimit: 2000, + MessageExpiryDuration: 10 * time.Hour, + EmailLimit: 123123, + ReservationLimit: 777333, AttachmentFileSizeLimit: 999111, AttachmentTotalSizeLimit: 888111, AttachmentExpiryDuration: time.Hour, @@ -238,9 +241,14 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) { require.Equal(t, 401, rr.Code) } -func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) { - // This tests a successful checkout flow (not a paying customer -> paying customer), - // and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user. +func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) { + // This test is too overloaded, but it's also a great end-to-end a test. + // + // It tests: + // - A successful checkout flow (not a paying customer -> paying customer) + // - Tier-changes reset the rate limits for the user + // - The request limits for tier-less user and a tier-user + // - The message limits for a tier-user stripeMock := &testStripeAPI{} defer stripeMock.AssertExpectations(t) @@ -248,19 +256,26 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test c := newTestConfigWithAuthFile(t) c.StripeSecretKey = "secret key" c.StripeWebhookKey = "webhook key" - c.VisitorRequestLimitBurst = 10 + c.VisitorRequestLimitBurst = 5 c.VisitorRequestLimitReplenish = time.Hour + c.CacheStartupQueries = ` +pragma journal_mode = WAL; +pragma synchronous = normal; +pragma temp_store = memory; +` + c.CacheBatchSize = 500 + c.CacheBatchTimeout = time.Second s := newTestServer(t, c) s.stripe = stripeMock // Create a user with a Stripe subscription and 3 reservations require.Nil(t, s.userManager.CreateTier(&user.Tier{ - ID: "ti_123", - Code: "starter", - StripePriceID: "price_1234", - ReservationsLimit: 1, - MessagesLimit: 100, - MessagesExpiryDuration: time.Hour, + ID: "ti_123", + Code: "starter", + StripePriceID: "price_1234", + ReservationLimit: 1, + MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in + MessageExpiryDuration: time.Hour, })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier u, err := s.userManager.User("phil") @@ -298,7 +313,7 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test Return(&stripe.Customer{}, nil) // Send messages until rate limit of free tier is hit - for i := 0; i < 10; i++ { + for i := 0; i < 5; i++ { rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) @@ -323,10 +338,9 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix()) require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) - // FIXME FIXME This test is broken, because the rate limit logic is unclear! - // Now for the fun part: Verify that new rate limits are immediately applied - for i := 0; i < 100; i++ { + // This only tests the request limiter, which kicks in before the message limiter. + for i := 0; i < 11; i++ { rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) @@ -336,6 +350,37 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test "Authorization": util.BasicAuth("phil", "phil"), }) require.Equal(t, 429, rr.Code) + + // Now let's test the message limiter by faking a ridiculously generous rate limiter + v := s.visitor(netip.MustParseAddr("9.9.9.9"), u) + v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000) + + var wg sync.WaitGroup + for i := 0; i < 209; i++ { + wg.Add(1) + go func() { + rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + wg.Done() + }() + } + wg.Wait() + rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 429, rr.Code) + + // And now let's cross-check that the stats are correct too + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Equal(t, int64(220), account.Limits.Messages) + require.Equal(t, int64(220), account.Stats.Messages) + require.Equal(t, int64(0), account.Stats.MessagesRemaining) } func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) { @@ -363,9 +408,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( ID: "ti_1", Code: "starter", StripePriceID: "price_1234", // ! - ReservationsLimit: 1, // ! - MessagesLimit: 100, - MessagesExpiryDuration: time.Hour, + ReservationLimit: 1, // ! + MessageLimit: 100, + MessageExpiryDuration: time.Hour, AttachmentExpiryDuration: time.Hour, AttachmentFileSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000, @@ -375,9 +420,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( ID: "ti_2", Code: "pro", StripePriceID: "price_1111", // ! - ReservationsLimit: 3, // ! - MessagesLimit: 200, - MessagesExpiryDuration: time.Hour, + ReservationLimit: 3, // ! + MessageLimit: 200, + MessageExpiryDuration: time.Hour, AttachmentExpiryDuration: time.Hour, AttachmentFileSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000, diff --git a/server/server_test.go b/server/server_test.go index 81a243e..a9b9889 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -8,7 +8,6 @@ import ( "fmt" "heckel.io/ntfy/user" "io" - "log" "math/rand" "net/http" "net/http/httptest" @@ -22,9 +21,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "heckel.io/ntfy/log" "heckel.io/ntfy/util" ) +func init() { + // log.SetLevel(log.DebugLevel) +} + func TestServer_PublishAndPoll(t *testing.T) { s := newTestServer(t, newTestConfig(t)) @@ -742,16 +746,31 @@ func TestServer_Auth_ViaQuery(t *testing.T) { require.Equal(t, 401, response.Code) } -func TestServer_StatsResetter(t *testing.T) { +func TestServer_StatsResetter_User_Without_Tier(t *testing.T) { + // This tests the stats resetter for + // - an anonymous user + // - a user without a tier (treated like the same as the anonymous user) + // - a user with a tier + c := newTestConfigWithAuthFile(t) - c.AuthDefault = user.PermissionDenyAll c.VisitorStatsResetTime = time.Now().Add(2 * time.Second) s := newTestServer(t, c) go s.runStatsResetter() + // Create user with tier (tieruser) and user without tier (phil) + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "test", + MessageLimit: 5, + MessageExpiryDuration: -5 * time.Second, // Second, what a hack! + })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) - require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite)) + require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser)) + require.Nil(t, s.userManager.ChangeTier("tieruser", "test")) + // Send an anonymous message + response := request(t, s, "PUT", "/mytopic", "test", nil) + + // Send messages from user without tier (phil) for i := 0; i < 5; i++ { response := request(t, s, "PUT", "/mytopic", "test", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), @@ -759,30 +778,66 @@ func TestServer_StatsResetter(t *testing.T) { require.Equal(t, 200, response.Code) } - response := request(t, s, "GET", "/v1/account", "", map[string]string{ - "Authorization": util.BasicAuth("phil", "phil"), - }) - require.Equal(t, 200, response.Code) + // Send messages from user with tier + for i := 0; i < 2; i++ { + response := request(t, s, "PUT", "/mytopic", "test", map[string]string{ + "Authorization": util.BasicAuth("tieruser", "tieruser"), + }) + require.Equal(t, 200, response.Code) + } - // User stats show 10 messages + // User stats show 6 messages (for user without tier) response = request(t, s, "GET", "/v1/account", "", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) require.Equal(t, 200, response.Code) account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body)) require.Nil(t, err) - require.Equal(t, int64(5), account.Stats.Messages) + require.Equal(t, int64(6), account.Stats.Messages) + + // User stats show 6 messages (for anonymous visitor) + response = request(t, s, "GET", "/v1/account", "", nil) + require.Equal(t, 200, response.Code) + account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body)) + require.Nil(t, err) + require.Equal(t, int64(6), account.Stats.Messages) + + // User stats show 2 messages (for user with tier) + response = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("tieruser", "tieruser"), + }) + require.Equal(t, 200, response.Code) + account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body)) + require.Nil(t, err) + require.Equal(t, int64(2), account.Stats.Messages) // Wait for stats resetter to run time.Sleep(2200 * time.Millisecond) // User stats show 0 messages now! + response = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, response.Code) + account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body)) + require.Nil(t, err) + require.Equal(t, int64(0), account.Stats.Messages) + + // Since this is a user without a tier, the anonymous user should have the same stats response = request(t, s, "GET", "/v1/account", "", nil) require.Equal(t, 200, response.Code) account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body)) require.Nil(t, err) require.Equal(t, int64(0), account.Stats.Messages) + // User stats show 0 messages (for user with tier) + response = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("tieruser", "tieruser"), + }) + require.Equal(t, 200, response.Code) + account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body)) + require.Nil(t, err) + require.Equal(t, int64(0), account.Stats.Messages) } type testMailer struct { @@ -1133,9 +1188,9 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) { // Create tier with certain limits require.Nil(t, s.userManager.CreateTier(&user.Tier{ - Code: "test", - MessagesLimit: 5, - MessagesExpiryDuration: -5 * time.Second, // Second, what a hack! + Code: "test", + MessageLimit: 5, + MessageExpiryDuration: -5 * time.Second, // Second, what a hack! })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "test")) @@ -1363,8 +1418,8 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) { sevenDays := time.Duration(604800) * time.Second require.Nil(t, s.userManager.CreateTier(&user.Tier{ Code: "test", - MessagesLimit: 10, - MessagesExpiryDuration: sevenDays, + MessageLimit: 10, + MessageExpiryDuration: sevenDays, AttachmentFileSizeLimit: 50_000, AttachmentTotalSizeLimit: 200_000, AttachmentExpiryDuration: sevenDays, // 7 days @@ -1407,8 +1462,8 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) { // Create tier with certain limits require.Nil(t, s.userManager.CreateTier(&user.Tier{ Code: "test", - MessagesLimit: 10, - MessagesExpiryDuration: time.Hour, + MessageLimit: 10, + MessageExpiryDuration: time.Hour, AttachmentFileSizeLimit: 50_000, AttachmentTotalSizeLimit: 200_000, AttachmentExpiryDuration: time.Hour, @@ -1450,7 +1505,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) { // Create tier with certain limits require.Nil(t, s.userManager.CreateTier(&user.Tier{ Code: "test", - MessagesLimit: 100, + MessageLimit: 100, AttachmentFileSizeLimit: 50_000, AttachmentTotalSizeLimit: 200_000, AttachmentExpiryDuration: 30 * time.Second, @@ -1574,7 +1629,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) { r, _ := http.NewRequest("GET", "/bla", nil) r.RemoteAddr = "8.9.10.11" r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty! - v, err := s.visitor(r) + v, err := s.maybeAuthenticate(r) require.Nil(t, err) require.Equal(t, "8.9.10.11", v.ip.String()) } @@ -1586,7 +1641,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) { r, _ := http.NewRequest("GET", "/bla", nil) r.RemoteAddr = "8.9.10.11" r.Header.Set("X-Forwarded-For", "1.1.1.1") - v, err := s.visitor(r) + v, err := s.maybeAuthenticate(r) require.Nil(t, err) require.Equal(t, "1.1.1.1", v.ip.String()) } @@ -1598,7 +1653,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) { r, _ := http.NewRequest("GET", "/bla", nil) r.RemoteAddr = "8.9.10.11" r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ") - v, err := s.visitor(r) + v, err := s.maybeAuthenticate(r) require.Nil(t, err) require.Equal(t, "234.5.2.1", v.ip.String()) } @@ -1611,7 +1666,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { s := newTestServer(t, c) // Add lots of messages - log.Printf("Adding %d messages", count) + log.Info("Adding %d messages", count) start := time.Now() messages := make([]*message, 0) for i := 0; i < count; i++ { @@ -1621,31 +1676,31 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { messages = append(messages, newDefaultMessage(topicID, "some message")) } require.Nil(t, s.messageCache.addMessages(messages)) - log.Printf("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond)) + log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond)) // Update stats statsChan := make(chan bool) go func() { - log.Printf("Updating stats") + log.Info("Updating stats") start := time.Now() s.execManager() - log.Printf("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond)) + log.Info("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond)) statsChan <- true }() time.Sleep(50 * time.Millisecond) // Make sure it starts first // Publish message (during stats update) - log.Printf("Publishing message") + log.Info("Publishing message") start = time.Now() response := request(t, s, "PUT", "/mytopic", "some body", nil) m := toMessage(t, response.Body.String()) assert.Equal(t, "some body", m.Message) assert.True(t, time.Since(start) < 100*time.Millisecond) - log.Printf("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond)) + log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond)) // Wait for all goroutines <-statsChan - log.Printf("Done: Waiting for all locks") + log.Info("Done: Waiting for all locks") } func newTestConfig(t *testing.T) *Config { diff --git a/server/visitor.go b/server/visitor.go index be8e8e3..113662f 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -14,16 +14,39 @@ import ( ) const ( + // oneDay is an approximation of a day as a time.Duration + oneDay = 24 * time.Hour + // visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number // has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since // they are replenished faster (typically). - visitorExpungeAfter = 24 * time.Hour + visitorExpungeAfter = oneDay // visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve. // This number is zero, and changing it may have unintended consequences in the web app, or otherwise visitorDefaultReservationsLimit = int64(0) ) +// Constants used to convert a tier-user's MessageLimit (see user.Tier) into adequate request limiter +// values (token bucket). +// +// Example: Assuming a user.Tier's MessageLimit is 10,000: +// - the allowed burst is 500 (= 10,000 * 5%), which is < 1000 (the max) +// - the replenish rate is 2 * 10,000 / 24 hours +const ( + visitorMessageToRequestLimitBurstRate = 0.05 + visitorMessageToRequestLimitBurstMax = 1000 + visitorMessageToRequestLimitReplenishFactor = 2 +) + +// Constants used to convert a tier-user's EmailLimit (see user.Tier) into adequate email limiter +// values (token bucket). Example: Assuming a user.Tier's EmailLimit is 200, the allowed burst is +// 40 (= 200 * 20%), which is <150 (the max). +const ( + visitorEmailLimitBurstRate = 0.2 + visitorEmailLimitBurstMax = 150 +) + var ( errVisitorLimitReached = errors.New("limit reached") ) @@ -55,9 +78,13 @@ type visitorInfo struct { type visitorLimits struct { Basis visitorLimitBasis - MessagesLimit int64 - MessagesExpiryDuration time.Duration - EmailsLimit int64 + RequestLimitBurst int + RequestLimitReplenish rate.Limit + MessageLimit int64 + MessageExpiryDuration time.Duration + EmailLimit int64 + EmailLimitBurst int + EmailLimitReplenish rate.Limit ReservationsLimit int64 AttachmentTotalSizeLimit int64 AttachmentFileSizeLimit int64 @@ -173,7 +200,7 @@ func (v *visitor) SubscriptionAllowed() error { } func (v *visitor) AccountCreationAllowed() error { - if v.accountLimiter != nil && !v.accountLimiter.Allow() { + if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) { return errVisitorLimitReached } return nil @@ -242,31 +269,6 @@ func (v *visitor) SetUser(u *user.User) { } } -func (v *visitor) resetLimiters() { - log.Info("%s Resetting limiters for visitor", v.stringNoLock()) - var messagesLimiter, bandwidthLimiter util.Limiter - var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter - if v.user != nil && v.user.Tier != nil { - requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst) - messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit) - emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst) - bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour) - } else { - requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst) - messagesLimiter = nil // Message limit is governed by the requestLimiter - emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst) - bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour) - } - if v.user == nil { - accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst) - } - v.requestLimiter = requestLimiter - v.messagesLimiter = messagesLimiter - v.emailsLimiter = emailsLimiter - v.bandwidthLimiter = bandwidthLimiter - v.accountLimiter = accountLimiter -} - // MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor, // an empty string is returned. func (v *visitor) MaybeUserID() string { @@ -278,22 +280,71 @@ func (v *visitor) MaybeUserID() string { return "" } +func (v *visitor) resetLimiters() { + log.Debug("%s Resetting limiters for visitor", v.stringNoLock()) + limits := v.limitsNoLock() + v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst) + v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, v.messages) + v.emailsLimiter = rate.NewLimiter(limits.EmailLimitReplenish, limits.EmailLimitBurst) + v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay) + if v.user == nil { + v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst) + } else { + v.accountLimiter = nil // Users cannot create accounts when logged in + } +} + func (v *visitor) Limits() *visitorLimits { v.mu.Lock() defer v.mu.Unlock() - limits := defaultVisitorLimits(v.config) + return v.limitsNoLock() +} + +func (v *visitor) limitsNoLock() *visitorLimits { if v.user != nil && v.user.Tier != nil { - limits.Basis = visitorLimitBasisTier - limits.MessagesLimit = v.user.Tier.MessagesLimit - limits.MessagesExpiryDuration = v.user.Tier.MessagesExpiryDuration - limits.EmailsLimit = v.user.Tier.EmailsLimit - limits.ReservationsLimit = v.user.Tier.ReservationsLimit - limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit - limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit - limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration - limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit + return tierBasedVisitorLimits(v.config, v.user.Tier) + } + return configBasedVisitorLimits(v.config) +} + +func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits { + return &visitorLimits{ + Basis: visitorLimitBasisTier, + RequestLimitBurst: util.MinMax(int(float64(tier.MessageLimit)*visitorMessageToRequestLimitBurstRate), conf.VisitorRequestLimitBurst, visitorMessageToRequestLimitBurstMax), + RequestLimitReplenish: dailyLimitToRate(tier.MessageLimit * visitorMessageToRequestLimitReplenishFactor), + MessageLimit: tier.MessageLimit, + MessageExpiryDuration: tier.MessageExpiryDuration, + EmailLimit: tier.EmailLimit, + EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax), + EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit), + ReservationsLimit: tier.ReservationLimit, + AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit, + AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit, + AttachmentExpiryDuration: tier.AttachmentExpiryDuration, + AttachmentBandwidthLimit: tier.AttachmentBandwidthLimit, + } +} + +func configBasedVisitorLimits(conf *Config) *visitorLimits { + messagesLimit := replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish) // Approximation! + if conf.VisitorMessageDailyLimit > 0 { + messagesLimit = int64(conf.VisitorMessageDailyLimit) + } + return &visitorLimits{ + Basis: visitorLimitBasisIP, + RequestLimitBurst: conf.VisitorRequestLimitBurst, + RequestLimitReplenish: rate.Every(conf.VisitorRequestLimitReplenish), + MessageLimit: messagesLimit, + MessageExpiryDuration: conf.CacheDuration, + EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation! + EmailLimitBurst: conf.VisitorEmailLimitBurst, + EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish), + ReservationsLimit: visitorDefaultReservationsLimit, + AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit, + AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit, + AttachmentExpiryDuration: conf.AttachmentExpiryDuration, + AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit, } - return limits } func (v *visitor) Info() (*visitorInfo, error) { @@ -321,9 +372,9 @@ func (v *visitor) Info() (*visitorInfo, error) { limits := v.Limits() stats := &visitorStats{ Messages: messages, - MessagesRemaining: zeroIfNegative(limits.MessagesLimit - messages), + MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages), Emails: emails, - EmailsRemaining: zeroIfNegative(limits.EmailsLimit - emails), + EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails), Reservations: reservations, ReservationsRemaining: zeroIfNegative(limits.ReservationsLimit - reservations), AttachmentTotalSize: attachmentsBytesUsed, @@ -343,23 +394,16 @@ func zeroIfNegative(value int64) int64 { } func replenishDurationToDailyLimit(duration time.Duration) int64 { - return int64(24 * time.Hour / duration) + return int64(oneDay / duration) } func dailyLimitToRate(limit int64) rate.Limit { - return rate.Limit(limit) * rate.Every(24*time.Hour) + return rate.Limit(limit) * rate.Every(oneDay) } -func defaultVisitorLimits(conf *Config) *visitorLimits { - return &visitorLimits{ - Basis: visitorLimitBasisIP, - MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish), - MessagesExpiryDuration: conf.CacheDuration, - EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), - ReservationsLimit: visitorDefaultReservationsLimit, - AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit, - AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit, - AttachmentExpiryDuration: conf.AttachmentExpiryDuration, - AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit, +func visitorID(ip netip.Addr, u *user.User) string { + if u != nil && u.Tier != nil { + return fmt.Sprintf("user:%s", u.ID) } + return fmt.Sprintf("ip:%s", ip.String()) } diff --git a/user/manager.go b/user/manager.go index c252e1f..7ecba73 100644 --- a/user/manager.go +++ b/user/manager.go @@ -709,10 +709,10 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { ID: tierID.String, Code: tierCode.String, Name: tierName.String, - MessagesLimit: messagesLimit.Int64, - MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, - EmailsLimit: emailsLimit.Int64, - ReservationsLimit: reservationsLimit.Int64, + MessageLimit: messagesLimit.Int64, + MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, + EmailLimit: emailsLimit.Int64, + ReservationLimit: reservationsLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, @@ -845,7 +845,7 @@ func (a *Manager) ChangeTier(username, tier string) error { t, err := a.Tier(tier) if err != nil { return err - } else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil { + } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil { return err } if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil { @@ -870,7 +870,7 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6 if err != nil { return err } - if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit { + if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit { reservations, err := a.Reservations(username) if err != nil { return err @@ -999,7 +999,7 @@ func (a *Manager) CreateTier(tier *Tier) error { if tier.ID == "" { tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) } - if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil { + if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil { return err } return nil @@ -1070,10 +1070,10 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { ID: id, Code: code, Name: name, - MessagesLimit: messagesLimit.Int64, - MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, - EmailsLimit: emailsLimit.Int64, - ReservationsLimit: reservationsLimit.Int64, + MessageLimit: messagesLimit.Int64, + MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, + EmailLimit: emailsLimit.Int64, + ReservationLimit: reservationsLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, diff --git a/user/manager_test.go b/user/manager_test.go index 2f21bce..2bc0492 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -335,10 +335,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { Code: "pro", Name: "ntfy Pro", StripePriceID: "price123", - MessagesLimit: 5_000, - MessagesExpiryDuration: 3 * 24 * time.Hour, - EmailsLimit: 50, - ReservationsLimit: 5, + MessageLimit: 5_000, + MessageExpiryDuration: 3 * 24 * time.Hour, + EmailLimit: 50, + ReservationLimit: 5, AttachmentFileSizeLimit: 52428800, AttachmentTotalSizeLimit: 524288000, AttachmentExpiryDuration: 24 * time.Hour, @@ -351,10 +351,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { require.Nil(t, err) require.Equal(t, RoleUser, ben.Role) require.Equal(t, "pro", ben.Tier.Code) - require.Equal(t, int64(5000), ben.Tier.MessagesLimit) - require.Equal(t, 3*24*time.Hour, ben.Tier.MessagesExpiryDuration) - require.Equal(t, int64(50), ben.Tier.EmailsLimit) - require.Equal(t, int64(5), ben.Tier.ReservationsLimit) + require.Equal(t, int64(5000), ben.Tier.MessageLimit) + require.Equal(t, 3*24*time.Hour, ben.Tier.MessageExpiryDuration) + require.Equal(t, int64(50), ben.Tier.EmailLimit) + require.Equal(t, int64(5), ben.Tier.ReservationLimit) require.Equal(t, int64(52428800), ben.Tier.AttachmentFileSizeLimit) require.Equal(t, int64(524288000), ben.Tier.AttachmentTotalSizeLimit) require.Equal(t, 24*time.Hour, ben.Tier.AttachmentExpiryDuration) diff --git a/user/types.go b/user/types.go index d5b1a25..e14e757 100644 --- a/user/types.go +++ b/user/types.go @@ -62,10 +62,10 @@ type Tier struct { ID string // Tier identifier (ti_...) Code string // Code of the tier Name string // Name of the tier - MessagesLimit int64 // Daily message limit - MessagesExpiryDuration time.Duration // Cache duration for messages - EmailsLimit int64 // Daily email limit - ReservationsLimit int64 // Number of topic reservations allowed by user + MessageLimit int64 // Daily message limit + MessageExpiryDuration time.Duration // Cache duration for messages + EmailLimit int64 // Daily email limit + ReservationLimit int64 // Number of topic reservations allowed by user AttachmentFileSizeLimit int64 // Max file size per file (bytes) AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes) AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted diff --git a/util/limit.go b/util/limit.go index 8df768a..7f39c4c 100644 --- a/util/limit.go +++ b/util/limit.go @@ -27,8 +27,14 @@ type FixedLimiter struct { // NewFixedLimiter creates a new Limiter func NewFixedLimiter(limit int64) *FixedLimiter { + return NewFixedLimiterWithValue(limit, 0) +} + +// NewFixedLimiterWithValue creates a new Limiter and sets the initial value +func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter { return &FixedLimiter{ limit: limit, + value: value, } } diff --git a/util/time.go b/util/time.go index 2447548..04d78f8 100644 --- a/util/time.go +++ b/util/time.go @@ -17,7 +17,7 @@ var ( // NextOccurrenceUTC takes a time of day (e.g. 9:00am), and returns the next occurrence // of that time from the current time (in UTC). func NextOccurrenceUTC(timeOfDay, base time.Time) time.Time { - hour, minute, seconds := timeOfDay.Clock() + hour, minute, seconds := timeOfDay.UTC().Clock() now := base.UTC() next := time.Date(now.Year(), now.Month(), now.Day(), hour, minute, seconds, 0, time.UTC) if next.Before(now) { diff --git a/util/util.go b/util/util.go index 1bbf629..2d021dc 100644 --- a/util/util.go +++ b/util/util.go @@ -337,6 +337,17 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error return nil, err } +// MinMax returns value if it is between min and max, or either +// min or max if it is out of range +func MinMax[T int | int64](value, min, max T) T { + if value < min { + return min + } else if value > max { + return max + } + return value +} + // String turns a string into a pointer of a string func String(v string) *string { return &v