Payments webhook test, delete attachments/messages when reservations are removed,

This commit is contained in:
binwiederhier 2023-01-20 22:47:37 -05:00
parent 45b97c7054
commit 31a3bb7cd6
16 changed files with 571 additions and 157 deletions

View file

@ -100,22 +100,24 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic
if !util.Contains([]string{"", "read-write", "rw", "read-only", "read", "ro", "write-only", "write", "wo", "none", "deny"}, perms) {
return errors.New("permission must be one of: read-write, read-only, write-only, or deny (or the aliases: read, ro, write, wo, none)")
}
read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms)
write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms)
permission, err := user.ParsePermission(perms)
if err != nil {
return err
}
u, err := manager.User(username)
if err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username)
} else if u.Role == user.RoleAdmin {
return fmt.Errorf("user %s is an admin user, access control entries have no effect", username)
}
if err := manager.AllowAccess("", username, topic, read, write); err != nil {
if err := manager.AllowAccess(username, topic, permission); err != nil {
return err
}
if read && write {
if permission.IsReadWrite() {
fmt.Fprintf(c.App.ErrWriter, "granted read-write access to topic %s\n\n", topic)
} else if read {
} else if permission.IsRead() {
fmt.Fprintf(c.App.ErrWriter, "granted read-only access to topic %s\n\n", topic)
} else if write {
} else if permission.IsWrite() {
fmt.Fprintf(c.App.ErrWriter, "granted write-only access to topic %s\n\n", topic)
} else {
fmt.Fprintf(c.App.ErrWriter, "revoked all access to topic %s\n\n", topic)

View file

@ -58,6 +58,7 @@ const (
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesSinceTimeQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
@ -96,7 +97,7 @@ const (
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
)
@ -506,6 +507,20 @@ func (c *messageCache) DeleteMessages(ids ...string) error {
return tx.Commit()
}
func (c *messageCache) ExpireMessages(topics ...string) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, t := range topics {
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix(), t); err != nil {
return err
}
}
return tx.Commit()
}
func (c *messageCache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
if err != nil {

View file

@ -362,6 +362,61 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
require.Equal(t, int64(0), size)
}
func TestSqliteCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newSqliteTestCache(t))
}
func TestMemCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newMemTestCache(t))
}
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic", "message with attachment")
m.ID = "m2"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: time.Now().Add(2 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic", "message with external attachment")
m.ID = "m3"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Expires: 0, // Unknown!
URL: "https://somedomain.com/car.jpg",
}
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic2", "message with expired attachment")
m.ID = "m4"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "expired-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: time.Now().Add(-1 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
ids, err := c.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "m4", ids[0])
}
func TestSqliteCache_Migration_From0(t *testing.T) {
filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename)

View file

@ -40,13 +40,15 @@ import (
- v.user --> see publishSyncEventAsync() test
payments:
- delete messages + reserved topics on ResetTier
- delete messages + reserved topics on ResetTier delete attachments in access.go
- reconciliation
Limits & rate limiting:
users without tier: should the stats be persisted? are they meaningful?
-> test that the visitor is based on the IP address!
users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
login/account endpoints
when ResetStats() is run, reset messagesLimiter (and others)?
Delete visitor when tier is changed to refresh rate limiters
Make sure account endpoints make sense for admins
UI:
@ -55,10 +57,9 @@ import (
- JS constants
Sync:
- sync problems with "deleteAfter=0" and "displayName="
Delete visitor when tier is changed to refresh rate limiters
Tests:
- Payment endpoints (make mocks)
- Change tier from higher to lower tier (delete reservations)
- Message rate limiting and reset tests
- test that the visitor is based on the IP address when a user has no tier
*/

View file

@ -119,7 +119,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
}
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID != "" {
if v.user.Billing.StripeSubscriptionID != "" {
log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" {
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
@ -332,11 +332,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
return errHTTPTooManyRequestsLimitReservations
}
}
owner, username := v.user.Name, v.user.Name
if err := s.userManager.AllowAccess(owner, username, req.Topic, true, true); err != nil {
return err
}
if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyone.IsRead(), everyone.IsWrite()); err != nil {
if err := s.userManager.ReserveAccess(v.user.Name, req.Topic, everyone); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
@ -357,10 +353,7 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
} else if !authorized {
return errHTTPUnauthorized
}
if err := s.userManager.ResetAccess(v.user.Name, topic); err != nil {
return err
}
if err := s.userManager.ResetAccess(user.Everyone, topic); err != nil {
if err := s.userManager.RemoveReservations(v.user.Name, topic); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())

View file

@ -27,6 +27,28 @@ var (
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
)
// Payments in ntfy are done via Stripe.
//
// Pretty much all payments related things are in this file. The following processes
// handle payments:
//
// - Checkout:
// Creating a Stripe customer and subscription via the Checkout flow. This flow is only used if the
// ntfy user is not already a Stripe customer. This requires redirecting to the Stripe checkout page.
// It is implemented in handleAccountBillingSubscriptionCreate and the success callback
// handleAccountBillingSubscriptionCreateSuccess.
// - Update subscription:
// Switching between Stripe subscriptions (upgrade/downgrade) is handled via
// handleAccountBillingSubscriptionUpdate. This also handles proration.
// - Cancel subscription (at period end):
// Users can cancel the Stripe subscription via the web app at the end of the billing period. This
// simply updates the subscription and Stripe will cancel it. Users cannot immediately cancel the
// subscription.
// - Webhooks:
// Whenever a subscription changes (updated, deleted), Stripe sends us a request via a webhook.
// This is used to keep the local user database fields up to date. Stripe is the source of truth.
// What Stripe says is mirrored and not questioned.
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
@ -37,7 +59,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
freeTier := defaultVisitorLimits(s.config)
response := []*apiAccountBillingTier{
{
// Free tier: no code, name or price
// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
Limits: &apiAccountLimits{
Messages: freeTier.MessagesLimit,
MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
@ -130,6 +152,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
return s.writeJSON(w, response)
}
// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
// and only time we can map the local username with the Stripe customer ID.
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
@ -139,8 +164,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
sessionID := matches[1]
sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
if err != nil {
log.Warn("Stripe: %s", err)
return errHTTPBadRequestBillingRequestInvalid
return err
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
}
@ -158,7 +182,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil {
if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err
}
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -216,6 +240,8 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
return s.writeJSON(w, newSuccessResponse())
}
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
@ -250,10 +276,11 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
}
event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
if err != nil {
return errHTTPBadRequestBillingRequestInvalid
return err
} else if event.Data == nil || event.Data.Raw == nil {
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: webhook event %s received", event.Type)
switch event.Type {
case "customer.subscription.updated":
@ -282,7 +309,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil {
if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
@ -301,29 +328,54 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil {
if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil
}
func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error {
u.Billing.StripeCustomerID = customerID
u.Billing.StripeSubscriptionID = subscriptionID
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status)
u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0)
if tier == "" {
func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
// Remove excess reservations (if too many for tier), and mark associated messages deleted
reservations, err := s.userManager.Reservations(u.Name)
if err != nil {
return err
}
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationsLimit
}
if int64(len(reservations)) > reservationsLimit {
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
}
// Change or remove tier
if tier == nil {
if err := s.userManager.ResetTier(u.Name); err != nil {
return err
}
} else {
if err := s.userManager.ChangeTier(u.Name, tier); err != nil {
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
}
if err := s.userManager.ChangeBilling(u); err != nil {
// Update billing fields
billing := &user.Billing{
StripeCustomerID: customerID,
StripeSubscriptionID: subscriptionID,
StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
}
if err := s.userManager.ChangeBilling(u.Name, billing); err != nil {
return err
}
return nil

View file

@ -1,13 +1,17 @@
package server
import (
"encoding/json"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stripe/stripe-go/v74"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"path/filepath"
"strings"
"testing"
"time"
)
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
@ -70,8 +74,10 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
u, err := s.userManager.User("phil")
require.Nil(t, err)
u.Billing.StripeCustomerID = "acct_123"
require.Nil(t, s.userManager.ChangeBilling(u))
billing := &user.Billing{
StripeCustomerID: "acct_123",
}
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
@ -109,9 +115,11 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
u, err := s.userManager.User("phil")
require.Nil(t, err)
u.Billing.StripeCustomerID = "acct_123"
u.Billing.StripeSubscriptionID = "sub_123"
require.Nil(t, s.userManager.ChangeBilling(u))
billing := &user.Billing{
StripeCustomerID: "acct_123",
StripeSubscriptionID: "sub_123",
}
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Delete account
rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{
@ -125,6 +133,127 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
require.Equal(t, 401, rr.Code)
}
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
// This tests incoming webhooks from Stripe to update a subscription:
// - All Stripe columns are updated in the user table
// - When downgrading, excess reservations are deleted, including messages and attachments in
// the corresponding topics
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
s.stripe = stripeMock
// Define how the mock should react
stripeMock.
On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
// Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "starter",
StripePriceID: "price_1234", // !
ReservationsLimit: 1, // !
MessagesLimit: 100,
MessagesExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
}))
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
StripePriceID: "price_1111", // !
ReservationsLimit: 3, // !
MessagesLimit: 200,
MessagesExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll))
// Add billing details
u, err := s.userManager.User("phil")
require.Nil(t, err)
billing := &user.Billing{
StripeCustomerID: "acct_5555",
StripeSubscriptionID: "sub_1234",
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
StripeSubscriptionPaidUntil: time.Unix(123, 0),
StripeSubscriptionCancelAt: time.Unix(456, 0),
}
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
a2 := toMessage(t, rr.Body.String())
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
z2 := toMessage(t, rr.Body.String())
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
// Call the webhook: This does all the magic
rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
"Stripe-Signature": "stripe signature",
})
require.Equal(t, 200, rr.Code)
// Verify that database columns were updated
u, err = s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
// Verify that reservations were deleted
r, err := s.userManager.Reservations("phil")
require.Nil(t, err)
require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
require.Equal(t, "atopic", r[0].Topic)
// Verify that messages and attachments were deleted
time.Sleep(time.Second)
s.execManager()
ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(ms))
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 0, len(ms))
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
}
type testStripeAPI struct {
mock.Mock
}
@ -175,3 +304,34 @@ func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, sec
}
var _ stripeAPI = (*testStripeAPI)(nil)
func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
var e stripe.Event
if err := json.Unmarshal([]byte(v), &e); err != nil {
t.Fatal(err)
}
return e
}
const subscriptionUpdatedEventJSON = `
{
"type": "customer.subscription.updated",
"data": {
"object": {
"id": "sub_1234",
"customer": "acct_5555",
"status": "active",
"current_period_end": 1674268231,
"cancel_at": 1674299999,
"items": {
"data": [
{
"price": {
"id": "price_1234"
}
}
]
}
}
}
}`

