WIP tier CLI

This commit is contained in:
binwiederhier 2023-02-06 22:38:22 -05:00
parent 9b54f63eb1
commit e3b39f670f
10 changed files with 367 additions and 35 deletions

View file

@ -189,7 +189,11 @@ func showUsers(c *cli.Context, manager *user.Manager, users []*user.User) error
if err != nil { if err != nil {
return err return err
} }
fmt.Fprintf(c.App.ErrWriter, "user %s (%s)\n", u.Name, u.Role) tier := "none"
if u.Tier != nil {
tier = u.Tier.Name
}
fmt.Fprintf(c.App.ErrWriter, "user %s (role: %s, tier: %s)\n", u.Name, u.Role, tier)
if u.Role == user.RoleAdmin { if u.Role == user.RoleAdmin {
fmt.Fprintf(c.App.ErrWriter, "- read-write access to all topics (admin role)\n") fmt.Fprintf(c.App.ErrWriter, "- read-write access to all topics (admin role)\n")
} else if len(grants) > 0 { } else if len(grants) > 0 {

View file

@ -15,7 +15,7 @@ func TestCLI_Access_Show(t *testing.T) {
app, _, _, stderr := newTestApp() app, _, _, stderr := newTestApp()
require.Nil(t, runAccessCommand(app, conf)) require.Nil(t, runAccessCommand(app, conf))
require.Contains(t, stderr.String(), "user * (anonymous)\n- no topic-specific permissions\n- no access to any (other) topics (server config)") require.Contains(t, stderr.String(), "user * (role: anonymous, tier: none)\n- no topic-specific permissions\n- no access to any (other) topics (server config)")
} }
func TestCLI_Access_Grant_And_Publish(t *testing.T) { func TestCLI_Access_Grant_And_Publish(t *testing.T) {
@ -32,12 +32,12 @@ func TestCLI_Access_Grant_And_Publish(t *testing.T) {
app, _, _, stderr := newTestApp() app, _, _, stderr := newTestApp()
require.Nil(t, runAccessCommand(app, conf)) require.Nil(t, runAccessCommand(app, conf))
expected := `user phil (admin) expected := `user phil (role: admin, tier: none)
- read-write access to all topics (admin role) - read-write access to all topics (admin role)
user ben (user) user ben (role: user, tier: none)
- read-write access to topic announcements - read-write access to topic announcements
- read-only access to topic sometopic - read-only access to topic sometopic
user * (anonymous) user * (role: anonymous, tier: none)
- read-only access to topic announcements - read-only access to topic announcements
- no access to any (other) topics (server config) - no access to any (other) topics (server config)
` `

289
cmd/tier.go Normal file
View file

@ -0,0 +1,289 @@
//go:build !noserver
package cmd
import (
"errors"
"fmt"
"github.com/urfave/cli/v2"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"time"
)
func init() {
commands = append(commands, cmdTier)
}
const (
defaultMessageLimit = 5000
defaultMessageExpiryDuration = 12 * time.Hour
defaultEmailLimit = 20
defaultReservationLimit = 3
defaultAttachmentFileSizeLimit = "15M"
defaultAttachmentTotalSizeLimit = "100M"
defaultAttachmentExpiryDuration = 6 * time.Hour
defaultAttachmentBandwidthLimit = "1G"
)
var (
flagsTier = append([]cli.Flag{}, flagsUser...)
)
var cmdTier = &cli.Command{
Name: "tier",
Usage: "Manage/show tiers",
UsageText: "ntfy tier [list|add|remove] ...",
Flags: flagsTier,
Before: initConfigFileInputSourceFunc("config", flagsUser, initLogFunc),
Category: categoryServer,
Subcommands: []*cli.Command{
{
Name: "add",
Aliases: []string{"a"},
Usage: "Adds a new tier",
UsageText: "ntfy tier add [OPTIONS] CODE",
Action: execTierAdd,
Flags: []cli.Flag{
&cli.StringFlag{Name: "name", Usage: "tier name"},
&cli.Int64Flag{Name: "message-limit", Value: defaultMessageLimit, Usage: "daily message limit"},
&cli.DurationFlag{Name: "message-expiry-duration", Value: defaultMessageExpiryDuration, Usage: "duration after which messages are deleted"},
&cli.Int64Flag{Name: "email-limit", Value: defaultEmailLimit, Usage: "daily email limit"},
&cli.Int64Flag{Name: "reservation-limit", Value: defaultReservationLimit, Usage: "topic reservation limit"},
&cli.StringFlag{Name: "attachment-file-size-limit", Value: defaultAttachmentFileSizeLimit, Usage: "per-attachment file size limit"},
&cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"},
&cli.DurationFlag{Name: "attachment-expiry-duration", Value: defaultAttachmentExpiryDuration, Usage: "duration after which attachments are deleted"},
&cli.StringFlag{Name: "attachment-bandwidth-limit", Value: defaultAttachmentBandwidthLimit, Usage: "daily bandwidth limit for attachment uploads/downloads"},
&cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"},
},
Description: `
FIXME
`,
},
{
Name: "change",
Aliases: []string{"ch"},
Usage: "Change a tier",
UsageText: "ntfy tier change [OPTIONS] CODE",
Action: execTierChange,
Flags: []cli.Flag{
&cli.StringFlag{Name: "name", Usage: "tier name"},
&cli.Int64Flag{Name: "message-limit", Usage: "daily message limit"},
&cli.DurationFlag{Name: "message-expiry-duration", Usage: "duration after which messages are deleted"},
&cli.Int64Flag{Name: "email-limit", Usage: "daily email limit"},
&cli.Int64Flag{Name: "reservation-limit", Usage: "topic reservation limit"},
&cli.StringFlag{Name: "attachment-file-size-limit", Usage: "per-attachment file size limit"},
&cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"},
&cli.DurationFlag{Name: "attachment-expiry-duration", Usage: "duration after which attachments are deleted"},
&cli.StringFlag{Name: "attachment-bandwidth-limit", Usage: "daily bandwidth limit for attachment uploads/downloads"},
&cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"},
},
Description: `
FIXME
`,
},
{
Name: "remove",
Aliases: []string{"del", "rm"},
Usage: "Removes a tier",
UsageText: "ntfy tier remove CODE",
Action: execTierDel,
Description: `
FIXME
`,
},
{
Name: "list",
Aliases: []string{"l"},
Usage: "Shows a list of tiers",
Action: execTierList,
Description: `
FIXME
`,
},
},
Description: `Manage tier of the ntfy server.
The command allows you to add/remove/change tier in the ntfy user database. Tiers are used
to grant users higher limits based on their tier.
This is a server-only command. It directly manages the user.db as defined in the server config
file server.yml. The command only works if 'auth-file' is properly defined. Please also refer
to the related command 'ntfy access'.
FIXME
`,
}
func execTierAdd(c *cli.Context) error {
code := c.Args().Get(0)
if code == "" {
return errors.New("tier code expected, type 'ntfy tier add --help' for help")
} else if !user.AllowedTier(code) {
return errors.New("tier code must consist only of numbers and letters")
}
manager, err := createUserManager(c)
if err != nil {
return err
}
if tier, _ := manager.Tier(code); tier != nil {
return fmt.Errorf("tier %s already exists", code)
}
name := c.String("name")
if name == "" {
name = code
}
attachmentFileSizeLimit, err := util.ParseSize(c.String("attachment-file-size-limit"))
if err != nil {
return err
}
attachmentTotalSizeLimit, err := util.ParseSize(c.String("attachment-total-size-limit"))
if err != nil {
return err
}
attachmentBandwidthLimit, err := util.ParseSize(c.String("attachment-bandwidth-limit"))
if err != nil {
return err
}
tier := &user.Tier{
ID: "", // Generated
Code: code,
Name: name,
MessageLimit: c.Int64("message-limit"),
MessageExpiryDuration: c.Duration("message-expiry-duration"),
EmailLimit: c.Int64("email-limit"),
ReservationLimit: c.Int64("reservation-limit"),
AttachmentFileSizeLimit: attachmentFileSizeLimit,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit,
AttachmentExpiryDuration: c.Duration("attachment-expiry-duration"),
AttachmentBandwidthLimit: attachmentBandwidthLimit,
StripePriceID: c.String("stripe-price-id"),
}
if err := manager.AddTier(tier); err != nil {
return err
}
tier, err = manager.Tier(code)
if err != nil {
return err
}
fmt.Fprintf(c.App.ErrWriter, "tier added\n\n")
printTier(c, tier)
return nil
}
func execTierChange(c *cli.Context) error {
code := c.Args().Get(0)
if code == "" {
return errors.New("tier code expected, type 'ntfy tier change --help' for help")
} else if !user.AllowedTier(code) {
return errors.New("tier code must consist only of numbers and letters")
}
manager, err := createUserManager(c)
if err != nil {
return err
}
tier, err := manager.Tier(code)
if err == user.ErrTierNotFound {
return fmt.Errorf("tier %s does not exist", code)
} else if err != nil {
return err
}
if c.IsSet("name") {
tier.Name = c.String("name")
}
if c.IsSet("message-limit") {
tier.MessageLimit = c.Int64("message-limit")
}
if c.IsSet("message-expiry-duration") {
tier.MessageExpiryDuration = c.Duration("message-expiry-duration")
}
if c.IsSet("email-limit") {
tier.EmailLimit = c.Int64("email-limit")
}
if c.IsSet("reservation-limit") {
tier.ReservationLimit = c.Int64("reservation-limit")
}
if c.IsSet("attachment-file-size-limit") {
tier.AttachmentFileSizeLimit, err = util.ParseSize(c.String("attachment-file-size-limit"))
if err != nil {
return err
}
}
if c.IsSet("attachment-total-size-limit") {
tier.AttachmentTotalSizeLimit, err = util.ParseSize(c.String("attachment-total-size-limit"))
if err != nil {
return err
}
}
if c.IsSet("attachment-expiry-duration") {
tier.AttachmentExpiryDuration = c.Duration("attachment-expiry-duration")
}
if c.IsSet("attachment-bandwidth-limit") {
tier.AttachmentBandwidthLimit, err = util.ParseSize(c.String("attachment-bandwidth-limit"))
if err != nil {
return err
}
}
if c.IsSet("stripe-price-id") {
tier.StripePriceID = c.String("stripe-price-id")
}
if err := manager.UpdateTier(tier); err != nil {
return err
}
fmt.Fprintf(c.App.ErrWriter, "tier updated\n\n")
printTier(c, tier)
return nil
}
func execTierDel(c *cli.Context) error {
code := c.Args().Get(0)
if code == "" {
return errors.New("tier code expected, type 'ntfy tier del --help' for help")
}
manager, err := createUserManager(c)
if err != nil {
return err
}
if _, err := manager.Tier(code); err == user.ErrTierNotFound {
return fmt.Errorf("tier %s does not exist", code)
}
if err := manager.RemoveTier(code); err != nil {
return err
}
fmt.Fprintf(c.App.ErrWriter, "tier %s removed\n", code)
return nil
}
func execTierList(c *cli.Context) error {
manager, err := createUserManager(c)
if err != nil {
return err
}
tiers, err := manager.Tiers()
if err != nil {
return err
}
for _, tier := range tiers {
printTier(c, tier)
}
return nil
}
func printTier(c *cli.Context, tier *user.Tier) {
stripePriceID := tier.StripePriceID
if stripePriceID == "" {
stripePriceID = "(none)"
}
fmt.Fprintf(c.App.ErrWriter, "tier %s (id: %s)\n", tier.Code, tier.ID)
fmt.Fprintf(c.App.ErrWriter, "- Name: %s\n", tier.Name)
fmt.Fprintf(c.App.ErrWriter, "- Message limit: %d\n", tier.MessageLimit)
fmt.Fprintf(c.App.ErrWriter, "- Message expiry duration: %s (%d seconds)\n", tier.MessageExpiryDuration.String(), int64(tier.MessageExpiryDuration.Seconds()))
fmt.Fprintf(c.App.ErrWriter, "- Email limit: %d\n", tier.EmailLimit)
fmt.Fprintf(c.App.ErrWriter, "- Reservation limit: %d\n", tier.ReservationLimit)
fmt.Fprintf(c.App.ErrWriter, "- Attachment file size limit: %s\n", util.FormatSize(tier.AttachmentFileSizeLimit))
fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit))
fmt.Fprintf(c.App.ErrWriter, "- Attachment expiry duration: %s (%d seconds)\n", tier.AttachmentExpiryDuration.String(), int64(tier.AttachmentExpiryDuration.Seconds()))
fmt.Fprintf(c.App.ErrWriter, "- Attachment daily bandwidth limit: %s\n", util.FormatSize(tier.AttachmentBandwidthLimit))
fmt.Fprintf(c.App.ErrWriter, "- Stripe price: %s\n", stripePriceID)
}

View file

@ -38,7 +38,6 @@ import (
- HIGH Account limit creation triggers when account is taken! - HIGH Account limit creation triggers when account is taken!
- HIGH Docs - HIGH Docs
- HIGH CLI "ntfy tier [add|list|delete]" - HIGH CLI "ntfy tier [add|list|delete]"
- HIGH CLI "ntfy user" should show tier
- HIGH Self-review - HIGH Self-review
- MEDIUM: Test for expiring messages after reservation removal - MEDIUM: Test for expiring messages after reservation removal
- MEDIUM: Test new token endpoints & never-expiring token - MEDIUM: Test new token endpoints & never-expiring token

View file

@ -437,7 +437,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
s := newTestServer(t, conf) s := newTestServer(t, conf)
// A user, an admin, and a reservation walk into a bar // A user, an admin, and a reservation walk into a bar
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro", Code: "pro",
ReservationLimit: 2, ReservationLimit: 2,
})) }))
@ -493,7 +493,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
// Create a tier // Create a tier
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro", Code: "pro",
MessageLimit: 123, MessageLimit: 123,
MessageExpiryDuration: 86400 * time.Second, MessageExpiryDuration: 86400 * time.Second,
@ -575,7 +575,7 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro", Code: "pro",
MessageLimit: 20, MessageLimit: 20,
ReservationLimit: 2, ReservationLimit: 2,
@ -610,7 +610,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro", Code: "pro",
MessageLimit: 20, MessageLimit: 20,
ReservationLimit: 2, ReservationLimit: 2,
@ -689,11 +689,11 @@ func TestAccount_Persist_UserStats_After_Tier_Change(t *testing.T) {
// Create user with tier // Create user with tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "starter", Code: "starter",
MessageLimit: 10, MessageLimit: 10,
})) }))
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro", Code: "pro",
MessageLimit: 20, MessageLimit: 20,
})) }))

