Upgrade dialog

This commit is contained in:
binwiederhier 2023-01-17 10:09:37 -05:00
parent 83de879894
commit 695c1349e8
11 changed files with 290 additions and 137 deletions

View file

@ -52,7 +52,6 @@ import (
update last_seen when API is accessed
Make sure account endpoints make sense for admins
triggerChange after publishing a message
UI:
- flicker of upgrade banner
- JS constants
@ -83,6 +82,7 @@ type Server struct {
userManager *user.Manager // Might be nil!
messageCache *messageCache
fileCache *fileCache
priceCache map[string]string // Stripe price ID -> formatted price
closeChan chan bool
mu sync.Mutex
}
@ -103,21 +103,23 @@ var (
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
webConfigPath = "/config.js"
healthPath = "/v1/health"
accountPath = "/v1/account"
accountTokenPath = "/v1/account/token"
accountPasswordPath = "/v1/account/password"
accountSettingsPath = "/v1/account/settings"
accountSubscriptionPath = "/v1/account/subscription"
accountReservationPath = "/v1/account/reservation"
accountBillingPortalPath = "/v1/account/billing/portal"
accountBillingWebhookPath = "/v1/account/billing/webhook"
accountBillingSubscriptionPath = "/v1/account/billing/subscription"
accountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
accountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
accountPath = "/account"
matrixPushPath = "/_matrix/push/v1/notify"
apiHealthPath = "/v1/health"
apiAccountPath = "/v1/account"
apiAccountTokenPath = "/v1/account/token"
apiAccountPasswordPath = "/v1/account/password"
apiAccountSettingsPath = "/v1/account/settings"
apiAccountSubscriptionPath = "/v1/account/subscription"
apiAccountReservationPath = "/v1/account/reservation"
apiAccountBillingTiersPath = "/v1/account/billing/tiers"
apiAccountBillingPortalPath = "/v1/account/billing/portal"
apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
apiAccountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
apiAccountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
apiAccountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
apiAccountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
staticRegex = regexp.MustCompile(`^/static/.+`)
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
@ -199,6 +201,7 @@ func New(conf *Config) (*Server, error) {
topics: topics,
userManager: userManager,
visitors: make(map[string]*visitor),
priceCache: make(map[string]string),
}, nil
}
@ -347,47 +350,49 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureWebEnabled(s.handleHome)(w, r, v)
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == healthPath {
} else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
return s.handleHealth(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountPath {
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountTokenPath {
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountTokenPath {
return s.ensureUser(s.handleAccountTokenIssue)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == accountPath {
} else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
return s.handleAccountGet(w, r, v) // Allowed by anonymous
} else if r.Method == http.MethodDelete && r.URL.Path == accountPath {
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPath {
return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath {
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountPasswordPath {
return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
} else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath {
} else if r.Method == http.MethodPatch && r.URL.Path == apiAccountTokenPath {
return s.ensureUser(s.handleAccountTokenExtend)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath {
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccountTokenPath {
return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
} else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath {
} else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSettingsPath {
return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath {
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountSubscriptionPath {
return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
} else if r.Method == http.MethodPatch && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
} else if r.Method == http.MethodPatch && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
} else if r.Method == http.MethodDelete && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
} else if r.Method == http.MethodDelete && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountReservationPath {
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountReservationPath {
return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) {
} else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingSubscriptionPath {
} else if r.Method == http.MethodGet && r.URL.Path == apiAccountBillingTiersPath {
return s.ensurePaymentsEnabled(s.handleAccountBillingTiersGet)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodGet && accountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
} else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
} else if r.Method == http.MethodPut && r.URL.Path == accountBillingSubscriptionPath {
} else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath {
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath {
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath {
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath {
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w)

View file

@ -91,7 +91,6 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
response.Tier = &apiAccountTier{
Code: v.user.Tier.Code,
Name: v.user.Tier.Name,
Paid: v.user.Tier.Paid,
}
}
if v.user.Billing.StripeCustomerID != "" {
@ -268,7 +267,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
}
func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
matches := accountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path)
matches := apiAccountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath
}
@ -303,7 +302,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
}
func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
matches := accountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path)
matches := apiAccountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath
}
@ -374,7 +373,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
}
func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
matches := accountReservationSingleRegex.FindStringSubmatch(r.URL.Path)
matches := apiAccountReservationSingleRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath
}

View file

@ -3,14 +3,17 @@ package server
import (
"encoding/json"
"errors"
"fmt"
"github.com/stripe/stripe-go/v74"
portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
"github.com/stripe/stripe-go/v74/checkout/session"
"github.com/stripe/stripe-go/v74/customer"
"github.com/stripe/stripe-go/v74/price"
"github.com/stripe/stripe-go/v74/subscription"
"github.com/stripe/stripe-go/v74/webhook"
"github.com/tidwall/gjson"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"net/http"
"net/netip"
@ -21,6 +24,44 @@ const (
stripeBodyBytesLimit = 16384
)
func (s *Server) handleAccountBillingTiersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
tiers, err := v.userManager.Tiers()
if err != nil {
return err
}
response := make([]*apiAccountBillingTier, 0)
for _, tier := range tiers {
if tier.StripePriceID == "" {
continue
}
priceStr, ok := s.priceCache[tier.StripePriceID]
if !ok {
p, err := price.Get(tier.StripePriceID, nil)
if err != nil {
return err
}
if p.UnitAmount%100 == 0 {
priceStr = fmt.Sprintf("$%d", p.UnitAmount/100)
} else {
priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
}
s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something
}
response = append(response, &apiAccountBillingTier{
Code: tier.Code,
Name: tier.Name,
Price: priceStr,
Features: tier.Features,
})
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
// will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
@ -49,10 +90,10 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
return errors.New("customer cannot have more than one subscription") //FIXME
}
}
successURL := s.config.BaseURL + "/account" //+ accountBillingSubscriptionCheckoutSuccessTemplate
successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
params := &stripe.CheckoutSessionParams{
Customer: stripeCustomerID, // A user may have previously deleted their subscription
ClientReferenceID: &v.user.Name, // FIXME Should be user ID
ClientReferenceID: &v.user.Name,
SuccessURL: &successURL,
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
@ -69,7 +110,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{
response := &apiAccountBillingSubscriptionCreateResponse{
RedirectURL: sess.URL,
}
w.Header().Set("Content-Type", "application/json")
@ -82,29 +123,25 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := accountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath
}
sessionID := matches[1]
// FIXME how do I rate limit this?
sess, err := session.Get(sessionID, nil)
sess, err := session.Get(sessionID, nil) // FIXME how do I rate limit this?
if err != nil {
log.Warn("Stripe: %s", err)
return errHTTPBadRequestInvalidStripeRequest
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
log.Warn("Stripe: Unexpected session, customer or subscription not found")
return errHTTPBadRequestInvalidStripeRequest
return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "customer or subscription not found")
}
sub, err := subscription.Get(sess.Subscription.ID, nil)
if err != nil {
return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
log.Error("Stripe: Unexpected subscription, expected exactly one line item")
return errHTTPBadRequestInvalidStripeRequest
return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "more than one line item in existing subscription")
}
priceID := sub.Items.Data[0].Price.ID
tier, err := s.userManager.TierByStripePrice(priceID)
tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
if err != nil {
return err
}
@ -112,26 +149,17 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil {
return err
}
u.Billing.StripeCustomerID = sess.Customer.ID
u.Billing.StripeSubscriptionID = sub.ID
u.Billing.StripeSubscriptionStatus = sub.Status
u.Billing.StripeSubscriptionPaidUntil = time.Unix(sub.CurrentPeriodEnd, 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(sub.CancelAt, 0)
if err := s.userManager.ChangeBilling(u); err != nil {
if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
accountURL := s.config.BaseURL + "/account" // FIXME
http.Redirect(w, r, accountURL, http.StatusSeeOther)
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
return nil
}
// handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
// a user's tier accordingly. This endpoint only works if there is an existing subscription.
func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID != "" {
if v.user.Billing.StripeSubscriptionID == "" {
return errors.New("no existing subscription for user")
}
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
@ -161,10 +189,9 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{} // FIXME
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
return err
}
return nil
@ -250,6 +277,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID)
u, err := s.userManager.UserByStripeCustomer(customerID.String())
if err != nil {
return err
@ -258,41 +286,47 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil {
return err
}
u.Billing.StripeSubscriptionID = subscriptionID.String()
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status.String())
u.Billing.StripeSubscriptionPaidUntil = time.Unix(currentPeriodEnd.Int(), 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt.Int(), 0)
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
log.Info("Stripe: customer %s: subscription updated to %s, with price %s", customerID.String(), status, priceID)
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil
}
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
stripeCustomerID := gjson.GetBytes(event, "customer")
if !stripeCustomerID.Exists() {
customerID := gjson.GetBytes(event, "customer")
if !customerID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID.String())
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID.String())
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String())
u, err := s.userManager.UserByStripeCustomer(customerID.String())
if err != nil {
return err
}
if err := s.userManager.ResetTier(u.Name); err != nil {
return err
}
u.Billing.StripeSubscriptionID = ""
u.Billing.StripeSubscriptionStatus = ""
u.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(0, 0)
if err := s.userManager.ChangeBilling(u); err != nil {
if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil
}
func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error {
u.Billing.StripeCustomerID = customerID
u.Billing.StripeSubscriptionID = subscriptionID
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status)
u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0)
if tier == "" {
if err := s.userManager.ResetTier(u.Name); err != nil {
return err
}
} else {
if err := s.userManager.ChangeTier(u.Name, tier); err != nil {
return err
}
}
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
return nil
}

