Add password confirmation to account delete dialog, v1/tiers test

This commit is contained in:
binwiederhier 2023-01-23 10:58:39 -05:00
parent 954d919361
commit e82a2e518c
14 changed files with 242 additions and 93 deletions

View file

@ -62,7 +62,7 @@ var (
errHTTPBadRequestNotAPaidUser = &errHTTP{40027, http.StatusBadRequest, "invalid request: not a paid user", ""}
errHTTPBadRequestBillingRequestInvalid = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid billing request", ""}
errHTTPBadRequestBillingSubscriptionExists = &errHTTP{40029, http.StatusBadRequest, "invalid request: billing subscription already exists", ""}
errHTTPBadRequestCurrentPasswordWrong = &errHTTP{40030, http.StatusBadRequest, "invalid request: current password is not correct", ""}
errHTTPBadRequestIncorrectPasswordConfirmation = &errHTTP{40030, http.StatusBadRequest, "invalid request: password confirmation is not correct", ""}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}

View file

@ -40,9 +40,9 @@ TODO
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Stripe: Add metadata to customer
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
- Reservation icons (UI)
races:
- v.user --> see publishSyncEventAsync() test
@ -65,7 +65,6 @@ Make sure account endpoints make sense for admins
UI:
-
- reservation icons
- reservation table delete button: dialog "keep or delete messages?"
- flicker of upgrade banner
- JS constants
@ -858,7 +857,6 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
if m.Time > attachmentExpiry {
return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
}
fmt.Printf("v = %#v\nlimits = %#v\nstats = %#v\n", v, vinfo.Limits, vinfo.Stats)
contentLengthStr := r.Header.Get("Content-Length")
if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)

View file

@ -7,6 +7,7 @@ import (
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"net/http"
"strings"
)
const (
@ -118,17 +119,24 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
}
func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccountDeleteRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
return err
} else if req.Password == "" {
return errHTTPBadRequest
}
if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil {
return errHTTPBadRequestIncorrectPasswordConfirmation
}
if v.user.Billing.StripeSubscriptionID != "" {
log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" {
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
return err
}
}
if err := s.maybeRemoveExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil {
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
return err
}
}
if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil {
return err
}
log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name)
if err := s.userManager.MarkUserRemoved(v.user); err != nil {
return err
@ -144,7 +152,7 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ
return errHTTPBadRequest
}
if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil {
return errHTTPBadRequestCurrentPasswordWrong
return errHTTPBadRequestIncorrectPasswordConfirmation
}
if err := s.userManager.ChangePassword(v.user.Name, req.NewPassword); err != nil {
return err
@ -365,6 +373,30 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
return s.writeJSON(w, newSuccessResponse())
}
// maybeRemoveMessagesAndExcessReservations deletes topic reservations for the given user (if too many for tier),
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
// The process relies on the manager to perform the actual deletions (see runManager).
func (s *Server) maybeRemoveMessagesAndExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
reservations, err := s.userManager.Reservations(u.Name)
if err != nil {
return err
} else if int64(len(reservations)) <= reservationsLimit {
return nil
}
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
return nil
}
func (s *Server) publishSyncEvent(v *visitor) error {
if v.user == nil || v.user.SyncTopic == "" {
return nil

View file

@ -319,7 +319,7 @@ func TestAccount_Delete_Success(t *testing.T) {
})
require.Equal(t, 200, rr.Code)
rr = request(t, s, "DELETE", "/v1/account", "", map[string]string{
rr = request(t, s, "DELETE", "/v1/account", `{"password":"mypass"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"),
})
require.Equal(t, 200, rr.Code)
@ -345,6 +345,15 @@ func TestAccount_Delete_Not_Allowed(t *testing.T) {
rr = request(t, s, "DELETE", "/v1/account", "", nil)
require.Equal(t, 401, rr.Code)
rr = request(t, s, "DELETE", "/v1/account", `{"password":"mypass"}`, nil)
require.Equal(t, 401, rr.Code)
rr = request(t, s, "DELETE", "/v1/account", `{"password":"INCORRECT"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"),
})
require.Equal(t, 400, rr.Code)
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
}
func TestAccount_Reservation_AddWithoutTierFails(t *testing.T) {
@ -386,7 +395,6 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
// Create a tier
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
Paid: false,
MessagesLimit: 123,
MessagesExpiryDuration: 86400 * time.Second,
EmailsLimit: 32,

View file

@ -18,7 +18,6 @@ import (
"io"
"net/http"
"net/netip"
"strings"
"time"
)
@ -66,6 +65,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
{
// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisIP),
Messages: freeTier.MessagesLimit,
MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
Emails: freeTier.EmailsLimit,
@ -90,6 +90,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
Name: tier.Name,
Price: priceStr,
Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisTier),
Messages: tier.MessagesLimit,
MessagesExpiryDuration: int64(tier.MessagesExpiryDuration.Seconds()),
Emails: tier.EmailsLimit,
@ -143,6 +144,11 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
Quantity: stripe.Int64(1),
},
},
Params: stripe.Params{
Metadata: map[string]string{
"user_id": v.user.ID,
},
},
}
sess, err := s.stripe.NewCheckoutSession(params)
if err != nil {
@ -185,6 +191,17 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
return err
}
v.SetUser(u)
customerParams := &stripe.CustomerParams{
Params: stripe.Params{
Metadata: map[string]string{
"user_id": u.ID,
"user_name": u.Name,
},
},
}
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
return err
}
if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err
}
@ -342,36 +359,12 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
return nil
}
// maybeRemoveExcessReservations deletes topic reservations for the given user (if too many for tier),
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
// The process relies on the manager to perform the actual deletions (see runManager).
func (s *Server) maybeRemoveExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
reservations, err := s.userManager.Reservations(u.Name)
if err != nil {
return err
} else if int64(len(reservations)) <= reservationsLimit {
return nil
}
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
return nil
}
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationsLimit
}
if err := s.maybeRemoveExcessReservations(logPrefix, u, reservationsLimit); err != nil {
if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
return err
}
if tier == nil {
@ -426,6 +419,7 @@ type stripeAPI interface {
GetCustomer(id string) (*stripe.Customer, error)
GetSession(id string) (*stripe.CheckoutSession, error)
GetSubscription(id string) (*stripe.Subscription, error)
UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error)
UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error)
CancelSubscription(id string) (*stripe.Subscription, error)
ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error)
@ -472,6 +466,10 @@ func (s *realStripeAPI) GetSubscription(id string) (*stripe.Subscription, error)
return subscription.Get(id, nil)
}
func (s *realStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
return customer.Update(id, params)
}
func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
return subscription.Update(id, params)
}

