Delayed deletion
This commit is contained in:
parent
9c082a8331
commit
954d919361
14 changed files with 280 additions and 131 deletions
|
@ -39,12 +39,10 @@ TODO
|
|||
--
|
||||
|
||||
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
|
||||
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- Stripe: Add metadata to customer
|
||||
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||
- Logging: Add detailed logging with username/customerID for all Stripe events (phil)
|
||||
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- Stripe webhook: Do not respond wih error if user does not exist (after account deletion)
|
||||
- Stripe: Add metadata to customer
|
||||
|
||||
races:
|
||||
- v.user --> see publishSyncEventAsync() test
|
||||
|
@ -53,7 +51,7 @@ payments:
|
|||
- reconciliation
|
||||
|
||||
delete messages + reserved topics on ResetTier delete attachments in access.go
|
||||
account deletion should delete messages and reservations and attachments
|
||||
|
||||
|
||||
Limits & rate limiting:
|
||||
rate limiting weirdness. wth is going on?
|
||||
|
@ -1256,11 +1254,14 @@ func (s *Server) execManager() {
|
|||
s.mu.Unlock()
|
||||
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
|
||||
|
||||
// Delete expired user tokens
|
||||
// Delete expired user tokens and users
|
||||
if s.userManager != nil {
|
||||
if err := s.userManager.RemoveExpiredTokens(); err != nil {
|
||||
log.Warn("Error expiring user tokens: %s", err.Error())
|
||||
}
|
||||
if err := s.userManager.RemoveDeletedUsers(); err != nil {
|
||||
log.Warn("Error deleting soft-deleted users: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Delete expired attachments
|
||||
|
@ -1283,7 +1284,7 @@ func (s *Server) execManager() {
|
|||
}
|
||||
}
|
||||
|
||||
// DeleteMessages message cache
|
||||
// Prune messages
|
||||
log.Debug("Manager: Pruning messages")
|
||||
expiredMessageIDs, err := s.messageCache.MessagesExpired()
|
||||
if err != nil {
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
const (
|
||||
subscriptionIDLength = 16
|
||||
createdByAPI = "api"
|
||||
syncTopicAccountSyncEvent = "sync"
|
||||
)
|
||||
|
||||
|
@ -34,7 +33,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
|||
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
|
||||
return errHTTPTooManyRequestsLimitAccountCreation
|
||||
}
|
||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User
|
||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
|
@ -118,18 +117,20 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
|||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||
func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if v.user.Billing.StripeSubscriptionID != "" {
|
||||
log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID)
|
||||
log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
|
||||
if v.user.Billing.StripeSubscriptionID != "" {
|
||||
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Info("Deleting user %s", v.user.Name)
|
||||
if err := s.maybeRemoveExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := s.userManager.RemoveUser(v.user.Name); err != nil {
|
||||
log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name)
|
||||
if err := s.userManager.MarkUserRemoved(v.user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
|
|
|
@ -67,8 +67,8 @@ func TestAccount_Signup_AsUser(t *testing.T) {
|
|||
conf.EnableSignup = true
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -133,7 +133,7 @@ func TestAccount_Get_Anonymous(t *testing.T) {
|
|||
|
||||
func TestAccount_ChangeSettings(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
user, _ := s.userManager.User("phil")
|
||||
token, _ := s.userManager.CreateToken(user)
|
||||
|
||||
|
@ -160,7 +160,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
|
|||
|
||||
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -210,7 +210,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
|||
|
||||
func TestAccount_ChangePassword(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -237,7 +237,7 @@ func TestAccount_ChangePassword_NoAccount(t *testing.T) {
|
|||
|
||||
func TestAccount_ExtendToken(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -260,7 +260,7 @@ func TestAccount_ExtendToken(t *testing.T) {
|
|||
|
||||
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"), // Not Bearer!
|
||||
|
@ -271,7 +271,7 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
|||
|
||||
func TestAccount_DeleteToken(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -324,10 +324,15 @@ func TestAccount_Delete_Success(t *testing.T) {
|
|||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Account was marked deleted
|
||||
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "mypass"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Cannot re-create account, since still exists
|
||||
rr = request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
||||
require.Equal(t, 409, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccount_Delete_Not_Allowed(t *testing.T) {
|
||||
|
@ -360,7 +365,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
|
|||
conf := newTestConfigWithAuthFile(t)
|
||||
conf.EnableSignup = true
|
||||
s := newTestServer(t, conf)
|
||||
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin))
|
||||
|
||||
rr := request(t, s, "POST", "/v1/account/reservation", `{"topic":"mytopic","everyone":"deny-all"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "adminpass"),
|
||||
|
|
|
@ -18,15 +18,10 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errNotAPaidTier = errors.New("tier does not have billing price identifier")
|
||||
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
|
||||
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
|
||||
)
|
||||
|
||||
// Payments in ntfy are done via Stripe.
|
||||
//
|
||||
// Pretty much all payments related things are in this file. The following processes
|
||||
|
@ -49,6 +44,16 @@ var (
|
|||
// This is used to keep the local user database fields up to date. Stripe is the source of truth.
|
||||
// What Stripe says is mirrored and not questioned.
|
||||
|
||||
var (
|
||||
errNotAPaidTier = errors.New("tier does not have billing price identifier")
|
||||
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
|
||||
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
|
||||
)
|
||||
|
||||
var (
|
||||
retryUserDelays = []time.Duration{3 * time.Second, 5 * time.Second, 7 * time.Second}
|
||||
)
|
||||
|
||||
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
|
||||
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
|
||||
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
|
@ -114,7 +119,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
} else if tier.StripePriceID == "" {
|
||||
return errNotAPaidTier
|
||||
}
|
||||
log.Info("Stripe: No existing subscription, creating checkout flow")
|
||||
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
|
||||
var stripeCustomerID *string
|
||||
if v.user.Billing.StripeCustomerID != "" {
|
||||
stripeCustomerID = &v.user.Billing.StripeCustomerID
|
||||
|
@ -138,9 +143,6 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
/*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{
|
||||
Enabled: stripe.Bool(true),
|
||||
},*/
|
||||
}
|
||||
sess, err := s.stripe.NewCheckoutSession(params)
|
||||
if err != nil {
|
||||
|
@ -155,7 +157,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
|
||||
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
|
||||
// and only time we can map the local username with the Stripe customer ID.
|
||||
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
// We don't have a v.user in this endpoint, only a userManager!
|
||||
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
|
||||
if len(matches) != 2 {
|
||||
|
@ -182,7 +184,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||
v.SetUser(u)
|
||||
if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
|
||||
|
@ -203,7 +206,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
|
||||
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID)
|
||||
sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -228,6 +231,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|||
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
||||
// and cancelling the Stripe subscription entirely
|
||||
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
|
||||
if v.user.Billing.StripeSubscriptionID != "" {
|
||||
params := &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(true),
|
||||
|
@ -246,6 +250,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
|||
if v.user.Billing.StripeCustomerID == "" {
|
||||
return errHTTPBadRequestNotAPaidUser
|
||||
}
|
||||
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: stripe.String(v.user.Billing.StripeCustomerID),
|
||||
ReturnURL: stripe.String(s.config.BaseURL),
|
||||
|
@ -280,28 +285,30 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|||
} else if event.Data == nil || event.Data.Raw == nil {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
|
||||
log.Info("Stripe: webhook event %s received", event.Type)
|
||||
switch event.Type {
|
||||
case "customer.subscription.updated":
|
||||
return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
|
||||
case "customer.subscription.deleted":
|
||||
return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
|
||||
default:
|
||||
log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
||||
r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
|
||||
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
|
||||
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
|
||||
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
|
||||
log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
|
||||
userFn := func() (*user.User, error) {
|
||||
return s.userManager.UserByStripeCustomer(ev.Customer)
|
||||
}
|
||||
u, err := util.Retry[user.User](userFn, retryUserDelays...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -309,7 +316,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
|
@ -317,47 +324,56 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|||
}
|
||||
|
||||
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
||||
r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if r.Customer == "" {
|
||||
} else if ev.Customer == "" {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
|
||||
u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
||||
log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
|
||||
u, err := s.userManager.UserByStripeCustomer(ev.Customer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||
// Remove excess reservations (if too many for tier), and mark associated messages deleted
|
||||
// maybeRemoveExcessReservations deletes topic reservations for the given user (if too many for tier),
|
||||
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
|
||||
// The process relies on the manager to perform the actual deletions (see runManager).
|
||||
func (s *Server) maybeRemoveExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
|
||||
reservations, err := s.userManager.Reservations(u.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if int64(len(reservations)) <= reservationsLimit {
|
||||
return nil
|
||||
}
|
||||
topics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||
topics = append(topics, reservations[i].Topic)
|
||||
}
|
||||
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
|
||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.messageCache.ExpireMessages(topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||
reservationsLimit := visitorDefaultReservationsLimit
|
||||
if tier != nil {
|
||||
reservationsLimit = tier.ReservationsLimit
|
||||
}
|
||||
if int64(len(reservations)) > reservationsLimit {
|
||||
topics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||
topics = append(topics, reservations[i].Topic)
|
||||
}
|
||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.messageCache.ExpireMessages(topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.maybeRemoveExcessReservations(logPrefix, u, reservationsLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
// Change or remove tier
|
||||
if tier == nil {
|
||||
if err := s.userManager.ResetTier(u.Name); err != nil {
|
||||
return err
|
||||
|
|
|
@ -34,7 +34,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
|||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
// Create subscription
|
||||
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
|
||||
|
@ -69,7 +69,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
|||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
@ -110,7 +110,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
|||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
@ -174,7 +174,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
AttachmentFileSizeLimit: 1000000,
|
||||
AttachmentTotalSizeLimit: 1000000,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
|
||||
|
|
|
@ -625,7 +625,7 @@ func TestServer_Auth_Success_Admin(t *testing.T) {
|
|||
c := newTestConfigWithAuthFile(t)
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -639,7 +639,7 @@ func TestServer_Auth_Success_User(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
|
@ -653,7 +653,7 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite))
|
||||
|
||||
|
@ -674,7 +674,7 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "INVALID"),
|
||||
|
@ -687,7 +687,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic!
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
|
@ -701,7 +701,7 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionReadWrite // Open by default
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
|
||||
|
||||
|
@ -731,7 +731,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
|
|||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin))
|
||||
|
||||
u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(util.BasicAuth("ben", "some pass"))))
|
||||
response := request(t, s, "GET", u, "", nil)
|
||||
|
@ -749,7 +749,7 @@ func TestServer_StatsResetter(t *testing.T) {
|
|||
s := newTestServer(t, c)
|
||||
go s.runStatsResetter()
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
|
@ -1137,7 +1137,7 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
|
|||
MessagesLimit: 5,
|
||||
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
||||
// Publish to reach message limit
|
||||
|
@ -1369,7 +1369,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
|
|||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: sevenDays, // 7 days
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
||||
// Publish and make sure we can retrieve it
|
||||
|
@ -1413,7 +1413,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
|
|||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: 30 * time.Second,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
||||
// Publish small file as anonymous
|
||||
|
|
|
@ -354,5 +354,6 @@ type apiStripeSubscriptionUpdatedEvent struct {
|
|||
}
|
||||
|
||||
type apiStripeSubscriptionDeletedEvent struct {
|
||||
ID string `json:"id"`
|
||||
Customer string `json:"customer"`
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ func readQueryParam(r *http.Request, names ...string) string {
|
|||
}
|
||||
|
||||
func logMessagePrefix(v *visitor, m *message) string {
|
||||
return fmt.Sprintf("%s/%s/%s", v.ip, m.Topic, m.ID)
|
||||
return fmt.Sprintf("%s/%s/%s", v.String(), m.Topic, m.ID)
|
||||
}
|
||||
|
||||
func logHTTPPrefix(v *visitor, r *http.Request) string {
|
||||
|
@ -57,7 +57,14 @@ func logHTTPPrefix(v *visitor, r *http.Request) string {
|
|||
if requestURI == "" {
|
||||
requestURI = r.URL.Path
|
||||
}
|
||||
return fmt.Sprintf("%s HTTP %s %s", v.ip, r.Method, requestURI)
|
||||
return fmt.Sprintf("%s HTTP %s %s", v.String(), r.Method, requestURI)
|
||||
}
|
||||
|
||||
func logStripePrefix(customerID, subscriptionID string) string {
|
||||
if subscriptionID != "" {
|
||||
return fmt.Sprintf("%s/%s STRIPE", customerID, subscriptionID)
|
||||
}
|
||||
return fmt.Sprintf("%s STRIPE", customerID)
|
||||
}
|
||||
|
||||
func logSMTPPrefix(state *smtp.ConnectionState) string {
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/user"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
@ -119,6 +120,17 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
|||
}
|
||||
}
|
||||
|
||||
func (v *visitor) String() string {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
if v.user != nil && v.user.Billing.StripeCustomerID != "" {
|
||||
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
|
||||
} else if v.user != nil {
|
||||
return fmt.Sprintf("%s/%s", v.ip.String(), v.user.ID)
|
||||
}
|
||||
return v.ip.String()
|
||||
}
|
||||
|
||||
func (v *visitor) RequestAllowed() error {
|
||||
if !v.requestLimiter.Allow() {
|
||||
return errVisitorLimitReached
|
||||
|
@ -216,6 +228,12 @@ func (v *visitor) ResetStats() {
|
|||
}
|
||||
}
|
||||
|
||||
func (v *visitor) SetUser(u *user.User) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
v.user = u
|
||||
}
|
||||
|
||||
func (v *visitor) Limits() *visitorLimits {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue