Payments webhook test, delete attachments/messages when reservations are removed,

This commit is contained in:
binwiederhier 2023-01-20 22:47:37 -05:00
parent 45b97c7054
commit 31a3bb7cd6
16 changed files with 571 additions and 157 deletions

View file

@ -100,22 +100,24 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic
if !util.Contains([]string{"", "read-write", "rw", "read-only", "read", "ro", "write-only", "write", "wo", "none", "deny"}, perms) { if !util.Contains([]string{"", "read-write", "rw", "read-only", "read", "ro", "write-only", "write", "wo", "none", "deny"}, perms) {
return errors.New("permission must be one of: read-write, read-only, write-only, or deny (or the aliases: read, ro, write, wo, none)") return errors.New("permission must be one of: read-write, read-only, write-only, or deny (or the aliases: read, ro, write, wo, none)")
} }
read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms) permission, err := user.ParsePermission(perms)
write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms) if err != nil {
return err
}
u, err := manager.User(username) u, err := manager.User(username)
if err == user.ErrUserNotFound { if err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username) return fmt.Errorf("user %s does not exist", username)
} else if u.Role == user.RoleAdmin { } else if u.Role == user.RoleAdmin {
return fmt.Errorf("user %s is an admin user, access control entries have no effect", username) return fmt.Errorf("user %s is an admin user, access control entries have no effect", username)
} }
if err := manager.AllowAccess("", username, topic, read, write); err != nil { if err := manager.AllowAccess(username, topic, permission); err != nil {
return err return err
} }
if read && write { if permission.IsReadWrite() {
fmt.Fprintf(c.App.ErrWriter, "granted read-write access to topic %s\n\n", topic) fmt.Fprintf(c.App.ErrWriter, "granted read-write access to topic %s\n\n", topic)
} else if read { } else if permission.IsRead() {
fmt.Fprintf(c.App.ErrWriter, "granted read-only access to topic %s\n\n", topic) fmt.Fprintf(c.App.ErrWriter, "granted read-only access to topic %s\n\n", topic)
} else if write { } else if permission.IsWrite() {
fmt.Fprintf(c.App.ErrWriter, "granted write-only access to topic %s\n\n", topic) fmt.Fprintf(c.App.ErrWriter, "granted write-only access to topic %s\n\n", topic)
} else { } else {
fmt.Fprintf(c.App.ErrWriter, "revoked all access to topic %s\n\n", topic) fmt.Fprintf(c.App.ErrWriter, "revoked all access to topic %s\n\n", topic)

View file

@ -57,9 +57,10 @@ const (
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, encoding, published) INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?` deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
selectMessagesSinceTimeQuery = ` selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesSinceTimeQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages FROM messages
WHERE topic = ? AND time >= ? AND published = 1 WHERE topic = ? AND time >= ? AND published = 1
@ -96,7 +97,7 @@ const (
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?` updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires <= ? AND attachment_deleted = 0` selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
) )
@ -506,6 +507,20 @@ func (c *messageCache) DeleteMessages(ids ...string) error {
return tx.Commit() return tx.Commit()
} }
func (c *messageCache) ExpireMessages(topics ...string) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, t := range topics {
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix(), t); err != nil {
return err
}
}
return tx.Commit()
}
func (c *messageCache) AttachmentsExpired() ([]string, error) { func (c *messageCache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix()) rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
if err != nil { if err != nil {

View file

@ -362,6 +362,61 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
require.Equal(t, int64(0), size) require.Equal(t, int64(0), size)
} }
func TestSqliteCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newSqliteTestCache(t))
}
func TestMemCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newMemTestCache(t))
}
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic", "message with attachment")
m.ID = "m2"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: time.Now().Add(2 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic", "message with external attachment")
m.ID = "m3"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Expires: 0, // Unknown!
URL: "https://somedomain.com/car.jpg",
}
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic2", "message with expired attachment")
m.ID = "m4"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "expired-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: time.Now().Add(-1 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
ids, err := c.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "m4", ids[0])
}
func TestSqliteCache_Migration_From0(t *testing.T) { func TestSqliteCache_Migration_From0(t *testing.T) {
filename := newSqliteTestCacheFile(t) filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename) db, err := sql.Open("sqlite3", filename)

View file

@ -40,13 +40,15 @@ import (
- v.user --> see publishSyncEventAsync() test - v.user --> see publishSyncEventAsync() test
payments: payments:
- delete messages + reserved topics on ResetTier - delete messages + reserved topics on ResetTier delete attachments in access.go
- reconciliation
Limits & rate limiting: Limits & rate limiting:
users without tier: should the stats be persisted? are they meaningful? users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
-> test that the visitor is based on the IP address!
login/account endpoints login/account endpoints
when ResetStats() is run, reset messagesLimiter (and others)? when ResetStats() is run, reset messagesLimiter (and others)?
Delete visitor when tier is changed to refresh rate limiters
Make sure account endpoints make sense for admins Make sure account endpoints make sense for admins
UI: UI:
@ -55,10 +57,9 @@ import (
- JS constants - JS constants
Sync: Sync:
- sync problems with "deleteAfter=0" and "displayName=" - sync problems with "deleteAfter=0" and "displayName="
Delete visitor when tier is changed to refresh rate limiters
Tests: Tests:
- Payment endpoints (make mocks) - Payment endpoints (make mocks)
- Change tier from higher to lower tier (delete reservations)
- Message rate limiting and reset tests - Message rate limiting and reset tests
- test that the visitor is based on the IP address when a user has no tier - test that the visitor is based on the IP address when a user has no tier
*/ */

View file

@ -119,7 +119,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
} }
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error { func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID != "" { if v.user.Billing.StripeSubscriptionID != "" {
log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID) log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" { if v.user.Billing.StripeSubscriptionID != "" {
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil { if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
@ -332,11 +332,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
return errHTTPTooManyRequestsLimitReservations return errHTTPTooManyRequestsLimitReservations
} }
} }
owner, username := v.user.Name, v.user.Name if err := s.userManager.ReserveAccess(v.user.Name, req.Topic, everyone); err != nil {
if err := s.userManager.AllowAccess(owner, username, req.Topic, true, true); err != nil {
return err
}
if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyone.IsRead(), everyone.IsWrite()); err != nil {
return err return err
} }
return s.writeJSON(w, newSuccessResponse()) return s.writeJSON(w, newSuccessResponse())
@ -357,10 +353,7 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
} else if !authorized { } else if !authorized {
return errHTTPUnauthorized return errHTTPUnauthorized
} }
if err := s.userManager.ResetAccess(v.user.Name, topic); err != nil { if err := s.userManager.RemoveReservations(v.user.Name, topic); err != nil {
return err
}
if err := s.userManager.ResetAccess(user.Everyone, topic); err != nil {
return err return err
} }
return s.writeJSON(w, newSuccessResponse()) return s.writeJSON(w, newSuccessResponse())

