diff --git a/cmd/user.go b/cmd/user.go index ff87018..9ca9f40 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -16,8 +16,7 @@ import ( ) const ( - tierReset = "-" - createdByCLI = "cli" + tierReset = "-" ) func init() { @@ -197,7 +196,7 @@ func execUserAdd(c *cli.Context) error { password = p } - if err := manager.AddUser(username, password, role, createdByCLI); err != nil { + if err := manager.AddUser(username, password, role); err != nil { return err } fmt.Fprintf(c.App.ErrWriter, "user %s added with role %s\n", username, role) diff --git a/server/server.go b/server/server.go index a098fc1..05524fd 100644 --- a/server/server.go +++ b/server/server.go @@ -39,12 +39,10 @@ TODO -- - Reservation: Kill existing subscribers when topic is reserved (deadcade) +- Rate limiting: Sensitive endpoints (account/login/change-password/...) +- Stripe: Add metadata to customer - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - Reservation (UI): Ask for confirmation when removing reservation (deadcade) -- Logging: Add detailed logging with username/customerID for all Stripe events (phil) -- Rate limiting: Sensitive endpoints (account/login/change-password/...) -- Stripe webhook: Do not respond wih error if user does not exist (after account deletion) -- Stripe: Add metadata to customer races: - v.user --> see publishSyncEventAsync() test @@ -53,7 +51,7 @@ payments: - reconciliation delete messages + reserved topics on ResetTier delete attachments in access.go -account deletion should delete messages and reservations and attachments + Limits & rate limiting: rate limiting weirdness. wth is going on? @@ -1256,11 +1254,14 @@ func (s *Server) execManager() { s.mu.Unlock() log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) - // Delete expired user tokens + // Delete expired user tokens and users if s.userManager != nil { if err := s.userManager.RemoveExpiredTokens(); err != nil { log.Warn("Error expiring user tokens: %s", err.Error()) } + if err := s.userManager.RemoveDeletedUsers(); err != nil { + log.Warn("Error deleting soft-deleted users: %s", err.Error()) + } } // Delete expired attachments @@ -1283,7 +1284,7 @@ func (s *Server) execManager() { } } - // DeleteMessages message cache + // Prune messages log.Debug("Manager: Pruning messages") expiredMessageIDs, err := s.messageCache.MessagesExpired() if err != nil { diff --git a/server/server_account.go b/server/server_account.go index 1b24383..72a7826 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -11,7 +11,6 @@ import ( const ( subscriptionIDLength = 16 - createdByAPI = "api" syncTopicAccountSyncEvent = "sync" ) @@ -34,7 +33,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v * if v.accountLimiter != nil && !v.accountLimiter.Allow() { return errHTTPTooManyRequestsLimitAccountCreation } - if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User + if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User return err } return s.writeJSON(w, newSuccessResponse()) @@ -118,18 +117,20 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis return s.writeJSON(w, response) } -func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error { +func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { if v.user.Billing.StripeSubscriptionID != "" { - log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID) + log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID) if v.user.Billing.StripeSubscriptionID != "" { if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil { return err } } - } else { - log.Info("Deleting user %s", v.user.Name) + if err := s.maybeRemoveExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil { + return err + } } - if err := s.userManager.RemoveUser(v.user.Name); err != nil { + log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name) + if err := s.userManager.MarkUserRemoved(v.user); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) diff --git a/server/server_account_test.go b/server/server_account_test.go index 4e4e452..dba1ad9 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -67,8 +67,8 @@ func TestAccount_Signup_AsUser(t *testing.T) { conf.EnableSignup = true s := newTestServer(t, conf) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) - require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin)) + require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), @@ -133,7 +133,7 @@ func TestAccount_Get_Anonymous(t *testing.T) { func TestAccount_ChangeSettings(t *testing.T) { s := newTestServer(t, newTestConfigWithAuthFile(t)) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) user, _ := s.userManager.User("phil") token, _ := s.userManager.CreateToken(user) @@ -160,7 +160,7 @@ func TestAccount_ChangeSettings(t *testing.T) { func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { s := newTestServer(t, newTestConfigWithAuthFile(t)) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), @@ -210,7 +210,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { func TestAccount_ChangePassword(t *testing.T) { s := newTestServer(t, newTestConfigWithAuthFile(t)) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), @@ -237,7 +237,7 @@ func TestAccount_ChangePassword_NoAccount(t *testing.T) { func TestAccount_ExtendToken(t *testing.T) { s := newTestServer(t, newTestConfigWithAuthFile(t)) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), @@ -260,7 +260,7 @@ func TestAccount_ExtendToken(t *testing.T) { func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) { s := newTestServer(t, newTestConfigWithAuthFile(t)) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), // Not Bearer! @@ -271,7 +271,7 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) { func TestAccount_DeleteToken(t *testing.T) { s := newTestServer(t, newTestConfigWithAuthFile(t)) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), @@ -324,10 +324,15 @@ func TestAccount_Delete_Success(t *testing.T) { }) require.Equal(t, 200, rr.Code) + // Account was marked deleted rr = request(t, s, "GET", "/v1/account", "", map[string]string{ "Authorization": util.BasicAuth("phil", "mypass"), }) require.Equal(t, 401, rr.Code) + + // Cannot re-create account, since still exists + rr = request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil) + require.Equal(t, 409, rr.Code) } func TestAccount_Delete_Not_Allowed(t *testing.T) { @@ -360,7 +365,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) { conf := newTestConfigWithAuthFile(t) conf.EnableSignup = true s := newTestServer(t, conf) - require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin)) rr := request(t, s, "POST", "/v1/account/reservation", `{"topic":"mytopic","everyone":"deny-all"}`, map[string]string{ "Authorization": util.BasicAuth("phil", "adminpass"), diff --git a/server/server_payments.go b/server/server_payments.go index 40b961a..3c29dae 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -18,15 +18,10 @@ import ( "io" "net/http" "net/netip" + "strings" "time" ) -var ( - errNotAPaidTier = errors.New("tier does not have billing price identifier") - errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions") - errNoBillingSubscription = errors.New("user does not have an active billing subscription") -) - // Payments in ntfy are done via Stripe. // // Pretty much all payments related things are in this file. The following processes @@ -49,6 +44,16 @@ var ( // This is used to keep the local user database fields up to date. Stripe is the source of truth. // What Stripe says is mirrored and not questioned. +var ( + errNotAPaidTier = errors.New("tier does not have billing price identifier") + errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions") + errNoBillingSubscription = errors.New("user does not have an active billing subscription") +) + +var ( + retryUserDelays = []time.Duration{3 * time.Second, 5 * time.Second, 7 * time.Second} +) + // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog // in the UI. Note that this endpoint does NOT have a user context (no v.user!). func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error { @@ -114,7 +119,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r } else if tier.StripePriceID == "" { return errNotAPaidTier } - log.Info("Stripe: No existing subscription, creating checkout flow") + log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r)) var stripeCustomerID *string if v.user.Billing.StripeCustomerID != "" { stripeCustomerID = &v.user.Billing.StripeCustomerID @@ -138,9 +143,6 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r Quantity: stripe.Int64(1), }, }, - /*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{ - Enabled: stripe.Bool(true), - },*/ } sess, err := s.stripe.NewCheckoutSession(params) if err != nil { @@ -155,7 +157,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r // handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use // the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first // and only time we can map the local username with the Stripe customer ID. -func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error { +func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error { // We don't have a v.user in this endpoint, only a userManager! matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { @@ -182,7 +184,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr if err != nil { return err } - if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil { + v.SetUser(u) + if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil { return err } http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther) @@ -203,7 +206,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r if err != nil { return err } - log.Info("Stripe: Changing tier and subscription to %s", tier.Code) + log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID) sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID) if err != nil { return err @@ -228,6 +231,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, // and cancelling the Stripe subscription entirely func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { + log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID) if v.user.Billing.StripeSubscriptionID != "" { params := &stripe.SubscriptionParams{ CancelAtPeriodEnd: stripe.Bool(true), @@ -246,6 +250,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, if v.user.Billing.StripeCustomerID == "" { return errHTTPBadRequestNotAPaidUser } + log.Info("%s Creating billing portal session", logHTTPPrefix(v, r)) params := &stripe.BillingPortalSessionParams{ Customer: stripe.String(v.user.Billing.StripeCustomerID), ReturnURL: stripe.String(s.config.BaseURL), @@ -280,28 +285,30 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ } else if event.Data == nil || event.Data.Raw == nil { return errHTTPBadRequestBillingRequestInvalid } - - log.Info("Stripe: webhook event %s received", event.Type) switch event.Type { case "customer.subscription.updated": return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw) case "customer.subscription.deleted": return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw) default: + log.Warn("STRIPE Unhandled webhook event %s received", event.Type) return nil } } func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error { - r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event))) + ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event))) if err != nil { return err - } else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" { + } else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" { return errHTTPBadRequestBillingRequestInvalid } - subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID - log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID) - u, err := s.userManager.UserByStripeCustomer(r.Customer) + subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID + log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID) + userFn := func() (*user.User, error) { + return s.userManager.UserByStripeCustomer(ev.Customer) + } + u, err := util.Retry[user.User](userFn, retryUserDelays...) if err != nil { return err } @@ -309,7 +316,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe if err != nil { return err } - if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil { + if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil { return err } s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) @@ -317,47 +324,56 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe } func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error { - r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event))) + ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event))) if err != nil { return err - } else if r.Customer == "" { + } else if ev.Customer == "" { return errHTTPBadRequestBillingRequestInvalid } - log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer) - u, err := s.userManager.UserByStripeCustomer(r.Customer) + log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID)) + u, err := s.userManager.UserByStripeCustomer(ev.Customer) if err != nil { return err } - if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil { + if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil { return err } s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) return nil } -func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { - // Remove excess reservations (if too many for tier), and mark associated messages deleted +// maybeRemoveExcessReservations deletes topic reservations for the given user (if too many for tier), +// and marks associated messages for the topics as deleted. This also eventually deletes attachments. +// The process relies on the manager to perform the actual deletions (see runManager). +func (s *Server) maybeRemoveExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error { reservations, err := s.userManager.Reservations(u.Name) if err != nil { return err + } else if int64(len(reservations)) <= reservationsLimit { + return nil } + topics := make([]string, 0) + for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- { + topics = append(topics, reservations[i].Topic) + } + log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", ")) + if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil { + return err + } + if err := s.messageCache.ExpireMessages(topics...); err != nil { + return err + } + return nil +} + +func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { reservationsLimit := visitorDefaultReservationsLimit if tier != nil { reservationsLimit = tier.ReservationsLimit } - if int64(len(reservations)) > reservationsLimit { - topics := make([]string, 0) - for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- { - topics = append(topics, reservations[i].Topic) - } - if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil { - return err - } - if err := s.messageCache.ExpireMessages(topics...); err != nil { - return err - } + if err := s.maybeRemoveExcessReservations(logPrefix, u, reservationsLimit); err != nil { + return err } - // Change or remove tier if tier == nil { if err := s.userManager.ResetTier(u.Name); err != nil { return err diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 4ee3d0e..d1008ef 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -34,7 +34,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) { Code: "pro", StripePriceID: "price_123", })) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // Create subscription response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ @@ -69,7 +69,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) { Code: "pro", StripePriceID: "price_123", })) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) u, err := s.userManager.User("phil") require.Nil(t, err) @@ -110,7 +110,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) { Code: "pro", StripePriceID: "price_123", })) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) u, err := s.userManager.User("phil") require.Nil(t, err) @@ -174,7 +174,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( AttachmentFileSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000, })) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll)) require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll)) diff --git a/server/server_test.go b/server/server_test.go index 0986380..303f198 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -625,7 +625,7 @@ func TestServer_Auth_Success_Admin(t *testing.T) { c := newTestConfigWithAuthFile(t) s := newTestServer(t, c) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin)) response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), @@ -639,7 +639,7 @@ func TestServer_Auth_Success_User(t *testing.T) { c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) - require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite)) response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ @@ -653,7 +653,7 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) { c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) - require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite)) require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite)) @@ -674,7 +674,7 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) { c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin)) response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ "Authorization": util.BasicAuth("phil", "INVALID"), @@ -687,7 +687,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) { c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) - require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic! response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ @@ -701,7 +701,7 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) { c.AuthDefault = user.PermissionReadWrite // Open by default s := newTestServer(t, c) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin)) require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll)) require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead)) @@ -731,7 +731,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) { c.AuthDefault = user.PermissionDenyAll s := newTestServer(t, c) - require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin, "unit-test")) + require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin)) u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(util.BasicAuth("ben", "some pass")))) response := request(t, s, "GET", u, "", nil) @@ -749,7 +749,7 @@ func TestServer_StatsResetter(t *testing.T) { s := newTestServer(t, c) go s.runStatsResetter() - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite)) for i := 0; i < 5; i++ { @@ -1137,7 +1137,7 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) { MessagesLimit: 5, MessagesExpiryDuration: -5 * time.Second, // Second, what a hack! })) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "test")) // Publish to reach message limit @@ -1369,7 +1369,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) { AttachmentTotalSizeLimit: 200_000, AttachmentExpiryDuration: sevenDays, // 7 days })) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "test")) // Publish and make sure we can retrieve it @@ -1413,7 +1413,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) { AttachmentTotalSizeLimit: 200_000, AttachmentExpiryDuration: 30 * time.Second, })) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "test")) // Publish small file as anonymous diff --git a/server/types.go b/server/types.go index 3d651d3..38364e9 100644 --- a/server/types.go +++ b/server/types.go @@ -354,5 +354,6 @@ type apiStripeSubscriptionUpdatedEvent struct { } type apiStripeSubscriptionDeletedEvent struct { + ID string `json:"id"` Customer string `json:"customer"` } diff --git a/server/util.go b/server/util.go index c1eff1f..a9c4df5 100644 --- a/server/util.go +++ b/server/util.go @@ -49,7 +49,7 @@ func readQueryParam(r *http.Request, names ...string) string { } func logMessagePrefix(v *visitor, m *message) string { - return fmt.Sprintf("%s/%s/%s", v.ip, m.Topic, m.ID) + return fmt.Sprintf("%s/%s/%s", v.String(), m.Topic, m.ID) } func logHTTPPrefix(v *visitor, r *http.Request) string { @@ -57,7 +57,14 @@ func logHTTPPrefix(v *visitor, r *http.Request) string { if requestURI == "" { requestURI = r.URL.Path } - return fmt.Sprintf("%s HTTP %s %s", v.ip, r.Method, requestURI) + return fmt.Sprintf("%s HTTP %s %s", v.String(), r.Method, requestURI) +} + +func logStripePrefix(customerID, subscriptionID string) string { + if subscriptionID != "" { + return fmt.Sprintf("%s/%s STRIPE", customerID, subscriptionID) + } + return fmt.Sprintf("%s STRIPE", customerID) } func logSMTPPrefix(state *smtp.ConnectionState) string { diff --git a/server/visitor.go b/server/visitor.go index b23f66e..77ed946 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -2,6 +2,7 @@ package server import ( "errors" + "fmt" "heckel.io/ntfy/user" "net/netip" "sync" @@ -119,6 +120,17 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana } } +func (v *visitor) String() string { + v.mu.Lock() + defer v.mu.Unlock() + if v.user != nil && v.user.Billing.StripeCustomerID != "" { + return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID) + } else if v.user != nil { + return fmt.Sprintf("%s/%s", v.ip.String(), v.user.ID) + } + return v.ip.String() +} + func (v *visitor) RequestAllowed() error { if !v.requestLimiter.Allow() { return errVisitorLimitReached @@ -216,6 +228,12 @@ func (v *visitor) ResetStats() { } } +func (v *visitor) SetUser(u *user.User) { + v.mu.Lock() + defer v.mu.Unlock() + v.user = u +} + func (v *visitor) Limits() *visitorLimits { v.mu.Lock() defer v.mu.Unlock() diff --git a/user/manager.go b/user/manager.go index 5c3baf7..a46150e 100644 --- a/user/manager.go +++ b/user/manager.go @@ -25,6 +25,7 @@ const ( userPasswordBcryptCost = 10 userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost userStatsQueueWriterInterval = 33 * time.Second + userHardDeleteAfterDuration = 7 * 24 * time.Hour tokenPrefix = "tk_" tokenLength = 32 tokenMaxCount = 10 // Only keep this many tokens in the table per user @@ -57,7 +58,7 @@ const ( CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); CREATE TABLE IF NOT EXISTS user ( id TEXT PRIMARY KEY, - tier_id INT, + tier_id TEXT, user TEXT NOT NULL, pass TEXT NOT NULL, role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, @@ -70,8 +71,8 @@ const ( stripe_subscription_status TEXT, stripe_subscription_paid_until INT, stripe_subscription_cancel_at INT, - created_by TEXT NOT NULL, - created_at INT NOT NULL, + created INT NOT NULL, + deleted INT, FOREIGN KEY (tier_id) REFERENCES tier (id) ); CREATE UNIQUE INDEX idx_user ON user (user); @@ -98,8 +99,8 @@ const ( id INT PRIMARY KEY, version INT NOT NULL ); - INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) - VALUES ('u_everyone', '*', '', 'anonymous', '', 'system', UNIXEPOCH()) + INSERT INTO user (id, user, pass, role, sync_topic, created) + VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH()) ON CONFLICT (id) DO NOTHING; ` createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;` @@ -108,26 +109,26 @@ const ( ` selectUserByIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.id = ? ` selectUserByNameQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id - WHERE user = ? + WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id FROM user u JOIN user_token t on u.id = t.user_id LEFT JOIN tier t on t.id = u.tier_id WHERE t.token = ? AND t.expires >= ? ` selectUserByStripeCustomerIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.stripe_customer_id = ? @@ -141,8 +142,8 @@ const ( ` insertUserQuery = ` - INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO user (id, user, pass, role, sync_topic, created) + VALUES (?, ?, ?, ?, ?, ?) ` selectUsernamesQuery = ` SELECT user @@ -159,6 +160,8 @@ const ( updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?` updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?` updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0` + updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?` + deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?` deleteUserQuery = `DELETE FROM user WHERE user = ?` upsertUserAccessQuery = ` @@ -214,7 +217,8 @@ const ( selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)` updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` - deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` + deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?` + deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?` deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?` deleteExcessTokensQuery = ` DELETE FROM user_token @@ -268,8 +272,8 @@ const ( ` migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` migrate1To2InsertUserNoTx = ` - INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) - SELECT ?, user, pass, role, ?, 'admin', UNIXEPOCH() FROM user_old WHERE user = ? + INSERT INTO user (id, user, pass, role, sync_topic, created) + SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ? ` migrate1To2InsertFromOldTablesAndDropNoTx = ` INSERT INTO user_access (user_id, topic, read, write) @@ -320,9 +324,9 @@ func newManager(filename, startupQueries string, defaultAccess Permission, stats return manager, nil } -// Authenticate checks username and password and returns a User if correct. The method -// returns in constant-ish time, regardless of whether the user exists or the password is -// correct or incorrect. +// Authenticate checks username and password and returns a User if correct, and the user has not been +// marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or +// the password is correct or incorrect. func (a *Manager) Authenticate(username, password string) (*User, error) { if username == Everyone { return nil, ErrUnauthenticated @@ -332,9 +336,12 @@ func (a *Manager) Authenticate(username, password string) (*User, error) { log.Trace("authentication of user %s failed (1): %s", username, err.Error()) bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) return nil, ErrUnauthenticated - } - if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil { - log.Trace("authentication of user %s failed (2): %s", username, err.Error()) + } else if user.Deleted { + log.Trace("authentication of user %s failed (2): user marked deleted", username) + bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) + return nil, ErrUnauthenticated + } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil { + log.Trace("authentication of user %s failed (3): %s", username, err.Error()) return nil, ErrUnauthenticated } return user, nil @@ -415,7 +422,7 @@ func (a *Manager) RemoveToken(user *User) error { if user.Token == "" { return ErrUnauthorized } - if _, err := a.db.Exec(deleteTokenQuery, user.Name, user.Token); err != nil { + if _, err := a.db.Exec(deleteTokenQuery, user.ID, user.Token); err != nil { return err } return nil @@ -429,6 +436,14 @@ func (a *Manager) RemoveExpiredTokens() error { return nil } +// RemoveDeletedUsers deletes all users that have been marked deleted for +func (a *Manager) RemoveDeletedUsers() error { + if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil { + return err + } + return nil +} + // ChangeSettings persists the user settings func (a *Manager) ChangeSettings(user *User) error { prefs, err := json.Marshal(user.Prefs) @@ -533,7 +548,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error { } // AddUser adds a user with the given username, password and role -func (a *Manager) AddUser(username, password string, role Role, createdBy string) error { +func (a *Manager) AddUser(username, password string, role Role) error { if !AllowedUsername(username) || !AllowedRole(role) { return ErrInvalidArgument } @@ -543,7 +558,7 @@ func (a *Manager) AddUser(username, password string, role Role, createdBy string } userID := util.RandomStringPrefix(userIDPrefix, userIDLength) syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix() - if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, createdBy, now); err != nil { + if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil { return err } return nil @@ -562,6 +577,29 @@ func (a *Manager) RemoveUser(username string) error { return nil } +// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents +// successful auth via Authenticate. A background process will delete the user at a later date. +func (a *Manager) MarkUserRemoved(user *User) error { + if !AllowedUsername(user.Name) { + return ErrInvalidArgument + } + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil { + return err + } + if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil { + return err + } + if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil { + return err + } + return tx.Commit() +} + // Users returns a list of users. It always also returns the Everyone user ("*"). func (a *Manager) Users() ([]*User, error) { rows, err := a.db.Query(selectUsernamesQuery) @@ -632,11 +670,11 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { var id, username, hash, role, prefs, syncTopic string var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString var messages, emails int64 - var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64 + var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -659,6 +697,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero }, + Deleted: deleted.Valid, } if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { return nil, err diff --git a/user/manager_test.go b/user/manager_test.go index 35d8cac..0fedc1c 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -13,8 +13,8 @@ const minBcryptTimingMillis = int64(50) // Ideally should be >100ms, but this sh func TestManager_FullScenario_Default_DenyAll(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("phil", "phil", RoleAdmin)) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite)) @@ -92,20 +92,20 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) { func TestManager_AddUser_Invalid(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin, "unit-test")) - require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role", "unit-test")) + require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin)) + require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role")) } func TestManager_AddUser_Timing(t *testing.T) { a := newTestManager(t, PermissionDenyAll) start := time.Now().UnixMilli() - require.Nil(t, a.AddUser("user", "pass", RoleAdmin, "unit-test")) + require.Nil(t, a.AddUser("user", "pass", RoleAdmin)) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) } func TestManager_Authenticate_Timing(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("user", "pass", RoleAdmin, "unit-test")) + require.Nil(t, a.AddUser("user", "pass", RoleAdmin)) // Timing a correct attempt start := time.Now().UnixMilli() @@ -126,10 +126,60 @@ func TestManager_Authenticate_Timing(t *testing.T) { require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) } +func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) { + a := newTestManager(t, PermissionDenyAll) + + // Create user, add reservations and token + require.Nil(t, a.AddUser("user", "pass", RoleAdmin)) + require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead)) + + u, err := a.User("user") + require.Nil(t, err) + require.False(t, u.Deleted) + + token, err := a.CreateToken(u) + require.Nil(t, err) + + u, err = a.Authenticate("user", "pass") + require.Nil(t, err) + + _, err = a.AuthenticateToken(token.Value) + require.Nil(t, err) + + reservations, err := a.Reservations("user") + require.Nil(t, err) + require.Equal(t, 1, len(reservations)) + + // Mark deleted: cannot auth anymore, and all reservations are gone + require.Nil(t, a.MarkUserRemoved(u)) + + _, err = a.Authenticate("user", "pass") + require.Equal(t, ErrUnauthenticated, err) + + _, err = a.AuthenticateToken(token.Value) + require.Equal(t, ErrUnauthenticated, err) + + reservations, err = a.Reservations("user") + require.Nil(t, err) + require.Equal(t, 0, len(reservations)) + + // Make sure user is still there + u, err = a.User("user") + require.Nil(t, err) + require.True(t, u.Deleted) + + _, err = a.db.Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID) + require.Nil(t, err) + require.Nil(t, a.RemoveDeletedUsers()) + + _, err = a.User("user") + require.Equal(t, ErrUserNotFound, err) +} + func TestManager_UserManagement(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("phil", "phil", RoleAdmin)) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite)) @@ -219,7 +269,7 @@ func TestManager_UserManagement(t *testing.T) { func TestManager_ChangePassword(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) + require.Nil(t, a.AddUser("phil", "phil", RoleAdmin)) _, err := a.Authenticate("phil", "phil") require.Nil(t, err) @@ -233,7 +283,7 @@ func TestManager_ChangePassword(t *testing.T) { func TestManager_ChangeRole(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) @@ -258,7 +308,7 @@ func TestManager_ChangeRole(t *testing.T) { func TestManager_Reservations(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll)) require.Nil(t, a.AddReservation("ben", "readme", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead)) @@ -292,7 +342,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { AttachmentTotalSizeLimit: 524288000, AttachmentExpiryDuration: 24 * time.Hour, })) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.ChangeTier("ben", "pro")) require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll)) @@ -340,7 +390,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { func TestManager_Token_Valid(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) u, err := a.User("ben") require.Nil(t, err) @@ -365,7 +415,7 @@ func TestManager_Token_Valid(t *testing.T) { func TestManager_Token_Invalid(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length require.Nil(t, u) @@ -378,7 +428,7 @@ func TestManager_Token_Invalid(t *testing.T) { func TestManager_Token_Expire(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) u, err := a.User("ben") require.Nil(t, err) @@ -426,7 +476,7 @@ func TestManager_Token_Expire(t *testing.T) { func TestManager_Token_Extend(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) // Try to extend token for user without token u, err := a.User("ben") @@ -453,7 +503,7 @@ func TestManager_Token_Extend(t *testing.T) { func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { a := newTestManager(t, PermissionDenyAll) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) // Try to extend token for user without token u, err := a.User("ben") @@ -497,7 +547,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { func TestManager_EnqueueStats(t *testing.T) { a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) require.Nil(t, err) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) // Baseline: No messages or emails u, err := a.User("ben") @@ -527,7 +577,7 @@ func TestManager_EnqueueStats(t *testing.T) { func TestManager_ChangeSettings(t *testing.T) { a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) require.Nil(t, err) - require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) // No settings u, err := a.User("ben") diff --git a/user/types.go b/user/types.go index 2213753..8783479 100644 --- a/user/types.go +++ b/user/types.go @@ -20,8 +20,7 @@ type User struct { Stats *Stats Billing *Billing SyncTopic string - Created time.Time - LastSeen time.Time + Deleted bool } // Auther is an interface for authentication and authorization @@ -186,7 +185,8 @@ const ( // Everyone is a special username representing anonymous users const ( - Everyone = "*" + Everyone = "*" + everyoneID = "u_everyone" ) var ( diff --git a/util/util.go b/util/util.go index 15a922d..1650bb5 100644 --- a/util/util.go +++ b/util/util.go @@ -324,3 +324,15 @@ func UnmarshalJSONWithLimit[T any](r io.ReadCloser, limit int) (*T, error) { } return &obj, nil } + +// Retry executes function f until if succeeds, and then returns t. If f fails, it sleeps +// and tries again. The sleep durations are passed as the after params. +func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error) { + for _, delay := range after { + if t, err = f(); err == nil { + return t, nil + } + time.Sleep(delay) + } + return nil, err +}