diff --git a/cmd/serve.go b/cmd/serve.go index 0e0bdc2..729b159 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -285,7 +285,7 @@ func execServe(c *cli.Context) error { conf.TotalTopicLimit = totalTopicLimit conf.VisitorSubscriptionLimit = visitorSubscriptionLimit conf.VisitorAttachmentTotalSizeLimit = visitorAttachmentTotalSizeLimit - conf.VisitorAttachmentDailyBandwidthLimit = int(visitorAttachmentDailyBandwidthLimit) + conf.VisitorAttachmentDailyBandwidthLimit = visitorAttachmentDailyBandwidthLimit conf.VisitorRequestLimitBurst = visitorRequestLimitBurst conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs diff --git a/server/config.go b/server/config.go index b54cb54..4fa48dc 100644 --- a/server/config.go +++ b/server/config.go @@ -101,7 +101,7 @@ type Config struct { TotalAttachmentSizeLimit int64 VisitorSubscriptionLimit int VisitorAttachmentTotalSizeLimit int64 - VisitorAttachmentDailyBandwidthLimit int + VisitorAttachmentDailyBandwidthLimit int64 VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration VisitorRequestExemptIPAddrs []netip.Prefix diff --git a/server/server.go b/server/server.go index 0d104ac..8ab9b2f 100644 --- a/server/server.go +++ b/server/server.go @@ -40,7 +40,6 @@ TODO - HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) -- HIGH Rate limiting: Bandwidth limit must be in tier + TESTS - MEDIUM: Races with v.user (see publishSyncEventAsync test) - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade) @@ -866,7 +865,6 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit), util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining), } - fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v) m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...) if err == util.ErrLimitReached { return errHTTPEntityTooLargeAttachment diff --git a/server/server_account.go b/server/server_account.go index b56509f..57547a1 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -11,6 +11,7 @@ import ( const ( subscriptionIDLength = 16 + subscriptionIDPrefix = "su_" syncTopicAccountSyncEvent = "sync" ) @@ -55,6 +56,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis AttachmentTotalSize: limits.AttachmentTotalSizeLimit, AttachmentFileSize: limits.AttachmentFileSizeLimit, AttachmentExpiryDuration: int64(limits.AttachmentExpiryDuration.Seconds()), + AttachmentBandwidth: limits.AttachmentBandwidthLimit, }, Stats: &apiAccountStats{ Messages: stats.Messages, @@ -249,7 +251,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req } } if newSubscription.ID == "" { - newSubscription.ID = util.RandomString(subscriptionIDLength) + newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, newSubscription) if err := s.userManager.ChangeSettings(v.user); err != nil { return err diff --git a/server/server_account_test.go b/server/server_account_test.go index 7361086..3405cae 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -153,9 +153,9 @@ func TestAccount_ChangeSettings(t *testing.T) { require.Equal(t, 200, rr.Code) account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body)) require.Equal(t, "de", account.Language) - require.Equal(t, 86400, account.Notification.DeleteAfter) - require.Equal(t, "juntos", account.Notification.Sound) - require.Equal(t, 0, account.Notification.MinPriority) // Not set + require.Equal(t, util.Int(86400), account.Notification.DeleteAfter) + require.Equal(t, util.String("juntos"), account.Notification.Sound) + require.Nil(t, account.Notification.MinPriority) // Not set } func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { @@ -176,7 +176,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { require.NotEmpty(t, account.Subscriptions[0].ID) require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL) require.Equal(t, "def", account.Subscriptions[0].Topic) - require.Equal(t, "", account.Subscriptions[0].DisplayName) + require.Nil(t, account.Subscriptions[0].DisplayName) subscriptionID := account.Subscriptions[0].ID rr = request(t, s, "PATCH", "/v1/account/subscription/"+subscriptionID, `{"display_name": "ding dong"}`, map[string]string{ @@ -193,7 +193,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { require.Equal(t, subscriptionID, account.Subscriptions[0].ID) require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL) require.Equal(t, "def", account.Subscriptions[0].Topic) - require.Equal(t, "ding dong", account.Subscriptions[0].DisplayName) + require.Equal(t, util.String("ding dong"), account.Subscriptions[0].DisplayName) rr = request(t, s, "DELETE", "/v1/account/subscription/"+subscriptionID, "", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), @@ -402,6 +402,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) { AttachmentFileSizeLimit: 1231231, AttachmentTotalSizeLimit: 123123, AttachmentExpiryDuration: 10800 * time.Second, + AttachmentBandwidthLimit: 21474836480, })) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) @@ -442,6 +443,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) { require.Equal(t, int64(1231231), account.Limits.AttachmentFileSize) require.Equal(t, int64(123123), account.Limits.AttachmentTotalSize) require.Equal(t, int64(10800), account.Limits.AttachmentExpiryDuration) + require.Equal(t, int64(21474836480), account.Limits.AttachmentBandwidth) require.Equal(t, 2, len(account.Reservations)) require.Equal(t, "another", account.Reservations[0].Topic) require.Equal(t, "write-only", account.Reservations[0].Everyone) diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 82d4ca8..a2a4738 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -265,6 +265,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( AttachmentExpiryDuration: time.Hour, AttachmentFileSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000, + AttachmentBandwidthLimit: 1000000, })) require.Nil(t, s.userManager.CreateTier(&user.Tier{ Code: "pro", @@ -275,6 +276,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( AttachmentExpiryDuration: time.Hour, AttachmentFileSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000, + AttachmentBandwidthLimit: 1000000, })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) diff --git a/server/server_test.go b/server/server_test.go index 303f198..81a243e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1368,6 +1368,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) { AttachmentFileSizeLimit: 50_000, AttachmentTotalSizeLimit: 200_000, AttachmentExpiryDuration: sevenDays, // 7 days + AttachmentBandwidthLimit: 100000, })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "test")) @@ -1376,6 +1377,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) { response := request(t, s, "PUT", "/mytopic", content, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) + require.Equal(t, 200, response.Code) msg := toMessage(t, response.Body.String()) require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") require.True(t, msg.Attachment.Expires > time.Now().Add(sevenDays-30*time.Second).Unix()) @@ -1396,6 +1398,46 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) { require.Equal(t, 200, response.Code) } +func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) { + content := util.RandomString(5000) // > 4096 + + c := newTestConfigWithAuthFile(t) + s := newTestServer(t, c) + + // Create tier with certain limits + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "test", + MessagesLimit: 10, + MessagesExpiryDuration: time.Hour, + AttachmentFileSizeLimit: 50_000, + AttachmentTotalSizeLimit: 200_000, + AttachmentExpiryDuration: time.Hour, + AttachmentBandwidthLimit: 14000, // < 3x5000 bytes -> enough for one upload, one download + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + require.Nil(t, s.userManager.ChangeTier("phil", "test")) + + // Publish and make sure we can retrieve it + rr := request(t, s, "PUT", "/mytopic", content, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + msg := toMessage(t, rr.Body.String()) + + // Retrieve it (first time succeeds) + rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + require.Equal(t, content, rr.Body.String()) + + // Retrieve it AGAIN (fails, due to bandwidth limit) + rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 429, rr.Code) +} + func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) { smallFile := util.RandomString(20_000) largeFile := util.RandomString(50_000) @@ -1412,6 +1454,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) { AttachmentFileSizeLimit: 50_000, AttachmentTotalSizeLimit: 200_000, AttachmentExpiryDuration: 30 * time.Second, + AttachmentBandwidthLimit: 1000000, })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "test")) diff --git a/server/types.go b/server/types.go index 50e41d6..15c1b84 100644 --- a/server/types.go +++ b/server/types.go @@ -246,7 +246,7 @@ type apiAccountTier struct { } type apiAccountLimits struct { - Basis string `json:"basis,omitempty"` // "ip", "role" or "tier" + Basis string `json:"basis,omitempty"` // "ip" or "tier" Messages int64 `json:"messages"` MessagesExpiryDuration int64 `json:"messages_expiry_duration"` Emails int64 `json:"emails"` @@ -254,6 +254,7 @@ type apiAccountLimits struct { AttachmentTotalSize int64 `json:"attachment_total_size"` AttachmentFileSize int64 `json:"attachment_file_size"` AttachmentExpiryDuration int64 `json:"attachment_expiry_duration"` + AttachmentBandwidth int64 `json:"attachment_bandwidth"` } type apiAccountStats struct { diff --git a/server/visitor.go b/server/visitor.go index 1619240..a7b5aaf 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -31,9 +31,9 @@ var ( type visitor struct { config *Config messageCache *messageCache - userManager *user.Manager // May be nil! - ip netip.Addr - user *user.User + userManager *user.Manager // May be nil + ip netip.Addr // Visitor IP address + user *user.User // Only set if authenticated user, otherwise nil messages int64 // Number of messages sent, reset every day emails int64 // Number of emails sent, reset every day requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages) @@ -61,6 +61,7 @@ type visitorLimits struct { AttachmentTotalSizeLimit int64 AttachmentFileSizeLimit int64 AttachmentExpiryDuration time.Duration + AttachmentBandwidthLimit int64 } type visitorStats struct { @@ -84,7 +85,7 @@ const ( ) func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { - var messagesLimiter util.Limiter + var messagesLimiter, attachmentBandwidthLimiter util.Limiter var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter var messages, emails int64 if user != nil { @@ -97,9 +98,11 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst) messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit) emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst) + attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour) } else { requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst) emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst) + attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour) } return &visitor{ config: conf, @@ -113,7 +116,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana messagesLimiter: messagesLimiter, // May be nil emailsLimiter: emailsLimiter, subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), - bandwidthLimiter: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour), + bandwidthLimiter: attachmentBandwidthLimiter, accountLimiter: accountLimiter, // May be nil firebase: time.Unix(0, 0), seen: time.Now(), @@ -259,6 +262,7 @@ func (v *visitor) Limits() *visitorLimits { limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration + limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit } return limits } @@ -327,5 +331,6 @@ func defaultVisitorLimits(conf *Config) *visitorLimits { AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit, AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit, AttachmentExpiryDuration: conf.AttachmentExpiryDuration, + AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit, } } diff --git a/user/manager.go b/user/manager.go index e3c49c8..f1bb73f 100644 --- a/user/manager.go +++ b/user/manager.go @@ -52,6 +52,7 @@ const ( attachment_file_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL, attachment_expiry_duration INT NOT NULL, + attachment_bandwidth_limit INT NOT NULL, stripe_price_id TEXT ); CREATE UNIQUE INDEX idx_tier_code ON tier (code); @@ -109,26 +110,26 @@ const ( ` selectUserByIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.id = ? ` selectUserByNameQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id FROM user u JOIN user_token t on u.id = t.user_id LEFT JOIN tier t on t.id = u.tier_id WHERE t.token = ? AND t.expires >= ? ` selectUserByStripeCustomerIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.stripe_customer_id = ? @@ -232,20 +233,20 @@ const ( ` insertTierQuery = ` - INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` selectTiersQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id FROM tier ` selectTierByCodeQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id FROM tier WHERE code = ? ` selectTierByPriceIDQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id FROM tier WHERE stripe_price_id = ? ` @@ -670,11 +671,11 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { var id, username, hash, role, prefs, syncTopic string var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString var messages, emails int64 - var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 + var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -714,6 +715,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, + AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, StripePriceID: stripePriceID.String, // May be empty } } @@ -994,7 +996,7 @@ func (a *Manager) DefaultAccess() Permission { // CreateTier creates a new tier in the database func (a *Manager) CreateTier(tier *Tier) error { tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength) - if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil { + if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil { return err } return nil @@ -1051,11 +1053,11 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { var id, code, name string var stripePriceID sql.NullString - var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 + var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 if !rows.Next() { return nil, ErrTierNotFound } - if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -1072,6 +1074,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, + AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, StripePriceID: stripePriceID.String, // May be empty }, nil } diff --git a/user/manager_test.go b/user/manager_test.go index 9fa1819..2f21bce 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -3,6 +3,7 @@ package user import ( "database/sql" "github.com/stretchr/testify/require" + "heckel.io/ntfy/util" "path/filepath" "strings" "testing" @@ -583,21 +584,21 @@ func TestManager_ChangeSettings(t *testing.T) { require.Nil(t, err) require.Nil(t, u.Prefs.Subscriptions) require.Nil(t, u.Prefs.Notification) - require.Equal(t, "", u.Prefs.Language) + require.Nil(t, u.Prefs.Language) // Save with new settings u.Prefs = &Prefs{ - Language: "de", + Language: util.String("de"), Notification: &NotificationPrefs{ - Sound: "ding", - MinPriority: 2, + Sound: util.String("ding"), + MinPriority: util.Int(2), }, Subscriptions: []*Subscription{ { ID: "someID", BaseURL: "https://ntfy.sh", Topic: "mytopic", - DisplayName: "My Topic", + DisplayName: util.String("My Topic"), }, }, } @@ -606,14 +607,14 @@ func TestManager_ChangeSettings(t *testing.T) { // Read again u, err = a.User("ben") require.Nil(t, err) - require.Equal(t, "de", u.Prefs.Language) - require.Equal(t, "ding", u.Prefs.Notification.Sound) - require.Equal(t, 2, u.Prefs.Notification.MinPriority) - require.Equal(t, 0, u.Prefs.Notification.DeleteAfter) + require.Equal(t, util.String("de"), u.Prefs.Language) + require.Equal(t, util.String("ding"), u.Prefs.Notification.Sound) + require.Equal(t, util.Int(2), u.Prefs.Notification.MinPriority) + require.Nil(t, u.Prefs.Notification.DeleteAfter) require.Equal(t, "someID", u.Prefs.Subscriptions[0].ID) require.Equal(t, "https://ntfy.sh", u.Prefs.Subscriptions[0].BaseURL) require.Equal(t, "mytopic", u.Prefs.Subscriptions[0].Topic) - require.Equal(t, "My Topic", u.Prefs.Subscriptions[0].DisplayName) + require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName) } func TestSqliteCache_Migration_From1(t *testing.T) { diff --git a/user/types.go b/user/types.go index 1d6f1c3..a4be9b4 100644 --- a/user/types.go +++ b/user/types.go @@ -50,17 +50,18 @@ type Prefs struct { // Tier represents a user's account type, including its account limits type Tier struct { - ID string - Code string - Name string - MessagesLimit int64 - MessagesExpiryDuration time.Duration - EmailsLimit int64 - ReservationsLimit int64 - AttachmentFileSizeLimit int64 - AttachmentTotalSizeLimit int64 - AttachmentExpiryDuration time.Duration - StripePriceID string + ID string // Tier identifier (ti_...) + Code string // Code of the tier + Name string // Name of the tier + MessagesLimit int64 // Daily message limit + MessagesExpiryDuration time.Duration // Cache duration for messages + EmailsLimit int64 // Daily email limit + ReservationsLimit int64 // Number of topic reservations allowed by user + AttachmentFileSizeLimit int64 // Max file size per file (bytes) + AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes) + AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted + AttachmentBandwidthLimit int64 // Daily bandwidth limit for the user + StripePriceID string // Price ID for paid tiers (price_...) } // Subscription represents a user's topic subscription diff --git a/util/util.go b/util/util.go index 1650bb5..1bbf629 100644 --- a/util/util.go +++ b/util/util.go @@ -336,3 +336,13 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error } return nil, err } + +// String turns a string into a pointer of a string +func String(v string) *string { + return &v +} + +// Int turns a string into a pointer of an int +func Int(v int) *int { + return &v +}