View file

@ -27,6 +27,28 @@ var (
errNoBillingSubscription = errors.New("user does not have an active billing subscription") errNoBillingSubscription = errors.New("user does not have an active billing subscription")
) )
// Payments in ntfy are done via Stripe.
//
// Pretty much all payments related things are in this file. The following processes
// handle payments:
//
// - Checkout:
// Creating a Stripe customer and subscription via the Checkout flow. This flow is only used if the
// ntfy user is not already a Stripe customer. This requires redirecting to the Stripe checkout page.
// It is implemented in handleAccountBillingSubscriptionCreate and the success callback
// handleAccountBillingSubscriptionCreateSuccess.
// - Update subscription:
// Switching between Stripe subscriptions (upgrade/downgrade) is handled via
// handleAccountBillingSubscriptionUpdate. This also handles proration.
// - Cancel subscription (at period end):
// Users can cancel the Stripe subscription via the web app at the end of the billing period. This
// simply updates the subscription and Stripe will cancel it. Users cannot immediately cancel the
// subscription.
// - Webhooks:
// Whenever a subscription changes (updated, deleted), Stripe sends us a request via a webhook.
// This is used to keep the local user database fields up to date. Stripe is the source of truth.
// What Stripe says is mirrored and not questioned.
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
// in the UI. Note that this endpoint does NOT have a user context (no v.user!). // in the UI. Note that this endpoint does NOT have a user context (no v.user!).
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error { func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
@ -37,7 +59,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
freeTier := defaultVisitorLimits(s.config) freeTier := defaultVisitorLimits(s.config)
response := []*apiAccountBillingTier{ response := []*apiAccountBillingTier{
{ {
// Free tier: no code, name or price // This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
Limits: &apiAccountLimits{ Limits: &apiAccountLimits{
Messages: freeTier.MessagesLimit, Messages: freeTier.MessagesLimit,
MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()), MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
@ -130,6 +152,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
return s.writeJSON(w, response) return s.writeJSON(w, response)
} }
// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
// and only time we can map the local username with the Stripe customer ID.
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error { func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
// We don't have a v.user in this endpoint, only a userManager! // We don't have a v.user in this endpoint, only a userManager!
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
@ -139,8 +164,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
sessionID := matches[1] sessionID := matches[1]
sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this? sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
if err != nil { if err != nil {
log.Warn("Stripe: %s", err) return err
return errHTTPBadRequestBillingRequestInvalid
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" { } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found") return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
} }
@ -158,7 +182,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil { if err != nil {
return err return err
} }
if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil { if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err return err
} }
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther) http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -216,6 +240,8 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
return s.writeJSON(w, newSuccessResponse()) return s.writeJSON(w, newSuccessResponse())
} }
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" { if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser return errHTTPBadRequestNotAPaidUser
@ -250,10 +276,11 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
} }
event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey) event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
if err != nil { if err != nil {
return errHTTPBadRequestBillingRequestInvalid return err
} else if event.Data == nil || event.Data.Raw == nil { } else if event.Data == nil || event.Data.Raw == nil {
return errHTTPBadRequestBillingRequestInvalid return errHTTPBadRequestBillingRequestInvalid
} }
log.Info("Stripe: webhook event %s received", event.Type) log.Info("Stripe: webhook event %s received", event.Type)
switch event.Type { switch event.Type {
case "customer.subscription.updated": case "customer.subscription.updated":
@ -282,7 +309,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err != nil { if err != nil {
return err return err
} }
if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil { if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil {
return err return err
} }
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
@ -301,29 +328,54 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
if err != nil { if err != nil {
return err return err
} }
if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil { if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil {
return err return err
} }
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil return nil
} }
func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error { func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
u.Billing.StripeCustomerID = customerID // Remove excess reservations (if too many for tier), and mark associated messages deleted
u.Billing.StripeSubscriptionID = subscriptionID reservations, err := s.userManager.Reservations(u.Name)
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status) if err != nil {
u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0) return err
u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0) }
if tier == "" { reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationsLimit
}
if int64(len(reservations)) > reservationsLimit {
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
}
// Change or remove tier
if tier == nil {
if err := s.userManager.ResetTier(u.Name); err != nil { if err := s.userManager.ResetTier(u.Name); err != nil {
return err return err
} }
} else { } else {
if err := s.userManager.ChangeTier(u.Name, tier); err != nil { if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err return err
} }
} }
if err := s.userManager.ChangeBilling(u); err != nil { // Update billing fields
billing := &user.Billing{
StripeCustomerID: customerID,
StripeSubscriptionID: subscriptionID,
StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
}
if err := s.userManager.ChangeBilling(u.Name, billing); err != nil {
return err return err
} }
return nil return nil