View file

@ -640,7 +640,7 @@ func TestServer_Auth_Success_User(t *testing.T) {
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", "ben", "mytopic", true, true))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
@ -654,8 +654,8 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", "ben", "mytopic", true, true))
require.Nil(t, s.userManager.AllowAccess("", "ben", "anothertopic", true, true))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite))
response := request(t, s, "GET", "/mytopic,anothertopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
@ -688,7 +688,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", "ben", "sometopic", true, true)) // Not mytopic!
require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic!
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
@ -702,8 +702,8 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", user.Everyone, "private", false, false))
require.Nil(t, s.userManager.AllowAccess("", user.Everyone, "announcements", true, false))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
response := request(t, s, "PUT", "/mytopic", "test", nil)
require.Equal(t, 200, response.Code)
@ -750,7 +750,7 @@ func TestServer_StatsResetter(t *testing.T) {
go s.runStatsResetter()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AllowAccess("", "phil", "mytopic", true, true))
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
for i := 0; i < 5; i++ {
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{

View file

@ -16,6 +16,10 @@ const (
// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
// they are replenished faster (typically).
visitorExpungeAfter = 24 * time.Hour
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
visitorDefaultReservationsLimit = int64(0)
)
var (
@ -289,7 +293,7 @@ func defaultVisitorLimits(conf *Config) *visitorLimits {
MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish),
MessagesExpiryDuration: conf.CacheDuration,
EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish),
ReservationsLimit: 0, // No reservations for anonymous users, or users without a tier
ReservationsLimit: visitorDefaultReservationsLimit,
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,

View file

@ -219,7 +219,6 @@ const (
INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
selectTiersQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
@ -234,7 +233,7 @@ const (
FROM tier
WHERE stripe_price_id = ?
`
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
updateBillingQuery = `
@ -772,26 +771,47 @@ func (a *Manager) ChangeRole(username string, role Role) error {
return nil
}
// ChangeTier changes a user's tier using the tier code
// ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
// or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
func (a *Manager) ChangeTier(username, tier string) error {
if !AllowedUsername(username) {
return ErrInvalidArgument
}
rows, err := a.db.Query(selectTierIDQuery, tier)
t, err := a.Tier(tier)
if err != nil {
return err
} else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
return err
}
if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
return err
}
return nil
}
// ResetTier removes the tier from the given user
func (a *Manager) ResetTier(username string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument
} else if err := a.checkReservationsLimit(username, 0); err != nil {
return err
}
_, err := a.db.Exec(deleteUserTierQuery, username)
return err
}
func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
u, err := a.User(username)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
return ErrInvalidArgument
}
var tierID int64
if err := rows.Scan(&tierID); err != nil {
if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
reservations, err := a.Reservations(username)
if err != nil {
return err
} else if int64(len(reservations)) > reservationsLimit {
return ErrTooManyReservations
}
rows.Close()
if _, err := a.db.Exec(updateUserTierQuery, tierID, username); err != nil {
return err
}
return nil
}
@ -823,20 +843,37 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error {
// AllowAccess adds or updates an entry in th access control list for a specific user. It controls
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
// owner may either be a user (username), or the system (empty).
func (a *Manager) AllowAccess(owner, username string, topicPattern string, read bool, write bool) error {
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument
} else if owner != "" && !AllowedUsername(owner) {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument
}
if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), read, write, owner, owner); err != nil {
owner := ""
if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
return err
}
return nil
}
func (a *Manager) ReserveAccess(username string, topic string, everyone Permission) error {
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
return err
}
if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
return err
}
return tx.Commit()
}
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *Manager) ResetAccess(username string, topicPattern string) error {
@ -856,14 +893,30 @@ func (a *Manager) ResetAccess(username string, topicPattern string) error {
return err
}
// ResetTier removes the tier from the given user
func (a *Manager) ResetTier(username string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
func (a *Manager) RemoveReservations(username string, topics ...string) error {
if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
return ErrInvalidArgument
}
_, err := a.db.Exec(deleteUserTierQuery, username)
for _, topic := range topics {
if !AllowedTopic(topic) {
return ErrInvalidArgument
}
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, topic := range topics {
if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
return err
}
if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
return err
}
}
return tx.Commit()
}
// DefaultAccess returns the default read/write access if no access control entry matches
func (a *Manager) DefaultAccess() Permission {
@ -879,8 +932,8 @@ func (a *Manager) CreateTier(tier *Tier) error {
}
// ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
func (a *Manager) ChangeBilling(user *User) error {
if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(user.Billing.StripeSubscriptionCancelAt.Unix()), user.Name); err != nil {
func (a *Manager) ChangeBilling(username string, billing *Billing) error {
if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
return err
}
return nil

View file

@ -15,13 +15,13 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true))
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false))
require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true))
require.Nil(t, a.AllowAccess("", "ben", "everyonewrite", false, false)) // How unfair!
require.Nil(t, a.AllowAccess("", Everyone, "announcements", true, false))
require.Nil(t, a.AllowAccess("", Everyone, "everyonewrite", true, true))
require.Nil(t, a.AllowAccess("", Everyone, "up*", false, true)) // Everyone can write to /up*
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
require.Nil(t, a.AllowAccess("ben", "everyonewrite", PermissionDenyAll)) // How unfair!
require.Nil(t, a.AllowAccess(Everyone, "announcements", PermissionRead))
require.Nil(t, a.AllowAccess(Everyone, "everyonewrite", PermissionReadWrite))
require.Nil(t, a.AllowAccess(Everyone, "up*", PermissionWrite)) // Everyone can write to /up*
phil, err := a.Authenticate("phil", "phil")
require.Nil(t, err)
@ -130,12 +130,12 @@ func TestManager_UserManagement(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true))
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false))
require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true))
require.Nil(t, a.AllowAccess("", "ben", "everyonewrite", false, false)) // How unfair!
require.Nil(t, a.AllowAccess("", Everyone, "announcements", true, false))
require.Nil(t, a.AllowAccess("", Everyone, "everyonewrite", true, true))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
require.Nil(t, a.AllowAccess("ben", "everyonewrite", PermissionDenyAll)) // How unfair!
require.Nil(t, a.AllowAccess(Everyone, "announcements", PermissionRead))
require.Nil(t, a.AllowAccess(Everyone, "everyonewrite", PermissionReadWrite))
// Query user details
phil, err := a.User("phil")
@ -177,9 +177,9 @@ func TestManager_UserManagement(t *testing.T) {
}, everyoneGrants)
// Ben: Before revoking
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) // Overwrite!
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false))
require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) // Overwrite!
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
require.Nil(t, a.Authorize(ben, "mytopic", PermissionRead))
require.Nil(t, a.Authorize(ben, "mytopic", PermissionWrite))
require.Nil(t, a.Authorize(ben, "readme", PermissionRead))
@ -234,8 +234,8 @@ func TestManager_ChangePassword(t *testing.T) {
func TestManager_ChangeRole(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true))
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
ben, err := a.User("ben")
require.Nil(t, err)
@ -256,6 +256,28 @@ func TestManager_ChangeRole(t *testing.T) {
require.Equal(t, 0, len(benGrants))
}
func TestManager_Reservations(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.ReserveAccess("ben", "ztopic", PermissionDenyAll))
require.Nil(t, a.ReserveAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
reservations, err := a.Reservations("ben")
require.Nil(t, err)
require.Equal(t, 2, len(reservations))
require.Equal(t, Reservation{
Topic: "readme",
Owner: PermissionReadWrite,
Everyone: PermissionRead,
}, reservations[0])
require.Equal(t, Reservation{
Topic: "ztopic",
Owner: PermissionReadWrite,
Everyone: PermissionDenyAll,
}, reservations[1])
}
func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.CreateTier(&Tier{
@ -272,8 +294,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
}))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.ChangeTier("ben", "pro"))
require.Nil(t, a.AllowAccess("ben", "ben", "mytopic", true, true))
require.Nil(t, a.AllowAccess("ben", Everyone, "mytopic", false, false))
require.Nil(t, a.ReserveAccess("ben", "mytopic", PermissionDenyAll))
ben, err := a.User("ben")
require.Nil(t, err)
@ -298,6 +319,13 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
require.Equal(t, 1, len(everyoneGrants))
require.Equal(t, PermissionDenyAll, everyoneGrants[0].Allow)
benReservations, err := a.Reservations("ben")
require.Nil(t, err)
require.Equal(t, 1, len(benReservations))
require.Equal(t, "mytopic", benReservations[0].Topic)
require.Equal(t, PermissionReadWrite, benReservations[0].Owner)
require.Equal(t, PermissionDenyAll, benReservations[0].Everyone)
// Switch to admin, this should remove all grants and owned ACL entries
require.Nil(t, a.ChangeRole("ben", RoleAdmin))

