diff --git a/server/config.go b/server/config.go index 4fa48dc..f67fe8f 100644 --- a/server/config.go +++ b/server/config.go @@ -46,8 +46,8 @@ const ( DefaultVisitorRequestLimitReplenish = 5 * time.Second DefaultVisitorEmailLimitBurst = 16 DefaultVisitorEmailLimitReplenish = time.Hour - DefaultVisitorAccountCreateLimitBurst = 3 - DefaultVisitorAccountCreateLimitReplenish = 24 * time.Hour + DefaultVisitorAccountCreationLimitBurst = 3 + DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB ) @@ -107,8 +107,8 @@ type Config struct { VisitorRequestExemptIPAddrs []netip.Prefix VisitorEmailLimitBurst int VisitorEmailLimitReplenish time.Duration - VisitorAccountCreateLimitBurst int - VisitorAccountCreateLimitReplenish time.Duration + VisitorAccountCreationLimitBurst int + VisitorAccountCreationLimitReplenish time.Duration VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats BehindProxy bool StripeSecretKey string @@ -173,8 +173,8 @@ func NewConfig() *Config { VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0), VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, - VisitorAccountCreateLimitBurst: DefaultVisitorAccountCreateLimitBurst, - VisitorAccountCreateLimitReplenish: DefaultVisitorAccountCreateLimitReplenish, + VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst, + VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish, VisitorStatsResetTime: DefaultVisitorStatsResetTime, BehindProxy: false, StripeSecretKey: "", diff --git a/server/server.go b/server/server.go index 8ab9b2f..d79c967 100644 --- a/server/server.go +++ b/server/server.go @@ -40,6 +40,8 @@ TODO - HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) +- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters +- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)? - MEDIUM: Races with v.user (see publishSyncEventAsync test) - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade) @@ -50,8 +52,6 @@ TODO Limits & rate limiting: users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address! - when ResetStats() is run, reset messagesLimiter (and others)? - Delete visitor when tier is changed to refresh rate limiters Make sure account endpoints make sense for admins @@ -1602,9 +1602,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) { } else { v = s.visitorFromIP(ip) } - v.mu.Lock() - v.user = u - v.mu.Unlock() + v.SetUser(u) // Update visitor user with latest from database! return v, err // Always return visitor, even when error occurs! } diff --git a/server/server_account.go b/server/server_account.go index 57547a1..da02cab 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -31,7 +31,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v * if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil { return errHTTPConflictUserExists } - if v.accountLimiter != nil && !v.accountLimiter.Allow() { + if err := v.AccountCreationAllowed(); err != nil { return errHTTPTooManyRequestsLimitAccountCreation } if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User diff --git a/server/server_account_test.go b/server/server_account_test.go index 3405cae..79504b6 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -6,6 +6,7 @@ import ( "heckel.io/ntfy/user" "heckel.io/ntfy/util" "io" + "strings" "testing" "time" ) @@ -91,6 +92,20 @@ func TestAccount_Signup_Disabled(t *testing.T) { require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code) } +func TestAccount_Signup_Rate_Limit(t *testing.T) { + conf := newTestConfigWithAuthFile(t) + conf.EnableSignup = true + s := newTestServer(t, conf) + + for i := 0; i < 3; i++ { + rr := request(t, s, "POST", "/v1/account", fmt.Sprintf(`{"username":"phil%d", "password":"mypass"}`, i), nil) + require.Equal(t, 200, rr.Code, "failed on iteration %d", i) + } + rr := request(t, s, "POST", "/v1/account", `{"username":"notallowed", "password":"mypass"}`, nil) + require.Equal(t, 429, rr.Code) + require.Equal(t, 42906, toHTTPError(t, rr.Body.String()).Code) +} + func TestAccount_Get_Anonymous(t *testing.T) { conf := newTestConfigWithAuthFile(t) conf.VisitorRequestLimitReplenish = 86 * time.Second @@ -567,3 +582,60 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { s.topics["mytopic"].CancelSubscribers("") <-userCh } + +func TestAccount_Tier_Create(t *testing.T) { + conf := newTestConfigWithAuthFile(t) + s := newTestServer(t, conf) + + // Create tier and user + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "pro", + Name: "Pro", + MessagesLimit: 123, + MessagesExpiryDuration: 86400 * time.Second, + EmailsLimit: 32, + ReservationsLimit: 2, + AttachmentFileSizeLimit: 1231231, + AttachmentTotalSizeLimit: 123123, + AttachmentExpiryDuration: 10800 * time.Second, + AttachmentBandwidthLimit: 21474836480, + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + require.Nil(t, s.userManager.ChangeTier("phil", "pro")) + + ti, err := s.userManager.Tier("pro") + require.Nil(t, err) + + u, err := s.userManager.User("phil") + require.Nil(t, err) + + // These are populated by different SQL queries + require.Equal(t, ti, u.Tier) + + // Fields + require.True(t, strings.HasPrefix(ti.ID, "ti_")) + require.Equal(t, "pro", ti.Code) + require.Equal(t, "Pro", ti.Name) + require.Equal(t, int64(123), ti.MessagesLimit) + require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration) + require.Equal(t, int64(32), ti.EmailsLimit) + require.Equal(t, int64(2), ti.ReservationsLimit) + require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit) + require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) + require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) + require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit) +} + +func TestAccount_Tier_Create_With_ID(t *testing.T) { + conf := newTestConfigWithAuthFile(t) + s := newTestServer(t, conf) + + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_123", + Code: "pro", + })) + + ti, err := s.userManager.Tier("pro") + require.Nil(t, err) + require.Equal(t, "ti_123", ti.ID) +} diff --git a/server/server_payments_test.go b/server/server_payments_test.go index a2a4738..d7bc660 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -133,6 +133,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) { // Create tier and user require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_123", Code: "pro", StripePriceID: "price_123", })) @@ -168,6 +169,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) { // Create tier and user require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_123", Code: "pro", StripePriceID: "price_123", })) @@ -209,6 +211,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) { // Create tier and user require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_123", Code: "pro", StripePriceID: "price_123", })) @@ -235,6 +238,106 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) { require.Equal(t, 401, rr.Code) } +func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) { + // This tests a successful checkout flow (not a paying customer -> paying customer), + // and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user. + + stripeMock := &testStripeAPI{} + defer stripeMock.AssertExpectations(t) + + c := newTestConfigWithAuthFile(t) + c.StripeSecretKey = "secret key" + c.StripeWebhookKey = "webhook key" + c.VisitorRequestLimitBurst = 10 + c.VisitorRequestLimitReplenish = time.Hour + s := newTestServer(t, c) + s.stripe = stripeMock + + // Create a user with a Stripe subscription and 3 reservations + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_123", + Code: "starter", + StripePriceID: "price_1234", + ReservationsLimit: 1, + MessagesLimit: 100, + MessagesExpiryDuration: time.Hour, + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier + u, err := s.userManager.User("phil") + require.Nil(t, err) + + // Define how the mock should react + stripeMock. + On("GetSession", "SOMETOKEN"). + Return(&stripe.CheckoutSession{ + ClientReferenceID: u.ID, // ntfy user ID + Customer: &stripe.Customer{ + ID: "acct_5555", + }, + Subscription: &stripe.Subscription{ + ID: "sub_1234", + }, + }, nil) + stripeMock. + On("GetSubscription", "sub_1234"). + Return(&stripe.Subscription{ + ID: "sub_1234", + Status: stripe.SubscriptionStatusActive, + CurrentPeriodEnd: 123456789, + CancelAt: 0, + Items: &stripe.SubscriptionItemList{ + Data: []*stripe.SubscriptionItem{ + { + Price: &stripe.Price{ID: "price_1234"}, + }, + }, + }, + }, nil) + stripeMock. + On("UpdateCustomer", mock.Anything). + Return(&stripe.Customer{}, nil) + + // Send messages until rate limit of free tier is hit + for i := 0; i < 10; i++ { + rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + } + rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 429, rr.Code) + + // Simulate Stripe success return URL call (no user context) + rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil) + require.Equal(t, 303, rr.Code) + + // Verify that database columns were updated + u, err = s.userManager.User("phil") + require.Nil(t, err) + require.Equal(t, "starter", u.Tier.Code) // Not "pro" + require.Equal(t, "acct_5555", u.Billing.StripeCustomerID) + require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID) + require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) + require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix()) + require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) + + // FIXME FIXME This test is broken, because the rate limit logic is unclear! + + // Now for the fun part: Verify that new rate limits are immediately applied + for i := 0; i < 100; i++ { + rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code, "failed on iteration %d", i) + } + rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 429, rr.Code) +} + func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) { // This tests incoming webhooks from Stripe to update a subscription: // - All Stripe columns are updated in the user table @@ -257,6 +360,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( // Create a user with a Stripe subscription and 3 reservations require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_1", Code: "starter", StripePriceID: "price_1234", // ! ReservationsLimit: 1, // ! @@ -268,6 +372,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( AttachmentBandwidthLimit: 1000000, })) require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_2", Code: "pro", StripePriceID: "price_1111", // ! ReservationsLimit: 3, // ! diff --git a/server/visitor.go b/server/visitor.go index a7b5aaf..6dcf403 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -3,6 +3,7 @@ package server import ( "errors" "fmt" + "heckel.io/ntfy/log" "heckel.io/ntfy/user" "net/netip" "sync" @@ -41,7 +42,7 @@ type visitor struct { emailsLimiter *rate.Limiter // Rate limiter for emails subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections) bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads - accountLimiter *rate.Limiter // Rate limiter for account creation + accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil firebase time.Time // Next allowed Firebase message seen time.Time // Last seen time of this visitor (needed for removal of stale visitors) mu sync.Mutex @@ -85,26 +86,12 @@ const ( ) func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { - var messagesLimiter, attachmentBandwidthLimiter util.Limiter - var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter var messages, emails int64 if user != nil { messages = user.Stats.Messages emails = user.Stats.Emails - } else { - accountLimiter = rate.NewLimiter(rate.Every(conf.VisitorAccountCreateLimitReplenish), conf.VisitorAccountCreateLimitBurst) } - if user != nil && user.Tier != nil { - requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst) - messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit) - emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst) - attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour) - } else { - requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst) - emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst) - attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour) - } - return &visitor{ + v := &visitor{ config: conf, messageCache: messageCache, userManager: userManager, // May be nil @@ -112,20 +99,26 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana user: user, messages: messages, emails: emails, - requestLimiter: requestLimiter, - messagesLimiter: messagesLimiter, // May be nil - emailsLimiter: emailsLimiter, - subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), - bandwidthLimiter: attachmentBandwidthLimiter, - accountLimiter: accountLimiter, // May be nil firebase: time.Unix(0, 0), seen: time.Now(), + subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), + requestLimiter: nil, // Set in resetLimiters + messagesLimiter: nil, // Set in resetLimiters, may be nil + emailsLimiter: nil, // Set in resetLimiters + bandwidthLimiter: nil, // Set in resetLimiters + accountLimiter: nil, // Set in resetLimiters, may be nil } + v.resetLimiters() + return v } func (v *visitor) String() string { v.mu.Lock() defer v.mu.Unlock() + return v.stringNoLock() +} + +func (v *visitor) stringNoLock() string { if v.user != nil && v.user.Billing.StripeCustomerID != "" { return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID) } else if v.user != nil { @@ -179,6 +172,13 @@ func (v *visitor) SubscriptionAllowed() error { return nil } +func (v *visitor) AccountCreationAllowed() error { + if v.accountLimiter != nil && !v.accountLimiter.Allow() { + return errVisitorLimitReached + } + return nil +} + func (v *visitor) RemoveSubscription() { v.mu.Lock() defer v.mu.Unlock() @@ -235,7 +235,35 @@ func (v *visitor) ResetStats() { func (v *visitor) SetUser(u *user.User) { v.mu.Lock() defer v.mu.Unlock() + shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver v.user = u + if shouldResetLimiters { + v.resetLimiters() + } +} + +func (v *visitor) resetLimiters() { + log.Info("%s Resetting limiters for visitor", v.stringNoLock()) + var messagesLimiter, bandwidthLimiter util.Limiter + var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter + if v.user != nil && v.user.Tier != nil { + requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst) + messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit) + emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst) + bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour) + accountLimiter = nil // A logged-in user cannot create an account + } else { + requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst) + messagesLimiter = nil // Message limit is governed by the requestLimiter + emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst) + bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour) + accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst) + } + v.requestLimiter = requestLimiter + v.messagesLimiter = messagesLimiter + v.emailsLimiter = emailsLimiter + v.bandwidthLimiter = bandwidthLimiter + v.accountLimiter = accountLimiter } // MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor, diff --git a/user/manager.go b/user/manager.go index f1bb73f..c252e1f 100644 --- a/user/manager.go +++ b/user/manager.go @@ -110,26 +110,26 @@ const ( ` selectUserByIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.id = ? ` selectUserByNameQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id FROM user u JOIN user_token t on u.id = t.user_id LEFT JOIN tier t on t.id = u.tier_id WHERE t.token = ? AND t.expires >= ? ` selectUserByStripeCustomerIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.stripe_customer_id = ? @@ -669,13 +669,13 @@ func (a *Manager) userByToken(token string) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var id, username, hash, role, prefs, syncTopic string - var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString + var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString var messages, emails int64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil { + if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -706,6 +706,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { if tierCode.Valid { // See readTier() when this is changed! user.Tier = &Tier{ + ID: tierID.String, Code: tierCode.String, Name: tierName.String, MessagesLimit: messagesLimit.Int64, @@ -995,8 +996,10 @@ func (a *Manager) DefaultAccess() Permission { // CreateTier creates a new tier in the database func (a *Manager) CreateTier(tier *Tier) error { - tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength) - if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil { + if tier.ID == "" { + tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) + } + if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil { return err } return nil diff --git a/user/types.go b/user/types.go index a4be9b4..d5b1a25 100644 --- a/user/types.go +++ b/user/types.go @@ -23,6 +23,15 @@ type User struct { Deleted bool } +// TierID returns the ID of the User.Tier, or an empty string if the user has no tier, +// or if the user itself is nil. +func (u *User) TierID() string { + if u == nil || u.Tier == nil { + return "" + } + return u.Tier.ID +} + // Auther is an interface for authentication and authorization type Auther interface { // Authenticate checks username and password and returns a user if correct. The method