Upgrade dialog

This commit is contained in:
binwiederhier 2023-01-17 10:09:37 -05:00
parent 83de879894
commit 695c1349e8
11 changed files with 290 additions and 137 deletions

View file

@ -52,7 +52,6 @@ import (
update last_seen when API is accessed update last_seen when API is accessed
Make sure account endpoints make sense for admins Make sure account endpoints make sense for admins
triggerChange after publishing a message
UI: UI:
- flicker of upgrade banner - flicker of upgrade banner
- JS constants - JS constants
@ -83,6 +82,7 @@ type Server struct {
userManager *user.Manager // Might be nil! userManager *user.Manager // Might be nil!
messageCache *messageCache messageCache *messageCache
fileCache *fileCache fileCache *fileCache
priceCache map[string]string // Stripe price ID -> formatted price
closeChan chan bool closeChan chan bool
mu sync.Mutex mu sync.Mutex
} }
@ -102,27 +102,29 @@ var (
authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`) authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`) publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
webConfigPath = "/config.js" webConfigPath = "/config.js"
healthPath = "/v1/health" accountPath = "/account"
accountPath = "/v1/account" matrixPushPath = "/_matrix/push/v1/notify"
accountTokenPath = "/v1/account/token" apiHealthPath = "/v1/health"
accountPasswordPath = "/v1/account/password" apiAccountPath = "/v1/account"
accountSettingsPath = "/v1/account/settings" apiAccountTokenPath = "/v1/account/token"
accountSubscriptionPath = "/v1/account/subscription" apiAccountPasswordPath = "/v1/account/password"
accountReservationPath = "/v1/account/reservation" apiAccountSettingsPath = "/v1/account/settings"
accountBillingPortalPath = "/v1/account/billing/portal" apiAccountSubscriptionPath = "/v1/account/subscription"
accountBillingWebhookPath = "/v1/account/billing/webhook" apiAccountReservationPath = "/v1/account/reservation"
accountBillingSubscriptionPath = "/v1/account/billing/subscription" apiAccountBillingTiersPath = "/v1/account/billing/tiers"
accountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}" apiAccountBillingPortalPath = "/v1/account/billing/portal"
accountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`) apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`) apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`) apiAccountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
matrixPushPath = "/_matrix/push/v1/notify" apiAccountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
staticRegex = regexp.MustCompile(`^/static/.+`) apiAccountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) apiAccountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`) staticRegex = regexp.MustCompile(`^/static/.+`)
disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
urlRegex = regexp.MustCompile(`^https?://`) fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app
urlRegex = regexp.MustCompile(`^https?://`)
//go:embed site //go:embed site
webFs embed.FS webFs embed.FS
@ -199,6 +201,7 @@ func New(conf *Config) (*Server, error) {
topics: topics, topics: topics,
userManager: userManager, userManager: userManager,
visitors: make(map[string]*visitor), visitors: make(map[string]*visitor),
priceCache: make(map[string]string),
}, nil }, nil
} }
@ -347,47 +350,49 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureWebEnabled(s.handleHome)(w, r, v) return s.ensureWebEnabled(s.handleHome)(w, r, v)
} else if r.Method == http.MethodHead && r.URL.Path == "/" { } else if r.Method == http.MethodHead && r.URL.Path == "/" {
return s.ensureWebEnabled(s.handleEmpty)(w, r, v) return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == healthPath { } else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
return s.handleHealth(w, r, v) return s.handleHealth(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath { } else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v) return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountPath { } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
return s.ensureUserManager(s.handleAccountCreate)(w, r, v) return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountTokenPath { } else if r.Method == http.MethodPost && r.URL.Path == apiAccountTokenPath {
return s.ensureUser(s.handleAccountTokenIssue)(w, r, v) return s.ensureUser(s.handleAccountTokenIssue)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == accountPath { } else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
return s.handleAccountGet(w, r, v) // Allowed by anonymous return s.handleAccountGet(w, r, v) // Allowed by anonymous
} else if r.Method == http.MethodDelete && r.URL.Path == accountPath { } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPath {
return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v) return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath { } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPasswordPath {
return s.ensureUser(s.handleAccountPasswordChange)(w, r, v) return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
} else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath { } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountTokenPath {
return s.ensureUser(s.handleAccountTokenExtend)(w, r, v) return s.ensureUser(s.handleAccountTokenExtend)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath { } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountTokenPath {
return s.ensureUser(s.handleAccountTokenDelete)(w, r, v) return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
} else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath { } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSettingsPath {
return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v) return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath { } else if r.Method == http.MethodPost && r.URL.Path == apiAccountSubscriptionPath {
return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v) return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
} else if r.Method == http.MethodPatch && accountSubscriptionSingleRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodPatch && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v) return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
} else if r.Method == http.MethodDelete && accountSubscriptionSingleRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodDelete && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v) return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountReservationPath { } else if r.Method == http.MethodPost && r.URL.Path == apiAccountReservationPath {
return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v) return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v) return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingSubscriptionPath { } else if r.Method == http.MethodGet && r.URL.Path == apiAccountBillingTiersPath {
return s.ensurePaymentsEnabled(s.handleAccountBillingTiersGet)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodGet && accountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context! return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
} else if r.Method == http.MethodPut && r.URL.Path == accountBillingSubscriptionPath { } else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath { } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath { } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath {
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v) return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath { } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath { } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w) return s.handleMatrixDiscovery(w)

