Delayed deletion

This commit is contained in:
binwiederhier 2023-01-22 22:21:30 -05:00
parent 9c082a8331
commit 954d919361
14 changed files with 280 additions and 131 deletions

View file

@ -39,12 +39,10 @@ TODO
--
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Stripe: Add metadata to customer
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
- Logging: Add detailed logging with username/customerID for all Stripe events (phil)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Stripe webhook: Do not respond wih error if user does not exist (after account deletion)
- Stripe: Add metadata to customer
races:
- v.user --> see publishSyncEventAsync() test
@ -53,7 +51,7 @@ payments:
- reconciliation
delete messages + reserved topics on ResetTier delete attachments in access.go
account deletion should delete messages and reservations and attachments
Limits & rate limiting:
rate limiting weirdness. wth is going on?
@ -1256,11 +1254,14 @@ func (s *Server) execManager() {
s.mu.Unlock()
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
// Delete expired user tokens
// Delete expired user tokens and users
if s.userManager != nil {
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Warn("Error expiring user tokens: %s", err.Error())
}
if err := s.userManager.RemoveDeletedUsers(); err != nil {
log.Warn("Error deleting soft-deleted users: %s", err.Error())
}
}
// Delete expired attachments
@ -1283,7 +1284,7 @@ func (s *Server) execManager() {
}
}
// DeleteMessages message cache
// Prune messages
log.Debug("Manager: Pruning messages")
expiredMessageIDs, err := s.messageCache.MessagesExpired()
if err != nil {

View file

@ -11,7 +11,6 @@ import (
const (
subscriptionIDLength = 16
createdByAPI = "api"
syncTopicAccountSyncEvent = "sync"
)
@ -34,7 +33,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
return errHTTPTooManyRequestsLimitAccountCreation
}
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
return err
}
return s.writeJSON(w, newSuccessResponse())
@ -118,18 +117,20 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
return s.writeJSON(w, response)
}
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID != "" {
log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID)
log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" {
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
return err
}
}
} else {
log.Info("Deleting user %s", v.user.Name)
if err := s.maybeRemoveExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil {
return err
}
}
if err := s.userManager.RemoveUser(v.user.Name); err != nil {
log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name)
if err := s.userManager.MarkUserRemoved(v.user); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())

View file

