Fix a bunch of FIXMEs
This commit is contained in:
parent
f945fb4cdd
commit
3bd6518309
15 changed files with 269 additions and 182 deletions
|
@ -19,6 +19,7 @@ const (
|
|||
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
|
||||
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
|
||||
DefaultFirebaseQuotaExceededPenaltyDuration = 10 * time.Minute // Time that over-users are locked out of Firebase if it returns "quota exceeded"
|
||||
DefaultStripePriceCacheDuration = time.Hour // Time to keep Stripe prices cached in memory before a refresh is needed
|
||||
)
|
||||
|
||||
// Defines all global and per-visitor limits
|
||||
|
@ -112,10 +113,12 @@ type Config struct {
|
|||
BehindProxy bool
|
||||
StripeSecretKey string
|
||||
StripeWebhookKey string
|
||||
StripePriceCacheDuration time.Duration
|
||||
EnableWeb bool
|
||||
EnableSignup bool // Enable creation of accounts via API and UI
|
||||
EnableLogin bool
|
||||
EnableReservations bool // Allow users with role "user" to own/reserve topics
|
||||
AccessControlAllowOrigin string // CORS header field to restrict access from web clients
|
||||
Version string // injected by App
|
||||
}
|
||||
|
||||
|
@ -132,9 +135,11 @@ func NewConfig() *Config {
|
|||
FirebaseKeyFile: "",
|
||||
CacheFile: "",
|
||||
CacheDuration: DefaultCacheDuration,
|
||||
CacheStartupQueries: "",
|
||||
CacheBatchSize: 0,
|
||||
CacheBatchTimeout: 0,
|
||||
AuthFile: "",
|
||||
AuthStartupQueries: "",
|
||||
AuthDefault: user.NewPermission(true, true),
|
||||
AttachmentCacheDir: "",
|
||||
AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit,
|
||||
|
@ -142,14 +147,24 @@ func NewConfig() *Config {
|
|||
AttachmentExpiryDuration: DefaultAttachmentExpiryDuration,
|
||||
KeepaliveInterval: DefaultKeepaliveInterval,
|
||||
ManagerInterval: DefaultManagerInterval,
|
||||
MessageLimit: DefaultMessageLengthLimit,
|
||||
MinDelay: DefaultMinDelay,
|
||||
MaxDelay: DefaultMaxDelay,
|
||||
WebRootIsApp: false,
|
||||
DelayedSenderInterval: DefaultDelayedSenderInterval,
|
||||
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
|
||||
FirebasePollInterval: DefaultFirebasePollInterval,
|
||||
FirebaseQuotaExceededPenaltyDuration: DefaultFirebaseQuotaExceededPenaltyDuration,
|
||||
UpstreamBaseURL: "",
|
||||
SMTPSenderAddr: "",
|
||||
SMTPSenderUser: "",
|
||||
SMTPSenderPass: "",
|
||||
SMTPSenderFrom: "",
|
||||
SMTPServerListen: "",
|
||||
SMTPServerDomain: "",
|
||||
SMTPServerAddrPrefix: "",
|
||||
MessageLimit: DefaultMessageLengthLimit,
|
||||
MinDelay: DefaultMinDelay,
|
||||
MaxDelay: DefaultMaxDelay,
|
||||
TotalTopicLimit: DefaultTotalTopicLimit,
|
||||
TotalAttachmentSizeLimit: 0,
|
||||
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
|
||||
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,
|
||||
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
|
||||
|
@ -162,7 +177,14 @@ func NewConfig() *Config {
|
|||
VisitorAccountCreateLimitReplenish: DefaultVisitorAccountCreateLimitReplenish,
|
||||
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
|
||||
BehindProxy: false,
|
||||
StripeSecretKey: "",
|
||||
StripeWebhookKey: "",
|
||||
StripePriceCacheDuration: DefaultStripePriceCacheDuration,
|
||||
EnableWeb: true,
|
||||
EnableSignup: false,
|
||||
EnableLogin: false,
|
||||
EnableReservations: false,
|
||||
AccessControlAllowOrigin: "*",
|
||||
Version: "",
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,21 +39,18 @@ import (
|
|||
payments:
|
||||
- send dunning emails when overdue
|
||||
- payment methods
|
||||
- unmarshal to stripe.Subscription instead of gjson
|
||||
- delete subscription when account deleted
|
||||
- delete messages + reserved topics on ResetTier
|
||||
|
||||
- move v1/account/tiers to v1/tiers
|
||||
|
||||
Limits & rate limiting:
|
||||
users without tier: should the stats be persisted? are they meaningful?
|
||||
-> test that the visitor is based on the IP address!
|
||||
login/account endpoints
|
||||
when ResetStats() is run, reset messagesLimiter (and others)?
|
||||
update last_seen when API is accessed
|
||||
Make sure account endpoints make sense for admins
|
||||
|
||||
UI:
|
||||
- revert home page change
|
||||
- flicker of upgrade banner
|
||||
- JS constants
|
||||
Sync:
|
||||
|
@ -82,7 +79,7 @@ type Server struct {
|
|||
userManager *user.Manager // Might be nil!
|
||||
messageCache *messageCache
|
||||
fileCache *fileCache
|
||||
priceCache map[string]string // Stripe price ID -> formatted price
|
||||
priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
|
||||
closeChan chan bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
@ -144,7 +141,8 @@ const (
|
|||
emptyMessageBody = "triggered" // Used if message body is empty
|
||||
newMessageBody = "New message" // Used in poll requests as generic message
|
||||
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
|
||||
encodingBase64 = "base64"
|
||||
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
||||
jsonBodyBytesLimit = 16384
|
||||
)
|
||||
|
||||
// WebSocket constants
|
||||
|
@ -201,7 +199,7 @@ func New(conf *Config) (*Server, error) {
|
|||
topics: topics,
|
||||
userManager: userManager,
|
||||
visitors: make(map[string]*visitor),
|
||||
priceCache: make(map[string]string),
|
||||
priceCache: util.NewLookupCache(fetchStripePrices, conf.StripePriceCacheDuration),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -454,22 +452,14 @@ func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor)
|
|||
}
|
||||
|
||||
func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
_, err := io.WriteString(w, `{"success":true}`+"\n")
|
||||
return err
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
response := &apiHealthResponse{
|
||||
Healthy: true,
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
|
@ -620,12 +610,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
if err := json.NewEncoder(w).Encode(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, m)
|
||||
}
|
||||
|
||||
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -1175,8 +1160,8 @@ func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
|
|||
|
||||
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1482,7 +1467,7 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
|
|||
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
|
||||
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit)
|
||||
m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit*2) // 2x to account for JSON format overhead
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1650,3 +1635,12 @@ func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
|||
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
|
||||
return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user)
|
||||
}
|
||||
|
||||
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
jsonBodyBytesLimit = 4096
|
||||
subscriptionIDLength = 16
|
||||
createdByAPI = "api"
|
||||
syncTopicAccountSyncEvent = "sync"
|
||||
|
@ -38,9 +37,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
|||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||
|
@ -118,21 +115,14 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
|||
response.Username = user.Everyone
|
||||
response.Role = string(user.RoleAnonymous)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||
if err := s.userManager.RemoveUser(v.user.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -143,9 +133,7 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ
|
|||
if err := s.userManager.ChangePassword(v.user.Name, newPassword.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountTokenIssue(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||
|
@ -154,16 +142,11 @@ func (s *Server) handleAccountTokenIssue(w http.ResponseWriter, _ *http.Request,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
response := &apiAccountTokenResponse{
|
||||
Token: token.Value,
|
||||
Expires: token.Expires.Unix(),
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||
|
@ -177,16 +160,11 @@ func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, _ *http.Request
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
response := &apiAccountTokenResponse{
|
||||
Token: token.Value,
|
||||
Expires: token.Expires.Unix(),
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||
|
@ -197,8 +175,7 @@ func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, _ *http.Request
|
|||
if err := s.userManager.RemoveToken(v.user); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -230,9 +207,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
|
|||
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -257,12 +232,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
|
|||
return err
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(newSubscription); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, newSubscription)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -292,12 +262,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
|
|||
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(subscription); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, subscription)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -321,9 +286,7 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
|
|||
return err
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -366,9 +329,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
|||
if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyone.IsRead(), everyone.IsWrite()); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -392,9 +353,7 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
|||
if err := s.userManager.ResetAccess(user.Everyone, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) publishSyncEvent(v *visitor) error {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -11,19 +12,15 @@ import (
|
|||
"github.com/stripe/stripe-go/v74/price"
|
||||
"github.com/stripe/stripe-go/v74/subscription"
|
||||
"github.com/stripe/stripe-go/v74/webhook"
|
||||
"github.com/tidwall/gjson"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
stripeBodyBytesLimit = 16384
|
||||
)
|
||||
|
||||
var (
|
||||
errNotAPaidTier = errors.New("tier does not have billing price identifier")
|
||||
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
|
||||
|
@ -52,23 +49,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||
},
|
||||
},
|
||||
}
|
||||
prices, err := s.priceCache.Value()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, tier := range tiers {
|
||||
if tier.StripePriceID == "" {
|
||||
priceStr, ok := prices[tier.StripePriceID]
|
||||
if tier.StripePriceID == "" || !ok {
|
||||
continue
|
||||
}
|
||||
priceStr, ok := s.priceCache[tier.StripePriceID]
|
||||
if !ok {
|
||||
p, err := price.Get(tier.StripePriceID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if p.UnitAmount%100 == 0 {
|
||||
priceStr = fmt.Sprintf("$%d", p.UnitAmount/100)
|
||||
} else {
|
||||
priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
||||
}
|
||||
s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something
|
||||
}
|
||||
response = append(response, &apiAccountBillingTier{
|
||||
Code: tier.Code,
|
||||
Name: tier.Name,
|
||||
|
@ -84,12 +73,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||
},
|
||||
})
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
|
||||
|
@ -143,12 +127,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
response := &apiAccountBillingSubscriptionCreateResponse{
|
||||
RedirectURL: sess.URL,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
|
@ -219,12 +198,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
||||
|
@ -239,12 +213,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
|||
return err
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
|
@ -262,12 +231,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
|||
response := &apiAccountBillingPortalRedirectResponse{
|
||||
RedirectURL: ps.URL,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
||||
|
@ -278,7 +242,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|||
if stripeSignature == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
|
||||
body, err := util.Peek(r.Body, jsonBodyBytesLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if body.LimitReached {
|
||||
|
@ -302,25 +266,23 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
||||
subscriptionID := gjson.GetBytes(event, "id")
|
||||
customerID := gjson.GetBytes(event, "customer")
|
||||
status := gjson.GetBytes(event, "status")
|
||||
currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
|
||||
cancelAt := gjson.GetBytes(event, "cancel_at")
|
||||
priceID := gjson.GetBytes(event, "items.data.0.price.id")
|
||||
if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
|
||||
r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID)
|
||||
u, err := s.userManager.UserByStripeCustomer(customerID.String())
|
||||
subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
|
||||
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
|
||||
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tier, err := s.userManager.TierByStripePrice(priceID.String())
|
||||
tier, err := s.userManager.TierByStripePrice(priceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
|
@ -328,16 +290,18 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
||||
customerID := gjson.GetBytes(event, "customer")
|
||||
if !customerID.Exists() {
|
||||
r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r.Customer == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String())
|
||||
u, err := s.userManager.UserByStripeCustomer(customerID.String())
|
||||
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
|
||||
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
|
@ -364,3 +328,27 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
|
||||
// in memory, and ultimately for the web app to display the price table.
|
||||
func fetchStripePrices() (map[string]string, error) {
|
||||
log.Debug("Caching prices from Stripe API")
|
||||
prices := make(map[string]string)
|
||||
iter := price.List(&stripe.PriceListParams{
|
||||
Active: stripe.Bool(true),
|
||||
})
|
||||
for iter.Next() {
|
||||
p := iter.Price()
|
||||
if p.UnitAmount%100 == 0 {
|
||||
prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
|
||||
} else {
|
||||
prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
||||
}
|
||||
log.Trace("- Caching price %s = %v", p.ID, prices[p.ID])
|
||||
}
|
||||
if iter.Err() != nil {
|
||||
log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error())
|
||||
return nil, iter.Err()
|
||||
}
|
||||
return prices, nil
|
||||
}
|
||||
|
|
|
@ -1463,7 +1463,7 @@ func TestServer_PublishAttachmentBandwidthLimit(t *testing.T) {
|
|||
msg := toMessage(t, response.Body.String())
|
||||
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
|
||||
|
||||
// Get it 4 times successfully
|
||||
// Value it 4 times successfully
|
||||
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
|
||||
for i := 1; i <= 4; i++ { // 4 successful downloads
|
||||
response = request(t, s, "GET", path, "", nil)
|
||||
|
|
|
@ -336,3 +336,22 @@ func newSuccessResponse() *apiSuccessResponse {
|
|||
Success: true,
|
||||
}
|
||||
}
|
||||
|
||||
type apiStripeSubscriptionUpdatedEvent struct {
|
||||
ID string `json:"id"`
|
||||
Customer string `json:"customer"`
|
||||
Status string `json:"status"`
|
||||
CurrentPeriodEnd int64 `json:"current_period_end"`
|
||||
CancelAt int64 `json:"cancel_at"`
|
||||
Items *struct {
|
||||
Data []*struct {
|
||||
Price *struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"price"`
|
||||
} `json:"data"`
|
||||
} `json:"items"`
|
||||
}
|
||||
|
||||
type apiStripeSubscriptionDeletedEvent struct {
|
||||
Customer string `json:"customer"`
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue