Merge branch 'main' into user-account

This commit is contained in:
binwiederhier 2023-02-22 19:22:47 -05:00
commit 4ab450309f
81 changed files with 4094 additions and 13687 deletions

View file

@ -61,7 +61,7 @@ var (
// DefaultDisallowedTopics defines the topics that are forbidden, because they are used elsewhere. This array can be
// extended using the server.yml config. If updated, also update in Android and web app.
DefaultDisallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "signup", "login"}
DefaultDisallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "signup", "login", "v1"}
)
// Config is the main config struct for the application. Use New to instantiate a default config struct.

View file

@ -11,19 +11,38 @@ import (
"unicode/utf8"
)
// Log tags
const (
tagStartup = "startup"
tagHTTP = "http"
tagPublish = "publish"
tagSubscribe = "subscribe"
tagFirebase = "firebase"
tagSMTP = "smtp" // Receive email
tagEmail = "email" // Send email
tagFileCache = "file_cache"
tagMessageCache = "message_cache"
tagStripe = "stripe"
tagAccount = "account"
tagManager = "manager"
tagResetter = "resetter"
tagWebsocket = "websocket"
tagMatrix = "matrix"
)
// logr creates a new log event with HTTP request fields
func logr(r *http.Request) *log.Event {
return log.Fields(httpContext(r))
return log.Tag(tagHTTP).Fields(httpContext(r)) // Tag may be overwritten
}
// logr creates a new log event with visitor fields
// logv creates a new log event with visitor fields
func logv(v *visitor) *log.Event {
return log.With(v)
}
// logr creates a new log event with HTTP request and visitor fields
// logvr creates a new log event with HTTP request and visitor fields
func logvr(v *visitor, r *http.Request) *log.Event {
return logv(v).Fields(httpContext(r))
return logr(r).With(v)
}
// logvrm creates a new log event with HTTP request, visitor fields and message fields
@ -37,13 +56,12 @@ func logvm(v *visitor, m *message) *log.Event {
}
// logem creates a new log event with email fields
func logem(state *smtp.ConnectionState) *log.Event {
return log.
Tag(tagSMTP).
Fields(log.Context{
"smtp_hostname": state.Hostname,
"smtp_remote_addr": state.RemoteAddr.String(),
})
func logem(smtpConn *smtp.Conn) *log.Event {
ev := log.Tag(tagSMTP).Field("smtp_hostname", smtpConn.Hostname())
if smtpConn.Conn() != nil {
ev.Field("smtp_remote_addr", smtpConn.Conn().RemoteAddr().String())
}
return ev
}
func httpContext(r *http.Request) log.Context {

View file

@ -536,7 +536,7 @@ func (c *messageCache) ExpireMessages(topics ...string) error {
}
defer tx.Rollback()
for _, t := range topics {
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix(), t); err != nil {
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix()-1, t); err != nil {
return err
}
}

View file

@ -33,16 +33,6 @@ import (
"heckel.io/ntfy/util"
)
/*
- MEDIUM fail2ban to work with ntfy log not nginx log
- HIGH Docs
- tiers
- api
- tokens
*/
// Server is the main server, providing the UI and API for ntfy
type Server struct {
config *Config
@ -56,11 +46,11 @@ type Server struct {
visitors map[string]*visitor // ip:<ip> or user:<user>
firebaseClient *firebaseClient
messages int64
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
closeChan chan bool
mu sync.Mutex
}
@ -134,24 +124,6 @@ const (
wsPongWait = 15 * time.Second
)
// Log tags
const (
tagStartup = "startup"
tagPublish = "publish"
tagSubscribe = "subscribe"
tagFirebase = "firebase"
tagSMTP = "smtp" // Receive email
tagEmail = "email" // Send email
tagFileCache = "file_cache"
tagMessageCache = "message_cache"
tagStripe = "stripe"
tagAccount = "account"
tagManager = "manager"
tagResetter = "resetter"
tagWebsocket = "websocket"
tagMatrix = "matrix"
)
// New instantiates a new Server. It creates the cache and adds a Firebase
// subscriber (if configured).
func New(conf *Config) (*Server, error) {
@ -327,11 +299,11 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
s.handleError(w, r, v, err)
return
}
if log.IsTrace() {
logvr(v, r).Field("http_request", renderHTTPRequest(r)).Trace("HTTP request started")
} else if log.IsDebug() {
logvr(v, r).Debug("HTTP request started")
ev := logvr(v, r)
if ev.IsTrace() {
ev.Field("http_request", renderHTTPRequest(r)).Trace("HTTP request started")
} else if logvr(v, r).IsDebug() {
ev.Debug("HTTP request started")
}
logvr(v, r).
Timing(func() {
@ -344,8 +316,12 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor, err error) {
httpErr, ok := err.(*errHTTP)
if !ok {
httpErr = errHTTPInternalError
}
isNormalError := strings.Contains(err.Error(), "i/o timeout") || util.Contains([]int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized}, httpErr.HTTPCode)
if websocket.IsWebSocketUpgrade(r) {
isNormalError := strings.Contains(err.Error(), "i/o timeout")
if isNormalError {
logvr(v, r).Tag(tagWebsocket).Err(err).Fields(websocketErrorContext(err)).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
} else {
@ -354,22 +330,15 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor,
return // Do not attempt to write to upgraded connection
}
if matrixErr, ok := err.(*errMatrix); ok {
writeMatrixError(w, r, v, matrixErr)
if err := writeMatrixError(w, r, v, matrixErr); err != nil {
logvr(v, r).Tag(tagMatrix).Err(err).Debug("Writing Matrix error failed")
}
return
}
httpErr, ok := err.(*errHTTP)
if !ok {
httpErr = errHTTPInternalError
}
isNormalError := httpErr.HTTPCode == http.StatusNotFound || httpErr.HTTPCode == http.StatusBadRequest || httpErr.HTTPCode == http.StatusTooManyRequests
if isNormalError {
logvr(v, r).
Err(httpErr).
Debug("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
logvr(v, r).Err(err).Debug("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
} else {
logvr(v, r).
Err(httpErr).
Info("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
logvr(v, r).Err(err).Info("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
@ -629,7 +598,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
}
m.Sender = v.IP()
m.User = v.MaybeUserID()
m.Expires = time.Unix(m.Time, 0).Add(vRate.Limits().MessageExpiryDuration).Unix()
if cache {
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
}
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
return nil, err
}
@ -637,21 +608,18 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
m.Message = emptyMessageBody
}
delayed := m.Time > time.Now().Unix()
logvrm(vRate, r, m).
ev := logvrm(vRate, r, m).
Tag(tagPublish).
Fields(log.Context{
"message_delayed": delayed,
"message_firebase": firebase,
"message_unifiedpush": unifiedpush,
"message_email": email,
"message_subscriber_rate_limited": vRate != v,
}).
Debug("Received message")
if log.IsTrace() {
logvrm(vRate, r, m).
Tag(tagPublish).
Field("message_body", util.MaybeMarshalJSON(m)).
Trace("Message body")
"message_delayed": delayed,
"message_firebase": firebase,
"message_unifiedpush": unifiedpush,
"message_email": email,
})
if ev.IsTrace() {
ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message")
} else if ev.IsDebug() {
ev.Debug("Received message")
}
if !delayed {
if err := t.Publish(v, m); err != nil {
@ -1176,7 +1144,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
return err
}
err = g.Wait()
if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
logvr(v, r).Tag(tagWebsocket).Err(err).Fields(websocketErrorContext(err)).Trace("WebSocket connection closed")
return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot
}
@ -1309,161 +1277,6 @@ func (s *Server) topicFromID(id string) (*topic, error) {
return topics[0], nil
}
func (s *Server) execManager() {
// WARNING: Make sure to only selectively lock with the mutex, and be aware that this
// there is no mutex for the entire function.
// Expire visitors from rate visitors map
staleVisitors := 0
log.
Tag(tagManager).
Timing(func() {
s.mu.Lock()
defer s.mu.Unlock()
for ip, v := range s.visitors {
if v.Stale() {
log.Tag(tagManager).With(v).Trace("Deleting stale visitor")
delete(s.visitors, ip)
staleVisitors++
}
}
}).
Field("stale_visitors", staleVisitors).
Debug("Deleted %d stale visitor(s)", staleVisitors)
// Delete expired user tokens and users
if s.userManager != nil {
log.
Tag(tagManager).
Timing(func() {
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error expiring user tokens")
}
if err := s.userManager.RemoveDeletedUsers(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting soft-deleted users")
}
}).
Debug("Removed expired tokens and users")
}
// Delete expired attachments
if s.fileCache != nil {
log.
Tag(tagManager).
Timing(func() {
ids, err := s.messageCache.AttachmentsExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired attachments")
} else if len(ids) > 0 {
if log.Tag(tagManager).IsDebug() {
log.Tag(tagManager).Debug("Deleting attachments %s", strings.Join(ids, ", "))
}
if err := s.fileCache.Remove(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments")
}
if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired attachments to delete")
}
}).
Debug("Deleted expired attachments")
}
// Prune messages
log.
Tag(tagManager).
Timing(func() {
expiredMessageIDs, err := s.messageCache.MessagesExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired messages")
} else if len(expiredMessageIDs) > 0 {
if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments for expired messages")
}
if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired messages to delete")
}
}).
Debug("Pruned messages")
// Message count per topic
var messagesCached int
messageCounts, err := s.messageCache.MessageCounts()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Cannot get message counts")
messageCounts = make(map[string]int) // Empty, so we can continue
}
for _, count := range messageCounts {
messagesCached += count
}
// Remove subscriptions without subscribers
var emptyTopics, subscribers int
log.
Tag(tagManager).
Timing(func() {
s.mu.Lock()
defer s.mu.Unlock()
for _, t := range s.topics {
subs := t.SubscribersCount()
ev := log.Tag(tagManager)
if ev.IsTrace() {
expiryMessage := ""
if subs == 0 {
expiryTime := time.Until(t.vRateExpires)
expiryMessage = ", expires in " + expiryTime.String()
}
ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
}
msgs, exists := messageCounts[t.ID]
if t.Stale() && (!exists || msgs == 0) {
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
emptyTopics++
delete(s.topics, t.ID)
continue
}
subscribers += subs
}
}).
Debug("Removed %d empty topic(s)", emptyTopics)
// Mail stats
var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
if s.smtpServerBackend != nil {
receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
}
var sentMailTotal, sentMailSuccess, sentMailFailure int64
if s.smtpSender != nil {
sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
}
// Print stats
s.mu.Lock()
messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
s.mu.Unlock()
log.
Tag(tagManager).
Fields(log.Context{
"messages_published": messagesCount,
"messages_cached": messagesCached,
"topics_active": topicsCount,
"subscribers": subscribers,
"visitors": visitorsCount,
"emails_received": receivedMailTotal,
"emails_received_success": receivedMailSuccess,
"emails_received_failure": receivedMailFailure,
"emails_sent": sentMailTotal,
"emails_sent_success": sentMailSuccess,
"emails_sent_failure": sentMailFailure,
}).
Info("Server stats")
}
func (s *Server) runSMTPServer() error {
s.smtpServerBackend = newMailBackend(s.config, s.handle)
s.smtpServer = smtp.NewServer(s.smtpServerBackend)

View file

@ -246,7 +246,7 @@
# Logging options
#
# By default, ntfy logs to the console (stderr), with a "info" log level, and in a human-readable text format.
# By default, ntfy logs to the console (stderr), with an "info" log level, and in a human-readable text format.
# ntfy supports five different log levels, can also write to a file, log as JSON, and even supports granular
# log level overrides for easier debugging. Some options (log-level and log-level-overrides) can be hot reloaded
# by calling "kill -HUP $pid" or "systemctl reload ntfy".

View file

@ -100,6 +100,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
Customer: true,
Subscription: u.Billing.StripeSubscriptionID != "",
Status: string(u.Billing.StripeSubscriptionStatus),
Interval: string(u.Billing.StripeSubscriptionInterval),
PaidUntil: u.Billing.StripeSubscriptionPaidUntil.Unix(),
CancelAt: u.Billing.StripeSubscriptionCancelAt.Unix(),
}
@ -479,6 +480,7 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
if err := s.messageCache.ExpireMessages(topic); err != nil {
return err
}
s.pruneMessages()
}
return s.writeJSON(w, newSuccessResponse())
}
@ -505,6 +507,7 @@ func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *vi
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
go s.pruneMessages()
return nil
}

View file

@ -669,8 +669,8 @@ func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
require.Equal(t, 200, rr.Code)
// Verify that messages and attachments were deleted
// This does not explicitly call the manager!
time.Sleep(time.Second)
s.execManager()
ms, err := s.messageCache.Messages("mytopic1", sinceAllMessages, false)
require.Nil(t, err)
@ -804,10 +804,27 @@ func TestAccount_Persist_UserStats_After_Tier_Change(t *testing.T) {
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
require.Equal(t, int64(1), account.Stats.Messages) // Is not reset!
// Publish another message
rr = request(t, s, "POST", "/mytopic", "hi", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Verify that message stats were persisted
time.Sleep(300 * time.Millisecond)
u, err = s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, int64(0), u.Stats.Messages) // v.EnqueueUserStats had run!
require.Equal(t, int64(2), u.Stats.Messages) // v.EnqueueUserStats had run!
// Stats keep counting
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
account, _ = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
require.Equal(t, int64(2), account.Stats.Messages) // Is not reset!
}

View file

@ -8,7 +8,6 @@ import (
"firebase.google.com/go/v4/messaging"
"fmt"
"google.golang.org/api/option"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"strings"
@ -46,16 +45,15 @@ func (c *firebaseClient) Send(v *visitor, m *message) error {
if err != nil {
return err
}
if log.Tag(tagFirebase).IsTrace() {
logvm(v, m).
Tag(tagFirebase).
Field("firebase_message", util.MaybeMarshalJSON(fbm)).
Trace("Firebase message")
ev := logvm(v, m).Tag(tagFirebase)
if ev.IsTrace() {
ev.Field("firebase_message", util.MaybeMarshalJSON(fbm)).Trace("Firebase message")
}
err = c.sender.Send(fbm)
if err == errFirebaseQuotaExceeded {
logvm(v, m).
Tag(tagFirebase).
Err(err).
Warn("Firebase quota exceeded (likely for topic), temporarily denying Firebase access to visitor")
v.FirebaseTemporarilyDeny()
}

175
server/server_manager.go Normal file
View file

@ -0,0 +1,175 @@
package server
import (
"heckel.io/ntfy/log"
"strings"
"time"
)
func (s *Server) execManager() {
// WARNING: Make sure to only selectively lock with the mutex, and be aware that this
// there is no mutex for the entire function.
// Prune all the things
s.pruneVisitors()
s.pruneTokens()
s.pruneAttachments()
s.pruneMessages()
// Message count per topic
var messagesCached int
messageCounts, err := s.messageCache.MessageCounts()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Cannot get message counts")
messageCounts = make(map[string]int) // Empty, so we can continue
}
for _, count := range messageCounts {
messagesCached += count
}
// Remove subscriptions without subscribers
var emptyTopics, subscribers int
log.
Tag(tagManager).
Timing(func() {
s.mu.Lock()
defer s.mu.Unlock()
for _, t := range s.topics {
subs := t.SubscribersCount()
ev := log.Tag(tagManager)
if ev.IsTrace() {
expiryMessage := ""
if subs == 0 {
expiryTime := time.Until(t.vRateExpires)
expiryMessage = ", expires in " + expiryTime.String()
}
ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
}
msgs, exists := messageCounts[t.ID]
if t.Stale() && (!exists || msgs == 0) {
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
emptyTopics++
delete(s.topics, t.ID)
continue
}
subscribers += subs
}
}).
Debug("Removed %d empty topic(s)", emptyTopics)
// Mail stats
var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
if s.smtpServerBackend != nil {
receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
}
var sentMailTotal, sentMailSuccess, sentMailFailure int64
if s.smtpSender != nil {
sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
}
// Print stats
s.mu.Lock()
messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
s.mu.Unlock()
log.
Tag(tagManager).
Fields(log.Context{
"messages_published": messagesCount,
"messages_cached": messagesCached,
"topics_active": topicsCount,
"subscribers": subscribers,
"visitors": visitorsCount,
"emails_received": receivedMailTotal,
"emails_received_success": receivedMailSuccess,
"emails_received_failure": receivedMailFailure,
"emails_sent": sentMailTotal,
"emails_sent_success": sentMailSuccess,
"emails_sent_failure": sentMailFailure,
}).
Info("Server stats")
}
func (s *Server) pruneVisitors() {
staleVisitors := 0
log.
Tag(tagManager).
Timing(func() {
s.mu.Lock()
defer s.mu.Unlock()
for ip, v := range s.visitors {
if v.Stale() {
log.Tag(tagManager).With(v).Trace("Deleting stale visitor")
delete(s.visitors, ip)
staleVisitors++
}
}
}).
Field("stale_visitors", staleVisitors).
Debug("Deleted %d stale visitor(s)", staleVisitors)
}
func (s *Server) pruneTokens() {
if s.userManager != nil {
log.
Tag(tagManager).
Timing(func() {
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error expiring user tokens")
}
if err := s.userManager.RemoveDeletedUsers(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting soft-deleted users")
}
}).
Debug("Removed expired tokens and users")
}
}
func (s *Server) pruneAttachments() {
if s.fileCache == nil {
return
}
log.
Tag(tagManager).
Timing(func() {
ids, err := s.messageCache.AttachmentsExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired attachments")
} else if len(ids) > 0 {
if log.Tag(tagManager).IsDebug() {
log.Tag(tagManager).Debug("Deleting attachments %s", strings.Join(ids, ", "))
}
if err := s.fileCache.Remove(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments")
}
if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired attachments to delete")
}
}).
Debug("Deleted expired attachments")
}
func (s *Server) pruneMessages() {
log.
Tag(tagManager).
Timing(func() {
expiredMessageIDs, err := s.messageCache.MessagesExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired messages")
} else if len(expiredMessageIDs) > 0 {
if s.fileCache != nil {
if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments for expired messages")
}
}
if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired messages to delete")
}
}).
Debug("Pruned messages")
}