View file

@ -1,13 +1,17 @@
package server package server
import ( import (
"encoding/json"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stripe/stripe-go/v74" "github.com/stripe/stripe-go/v74"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"io" "io"
"path/filepath"
"strings"
"testing" "testing"
"time"
) )
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) { func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
@ -70,8 +74,10 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
u, err := s.userManager.User("phil") u, err := s.userManager.User("phil")
require.Nil(t, err) require.Nil(t, err)
u.Billing.StripeCustomerID = "acct_123" billing := &user.Billing{
require.Nil(t, s.userManager.ChangeBilling(u)) StripeCustomerID: "acct_123",
}
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Create subscription // Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
@ -109,9 +115,11 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
u, err := s.userManager.User("phil") u, err := s.userManager.User("phil")
require.Nil(t, err) require.Nil(t, err)
u.Billing.StripeCustomerID = "acct_123" billing := &user.Billing{
u.Billing.StripeSubscriptionID = "sub_123" StripeCustomerID: "acct_123",
require.Nil(t, s.userManager.ChangeBilling(u)) StripeSubscriptionID: "sub_123",
}
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Delete account // Delete account
rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{ rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{
@ -125,6 +133,127 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
require.Equal(t, 401, rr.Code) require.Equal(t, 401, rr.Code)
} }
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
// This tests incoming webhooks from Stripe to update a subscription:
// - All Stripe columns are updated in the user table
// - When downgrading, excess reservations are deleted, including messages and attachments in
// the corresponding topics
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
s.stripe = stripeMock
// Define how the mock should react
stripeMock.
On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
// Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "starter",
StripePriceID: "price_1234", // !
ReservationsLimit: 1, // !
MessagesLimit: 100,
MessagesExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
}))
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
StripePriceID: "price_1111", // !
ReservationsLimit: 3, // !
MessagesLimit: 200,
MessagesExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll))
// Add billing details
u, err := s.userManager.User("phil")
require.Nil(t, err)
billing := &user.Billing{
StripeCustomerID: "acct_5555",
StripeSubscriptionID: "sub_1234",
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
StripeSubscriptionPaidUntil: time.Unix(123, 0),
StripeSubscriptionCancelAt: time.Unix(456, 0),
}
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
a2 := toMessage(t, rr.Body.String())
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
z2 := toMessage(t, rr.Body.String())
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
// Call the webhook: This does all the magic
rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
"Stripe-Signature": "stripe signature",
})
require.Equal(t, 200, rr.Code)
// Verify that database columns were updated
u, err = s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
// Verify that reservations were deleted
r, err := s.userManager.Reservations("phil")
require.Nil(t, err)
require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
require.Equal(t, "atopic", r[0].Topic)
// Verify that messages and attachments were deleted
time.Sleep(time.Second)
s.execManager()
ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(ms))
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 0, len(ms))
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
}
type testStripeAPI struct { type testStripeAPI struct {
mock.Mock mock.Mock
} }
@ -175,3 +304,34 @@ func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, sec
} }
var _ stripeAPI = (*testStripeAPI)(nil) var _ stripeAPI = (*testStripeAPI)(nil)
func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
var e stripe.Event
if err := json.Unmarshal([]byte(v), &e); err != nil {
t.Fatal(err)
}
return e
}
const subscriptionUpdatedEventJSON = `
{
"type": "customer.subscription.updated",
"data": {
"object": {
"id": "sub_1234",
"customer": "acct_5555",
"status": "active",
"current_period_end": 1674268231,
"cancel_at": 1674299999,
"items": {
"data": [
{
"price": {
"id": "price_1234"
}
}
]
}
}
}
}`

View file

@ -640,7 +640,7 @@ func TestServer_Auth_Success_User(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", "ben", "mytopic", true, true)) require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"), "Authorization": util.BasicAuth("ben", "ben"),
@ -654,8 +654,8 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", "ben", "mytopic", true, true)) require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AllowAccess("", "ben", "anothertopic", true, true)) require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite))
response := request(t, s, "GET", "/mytopic,anothertopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic,anothertopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"), "Authorization": util.BasicAuth("ben", "ben"),
@ -688,7 +688,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", "ben", "sometopic", true, true)) // Not mytopic! require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic!
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"), "Authorization": util.BasicAuth("ben", "ben"),
@ -702,8 +702,8 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", user.Everyone, "private", false, false)) require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll))
require.Nil(t, s.userManager.AllowAccess("", user.Everyone, "announcements", true, false)) require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
response := request(t, s, "PUT", "/mytopic", "test", nil) response := request(t, s, "PUT", "/mytopic", "test", nil)
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
@ -750,7 +750,7 @@ func TestServer_StatsResetter(t *testing.T) {
go s.runStatsResetter() go s.runStatsResetter()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", "phil", "mytopic", true, true)) require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{ response := request(t, s, "PUT", "/mytopic", "test", map[string]string{

View file

@ -16,6 +16,10 @@ const (
// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since // has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
// they are replenished faster (typically). // they are replenished faster (typically).
visitorExpungeAfter = 24 * time.Hour visitorExpungeAfter = 24 * time.Hour
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
visitorDefaultReservationsLimit = int64(0)
) )
var ( var (
@ -289,7 +293,7 @@ func defaultVisitorLimits(conf *Config) *visitorLimits {
MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish), MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish),
MessagesExpiryDuration: conf.CacheDuration, MessagesExpiryDuration: conf.CacheDuration,
EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish),
ReservationsLimit: 0, // No reservations for anonymous users, or users without a tier ReservationsLimit: visitorDefaultReservationsLimit,
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit, AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit, AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
AttachmentExpiryDuration: conf.AttachmentExpiryDuration, AttachmentExpiryDuration: conf.AttachmentExpiryDuration,

