No more v.user races
This commit is contained in:
parent
e596834096
commit
92d563371c
5 changed files with 87 additions and 77 deletions
|
@ -39,7 +39,6 @@ import (
|
||||||
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||||
- HIGH Stripe payment methods
|
- HIGH Stripe payment methods
|
||||||
- MEDIUM: Test new token endpoints & never-expiring token
|
- MEDIUM: Test new token endpoints & never-expiring token
|
||||||
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
|
|
||||||
- MEDIUM: Test that anonymous user and user without tier are the same visitor
|
- MEDIUM: Test that anonymous user and user without tier are the same visitor
|
||||||
- MEDIUM: Make sure account endpoints make sense for admins
|
- MEDIUM: Make sure account endpoints make sense for admins
|
||||||
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||||
|
|
|
@ -19,11 +19,12 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
admin := v.user != nil && v.user.Role == user.RoleAdmin
|
u := v.User()
|
||||||
|
admin := u != nil && u.Role == user.RoleAdmin
|
||||||
if !admin {
|
if !admin {
|
||||||
if !s.config.EnableSignup {
|
if !s.config.EnableSignup {
|
||||||
return errHTTPBadRequestSignupNotEnabled
|
return errHTTPBadRequestSignupNotEnabled
|
||||||
} else if v.user != nil {
|
} else if u != nil {
|
||||||
return errHTTPUnauthorized // Cannot create account from user context
|
return errHTTPUnauthorized // Cannot create account from user context
|
||||||
}
|
}
|
||||||
if !v.AccountCreationAllowed() {
|
if !v.AccountCreationAllowed() {
|
||||||
|
@ -150,20 +151,21 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
|
||||||
} else if req.Password == "" {
|
} else if req.Password == "" {
|
||||||
return errHTTPBadRequest
|
return errHTTPBadRequest
|
||||||
}
|
}
|
||||||
if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil {
|
u := v.User()
|
||||||
|
if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil {
|
||||||
return errHTTPBadRequestIncorrectPasswordConfirmation
|
return errHTTPBadRequestIncorrectPasswordConfirmation
|
||||||
}
|
}
|
||||||
if v.user.Billing.StripeSubscriptionID != "" {
|
if u.Billing.StripeSubscriptionID != "" {
|
||||||
log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
|
log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID)
|
||||||
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
|
if _, err := s.stripe.CancelSubscription(u.Billing.StripeSubscriptionID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil {
|
if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), u, 0); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name)
|
log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), u.Name)
|
||||||
if err := s.userManager.MarkUserRemoved(v.user); err != nil {
|
if err := s.userManager.MarkUserRemoved(u); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.writeJSON(w, newSuccessResponse())
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
|
@ -176,10 +178,11 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ
|
||||||
} else if req.Password == "" || req.NewPassword == "" {
|
} else if req.Password == "" || req.NewPassword == "" {
|
||||||
return errHTTPBadRequest
|
return errHTTPBadRequest
|
||||||
}
|
}
|
||||||
if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil {
|
u := v.User()
|
||||||
|
if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil {
|
||||||
return errHTTPBadRequestIncorrectPasswordConfirmation
|
return errHTTPBadRequestIncorrectPasswordConfirmation
|
||||||
}
|
}
|
||||||
if err := s.userManager.ChangePassword(v.user.Name, req.NewPassword); err != nil {
|
if err := s.userManager.ChangePassword(u.Name, req.NewPassword); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.writeJSON(w, newSuccessResponse())
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
|
@ -267,10 +270,11 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if v.user.Prefs == nil {
|
u := v.User()
|
||||||
v.user.Prefs = &user.Prefs{}
|
if u.Prefs == nil {
|
||||||
|
u.Prefs = &user.Prefs{}
|
||||||
}
|
}
|
||||||
prefs := v.user.Prefs
|
prefs := u.Prefs
|
||||||
if newPrefs.Language != nil {
|
if newPrefs.Language != nil {
|
||||||
prefs.Language = newPrefs.Language
|
prefs.Language = newPrefs.Language
|
||||||
}
|
}
|
||||||
|
@ -288,7 +292,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
|
||||||
prefs.Notification.MinPriority = newPrefs.Notification.MinPriority
|
prefs.Notification.MinPriority = newPrefs.Notification.MinPriority
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
if err := s.userManager.ChangeSettings(u); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.writeJSON(w, newSuccessResponse())
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
|
@ -299,11 +303,12 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if v.user.Prefs == nil {
|
u := v.User()
|
||||||
v.user.Prefs = &user.Prefs{}
|
if u.Prefs == nil {
|
||||||
|
u.Prefs = &user.Prefs{}
|
||||||
}
|
}
|
||||||
newSubscription.ID = "" // Client cannot set ID
|
newSubscription.ID = "" // Client cannot set ID
|
||||||
for _, subscription := range v.user.Prefs.Subscriptions {
|
for _, subscription := range u.Prefs.Subscriptions {
|
||||||
if newSubscription.BaseURL == subscription.BaseURL && newSubscription.Topic == subscription.Topic {
|
if newSubscription.BaseURL == subscription.BaseURL && newSubscription.Topic == subscription.Topic {
|
||||||
newSubscription = subscription
|
newSubscription = subscription
|
||||||
break
|
break
|
||||||
|
@ -311,8 +316,8 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
|
||||||
}
|
}
|
||||||
if newSubscription.ID == "" {
|
if newSubscription.ID == "" {
|
||||||
newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||||
v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, newSubscription)
|
u.Prefs.Subscriptions = append(u.Prefs.Subscriptions, newSubscription)
|
||||||
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
if err := s.userManager.ChangeSettings(u); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -329,11 +334,12 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil {
|
u := v.User()
|
||||||
|
if u.Prefs == nil || u.Prefs.Subscriptions == nil {
|
||||||
return errHTTPNotFound
|
return errHTTPNotFound
|
||||||
}
|
}
|
||||||
var subscription *user.Subscription
|
var subscription *user.Subscription
|
||||||
for _, sub := range v.user.Prefs.Subscriptions {
|
for _, sub := range u.Prefs.Subscriptions {
|
||||||
if sub.ID == subscriptionID {
|
if sub.ID == subscriptionID {
|
||||||
sub.DisplayName = updatedSubscription.DisplayName
|
sub.DisplayName = updatedSubscription.DisplayName
|
||||||
subscription = sub
|
subscription = sub
|
||||||
|
@ -343,7 +349,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
|
||||||
if subscription == nil {
|
if subscription == nil {
|
||||||
return errHTTPNotFound
|
return errHTTPNotFound
|
||||||
}
|
}
|
||||||
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
if err := s.userManager.ChangeSettings(u); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.writeJSON(w, subscription)
|
return s.writeJSON(w, subscription)
|
||||||
|
@ -355,18 +361,19 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
|
||||||
return errHTTPInternalErrorInvalidPath
|
return errHTTPInternalErrorInvalidPath
|
||||||
}
|
}
|
||||||
subscriptionID := matches[1]
|
subscriptionID := matches[1]
|
||||||
if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil {
|
u := v.User()
|
||||||
|
if u.Prefs == nil || u.Prefs.Subscriptions == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
newSubscriptions := make([]*user.Subscription, 0)
|
newSubscriptions := make([]*user.Subscription, 0)
|
||||||
for _, subscription := range v.user.Prefs.Subscriptions {
|
for _, subscription := range u.Prefs.Subscriptions {
|
||||||
if subscription.ID != subscriptionID {
|
if subscription.ID != subscriptionID {
|
||||||
newSubscriptions = append(newSubscriptions, subscription)
|
newSubscriptions = append(newSubscriptions, subscription)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(newSubscriptions) < len(v.user.Prefs.Subscriptions) {
|
if len(newSubscriptions) < len(u.Prefs.Subscriptions) {
|
||||||
v.user.Prefs.Subscriptions = newSubscriptions
|
u.Prefs.Subscriptions = newSubscriptions
|
||||||
if err := s.userManager.ChangeSettings(v.user); err != nil {
|
if err := s.userManager.ChangeSettings(u); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -374,7 +381,8 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if v.user != nil && v.user.Role == user.RoleAdmin {
|
u := v.User()
|
||||||
|
if u != nil && u.Role == user.RoleAdmin {
|
||||||
return errHTTPBadRequestMakesNoSenseForAdmin
|
return errHTTPBadRequestMakesNoSenseForAdmin
|
||||||
}
|
}
|
||||||
req, err := readJSONWithLimit[apiAccountReservationRequest](r.Body, jsonBodyBytesLimit, false)
|
req, err := readJSONWithLimit[apiAccountReservationRequest](r.Body, jsonBodyBytesLimit, false)
|
||||||
|
@ -388,27 +396,27 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errHTTPBadRequestPermissionInvalid
|
return errHTTPBadRequestPermissionInvalid
|
||||||
}
|
}
|
||||||
if v.user.Tier == nil {
|
if u.Tier == nil {
|
||||||
return errHTTPUnauthorized
|
return errHTTPUnauthorized
|
||||||
}
|
}
|
||||||
// CHeck if we are allowed to reserve this topic
|
// CHeck if we are allowed to reserve this topic
|
||||||
if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil {
|
if err := s.userManager.CheckAllowAccess(u.Name, req.Topic); err != nil {
|
||||||
return errHTTPConflictTopicReserved
|
return errHTTPConflictTopicReserved
|
||||||
}
|
}
|
||||||
hasReservation, err := s.userManager.HasReservation(v.user.Name, req.Topic)
|
hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !hasReservation {
|
if !hasReservation {
|
||||||
reservations, err := s.userManager.ReservationsCount(v.user.Name)
|
reservations, err := s.userManager.ReservationsCount(u.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if reservations >= v.user.Tier.ReservationLimit {
|
} else if reservations >= u.Tier.ReservationLimit {
|
||||||
return errHTTPTooManyRequestsLimitReservations
|
return errHTTPTooManyRequestsLimitReservations
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Actually add the reservation
|
// Actually add the reservation
|
||||||
if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil {
|
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Kill existing subscribers
|
// Kill existing subscribers
|
||||||
|
@ -416,7 +424,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
t.CancelSubscribers(v.user.ID)
|
t.CancelSubscribers(u.ID)
|
||||||
return s.writeJSON(w, newSuccessResponse())
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -429,13 +437,14 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
||||||
if !topicRegex.MatchString(topic) {
|
if !topicRegex.MatchString(topic) {
|
||||||
return errHTTPBadRequestTopicInvalid
|
return errHTTPBadRequestTopicInvalid
|
||||||
}
|
}
|
||||||
authorized, err := s.userManager.HasReservation(v.user.Name, topic)
|
u := v.User()
|
||||||
|
authorized, err := s.userManager.HasReservation(u.Name, topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if !authorized {
|
} else if !authorized {
|
||||||
return errHTTPUnauthorized
|
return errHTTPUnauthorized
|
||||||
}
|
}
|
||||||
if err := s.userManager.RemoveReservations(v.user.Name, topic); err != nil {
|
if err := s.userManager.RemoveReservations(u.Name, topic); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.writeJSON(w, newSuccessResponse())
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
|
@ -465,12 +474,23 @@ func (s *Server) maybeRemoveMessagesAndExcessReservations(logPrefix string, u *u
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// publishSyncEventAsync kicks of a Go routine to publish a sync message to the user's sync topic
|
||||||
|
func (s *Server) publishSyncEventAsync(v *visitor) {
|
||||||
|
go func() {
|
||||||
|
if err := s.publishSyncEvent(v); err != nil {
|
||||||
|
log.Trace("%s Error publishing to user's sync topic: %s", v.String(), err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// publishSyncEvent publishes a sync message to the user's sync topic
|
||||||
func (s *Server) publishSyncEvent(v *visitor) error {
|
func (s *Server) publishSyncEvent(v *visitor) error {
|
||||||
if v.user == nil || v.user.SyncTopic == "" {
|
u := v.User()
|
||||||
|
if u == nil || u.SyncTopic == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic)
|
log.Trace("Publishing sync event to user %s's sync topic %s", u.Name, u.SyncTopic)
|
||||||
syncTopic, err := s.topicFromID(v.user.SyncTopic)
|
syncTopic, err := s.topicFromID(u.SyncTopic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -484,15 +504,3 @@ func (s *Server) publishSyncEvent(v *visitor) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) publishSyncEventAsync(v *visitor) {
|
|
||||||
go func() {
|
|
||||||
u := v.User()
|
|
||||||
if u == nil || u.SyncTopic == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := s.publishSyncEvent(v); err != nil {
|
|
||||||
log.Trace("Error publishing to user %s's sync topic %s: %s", u.Name, u.SyncTopic, err.Error())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ func (s *Server) ensureUserManager(next handleFunc) handleFunc {
|
||||||
|
|
||||||
func (s *Server) ensureUser(next handleFunc) handleFunc {
|
func (s *Server) ensureUser(next handleFunc) handleFunc {
|
||||||
return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if v.user == nil {
|
if v.User() == nil {
|
||||||
return errHTTPUnauthorized
|
return errHTTPUnauthorized
|
||||||
}
|
}
|
||||||
return next(w, r, v)
|
return next(w, r, v)
|
||||||
|
@ -42,7 +42,7 @@ func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
|
||||||
|
|
||||||
func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
|
func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
|
||||||
return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if v.user.Billing.StripeCustomerID == "" {
|
if v.User().Billing.StripeCustomerID == "" {
|
||||||
return errHTTPBadRequestNotAPaidUser
|
return errHTTPBadRequestNotAPaidUser
|
||||||
}
|
}
|
||||||
return next(w, r, v)
|
return next(w, r, v)
|
||||||
|
@ -51,9 +51,6 @@ func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
|
||||||
|
|
||||||
func (s *Server) withAccountSync(next handleFunc) handleFunc {
|
func (s *Server) withAccountSync(next handleFunc) handleFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if v.user == nil {
|
|
||||||
return next(w, r, v)
|
|
||||||
}
|
|
||||||
err := next(w, r, v)
|
err := next(w, r, v)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s.publishSyncEventAsync(v)
|
s.publishSyncEventAsync(v)
|
||||||
|
|
|
@ -54,7 +54,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
|
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
|
||||||
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
|
// in the UI. Note that this endpoint does NOT have a user context (no u!).
|
||||||
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||||
tiers, err := s.userManager.Tiers()
|
tiers, err := s.userManager.Tiers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -107,7 +107,8 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
||||||
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
|
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
|
||||||
// will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
|
// will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
|
||||||
func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if v.user.Billing.StripeSubscriptionID != "" {
|
u := v.User()
|
||||||
|
if u.Billing.StripeSubscriptionID != "" {
|
||||||
return errHTTPBadRequestBillingSubscriptionExists
|
return errHTTPBadRequestBillingSubscriptionExists
|
||||||
}
|
}
|
||||||
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
|
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
|
||||||
|
@ -122,9 +123,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
||||||
}
|
}
|
||||||
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
|
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
|
||||||
var stripeCustomerID *string
|
var stripeCustomerID *string
|
||||||
if v.user.Billing.StripeCustomerID != "" {
|
if u.Billing.StripeCustomerID != "" {
|
||||||
stripeCustomerID = &v.user.Billing.StripeCustomerID
|
stripeCustomerID = &u.Billing.StripeCustomerID
|
||||||
stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID)
|
stripeCustomer, err := s.stripe.GetCustomer(u.Billing.StripeCustomerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
|
} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
|
||||||
|
@ -134,7 +135,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
||||||
successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
|
successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
|
||||||
params := &stripe.CheckoutSessionParams{
|
params := &stripe.CheckoutSessionParams{
|
||||||
Customer: stripeCustomerID, // A user may have previously deleted their subscription
|
Customer: stripeCustomerID, // A user may have previously deleted their subscription
|
||||||
ClientReferenceID: &v.user.ID,
|
ClientReferenceID: &u.ID,
|
||||||
SuccessURL: &successURL,
|
SuccessURL: &successURL,
|
||||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||||
AllowPromotionCodes: stripe.Bool(true),
|
AllowPromotionCodes: stripe.Bool(true),
|
||||||
|
@ -146,7 +147,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
||||||
},
|
},
|
||||||
Params: stripe.Params{
|
Params: stripe.Params{
|
||||||
Metadata: map[string]string{
|
Metadata: map[string]string{
|
||||||
"user_id": v.user.ID,
|
"user_id": u.ID,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -164,7 +165,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
||||||
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
|
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
|
||||||
// and only time we can map the local username with the Stripe customer ID.
|
// and only time we can map the local username with the Stripe customer ID.
|
||||||
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
// We don't have a v.user in this endpoint, only a userManager!
|
// We don't have v.User() in this endpoint, only a userManager!
|
||||||
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
|
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
|
||||||
if len(matches) != 2 {
|
if len(matches) != 2 {
|
||||||
return errHTTPInternalErrorInvalidPath
|
return errHTTPInternalErrorInvalidPath
|
||||||
|
@ -212,7 +213,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
||||||
// handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
|
// handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
|
||||||
// a user's tier accordingly. This endpoint only works if there is an existing subscription.
|
// a user's tier accordingly. This endpoint only works if there is an existing subscription.
|
||||||
func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if v.user.Billing.StripeSubscriptionID == "" {
|
u := v.User()
|
||||||
|
if u.Billing.StripeSubscriptionID == "" {
|
||||||
return errNoBillingSubscription
|
return errNoBillingSubscription
|
||||||
}
|
}
|
||||||
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
|
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
|
||||||
|
@ -223,8 +225,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID)
|
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID)
|
||||||
sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
|
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -248,12 +250,13 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
||||||
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
||||||
// and cancelling the Stripe subscription entirely
|
// and cancelling the Stripe subscription entirely
|
||||||
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
|
u := v.User()
|
||||||
if v.user.Billing.StripeSubscriptionID != "" {
|
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID)
|
||||||
|
if u.Billing.StripeSubscriptionID != "" {
|
||||||
params := &stripe.SubscriptionParams{
|
params := &stripe.SubscriptionParams{
|
||||||
CancelAtPeriodEnd: stripe.Bool(true),
|
CancelAtPeriodEnd: stripe.Bool(true),
|
||||||
}
|
}
|
||||||
_, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params)
|
_, err := s.stripe.UpdateSubscription(u.Billing.StripeSubscriptionID, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -264,12 +267,13 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
||||||
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
|
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
|
||||||
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
|
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
|
||||||
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if v.user.Billing.StripeCustomerID == "" {
|
u := v.User()
|
||||||
|
if u.Billing.StripeCustomerID == "" {
|
||||||
return errHTTPBadRequestNotAPaidUser
|
return errHTTPBadRequestNotAPaidUser
|
||||||
}
|
}
|
||||||
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
|
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
|
||||||
params := &stripe.BillingPortalSessionParams{
|
params := &stripe.BillingPortalSessionParams{
|
||||||
Customer: stripe.String(v.user.Billing.StripeCustomerID),
|
Customer: stripe.String(u.Billing.StripeCustomerID),
|
||||||
ReturnURL: stripe.String(s.config.BaseURL),
|
ReturnURL: stripe.String(s.config.BaseURL),
|
||||||
}
|
}
|
||||||
ps, err := s.stripe.NewPortalSession(params)
|
ps, err := s.stripe.NewPortalSession(params)
|
||||||
|
@ -284,8 +288,8 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
||||||
|
|
||||||
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
||||||
// with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
|
// with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
|
||||||
// visitor (v) in this endpoint is the Stripe API, so we don't have v.user available.
|
// visitor (v) in this endpoint is the Stripe API, so we don't have u available.
|
||||||
func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||||
stripeSignature := r.Header.Get("Stripe-Signature")
|
stripeSignature := r.Header.Get("Stripe-Signature")
|
||||||
if stripeSignature == "" {
|
if stripeSignature == "" {
|
||||||
return errHTTPBadRequestBillingRequestInvalid
|
return errHTTPBadRequestBillingRequestInvalid
|
||||||
|
|
|
@ -30,6 +30,7 @@ const (
|
||||||
tokenMaxCount = 10 // Only keep this many tokens in the table per user
|
tokenMaxCount = 10 // Only keep this many tokens in the table per user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Default constants that may be overridden by configs
|
||||||
const (
|
const (
|
||||||
DefaultUserStatsQueueWriterInterval = 33 * time.Second
|
DefaultUserStatsQueueWriterInterval = 33 * time.Second
|
||||||
DefaultUserPasswordBcryptCost = 10
|
DefaultUserPasswordBcryptCost = 10
|
||||||
|
@ -1195,6 +1196,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying database
|
||||||
func (a *Manager) Close() error {
|
func (a *Manager) Close() error {
|
||||||
return a.db.Close()
|
return a.db.Close()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue