From 92d563371c32b1d894028bde81e8afa6d765f2f2 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 28 Jan 2023 20:43:06 -0500 Subject: [PATCH] No more v.user races --- server/server.go | 1 - server/server_account.go | 114 +++++++++++++++++++----------------- server/server_middleware.go | 7 +-- server/server_payments.go | 40 +++++++------ user/manager.go | 2 + 5 files changed, 87 insertions(+), 77 deletions(-) diff --git a/server/server.go b/server/server.go index c2357d5..7495f23 100644 --- a/server/server.go +++ b/server/server.go @@ -39,7 +39,6 @@ import ( - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) - HIGH Stripe payment methods - MEDIUM: Test new token endpoints & never-expiring token -- MEDIUM: Races with v.user (see publishSyncEventAsync test) - MEDIUM: Test that anonymous user and user without tier are the same visitor - MEDIUM: Make sure account endpoints make sense for admins - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) diff --git a/server/server_account.go b/server/server_account.go index 0c5b25b..1fcfabe 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -19,11 +19,12 @@ const ( ) func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { - admin := v.user != nil && v.user.Role == user.RoleAdmin + u := v.User() + admin := u != nil && u.Role == user.RoleAdmin if !admin { if !s.config.EnableSignup { return errHTTPBadRequestSignupNotEnabled - } else if v.user != nil { + } else if u != nil { return errHTTPUnauthorized // Cannot create account from user context } if !v.AccountCreationAllowed() { @@ -150,20 +151,21 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v * } else if req.Password == "" { return errHTTPBadRequest } - if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil { + u := v.User() + if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil { return errHTTPBadRequestIncorrectPasswordConfirmation } - if v.user.Billing.StripeSubscriptionID != "" { - log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID) - if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil { + if u.Billing.StripeSubscriptionID != "" { + log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID) + if _, err := s.stripe.CancelSubscription(u.Billing.StripeSubscriptionID); err != nil { return err } } - if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil { + if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), u, 0); err != nil { return err } - log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name) - if err := s.userManager.MarkUserRemoved(v.user); err != nil { + log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), u.Name) + if err := s.userManager.MarkUserRemoved(u); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) @@ -176,10 +178,11 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ } else if req.Password == "" || req.NewPassword == "" { return errHTTPBadRequest } - if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil { + u := v.User() + if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil { return errHTTPBadRequestIncorrectPasswordConfirmation } - if err := s.userManager.ChangePassword(v.user.Name, req.NewPassword); err != nil { + if err := s.userManager.ChangePassword(u.Name, req.NewPassword); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) @@ -267,10 +270,11 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ if err != nil { return err } - if v.user.Prefs == nil { - v.user.Prefs = &user.Prefs{} + u := v.User() + if u.Prefs == nil { + u.Prefs = &user.Prefs{} } - prefs := v.user.Prefs + prefs := u.Prefs if newPrefs.Language != nil { prefs.Language = newPrefs.Language } @@ -288,7 +292,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ prefs.Notification.MinPriority = newPrefs.Notification.MinPriority } } - if err := s.userManager.ChangeSettings(v.user); err != nil { + if err := s.userManager.ChangeSettings(u); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) @@ -299,11 +303,12 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req if err != nil { return err } - if v.user.Prefs == nil { - v.user.Prefs = &user.Prefs{} + u := v.User() + if u.Prefs == nil { + u.Prefs = &user.Prefs{} } newSubscription.ID = "" // Client cannot set ID - for _, subscription := range v.user.Prefs.Subscriptions { + for _, subscription := range u.Prefs.Subscriptions { if newSubscription.BaseURL == subscription.BaseURL && newSubscription.Topic == subscription.Topic { newSubscription = subscription break @@ -311,8 +316,8 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req } if newSubscription.ID == "" { newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) - v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, newSubscription) - if err := s.userManager.ChangeSettings(v.user); err != nil { + u.Prefs.Subscriptions = append(u.Prefs.Subscriptions, newSubscription) + if err := s.userManager.ChangeSettings(u); err != nil { return err } } @@ -329,11 +334,12 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http. if err != nil { return err } - if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil { + u := v.User() + if u.Prefs == nil || u.Prefs.Subscriptions == nil { return errHTTPNotFound } var subscription *user.Subscription - for _, sub := range v.user.Prefs.Subscriptions { + for _, sub := range u.Prefs.Subscriptions { if sub.ID == subscriptionID { sub.DisplayName = updatedSubscription.DisplayName subscription = sub @@ -343,7 +349,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http. if subscription == nil { return errHTTPNotFound } - if err := s.userManager.ChangeSettings(v.user); err != nil { + if err := s.userManager.ChangeSettings(u); err != nil { return err } return s.writeJSON(w, subscription) @@ -355,18 +361,19 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http. return errHTTPInternalErrorInvalidPath } subscriptionID := matches[1] - if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil { + u := v.User() + if u.Prefs == nil || u.Prefs.Subscriptions == nil { return nil } newSubscriptions := make([]*user.Subscription, 0) - for _, subscription := range v.user.Prefs.Subscriptions { + for _, subscription := range u.Prefs.Subscriptions { if subscription.ID != subscriptionID { newSubscriptions = append(newSubscriptions, subscription) } } - if len(newSubscriptions) < len(v.user.Prefs.Subscriptions) { - v.user.Prefs.Subscriptions = newSubscriptions - if err := s.userManager.ChangeSettings(v.user); err != nil { + if len(newSubscriptions) < len(u.Prefs.Subscriptions) { + u.Prefs.Subscriptions = newSubscriptions + if err := s.userManager.ChangeSettings(u); err != nil { return err } } @@ -374,7 +381,8 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http. } func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user != nil && v.user.Role == user.RoleAdmin { + u := v.User() + if u != nil && u.Role == user.RoleAdmin { return errHTTPBadRequestMakesNoSenseForAdmin } req, err := readJSONWithLimit[apiAccountReservationRequest](r.Body, jsonBodyBytesLimit, false) @@ -388,27 +396,27 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ if err != nil { return errHTTPBadRequestPermissionInvalid } - if v.user.Tier == nil { + if u.Tier == nil { return errHTTPUnauthorized } // CHeck if we are allowed to reserve this topic - if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil { + if err := s.userManager.CheckAllowAccess(u.Name, req.Topic); err != nil { return errHTTPConflictTopicReserved } - hasReservation, err := s.userManager.HasReservation(v.user.Name, req.Topic) + hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic) if err != nil { return err } if !hasReservation { - reservations, err := s.userManager.ReservationsCount(v.user.Name) + reservations, err := s.userManager.ReservationsCount(u.Name) if err != nil { return err - } else if reservations >= v.user.Tier.ReservationLimit { + } else if reservations >= u.Tier.ReservationLimit { return errHTTPTooManyRequestsLimitReservations } } // Actually add the reservation - if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil { + if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil { return err } // Kill existing subscribers @@ -416,7 +424,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ if err != nil { return err } - t.CancelSubscribers(v.user.ID) + t.CancelSubscribers(u.ID) return s.writeJSON(w, newSuccessResponse()) } @@ -429,13 +437,14 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R if !topicRegex.MatchString(topic) { return errHTTPBadRequestTopicInvalid } - authorized, err := s.userManager.HasReservation(v.user.Name, topic) + u := v.User() + authorized, err := s.userManager.HasReservation(u.Name, topic) if err != nil { return err } else if !authorized { return errHTTPUnauthorized } - if err := s.userManager.RemoveReservations(v.user.Name, topic); err != nil { + if err := s.userManager.RemoveReservations(u.Name, topic); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) @@ -465,12 +474,23 @@ func (s *Server) maybeRemoveMessagesAndExcessReservations(logPrefix string, u *u return nil } +// publishSyncEventAsync kicks of a Go routine to publish a sync message to the user's sync topic +func (s *Server) publishSyncEventAsync(v *visitor) { + go func() { + if err := s.publishSyncEvent(v); err != nil { + log.Trace("%s Error publishing to user's sync topic: %s", v.String(), err.Error()) + } + }() +} + +// publishSyncEvent publishes a sync message to the user's sync topic func (s *Server) publishSyncEvent(v *visitor) error { - if v.user == nil || v.user.SyncTopic == "" { + u := v.User() + if u == nil || u.SyncTopic == "" { return nil } - log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic) - syncTopic, err := s.topicFromID(v.user.SyncTopic) + log.Trace("Publishing sync event to user %s's sync topic %s", u.Name, u.SyncTopic) + syncTopic, err := s.topicFromID(u.SyncTopic) if err != nil { return err } @@ -484,15 +504,3 @@ func (s *Server) publishSyncEvent(v *visitor) error { } return nil } - -func (s *Server) publishSyncEventAsync(v *visitor) { - go func() { - u := v.User() - if u == nil || u.SyncTopic == "" { - return - } - if err := s.publishSyncEvent(v); err != nil { - log.Trace("Error publishing to user %s's sync topic %s: %s", u.Name, u.SyncTopic, err.Error()) - } - }() -} diff --git a/server/server_middleware.go b/server/server_middleware.go index 544bb59..a3d945b 100644 --- a/server/server_middleware.go +++ b/server/server_middleware.go @@ -24,7 +24,7 @@ func (s *Server) ensureUserManager(next handleFunc) handleFunc { func (s *Server) ensureUser(next handleFunc) handleFunc { return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user == nil { + if v.User() == nil { return errHTTPUnauthorized } return next(w, r, v) @@ -42,7 +42,7 @@ func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc { func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc { return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user.Billing.StripeCustomerID == "" { + if v.User().Billing.StripeCustomerID == "" { return errHTTPBadRequestNotAPaidUser } return next(w, r, v) @@ -51,9 +51,6 @@ func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc { func (s *Server) withAccountSync(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user == nil { - return next(w, r, v) - } err := next(w, r, v) if err == nil { s.publishSyncEventAsync(v) diff --git a/server/server_payments.go b/server/server_payments.go index 7662897..bc616e7 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -54,7 +54,7 @@ var ( ) // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog -// in the UI. Note that this endpoint does NOT have a user context (no v.user!). +// in the UI. Note that this endpoint does NOT have a user context (no u!). func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error { tiers, err := s.userManager.Tiers() if err != nil { @@ -107,7 +107,8 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier // will be updated by a subsequent webhook from Stripe, once the subscription becomes active. func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user.Billing.StripeSubscriptionID != "" { + u := v.User() + if u.Billing.StripeSubscriptionID != "" { return errHTTPBadRequestBillingSubscriptionExists } req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false) @@ -122,9 +123,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r } log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r)) var stripeCustomerID *string - if v.user.Billing.StripeCustomerID != "" { - stripeCustomerID = &v.user.Billing.StripeCustomerID - stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID) + if u.Billing.StripeCustomerID != "" { + stripeCustomerID = &u.Billing.StripeCustomerID + stripeCustomer, err := s.stripe.GetCustomer(u.Billing.StripeCustomerID) if err != nil { return err } else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 { @@ -134,7 +135,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate params := &stripe.CheckoutSessionParams{ Customer: stripeCustomerID, // A user may have previously deleted their subscription - ClientReferenceID: &v.user.ID, + ClientReferenceID: &u.ID, SuccessURL: &successURL, Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), AllowPromotionCodes: stripe.Bool(true), @@ -146,7 +147,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r }, Params: stripe.Params{ Metadata: map[string]string{ - "user_id": v.user.ID, + "user_id": u.ID, }, }, } @@ -164,7 +165,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r // the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first // and only time we can map the local username with the Stripe customer ID. func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error { - // We don't have a v.user in this endpoint, only a userManager! + // We don't have v.User() in this endpoint, only a userManager! matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { return errHTTPInternalErrorInvalidPath @@ -212,7 +213,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates // a user's tier accordingly. This endpoint only works if there is an existing subscription. func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user.Billing.StripeSubscriptionID == "" { + u := v.User() + if u.Billing.StripeSubscriptionID == "" { return errNoBillingSubscription } req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false) @@ -223,8 +225,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r if err != nil { return err } - log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID) - sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID) + log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID) + sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID) if err != nil { return err } @@ -248,12 +250,13 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, // and cancelling the Stripe subscription entirely func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { - log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID) - if v.user.Billing.StripeSubscriptionID != "" { + u := v.User() + log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID) + if u.Billing.StripeSubscriptionID != "" { params := &stripe.SubscriptionParams{ CancelAtPeriodEnd: stripe.Bool(true), } - _, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params) + _, err := s.stripe.UpdateSubscription(u.Billing.StripeSubscriptionID, params) if err != nil { return err } @@ -264,12 +267,13 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r // handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the // redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription. func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user.Billing.StripeCustomerID == "" { + u := v.User() + if u.Billing.StripeCustomerID == "" { return errHTTPBadRequestNotAPaidUser } log.Info("%s Creating billing portal session", logHTTPPrefix(v, r)) params := &stripe.BillingPortalSessionParams{ - Customer: stripe.String(v.user.Billing.StripeCustomerID), + Customer: stripe.String(u.Billing.StripeCustomerID), ReturnURL: stripe.String(s.config.BaseURL), } ps, err := s.stripe.NewPortalSession(params) @@ -284,8 +288,8 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, // handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync // with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the -// visitor (v) in this endpoint is the Stripe API, so we don't have v.user available. -func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error { +// visitor (v) in this endpoint is the Stripe API, so we don't have u available. +func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error { stripeSignature := r.Header.Get("Stripe-Signature") if stripeSignature == "" { return errHTTPBadRequestBillingRequestInvalid diff --git a/user/manager.go b/user/manager.go index 0ed5794..7009d12 100644 --- a/user/manager.go +++ b/user/manager.go @@ -30,6 +30,7 @@ const ( tokenMaxCount = 10 // Only keep this many tokens in the table per user ) +// Default constants that may be overridden by configs const ( DefaultUserStatsQueueWriterInterval = 33 * time.Second DefaultUserPasswordBcryptCost = 10 @@ -1195,6 +1196,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { }, nil } +// Close closes the underlying database func (a *Manager) Close() error { return a.db.Close() }