View file

@ -91,7 +91,6 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
response.Tier = &apiAccountTier{ response.Tier = &apiAccountTier{
Code: v.user.Tier.Code, Code: v.user.Tier.Code,
Name: v.user.Tier.Name, Name: v.user.Tier.Name,
Paid: v.user.Tier.Paid,
} }
} }
if v.user.Billing.StripeCustomerID != "" { if v.user.Billing.StripeCustomerID != "" {
@ -268,7 +267,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
} }
func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
matches := accountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path) matches := apiAccountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 { if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath return errHTTPInternalErrorInvalidPath
} }
@ -303,7 +302,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
} }
func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
matches := accountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path) matches := apiAccountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 { if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath return errHTTPInternalErrorInvalidPath
} }
@ -374,7 +373,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
} }
func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
matches := accountReservationSingleRegex.FindStringSubmatch(r.URL.Path) matches := apiAccountReservationSingleRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 { if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath return errHTTPInternalErrorInvalidPath
} }

View file

@ -3,14 +3,17 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"github.com/stripe/stripe-go/v74" "github.com/stripe/stripe-go/v74"
portalsession "github.com/stripe/stripe-go/v74/billingportal/session" portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
"github.com/stripe/stripe-go/v74/checkout/session" "github.com/stripe/stripe-go/v74/checkout/session"
"github.com/stripe/stripe-go/v74/customer" "github.com/stripe/stripe-go/v74/customer"
"github.com/stripe/stripe-go/v74/price"
"github.com/stripe/stripe-go/v74/subscription" "github.com/stripe/stripe-go/v74/subscription"
"github.com/stripe/stripe-go/v74/webhook" "github.com/stripe/stripe-go/v74/webhook"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"net/http" "net/http"
"net/netip" "net/netip"
@ -21,6 +24,44 @@ const (
stripeBodyBytesLimit = 16384 stripeBodyBytesLimit = 16384
) )
func (s *Server) handleAccountBillingTiersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
tiers, err := v.userManager.Tiers()
if err != nil {
return err
}
response := make([]*apiAccountBillingTier, 0)
for _, tier := range tiers {
if tier.StripePriceID == "" {
continue
}
priceStr, ok := s.priceCache[tier.StripePriceID]
if !ok {
p, err := price.Get(tier.StripePriceID, nil)
if err != nil {
return err
}
if p.UnitAmount%100 == 0 {
priceStr = fmt.Sprintf("$%d", p.UnitAmount/100)
} else {
priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
}
s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something
}
response = append(response, &apiAccountBillingTier{
Code: tier.Code,
Name: tier.Name,
Price: priceStr,
Features: tier.Features,
})
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
// will be updated by a subsequent webhook from Stripe, once the subscription becomes active. // will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
@ -49,10 +90,10 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
return errors.New("customer cannot have more than one subscription") //FIXME return errors.New("customer cannot have more than one subscription") //FIXME
} }
} }
successURL := s.config.BaseURL + "/account" //+ accountBillingSubscriptionCheckoutSuccessTemplate successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
params := &stripe.CheckoutSessionParams{ params := &stripe.CheckoutSessionParams{
Customer: stripeCustomerID, // A user may have previously deleted their subscription Customer: stripeCustomerID, // A user may have previously deleted their subscription
ClientReferenceID: &v.user.Name, // FIXME Should be user ID ClientReferenceID: &v.user.Name,
SuccessURL: &successURL, SuccessURL: &successURL,
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{ LineItems: []*stripe.CheckoutSessionLineItemParams{
@ -69,7 +110,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
if err != nil { if err != nil {
return err return err
} }
response := &apiAccountCheckoutResponse{ response := &apiAccountBillingSubscriptionCreateResponse{
RedirectURL: sess.URL, RedirectURL: sess.URL,
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -82,29 +123,25 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error { func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
// We don't have a v.user in this endpoint, only a userManager! // We don't have a v.user in this endpoint, only a userManager!
matches := accountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 { if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath return errHTTPInternalErrorInvalidPath
} }
sessionID := matches[1] sessionID := matches[1]
// FIXME how do I rate limit this? sess, err := session.Get(sessionID, nil) // FIXME how do I rate limit this?
sess, err := session.Get(sessionID, nil)
if err != nil { if err != nil {
log.Warn("Stripe: %s", err) log.Warn("Stripe: %s", err)
return errHTTPBadRequestInvalidStripeRequest return errHTTPBadRequestInvalidStripeRequest
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" { } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
log.Warn("Stripe: Unexpected session, customer or subscription not found") return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "customer or subscription not found")
return errHTTPBadRequestInvalidStripeRequest
} }
sub, err := subscription.Get(sess.Subscription.ID, nil) sub, err := subscription.Get(sess.Subscription.ID, nil)
if err != nil { if err != nil {
return err return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil { } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
log.Error("Stripe: Unexpected subscription, expected exactly one line item") return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "more than one line item in existing subscription")
return errHTTPBadRequestInvalidStripeRequest
} }
priceID := sub.Items.Data[0].Price.ID tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil { if err != nil {
return err return err
} }
@ -112,26 +149,17 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil { if err != nil {
return err return err
} }
u.Billing.StripeCustomerID = sess.Customer.ID if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil {
u.Billing.StripeSubscriptionID = sub.ID
u.Billing.StripeSubscriptionStatus = sub.Status
u.Billing.StripeSubscriptionPaidUntil = time.Unix(sub.CurrentPeriodEnd, 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(sub.CancelAt, 0)
if err := s.userManager.ChangeBilling(u); err != nil {
return err return err
} }
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
return err
}
accountURL := s.config.BaseURL + "/account" // FIXME
http.Redirect(w, r, accountURL, http.StatusSeeOther)
return nil return nil
} }
// handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
// a user's tier accordingly. This endpoint only works if there is an existing subscription. // a user's tier accordingly. This endpoint only works if there is an existing subscription.
func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID != "" { if v.user.Billing.StripeSubscriptionID == "" {
return errors.New("no existing subscription for user") return errors.New("no existing subscription for user")
} }
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit) req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
@ -161,10 +189,9 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil { if err != nil {
return err return err
} }
response := &apiAccountCheckoutResponse{} // FIXME
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil { if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
return err return err
} }
return nil return nil
@ -250,6 +277,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() { if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
return errHTTPBadRequestInvalidStripeRequest return errHTTPBadRequestInvalidStripeRequest
} }
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID)
u, err := s.userManager.UserByStripeCustomer(customerID.String()) u, err := s.userManager.UserByStripeCustomer(customerID.String())
if err != nil { if err != nil {
return err return err
@ -258,41 +286,47 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err != nil { if err != nil {
return err return err
} }
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil {
return err return err
} }
u.Billing.StripeSubscriptionID = subscriptionID.String()
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status.String())
u.Billing.StripeSubscriptionPaidUntil = time.Unix(currentPeriodEnd.Int(), 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt.Int(), 0)
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
log.Info("Stripe: customer %s: subscription updated to %s, with price %s", customerID.String(), status, priceID)
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil return nil
} }
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error { func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
stripeCustomerID := gjson.GetBytes(event, "customer") customerID := gjson.GetBytes(event, "customer")
if !stripeCustomerID.Exists() { if !customerID.Exists() {
return errHTTPBadRequestInvalidStripeRequest return errHTTPBadRequestInvalidStripeRequest
} }
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID.String()) log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String())
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID.String()) u, err := s.userManager.UserByStripeCustomer(customerID.String())
if err != nil { if err != nil {
return err return err
} }
if err := s.userManager.ResetTier(u.Name); err != nil { if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil {
return err
}
u.Billing.StripeSubscriptionID = ""
u.Billing.StripeSubscriptionStatus = ""
u.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(0, 0)
if err := s.userManager.ChangeBilling(u); err != nil {
return err return err
} }
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil return nil
} }
func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error {
u.Billing.StripeCustomerID = customerID
u.Billing.StripeSubscriptionID = subscriptionID
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status)
u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0)
if tier == "" {
if err := s.userManager.ResetTier(u.Name); err != nil {
return err
}
} else {
if err := s.userManager.ChangeTier(u.Name, tier); err != nil {
return err
}
}
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
return nil
}