View file

@ -42,12 +42,12 @@ func TestPayments_Tiers(t *testing.T) {
}, nil) }, nil)
// Create tiers // Create tiers
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_1", ID: "ti_1",
Code: "admin", Code: "admin",
Name: "Admin", Name: "Admin",
})) }))
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
Name: "Pro", Name: "Pro",
@ -60,7 +60,7 @@ func TestPayments_Tiers(t *testing.T) {
AttachmentExpiryDuration: time.Minute, AttachmentExpiryDuration: time.Minute,
StripePriceID: "price_123", StripePriceID: "price_123",
})) }))
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_444", ID: "ti_444",
Code: "business", Code: "business",
Name: "Business", Name: "Business",
@ -135,7 +135,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil) Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
@ -171,7 +171,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil) Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
@ -213,7 +213,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
Return(&stripe.Subscription{}, nil) Return(&stripe.Subscription{}, nil)
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
@ -264,7 +264,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
s.stripe = stripeMock s.stripe = stripeMock
// Create a user with a Stripe subscription and 3 reservations // Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "starter", Code: "starter",
StripePriceID: "price_1234", StripePriceID: "price_1234",
@ -420,7 +420,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil) Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
// Create a user with a Stripe subscription and 3 reservations // Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_1", ID: "ti_1",
Code: "starter", Code: "starter",
StripePriceID: "price_1234", // ! StripePriceID: "price_1234", // !
@ -432,7 +432,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentTotalSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000,
AttachmentBandwidthLimit: 1000000, AttachmentBandwidthLimit: 1000000,
})) }))
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_2", ID: "ti_2",
Code: "pro", Code: "pro",
StripePriceID: "price_1111", // ! StripePriceID: "price_1111", // !
@ -545,7 +545,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
Return(jsonToStripeEvent(t, subscriptionDeletedEventJSON), nil) Return(jsonToStripeEvent(t, subscriptionDeletedEventJSON), nil)
// Create a user with a Stripe subscription and 3 reservations // Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_1", ID: "ti_1",
Code: "pro", Code: "pro",
StripePriceID: "price_1234", StripePriceID: "price_1234",
@ -626,12 +626,12 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
Return(&stripe.Subscription{}, nil) Return(&stripe.Subscription{}, nil)
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
})) }))
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_456", ID: "ti_456",
Code: "business", Code: "business",
StripePriceID: "price_456", StripePriceID: "price_456",

