Payment checkout test, rate limit resetting on tier change; failing

This commit is contained in:
binwiederhier 2023-01-25 22:26:04 -05:00
parent 236254d907
commit 593e0748a8
8 changed files with 257 additions and 42 deletions

View file

@ -46,8 +46,8 @@ const (
DefaultVisitorRequestLimitReplenish = 5 * time.Second
DefaultVisitorEmailLimitBurst = 16
DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorAccountCreateLimitBurst = 3
DefaultVisitorAccountCreateLimitReplenish = 24 * time.Hour
DefaultVisitorAccountCreationLimitBurst = 3
DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
)
@ -107,8 +107,8 @@ type Config struct {
VisitorRequestExemptIPAddrs []netip.Prefix
VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration
VisitorAccountCreateLimitBurst int
VisitorAccountCreateLimitReplenish time.Duration
VisitorAccountCreationLimitBurst int
VisitorAccountCreationLimitReplenish time.Duration
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
BehindProxy bool
StripeSecretKey string
@ -173,8 +173,8 @@ func NewConfig() *Config {
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
VisitorAccountCreateLimitBurst: DefaultVisitorAccountCreateLimitBurst,
VisitorAccountCreateLimitReplenish: DefaultVisitorAccountCreateLimitReplenish,
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
BehindProxy: false,
StripeSecretKey: "",

View file

@ -40,6 +40,8 @@ TODO
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@ -50,8 +52,6 @@ TODO
Limits & rate limiting:
users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
when ResetStats() is run, reset messagesLimiter (and others)?
Delete visitor when tier is changed to refresh rate limiters
Make sure account endpoints make sense for admins
@ -1602,9 +1602,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
} else {
v = s.visitorFromIP(ip)
}
v.mu.Lock()
v.user = u
v.mu.Unlock()
v.SetUser(u) // Update visitor user with latest from database!
return v, err // Always return visitor, even when error occurs!
}

View file

@ -31,7 +31,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
return errHTTPConflictUserExists
}
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
if err := v.AccountCreationAllowed(); err != nil {
return errHTTPTooManyRequestsLimitAccountCreation
}
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User

View file

@ -6,6 +6,7 @@ import (
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"strings"
"testing"
"time"
)
@ -91,6 +92,20 @@ func TestAccount_Signup_Disabled(t *testing.T) {
require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code)
}
func TestAccount_Signup_Rate_Limit(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
conf.EnableSignup = true
s := newTestServer(t, conf)
for i := 0; i < 3; i++ {
rr := request(t, s, "POST", "/v1/account", fmt.Sprintf(`{"username":"phil%d", "password":"mypass"}`, i), nil)
require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
}
rr := request(t, s, "POST", "/v1/account", `{"username":"notallowed", "password":"mypass"}`, nil)
require.Equal(t, 429, rr.Code)
require.Equal(t, 42906, toHTTPError(t, rr.Body.String()).Code)
}
func TestAccount_Get_Anonymous(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
conf.VisitorRequestLimitReplenish = 86 * time.Second
@ -567,3 +582,60 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
s.topics["mytopic"].CancelSubscribers("<invalid>")
<-userCh
}
func TestAccount_Tier_Create(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
s := newTestServer(t, conf)
// Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
Name: "Pro",
MessagesLimit: 123,
MessagesExpiryDuration: 86400 * time.Second,
EmailsLimit: 32,
ReservationsLimit: 2,
AttachmentFileSizeLimit: 1231231,
AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second,
AttachmentBandwidthLimit: 21474836480,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
ti, err := s.userManager.Tier("pro")
require.Nil(t, err)
u, err := s.userManager.User("phil")
require.Nil(t, err)
// These are populated by different SQL queries
require.Equal(t, ti, u.Tier)
// Fields
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
require.Equal(t, "pro", ti.Code)
require.Equal(t, "Pro", ti.Name)
require.Equal(t, int64(123), ti.MessagesLimit)
require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
require.Equal(t, int64(32), ti.EmailsLimit)
require.Equal(t, int64(2), ti.ReservationsLimit)
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
}
func TestAccount_Tier_Create_With_ID(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
s := newTestServer(t, conf)
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "pro",
}))
ti, err := s.userManager.Tier("pro")
require.Nil(t, err)
require.Equal(t, "ti_123", ti.ID)
}

View file

@ -133,6 +133,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
}))
@ -168,6 +169,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
}))
@ -209,6 +211,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
}))
@ -235,6 +238,106 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
require.Equal(t, 401, rr.Code)
}
func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
// This tests a successful checkout flow (not a paying customer -> paying customer),
// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
c.VisitorRequestLimitBurst = 10
c.VisitorRequestLimitReplenish = time.Hour
s := newTestServer(t, c)
s.stripe = stripeMock
// Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "starter",
StripePriceID: "price_1234",
ReservationsLimit: 1,
MessagesLimit: 100,
MessagesExpiryDuration: time.Hour,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
u, err := s.userManager.User("phil")
require.Nil(t, err)
// Define how the mock should react
stripeMock.
On("GetSession", "SOMETOKEN").
Return(&stripe.CheckoutSession{
ClientReferenceID: u.ID, // ntfy user ID
Customer: &stripe.Customer{
ID: "acct_5555",
},
Subscription: &stripe.Subscription{
ID: "sub_1234",
},
}, nil)
stripeMock.
On("GetSubscription", "sub_1234").
Return(&stripe.Subscription{
ID: "sub_1234",
Status: stripe.SubscriptionStatusActive,
CurrentPeriodEnd: 123456789,
CancelAt: 0,
Items: &stripe.SubscriptionItemList{
Data: []*stripe.SubscriptionItem{
{
Price: &stripe.Price{ID: "price_1234"},
},
},
},
}, nil)
stripeMock.
On("UpdateCustomer", mock.Anything).
Return(&stripe.Customer{}, nil)
// Send messages until rate limit of free tier is hit
for i := 0; i < 10; i++ {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
}
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code)
// Simulate Stripe success return URL call (no user context)
rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
require.Equal(t, 303, rr.Code)
// Verify that database columns were updated
u, err = s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
// FIXME FIXME This test is broken, because the rate limit logic is unclear!
// Now for the fun part: Verify that new rate limits are immediately applied
for i := 0; i < 100; i++ {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
}
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code)
}
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
// This tests incoming webhooks from Stripe to update a subscription:
// - All Stripe columns are updated in the user table
@ -257,6 +360,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
// Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_1",
Code: "starter",
StripePriceID: "price_1234", // !
ReservationsLimit: 1, // !
@ -268,6 +372,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentBandwidthLimit: 1000000,
}))
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_2",
Code: "pro",
StripePriceID: "price_1111", // !
ReservationsLimit: 3, // !