View file

@ -238,7 +238,6 @@ type apiAccountTokenResponse struct {
type apiAccountTier struct {
Code string `json:"code"`
Name string `json:"name"`
Paid bool `json:"paid"`
}
type apiAccountLimits struct {
@ -305,14 +304,21 @@ type apiConfigResponse struct {
DisallowedTopics []string `json:"disallowed_topics"`
}
type apiAccountBillingSubscriptionChangeRequest struct {
Tier string `json:"tier"`
type apiAccountBillingTier struct {
Code string `json:"code"`
Name string `json:"name"`
Price string `json:"price"`
Features string `json:"features"`
}
type apiAccountCheckoutResponse struct {
type apiAccountBillingSubscriptionCreateResponse struct {
RedirectURL string `json:"redirect_url"`
}
type apiAccountBillingSubscriptionChangeRequest struct {
Tier string `json:"tier"`
}
type apiAccountBillingPortalRedirectResponse struct {
RedirectURL string `json:"redirect_url"`
}
@ -320,3 +326,13 @@ type apiAccountBillingPortalRedirectResponse struct {
type apiAccountSyncTopicResponse struct {
Event string `json:"event"`
}
type apiSuccessResponse struct {
Success bool `json:"success"`
}
func newSuccessResponse() *apiSuccessResponse {
return &apiSuccessResponse{
Success: true,
}
}

View file

@ -45,6 +45,7 @@ const (
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
features TEXT,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
@ -103,22 +104,22 @@ const (
`
selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
LEFT JOIN tier p on p.id = u.tier_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ?
`
selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
`
selectTopicPermsQuery = `
@ -221,13 +222,17 @@ const (
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
selectTiersQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id
FROM tier
`
selectTierByCodeQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id
FROM tier
WHERE code = ?
`
selectTierByPriceIDQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id
FROM tier
WHERE stripe_price_id = ?
`
@ -604,13 +609,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName, tierFeatures sql.NullString
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &tierFeatures, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -649,7 +654,8 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String,
Features: tierFeatures.String, // May be empty
StripePriceID: stripePriceID.String, // May be empty
}
}
return user, nil
@ -881,11 +887,31 @@ func (a *Manager) ChangeBilling(user *User) error {
return nil
}
func (a *Manager) Tiers() ([]*Tier, error) {
rows, err := a.db.Query(selectTiersQuery)
if err != nil {
return nil, err
}
defer rows.Close()
tiers := make([]*Tier, 0)
for {
tier, err := a.readTier(rows)
if err == ErrTierNotFound {
break
} else if err != nil {
return nil, err
}
tiers = append(tiers, tier)
}
return tiers, nil
}
func (a *Manager) Tier(code string) (*Tier, error) {
rows, err := a.db.Query(selectTierByCodeQuery, code)
if err != nil {
return nil, err
}
defer rows.Close()
return a.readTier(rows)
}
@ -894,18 +920,18 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
if err != nil {
return nil, err
}
defer rows.Close()
return a.readTier(rows)
}
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
defer rows.Close()
var code, name string
var stripePriceID sql.NullString
var features, stripePriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &features, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -922,7 +948,8 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String, // May be empty!
Features: features.String, // May be empty
StripePriceID: stripePriceID.String, // May be empty
}, nil
}

View file

@ -60,6 +60,7 @@ type Tier struct {
AttachmentFileSizeLimit int64
AttachmentTotalSizeLimit int64
AttachmentExpiryDuration time.Duration
Features string
StripePriceID string
}

View file

@ -201,8 +201,9 @@
"account_delete_dialog_label": "Type '{{username}}' to delete account",
"account_delete_dialog_button_cancel": "Cancel",
"account_delete_dialog_button_submit": "Permanently delete account",
"account_upgrade_dialog_title": "Change billing plan",
"account_upgrade_dialog_title": "Change account tier",
"account_upgrade_dialog_cancel_warning": "This will cancel your subscription, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server will be deleted.",
"account_upgrade_dialog_proration_info": "When switching between paid plans, the price difference will be charged or refunded in the next invoice.",
"prefs_notifications_title": "Notifications",
"prefs_notifications_sound_title": "Notification sound",
"prefs_notifications_sound_description_none": "Notifications do not play any sound when they arrive",

View file

@ -8,7 +8,7 @@ import {
accountTokenUrl,
accountUrl, maybeWithAuth, topicUrl,
withBasicAuth,
withBearerAuth, accountBillingSubscriptionUrl, accountBillingPortalUrl
withBearerAuth, accountBillingSubscriptionUrl, accountBillingPortalUrl, accountBillingTiersUrl
} from "./utils";
import session from "./Session";
import subscriptionManager from "./SubscriptionManager";
@ -264,6 +264,20 @@ class AccountApi {
this.triggerChange(); // Dangle!
}
async billingTiers() {
const url = accountBillingTiersUrl(config.base_url);
console.log(`[AccountApi] Fetching billing tiers`);
const response = await fetch(url, {
headers: withBearerAuth({}, session.token())
});
if (response.status === 401 || response.status === 403) {
throw new UnauthorizedError();
} else if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
return await response.json();
}
async createBillingSubscription(tier) {
console.log(`[AccountApi] Creating billing subscription with ${tier}`);
return await this.upsertBillingSubscription("POST", tier)

View file

@ -28,6 +28,7 @@ export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reserva
export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`;
export const accountBillingSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/billing/subscription`;
export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`;
export const accountBillingTiersUrl = (baseUrl) => `${baseUrl}/v1/account/billing/tiers`;
export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
export const expandSecureUrl = (url) => `https://${url}`;

View file

@ -69,8 +69,9 @@ const Layout = () => {
const [sendDialogOpenMode, setSendDialogOpenMode] = useState("");
const users = useLiveQuery(() => userManager.all());
const subscriptions = useLiveQuery(() => subscriptionManager.all());
const newNotificationsCount = subscriptions?.reduce((prev, cur) => prev + cur.new, 0) || 0;
const [selected] = (subscriptions || []).filter(s => {
const subscriptionsWithoutInternal = subscriptions?.filter(s => !s.internal);
const newNotificationsCount = subscriptionsWithoutInternal?.reduce((prev, cur) => prev + cur.new, 0) || 0;
const [selected] = (subscriptionsWithoutInternal || []).filter(s => {
return (params.baseUrl && expandUrl(params.baseUrl).includes(s.baseUrl) && params.topic === s.topic)
|| (config.base_url === s.baseUrl && params.topic === s.topic)
});
@ -101,7 +102,7 @@ const Layout = () => {
onMobileDrawerToggle={() => setMobileDrawerOpen(!mobileDrawerOpen)}
/>
<Navigation
subscriptions={subscriptions}
subscriptions={subscriptionsWithoutInternal}
selectedSubscription={selected}
notificationsGranted={notificationsGranted}
mobileDrawerOpen={mobileDrawerOpen}
@ -111,7 +112,10 @@ const Layout = () => {
/>
<Main>
<Toolbar/>
<Outlet context={{ subscriptions, selected }}/>
<Outlet context={{
subscriptions: subscriptionsWithoutInternal,
selected: selected
}}/>
</Main>
<Messaging
selected={selected}

View file

@ -9,21 +9,29 @@ import Button from "@mui/material/Button";
import accountApi, {TopicReservedError, UnauthorizedError} from "../app/AccountApi";
import session from "../app/Session";
import routes from "./routes";
import {useContext, useState} from "react";
import {useContext, useEffect, useState} from "react";
import Card from "@mui/material/Card";
import Typography from "@mui/material/Typography";
import {AccountContext} from "./App";
import {formatShortDate} from "../app/utils";
import {useTranslation} from "react-i18next";
import subscriptionManager from "../app/SubscriptionManager";
const UpgradeDialog = (props) => {
const { t } = useTranslation();
const { account } = useContext(AccountContext);
const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const [tiers, setTiers] = useState(null);
const [newTier, setNewTier] = useState(account?.tier?.code || null);
const [errorText, setErrorText] = useState("");
if (!account) {
useEffect(() => {
(async () => {
setTiers(await accountApi.billingTiers());
})();
}, []);
if (!account || !tiers) {
return <></>;
}
@ -69,22 +77,43 @@ const UpgradeDialog = (props) => {
return (
<Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}>
<DialogTitle>Change billing plan</DialogTitle>
<DialogTitle>{t("account_upgrade_dialog_title")}</DialogTitle>
<DialogContent>
<div style={{
display: "flex",
flexDirection: "row"
flexDirection: "row",
marginBottom: "8px",
width: "100%"
}}>
<TierCard code={null} name={"Free"} selected={newTier === null} onClick={() => setNewTier(null)}/>
<TierCard code="starter" name={"Starter"} selected={newTier === "starter"} onClick={() => setNewTier("starter")}/>
<TierCard code="pro" name={"Pro"} selected={newTier === "pro"} onClick={() => setNewTier("pro")}/>
<TierCard code="business" name={"Business"} selected={newTier === "business"} onClick={() => setNewTier("business")}/>
<TierCard
code={null}
name={t("account_usage_tier_free")}
price={null}
selected={newTier === null}
onClick={() => setNewTier(null)}
/>
{tiers.map(tier =>
<TierCard
key={`tierCard${tier.code}`}
code={tier.code}
name={tier.name}
price={tier.price}
features={tier.features}
selected={newTier === tier.code}
onClick={() => setNewTier(tier.code)}
/>
)}
</div>
{action === Action.CANCEL &&
<Alert severity="warning">
{t("account_upgrade_dialog_cancel_warning", { date: formatShortDate(account.billing.paid_until) })}
</Alert>
}
{action === Action.UPDATE &&
<Alert severity="info">
{t("account_upgrade_dialog_proration_info")}
</Alert>
}
</DialogContent>
<DialogFooter status={errorText}>
<Button onClick={props.onCancel}>Cancel</Button>
@ -95,20 +124,42 @@ const UpgradeDialog = (props) => {
};
const TierCard = (props) => {
const cardStyle = (props.selected) ? {
background: "#eee"
} : {};
const cardStyle = (props.selected) ? { background: "#eee", border: "2px solid #338574" } : {};
return (
<Card sx={{ m: 1, maxWidth: 345 }}>
<CardActionArea>
<CardContent sx={{...cardStyle}} onClick={props.onClick}>
<Card sx={{
m: 1,
minWidth: "190px",
maxWidth: "250px",
"&:first-child": { ml: 0 },
"&:last-child": { mr: 0 },
...cardStyle
}}>
<CardActionArea sx={{ height: "100%" }}>
<CardContent onClick={props.onClick} sx={{ height: "100%" }}>
{props.selected &&
<div style={{
position: "absolute",
top: "0",
right: "15px",
padding: "2px 10px",
background: "#338574",
color: "white",
borderRadius: "3px",
}}>Selected</div>
}
<Typography gutterBottom variant="h5" component="div">
{props.name}
</Typography>
<Typography variant="body2" color="text.secondary">
Lizards are a widespread group of squamate reptiles, with over 6,000
species, ranging across all continents except Antarctica
{props.features &&
<Typography variant="body2" color="text.secondary" sx={{whiteSpace: "pre-wrap"}}>
{props.features}
</Typography>
}
{props.price &&
<Typography variant="subtitle1" sx={{mt: 1}}>
{props.price} / month
</Typography>
}
</CardContent>
</CardActionArea>
</Card>