View file

@ -219,8 +219,7 @@ const (
INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id) INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?` selectTiersQuery = `
selectTiersQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier FROM tier
` `
@ -234,7 +233,7 @@ const (
FROM tier FROM tier
WHERE stripe_price_id = ? WHERE stripe_price_id = ?
` `
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?` updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
updateBillingQuery = ` updateBillingQuery = `
@ -772,26 +771,47 @@ func (a *Manager) ChangeRole(username string, role Role) error {
return nil return nil
} }
// ChangeTier changes a user's tier using the tier code // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
// or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
func (a *Manager) ChangeTier(username, tier string) error { func (a *Manager) ChangeTier(username, tier string) error {
if !AllowedUsername(username) { if !AllowedUsername(username) {
return ErrInvalidArgument return ErrInvalidArgument
} }
rows, err := a.db.Query(selectTierIDQuery, tier) t, err := a.Tier(tier)
if err != nil {
return err
} else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
return err
}
if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
return err
}
return nil
}
// ResetTier removes the tier from the given user
func (a *Manager) ResetTier(username string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument
} else if err := a.checkReservationsLimit(username, 0); err != nil {
return err
}
_, err := a.db.Exec(deleteUserTierQuery, username)
return err
}
func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
u, err := a.User(username)
if err != nil { if err != nil {
return err return err
} }
defer rows.Close() if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
if !rows.Next() { reservations, err := a.Reservations(username)
return ErrInvalidArgument if err != nil {
} return err
var tierID int64 } else if int64(len(reservations)) > reservationsLimit {
if err := rows.Scan(&tierID); err != nil { return ErrTooManyReservations
return err }
}
rows.Close()
if _, err := a.db.Exec(updateUserTierQuery, tierID, username); err != nil {
return err
} }
return nil return nil
} }
@ -823,20 +843,37 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error {
// AllowAccess adds or updates an entry in th access control list for a specific user. It controls // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
// owner may either be a user (username), or the system (empty). // owner may either be a user (username), or the system (empty).
func (a *Manager) AllowAccess(owner, username string, topicPattern string, read bool, write bool) error { func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
if !AllowedUsername(username) && username != Everyone { if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument return ErrInvalidArgument
} else if owner != "" && !AllowedUsername(owner) {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) { } else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument return ErrInvalidArgument
} }
if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), read, write, owner, owner); err != nil { owner := ""
if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
return err return err
} }
return nil return nil
} }
func (a *Manager) ReserveAccess(username string, topic string, everyone Permission) error {
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
return err
}
if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
return err
}
return tx.Commit()
}
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*). // empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *Manager) ResetAccess(username string, topicPattern string) error { func (a *Manager) ResetAccess(username string, topicPattern string) error {
@ -856,13 +893,29 @@ func (a *Manager) ResetAccess(username string, topicPattern string) error {
return err return err
} }
// ResetTier removes the tier from the given user func (a *Manager) RemoveReservations(username string, topics ...string) error {
func (a *Manager) ResetTier(username string) error { if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument return ErrInvalidArgument
} }
_, err := a.db.Exec(deleteUserTierQuery, username) for _, topic := range topics {
return err if !AllowedTopic(topic) {
return ErrInvalidArgument
}
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, topic := range topics {
if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
return err
}
if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
return err
}
}
return tx.Commit()
} }
// DefaultAccess returns the default read/write access if no access control entry matches // DefaultAccess returns the default read/write access if no access control entry matches
@ -879,8 +932,8 @@ func (a *Manager) CreateTier(tier *Tier) error {
} }
// ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
func (a *Manager) ChangeBilling(user *User) error { func (a *Manager) ChangeBilling(username string, billing *Billing) error {
if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(user.Billing.StripeSubscriptionCancelAt.Unix()), user.Name); err != nil { if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
return err return err
} }
return nil return nil

View file

@ -15,13 +15,13 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true)) require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
require.Nil(t, a.AllowAccess("", "ben", "everyonewrite", false, false)) // How unfair! require.Nil(t, a.AllowAccess("ben", "everyonewrite", PermissionDenyAll)) // How unfair!
require.Nil(t, a.AllowAccess("", Everyone, "announcements", true, false)) require.Nil(t, a.AllowAccess(Everyone, "announcements", PermissionRead))
require.Nil(t, a.AllowAccess("", Everyone, "everyonewrite", true, true)) require.Nil(t, a.AllowAccess(Everyone, "everyonewrite", PermissionReadWrite))
require.Nil(t, a.AllowAccess("", Everyone, "up*", false, true)) // Everyone can write to /up* require.Nil(t, a.AllowAccess(Everyone, "up*", PermissionWrite)) // Everyone can write to /up*
phil, err := a.Authenticate("phil", "phil") phil, err := a.Authenticate("phil", "phil")
require.Nil(t, err) require.Nil(t, err)
@ -130,12 +130,12 @@ func TestManager_UserManagement(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true)) require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
require.Nil(t, a.AllowAccess("", "ben", "everyonewrite", false, false)) // How unfair! require.Nil(t, a.AllowAccess("ben", "everyonewrite", PermissionDenyAll)) // How unfair!
require.Nil(t, a.AllowAccess("", Everyone, "announcements", true, false)) require.Nil(t, a.AllowAccess(Everyone, "announcements", PermissionRead))
require.Nil(t, a.AllowAccess("", Everyone, "everyonewrite", true, true)) require.Nil(t, a.AllowAccess(Everyone, "everyonewrite", PermissionReadWrite))
// Query user details // Query user details
phil, err := a.User("phil") phil, err := a.User("phil")
@ -177,9 +177,9 @@ func TestManager_UserManagement(t *testing.T) {
}, everyoneGrants) }, everyoneGrants)
// Ben: Before revoking // Ben: Before revoking
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) // Overwrite! require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) // Overwrite!
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true)) require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
require.Nil(t, a.Authorize(ben, "mytopic", PermissionRead)) require.Nil(t, a.Authorize(ben, "mytopic", PermissionRead))
require.Nil(t, a.Authorize(ben, "mytopic", PermissionWrite)) require.Nil(t, a.Authorize(ben, "mytopic", PermissionWrite))
require.Nil(t, a.Authorize(ben, "readme", PermissionRead)) require.Nil(t, a.Authorize(ben, "readme", PermissionRead))
@ -234,8 +234,8 @@ func TestManager_ChangePassword(t *testing.T) {
func TestManager_ChangeRole(t *testing.T) { func TestManager_ChangeRole(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
ben, err := a.User("ben") ben, err := a.User("ben")
require.Nil(t, err) require.Nil(t, err)
@ -256,6 +256,28 @@ func TestManager_ChangeRole(t *testing.T) {
require.Equal(t, 0, len(benGrants)) require.Equal(t, 0, len(benGrants))
} }
func TestManager_Reservations(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.ReserveAccess("ben", "ztopic", PermissionDenyAll))
require.Nil(t, a.ReserveAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
reservations, err := a.Reservations("ben")
require.Nil(t, err)
require.Equal(t, 2, len(reservations))
require.Equal(t, Reservation{
Topic: "readme",
Owner: PermissionReadWrite,
Everyone: PermissionRead,
}, reservations[0])
require.Equal(t, Reservation{
Topic: "ztopic",
Owner: PermissionReadWrite,
Everyone: PermissionDenyAll,
}, reservations[1])
}
func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.CreateTier(&Tier{ require.Nil(t, a.CreateTier(&Tier{
@ -272,8 +294,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
})) }))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.ChangeTier("ben", "pro")) require.Nil(t, a.ChangeTier("ben", "pro"))
require.Nil(t, a.AllowAccess("ben", "ben", "mytopic", true, true)) require.Nil(t, a.ReserveAccess("ben", "mytopic", PermissionDenyAll))
require.Nil(t, a.AllowAccess("ben", Everyone, "mytopic", false, false))
ben, err := a.User("ben") ben, err := a.User("ben")
require.Nil(t, err) require.Nil(t, err)
@ -298,6 +319,13 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
require.Equal(t, 1, len(everyoneGrants)) require.Equal(t, 1, len(everyoneGrants))
require.Equal(t, PermissionDenyAll, everyoneGrants[0].Allow) require.Equal(t, PermissionDenyAll, everyoneGrants[0].Allow)
benReservations, err := a.Reservations("ben")
require.Nil(t, err)
require.Equal(t, 1, len(benReservations))
require.Equal(t, "mytopic", benReservations[0].Topic)
require.Equal(t, PermissionReadWrite, benReservations[0].Owner)
require.Equal(t, PermissionDenyAll, benReservations[0].Everyone)
// Switch to admin, this should remove all grants and owned ACL entries // Switch to admin, this should remove all grants and owned ACL entries
require.Nil(t, a.ChangeRole("ben", RoleAdmin)) require.Nil(t, a.ChangeRole("ben", RoleAdmin))

View file

@ -221,19 +221,10 @@ func AllowedTier(tier string) bool {
// Error constants used by the package // Error constants used by the package
var ( var (
ErrUnauthenticated = errors.New("unauthenticated") ErrUnauthenticated = errors.New("unauthenticated")
ErrUnauthorized = errors.New("unauthorized") ErrUnauthorized = errors.New("unauthorized")
ErrInvalidArgument = errors.New("invalid argument") ErrInvalidArgument = errors.New("invalid argument")
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
ErrTierNotFound = errors.New("tier not found") ErrTierNotFound = errors.New("tier not found")
) ErrTooManyReservations = errors.New("new tier has lower reservation limit")
// BillingStatus represents the status of a Stripe subscription
type BillingStatus string
// BillingStatus values, subset of https://stripe.com/docs/billing/subscriptions/overview
const (
BillingStatusIncomplete = BillingStatus("incomplete")
BillingStatusActive = BillingStatus("active")
BillingStatusPastDue = BillingStatus("past_due")
) )

View file

@ -201,15 +201,19 @@
"account_delete_dialog_label": "Type '{{username}}' to delete account", "account_delete_dialog_label": "Type '{{username}}' to delete account",
"account_delete_dialog_button_cancel": "Cancel", "account_delete_dialog_button_cancel": "Cancel",
"account_delete_dialog_button_submit": "Permanently delete account", "account_delete_dialog_button_submit": "Permanently delete account",
"account_delete_dialog_billing_warning": "Deleting your account also cancels your billing subscription immediately. You will not have access to the billing dashboard anymore.",
"account_upgrade_dialog_title": "Change account tier", "account_upgrade_dialog_title": "Change account tier",
"account_upgrade_dialog_cancel_warning": "This will <strong>cancel your subscription</strong>, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server <strong>will be deleted</strong>.", "account_upgrade_dialog_cancel_warning": "This will <strong>cancel your subscription</strong>, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server <strong>will be deleted</strong>.",
"account_upgrade_dialog_proration_info": "<strong>Proration</strong>: When switching between paid plans, the price difference will be charged or refunded in the next invoice. You will not receive another invoice until the end of the next billing period.", "account_upgrade_dialog_proration_info": "<strong>Proration</strong>: When switching between paid plans, the price difference will be charged or refunded in the next invoice. You will not receive another invoice until the end of the next billing period.",
"account_upgrade_dialog_reservations_warning_one": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least one reservation</strong>. You can remove reservations in the <Link>Settings</Link>.",
"account_upgrade_dialog_reservations_warning_other": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least {{count}} reservations</strong>. You can remove reservations in the <Link>Settings</Link>.",
"account_upgrade_dialog_tier_features_reservations": "{{reservations}} reserved topics", "account_upgrade_dialog_tier_features_reservations": "{{reservations}} reserved topics",
"account_upgrade_dialog_tier_features_messages": "{{messages}} daily messages", "account_upgrade_dialog_tier_features_messages": "{{messages}} daily messages",
"account_upgrade_dialog_tier_features_emails": "{{emails}} daily emails", "account_upgrade_dialog_tier_features_emails": "{{emails}} daily emails",
"account_upgrade_dialog_tier_features_attachment_file_size": "{{filesize}} per file", "account_upgrade_dialog_tier_features_attachment_file_size": "{{filesize}} per file",
"account_upgrade_dialog_tier_features_attachment_total_size": "{{totalsize}} total storage", "account_upgrade_dialog_tier_features_attachment_total_size": "{{totalsize}} total storage",
"account_upgrade_dialog_tier_selected_label": "Selected", "account_upgrade_dialog_tier_selected_label": "Selected",
"account_upgrade_dialog_tier_current_label": "Current",
"account_upgrade_dialog_button_cancel": "Cancel", "account_upgrade_dialog_button_cancel": "Cancel",
"account_upgrade_dialog_button_redirect_signup": "Sign up now", "account_upgrade_dialog_button_redirect_signup": "Sign up now",
"account_upgrade_dialog_button_pay_now": "Pay now and subscribe", "account_upgrade_dialog_button_pay_now": "Pay now and subscribe",

View file

@ -264,7 +264,6 @@ const AccountType = () => {
const Stats = () => { const Stats = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { account } = useContext(AccountContext); const { account } = useContext(AccountContext);
const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
if (!account) { if (!account) {
return <></>; return <></>;
@ -435,6 +434,7 @@ const DeleteAccount = () => {
const DeleteAccountDialog = (props) => { const DeleteAccountDialog = (props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const { account } = useContext(AccountContext);
const [username, setUsername] = useState(""); const [username, setUsername] = useState("");
const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const buttonEnabled = username === session.username(); const buttonEnabled = username === session.username();
@ -456,6 +456,9 @@ const DeleteAccountDialog = (props) => {
fullWidth fullWidth
variant="standard" variant="standard"
/> />
{account?.billing?.subscription &&
<Alert severity="warning" sx={{mt: 1}}>{t("account_delete_dialog_billing_warning")}</Alert>
}
</DialogContent> </DialogContent>
<DialogActions> <DialogActions>
<Button onClick={props.onCancel}>{t("account_delete_dialog_button_cancel")}</Button> <Button onClick={props.onCancel}>{t("account_delete_dialog_button_cancel")}</Button>

View file

@ -3,7 +3,7 @@ import {useContext, useEffect, useState} from 'react';
import { import {
Alert, Alert,
CardActions, CardActions,
CardContent, CardContent, Chip,
FormControl, FormControl,
Select, Select,
Stack, Stack,
@ -20,6 +20,7 @@ import prefs from "../app/Prefs";
import {Paragraph} from "./styles"; import {Paragraph} from "./styles";
import EditIcon from '@mui/icons-material/Edit'; import EditIcon from '@mui/icons-material/Edit';
import CloseIcon from "@mui/icons-material/Close"; import CloseIcon from "@mui/icons-material/Close";
import WarningIcon from '@mui/icons-material/Warning';
import IconButton from "@mui/material/IconButton"; import IconButton from "@mui/material/IconButton";
import PlayArrowIcon from '@mui/icons-material/PlayArrow'; import PlayArrowIcon from '@mui/icons-material/PlayArrow';
import Container from "@mui/material/Container"; import Container from "@mui/material/Container";
@ -41,10 +42,12 @@ import routes from "./routes";
import accountApi, {UnauthorizedError} from "../app/AccountApi"; import accountApi, {UnauthorizedError} from "../app/AccountApi";
import {Pref, PrefGroup} from "./Pref"; import {Pref, PrefGroup} from "./Pref";
import LockIcon from "@mui/icons-material/Lock"; import LockIcon from "@mui/icons-material/Lock";
import {Public, PublicOff} from "@mui/icons-material"; import {Check, Info, Public, PublicOff} from "@mui/icons-material";
import DialogContentText from "@mui/material/DialogContentText"; import DialogContentText from "@mui/material/DialogContentText";
import ReserveTopicSelect from "./ReserveTopicSelect"; import ReserveTopicSelect from "./ReserveTopicSelect";
import {AccountContext} from "./App"; import {AccountContext} from "./App";
import {useOutletContext} from "react-router-dom";
import subscriptionManager from "../app/SubscriptionManager";
const Preferences = () => { const Preferences = () => {
return ( return (
@ -543,6 +546,12 @@ const ReservationsTable = (props) => {
const [dialogKey, setDialogKey] = useState(0); const [dialogKey, setDialogKey] = useState(0);
const [dialogOpen, setDialogOpen] = useState(false); const [dialogOpen, setDialogOpen] = useState(false);
const [dialogReservation, setDialogReservation] = useState(null); const [dialogReservation, setDialogReservation] = useState(null);
const { subscriptions } = useOutletContext();
const localSubscriptions = Object.assign(
...subscriptions
.filter(s => s.baseUrl === config.base_url)
.map(s => ({[s.topic]: s}))
);
const handleEditClick = (reservation) => { const handleEditClick = (reservation) => {
setDialogKey(prev => prev+1); setDialogKey(prev => prev+1);
@ -592,7 +601,9 @@ const ReservationsTable = (props) => {
key={reservation.topic} key={reservation.topic}
sx={{'&:last-child td, &:last-child th': {border: 0}}} sx={{'&:last-child td, &:last-child th': {border: 0}}}
> >
<TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("prefs_reservations_table_topic_header")}>{reservation.topic}</TableCell> <TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("prefs_reservations_table_topic_header")}>
{reservation.topic}
</TableCell>
<TableCell aria-label={t("prefs_reservations_table_access_header")}> <TableCell aria-label={t("prefs_reservations_table_access_header")}>
{reservation.everyone === "read-write" && {reservation.everyone === "read-write" &&
<> <>
@ -620,6 +631,9 @@ const ReservationsTable = (props) => {
} }
</TableCell> </TableCell>
<TableCell align="right"> <TableCell align="right">
{!localSubscriptions[reservation.topic] &&
<Chip icon={<Info/>} label="Not subscribed" color="primary" variant="outlined"/>
}
<IconButton onClick={() => handleEditClick(reservation)} aria-label={t("prefs_reservations_edit_button")}> <IconButton onClick={() => handleEditClick(reservation)} aria-label={t("prefs_reservations_edit_button")}>
<EditIcon/> <EditIcon/>
</IconButton> </IconButton>

View file

@ -21,13 +21,14 @@ import {Check} from "@mui/icons-material";
import ListItemIcon from "@mui/material/ListItemIcon"; import ListItemIcon from "@mui/material/ListItemIcon";
import ListItemText from "@mui/material/ListItemText"; import ListItemText from "@mui/material/ListItemText";
import Box from "@mui/material/Box"; import Box from "@mui/material/Box";
import {NavLink} from "react-router-dom";
const UpgradeDialog = (props) => { const UpgradeDialog = (props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const { account } = useContext(AccountContext); // May be undefined! const { account } = useContext(AccountContext); // May be undefined!
const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const [tiers, setTiers] = useState(null); const [tiers, setTiers] = useState(null);
const [newTier, setNewTier] = useState(account?.tier?.code); // May be undefined const [newTierCode, setNewTierCode] = useState(account?.tier?.code); // May be undefined
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [errorText, setErrorText] = useState(""); const [errorText, setErrorText] = useState("");
@ -41,47 +42,56 @@ const UpgradeDialog = (props) => {
return <></>; return <></>;
} }
const currentTier = account?.tier?.code; // May be undefined const tiersMap = Object.assign(...tiers.map(tier => ({[tier.code]: tier})));
let action, submitButtonLabel, submitButtonEnabled; const newTier = tiersMap[newTierCode]; // May be undefined
const currentTier = account?.tier; // May be undefined
const currentTierCode = currentTier?.code; // May be undefined
// Figure out buttons, labels and the submit action
let submitAction, submitButtonLabel, banner;
if (!account) { if (!account) {
submitButtonLabel = t("account_upgrade_dialog_button_redirect_signup"); submitButtonLabel = t("account_upgrade_dialog_button_redirect_signup");
submitButtonEnabled = true; submitAction = Action.REDIRECT_SIGNUP;
action = Action.REDIRECT_SIGNUP; banner = null;
} else if (currentTier === newTier) { } else if (currentTierCode === newTierCode) {
submitButtonLabel = t("account_upgrade_dialog_button_update_subscription"); submitButtonLabel = t("account_upgrade_dialog_button_update_subscription");
submitButtonEnabled = false; submitAction = null;
action = null; banner = (currentTierCode) ? Banner.PRORATION_INFO : null;
} else if (!currentTier) { } else if (!currentTierCode) {
submitButtonLabel = t("account_upgrade_dialog_button_pay_now"); submitButtonLabel = t("account_upgrade_dialog_button_pay_now");
submitButtonEnabled = true; submitAction = Action.CREATE_SUBSCRIPTION;
action = Action.CREATE_SUBSCRIPTION; banner = null;
} else if (!newTier) { } else if (!newTierCode) {
submitButtonLabel = t("account_upgrade_dialog_button_cancel_subscription"); submitButtonLabel = t("account_upgrade_dialog_button_cancel_subscription");
submitButtonEnabled = true; submitAction = Action.CANCEL_SUBSCRIPTION;
action = Action.CANCEL_SUBSCRIPTION; banner = Banner.CANCEL_WARNING;
} else { } else {
submitButtonLabel = t("account_upgrade_dialog_button_update_subscription"); submitButtonLabel = t("account_upgrade_dialog_button_update_subscription");
submitButtonEnabled = true; submitAction = Action.UPDATE_SUBSCRIPTION;
action = Action.UPDATE_SUBSCRIPTION; banner = Banner.PRORATION_INFO;
} }
// Exceptional conditions
if (loading) { if (loading) {
submitButtonEnabled = false; submitAction = null;
} else if (newTier?.code && account?.reservations.length > newTier?.limits.reservations) {
submitAction = null;
banner = Banner.RESERVATIONS_WARNING;
} }
const handleSubmit = async () => { const handleSubmit = async () => {
if (action === Action.REDIRECT_SIGNUP) { if (submitAction === Action.REDIRECT_SIGNUP) {
window.location.href = routes.signup; window.location.href = routes.signup;
return; return;
} }
try { try {
setLoading(true); setLoading(true);
if (action === Action.CREATE_SUBSCRIPTION) { if (submitAction === Action.CREATE_SUBSCRIPTION) {
const response = await accountApi.createBillingSubscription(newTier); const response = await accountApi.createBillingSubscription(newTierCode);
window.location.href = response.redirect_url; window.location.href = response.redirect_url;
} else if (action === Action.UPDATE_SUBSCRIPTION) { } else if (submitAction === Action.UPDATE_SUBSCRIPTION) {
await accountApi.updateBillingSubscription(newTier); await accountApi.updateBillingSubscription(newTierCode);
} else if (action === Action.CANCEL_SUBSCRIPTION) { } else if (submitAction === Action.CANCEL_SUBSCRIPTION) {
await accountApi.deleteBillingSubscription(); await accountApi.deleteBillingSubscription();
} }
props.onCancel(); props.onCancel();
@ -116,27 +126,39 @@ const UpgradeDialog = (props) => {
<TierCard <TierCard
key={`tierCard${tier.code || '_free'}`} key={`tierCard${tier.code || '_free'}`}
tier={tier} tier={tier}
selected={newTier === tier.code} // tier.code may be undefined! current={currentTierCode === tier.code} // tier.code or currentTierCode may be undefined!
onClick={() => setNewTier(tier.code)} // tier.code may be undefined! selected={newTierCode === tier.code} // tier.code may be undefined!
onClick={() => setNewTierCode(tier.code)} // tier.code may be undefined!
/> />
)} )}
</div> </div>
{action === Action.CANCEL_SUBSCRIPTION && {banner === Banner.CANCEL_WARNING &&
<Alert severity="warning"> <Alert severity="warning">
<Trans <Trans
i18nKey="account_upgrade_dialog_cancel_warning" i18nKey="account_upgrade_dialog_cancel_warning"
values={{ date: formatShortDate(account?.billing?.paid_until || 0) }} /> values={{ date: formatShortDate(account?.billing?.paid_until || 0) }} />
</Alert> </Alert>
} }
{currentTier && (!action || action === Action.UPDATE_SUBSCRIPTION) && {banner === Banner.PRORATION_INFO &&
<Alert severity="info"> <Alert severity="info">
<Trans i18nKey="account_upgrade_dialog_proration_info" /> <Trans i18nKey="account_upgrade_dialog_proration_info" />
</Alert> </Alert>
} }
{banner === Banner.RESERVATIONS_WARNING &&
<Alert severity="warning">
<Trans
i18nKey="account_upgrade_dialog_reservations_warning"
count={account?.reservations.length - newTier?.limits.reservations}
components={{
Link: <NavLink to={routes.settings}/>,
}}
/>
</Alert>
}
</DialogContent> </DialogContent>
<DialogFooter status={errorText}> <DialogFooter status={errorText}>
<Button onClick={props.onCancel}>{t("account_upgrade_dialog_button_cancel")}</Button> <Button onClick={props.onCancel}>{t("account_upgrade_dialog_button_cancel")}</Button>
<Button onClick={handleSubmit} disabled={!submitButtonEnabled}>{submitButtonLabel}</Button> <Button onClick={handleSubmit} disabled={!submitAction}>{submitButtonLabel}</Button>
</DialogFooter> </DialogFooter>
</Dialog> </Dialog>
); );
@ -144,8 +166,19 @@ const UpgradeDialog = (props) => {
const TierCard = (props) => { const TierCard = (props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const cardStyle = (props.selected) ? { background: "#eee", border: "2px solid #338574" } : { border: "2px solid transparent" };
const tier = props.tier; const tier = props.tier;
let cardStyle, labelStyle, labelText;
if (props.selected) {
cardStyle = { background: "#eee", border: "2px solid #338574" };
labelStyle = { background: "#338574", color: "white" };
labelText = t("account_upgrade_dialog_tier_selected_label");
} else if (props.current) {
cardStyle = { border: "2px solid #eee" };
labelStyle = { background: "#eee", color: "black" };
labelText = t("account_upgrade_dialog_tier_current_label");
} else {
cardStyle = { border: "2px solid transparent" };
}
return ( return (
<Box sx={{ <Box sx={{
@ -163,16 +196,15 @@ const TierCard = (props) => {
<Card sx={{ height: "100%" }}> <Card sx={{ height: "100%" }}>
<CardActionArea sx={{ height: "100%" }}> <CardActionArea sx={{ height: "100%" }}>
<CardContent onClick={props.onClick} sx={{ height: "100%" }}> <CardContent onClick={props.onClick} sx={{ height: "100%" }}>
{props.selected && {labelStyle &&
<div style={{ <div style={{
position: "absolute", position: "absolute",
top: "0", top: "0",
right: "15px", right: "15px",
padding: "2px 10px", padding: "2px 10px",
background: "#338574",
color: "white",
borderRadius: "3px", borderRadius: "3px",
}}>{t("account_upgrade_dialog_tier_selected_label")}</div> ...labelStyle
}}>{labelText}</div>
} }
<Typography variant="h5" component="div"> <Typography variant="h5" component="div">
{tier.name || t("account_usage_tier_free")} {tier.name || t("account_usage_tier_free")}
@ -217,10 +249,17 @@ const FeatureItem = (props) => {
}; };
const Action = { const Action = {
REDIRECT_SIGNUP: 0, REDIRECT_SIGNUP: 1,
CREATE_SUBSCRIPTION: 1, CREATE_SUBSCRIPTION: 2,
UPDATE_SUBSCRIPTION: 2, UPDATE_SUBSCRIPTION: 3,
CANCEL_SUBSCRIPTION: 3 CANCEL_SUBSCRIPTION: 4
}; };
const Banner = {
CANCEL_WARNING: 1,
PRORATION_INFO: 2,
RESERVATIONS_WARNING: 3
};
export default UpgradeDialog; export default UpgradeDialog;