View file

@ -226,14 +226,5 @@ var (
ErrInvalidArgument = errors.New("invalid argument")
ErrUserNotFound = errors.New("user not found")
ErrTierNotFound = errors.New("tier not found")
)
// BillingStatus represents the status of a Stripe subscription
type BillingStatus string
// BillingStatus values, subset of https://stripe.com/docs/billing/subscriptions/overview
const (
BillingStatusIncomplete = BillingStatus("incomplete")
BillingStatusActive = BillingStatus("active")
BillingStatusPastDue = BillingStatus("past_due")
ErrTooManyReservations = errors.New("new tier has lower reservation limit")
)

View file

@ -201,15 +201,19 @@
"account_delete_dialog_label": "Type '{{username}}' to delete account",
"account_delete_dialog_button_cancel": "Cancel",
"account_delete_dialog_button_submit": "Permanently delete account",
"account_delete_dialog_billing_warning": "Deleting your account also cancels your billing subscription immediately. You will not have access to the billing dashboard anymore.",
"account_upgrade_dialog_title": "Change account tier",
"account_upgrade_dialog_cancel_warning": "This will <strong>cancel your subscription</strong>, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server <strong>will be deleted</strong>.",
"account_upgrade_dialog_proration_info": "<strong>Proration</strong>: When switching between paid plans, the price difference will be charged or refunded in the next invoice. You will not receive another invoice until the end of the next billing period.",
"account_upgrade_dialog_reservations_warning_one": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least one reservation</strong>. You can remove reservations in the <Link>Settings</Link>.",
"account_upgrade_dialog_reservations_warning_other": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least {{count}} reservations</strong>. You can remove reservations in the <Link>Settings</Link>.",
"account_upgrade_dialog_tier_features_reservations": "{{reservations}} reserved topics",
"account_upgrade_dialog_tier_features_messages": "{{messages}} daily messages",
"account_upgrade_dialog_tier_features_emails": "{{emails}} daily emails",
"account_upgrade_dialog_tier_features_attachment_file_size": "{{filesize}} per file",
"account_upgrade_dialog_tier_features_attachment_total_size": "{{totalsize}} total storage",
"account_upgrade_dialog_tier_selected_label": "Selected",
"account_upgrade_dialog_tier_current_label": "Current",
"account_upgrade_dialog_button_cancel": "Cancel",
"account_upgrade_dialog_button_redirect_signup": "Sign up now",
"account_upgrade_dialog_button_pay_now": "Pay now and subscribe",