@ -67,8 +67,8 @@ func TestAccount_Signup_AsUser(t *testing.T) {
conf.EnableSignup = true
s := newTestServer(t, conf)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -133,7 +133,7 @@ func TestAccount_Get_Anonymous(t *testing.T) {
func TestAccount_ChangeSettings(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
user, _ := s.userManager.User("phil")
token, _ := s.userManager.CreateToken(user)
@ -160,7 +160,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -210,7 +210,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
func TestAccount_ChangePassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -237,7 +237,7 @@ func TestAccount_ChangePassword_NoAccount(t *testing.T) {
func TestAccount_ExtendToken(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -260,7 +260,7 @@ func TestAccount_ExtendToken(t *testing.T) {
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), // Not Bearer!
@ -271,7 +271,7 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
func TestAccount_DeleteToken(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -324,10 +324,15 @@ func TestAccount_Delete_Success(t *testing.T) {
})
require.Equal(t, 200, rr.Code)
// Account was marked deleted
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"),
})
require.Equal(t, 401, rr.Code)
// Cannot re-create account, since still exists
rr = request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
require.Equal(t, 409, rr.Code)
}
func TestAccount_Delete_Not_Allowed(t *testing.T) {
@ -360,7 +365,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
conf.EnableSignup = true
s := newTestServer(t, conf)
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin))
rr := request(t, s, "POST", "/v1/account/reservation", `{"topic":"mytopic","everyone":"deny-all"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "adminpass"),

View file

@ -18,15 +18,10 @@ import (
"io"
"net/http"
"net/netip"
"strings"
"time"
)
var (
errNotAPaidTier = errors.New("tier does not have billing price identifier")
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
)
// Payments in ntfy are done via Stripe.
//
// Pretty much all payments related things are in this file. The following processes
@ -49,6 +44,16 @@ var (
// This is used to keep the local user database fields up to date. Stripe is the source of truth.
// What Stripe says is mirrored and not questioned.
var (
errNotAPaidTier = errors.New("tier does not have billing price identifier")
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
)
var (
retryUserDelays = []time.Duration{3 * time.Second, 5 * time.Second, 7 * time.Second}
)
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
@ -114,7 +119,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
} else if tier.StripePriceID == "" {
return errNotAPaidTier
}
log.Info("Stripe: No existing subscription, creating checkout flow")
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
var stripeCustomerID *string
if v.user.Billing.StripeCustomerID != "" {
stripeCustomerID = &v.user.Billing.StripeCustomerID
@ -138,9 +143,6 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
Quantity: stripe.Int64(1),
},
},
/*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{
Enabled: stripe.Bool(true),
},*/
}
sess, err := s.stripe.NewCheckoutSession(params)
if err != nil {
@ -155,7 +157,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
// and only time we can map the local username with the Stripe customer ID.
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
@ -182,7 +184,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
v.SetUser(u)
if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err
}
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -203,7 +206,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil {
return err
}
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID)
sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
if err != nil {
return err
@ -228,6 +231,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
// and cancelling the Stripe subscription entirely
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" {
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(true),
@ -246,6 +250,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(v.user.Billing.StripeCustomerID),
ReturnURL: stripe.String(s.config.BaseURL),
@ -280,28 +285,30 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
} else if event.Data == nil || event.Data.Raw == nil {
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: webhook event %s received", event.Type)
switch event.Type {
case "customer.subscription.updated":
return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
case "customer.subscription.deleted":
return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
default:
log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
return nil
}
}
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
if err != nil {
return err
} else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
return errHTTPBadRequestBillingRequestInvalid
}
subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
u, err := s.userManager.UserByStripeCustomer(r.Customer)
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
userFn := func() (*user.User, error) {
return s.userManager.UserByStripeCustomer(ev.Customer)
}
u, err := util.Retry[user.User](userFn, retryUserDelays...)
if err != nil {
return err
}
@ -309,7 +316,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
@ -317,47 +324,56 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
}
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
if err != nil {
return err
} else if r.Customer == "" {
} else if ev.Customer == "" {
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
u, err := s.userManager.UserByStripeCustomer(r.Customer)
log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
u, err := s.userManager.UserByStripeCustomer(ev.Customer)
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil {
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil
}
func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
// Remove excess reservations (if too many for tier), and mark associated messages deleted
// maybeRemoveExcessReservations deletes topic reservations for the given user (if too many for tier),
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
// The process relies on the manager to perform the actual deletions (see runManager).
func (s *Server) maybeRemoveExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
reservations, err := s.userManager.Reservations(u.Name)
if err != nil {
return err
} else if int64(len(reservations)) <= reservationsLimit {
return nil
}
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
return nil
}
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationsLimit
}
if int64(len(reservations)) > reservationsLimit {
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
if err := s.maybeRemoveExcessReservations(logPrefix, u, reservationsLimit); err != nil {
return err
}
// Change or remove tier
if tier == nil {
if err := s.userManager.ResetTier(u.Name); err != nil {
return err

View file

@ -34,7 +34,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
Code: "pro",
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
@ -69,7 +69,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
Code: "pro",
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
u, err := s.userManager.User("phil")
require.Nil(t, err)
@ -110,7 +110,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
Code: "pro",
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
u, err := s.userManager.User("phil")
require.Nil(t, err)
@ -174,7 +174,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))

View file

@ -625,7 +625,7 @@ func TestServer_Auth_Success_Admin(t *testing.T) {
c := newTestConfigWithAuthFile(t)
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -639,7 +639,7 @@ func TestServer_Auth_Success_User(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@ -653,7 +653,7 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite))
@ -674,7 +674,7 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("phil", "INVALID"),
@ -687,7 +687,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic!
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@ -701,7 +701,7 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
c.AuthDefault = user.PermissionReadWrite // Open by default
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
@ -731,7 +731,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin))
u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(util.BasicAuth("ben", "some pass"))))
response := request(t, s, "GET", u, "", nil)
@ -749,7 +749,7 @@ func TestServer_StatsResetter(t *testing.T) {
s := newTestServer(t, c)
go s.runStatsResetter()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
for i := 0; i < 5; i++ {
@ -1137,7 +1137,7 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
MessagesLimit: 5,
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish to reach message limit
@ -1369,7 +1369,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: sevenDays, // 7 days
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish and make sure we can retrieve it
@ -1413,7 +1413,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: 30 * time.Second,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish small file as anonymous

View file

@ -354,5 +354,6 @@ type apiStripeSubscriptionUpdatedEvent struct {
}
type apiStripeSubscriptionDeletedEvent struct {
ID string `json:"id"`
Customer string `json:"customer"`
}

View file

@ -49,7 +49,7 @@ func readQueryParam(r *http.Request, names ...string) string {
}
func logMessagePrefix(v *visitor, m *message) string {
return fmt.Sprintf("%s/%s/%s", v.ip, m.Topic, m.ID)
return fmt.Sprintf("%s/%s/%s", v.String(), m.Topic, m.ID)
}
func logHTTPPrefix(v *visitor, r *http.Request) string {
@ -57,7 +57,14 @@ func logHTTPPrefix(v *visitor, r *http.Request) string {
if requestURI == "" {
requestURI = r.URL.Path
}
return fmt.Sprintf("%s HTTP %s %s", v.ip, r.Method, requestURI)
return fmt.Sprintf("%s HTTP %s %s", v.String(), r.Method, requestURI)
}
func logStripePrefix(customerID, subscriptionID string) string {
if subscriptionID != "" {
return fmt.Sprintf("%s/%s STRIPE", customerID, subscriptionID)
}
return fmt.Sprintf("%s STRIPE", customerID)
}
func logSMTPPrefix(state *smtp.ConnectionState) string {

View file

@ -2,6 +2,7 @@ package server
import (
"errors"
"fmt"
"heckel.io/ntfy/user"
"net/netip"
"sync"
@ -119,6 +120,17 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
}
}
func (v *visitor) String() string {
v.mu.Lock()
defer v.mu.Unlock()
if v.user != nil && v.user.Billing.StripeCustomerID != "" {
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
} else if v.user != nil {
return fmt.Sprintf("%s/%s", v.ip.String(), v.user.ID)
}
return v.ip.String()
}
func (v *visitor) RequestAllowed() error {
if !v.requestLimiter.Allow() {
return errVisitorLimitReached
@ -216,6 +228,12 @@ func (v *visitor) ResetStats() {
}
}
func (v *visitor) SetUser(u *user.User) {
v.mu.Lock()
defer v.mu.Unlock()
v.user = u
}
func (v *visitor) Limits() *visitorLimits {
v.mu.Lock()
defer v.mu.Unlock()