View file

@ -238,7 +238,6 @@ type apiAccountTokenResponse struct {
type apiAccountTier struct { type apiAccountTier struct {
Code string `json:"code"` Code string `json:"code"`
Name string `json:"name"` Name string `json:"name"`
Paid bool `json:"paid"`
} }
type apiAccountLimits struct { type apiAccountLimits struct {
@ -305,14 +304,21 @@ type apiConfigResponse struct {
DisallowedTopics []string `json:"disallowed_topics"` DisallowedTopics []string `json:"disallowed_topics"`
} }
type apiAccountBillingSubscriptionChangeRequest struct { type apiAccountBillingTier struct {
Tier string `json:"tier"` Code string `json:"code"`
Name string `json:"name"`
Price string `json:"price"`
Features string `json:"features"`
} }
type apiAccountCheckoutResponse struct { type apiAccountBillingSubscriptionCreateResponse struct {
RedirectURL string `json:"redirect_url"` RedirectURL string `json:"redirect_url"`
} }
type apiAccountBillingSubscriptionChangeRequest struct {
Tier string `json:"tier"`
}
type apiAccountBillingPortalRedirectResponse struct { type apiAccountBillingPortalRedirectResponse struct {
RedirectURL string `json:"redirect_url"` RedirectURL string `json:"redirect_url"`
} }
@ -320,3 +326,13 @@ type apiAccountBillingPortalRedirectResponse struct {
type apiAccountSyncTopicResponse struct { type apiAccountSyncTopicResponse struct {
Event string `json:"event"` Event string `json:"event"`
} }
type apiSuccessResponse struct {
Success bool `json:"success"`
}
func newSuccessResponse() *apiSuccessResponse {
return &apiSuccessResponse{
Success: true,
}
}

View file

@ -45,6 +45,7 @@ const (
attachment_file_size_limit INT NOT NULL, attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL, attachment_expiry_duration INT NOT NULL,
features TEXT,
stripe_price_id TEXT stripe_price_id TEXT
); );
CREATE UNIQUE INDEX idx_tier_code ON tier (code); CREATE UNIQUE INDEX idx_tier_code ON tier (code);
@ -103,22 +104,22 @@ const (
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier p on p.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id
FROM user u FROM user u
JOIN user_token t on u.id = t.user_id JOIN user_token t on u.id = t.user_id
LEFT JOIN tier p on p.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ? WHERE t.token = ? AND t.expires >= ?
` `
selectUserByStripeCustomerIDQuery = ` selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier p on p.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ? WHERE u.stripe_customer_id = ?
` `
selectTopicPermsQuery = ` selectTopicPermsQuery = `
@ -220,14 +221,18 @@ const (
INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration) INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?` selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
selectTiersQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id
FROM tier
`
selectTierByCodeQuery = ` selectTierByCodeQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id
FROM tier FROM tier
WHERE code = ? WHERE code = ?
` `
selectTierByPriceIDQuery = ` selectTierByPriceIDQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id
FROM tier FROM tier
WHERE stripe_price_id = ? WHERE stripe_price_id = ?
` `
@ -604,13 +609,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var username, hash, role, prefs, syncTopic string var username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName, tierFeatures sql.NullString
var messages, emails int64 var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrUserNotFound return nil, ErrUserNotFound
} }
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &tierFeatures, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -649,7 +654,8 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String, Features: tierFeatures.String, // May be empty
StripePriceID: stripePriceID.String, // May be empty
} }
} }
return user, nil return user, nil
@ -881,11 +887,31 @@ func (a *Manager) ChangeBilling(user *User) error {
return nil return nil
} }
func (a *Manager) Tiers() ([]*Tier, error) {
rows, err := a.db.Query(selectTiersQuery)
if err != nil {
return nil, err
}
defer rows.Close()
tiers := make([]*Tier, 0)
for {
tier, err := a.readTier(rows)
if err == ErrTierNotFound {
break
} else if err != nil {
return nil, err
}
tiers = append(tiers, tier)
}
return tiers, nil
}
func (a *Manager) Tier(code string) (*Tier, error) { func (a *Manager) Tier(code string) (*Tier, error) {
rows, err := a.db.Query(selectTierByCodeQuery, code) rows, err := a.db.Query(selectTierByCodeQuery, code)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
return a.readTier(rows) return a.readTier(rows)
} }
@ -894,18 +920,18 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
return a.readTier(rows) return a.readTier(rows)
} }
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
defer rows.Close()
var code, name string var code, name string
var stripePriceID sql.NullString var features, stripePriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrTierNotFound return nil, ErrTierNotFound
} }
if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &features, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -922,7 +948,8 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String, // May be empty! Features: features.String, // May be empty
StripePriceID: stripePriceID.String, // May be empty
}, nil }, nil
} }

View file

@ -60,6 +60,7 @@ type Tier struct {
AttachmentFileSizeLimit int64 AttachmentFileSizeLimit int64
AttachmentTotalSizeLimit int64 AttachmentTotalSizeLimit int64
AttachmentExpiryDuration time.Duration AttachmentExpiryDuration time.Duration
Features string
StripePriceID string StripePriceID string
} }

View file

@ -201,8 +201,9 @@
"account_delete_dialog_label": "Type '{{username}}' to delete account", "account_delete_dialog_label": "Type '{{username}}' to delete account",
"account_delete_dialog_button_cancel": "Cancel", "account_delete_dialog_button_cancel": "Cancel",
"account_delete_dialog_button_submit": "Permanently delete account", "account_delete_dialog_button_submit": "Permanently delete account",
"account_upgrade_dialog_title": "Change billing plan", "account_upgrade_dialog_title": "Change account tier",
"account_upgrade_dialog_cancel_warning": "This will cancel your subscription, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server will be deleted.", "account_upgrade_dialog_cancel_warning": "This will cancel your subscription, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server will be deleted.",
"account_upgrade_dialog_proration_info": "When switching between paid plans, the price difference will be charged or refunded in the next invoice.",
"prefs_notifications_title": "Notifications", "prefs_notifications_title": "Notifications",
"prefs_notifications_sound_title": "Notification sound", "prefs_notifications_sound_title": "Notification sound",
"prefs_notifications_sound_description_none": "Notifications do not play any sound when they arrive", "prefs_notifications_sound_description_none": "Notifications do not play any sound when they arrive",

View file

@ -8,7 +8,7 @@ import {
accountTokenUrl, accountTokenUrl,
accountUrl, maybeWithAuth, topicUrl, accountUrl, maybeWithAuth, topicUrl,
withBasicAuth, withBasicAuth,
withBearerAuth, accountBillingSubscriptionUrl, accountBillingPortalUrl withBearerAuth, accountBillingSubscriptionUrl, accountBillingPortalUrl, accountBillingTiersUrl
} from "./utils"; } from "./utils";
import session from "./Session"; import session from "./Session";
import subscriptionManager from "./SubscriptionManager"; import subscriptionManager from "./SubscriptionManager";
@ -264,6 +264,20 @@ class AccountApi {
this.triggerChange(); // Dangle! this.triggerChange(); // Dangle!
} }
async billingTiers() {
const url = accountBillingTiersUrl(config.base_url);
console.log(`[AccountApi] Fetching billing tiers`);
const response = await fetch(url, {
headers: withBearerAuth({}, session.token())
});
if (response.status === 401 || response.status === 403) {
throw new UnauthorizedError();
} else if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
return await response.json();
}
async createBillingSubscription(tier) { async createBillingSubscription(tier) {
console.log(`[AccountApi] Creating billing subscription with ${tier}`); console.log(`[AccountApi] Creating billing subscription with ${tier}`);
return await this.upsertBillingSubscription("POST", tier) return await this.upsertBillingSubscription("POST", tier)

View file

@ -28,6 +28,7 @@ export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reserva
export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`; export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`;
export const accountBillingSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/billing/subscription`; export const accountBillingSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/billing/subscription`;
export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`; export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`;
export const accountBillingTiersUrl = (baseUrl) => `${baseUrl}/v1/account/billing/tiers`;
export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, ""); export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
export const expandUrl = (url) => [`https://${url}`, `http://${url}`]; export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
export const expandSecureUrl = (url) => `https://${url}`; export const expandSecureUrl = (url) => `https://${url}`;

View file

@ -69,8 +69,9 @@ const Layout = () => {
const [sendDialogOpenMode, setSendDialogOpenMode] = useState(""); const [sendDialogOpenMode, setSendDialogOpenMode] = useState("");
const users = useLiveQuery(() => userManager.all()); const users = useLiveQuery(() => userManager.all());
const subscriptions = useLiveQuery(() => subscriptionManager.all()); const subscriptions = useLiveQuery(() => subscriptionManager.all());
const newNotificationsCount = subscriptions?.reduce((prev, cur) => prev + cur.new, 0) || 0; const subscriptionsWithoutInternal = subscriptions?.filter(s => !s.internal);
const [selected] = (subscriptions || []).filter(s => { const newNotificationsCount = subscriptionsWithoutInternal?.reduce((prev, cur) => prev + cur.new, 0) || 0;
const [selected] = (subscriptionsWithoutInternal || []).filter(s => {
return (params.baseUrl && expandUrl(params.baseUrl).includes(s.baseUrl) && params.topic === s.topic) return (params.baseUrl && expandUrl(params.baseUrl).includes(s.baseUrl) && params.topic === s.topic)
|| (config.base_url === s.baseUrl && params.topic === s.topic) || (config.base_url === s.baseUrl && params.topic === s.topic)
}); });
@ -101,7 +102,7 @@ const Layout = () => {
onMobileDrawerToggle={() => setMobileDrawerOpen(!mobileDrawerOpen)} onMobileDrawerToggle={() => setMobileDrawerOpen(!mobileDrawerOpen)}
/> />
<Navigation <Navigation
subscriptions={subscriptions} subscriptions={subscriptionsWithoutInternal}
selectedSubscription={selected} selectedSubscription={selected}
notificationsGranted={notificationsGranted} notificationsGranted={notificationsGranted}
mobileDrawerOpen={mobileDrawerOpen} mobileDrawerOpen={mobileDrawerOpen}
@ -111,7 +112,10 @@ const Layout = () => {
/> />
<Main> <Main>
<Toolbar/> <Toolbar/>
<Outlet context={{ subscriptions, selected }}/> <Outlet context={{
subscriptions: subscriptionsWithoutInternal,
selected: selected
}}/>
</Main> </Main>
<Messaging <Messaging
selected={selected} selected={selected}

View file

@ -9,21 +9,29 @@ import Button from "@mui/material/Button";
import accountApi, {TopicReservedError, UnauthorizedError} from "../app/AccountApi"; import accountApi, {TopicReservedError, UnauthorizedError} from "../app/AccountApi";
import session from "../app/Session"; import session from "../app/Session";
import routes from "./routes"; import routes from "./routes";
import {useContext, useState} from "react"; import {useContext, useEffect, useState} from "react";
import Card from "@mui/material/Card"; import Card from "@mui/material/Card";
import Typography from "@mui/material/Typography"; import Typography from "@mui/material/Typography";
import {AccountContext} from "./App"; import {AccountContext} from "./App";
import {formatShortDate} from "../app/utils"; import {formatShortDate} from "../app/utils";
import {useTranslation} from "react-i18next"; import {useTranslation} from "react-i18next";
import subscriptionManager from "../app/SubscriptionManager";
const UpgradeDialog = (props) => { const UpgradeDialog = (props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const { account } = useContext(AccountContext); const { account } = useContext(AccountContext);
const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const [tiers, setTiers] = useState(null);
const [newTier, setNewTier] = useState(account?.tier?.code || null); const [newTier, setNewTier] = useState(account?.tier?.code || null);
const [errorText, setErrorText] = useState(""); const [errorText, setErrorText] = useState("");
if (!account) { useEffect(() => {
(async () => {
setTiers(await accountApi.billingTiers());
})();
}, []);
if (!account || !tiers) {
return <></>; return <></>;
} }
@ -69,22 +77,43 @@ const UpgradeDialog = (props) => {
return ( return (
<Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}> <Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}>
<DialogTitle>Change billing plan</DialogTitle> <DialogTitle>{t("account_upgrade_dialog_title")}</DialogTitle>
<DialogContent> <DialogContent>
<div style={{ <div style={{
display: "flex", display: "flex",
flexDirection: "row" flexDirection: "row",
marginBottom: "8px",
width: "100%"
}}> }}>
<TierCard code={null} name={"Free"} selected={newTier === null} onClick={() => setNewTier(null)}/> <TierCard
<TierCard code="starter" name={"Starter"} selected={newTier === "starter"} onClick={() => setNewTier("starter")}/> code={null}
<TierCard code="pro" name={"Pro"} selected={newTier === "pro"} onClick={() => setNewTier("pro")}/> name={t("account_usage_tier_free")}
<TierCard code="business" name={"Business"} selected={newTier === "business"} onClick={() => setNewTier("business")}/> price={null}
selected={newTier === null}
onClick={() => setNewTier(null)}
/>
{tiers.map(tier =>
<TierCard
key={`tierCard${tier.code}`}
code={tier.code}
name={tier.name}
price={tier.price}
features={tier.features}
selected={newTier === tier.code}
onClick={() => setNewTier(tier.code)}
/>
)}
</div> </div>
{action === Action.CANCEL && {action === Action.CANCEL &&
<Alert severity="warning"> <Alert severity="warning">
{t("account_upgrade_dialog_cancel_warning", { date: formatShortDate(account.billing.paid_until) })} {t("account_upgrade_dialog_cancel_warning", { date: formatShortDate(account.billing.paid_until) })}
</Alert> </Alert>
} }
{action === Action.UPDATE &&
<Alert severity="info">
{t("account_upgrade_dialog_proration_info")}
</Alert>
}
</DialogContent> </DialogContent>
<DialogFooter status={errorText}> <DialogFooter status={errorText}>
<Button onClick={props.onCancel}>Cancel</Button> <Button onClick={props.onCancel}>Cancel</Button>
@ -95,20 +124,42 @@ const UpgradeDialog = (props) => {
}; };
const TierCard = (props) => { const TierCard = (props) => {
const cardStyle = (props.selected) ? { const cardStyle = (props.selected) ? { background: "#eee", border: "2px solid #338574" } : {};
background: "#eee"
} : {};
return ( return (
<Card sx={{ m: 1, maxWidth: 345 }}> <Card sx={{
<CardActionArea> m: 1,
<CardContent sx={{...cardStyle}} onClick={props.onClick}> minWidth: "190px",
maxWidth: "250px",
"&:first-child": { ml: 0 },
"&:last-child": { mr: 0 },
...cardStyle
}}>
<CardActionArea sx={{ height: "100%" }}>
<CardContent onClick={props.onClick} sx={{ height: "100%" }}>
{props.selected &&
<div style={{
position: "absolute",
top: "0",
right: "15px",
padding: "2px 10px",
background: "#338574",
color: "white",
borderRadius: "3px",
}}>Selected</div>
}
<Typography gutterBottom variant="h5" component="div"> <Typography gutterBottom variant="h5" component="div">
{props.name} {props.name}
</Typography> </Typography>
<Typography variant="body2" color="text.secondary"> {props.features &&
Lizards are a widespread group of squamate reptiles, with over 6,000 <Typography variant="body2" color="text.secondary" sx={{whiteSpace: "pre-wrap"}}>
species, ranging across all continents except Antarctica {props.features}
</Typography> </Typography>
}
{props.price &&
<Typography variant="subtitle1" sx={{mt: 1}}>
{props.price} / month
</Typography>
}
</CardContent> </CardContent>
</CardActionArea> </CardActionArea>
</Card> </Card>