Fix a bunch of FIXMEs

This commit is contained in:
binwiederhier 2023-01-18 15:50:06 -05:00
parent f945fb4cdd
commit 3bd6518309
15 changed files with 269 additions and 182 deletions

View file

@ -1,6 +1,7 @@
package server
import (
"bytes"
"encoding/json"
"errors"
"fmt"
@ -11,19 +12,15 @@ import (
"github.com/stripe/stripe-go/v74/price"
"github.com/stripe/stripe-go/v74/subscription"
"github.com/stripe/stripe-go/v74/webhook"
"github.com/tidwall/gjson"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"net/http"
"net/netip"
"time"
)
const (
stripeBodyBytesLimit = 16384
)
var (
errNotAPaidTier = errors.New("tier does not have billing price identifier")
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
@ -52,23 +49,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
},
},
}
prices, err := s.priceCache.Value()
if err != nil {
return err
}
for _, tier := range tiers {
if tier.StripePriceID == "" {
priceStr, ok := prices[tier.StripePriceID]
if tier.StripePriceID == "" || !ok {
continue
}
priceStr, ok := s.priceCache[tier.StripePriceID]
if !ok {
p, err := price.Get(tier.StripePriceID, nil)
if err != nil {
return err
}
if p.UnitAmount%100 == 0 {
priceStr = fmt.Sprintf("$%d", p.UnitAmount/100)
} else {
priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
}
s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something
}
response = append(response, &apiAccountBillingTier{
Code: tier.Code,
Name: tier.Name,
@ -84,12 +73,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
},
})
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
return s.writeJSON(w, response)
}
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
@ -143,12 +127,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
response := &apiAccountBillingSubscriptionCreateResponse{
RedirectURL: sess.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
return s.writeJSON(w, response)
}
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
@ -219,12 +198,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
return err
}
return nil
return s.writeJSON(w, newSuccessResponse())
}
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
@ -239,12 +213,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
return err
}
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
return err
}
return nil
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
@ -262,12 +231,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
response := &apiAccountBillingPortalRedirectResponse{
RedirectURL: ps.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
return s.writeJSON(w, response)
}
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
@ -278,7 +242,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
if stripeSignature == "" {
return errHTTPBadRequestBillingRequestInvalid
}
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
body, err := util.Peek(r.Body, jsonBodyBytesLimit)
if err != nil {
return err
} else if body.LimitReached {
@ -302,25 +266,23 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
}
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
subscriptionID := gjson.GetBytes(event, "id")
customerID := gjson.GetBytes(event, "customer")
status := gjson.GetBytes(event, "status")
currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
cancelAt := gjson.GetBytes(event, "cancel_at")
priceID := gjson.GetBytes(event, "items.data.0.price.id")
if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
if err != nil {
return err
} else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID)
u, err := s.userManager.UserByStripeCustomer(customerID.String())
subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
u, err := s.userManager.UserByStripeCustomer(r.Customer)
if err != nil {
return err
}
tier, err := s.userManager.TierByStripePrice(priceID.String())
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil {
if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
@ -328,16 +290,18 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
}
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
customerID := gjson.GetBytes(event, "customer")
if !customerID.Exists() {
r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
if err != nil {
return err
} else if r.Customer == "" {
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String())
u, err := s.userManager.UserByStripeCustomer(customerID.String())
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
u, err := s.userManager.UserByStripeCustomer(r.Customer)
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil {
if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
@ -364,3 +328,27 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio
}
return nil
}
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
// in memory, and ultimately for the web app to display the price table.
func fetchStripePrices() (map[string]string, error) {
log.Debug("Caching prices from Stripe API")
prices := make(map[string]string)
iter := price.List(&stripe.PriceListParams{
Active: stripe.Bool(true),
})
for iter.Next() {
p := iter.Price()
if p.UnitAmount%100 == 0 {
prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
} else {
prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
}
log.Trace("- Caching price %s = %v", p.ID, prices[p.ID])
}
if iter.Err() != nil {
log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error())
return nil, iter.Err()
}
return prices, nil
}