View file

@ -3,6 +3,7 @@ package server
import (
"errors"
"fmt"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"net/netip"
"sync"
@ -41,7 +42,7 @@ type visitor struct {
emailsLimiter *rate.Limiter // Rate limiter for emails
subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections)
bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads
accountLimiter *rate.Limiter // Rate limiter for account creation
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
firebase time.Time // Next allowed Firebase message
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
mu sync.Mutex
@ -85,26 +86,12 @@ const (
)
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messagesLimiter, attachmentBandwidthLimiter util.Limiter
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
var messages, emails int64
if user != nil {
messages = user.Stats.Messages
emails = user.Stats.Emails
} else {
accountLimiter = rate.NewLimiter(rate.Every(conf.VisitorAccountCreateLimitReplenish), conf.VisitorAccountCreateLimitBurst)
}
if user != nil && user.Tier != nil {
requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst)
messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit)
emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst)
attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
} else {
requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst)
emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst)
attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
}
return &visitor{
v := &visitor{
config: conf,
messageCache: messageCache,
userManager: userManager, // May be nil
@ -112,20 +99,26 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
user: user,
messages: messages,
emails: emails,
requestLimiter: requestLimiter,
messagesLimiter: messagesLimiter, // May be nil
emailsLimiter: emailsLimiter,
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
bandwidthLimiter: attachmentBandwidthLimiter,
accountLimiter: accountLimiter, // May be nil
firebase: time.Unix(0, 0),
seen: time.Now(),
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
requestLimiter: nil, // Set in resetLimiters
messagesLimiter: nil, // Set in resetLimiters, may be nil
emailsLimiter: nil, // Set in resetLimiters
bandwidthLimiter: nil, // Set in resetLimiters
accountLimiter: nil, // Set in resetLimiters, may be nil
}
v.resetLimiters()
return v
}
func (v *visitor) String() string {
v.mu.Lock()
defer v.mu.Unlock()
return v.stringNoLock()
}
func (v *visitor) stringNoLock() string {
if v.user != nil && v.user.Billing.StripeCustomerID != "" {
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
} else if v.user != nil {
@ -179,6 +172,13 @@ func (v *visitor) SubscriptionAllowed() error {
return nil
}
func (v *visitor) AccountCreationAllowed() error {
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
return errVisitorLimitReached
}
return nil
}
func (v *visitor) RemoveSubscription() {
v.mu.Lock()
defer v.mu.Unlock()
@ -235,7 +235,35 @@ func (v *visitor) ResetStats() {
func (v *visitor) SetUser(u *user.User) {
v.mu.Lock()
defer v.mu.Unlock()
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
v.user = u
if shouldResetLimiters {
v.resetLimiters()
}
}
func (v *visitor) resetLimiters() {
log.Info("%s Resetting limiters for visitor", v.stringNoLock())
var messagesLimiter, bandwidthLimiter util.Limiter
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
if v.user != nil && v.user.Tier != nil {
requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
accountLimiter = nil // A logged-in user cannot create an account
} else {
requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
messagesLimiter = nil // Message limit is governed by the requestLimiter
emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
}
v.requestLimiter = requestLimiter
v.messagesLimiter = messagesLimiter
v.emailsLimiter = emailsLimiter
v.bandwidthLimiter = bandwidthLimiter
v.accountLimiter = accountLimiter
}
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,

View file

@ -110,26 +110,26 @@ const (
`
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ?
`
selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
@ -669,13 +669,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -706,6 +706,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
if tierCode.Valid {
// See readTier() when this is changed!
user.Tier = &Tier{
ID: tierID.String,
Code: tierCode.String,
Name: tierName.String,
MessagesLimit: messagesLimit.Int64,
@ -995,8 +996,10 @@ func (a *Manager) DefaultAccess() Permission {
// CreateTier creates a new tier in the database
func (a *Manager) CreateTier(tier *Tier) error {
tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
}
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
return err
}
return nil

View file

@ -23,6 +23,15 @@ type User struct {
Deleted bool
}
// TierID returns the ID of the User.Tier, or an empty string if the user has no tier,
// or if the user itself is nil.
func (u *User) TierID() string {
if u == nil || u.Tier == nil {
return ""
}
return u.Tier.ID
}
// Auther is an interface for authentication and authorization
type Auther interface {
// Authenticate checks username and password and returns a user if correct. The method