View file

@ -14,6 +14,108 @@ import (
"time"
)
func TestPayments_Tiers(t *testing.T) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
c.VisitorRequestLimitReplenish = 12 * time.Hour
c.CacheDuration = 13 * time.Hour
c.AttachmentFileSizeLimit = 111
c.VisitorAttachmentTotalSizeLimit = 222
c.AttachmentExpiryDuration = 123 * time.Second
s := newTestServer(t, c)
s.stripe = stripeMock
// Define how the mock should react
stripeMock.
On("ListPrices", mock.Anything).
Return([]*stripe.Price{
{ID: "price_123", UnitAmount: 500},
{ID: "price_456", UnitAmount: 1000},
{ID: "price_999", UnitAmount: 9999},
}, nil)
// Create tiers
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_1",
Code: "admin",
Name: "Admin",
}))
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "pro",
Name: "Pro",
MessagesLimit: 1000,
MessagesExpiryDuration: time.Hour,
EmailsLimit: 123,
ReservationsLimit: 777,
AttachmentFileSizeLimit: 999,
AttachmentTotalSizeLimit: 888,
AttachmentExpiryDuration: time.Minute,
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_444",
Code: "business",
Name: "Business",
MessagesLimit: 2000,
MessagesExpiryDuration: 10 * time.Hour,
EmailsLimit: 123123,
ReservationsLimit: 777333,
AttachmentFileSizeLimit: 999111,
AttachmentTotalSizeLimit: 888111,
AttachmentExpiryDuration: time.Hour,
StripePriceID: "price_456",
}))
response := request(t, s, "GET", "/v1/tiers", "", nil)
require.Equal(t, 200, response.Code)
var tiers []apiAccountBillingTier
require.Nil(t, json.NewDecoder(response.Body).Decode(&tiers))
require.Equal(t, 3, len(tiers))
// Free tier
tier := tiers[0]
require.Equal(t, "", tier.Code)
require.Equal(t, "", tier.Name)
require.Equal(t, "ip", tier.Limits.Basis)
require.Equal(t, int64(0), tier.Limits.Reservations)
require.Equal(t, int64(2), tier.Limits.Messages) // :-(
require.Equal(t, int64(13*3600), tier.Limits.MessagesExpiryDuration)
require.Equal(t, int64(24), tier.Limits.Emails)
require.Equal(t, int64(111), tier.Limits.AttachmentFileSize)
require.Equal(t, int64(222), tier.Limits.AttachmentTotalSize)
require.Equal(t, int64(123), tier.Limits.AttachmentExpiryDuration)
// Admin tier is not included, because it is not paid!
tier = tiers[1]
require.Equal(t, "pro", tier.Code)
require.Equal(t, "Pro", tier.Name)
require.Equal(t, "tier", tier.Limits.Basis)
require.Equal(t, int64(777), tier.Limits.Reservations)
require.Equal(t, int64(1000), tier.Limits.Messages)
require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
require.Equal(t, int64(123), tier.Limits.Emails)
require.Equal(t, int64(999), tier.Limits.AttachmentFileSize)
require.Equal(t, int64(888), tier.Limits.AttachmentTotalSize)
require.Equal(t, int64(60), tier.Limits.AttachmentExpiryDuration)
tier = tiers[2]
require.Equal(t, "business", tier.Code)
require.Equal(t, "Business", tier.Name)
require.Equal(t, "tier", tier.Limits.Basis)
require.Equal(t, int64(777333), tier.Limits.Reservations)
require.Equal(t, int64(2000), tier.Limits.Messages)
require.Equal(t, int64(36000), tier.Limits.MessagesExpiryDuration)
require.Equal(t, int64(123123), tier.Limits.Emails)
require.Equal(t, int64(999111), tier.Limits.AttachmentFileSize)
require.Equal(t, int64(888111), tier.Limits.AttachmentTotalSize)
require.Equal(t, int64(3600), tier.Limits.AttachmentExpiryDuration)
}
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
@ -122,7 +224,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Delete account
rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{
rr := request(t, s, "DELETE", "/v1/account", `{"password": "phil"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
@ -258,6 +360,8 @@ type testStripeAPI struct {
mock.Mock
}
var _ stripeAPI = (*testStripeAPI)(nil)
func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
args := s.Called(params)
return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
@ -288,6 +392,11 @@ func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error)
return args.Get(0).(*stripe.Subscription), args.Error(1)
}
func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
args := s.Called(id)
return args.Get(0).(*stripe.Customer), args.Error(1)
}
func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
args := s.Called(id)
return args.Get(0).(*stripe.Subscription), args.Error(1)
@ -303,8 +412,6 @@ func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, sec
return args.Get(0).(stripe.Event), args.Error(1)
}
var _ stripeAPI = (*testStripeAPI)(nil)
func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
var e stripe.Event
if err := json.Unmarshal([]byte(v), &e); err != nil {

View file

@ -231,6 +231,10 @@ type apiAccountPasswordChangeRequest struct {
NewPassword string `json:"new_password"`
}
type apiAccountDeleteRequest struct {
Password string `json:"password"`
}
type apiAccountTokenResponse struct {
Token string `json:"token"`
Expires int64 `json:"expires"`

View file

@ -57,18 +57,18 @@ func logHTTPPrefix(v *visitor, r *http.Request) string {
if requestURI == "" {
requestURI = r.URL.Path
}
return fmt.Sprintf("%s HTTP %s %s", v.String(), r.Method, requestURI)
return fmt.Sprintf("HTTP %s %s %s", v.String(), r.Method, requestURI)
}
func logStripePrefix(customerID, subscriptionID string) string {
if subscriptionID != "" {
return fmt.Sprintf("%s/%s STRIPE", customerID, subscriptionID)
return fmt.Sprintf("STRIPE %s/%s", customerID, subscriptionID)
}
return fmt.Sprintf("%s STRIPE", customerID)
return fmt.Sprintf("STRIPE %s", customerID)
}
func logSMTPPrefix(state *smtp.ConnectionState) string {
return fmt.Sprintf("%s/%s SMTP", state.Hostname, state.RemoteAddr.String())
return fmt.Sprintf("SMTP %s/%s", state.Hostname, state.RemoteAddr.String())
}
func renderHTTPRequest(r *http.Request) string {