View file

@ -0,0 +1,28 @@
package server
import (
"github.com/stretchr/testify/require"
"testing"
)
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
// Tests that the manager runs without attachment-cache-dir set, see #617
c := newTestConfig(t)
c.AttachmentCacheDir = ""
s := newTestServer(t, c)
// Publish a message
rr := request(t, s, "POST", "/mytopic", "hi", nil)
require.Equal(t, 200, rr.Code)
m := toMessage(t, rr.Body.String())
// Expire message
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
// Does not panic
s.pruneMessages()
// Actually deleted
_, err := s.messageCache.Message(m.ID)
require.Equal(t, errMessageNotFound, err)
}

View file

@ -29,7 +29,7 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
vRate = topicCountsAgainst
}
r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t))
r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t))
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v)

View file

@ -80,14 +80,17 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
return err
}
for _, tier := range tiers {
priceStr, ok := prices[tier.StripePriceID]
if tier.StripePriceID == "" || !ok {
priceMonth, priceYear := prices[tier.StripeMonthlyPriceID], prices[tier.StripeYearlyPriceID]
if priceMonth == 0 || priceYear == 0 { // Only allow tiers that have both prices!
continue
}
response = append(response, &apiAccountBillingTier{
Code: tier.Code,
Name: tier.Name,
Price: priceStr,
Code: tier.Code,
Name: tier.Name,
Prices: &apiAccountBillingPrices{
Month: priceMonth,
Year: priceYear,
},
Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisTier),
Messages: tier.MessageLimit,
@ -117,11 +120,21 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
tier, err := s.userManager.Tier(req.Tier)
if err != nil {
return err
} else if tier.StripePriceID == "" {
}
var priceID string
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
priceID = tier.StripeMonthlyPriceID
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
priceID = tier.StripeYearlyPriceID
} else {
return errNotAPaidTier
}
logvr(v, r).
With(tier).
Fields(log.Context{
"stripe_price_id": priceID,
"stripe_subscription_interval": req.Interval,
}).
Tag(tagStripe).
Info("Creating Stripe checkout flow")
var stripeCustomerID *string
@ -143,7 +156,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
AllowPromotionCodes: stripe.Bool(true),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(tier.StripePriceID),
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
},
},
@ -180,10 +193,11 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
if err != nil {
return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil || sub.Items.Data[0].Price.Recurring == nil {
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription")
}
tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
priceID, interval := sub.Items.Data[0].Price.ID, sub.Items.Data[0].Price.Recurring.Interval
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil {
return err
}
@ -197,8 +211,10 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
Tag(tagStripe).
Fields(log.Context{
"stripe_customer_id": sess.Customer.ID,
"stripe_price_id": priceID,
"stripe_subscription_id": sub.ID,
"stripe_subscription_status": string(sub.Status),
"stripe_subscription_interval": string(interval),
"stripe_subscription_paid_until": sub.CurrentPeriodEnd,
}).
Info("Stripe checkout flow succeeded, updating user tier and subscription")
@ -213,7 +229,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
return err
}
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), string(interval), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err
}
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -235,15 +251,24 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil {
return err
}
var priceID string
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
priceID = tier.StripeMonthlyPriceID
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
priceID = tier.StripeYearlyPriceID
} else {
return errNotAPaidTier
}
logvr(v, r).
Tag(tagStripe).
Fields(log.Context{
"new_tier_id": tier.ID,
"new_tier_name": tier.Name,
"new_tier_stripe_price_id": tier.StripePriceID,
"new_tier_id": tier.ID,
"new_tier_code": tier.Code,
"new_tier_stripe_price_id": priceID,
"new_tier_stripe_subscription_interval": req.Interval,
// Other stripe_* fields filled by visitor context
}).
Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID)
Info("Changing Stripe subscription and billing tier to %s/%s (price %s, %s)", tier.ID, tier.Name, priceID, req.Interval)
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
if err != nil {
return err
@ -252,11 +277,11 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID),
Price: stripe.String(priceID),
},
},
}
@ -345,20 +370,22 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
if err != nil {
return err
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" || ev.Items.Data[0].Price.Recurring == nil {
logvr(v, r).Tag(tagStripe).Field("stripe_request", fmt.Sprintf("%#v", ev)).Warn("Unexpected request from Stripe")
return errHTTPBadRequestBillingRequestInvalid
}
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
subscriptionID, priceID, interval := ev.ID, ev.Items.Data[0].Price.ID, ev.Items.Data[0].Price.Recurring.Interval
logvr(v, r).
Tag(tagStripe).
Fields(log.Context{
"stripe_webhook_type": event.Type,
"stripe_customer_id": ev.Customer,
"stripe_price_id": priceID,
"stripe_subscription_id": ev.ID,
"stripe_subscription_status": ev.Status,
"stripe_subscription_interval": interval,
"stripe_subscription_paid_until": ev.CurrentPeriodEnd,
"stripe_subscription_cancel_at": ev.CancelAt,
"stripe_price_id": priceID,
}).
Info("Updating subscription to status %s, with price %s", ev.Status, priceID)
userFn := func() (*user.User, error) {
@ -376,7 +403,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, string(interval), ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err
}
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
@ -399,14 +426,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request,
Tag(tagStripe).
Field("stripe_webhook_type", event.Type).
Info("Subscription deleted, downgrading to unpaid tier")
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", "", 0, 0); err != nil {
return err
}
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil
}
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status, interval string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationLimit
@ -423,9 +450,8 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
logvr(v, r).
Tag(tagStripe).
Fields(log.Context{
"new_tier_id": tier.ID,
"new_tier_name": tier.Name,
"new_tier_stripe_price_id": tier.StripePriceID,
"new_tier_id": tier.ID,
"new_tier_code": tier.Code,
}).
Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name)
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
@ -437,6 +463,7 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
StripeCustomerID: customerID,
StripeSubscriptionID: subscriptionID,
StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
StripeSubscriptionInterval: stripe.PriceRecurringInterval(interval),
StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
}
@ -448,20 +475,16 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
// in memory, and ultimately for the web app to display the price table.
func (s *Server) fetchStripePrices() (map[string]string, error) {
func (s *Server) fetchStripePrices() (map[string]int64, error) {
log.Debug("Caching prices from Stripe API")
priceMap := make(map[string]string)
priceMap := make(map[string]int64)
prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
if err != nil {
log.Warn("Fetching Stripe prices failed: %s", err.Error())
return nil, err
}
for _, p := range prices {
if p.UnitAmount%100 == 0 {
priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
} else {
priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
}
priceMap[p.ID] = p.UnitAmount
log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
}
return priceMap, nil

View file

@ -37,7 +37,9 @@ func TestPayments_Tiers(t *testing.T) {
On("ListPrices", mock.Anything).
Return([]*stripe.Price{
{ID: "price_123", UnitAmount: 500},
{ID: "price_124", UnitAmount: 5000},
{ID: "price_456", UnitAmount: 1000},
{ID: "price_457", UnitAmount: 10000},
{ID: "price_999", UnitAmount: 9999},
}, nil)
@ -58,7 +60,8 @@ func TestPayments_Tiers(t *testing.T) {
AttachmentFileSizeLimit: 999,
AttachmentTotalSizeLimit: 888,
AttachmentExpiryDuration: time.Minute,
StripePriceID: "price_123",
StripeMonthlyPriceID: "price_123",
StripeYearlyPriceID: "price_124",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_444",
@ -71,7 +74,8 @@ func TestPayments_Tiers(t *testing.T) {
AttachmentFileSizeLimit: 999111,
AttachmentTotalSizeLimit: 888111,
AttachmentExpiryDuration: time.Hour,
StripePriceID: "price_456",
StripeMonthlyPriceID: "price_456",
StripeYearlyPriceID: "price_457",
}))
response := request(t, s, "GET", "/v1/tiers", "", nil)
require.Equal(t, 200, response.Code)
@ -98,6 +102,8 @@ func TestPayments_Tiers(t *testing.T) {
require.Equal(t, "pro", tier.Code)
require.Equal(t, "Pro", tier.Name)
require.Equal(t, "tier", tier.Limits.Basis)
require.Equal(t, int64(500), tier.Prices.Month)
require.Equal(t, int64(5000), tier.Prices.Year)
require.Equal(t, int64(777), tier.Limits.Reservations)
require.Equal(t, int64(1000), tier.Limits.Messages)
require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
@ -109,6 +115,8 @@ func TestPayments_Tiers(t *testing.T) {
tier = tiers[2]
require.Equal(t, "business", tier.Code)
require.Equal(t, "Business", tier.Name)
require.Equal(t, int64(1000), tier.Prices.Month)
require.Equal(t, int64(10000), tier.Prices.Year)
require.Equal(t, "tier", tier.Limits.Basis)
require.Equal(t, int64(777333), tier.Limits.Reservations)
require.Equal(t, int64(2000), tier.Limits.Messages)
@ -136,14 +144,14 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
ID: "ti_123",
Code: "pro",
StripeMonthlyPriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
@ -172,9 +180,9 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
ID: "ti_123",
Code: "pro",
StripeMonthlyPriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
@ -187,7 +195,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
@ -214,9 +222,9 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
ID: "ti_123",
Code: "pro",
StripeMonthlyPriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
@ -267,7 +275,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "starter",
StripePriceID: "price_1234",
StripeMonthlyPriceID: "price_1234",
ReservationLimit: 1,
MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
MessageExpiryDuration: time.Hour,
@ -298,7 +306,12 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
Items: &stripe.SubscriptionItemList{
Data: []*stripe.SubscriptionItem{
{
Price: &stripe.Price{ID: "price_1234"},
Price: &stripe.Price{
ID: "price_1234",
Recurring: &stripe.PriceRecurring{
Interval: stripe.PriceRecurringIntervalMonth,
},
},
},
},
},
@ -333,6 +346,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
require.Equal(t, "", u.Billing.StripeCustomerID)
require.Equal(t, "", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
require.Equal(t, stripe.PriceRecurringInterval(""), u.Billing.StripeSubscriptionInterval)
require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
@ -349,6 +363,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
require.Equal(t, stripe.PriceRecurringIntervalMonth, u.Billing.StripeSubscriptionInterval)
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
require.Equal(t, int64(0), u.Stats.Messages)
@ -423,7 +438,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_1",
Code: "starter",
StripePriceID: "price_1234", // !
StripeMonthlyPriceID: "price_1234", // !
ReservationLimit: 1, // !
MessageLimit: 100,
MessageExpiryDuration: time.Hour,
@ -435,7 +450,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_2",
Code: "pro",
StripePriceID: "price_1111", // !
StripeMonthlyPriceID: "price_1111", // !
ReservationLimit: 3, // !
MessageLimit: 200,
MessageExpiryDuration: time.Hour,
@ -457,6 +472,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
StripeCustomerID: "acct_5555",
StripeSubscriptionID: "sub_1234",
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
StripeSubscriptionPaidUntil: time.Unix(123, 0),
StripeSubscriptionCancelAt: time.Unix(456, 0),
}
@ -499,9 +515,10 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
require.Equal(t, stripe.PriceRecurringIntervalYear, u.Billing.StripeSubscriptionInterval) // Not "month"
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
// Verify that reservations were deleted
r, err := s.userManager.Reservations("phil")
@ -546,10 +563,10 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
// Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_1",
Code: "pro",
StripePriceID: "price_1234",
ReservationLimit: 1,
ID: "ti_1",
Code: "pro",
StripeMonthlyPriceID: "price_1234",
ReservationLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -562,6 +579,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
StripeCustomerID: "acct_5555",
StripeSubscriptionID: "sub_1234",
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
StripeSubscriptionPaidUntil: time.Unix(123, 0),
StripeSubscriptionCancelAt: time.Unix(0, 0),
}))
@ -615,11 +633,11 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
stripeMock.
On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String("someid_123"),
Price: stripe.String("price_456"),
Price: stripe.String("price_457"),
},
},
}).
@ -627,14 +645,16 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
ID: "ti_123",
Code: "pro",
StripeMonthlyPriceID: "price_123",
StripeYearlyPriceID: "price_124",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_456",
Code: "business",
StripePriceID: "price_456",
ID: "ti_456",
Code: "business",
StripeMonthlyPriceID: "price_456",
StripeYearlyPriceID: "price_457",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -644,7 +664,7 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
}))
// Call endpoint to change subscription
rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business"}`, map[string]string{
rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business","interval":"year"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
@ -795,7 +815,10 @@ const subscriptionUpdatedEventJSON = `
"data": [
{
"price": {
"id": "price_1234"
"id": "price_1234",
"recurring": {
"interval": "year"
}
}
}
]
@ -818,7 +841,10 @@ const subscriptionDeletedEventJSON = `
"data": [
{
"price": {
"id": "price_1234"
"id": "price_1234",
"recurring": {
"interval": "month"
}
}
}
]

View file

@ -149,6 +149,8 @@ func TestServer_PublishAndSubscribe(t *testing.T) {
require.Equal(t, "", messages[1].Title)
require.Equal(t, 0, messages[1].Priority)
require.Nil(t, messages[1].Tags)
require.True(t, time.Now().Add(12*time.Hour-5*time.Second).Unix() < messages[1].Expires)
require.True(t, time.Now().Add(12*time.Hour+5*time.Second).Unix() > messages[1].Expires)
require.Equal(t, messageEvent, messages[2].Event)
require.Equal(t, "mytopic", messages[2].Topic)
@ -287,6 +289,7 @@ func TestServer_PublishNoCache(t *testing.T) {
msg := toMessage(t, response.Body.String())
require.NotEmpty(t, msg.ID)
require.Equal(t, "this message is not cached", msg.Message)
require.Equal(t, int64(0), msg.Expires)
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
messages := toMessages(t, response.Body.String())
@ -324,6 +327,18 @@ func TestServer_PublishAt(t *testing.T) {
require.Equal(t, "9.9.9.9", messages[0].Sender.String()) // It's stored in the DB though!
}
func TestServer_PublishAt_Expires(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic", "a message", map[string]string{
"In": "2 days",
})
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.True(t, m.Expires > time.Now().Add(12*time.Hour+48*time.Hour-time.Minute).Unix())
require.True(t, m.Expires < time.Now().Add(12*time.Hour+48*time.Hour+time.Minute).Unix())
}
func TestServer_PublishAtWithCacheError(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
@ -1486,7 +1501,7 @@ func TestServer_PublishAttachmentTooLargeBodyVisitorAttachmentTotalSizeLimit(t *
c.VisitorAttachmentTotalSizeLimit = 10000
s := newTestServer(t, c)
response := request(t, s, "PUT", "/mytopic", util.RandomString(5000), nil)
response := request(t, s, "PUT", "/mytopic", "text file!"+util.RandomString(4990), nil)
msg := toMessage(t, response.Body.String())
require.Equal(t, 200, response.Code)
require.Equal(t, "You received a file: attachment.txt", msg.Message)

View file

@ -37,18 +37,18 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error {
return err
}
auth := smtp.PlainAuth("", s.config.SMTPSenderUser, s.config.SMTPSenderPass, host)
logvm(v, m).
ev := logvm(v, m).
Tag(tagEmail).
Fields(log.Context{
"email_via": s.config.SMTPSenderAddr,
"email_user": s.config.SMTPSenderUser,
"email_to": to,
}).
Debug("Sending email")
logvm(v, m).
Tag(tagEmail).
Field("email_body", message).
Trace("Email body")
})
if ev.IsTrace() {
ev.Field("email_body", message).Trace("Sending email")
} else if ev.IsDebug() {
ev.Debug("Sending email")
}
return smtp.SendMail(s.config.SMTPSenderAddr, auth, s.config.SMTPSenderFrom, []string{to}, []byte(message))
})
}

View file

@ -2,6 +2,7 @@ package server
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"github.com/emersion/go-smtp"
@ -21,9 +22,14 @@ var (
errInvalidAddress = errors.New("invalid address")
errInvalidTopic = errors.New("invalid topic")
errTooManyRecipients = errors.New("too many recipients")
errMultipartNestedTooDeep = errors.New("multipart message nested too deep")
errUnsupportedContentType = errors.New("unsupported content type")
)
const (
maxMultipartDepth = 2
)
// smtpBackend implements SMTP server methods.
type smtpBackend struct {
config *Config
@ -33,6 +39,9 @@ type smtpBackend struct {
mu sync.Mutex
}
var _ smtp.Backend = (*smtpBackend)(nil)
var _ smtp.Session = (*smtpSession)(nil)
func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
return &smtpBackend{
config: conf,
@ -40,14 +49,9 @@ func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Reques
}
}
func (b *smtpBackend) Login(state *smtp.ConnectionState, username, _ string) (smtp.Session, error) {
logem(state).Debug("Incoming mail, login with user %s", username)
return &smtpSession{backend: b, state: state}, nil
}
func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
logem(state).Debug("Incoming mail, anonymous login")
return &smtpSession{backend: b, state: state}, nil
func (b *smtpBackend) NewSession(conn *smtp.Conn) (smtp.Session, error) {
logem(conn).Debug("Incoming mail")
return &smtpSession{backend: b, conn: conn}, nil
}
func (b *smtpBackend) Counts() (total int64, success int64, failure int64) {
@ -59,24 +63,26 @@ func (b *smtpBackend) Counts() (total int64, success int64, failure int64) {
// smtpSession is returned after EHLO.
type smtpSession struct {
backend *smtpBackend
state *smtp.ConnectionState
conn *smtp.Conn
topic string
token string
mu sync.Mutex
}
func (s *smtpSession) AuthPlain(username, password string) error {
logem(s.state).Debug("AUTH PLAIN (with username %s)", username)
func (s *smtpSession) AuthPlain(username, _ string) error {
logem(s.conn).Field("smtp_username", username).Debug("AUTH PLAIN (with username %s)", username)
return nil
}
func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error {
logem(s.state).Debug("MAIL FROM: %s (with options: %#v)", from, opts)
func (s *smtpSession) Mail(from string, opts *smtp.MailOptions) error {
logem(s.conn).Field("smtp_mail_from", from).Debug("MAIL FROM: %s", from)
return nil
}
func (s *smtpSession) Rcpt(to string) error {
logem(s.state).Debug("RCPT TO: %s", to)
logem(s.conn).Field("smtp_rcpt_to", to).Debug("RCPT TO: %s", to)
return s.withFailCount(func() error {
token := ""
conf := s.backend.config
addressList, err := mail.ParseAddressList(to)
if err != nil {
@ -88,18 +94,27 @@ func (s *smtpSession) Rcpt(to string) error {
if !strings.HasSuffix(to, "@"+conf.SMTPServerDomain) {
return errInvalidDomain
}
// Remove @ntfy.sh from end of email
to = strings.TrimSuffix(to, "@"+conf.SMTPServerDomain)
if conf.SMTPServerAddrPrefix != "" {
if !strings.HasPrefix(to, conf.SMTPServerAddrPrefix) {
return errInvalidAddress
}
// remove ntfy- from beginning of email
to = strings.TrimPrefix(to, conf.SMTPServerAddrPrefix)
}
// If email contains token, split topic and token
if strings.Contains(to, "+") {
parts := strings.Split(to, "+")
to = parts[0]
token = parts[1]
}
if !topicRegex.MatchString(to) {
return errInvalidTopic
}
s.mu.Lock()
s.topic = to
s.token = token
s.mu.Unlock()
return nil
})
@ -112,17 +127,17 @@ func (s *smtpSession) Data(r io.Reader) error {
if err != nil {
return err
}
ev := logem(s.state).Tag(tagSMTP)
ev := logem(s.conn)
if ev.IsTrace() {
ev.Field("smtp_data", string(b)).Trace("DATA")
} else if ev.IsDebug() {
ev.Debug("DATA: %d byte(s)", len(b))
ev.Field("smtp_data_len", len(b)).Debug("DATA")
}
msg, err := mail.ReadMessage(bytes.NewReader(b))
if err != nil {
return err
}
body, err := readMailBody(msg)
body, err := readMailBody(msg.Body, msg.Header)
if err != nil {
return err
}
@ -156,11 +171,10 @@ func (s *smtpSession) Data(r io.Reader) error {
func (s *smtpSession) publishMessage(m *message) error {
// Extract remote address (for rate limiting)
remoteAddr, _, err := net.SplitHostPort(s.state.RemoteAddr.String())
remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String())
if err != nil {
remoteAddr = s.state.RemoteAddr.String()
remoteAddr = s.conn.Conn().RemoteAddr().String()
}
// Call HTTP handler with fake HTTP request
url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
req, err := http.NewRequest("POST", url, strings.NewReader(m.Message))
@ -173,6 +187,9 @@ func (s *smtpSession) publishMessage(m *message) error {
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
if s.token != "" {
req.Header.Add("Authorization", "Bearer "+s.token)
}
rr := httptest.NewRecorder()
s.backend.handler(rr, req)
if rr.Code != http.StatusOK {
@ -198,54 +215,58 @@ func (s *smtpSession) withFailCount(fn func() error) error {
if err != nil {
// Almost all of these errors are parse errors, and user input errors.
// We do not want to spam the log with WARN messages.
logem(s.state).Err(err).Debug("Incoming mail error")
logem(s.conn).Err(err).Debug("Incoming mail error")
s.backend.failure++
}
return err
}
func readMailBody(msg *mail.Message) (string, error) {
if msg.Header.Get("Content-Type") == "" {
return readPlainTextMailBody(msg)
func readMailBody(body io.Reader, header mail.Header) (string, error) {
if header.Get("Content-Type") == "" {
return readPlainTextMailBody(body, header.Get("Content-Transfer-Encoding"))
}
contentType, params, err := mime.ParseMediaType(msg.Header.Get("Content-Type"))
contentType, params, err := mime.ParseMediaType(header.Get("Content-Type"))
if err != nil {
return "", err
}
if contentType == "text/plain" {
return readPlainTextMailBody(msg)
} else if strings.HasPrefix(contentType, "multipart/") {
return readMultipartMailBody(msg, params)
if strings.ToLower(contentType) == "text/plain" {
return readPlainTextMailBody(body, header.Get("Content-Transfer-Encoding"))
} else if strings.HasPrefix(strings.ToLower(contentType), "multipart/") {
return readMultipartMailBody(body, params, 0)
}
return "", errUnsupportedContentType
}
func readPlainTextMailBody(msg *mail.Message) (string, error) {
body, err := io.ReadAll(msg.Body)
if err != nil {
return "", err
func readMultipartMailBody(body io.Reader, params map[string]string, depth int) (string, error) {
if depth >= maxMultipartDepth {
return "", errMultipartNestedTooDeep
}
return string(body), nil
}
func readMultipartMailBody(msg *mail.Message, params map[string]string) (string, error) {
mr := multipart.NewReader(msg.Body, params["boundary"])
mr := multipart.NewReader(body, params["boundary"])
for {
part, err := mr.NextPart()
if err != nil { // may be io.EOF
return "", err
}
partContentType, _, err := mime.ParseMediaType(part.Header.Get("Content-Type"))
partContentType, partParams, err := mime.ParseMediaType(part.Header.Get("Content-Type"))
if err != nil {
return "", err
}
if partContentType != "text/plain" {
continue
if strings.ToLower(partContentType) == "text/plain" {
return readPlainTextMailBody(part, part.Header.Get("Content-Transfer-Encoding"))
} else if strings.HasPrefix(strings.ToLower(partContentType), "multipart/") {
return readMultipartMailBody(part, partParams, depth+1)
}
body, err := io.ReadAll(part)
if err != nil {
return "", err
}
return string(body), nil
// Continue with next part
}
}
func readPlainTextMailBody(reader io.Reader, transferEncoding string) (string, error) {
if strings.ToLower(transferEncoding) == "base64" {
reader = base64.NewDecoder(base64.StdEncoding, reader)
}
body, err := io.ReadAll(reader)
if err != nil {
return "", err
}
return string(body), nil
}

View file

@ -1,16 +1,23 @@
package server
import (
"bufio"
"github.com/emersion/go-smtp"
"github.com/stretchr/testify/require"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
)
func TestSmtpBackend_Multipart(t *testing.T) {
email := `MIME-Version: 1.0
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: ntfy-mytopic@ntfy.sh
DATA
MIME-Version: 1.0
Date: Tue, 28 Dec 2021 00:30:10 +0100
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
Subject: and one more
@ -28,20 +35,25 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr">what&#39;s up<br clear="all"><div><br></div></div>
--000000000000f3320b05d42915c9--`
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
--000000000000f3320b05d42915c9--
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", readAll(t, r.Body))
})
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_MultipartNoBody(t *testing.T) {
email := `MIME-Version: 1.0
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: ntfy-emailtest@ntfy.sh
DATA
MIME-Version: 1.0
Date: Tue, 28 Dec 2021 01:33:34 +0100
Message-ID: <CAAvm7ABCDsi9vsuu0WTRXzZQBC8dXrDOLT8iCWdqrsmg@mail.gmail.com>
Subject: This email has a subject but no body
@ -59,20 +71,25 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr"><br></div>
--000000000000bcf4a405d429f8d4--`
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
--000000000000bcf4a405d429f8d4--
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/emailtest", r.URL.Path)
require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body
require.Equal(t, "This email has a subject but no body", readAll(t, r.Body))
})
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_Plaintext(t *testing.T) {
email := `Date: Tue, 28 Dec 2021 00:30:10 +0100
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: mytopic@ntfy.sh
DATA
Date: Tue, 28 Dec 2021 00:30:10 +0100
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
Subject: and one more
From: Phil <phil@example.com>
@ -80,56 +97,68 @@ To: mytopic@ntfy.sh
Content-Type: text/plain; charset="UTF-8"
what's up
.
`
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", readAll(t, r.Body))
})
conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) {
email := `Subject: Very short mail
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: mytopic@ntfy.sh
DATA
Subject: Very short mail
what's up
.
`
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "Very short mail", r.Header.Get("Title"))
require.Equal(t, "what's up", readAll(t, r.Body))
})
conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_Plaintext_EncodedSubject(t *testing.T) {
email := `Date: Tue, 28 Dec 2021 00:30:10 +0100
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: ntfy-mytopic@ntfy.sh
DATA
Date: Tue, 28 Dec 2021 00:30:10 +0100
Subject: =?UTF-8?B?VGhyZWUgc2FudGFzIPCfjoXwn46F8J+OhQ==?=
From: Phil <phil@example.com>
To: ntfy-mytopic@ntfy.sh
Content-Type: text/plain; charset="UTF-8"
what's up
.
`
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title"))
})
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_Plaintext_TooLongTruncate(t *testing.T) {
email := `Date: Tue, 28 Dec 2021 00:30:10 +0100
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: mytopic@ntfy.sh
DATA
Date: Tue, 28 Dec 2021 00:30:10 +0100
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
Subject: and one more
From: Phil <phil@example.com>
@ -148,60 +177,61 @@ so i'm gonna fill the rest of this with AAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAa
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
that should do it
.
`
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
expected := `you know this is a string.
it's a long string.
it's supposed to be longer than the max message length
@ -214,68 +244,71 @@ so i'm gonna fill the rest of this with AAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAa
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
......................................................................
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBBB`
require.Equal(t, 4096, len(expected)) // Sanity check
require.Equal(t, expected, readAll(t, r.Body))
})
defer s.Close()
defer c.Close()
conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email)))
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_Unsupported(t *testing.T) {
email := `Date: Tue, 28 Dec 2021 00:30:10 +0100
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: ntfy-mytopic@ntfy.sh
DATA
Date: Tue, 28 Dec 2021 00:30:10 +0100
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
Subject: and one more
From: Phil <phil@example.com>
@ -283,34 +316,254 @@ To: mytopic@ntfy.sh
Content-Type: text/SOMETHINGELSE
what's up
.
`
conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) {
// Nothing.
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
t.Fatal("This should not be called")
})
conf.SMTPServerAddrPrefix = ""
session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass")
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "554 5.0.0 Error: transaction failed, blame it on the weather: unsupported content type")
}
func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) {
conf := newTestConfig(t)
func TestSmtpBackend_InvalidAddress(t *testing.T) {
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: unsupported@ntfy.sh
DATA
Date: Tue, 28 Dec 2021 00:30:10 +0100
Subject: and one more
From: Phil <phil@example.com>
To: mytopic@ntfy.sh
Content-Type: text/plain
what's up
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
t.Fatal("This should not be called")
})
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "451 4.0.0 invalid address")
}
func TestSmtpBackend_Base64Body(t *testing.T) {
email := `EHLO example.com
MAIL FROM: test@mydomain.me
RCPT TO: ntfy-mytopic@ntfy.sh
DATA
Content-Type: multipart/mixed; boundary="===============2138658284696597373=="
MIME-Version: 1.0
Subject: TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local
From: =?utf-8?q?Robbie?= <test@mydomain.me>
To: test@mydomain.me
Date: Thu, 16 Feb 2023 01:04:00 -0000
Message-ID: <truenas-20230216.010400.344514.b'8jfL'@truenas.local>
This is a multi-part message in MIME format.
--===============2138658284696597373==
Content-Type: text/plain; charset="utf-8"
MIME-Version: 1.0
Content-Transfer-Encoding: base64
VGhpcyBpcyBhIHRlc3QgbWVzc2FnZSBmcm9tIFRydWVOQVMgQ09SRS4=
--===============2138658284696597373==
Content-Type: text/html; charset="utf-8"
MIME-Version: 1.0
Content-Transfer-Encoding: base64
PCFET0NUWVBFIEhUTUwgUFVCTElDICItLy9XM0MvL0RURCBIVE1MIDQuMCBUcmFuc2l0aW9uYWwv
L0VOIj4KClRoaXMgaXMgYSB0ZXN0IG1lc3NhZ2UgZnJvbSBUcnVlTkFTIENPUkUuCg==
--===============2138658284696597373==--
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local", r.Header.Get("Title"))
require.Equal(t, "This is a test message from TrueNAS CORE.", readAll(t, r.Body))
})
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_NestedMultipartBase64(t *testing.T) {
email := `EHLO example.com
MAIL FROM: test@mydomain.me
RCPT TO: ntfy-mytopic@ntfy.sh
DATA
Content-Type: multipart/mixed; boundary="===============2138658284696597373=="
MIME-Version: 1.0
Subject: TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local
From: =?utf-8?q?Robbie?= <test@mydomain.me>
To: test@mydomain.me
Date: Thu, 16 Feb 2023 01:04:00 -0000
Message-ID: <truenas-20230216.010400.344514.b'8jfL'@truenas.local>
This is a multi-part message in MIME format.
--===============2138658284696597373==
Content-Type: multipart/alternative; boundary="===============2233989480071754745=="
MIME-Version: 1.0
--===============2233989480071754745==
Content-Type: text/plain; charset="utf-8"
MIME-Version: 1.0
Content-Transfer-Encoding: base64
VGhpcyBpcyBhIHRlc3QgbWVzc2FnZSBmcm9tIFRydWVOQVMgQ09SRS4=
--===============2233989480071754745==
Content-Type: text/html; charset="utf-8"
MIME-Version: 1.0
Content-Transfer-Encoding: base64
PCFET0NUWVBFIEhUTUwgUFVCTElDICItLy9XM0MvL0RURCBIVE1MIDQuMCBUcmFuc2l0aW9uYWwv
L0VOIj4KClRoaXMgaXMgYSB0ZXN0IG1lc3NhZ2UgZnJvbSBUcnVlTkFTIENPUkUuCg==
--===============2233989480071754745==--
--===============2138658284696597373==--
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local", r.Header.Get("Title"))
require.Equal(t, "This is a test message from TrueNAS CORE.", readAll(t, r.Body))
})
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_NestedMultipartTooDeep(t *testing.T) {
email := `EHLO example.com
MAIL FROM: test@mydomain.me
RCPT TO: ntfy-mytopic@ntfy.sh
DATA
Content-Type: multipart/mixed; boundary="===============1=="
MIME-Version: 1.0
Subject: TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local
From: =?utf-8?q?Robbie?= <test@mydomain.me>
To: test@mydomain.me
Date: Thu, 16 Feb 2023 01:04:00 -0000
Message-ID: <truenas-20230216.010400.344514.b'8jfL'@truenas.local>
This is a multi-part message in MIME format.
--===============1==
Content-Type: multipart/alternative; boundary="===============2=="
MIME-Version: 1.0
--===============2==
Content-Type: multipart/alternative; boundary="===============3=="
MIME-Version: 1.0
--===============3==
Content-Type: text/plain; charset="utf-8"
MIME-Version: 1.0
Content-Transfer-Encoding: base64
VGhpcyBpcyBhIHRlc3QgbWVzc2FnZSBmcm9tIFRydWVOQVMgQ09SRS4=
--===============3==
Content-Type: text/html; charset="utf-8"
MIME-Version: 1.0
Content-Transfer-Encoding: base64
PCFET0NUWVBFIEhUTUwgUFVCTElDICItLy9XM0MvL0RURCBIVE1MIDQuMCBUcmFuc2l0aW9uYWwv
L0VOIj4KClRoaXMgaXMgYSB0ZXN0IG1lc3NhZ2UgZnJvbSBUcnVlTkFTIENPUkUuCg==
--===============3==--
--===============2==--
--===============1==--
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
t.Fatal("This should not be called")
})
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "554 5.0.0 Error: transaction failed, blame it on the weather: multipart message nested too deep")
}
func TestSmtpBackend_PlaintextWithToken(t *testing.T) {
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: ntfy-mytopic+tk_KLORUqSqvNRLpY11DfkHVbHu9NGG2@ntfy.sh
DATA
Subject: Very short mail
what's up
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "Very short mail", r.Header.Get("Title"))
require.Equal(t, "Bearer tk_KLORUqSqvNRLpY11DfkHVbHu9NGG2", r.Header.Get("Authorization"))
require.Equal(t, "what's up", readAll(t, r.Body))
})
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
type smtpHandlerFunc func(http.ResponseWriter, *http.Request)
func newTestSMTPServer(t *testing.T, handler smtpHandlerFunc) (s *smtp.Server, c net.Conn, conf *Config, scanner *bufio.Scanner) {
conf = newTestConfig(t)
conf.SMTPServerListen = ":25"
conf.SMTPServerDomain = "ntfy.sh"
conf.SMTPServerAddrPrefix = "ntfy-"
backend := newMailBackend(conf, handler)
return conf, backend
}
func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
ip, err := net.ResolveIPAddr("ip", remoteAddr)
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
return &smtp.ConnectionState{
Hostname: "myhostname",
LocalAddr: ip,
RemoteAddr: ip,
s = smtp.NewServer(backend)
s.Domain = conf.SMTPServerDomain
s.AllowInsecureAuth = true
go func() {
require.Nil(t, s.Serve(l))
}()
c, err = net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
scanner = bufio.NewScanner(c)
return
}
func writeAndReadUntilLine(t *testing.T, email string, conn net.Conn, scanner *bufio.Scanner, expectedLine string) {
_, err := io.WriteString(conn, email)
require.Nil(t, err)
readUntilLine(t, conn, scanner, expectedLine)
}
func readUntilLine(t *testing.T, conn net.Conn, scanner *bufio.Scanner, expectedLine string) {
cancelChan := make(chan bool)
go func() {
select {
case <-cancelChan:
case <-time.After(3 * time.Second):
conn.Close()
t.Error("Failed waiting for expected output")
}
}()
var output string
for scanner.Scan() {
text := scanner.Text()
if strings.TrimSpace(text) == expectedLine {
cancelChan <- true
return
}
output += text + "\n"
//fmt.Println(text)
}
t.Fatalf("Expected line '%s' not found in output:\n%s", expectedLine, output)
}

View file

@ -309,6 +309,7 @@ type apiAccountBilling struct {
Customer bool `json:"customer"`
Subscription bool `json:"subscription"`
Status string `json:"status,omitempty"`
Interval string `json:"interval,omitempty"`
PaidUntil int64 `json:"paid_until,omitempty"`
CancelAt int64 `json:"cancel_at,omitempty"`
}
@ -343,11 +344,16 @@ type apiConfigResponse struct {
DisallowedTopics []string `json:"disallowed_topics"`
}
type apiAccountBillingPrices struct {
Month int64 `json:"month"`
Year int64 `json:"year"`
}
type apiAccountBillingTier struct {
Code string `json:"code,omitempty"`
Name string `json:"name,omitempty"`
Price string `json:"price,omitempty"`
Limits *apiAccountLimits `json:"limits"`
Code string `json:"code,omitempty"`
Name string `json:"name,omitempty"`
Prices *apiAccountBillingPrices `json:"prices,omitempty"`
Limits *apiAccountLimits `json:"limits"`
}
type apiAccountBillingSubscriptionCreateResponse struct {
@ -355,7 +361,8 @@ type apiAccountBillingSubscriptionCreateResponse struct {
}
type apiAccountBillingSubscriptionChangeRequest struct {
Tier string `json:"tier"`
Tier string `json:"tier"`
Interval string `json:"interval"`
}
type apiAccountBillingPortalRedirectResponse struct {
@ -385,7 +392,10 @@ type apiStripeSubscriptionUpdatedEvent struct {
Items *struct {
Data []*struct {
Price *struct {
ID string `json:"id"`
ID string `json:"id"`
Recurring *struct {
Interval string `json:"interval"`
} `json:"recurring"`
} `json:"price"`
} `json:"data"`
} `json:"items"`

View file

@ -1,13 +1,11 @@
package server
import (
"heckel.io/ntfy/util"
"io"
"net/http"
"net/netip"
"strings"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
)
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
@ -74,7 +72,7 @@ func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
if err != nil {
ip = netip.IPv4Unspecified()
if remoteAddr != "@" || !behindProxy { // RemoteAddr is @ when unix socket is used
log.Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err)
logr(r).Err(err).Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created", remoteAddr)
}
}
}
@ -85,7 +83,7 @@ func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
if err != nil {
log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error())
logr(r).Err(err).Error("invalid IP address %s received in X-Forwarded-For header", ip)
// Fall back to regular remote address if X-Forwarded-For is damaged
} else {
ip = realIP

View file

@ -159,8 +159,9 @@ func (v *visitor) contextNoLock() log.Context {
fields["user_id"] = v.user.ID
fields["user_name"] = v.user.Name
if v.user.Tier != nil {
fields["tier_id"] = v.user.Tier.ID
fields["tier_name"] = v.user.Tier.Name
for field, value := range v.user.Tier.Context() {
fields[field] = value
}
}
if v.user.Billing.StripeCustomerID != "" {
fields["stripe_customer_id"] = v.user.Billing.StripeCustomerID
@ -329,9 +330,13 @@ func (v *visitor) SetUser(u *user.User) {
v.mu.Lock()
defer v.mu.Unlock()
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
v.user = u
v.user = u // u may be nil!
if shouldResetLimiters {
v.resetLimitersNoLock(0, 0, true)
var messages, emails int64
if u != nil {
messages, emails = u.Stats.Messages, u.Stats.Emails
}
v.resetLimitersNoLock(messages, emails, true)
}
}