diff --git a/server/server.go b/server/server.go index 8f4370c..09256cf 100644 --- a/server/server.go +++ b/server/server.go @@ -52,7 +52,6 @@ import ( update last_seen when API is accessed Make sure account endpoints make sense for admins - triggerChange after publishing a message UI: - flicker of upgrade banner - JS constants @@ -83,6 +82,7 @@ type Server struct { userManager *user.Manager // Might be nil! messageCache *messageCache fileCache *fileCache + priceCache map[string]string // Stripe price ID -> formatted price closeChan chan bool mu sync.Mutex } @@ -102,27 +102,29 @@ var ( authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`) publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`) - webConfigPath = "/config.js" - healthPath = "/v1/health" - accountPath = "/v1/account" - accountTokenPath = "/v1/account/token" - accountPasswordPath = "/v1/account/password" - accountSettingsPath = "/v1/account/settings" - accountSubscriptionPath = "/v1/account/subscription" - accountReservationPath = "/v1/account/reservation" - accountBillingPortalPath = "/v1/account/billing/portal" - accountBillingWebhookPath = "/v1/account/billing/webhook" - accountBillingSubscriptionPath = "/v1/account/billing/subscription" - accountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}" - accountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`) - accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`) - accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`) - matrixPushPath = "/_matrix/push/v1/notify" - staticRegex = regexp.MustCompile(`^/static/.+`) - docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) - fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`) - disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app - urlRegex = regexp.MustCompile(`^https?://`) + webConfigPath = "/config.js" + accountPath = "/account" + matrixPushPath = "/_matrix/push/v1/notify" + apiHealthPath = "/v1/health" + apiAccountPath = "/v1/account" + apiAccountTokenPath = "/v1/account/token" + apiAccountPasswordPath = "/v1/account/password" + apiAccountSettingsPath = "/v1/account/settings" + apiAccountSubscriptionPath = "/v1/account/subscription" + apiAccountReservationPath = "/v1/account/reservation" + apiAccountBillingTiersPath = "/v1/account/billing/tiers" + apiAccountBillingPortalPath = "/v1/account/billing/portal" + apiAccountBillingWebhookPath = "/v1/account/billing/webhook" + apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription" + apiAccountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}" + apiAccountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`) + apiAccountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`) + apiAccountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`) + staticRegex = regexp.MustCompile(`^/static/.+`) + docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) + fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`) + disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app + urlRegex = regexp.MustCompile(`^https?://`) //go:embed site webFs embed.FS @@ -199,6 +201,7 @@ func New(conf *Config) (*Server, error) { topics: topics, userManager: userManager, visitors: make(map[string]*visitor), + priceCache: make(map[string]string), }, nil } @@ -347,47 +350,49 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.ensureWebEnabled(s.handleHome)(w, r, v) } else if r.Method == http.MethodHead && r.URL.Path == "/" { return s.ensureWebEnabled(s.handleEmpty)(w, r, v) - } else if r.Method == http.MethodGet && r.URL.Path == healthPath { + } else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath { return s.handleHealth(w, r, v) } else if r.Method == http.MethodGet && r.URL.Path == webConfigPath { return s.ensureWebEnabled(s.handleWebConfig)(w, r, v) - } else if r.Method == http.MethodPost && r.URL.Path == accountPath { + } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath { return s.ensureUserManager(s.handleAccountCreate)(w, r, v) - } else if r.Method == http.MethodPost && r.URL.Path == accountTokenPath { + } else if r.Method == http.MethodPost && r.URL.Path == apiAccountTokenPath { return s.ensureUser(s.handleAccountTokenIssue)(w, r, v) - } else if r.Method == http.MethodGet && r.URL.Path == accountPath { + } else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath { return s.handleAccountGet(w, r, v) // Allowed by anonymous - } else if r.Method == http.MethodDelete && r.URL.Path == accountPath { + } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPath { return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v) - } else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath { + } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPasswordPath { return s.ensureUser(s.handleAccountPasswordChange)(w, r, v) - } else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath { + } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountTokenPath { return s.ensureUser(s.handleAccountTokenExtend)(w, r, v) - } else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath { + } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountTokenPath { return s.ensureUser(s.handleAccountTokenDelete)(w, r, v) - } else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath { + } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSettingsPath { return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v) - } else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath { + } else if r.Method == http.MethodPost && r.URL.Path == apiAccountSubscriptionPath { return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v) - } else if r.Method == http.MethodPatch && accountSubscriptionSingleRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodPatch && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) { return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v) - } else if r.Method == http.MethodDelete && accountSubscriptionSingleRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodDelete && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) { return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v) - } else if r.Method == http.MethodPost && r.URL.Path == accountReservationPath { + } else if r.Method == http.MethodPost && r.URL.Path == apiAccountReservationPath { return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v) - } else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) { return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v) - } else if r.Method == http.MethodPost && r.URL.Path == accountBillingSubscriptionPath { + } else if r.Method == http.MethodGet && r.URL.Path == apiAccountBillingTiersPath { + return s.ensurePaymentsEnabled(s.handleAccountBillingTiersGet)(w, r, v) + } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath { return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook - } else if r.Method == http.MethodGet && accountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) { + } else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) { return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context! - } else if r.Method == http.MethodPut && r.URL.Path == accountBillingSubscriptionPath { + } else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath { return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook - } else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath { + } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath { return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook - } else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath { + } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath { return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v) - } else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath { + } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath { return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath { return s.handleMatrixDiscovery(w) diff --git a/server/server_account.go b/server/server_account.go index 9159ea4..f256b16 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -91,7 +91,6 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis response.Tier = &apiAccountTier{ Code: v.user.Tier.Code, Name: v.user.Tier.Name, - Paid: v.user.Tier.Paid, } } if v.user.Billing.StripeCustomerID != "" { @@ -268,7 +267,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req } func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error { - matches := accountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path) + matches := apiAccountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { return errHTTPInternalErrorInvalidPath } @@ -303,7 +302,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http. } func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { - matches := accountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path) + matches := apiAccountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { return errHTTPInternalErrorInvalidPath } @@ -374,7 +373,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ } func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { - matches := accountReservationSingleRegex.FindStringSubmatch(r.URL.Path) + matches := apiAccountReservationSingleRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { return errHTTPInternalErrorInvalidPath } diff --git a/server/server_payments.go b/server/server_payments.go index 298945a..9476a62 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -3,14 +3,17 @@ package server import ( "encoding/json" "errors" + "fmt" "github.com/stripe/stripe-go/v74" portalsession "github.com/stripe/stripe-go/v74/billingportal/session" "github.com/stripe/stripe-go/v74/checkout/session" "github.com/stripe/stripe-go/v74/customer" + "github.com/stripe/stripe-go/v74/price" "github.com/stripe/stripe-go/v74/subscription" "github.com/stripe/stripe-go/v74/webhook" "github.com/tidwall/gjson" "heckel.io/ntfy/log" + "heckel.io/ntfy/user" "heckel.io/ntfy/util" "net/http" "net/netip" @@ -21,6 +24,44 @@ const ( stripeBodyBytesLimit = 16384 ) +func (s *Server) handleAccountBillingTiersGet(w http.ResponseWriter, r *http.Request, v *visitor) error { + tiers, err := v.userManager.Tiers() + if err != nil { + return err + } + response := make([]*apiAccountBillingTier, 0) + for _, tier := range tiers { + if tier.StripePriceID == "" { + continue + } + priceStr, ok := s.priceCache[tier.StripePriceID] + if !ok { + p, err := price.Get(tier.StripePriceID, nil) + if err != nil { + return err + } + if p.UnitAmount%100 == 0 { + priceStr = fmt.Sprintf("$%d", p.UnitAmount/100) + } else { + priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100) + } + s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something + } + response = append(response, &apiAccountBillingTier{ + Code: tier.Code, + Name: tier.Name, + Price: priceStr, + Features: tier.Features, + }) + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil +} + // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier // will be updated by a subsequent webhook from Stripe, once the subscription becomes active. func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -49,10 +90,10 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r return errors.New("customer cannot have more than one subscription") //FIXME } } - successURL := s.config.BaseURL + "/account" //+ accountBillingSubscriptionCheckoutSuccessTemplate + successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate params := &stripe.CheckoutSessionParams{ Customer: stripeCustomerID, // A user may have previously deleted their subscription - ClientReferenceID: &v.user.Name, // FIXME Should be user ID + ClientReferenceID: &v.user.Name, SuccessURL: &successURL, Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), LineItems: []*stripe.CheckoutSessionLineItemParams{ @@ -69,7 +110,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r if err != nil { return err } - response := &apiAccountCheckoutResponse{ + response := &apiAccountBillingSubscriptionCreateResponse{ RedirectURL: sess.URL, } w.Header().Set("Content-Type", "application/json") @@ -82,29 +123,25 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error { // We don't have a v.user in this endpoint, only a userManager! - matches := accountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) + matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { return errHTTPInternalErrorInvalidPath } sessionID := matches[1] - // FIXME how do I rate limit this? - sess, err := session.Get(sessionID, nil) + sess, err := session.Get(sessionID, nil) // FIXME how do I rate limit this? if err != nil { log.Warn("Stripe: %s", err) return errHTTPBadRequestInvalidStripeRequest } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" { - log.Warn("Stripe: Unexpected session, customer or subscription not found") - return errHTTPBadRequestInvalidStripeRequest + return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "customer or subscription not found") } sub, err := subscription.Get(sess.Subscription.ID, nil) if err != nil { return err } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil { - log.Error("Stripe: Unexpected subscription, expected exactly one line item") - return errHTTPBadRequestInvalidStripeRequest + return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "more than one line item in existing subscription") } - priceID := sub.Items.Data[0].Price.ID - tier, err := s.userManager.TierByStripePrice(priceID) + tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID) if err != nil { return err } @@ -112,26 +149,17 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr if err != nil { return err } - u.Billing.StripeCustomerID = sess.Customer.ID - u.Billing.StripeSubscriptionID = sub.ID - u.Billing.StripeSubscriptionStatus = sub.Status - u.Billing.StripeSubscriptionPaidUntil = time.Unix(sub.CurrentPeriodEnd, 0) - u.Billing.StripeSubscriptionCancelAt = time.Unix(sub.CancelAt, 0) - if err := s.userManager.ChangeBilling(u); err != nil { + if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil { return err } - if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { - return err - } - accountURL := s.config.BaseURL + "/account" // FIXME - http.Redirect(w, r, accountURL, http.StatusSeeOther) + http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther) return nil } // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates // a user's tier accordingly. This endpoint only works if there is an existing subscription. func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user.Billing.StripeSubscriptionID != "" { + if v.user.Billing.StripeSubscriptionID == "" { return errors.New("no existing subscription for user") } req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit) @@ -161,10 +189,9 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r if err != nil { return err } - response := &apiAccountCheckoutResponse{} // FIXME w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this - if err := json.NewEncoder(w).Encode(response); err != nil { + if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil { return err } return nil @@ -250,6 +277,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() { return errHTTPBadRequestInvalidStripeRequest } + log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID) u, err := s.userManager.UserByStripeCustomer(customerID.String()) if err != nil { return err @@ -258,41 +286,47 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe if err != nil { return err } - if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { + if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil { return err } - u.Billing.StripeSubscriptionID = subscriptionID.String() - u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status.String()) - u.Billing.StripeSubscriptionPaidUntil = time.Unix(currentPeriodEnd.Int(), 0) - u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt.Int(), 0) - if err := s.userManager.ChangeBilling(u); err != nil { - return err - } - log.Info("Stripe: customer %s: subscription updated to %s, with price %s", customerID.String(), status, priceID) s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) return nil } func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error { - stripeCustomerID := gjson.GetBytes(event, "customer") - if !stripeCustomerID.Exists() { + customerID := gjson.GetBytes(event, "customer") + if !customerID.Exists() { return errHTTPBadRequestInvalidStripeRequest } - log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID.String()) - u, err := s.userManager.UserByStripeCustomer(stripeCustomerID.String()) + log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String()) + u, err := s.userManager.UserByStripeCustomer(customerID.String()) if err != nil { return err } - if err := s.userManager.ResetTier(u.Name); err != nil { - return err - } - u.Billing.StripeSubscriptionID = "" - u.Billing.StripeSubscriptionStatus = "" - u.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0) - u.Billing.StripeSubscriptionCancelAt = time.Unix(0, 0) - if err := s.userManager.ChangeBilling(u); err != nil { + if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil { return err } s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) return nil } + +func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error { + u.Billing.StripeCustomerID = customerID + u.Billing.StripeSubscriptionID = subscriptionID + u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status) + u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0) + u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0) + if tier == "" { + if err := s.userManager.ResetTier(u.Name); err != nil { + return err + } + } else { + if err := s.userManager.ChangeTier(u.Name, tier); err != nil { + return err + } + } + if err := s.userManager.ChangeBilling(u); err != nil { + return err + } + return nil +} diff --git a/server/types.go b/server/types.go index 0e37f55..99d521e 100644 --- a/server/types.go +++ b/server/types.go @@ -238,7 +238,6 @@ type apiAccountTokenResponse struct { type apiAccountTier struct { Code string `json:"code"` Name string `json:"name"` - Paid bool `json:"paid"` } type apiAccountLimits struct { @@ -305,14 +304,21 @@ type apiConfigResponse struct { DisallowedTopics []string `json:"disallowed_topics"` } -type apiAccountBillingSubscriptionChangeRequest struct { - Tier string `json:"tier"` +type apiAccountBillingTier struct { + Code string `json:"code"` + Name string `json:"name"` + Price string `json:"price"` + Features string `json:"features"` } -type apiAccountCheckoutResponse struct { +type apiAccountBillingSubscriptionCreateResponse struct { RedirectURL string `json:"redirect_url"` } +type apiAccountBillingSubscriptionChangeRequest struct { + Tier string `json:"tier"` +} + type apiAccountBillingPortalRedirectResponse struct { RedirectURL string `json:"redirect_url"` } @@ -320,3 +326,13 @@ type apiAccountBillingPortalRedirectResponse struct { type apiAccountSyncTopicResponse struct { Event string `json:"event"` } + +type apiSuccessResponse struct { + Success bool `json:"success"` +} + +func newSuccessResponse() *apiSuccessResponse { + return &apiSuccessResponse{ + Success: true, + } +} diff --git a/user/manager.go b/user/manager.go index de4fc74..a6086f6 100644 --- a/user/manager.go +++ b/user/manager.go @@ -45,6 +45,7 @@ const ( attachment_file_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL, attachment_expiry_duration INT NOT NULL, + features TEXT, stripe_price_id TEXT ); CREATE UNIQUE INDEX idx_tier_code ON tier (code); @@ -103,22 +104,22 @@ const ( ` selectUserByNameQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id FROM user u - LEFT JOIN tier p on p.id = u.tier_id + LEFT JOIN tier t on t.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id FROM user u JOIN user_token t on u.id = t.user_id - LEFT JOIN tier p on p.id = u.tier_id + LEFT JOIN tier t on t.id = u.tier_id WHERE t.token = ? AND t.expires >= ? ` selectUserByStripeCustomerIDQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id + SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id FROM user u - LEFT JOIN tier p on p.id = u.tier_id + LEFT JOIN tier t on t.id = u.tier_id WHERE u.stripe_customer_id = ? ` selectTopicPermsQuery = ` @@ -220,14 +221,18 @@ const ( INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ` - selectTierIDQuery = `SELECT id FROM tier WHERE code = ?` + selectTierIDQuery = `SELECT id FROM tier WHERE code = ?` + selectTiersQuery = ` + SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id + FROM tier + ` selectTierByCodeQuery = ` - SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id FROM tier WHERE code = ? ` selectTierByPriceIDQuery = ` - SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id FROM tier WHERE stripe_price_id = ? ` @@ -604,13 +609,13 @@ func (a *Manager) userByToken(token string) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var username, hash, role, prefs, syncTopic string - var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString + var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName, tierFeatures sql.NullString var messages, emails int64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &tierFeatures, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -649,7 +654,8 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, - StripePriceID: stripePriceID.String, + Features: tierFeatures.String, // May be empty + StripePriceID: stripePriceID.String, // May be empty } } return user, nil @@ -881,11 +887,31 @@ func (a *Manager) ChangeBilling(user *User) error { return nil } +func (a *Manager) Tiers() ([]*Tier, error) { + rows, err := a.db.Query(selectTiersQuery) + if err != nil { + return nil, err + } + defer rows.Close() + tiers := make([]*Tier, 0) + for { + tier, err := a.readTier(rows) + if err == ErrTierNotFound { + break + } else if err != nil { + return nil, err + } + tiers = append(tiers, tier) + } + return tiers, nil +} + func (a *Manager) Tier(code string) (*Tier, error) { rows, err := a.db.Query(selectTierByCodeQuery, code) if err != nil { return nil, err } + defer rows.Close() return a.readTier(rows) } @@ -894,18 +920,18 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { if err != nil { return nil, err } + defer rows.Close() return a.readTier(rows) } func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { - defer rows.Close() var code, name string - var stripePriceID sql.NullString + var features, stripePriceID sql.NullString var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 if !rows.Next() { return nil, ErrTierNotFound } - if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &features, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -922,7 +948,8 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, - StripePriceID: stripePriceID.String, // May be empty! + Features: features.String, // May be empty + StripePriceID: stripePriceID.String, // May be empty }, nil } diff --git a/user/types.go b/user/types.go index 2aca565..bc4b709 100644 --- a/user/types.go +++ b/user/types.go @@ -60,6 +60,7 @@ type Tier struct { AttachmentFileSizeLimit int64 AttachmentTotalSizeLimit int64 AttachmentExpiryDuration time.Duration + Features string StripePriceID string } diff --git a/web/public/static/langs/en.json b/web/public/static/langs/en.json index 0ae7e45..648b967 100644 --- a/web/public/static/langs/en.json +++ b/web/public/static/langs/en.json @@ -201,8 +201,9 @@ "account_delete_dialog_label": "Type '{{username}}' to delete account", "account_delete_dialog_button_cancel": "Cancel", "account_delete_dialog_button_submit": "Permanently delete account", - "account_upgrade_dialog_title": "Change billing plan", + "account_upgrade_dialog_title": "Change account tier", "account_upgrade_dialog_cancel_warning": "This will cancel your subscription, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server will be deleted.", + "account_upgrade_dialog_proration_info": "When switching between paid plans, the price difference will be charged or refunded in the next invoice.", "prefs_notifications_title": "Notifications", "prefs_notifications_sound_title": "Notification sound", "prefs_notifications_sound_description_none": "Notifications do not play any sound when they arrive", diff --git a/web/src/app/AccountApi.js b/web/src/app/AccountApi.js index ef9f13a..1f872a0 100644 --- a/web/src/app/AccountApi.js +++ b/web/src/app/AccountApi.js @@ -8,7 +8,7 @@ import { accountTokenUrl, accountUrl, maybeWithAuth, topicUrl, withBasicAuth, - withBearerAuth, accountBillingSubscriptionUrl, accountBillingPortalUrl + withBearerAuth, accountBillingSubscriptionUrl, accountBillingPortalUrl, accountBillingTiersUrl } from "./utils"; import session from "./Session"; import subscriptionManager from "./SubscriptionManager"; @@ -264,6 +264,20 @@ class AccountApi { this.triggerChange(); // Dangle! } + async billingTiers() { + const url = accountBillingTiersUrl(config.base_url); + console.log(`[AccountApi] Fetching billing tiers`); + const response = await fetch(url, { + headers: withBearerAuth({}, session.token()) + }); + if (response.status === 401 || response.status === 403) { + throw new UnauthorizedError(); + } else if (response.status !== 200) { + throw new Error(`Unexpected server response ${response.status}`); + } + return await response.json(); + } + async createBillingSubscription(tier) { console.log(`[AccountApi] Creating billing subscription with ${tier}`); return await this.upsertBillingSubscription("POST", tier) diff --git a/web/src/app/utils.js b/web/src/app/utils.js index b34af7e..5e121cc 100644 --- a/web/src/app/utils.js +++ b/web/src/app/utils.js @@ -28,6 +28,7 @@ export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reserva export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`; export const accountBillingSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/billing/subscription`; export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`; +export const accountBillingTiersUrl = (baseUrl) => `${baseUrl}/v1/account/billing/tiers`; export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, ""); export const expandUrl = (url) => [`https://${url}`, `http://${url}`]; export const expandSecureUrl = (url) => `https://${url}`; diff --git a/web/src/components/App.js b/web/src/components/App.js index 331119a..970f111 100644 --- a/web/src/components/App.js +++ b/web/src/components/App.js @@ -69,8 +69,9 @@ const Layout = () => { const [sendDialogOpenMode, setSendDialogOpenMode] = useState(""); const users = useLiveQuery(() => userManager.all()); const subscriptions = useLiveQuery(() => subscriptionManager.all()); - const newNotificationsCount = subscriptions?.reduce((prev, cur) => prev + cur.new, 0) || 0; - const [selected] = (subscriptions || []).filter(s => { + const subscriptionsWithoutInternal = subscriptions?.filter(s => !s.internal); + const newNotificationsCount = subscriptionsWithoutInternal?.reduce((prev, cur) => prev + cur.new, 0) || 0; + const [selected] = (subscriptionsWithoutInternal || []).filter(s => { return (params.baseUrl && expandUrl(params.baseUrl).includes(s.baseUrl) && params.topic === s.topic) || (config.base_url === s.baseUrl && params.topic === s.topic) }); @@ -101,7 +102,7 @@ const Layout = () => { onMobileDrawerToggle={() => setMobileDrawerOpen(!mobileDrawerOpen)} /> { />
- +
{ const { t } = useTranslation(); const { account } = useContext(AccountContext); const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); + const [tiers, setTiers] = useState(null); const [newTier, setNewTier] = useState(account?.tier?.code || null); const [errorText, setErrorText] = useState(""); - if (!account) { + useEffect(() => { + (async () => { + setTiers(await accountApi.billingTiers()); + })(); + }, []); + + if (!account || !tiers) { return <>; } @@ -69,22 +77,43 @@ const UpgradeDialog = (props) => { return ( - Change billing plan + {t("account_upgrade_dialog_title")}
- setNewTier(null)}/> - setNewTier("starter")}/> - setNewTier("pro")}/> - setNewTier("business")}/> + setNewTier(null)} + /> + {tiers.map(tier => + setNewTier(tier.code)} + /> + )}
{action === Action.CANCEL && {t("account_upgrade_dialog_cancel_warning", { date: formatShortDate(account.billing.paid_until) })} } + {action === Action.UPDATE && + + {t("account_upgrade_dialog_proration_info")} + + }
@@ -95,20 +124,42 @@ const UpgradeDialog = (props) => { }; const TierCard = (props) => { - const cardStyle = (props.selected) ? { - background: "#eee" - } : {}; + const cardStyle = (props.selected) ? { background: "#eee", border: "2px solid #338574" } : {}; return ( - - - + + + + {props.selected && +
Selected
+ } {props.name} - - Lizards are a widespread group of squamate reptiles, with over 6,000 - species, ranging across all continents except Antarctica - + {props.features && + + {props.features} + + } + {props.price && + + {props.price} / month + + }