Logging WIP
This commit is contained in:
parent
a6641980c2
commit
5d6051c490
11 changed files with 108 additions and 124 deletions
|
@ -26,6 +26,8 @@ func TestCLI_Access_Grant_And_Publish(t *testing.T) {
|
||||||
stdin.WriteString("philpass\nphilpass\nbenpass\nbenpass")
|
stdin.WriteString("philpass\nphilpass\nbenpass\nbenpass")
|
||||||
require.Nil(t, runUserCommand(app, conf, "add", "--role=admin", "phil"))
|
require.Nil(t, runUserCommand(app, conf, "add", "--role=admin", "phil"))
|
||||||
require.Nil(t, runUserCommand(app, conf, "add", "ben"))
|
require.Nil(t, runUserCommand(app, conf, "add", "ben"))
|
||||||
|
|
||||||
|
app, stdin, _, _ = newTestApp()
|
||||||
require.Nil(t, runAccessCommand(app, conf, "ben", "announcements", "rw"))
|
require.Nil(t, runAccessCommand(app, conf, "ben", "announcements", "rw"))
|
||||||
require.Nil(t, runAccessCommand(app, conf, "ben", "sometopic", "read"))
|
require.Nil(t, runAccessCommand(app, conf, "ben", "sometopic", "read"))
|
||||||
require.Nil(t, runAccessCommand(app, conf, "everyone", "announcements", "read"))
|
require.Nil(t, runAccessCommand(app, conf, "everyone", "announcements", "read"))
|
||||||
|
|
|
@ -76,7 +76,7 @@ func (e *Event) Fields(fields map[string]any) *Event {
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Event) Context(contexts ...Ctx) *Event {
|
func (e *Event) Context(contexts ...Contexter) *Event {
|
||||||
for _, c := range contexts {
|
for _, c := range contexts {
|
||||||
e.Fields(c.Context())
|
e.Fields(c.Context())
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,7 @@ func Trace(message string, v ...any) {
|
||||||
newEvent().Trace(message, v...)
|
newEvent().Trace(message, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Context(contexts ...Ctx) *Event {
|
func Context(contexts ...Contexter) *Event {
|
||||||
return newEvent().Context(contexts...)
|
return newEvent().Context(contexts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,7 @@ func ToFormat(s string) Format {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ctx interface {
|
type Contexter interface {
|
||||||
Context() map[string]any
|
Context() map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ func (f fieldsCtx) Context() map[string]any {
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCtx(fields map[string]any) Ctx {
|
func NewCtx(fields map[string]any) Contexter {
|
||||||
return fieldsCtx(fields)
|
return fieldsCtx(fields)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -149,6 +149,7 @@ const (
|
||||||
tagManager = "manager"
|
tagManager = "manager"
|
||||||
tagResetter = "resetter"
|
tagResetter = "resetter"
|
||||||
tagWebsocket = "websocket"
|
tagWebsocket = "websocket"
|
||||||
|
tagMatrix = "matrix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// New instantiates a new Server. It creates the cache and adds a Firebase
|
// New instantiates a new Server. It creates the cache and adds a Firebase
|
||||||
|
@ -328,9 +329,9 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||||
if websocket.IsWebSocketUpgrade(r) {
|
if websocket.IsWebSocketUpgrade(r) {
|
||||||
isNormalError := strings.Contains(err.Error(), "i/o timeout")
|
isNormalError := strings.Contains(err.Error(), "i/o timeout")
|
||||||
if isNormalError {
|
if isNormalError {
|
||||||
logvr(v, r).Tag(tagWebsocket).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
|
logvr(v, r).Tag(tagWebsocket).Err(err).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
|
||||||
} else {
|
} else {
|
||||||
logvr(v, r).Tag(tagWebsocket).Info("WebSocket error: %s", err.Error())
|
logvr(v, r).Tag(tagWebsocket).Err(err).Info("WebSocket error: %s", err.Error())
|
||||||
}
|
}
|
||||||
return // Do not attempt to write to upgraded connection
|
return // Do not attempt to write to upgraded connection
|
||||||
}
|
}
|
||||||
|
@ -711,7 +712,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||||
logvm(v, m).Err(err).Warn("Unable to publish poll request")
|
logvm(v, m).Err(err).Warn("Unable to publish poll request")
|
||||||
return
|
return
|
||||||
} else if response.StatusCode != http.StatusOK {
|
} else if response.StatusCode != http.StatusOK {
|
||||||
logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d")
|
logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d", response.StatusCode)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1537,6 +1538,7 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||||
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||||
return next(w, r, v)
|
return next(w, r, v)
|
||||||
} else if err := v.RequestAllowed(); err != nil {
|
} else if err := v.RequestAllowed(); err != nil {
|
||||||
|
logvr(v, r).Err(err).Fields(requestLimiterFields(v.RequestLimiter())).Trace("Request not allowed by rate limiter")
|
||||||
return errHTTPTooManyRequestsLimitRequests
|
return errHTTPTooManyRequestsLimitRequests
|
||||||
}
|
}
|
||||||
return next(w, r, v)
|
return next(w, r, v)
|
||||||
|
@ -1601,6 +1603,7 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
|
newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logvr(v, r).Tag(tagMatrix).Err(err).Trace("Invalid Matrix request")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := next(w, newRequest, v); err != nil {
|
if err := next(w, newRequest, v); err != nil {
|
||||||
|
@ -1630,7 +1633,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
|
||||||
u := v.User()
|
u := v.User()
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
if err := s.userManager.Authorize(u, t.ID, perm); err != nil {
|
if err := s.userManager.Authorize(u, t.ID, perm); err != nil {
|
||||||
logvr(v, r).Err(err).Debug("Unauthorized")
|
logvr(v, r).Err(err).Field("message_topic", t.ID).Debug("Access to topic %s not authorized", t.ID)
|
||||||
return errHTTPForbidden
|
return errHTTPForbidden
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1644,7 +1647,7 @@ func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
|
||||||
ip := extractIPAddress(r, s.config.BehindProxy)
|
ip := extractIPAddress(r, s.config.BehindProxy)
|
||||||
var u *user.User // may stay nil if no auth header!
|
var u *user.User // may stay nil if no auth header!
|
||||||
if u, err = s.authenticate(r); err != nil {
|
if u, err = s.authenticate(r); err != nil {
|
||||||
logr(r).Debug("Authentication failed: %s", err.Error())
|
logr(r).Err(err).Debug("Authentication failed: %s", err.Error())
|
||||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||||
}
|
}
|
||||||
v = s.visitor(ip, u)
|
v = s.visitor(ip, u)
|
||||||
|
|
|
@ -160,7 +160,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), u, 0); err != nil {
|
if err := s.maybeRemoveMessagesAndExcessReservations(r, v, u, 0); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
logvr(v, r).Tag(tagAccount).Info("Marking user %s as deleted", u.Name)
|
logvr(v, r).Tag(tagAccount).Info("Marking user %s as deleted", u.Name)
|
||||||
|
@ -462,18 +462,19 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
||||||
// maybeRemoveMessagesAndExcessReservations deletes topic reservations for the given user (if too many for tier),
|
// maybeRemoveMessagesAndExcessReservations deletes topic reservations for the given user (if too many for tier),
|
||||||
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
|
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
|
||||||
// The process relies on the manager to perform the actual deletions (see runManager).
|
// The process relies on the manager to perform the actual deletions (see runManager).
|
||||||
func (s *Server) maybeRemoveMessagesAndExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
|
func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error {
|
||||||
reservations, err := s.userManager.Reservations(u.Name)
|
reservations, err := s.userManager.Reservations(u.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if int64(len(reservations)) <= reservationsLimit {
|
} else if int64(len(reservations)) <= reservationsLimit {
|
||||||
|
logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
topics := make([]string, 0)
|
topics := make([]string, 0)
|
||||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||||
topics = append(topics, reservations[i].Topic)
|
topics = append(topics, reservations[i].Topic)
|
||||||
}
|
}
|
||||||
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
|
logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", "))
|
||||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"heckel.io/ntfy/log"
|
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -147,7 +146,7 @@ func writeMatrixDiscoveryResponse(w http.ResponseWriter) error {
|
||||||
|
|
||||||
// writeMatrixError logs and writes the errMatrix to the given http.ResponseWriter as a matrixResponse
|
// writeMatrixError logs and writes the errMatrix to the given http.ResponseWriter as a matrixResponse
|
||||||
func writeMatrixError(w http.ResponseWriter, r *http.Request, v *visitor, err *errMatrix) error {
|
func writeMatrixError(w http.ResponseWriter, r *http.Request, v *visitor, err *errMatrix) error {
|
||||||
log.Debug("%s Matrix gateway error: %s", logHTTPPrefix(v, r), err.Error())
|
logvr(v, r).Tag(tagMatrix).Err(err).Debug("Matrix gateway error")
|
||||||
return writeMatrixResponse(w, err.pushKey)
|
return writeMatrixResponse(w, err.pushKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stripe/stripe-go/v74"
|
"github.com/stripe/stripe-go/v74"
|
||||||
|
@ -121,7 +120,13 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
||||||
} else if tier.StripePriceID == "" {
|
} else if tier.StripePriceID == "" {
|
||||||
return errNotAPaidTier
|
return errNotAPaidTier
|
||||||
}
|
}
|
||||||
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
|
logvr(v, r).
|
||||||
|
Tag(tagPay).
|
||||||
|
Fields(map[string]any{
|
||||||
|
"tier": tier,
|
||||||
|
"stripe_price_id": tier.StripePriceID,
|
||||||
|
}).
|
||||||
|
Info("Creating Stripe checkout flow")
|
||||||
var stripeCustomerID *string
|
var stripeCustomerID *string
|
||||||
if u.Billing.StripeCustomerID != "" {
|
if u.Billing.StripeCustomerID != "" {
|
||||||
stripeCustomerID = &u.Billing.StripeCustomerID
|
stripeCustomerID = &u.Billing.StripeCustomerID
|
||||||
|
@ -190,6 +195,18 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
v.SetUser(u)
|
v.SetUser(u)
|
||||||
|
logvr(v, r).
|
||||||
|
Tag(tagPay).
|
||||||
|
Fields(map[string]any{
|
||||||
|
"tier_id": tier.ID,
|
||||||
|
"tier_name": tier.Name,
|
||||||
|
"stripe_price_id": tier.StripePriceID,
|
||||||
|
"stripe_customer_id": sess.Customer.ID,
|
||||||
|
"stripe_subscription_id": sub.ID,
|
||||||
|
"stripe_subscription_status": string(sub.Status),
|
||||||
|
"stripe_subscription_paid_until": sub.CurrentPeriodEnd,
|
||||||
|
}).
|
||||||
|
Info("Stripe checkout flow succeeded, updating user tier and subscription")
|
||||||
customerParams := &stripe.CustomerParams{
|
customerParams := &stripe.CustomerParams{
|
||||||
Params: stripe.Params{
|
Params: stripe.Params{
|
||||||
Metadata: map[string]string{
|
Metadata: map[string]string{
|
||||||
|
@ -201,7 +218,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
||||||
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
|
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
|
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
|
||||||
|
@ -223,7 +240,15 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID)
|
logvr(v, r).
|
||||||
|
Tag(tagPay).
|
||||||
|
Fields(map[string]any{
|
||||||
|
"new_tier_id": tier.ID,
|
||||||
|
"new_tier_name": tier.Name,
|
||||||
|
"new_tier_stripe_price_id": tier.StripePriceID,
|
||||||
|
// Other stripe_* fields filled by visitor context
|
||||||
|
}).
|
||||||
|
Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID)
|
||||||
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
|
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -250,8 +275,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
||||||
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
||||||
// and cancelling the Stripe subscription entirely
|
// and cancelling the Stripe subscription entirely
|
||||||
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
|
logvr(v, r).Tag(tagPay).Info("Deleting Stripe subscription")
|
||||||
u := v.User()
|
u := v.User()
|
||||||
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID)
|
|
||||||
if u.Billing.StripeSubscriptionID != "" {
|
if u.Billing.StripeSubscriptionID != "" {
|
||||||
params := &stripe.SubscriptionParams{
|
params := &stripe.SubscriptionParams{
|
||||||
CancelAtPeriodEnd: stripe.Bool(true),
|
CancelAtPeriodEnd: stripe.Bool(true),
|
||||||
|
@ -267,11 +292,11 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
||||||
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
|
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
|
||||||
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
|
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
|
||||||
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
|
logvr(v, r).Tag(tagPay).Info("Creating Stripe billing portal session")
|
||||||
u := v.User()
|
u := v.User()
|
||||||
if u.Billing.StripeCustomerID == "" {
|
if u.Billing.StripeCustomerID == "" {
|
||||||
return errHTTPBadRequestNotAPaidUser
|
return errHTTPBadRequestNotAPaidUser
|
||||||
}
|
}
|
||||||
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
|
|
||||||
params := &stripe.BillingPortalSessionParams{
|
params := &stripe.BillingPortalSessionParams{
|
||||||
Customer: stripe.String(u.Billing.StripeCustomerID),
|
Customer: stripe.String(u.Billing.StripeCustomerID),
|
||||||
ReturnURL: stripe.String(s.config.BaseURL),
|
ReturnURL: stripe.String(s.config.BaseURL),
|
||||||
|
@ -289,7 +314,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
||||||
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
||||||
// with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
|
// with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
|
||||||
// visitor (v) in this endpoint is the Stripe API, so we don't have u available.
|
// visitor (v) in this endpoint is the Stripe API, so we don't have u available.
|
||||||
func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error {
|
func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
stripeSignature := r.Header.Get("Stripe-Signature")
|
stripeSignature := r.Header.Get("Stripe-Signature")
|
||||||
if stripeSignature == "" {
|
if stripeSignature == "" {
|
||||||
return errHTTPBadRequestBillingRequestInvalid
|
return errHTTPBadRequestBillingRequestInvalid
|
||||||
|
@ -308,74 +333,105 @@ func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Requ
|
||||||
}
|
}
|
||||||
switch event.Type {
|
switch event.Type {
|
||||||
case "customer.subscription.updated":
|
case "customer.subscription.updated":
|
||||||
return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
|
return s.handleAccountBillingWebhookSubscriptionUpdated(r, v, event)
|
||||||
case "customer.subscription.deleted":
|
case "customer.subscription.deleted":
|
||||||
return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
|
return s.handleAccountBillingWebhookSubscriptionDeleted(r, v, event)
|
||||||
default:
|
default:
|
||||||
log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
|
logvr(v, r).
|
||||||
|
Tag(tagPay).
|
||||||
|
Field("stripe_webhook_type", event.Type).
|
||||||
|
Warn("Unhandled Stripe webhook event %s received", event.Type)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request, v *visitor, event stripe.Event) error {
|
||||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
|
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
|
||||||
return errHTTPBadRequestBillingRequestInvalid
|
return errHTTPBadRequestBillingRequestInvalid
|
||||||
}
|
}
|
||||||
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
|
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
|
||||||
log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
|
logvr(v, r).
|
||||||
|
Tag(tagPay).
|
||||||
|
Fields(map[string]any{
|
||||||
|
"stripe_webhook_type": event.Type,
|
||||||
|
"stripe_customer_id": ev.Customer,
|
||||||
|
"stripe_subscription_id": ev.ID,
|
||||||
|
"stripe_subscription_status": ev.Status,
|
||||||
|
"stripe_subscription_paid_until": ev.CurrentPeriodEnd,
|
||||||
|
"stripe_subscription_cancel_at": ev.CancelAt,
|
||||||
|
"stripe_price_id": priceID,
|
||||||
|
}).
|
||||||
|
Info("Updating subscription to status %s, with price %s", ev.Status, priceID)
|
||||||
userFn := func() (*user.User, error) {
|
userFn := func() (*user.User, error) {
|
||||||
return s.userManager.UserByStripeCustomer(ev.Customer)
|
return s.userManager.UserByStripeCustomer(ev.Customer)
|
||||||
}
|
}
|
||||||
|
// We retry the user retrieval function, because during the Stripe checkout, there a race between the browser
|
||||||
|
// checkout success redirect (see handleAccountBillingSubscriptionCreateSuccess), and this webhook. The checkout
|
||||||
|
// success call is the one that updates the user with the Stripe customer ID.
|
||||||
u, err := util.Retry[user.User](userFn, retryUserDelays...)
|
u, err := util.Retry[user.User](userFn, retryUserDelays...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
v.SetUser(u)
|
||||||
tier, err := s.userManager.TierByStripePrice(priceID)
|
tier, err := s.userManager.TierByStripePrice(priceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request, v *visitor, event stripe.Event) error {
|
||||||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if ev.Customer == "" {
|
} else if ev.Customer == "" {
|
||||||
return errHTTPBadRequestBillingRequestInvalid
|
return errHTTPBadRequestBillingRequestInvalid
|
||||||
}
|
}
|
||||||
log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
|
|
||||||
u, err := s.userManager.UserByStripeCustomer(ev.Customer)
|
u, err := s.userManager.UserByStripeCustomer(ev.Customer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
v.SetUser(u)
|
||||||
|
logvr(v, r).
|
||||||
|
Tag(tagPay).
|
||||||
|
Field("stripe_webhook_type", event.Type).
|
||||||
|
Info("Subscription deleted, downgrading to unpaid tier")
|
||||||
|
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||||
reservationsLimit := visitorDefaultReservationsLimit
|
reservationsLimit := visitorDefaultReservationsLimit
|
||||||
if tier != nil {
|
if tier != nil {
|
||||||
reservationsLimit = tier.ReservationLimit
|
reservationsLimit = tier.ReservationLimit
|
||||||
}
|
}
|
||||||
if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
|
if err := s.maybeRemoveMessagesAndExcessReservations(r, v, u, reservationsLimit); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if tier == nil {
|
if tier == nil && u.Tier != nil {
|
||||||
|
logvr(v, r).Tag(tagPay).Info("Resetting tier for user %s", u.Name)
|
||||||
if err := s.userManager.ResetTier(u.Name); err != nil {
|
if err := s.userManager.ResetTier(u.Name); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else if tier != nil && u.TierID() != tier.ID {
|
||||||
|
logvr(v, r).
|
||||||
|
Tag(tagPay).
|
||||||
|
Fields(map[string]any{
|
||||||
|
"new_tier_id": tier.ID,
|
||||||
|
"new_tier_name": tier.Name,
|
||||||
|
"new_tier_stripe_price_id": tier.StripePriceID,
|
||||||
|
}).
|
||||||
|
Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name)
|
||||||
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
|
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ func (s *smtpSession) AuthPlain(username, password string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
||||||
logem(s.state).Debug("%s MAIL FROM: %s (with options: %#v)", from, opts)
|
logem(s.state).Debug("MAIL FROM: %s (with options: %#v)", from, opts)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"github.com/emersion/go-smtp"
|
|
||||||
"heckel.io/ntfy/log"
|
"heckel.io/ntfy/log"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
|
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
|
||||||
|
@ -48,90 +45,6 @@ func readQueryParam(r *http.Request, names ...string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func logr(r *http.Request) *log.Event {
|
|
||||||
return log.Fields(logFieldsHTTP(r))
|
|
||||||
}
|
|
||||||
|
|
||||||
func logv(v *visitor) *log.Event {
|
|
||||||
return log.Context(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func logvr(v *visitor, r *http.Request) *log.Event {
|
|
||||||
return logv(v).Fields(logFieldsHTTP(r))
|
|
||||||
}
|
|
||||||
|
|
||||||
func logvrm(v *visitor, r *http.Request, m *message) *log.Event {
|
|
||||||
return logvr(v, r).Context(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func logvm(v *visitor, m *message) *log.Event {
|
|
||||||
return logv(v).Context(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func logem(state *smtp.ConnectionState) *log.Event {
|
|
||||||
return log.
|
|
||||||
Tag(tagSMTP).
|
|
||||||
Fields(map[string]any{
|
|
||||||
"smtp_hostname": state.Hostname,
|
|
||||||
"smtp_remote_addr": state.RemoteAddr.String(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func logFieldsHTTP(r *http.Request) map[string]any {
|
|
||||||
requestURI := r.RequestURI
|
|
||||||
if requestURI == "" {
|
|
||||||
requestURI = r.URL.Path
|
|
||||||
}
|
|
||||||
return map[string]any{
|
|
||||||
"http_method": r.Method,
|
|
||||||
"http_path": requestURI,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func logHTTPPrefix(v *visitor, r *http.Request) string {
|
|
||||||
requestURI := r.RequestURI
|
|
||||||
if requestURI == "" {
|
|
||||||
requestURI = r.URL.Path
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("HTTP %s %s %s", v.String(), r.Method, requestURI)
|
|
||||||
}
|
|
||||||
|
|
||||||
func logStripePrefix(customerID, subscriptionID string) string {
|
|
||||||
if subscriptionID != "" {
|
|
||||||
return fmt.Sprintf("STRIPE %s/%s", customerID, subscriptionID)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("STRIPE %s", customerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func renderHTTPRequest(r *http.Request) string {
|
|
||||||
peekLimit := 4096
|
|
||||||
lines := fmt.Sprintf("%s %s %s\n", r.Method, r.URL.RequestURI(), r.Proto)
|
|
||||||
for key, values := range r.Header {
|
|
||||||
for _, value := range values {
|
|
||||||
lines += fmt.Sprintf("%s: %s\n", key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lines += "\n"
|
|
||||||
body, err := util.Peek(r.Body, peekLimit)
|
|
||||||
if err != nil {
|
|
||||||
lines = fmt.Sprintf("(could not read body: %s)\n", err.Error())
|
|
||||||
} else if utf8.Valid(body.PeekedBytes) {
|
|
||||||
lines += string(body.PeekedBytes)
|
|
||||||
if body.LimitReached {
|
|
||||||
lines += fmt.Sprintf(" ... (peeked %d bytes)", peekLimit)
|
|
||||||
}
|
|
||||||
lines += "\n"
|
|
||||||
} else {
|
|
||||||
if body.LimitReached {
|
|
||||||
lines += fmt.Sprintf("(peeked bytes not UTF-8, peek limit of %d bytes reached, hex: %x ...)\n", peekLimit, body.PeekedBytes)
|
|
||||||
} else {
|
|
||||||
lines += fmt.Sprintf("(peeked bytes not UTF-8, %d bytes, hex: %x)\n", len(body.PeekedBytes), body.PeekedBytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.Body = body // Important: Reset body, so it can be re-read
|
|
||||||
return strings.TrimSpace(lines)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
|
func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
|
||||||
remoteAddr := r.RemoteAddr
|
remoteAddr := r.RemoteAddr
|
||||||
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
||||||
|
|
|
@ -159,6 +159,10 @@ func (v *visitor) Context() map[string]any {
|
||||||
if v.user != nil {
|
if v.user != nil {
|
||||||
fields["user_id"] = v.user.ID
|
fields["user_id"] = v.user.ID
|
||||||
fields["user_name"] = v.user.Name
|
fields["user_name"] = v.user.Name
|
||||||
|
if v.user.Tier != nil {
|
||||||
|
fields["tier_id"] = v.user.Tier.ID
|
||||||
|
fields["tier_name"] = v.user.Tier.Name
|
||||||
|
}
|
||||||
if v.user.Billing.StripeCustomerID != "" {
|
if v.user.Billing.StripeCustomerID != "" {
|
||||||
fields["stripe_customer_id"] = v.user.Billing.StripeCustomerID
|
fields["stripe_customer_id"] = v.user.Billing.StripeCustomerID
|
||||||
}
|
}
|
||||||
|
@ -178,6 +182,12 @@ func (v *visitor) RequestAllowed() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *visitor) RequestLimiter() *rate.Limiter {
|
||||||
|
v.mu.Lock() // limiters could be replaced!
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
return v.requestLimiter
|
||||||
|
}
|
||||||
|
|
||||||
func (v *visitor) FirebaseAllowed() error {
|
func (v *visitor) FirebaseAllowed() error {
|
||||||
v.mu.Lock()
|
v.mu.Lock()
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
|
|
Loading…
Reference in a new issue