View file

@ -264,7 +264,6 @@ const AccountType = () => {
const Stats = () => {
const { t } = useTranslation();
const { account } = useContext(AccountContext);
const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
if (!account) {
return <></>;
@ -435,6 +434,7 @@ const DeleteAccount = () => {
const DeleteAccountDialog = (props) => {
const { t } = useTranslation();
const { account } = useContext(AccountContext);
const [username, setUsername] = useState("");
const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const buttonEnabled = username === session.username();
@ -456,6 +456,9 @@ const DeleteAccountDialog = (props) => {
fullWidth
variant="standard"
/>
{account?.billing?.subscription &&
<Alert severity="warning" sx={{mt: 1}}>{t("account_delete_dialog_billing_warning")}</Alert>
}
</DialogContent>
<DialogActions>
<Button onClick={props.onCancel}>{t("account_delete_dialog_button_cancel")}</Button>

View file

@ -3,7 +3,7 @@ import {useContext, useEffect, useState} from 'react';
import {
Alert,
CardActions,
CardContent,
CardContent, Chip,
FormControl,
Select,
Stack,
@ -20,6 +20,7 @@ import prefs from "../app/Prefs";
import {Paragraph} from "./styles";
import EditIcon from '@mui/icons-material/Edit';
import CloseIcon from "@mui/icons-material/Close";
import WarningIcon from '@mui/icons-material/Warning';
import IconButton from "@mui/material/IconButton";
import PlayArrowIcon from '@mui/icons-material/PlayArrow';
import Container from "@mui/material/Container";
@ -41,10 +42,12 @@ import routes from "./routes";
import accountApi, {UnauthorizedError} from "../app/AccountApi";
import {Pref, PrefGroup} from "./Pref";
import LockIcon from "@mui/icons-material/Lock";
import {Public, PublicOff} from "@mui/icons-material";
import {Check, Info, Public, PublicOff} from "@mui/icons-material";
import DialogContentText from "@mui/material/DialogContentText";
import ReserveTopicSelect from "./ReserveTopicSelect";
import {AccountContext} from "./App";
import {useOutletContext} from "react-router-dom";
import subscriptionManager from "../app/SubscriptionManager";
const Preferences = () => {
return (
@ -543,6 +546,12 @@ const ReservationsTable = (props) => {
const [dialogKey, setDialogKey] = useState(0);
const [dialogOpen, setDialogOpen] = useState(false);
const [dialogReservation, setDialogReservation] = useState(null);
const { subscriptions } = useOutletContext();
const localSubscriptions = Object.assign(
...subscriptions
.filter(s => s.baseUrl === config.base_url)
.map(s => ({[s.topic]: s}))
);
const handleEditClick = (reservation) => {
setDialogKey(prev => prev+1);
@ -592,7 +601,9 @@ const ReservationsTable = (props) => {
key={reservation.topic}
sx={{'&:last-child td, &:last-child th': {border: 0}}}
>
<TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("prefs_reservations_table_topic_header")}>{reservation.topic}</TableCell>
<TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("prefs_reservations_table_topic_header")}>
{reservation.topic}
</TableCell>
<TableCell aria-label={t("prefs_reservations_table_access_header")}>
{reservation.everyone === "read-write" &&
<>
@ -620,6 +631,9 @@ const ReservationsTable = (props) => {
}
</TableCell>
<TableCell align="right">
{!localSubscriptions[reservation.topic] &&
<Chip icon={<Info/>} label="Not subscribed" color="primary" variant="outlined"/>
}
<IconButton onClick={() => handleEditClick(reservation)} aria-label={t("prefs_reservations_edit_button")}>
<EditIcon/>
</IconButton>

View file

@ -21,13 +21,14 @@ import {Check} from "@mui/icons-material";
import ListItemIcon from "@mui/material/ListItemIcon";
import ListItemText from "@mui/material/ListItemText";
import Box from "@mui/material/Box";
import {NavLink} from "react-router-dom";
const UpgradeDialog = (props) => {
const { t } = useTranslation();
const { account } = useContext(AccountContext); // May be undefined!
const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const [tiers, setTiers] = useState(null);
const [newTier, setNewTier] = useState(account?.tier?.code); // May be undefined
const [newTierCode, setNewTierCode] = useState(account?.tier?.code); // May be undefined
const [loading, setLoading] = useState(false);
const [errorText, setErrorText] = useState("");
@ -41,47 +42,56 @@ const UpgradeDialog = (props) => {
return <></>;
}
const currentTier = account?.tier?.code; // May be undefined
let action, submitButtonLabel, submitButtonEnabled;
const tiersMap = Object.assign(...tiers.map(tier => ({[tier.code]: tier})));
const newTier = tiersMap[newTierCode]; // May be undefined
const currentTier = account?.tier; // May be undefined
const currentTierCode = currentTier?.code; // May be undefined
// Figure out buttons, labels and the submit action
let submitAction, submitButtonLabel, banner;
if (!account) {
submitButtonLabel = t("account_upgrade_dialog_button_redirect_signup");
submitButtonEnabled = true;
action = Action.REDIRECT_SIGNUP;
} else if (currentTier === newTier) {
submitAction = Action.REDIRECT_SIGNUP;
banner = null;
} else if (currentTierCode === newTierCode) {
submitButtonLabel = t("account_upgrade_dialog_button_update_subscription");
submitButtonEnabled = false;
action = null;
} else if (!currentTier) {
submitAction = null;
banner = (currentTierCode) ? Banner.PRORATION_INFO : null;
} else if (!currentTierCode) {
submitButtonLabel = t("account_upgrade_dialog_button_pay_now");
submitButtonEnabled = true;
action = Action.CREATE_SUBSCRIPTION;
} else if (!newTier) {
submitAction = Action.CREATE_SUBSCRIPTION;
banner = null;
} else if (!newTierCode) {
submitButtonLabel = t("account_upgrade_dialog_button_cancel_subscription");
submitButtonEnabled = true;
action = Action.CANCEL_SUBSCRIPTION;
submitAction = Action.CANCEL_SUBSCRIPTION;
banner = Banner.CANCEL_WARNING;
} else {
submitButtonLabel = t("account_upgrade_dialog_button_update_subscription");
submitButtonEnabled = true;
action = Action.UPDATE_SUBSCRIPTION;
submitAction = Action.UPDATE_SUBSCRIPTION;
banner = Banner.PRORATION_INFO;
}
// Exceptional conditions
if (loading) {
submitButtonEnabled = false;
submitAction = null;
} else if (newTier?.code && account?.reservations.length > newTier?.limits.reservations) {
submitAction = null;
banner = Banner.RESERVATIONS_WARNING;
}
const handleSubmit = async () => {
if (action === Action.REDIRECT_SIGNUP) {
if (submitAction === Action.REDIRECT_SIGNUP) {
window.location.href = routes.signup;
return;
}
try {
setLoading(true);
if (action === Action.CREATE_SUBSCRIPTION) {
const response = await accountApi.createBillingSubscription(newTier);
if (submitAction === Action.CREATE_SUBSCRIPTION) {
const response = await accountApi.createBillingSubscription(newTierCode);
window.location.href = response.redirect_url;
} else if (action === Action.UPDATE_SUBSCRIPTION) {
await accountApi.updateBillingSubscription(newTier);
} else if (action === Action.CANCEL_SUBSCRIPTION) {
} else if (submitAction === Action.UPDATE_SUBSCRIPTION) {
await accountApi.updateBillingSubscription(newTierCode);
} else if (submitAction === Action.CANCEL_SUBSCRIPTION) {
await accountApi.deleteBillingSubscription();
}
props.onCancel();
@ -116,27 +126,39 @@ const UpgradeDialog = (props) => {
<TierCard
key={`tierCard${tier.code || '_free'}`}
tier={tier}
selected={newTier === tier.code} // tier.code may be undefined!
onClick={() => setNewTier(tier.code)} // tier.code may be undefined!
current={currentTierCode === tier.code} // tier.code or currentTierCode may be undefined!
selected={newTierCode === tier.code} // tier.code may be undefined!
onClick={() => setNewTierCode(tier.code)} // tier.code may be undefined!
/>
)}
</div>
{action === Action.CANCEL_SUBSCRIPTION &&
{banner === Banner.CANCEL_WARNING &&
<Alert severity="warning">
<Trans
i18nKey="account_upgrade_dialog_cancel_warning"
values={{ date: formatShortDate(account?.billing?.paid_until || 0) }} />
</Alert>
}
{currentTier && (!action || action === Action.UPDATE_SUBSCRIPTION) &&
{banner === Banner.PRORATION_INFO &&
<Alert severity="info">
<Trans i18nKey="account_upgrade_dialog_proration_info" />
</Alert>
}
{banner === Banner.RESERVATIONS_WARNING &&
<Alert severity="warning">
<Trans
i18nKey="account_upgrade_dialog_reservations_warning"
count={account?.reservations.length - newTier?.limits.reservations}
components={{
Link: <NavLink to={routes.settings}/>,
}}
/>
</Alert>
}
</DialogContent>
<DialogFooter status={errorText}>
<Button onClick={props.onCancel}>{t("account_upgrade_dialog_button_cancel")}</Button>
<Button onClick={handleSubmit} disabled={!submitButtonEnabled}>{submitButtonLabel}</Button>
<Button onClick={handleSubmit} disabled={!submitAction}>{submitButtonLabel}</Button>
</DialogFooter>
</Dialog>
);
@ -144,8 +166,19 @@ const UpgradeDialog = (props) => {
const TierCard = (props) => {
const { t } = useTranslation();
const cardStyle = (props.selected) ? { background: "#eee", border: "2px solid #338574" } : { border: "2px solid transparent" };
const tier = props.tier;
let cardStyle, labelStyle, labelText;
if (props.selected) {
cardStyle = { background: "#eee", border: "2px solid #338574" };
labelStyle = { background: "#338574", color: "white" };
labelText = t("account_upgrade_dialog_tier_selected_label");
} else if (props.current) {
cardStyle = { border: "2px solid #eee" };
labelStyle = { background: "#eee", color: "black" };
labelText = t("account_upgrade_dialog_tier_current_label");
} else {
cardStyle = { border: "2px solid transparent" };
}
return (
<Box sx={{
@ -163,16 +196,15 @@ const TierCard = (props) => {
<Card sx={{ height: "100%" }}>
<CardActionArea sx={{ height: "100%" }}>
<CardContent onClick={props.onClick} sx={{ height: "100%" }}>
{props.selected &&
{labelStyle &&
<div style={{
position: "absolute",
top: "0",
right: "15px",
padding: "2px 10px",
background: "#338574",
color: "white",
borderRadius: "3px",
}}>{t("account_upgrade_dialog_tier_selected_label")}</div>
...labelStyle
}}>{labelText}</div>
}
<Typography variant="h5" component="div">
{tier.name || t("account_usage_tier_free")}
@ -217,10 +249,17 @@ const FeatureItem = (props) => {
};
const Action = {
REDIRECT_SIGNUP: 0,
CREATE_SUBSCRIPTION: 1,
UPDATE_SUBSCRIPTION: 2,
CANCEL_SUBSCRIPTION: 3
REDIRECT_SIGNUP: 1,
CREATE_SUBSCRIPTION: 2,
UPDATE_SUBSCRIPTION: 3,
CANCEL_SUBSCRIPTION: 4
};
const Banner = {
CANCEL_WARNING: 1,
PRORATION_INFO: 2,
RESERVATIONS_WARNING: 3
};
export default UpgradeDialog;