View file

@ -761,7 +761,7 @@ func TestServer_StatsResetter(t *testing.T) {
go s.runStatsResetter() go s.runStatsResetter()
// Create user with tier (tieruser) and user without tier (phil) // Create user with tier (tieruser) and user without tier (phil)
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "test", Code: "test",
MessageLimit: 5, MessageLimit: 5,
MessageExpiryDuration: -5 * time.Second, // Second, what a hack! MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
@ -898,7 +898,7 @@ func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
// Create user, and update it with some message and email stats // Create user, and update it with some message and email stats
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "test", Code: "test",
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
@ -1275,7 +1275,7 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
// Create tier with certain limits // Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "test", Code: "test",
MessageLimit: 5, MessageLimit: 5,
MessageExpiryDuration: -5 * time.Second, // Second, what a hack! MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
@ -1504,7 +1504,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
// Create tier with certain limits // Create tier with certain limits
sevenDays := time.Duration(604800) * time.Second sevenDays := time.Duration(604800) * time.Second
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "test", Code: "test",
MessageLimit: 10, MessageLimit: 10,
MessageExpiryDuration: sevenDays, MessageExpiryDuration: sevenDays,
@ -1549,7 +1549,7 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
// Create tier with certain limits // Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "test", Code: "test",
MessageLimit: 10, MessageLimit: 10,
MessageExpiryDuration: time.Hour, MessageExpiryDuration: time.Hour,
@ -1588,7 +1588,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
// Create tier with certain limits // Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "test", Code: "test",
MessageLimit: 100, MessageLimit: 100,
AttachmentFileSizeLimit: 50_000, AttachmentFileSizeLimit: 50_000,

View file

@ -248,6 +248,11 @@ const (
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id) INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
updateTierQuery = `
UPDATE tier
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_price_id = ?
WHERE code = ?
`
selectTiersQuery = ` selectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
FROM tier FROM tier
@ -264,6 +269,7 @@ const (
` `
updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?` updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
deleteTierQuery = `DELETE FROM tier WHERE code = ?`
updateBillingQuery = ` updateBillingQuery = `
UPDATE user UPDATE user
@ -1116,8 +1122,8 @@ func (a *Manager) DefaultAccess() Permission {
return a.defaultAccess return a.defaultAccess
} }
// CreateTier creates a new tier in the database // AddTier creates a new tier in the database
func (a *Manager) CreateTier(tier *Tier) error { func (a *Manager) AddTier(tier *Tier) error {
if tier.ID == "" { if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
} }
@ -1127,6 +1133,26 @@ func (a *Manager) CreateTier(tier *Tier) error {
return nil return nil
} }
// UpdateTier updates a tier's properties in the database
func (a *Manager) UpdateTier(tier *Tier) error {
if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID), tier.Code); err != nil {
return err
}
return nil
}
// RemoveTier deletes the tier with the given code
func (a *Manager) RemoveTier(code string) error {
if !AllowedTier(code) {
return ErrInvalidArgument
}
// This fails if any user has this tier
if _, err := a.db.Exec(deleteTierQuery, code); err != nil {
return err
}
return nil
}
// ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
func (a *Manager) ChangeBilling(username string, billing *Billing) error { func (a *Manager) ChangeBilling(username string, billing *Billing) error {
if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {

View file

@ -333,7 +333,7 @@ func TestManager_Reservations(t *testing.T) {
func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.CreateTier(&Tier{ require.Nil(t, a.AddTier(&Tier{
Code: "pro", Code: "pro",
Name: "ntfy Pro", Name: "ntfy Pro",
StripePriceID: "price123", StripePriceID: "price123",
@ -629,7 +629,7 @@ func TestManager_Tier_Create(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
// Create tier and user // Create tier and user
require.Nil(t, a.CreateTier(&Tier{ require.Nil(t, a.AddTier(&Tier{
Code: "pro", Code: "pro",
Name: "Pro", Name: "Pro",
MessageLimit: 123, MessageLimit: 123,
@ -670,7 +670,7 @@ func TestManager_Tier_Create(t *testing.T) {
func TestAccount_Tier_Create_With_ID(t *testing.T) { func TestAccount_Tier_Create_With_ID(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.CreateTier(&Tier{ require.Nil(t, a.AddTier(&Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
})) }))

View file

@ -222,6 +222,20 @@ func ParseSize(s string) (int64, error) {
} }
} }
// FormatSize formats bytes into a human-readable notation, e.g. 2.1 MB
func FormatSize(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d bytes", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(b)/float64(div), "KMGTPE"[exp])
}
// ReadPassword will read a password from STDIN. If the terminal supports it, it will not print the // ReadPassword will read a password from STDIN. If the terminal supports it, it will not print the
// input characters to the screen. If not, it'll just read using normal readline semantics (useful for testing). // input characters to the screen. If not, it'll just read using normal readline semantics (useful for testing).
func ReadPassword(in io.Reader) ([]byte, error) { func ReadPassword(in io.Reader) ([]byte, error) {