Fix a bunch of FIXMEs
This commit is contained in:
parent
f945fb4cdd
commit
3bd6518309
15 changed files with 269 additions and 182 deletions
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -11,19 +12,15 @@ import (
|
|||
"github.com/stripe/stripe-go/v74/price"
|
||||
"github.com/stripe/stripe-go/v74/subscription"
|
||||
"github.com/stripe/stripe-go/v74/webhook"
|
||||
"github.com/tidwall/gjson"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
stripeBodyBytesLimit = 16384
|
||||
)
|
||||
|
||||
var (
|
||||
errNotAPaidTier = errors.New("tier does not have billing price identifier")
|
||||
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
|
||||
|
@ -52,23 +49,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||
},
|
||||
},
|
||||
}
|
||||
prices, err := s.priceCache.Value()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, tier := range tiers {
|
||||
if tier.StripePriceID == "" {
|
||||
priceStr, ok := prices[tier.StripePriceID]
|
||||
if tier.StripePriceID == "" || !ok {
|
||||
continue
|
||||
}
|
||||
priceStr, ok := s.priceCache[tier.StripePriceID]
|
||||
if !ok {
|
||||
p, err := price.Get(tier.StripePriceID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if p.UnitAmount%100 == 0 {
|
||||
priceStr = fmt.Sprintf("$%d", p.UnitAmount/100)
|
||||
} else {
|
||||
priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
||||
}
|
||||
s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something
|
||||
}
|
||||
response = append(response, &apiAccountBillingTier{
|
||||
Code: tier.Code,
|
||||
Name: tier.Name,
|
||||
|
@ -84,12 +73,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||
},
|
||||
})
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
|
||||
|
@ -143,12 +127,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
response := &apiAccountBillingSubscriptionCreateResponse{
|
||||
RedirectURL: sess.URL,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
|
@ -219,12 +198,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
||||
|
@ -239,12 +213,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
|||
return err
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -262,12 +231,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
|||
response := &apiAccountBillingPortalRedirectResponse{
|
||||
RedirectURL: ps.URL,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
||||
|
@ -278,7 +242,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|||
if stripeSignature == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
|
||||
body, err := util.Peek(r.Body, jsonBodyBytesLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if body.LimitReached {
|
||||
|
@ -302,25 +266,23 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
||||
subscriptionID := gjson.GetBytes(event, "id")
|
||||
customerID := gjson.GetBytes(event, "customer")
|
||||
status := gjson.GetBytes(event, "status")
|
||||
currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
|
||||
cancelAt := gjson.GetBytes(event, "cancel_at")
|
||||
priceID := gjson.GetBytes(event, "items.data.0.price.id")
|
||||
if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
|
||||
r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID)
|
||||
u, err := s.userManager.UserByStripeCustomer(customerID.String())
|
||||
subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
|
||||
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
|
||||
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tier, err := s.userManager.TierByStripePrice(priceID.String())
|
||||
tier, err := s.userManager.TierByStripePrice(priceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
|
@ -328,16 +290,18 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
||||
customerID := gjson.GetBytes(event, "customer")
|
||||
if !customerID.Exists() {
|
||||
r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r.Customer == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String())
|
||||
u, err := s.userManager.UserByStripeCustomer(customerID.String())
|
||||
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
|
||||
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
|
@ -364,3 +328,27 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
|
||||
// in memory, and ultimately for the web app to display the price table.
|
||||
func fetchStripePrices() (map[string]string, error) {
|
||||
log.Debug("Caching prices from Stripe API")
|
||||
prices := make(map[string]string)
|
||||
iter := price.List(&stripe.PriceListParams{
|
||||
Active: stripe.Bool(true),
|
||||
})
|
||||
for iter.Next() {
|
||||
p := iter.Price()
|
||||
if p.UnitAmount%100 == 0 {
|
||||
prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
|
||||
} else {
|
||||
prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
||||
}
|
||||
log.Trace("- Caching price %s = %v", p.ID, prices[p.ID])
|
||||
}
|
||||
if iter.Err() != nil {
|
||||
log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error())
|
||||
return nil, iter.Err()
|
||||
}
|
||||
return prices, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue