Merge branch 'main' into user-account
This commit is contained in:
commit
4ab450309f
81 changed files with 4094 additions and 13687 deletions
|
@ -61,7 +61,7 @@ var (
|
|||
|
||||
// DefaultDisallowedTopics defines the topics that are forbidden, because they are used elsewhere. This array can be
|
||||
// extended using the server.yml config. If updated, also update in Android and web app.
|
||||
DefaultDisallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "signup", "login"}
|
||||
DefaultDisallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "signup", "login", "v1"}
|
||||
)
|
||||
|
||||
// Config is the main config struct for the application. Use New to instantiate a default config struct.
|
||||
|
|
|
@ -11,19 +11,38 @@ import (
|
|||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Log tags
|
||||
const (
|
||||
tagStartup = "startup"
|
||||
tagHTTP = "http"
|
||||
tagPublish = "publish"
|
||||
tagSubscribe = "subscribe"
|
||||
tagFirebase = "firebase"
|
||||
tagSMTP = "smtp" // Receive email
|
||||
tagEmail = "email" // Send email
|
||||
tagFileCache = "file_cache"
|
||||
tagMessageCache = "message_cache"
|
||||
tagStripe = "stripe"
|
||||
tagAccount = "account"
|
||||
tagManager = "manager"
|
||||
tagResetter = "resetter"
|
||||
tagWebsocket = "websocket"
|
||||
tagMatrix = "matrix"
|
||||
)
|
||||
|
||||
// logr creates a new log event with HTTP request fields
|
||||
func logr(r *http.Request) *log.Event {
|
||||
return log.Fields(httpContext(r))
|
||||
return log.Tag(tagHTTP).Fields(httpContext(r)) // Tag may be overwritten
|
||||
}
|
||||
|
||||
// logr creates a new log event with visitor fields
|
||||
// logv creates a new log event with visitor fields
|
||||
func logv(v *visitor) *log.Event {
|
||||
return log.With(v)
|
||||
}
|
||||
|
||||
// logr creates a new log event with HTTP request and visitor fields
|
||||
// logvr creates a new log event with HTTP request and visitor fields
|
||||
func logvr(v *visitor, r *http.Request) *log.Event {
|
||||
return logv(v).Fields(httpContext(r))
|
||||
return logr(r).With(v)
|
||||
}
|
||||
|
||||
// logvrm creates a new log event with HTTP request, visitor fields and message fields
|
||||
|
@ -37,13 +56,12 @@ func logvm(v *visitor, m *message) *log.Event {
|
|||
}
|
||||
|
||||
// logem creates a new log event with email fields
|
||||
func logem(state *smtp.ConnectionState) *log.Event {
|
||||
return log.
|
||||
Tag(tagSMTP).
|
||||
Fields(log.Context{
|
||||
"smtp_hostname": state.Hostname,
|
||||
"smtp_remote_addr": state.RemoteAddr.String(),
|
||||
})
|
||||
func logem(smtpConn *smtp.Conn) *log.Event {
|
||||
ev := log.Tag(tagSMTP).Field("smtp_hostname", smtpConn.Hostname())
|
||||
if smtpConn.Conn() != nil {
|
||||
ev.Field("smtp_remote_addr", smtpConn.Conn().RemoteAddr().String())
|
||||
}
|
||||
return ev
|
||||
}
|
||||
|
||||
func httpContext(r *http.Request) log.Context {
|
||||
|
|
|
@ -536,7 +536,7 @@ func (c *messageCache) ExpireMessages(topics ...string) error {
|
|||
}
|
||||
defer tx.Rollback()
|
||||
for _, t := range topics {
|
||||
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix(), t); err != nil {
|
||||
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix()-1, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
255
server/server.go
255
server/server.go
|
@ -33,16 +33,6 @@ import (
|
|||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
- MEDIUM fail2ban to work with ntfy log not nginx log
|
||||
- HIGH Docs
|
||||
- tiers
|
||||
- api
|
||||
- tokens
|
||||
|
||||
*/
|
||||
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
type Server struct {
|
||||
config *Config
|
||||
|
@ -56,11 +46,11 @@ type Server struct {
|
|||
visitors map[string]*visitor // ip:<ip> or user:<user>
|
||||
firebaseClient *firebaseClient
|
||||
messages int64
|
||||
userManager *user.Manager // Might be nil!
|
||||
messageCache *messageCache // Database that stores the messages
|
||||
fileCache *fileCache // File system based cache that stores attachments
|
||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||
priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
|
||||
userManager *user.Manager // Might be nil!
|
||||
messageCache *messageCache // Database that stores the messages
|
||||
fileCache *fileCache // File system based cache that stores attachments
|
||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||
closeChan chan bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
@ -134,24 +124,6 @@ const (
|
|||
wsPongWait = 15 * time.Second
|
||||
)
|
||||
|
||||
// Log tags
|
||||
const (
|
||||
tagStartup = "startup"
|
||||
tagPublish = "publish"
|
||||
tagSubscribe = "subscribe"
|
||||
tagFirebase = "firebase"
|
||||
tagSMTP = "smtp" // Receive email
|
||||
tagEmail = "email" // Send email
|
||||
tagFileCache = "file_cache"
|
||||
tagMessageCache = "message_cache"
|
||||
tagStripe = "stripe"
|
||||
tagAccount = "account"
|
||||
tagManager = "manager"
|
||||
tagResetter = "resetter"
|
||||
tagWebsocket = "websocket"
|
||||
tagMatrix = "matrix"
|
||||
)
|
||||
|
||||
// New instantiates a new Server. It creates the cache and adds a Firebase
|
||||
// subscriber (if configured).
|
||||
func New(conf *Config) (*Server, error) {
|
||||
|
@ -327,11 +299,11 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
|||
s.handleError(w, r, v, err)
|
||||
return
|
||||
}
|
||||
|
||||
if log.IsTrace() {
|
||||
logvr(v, r).Field("http_request", renderHTTPRequest(r)).Trace("HTTP request started")
|
||||
} else if log.IsDebug() {
|
||||
logvr(v, r).Debug("HTTP request started")
|
||||
ev := logvr(v, r)
|
||||
if ev.IsTrace() {
|
||||
ev.Field("http_request", renderHTTPRequest(r)).Trace("HTTP request started")
|
||||
} else if logvr(v, r).IsDebug() {
|
||||
ev.Debug("HTTP request started")
|
||||
}
|
||||
logvr(v, r).
|
||||
Timing(func() {
|
||||
|
@ -344,8 +316,12 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor, err error) {
|
||||
httpErr, ok := err.(*errHTTP)
|
||||
if !ok {
|
||||
httpErr = errHTTPInternalError
|
||||
}
|
||||
isNormalError := strings.Contains(err.Error(), "i/o timeout") || util.Contains([]int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized}, httpErr.HTTPCode)
|
||||
if websocket.IsWebSocketUpgrade(r) {
|
||||
isNormalError := strings.Contains(err.Error(), "i/o timeout")
|
||||
if isNormalError {
|
||||
logvr(v, r).Tag(tagWebsocket).Err(err).Fields(websocketErrorContext(err)).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
|
||||
} else {
|
||||
|
@ -354,22 +330,15 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor,
|
|||
return // Do not attempt to write to upgraded connection
|
||||
}
|
||||
if matrixErr, ok := err.(*errMatrix); ok {
|
||||
writeMatrixError(w, r, v, matrixErr)
|
||||
if err := writeMatrixError(w, r, v, matrixErr); err != nil {
|
||||
logvr(v, r).Tag(tagMatrix).Err(err).Debug("Writing Matrix error failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
httpErr, ok := err.(*errHTTP)
|
||||
if !ok {
|
||||
httpErr = errHTTPInternalError
|
||||
}
|
||||
isNormalError := httpErr.HTTPCode == http.StatusNotFound || httpErr.HTTPCode == http.StatusBadRequest || httpErr.HTTPCode == http.StatusTooManyRequests
|
||||
if isNormalError {
|
||||
logvr(v, r).
|
||||
Err(httpErr).
|
||||
Debug("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
|
||||
logvr(v, r).Err(err).Debug("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
|
||||
} else {
|
||||
logvr(v, r).
|
||||
Err(httpErr).
|
||||
Info("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
|
||||
logvr(v, r).Err(err).Info("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
|
@ -629,7 +598,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
}
|
||||
m.Sender = v.IP()
|
||||
m.User = v.MaybeUserID()
|
||||
m.Expires = time.Unix(m.Time, 0).Add(vRate.Limits().MessageExpiryDuration).Unix()
|
||||
if cache {
|
||||
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
|
||||
}
|
||||
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -637,21 +608,18 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
m.Message = emptyMessageBody
|
||||
}
|
||||
delayed := m.Time > time.Now().Unix()
|
||||
logvrm(vRate, r, m).
|
||||
ev := logvrm(vRate, r, m).
|
||||
Tag(tagPublish).
|
||||
Fields(log.Context{
|
||||
"message_delayed": delayed,
|
||||
"message_firebase": firebase,
|
||||
"message_unifiedpush": unifiedpush,
|
||||
"message_email": email,
|
||||
"message_subscriber_rate_limited": vRate != v,
|
||||
}).
|
||||
Debug("Received message")
|
||||
if log.IsTrace() {
|
||||
logvrm(vRate, r, m).
|
||||
Tag(tagPublish).
|
||||
Field("message_body", util.MaybeMarshalJSON(m)).
|
||||
Trace("Message body")
|
||||
"message_delayed": delayed,
|
||||
"message_firebase": firebase,
|
||||
"message_unifiedpush": unifiedpush,
|
||||
"message_email": email,
|
||||
})
|
||||
if ev.IsTrace() {
|
||||
ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message")
|
||||
} else if ev.IsDebug() {
|
||||
ev.Debug("Received message")
|
||||
}
|
||||
if !delayed {
|
||||
if err := t.Publish(v, m); err != nil {
|
||||
|
@ -1176,7 +1144,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
return err
|
||||
}
|
||||
err = g.Wait()
|
||||
if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
|
||||
logvr(v, r).Tag(tagWebsocket).Err(err).Fields(websocketErrorContext(err)).Trace("WebSocket connection closed")
|
||||
return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot
|
||||
}
|
||||
|
@ -1309,161 +1277,6 @@ func (s *Server) topicFromID(id string) (*topic, error) {
|
|||
return topics[0], nil
|
||||
}
|
||||
|
||||
func (s *Server) execManager() {
|
||||
// WARNING: Make sure to only selectively lock with the mutex, and be aware that this
|
||||
// there is no mutex for the entire function.
|
||||
|
||||
// Expire visitors from rate visitors map
|
||||
staleVisitors := 0
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for ip, v := range s.visitors {
|
||||
if v.Stale() {
|
||||
log.Tag(tagManager).With(v).Trace("Deleting stale visitor")
|
||||
delete(s.visitors, ip)
|
||||
staleVisitors++
|
||||
}
|
||||
}
|
||||
}).
|
||||
Field("stale_visitors", staleVisitors).
|
||||
Debug("Deleted %d stale visitor(s)", staleVisitors)
|
||||
|
||||
// Delete expired user tokens and users
|
||||
if s.userManager != nil {
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
if err := s.userManager.RemoveExpiredTokens(); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error expiring user tokens")
|
||||
}
|
||||
if err := s.userManager.RemoveDeletedUsers(); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error deleting soft-deleted users")
|
||||
}
|
||||
}).
|
||||
Debug("Removed expired tokens and users")
|
||||
}
|
||||
|
||||
// Delete expired attachments
|
||||
if s.fileCache != nil {
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
ids, err := s.messageCache.AttachmentsExpired()
|
||||
if err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error retrieving expired attachments")
|
||||
} else if len(ids) > 0 {
|
||||
if log.Tag(tagManager).IsDebug() {
|
||||
log.Tag(tagManager).Debug("Deleting attachments %s", strings.Join(ids, ", "))
|
||||
}
|
||||
if err := s.fileCache.Remove(ids...); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error deleting attachments")
|
||||
}
|
||||
if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
|
||||
}
|
||||
} else {
|
||||
log.Tag(tagManager).Debug("No expired attachments to delete")
|
||||
}
|
||||
}).
|
||||
Debug("Deleted expired attachments")
|
||||
}
|
||||
|
||||
// Prune messages
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
expiredMessageIDs, err := s.messageCache.MessagesExpired()
|
||||
if err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error retrieving expired messages")
|
||||
} else if len(expiredMessageIDs) > 0 {
|
||||
if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error deleting attachments for expired messages")
|
||||
}
|
||||
if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
|
||||
}
|
||||
} else {
|
||||
log.Tag(tagManager).Debug("No expired messages to delete")
|
||||
}
|
||||
}).
|
||||
Debug("Pruned messages")
|
||||
|
||||
// Message count per topic
|
||||
var messagesCached int
|
||||
messageCounts, err := s.messageCache.MessageCounts()
|
||||
if err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Cannot get message counts")
|
||||
messageCounts = make(map[string]int) // Empty, so we can continue
|
||||
}
|
||||
for _, count := range messageCounts {
|
||||
messagesCached += count
|
||||
}
|
||||
|
||||
// Remove subscriptions without subscribers
|
||||
var emptyTopics, subscribers int
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, t := range s.topics {
|
||||
subs := t.SubscribersCount()
|
||||
ev := log.Tag(tagManager)
|
||||
if ev.IsTrace() {
|
||||
expiryMessage := ""
|
||||
if subs == 0 {
|
||||
expiryTime := time.Until(t.vRateExpires)
|
||||
expiryMessage = ", expires in " + expiryTime.String()
|
||||
}
|
||||
ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
|
||||
}
|
||||
msgs, exists := messageCounts[t.ID]
|
||||
if t.Stale() && (!exists || msgs == 0) {
|
||||
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
|
||||
emptyTopics++
|
||||
delete(s.topics, t.ID)
|
||||
continue
|
||||
}
|
||||
subscribers += subs
|
||||
}
|
||||
}).
|
||||
Debug("Removed %d empty topic(s)", emptyTopics)
|
||||
|
||||
// Mail stats
|
||||
var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
|
||||
if s.smtpServerBackend != nil {
|
||||
receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
|
||||
}
|
||||
var sentMailTotal, sentMailSuccess, sentMailFailure int64
|
||||
if s.smtpSender != nil {
|
||||
sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
|
||||
}
|
||||
|
||||
// Print stats
|
||||
s.mu.Lock()
|
||||
messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
|
||||
s.mu.Unlock()
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Fields(log.Context{
|
||||
"messages_published": messagesCount,
|
||||
"messages_cached": messagesCached,
|
||||
"topics_active": topicsCount,
|
||||
"subscribers": subscribers,
|
||||
"visitors": visitorsCount,
|
||||
"emails_received": receivedMailTotal,
|
||||
"emails_received_success": receivedMailSuccess,
|
||||
"emails_received_failure": receivedMailFailure,
|
||||
"emails_sent": sentMailTotal,
|
||||
"emails_sent_success": sentMailSuccess,
|
||||
"emails_sent_failure": sentMailFailure,
|
||||
}).
|
||||
Info("Server stats")
|
||||
}
|
||||
|
||||
func (s *Server) runSMTPServer() error {
|
||||
s.smtpServerBackend = newMailBackend(s.config, s.handle)
|
||||
s.smtpServer = smtp.NewServer(s.smtpServerBackend)
|
||||
|
|
|
@ -246,7 +246,7 @@
|
|||
|
||||
# Logging options
|
||||
#
|
||||
# By default, ntfy logs to the console (stderr), with a "info" log level, and in a human-readable text format.
|
||||
# By default, ntfy logs to the console (stderr), with an "info" log level, and in a human-readable text format.
|
||||
# ntfy supports five different log levels, can also write to a file, log as JSON, and even supports granular
|
||||
# log level overrides for easier debugging. Some options (log-level and log-level-overrides) can be hot reloaded
|
||||
# by calling "kill -HUP $pid" or "systemctl reload ntfy".
|
||||
|
|
|
@ -100,6 +100,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
|
|||
Customer: true,
|
||||
Subscription: u.Billing.StripeSubscriptionID != "",
|
||||
Status: string(u.Billing.StripeSubscriptionStatus),
|
||||
Interval: string(u.Billing.StripeSubscriptionInterval),
|
||||
PaidUntil: u.Billing.StripeSubscriptionPaidUntil.Unix(),
|
||||
CancelAt: u.Billing.StripeSubscriptionCancelAt.Unix(),
|
||||
}
|
||||
|
@ -479,6 +480,7 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
|||
if err := s.messageCache.ExpireMessages(topic); err != nil {
|
||||
return err
|
||||
}
|
||||
s.pruneMessages()
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
@ -505,6 +507,7 @@ func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *vi
|
|||
if err := s.messageCache.ExpireMessages(topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
go s.pruneMessages()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -669,8 +669,8 @@ func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
|
|||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Verify that messages and attachments were deleted
|
||||
// This does not explicitly call the manager!
|
||||
time.Sleep(time.Second)
|
||||
s.execManager()
|
||||
|
||||
ms, err := s.messageCache.Messages("mytopic1", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
|
@ -804,10 +804,27 @@ func TestAccount_Persist_UserStats_After_Tier_Change(t *testing.T) {
|
|||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
|
||||
require.Equal(t, int64(1), account.Stats.Messages) // Is not reset!
|
||||
|
||||
// Publish another message
|
||||
rr = request(t, s, "POST", "/mytopic", "hi", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Verify that message stats were persisted
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
u, err = s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), u.Stats.Messages) // v.EnqueueUserStats had run!
|
||||
require.Equal(t, int64(2), u.Stats.Messages) // v.EnqueueUserStats had run!
|
||||
|
||||
// Stats keep counting
|
||||
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
account, _ = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
|
||||
require.Equal(t, int64(2), account.Stats.Messages) // Is not reset!
|
||||
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"firebase.google.com/go/v4/messaging"
|
||||
"fmt"
|
||||
"google.golang.org/api/option"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"strings"
|
||||
|
@ -46,16 +45,15 @@ func (c *firebaseClient) Send(v *visitor, m *message) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if log.Tag(tagFirebase).IsTrace() {
|
||||
logvm(v, m).
|
||||
Tag(tagFirebase).
|
||||
Field("firebase_message", util.MaybeMarshalJSON(fbm)).
|
||||
Trace("Firebase message")
|
||||
ev := logvm(v, m).Tag(tagFirebase)
|
||||
if ev.IsTrace() {
|
||||
ev.Field("firebase_message", util.MaybeMarshalJSON(fbm)).Trace("Firebase message")
|
||||
}
|
||||
err = c.sender.Send(fbm)
|
||||
if err == errFirebaseQuotaExceeded {
|
||||
logvm(v, m).
|
||||
Tag(tagFirebase).
|
||||
Err(err).
|
||||
Warn("Firebase quota exceeded (likely for topic), temporarily denying Firebase access to visitor")
|
||||
v.FirebaseTemporarilyDeny()
|
||||
}
|
||||
|
|
175
server/server_manager.go
Normal file
175
server/server_manager.go
Normal file
|
@ -0,0 +1,175 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Server) execManager() {
|
||||
// WARNING: Make sure to only selectively lock with the mutex, and be aware that this
|
||||
// there is no mutex for the entire function.
|
||||
|
||||
// Prune all the things
|
||||
s.pruneVisitors()
|
||||
s.pruneTokens()
|
||||
s.pruneAttachments()
|
||||
s.pruneMessages()
|
||||
|
||||
// Message count per topic
|
||||
var messagesCached int
|
||||
messageCounts, err := s.messageCache.MessageCounts()
|
||||
if err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Cannot get message counts")
|
||||
messageCounts = make(map[string]int) // Empty, so we can continue
|
||||
}
|
||||
for _, count := range messageCounts {
|
||||
messagesCached += count
|
||||
}
|
||||
|
||||
// Remove subscriptions without subscribers
|
||||
var emptyTopics, subscribers int
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, t := range s.topics {
|
||||
subs := t.SubscribersCount()
|
||||
ev := log.Tag(tagManager)
|
||||
if ev.IsTrace() {
|
||||
expiryMessage := ""
|
||||
if subs == 0 {
|
||||
expiryTime := time.Until(t.vRateExpires)
|
||||
expiryMessage = ", expires in " + expiryTime.String()
|
||||
}
|
||||
ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
|
||||
}
|
||||
msgs, exists := messageCounts[t.ID]
|
||||
if t.Stale() && (!exists || msgs == 0) {
|
||||
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
|
||||
emptyTopics++
|
||||
delete(s.topics, t.ID)
|
||||
continue
|
||||
}
|
||||
subscribers += subs
|
||||
}
|
||||
}).
|
||||
Debug("Removed %d empty topic(s)", emptyTopics)
|
||||
|
||||
// Mail stats
|
||||
var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
|
||||
if s.smtpServerBackend != nil {
|
||||
receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
|
||||
}
|
||||
var sentMailTotal, sentMailSuccess, sentMailFailure int64
|
||||
if s.smtpSender != nil {
|
||||
sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
|
||||
}
|
||||
|
||||
// Print stats
|
||||
s.mu.Lock()
|
||||
messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
|
||||
s.mu.Unlock()
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Fields(log.Context{
|
||||
"messages_published": messagesCount,
|
||||
"messages_cached": messagesCached,
|
||||
"topics_active": topicsCount,
|
||||
"subscribers": subscribers,
|
||||
"visitors": visitorsCount,
|
||||
"emails_received": receivedMailTotal,
|
||||
"emails_received_success": receivedMailSuccess,
|
||||
"emails_received_failure": receivedMailFailure,
|
||||
"emails_sent": sentMailTotal,
|
||||
"emails_sent_success": sentMailSuccess,
|
||||
"emails_sent_failure": sentMailFailure,
|
||||
}).
|
||||
Info("Server stats")
|
||||
}
|
||||
|
||||
func (s *Server) pruneVisitors() {
|
||||
staleVisitors := 0
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for ip, v := range s.visitors {
|
||||
if v.Stale() {
|
||||
log.Tag(tagManager).With(v).Trace("Deleting stale visitor")
|
||||
delete(s.visitors, ip)
|
||||
staleVisitors++
|
||||
}
|
||||
}
|
||||
}).
|
||||
Field("stale_visitors", staleVisitors).
|
||||
Debug("Deleted %d stale visitor(s)", staleVisitors)
|
||||
}
|
||||
|
||||
func (s *Server) pruneTokens() {
|
||||
if s.userManager != nil {
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
if err := s.userManager.RemoveExpiredTokens(); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error expiring user tokens")
|
||||
}
|
||||
if err := s.userManager.RemoveDeletedUsers(); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error deleting soft-deleted users")
|
||||
}
|
||||
}).
|
||||
Debug("Removed expired tokens and users")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) pruneAttachments() {
|
||||
if s.fileCache == nil {
|
||||
return
|
||||
}
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
ids, err := s.messageCache.AttachmentsExpired()
|
||||
if err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error retrieving expired attachments")
|
||||
} else if len(ids) > 0 {
|
||||
if log.Tag(tagManager).IsDebug() {
|
||||
log.Tag(tagManager).Debug("Deleting attachments %s", strings.Join(ids, ", "))
|
||||
}
|
||||
if err := s.fileCache.Remove(ids...); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error deleting attachments")
|
||||
}
|
||||
if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
|
||||
}
|
||||
} else {
|
||||
log.Tag(tagManager).Debug("No expired attachments to delete")
|
||||
}
|
||||
}).
|
||||
Debug("Deleted expired attachments")
|
||||
}
|
||||
|
||||
func (s *Server) pruneMessages() {
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Timing(func() {
|
||||
expiredMessageIDs, err := s.messageCache.MessagesExpired()
|
||||
if err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error retrieving expired messages")
|
||||
} else if len(expiredMessageIDs) > 0 {
|
||||
if s.fileCache != nil {
|
||||
if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error deleting attachments for expired messages")
|
||||
}
|
||||
}
|
||||
if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
|
||||
}
|
||||
} else {
|
||||
log.Tag(tagManager).Debug("No expired messages to delete")
|
||||
}
|
||||
}).
|
||||
Debug("Pruned messages")
|
||||
}
|
28
server/server_manager_test.go
Normal file
28
server/server_manager_test.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
|
||||
// Tests that the manager runs without attachment-cache-dir set, see #617
|
||||
c := newTestConfig(t)
|
||||
c.AttachmentCacheDir = ""
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Publish a message
|
||||
rr := request(t, s, "POST", "/mytopic", "hi", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
m := toMessage(t, rr.Body.String())
|
||||
|
||||
// Expire message
|
||||
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
|
||||
|
||||
// Does not panic
|
||||
s.pruneMessages()
|
||||
|
||||
// Actually deleted
|
||||
_, err := s.messageCache.Message(m.ID)
|
||||
require.Equal(t, errMessageNotFound, err)
|
||||
}
|
|
@ -29,7 +29,7 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
|
|||
if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
|
||||
vRate = topicCountsAgainst
|
||||
}
|
||||
r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t))
|
||||
r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t))
|
||||
|
||||
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||
return next(w, r, v)
|
||||
|
|
|
@ -80,14 +80,17 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||
return err
|
||||
}
|
||||
for _, tier := range tiers {
|
||||
priceStr, ok := prices[tier.StripePriceID]
|
||||
if tier.StripePriceID == "" || !ok {
|
||||
priceMonth, priceYear := prices[tier.StripeMonthlyPriceID], prices[tier.StripeYearlyPriceID]
|
||||
if priceMonth == 0 || priceYear == 0 { // Only allow tiers that have both prices!
|
||||
continue
|
||||
}
|
||||
response = append(response, &apiAccountBillingTier{
|
||||
Code: tier.Code,
|
||||
Name: tier.Name,
|
||||
Price: priceStr,
|
||||
Code: tier.Code,
|
||||
Name: tier.Name,
|
||||
Prices: &apiAccountBillingPrices{
|
||||
Month: priceMonth,
|
||||
Year: priceYear,
|
||||
},
|
||||
Limits: &apiAccountLimits{
|
||||
Basis: string(visitorLimitBasisTier),
|
||||
Messages: tier.MessageLimit,
|
||||
|
@ -117,11 +120,21 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
tier, err := s.userManager.Tier(req.Tier)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if tier.StripePriceID == "" {
|
||||
}
|
||||
var priceID string
|
||||
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
|
||||
priceID = tier.StripeMonthlyPriceID
|
||||
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
|
||||
priceID = tier.StripeYearlyPriceID
|
||||
} else {
|
||||
return errNotAPaidTier
|
||||
}
|
||||
logvr(v, r).
|
||||
With(tier).
|
||||
Fields(log.Context{
|
||||
"stripe_price_id": priceID,
|
||||
"stripe_subscription_interval": req.Interval,
|
||||
}).
|
||||
Tag(tagStripe).
|
||||
Info("Creating Stripe checkout flow")
|
||||
var stripeCustomerID *string
|
||||
|
@ -143,7 +156,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
AllowPromotionCodes: stripe.Bool(true),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(tier.StripePriceID),
|
||||
Price: stripe.String(priceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
|
@ -180,10 +193,11 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|||
sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
|
||||
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil || sub.Items.Data[0].Price.Recurring == nil {
|
||||
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription")
|
||||
}
|
||||
tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
|
||||
priceID, interval := sub.Items.Data[0].Price.ID, sub.Items.Data[0].Price.Recurring.Interval
|
||||
tier, err := s.userManager.TierByStripePrice(priceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -197,8 +211,10 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|||
Tag(tagStripe).
|
||||
Fields(log.Context{
|
||||
"stripe_customer_id": sess.Customer.ID,
|
||||
"stripe_price_id": priceID,
|
||||
"stripe_subscription_id": sub.ID,
|
||||
"stripe_subscription_status": string(sub.Status),
|
||||
"stripe_subscription_interval": string(interval),
|
||||
"stripe_subscription_paid_until": sub.CurrentPeriodEnd,
|
||||
}).
|
||||
Info("Stripe checkout flow succeeded, updating user tier and subscription")
|
||||
|
@ -213,7 +229,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|||
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), string(interval), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
|
||||
|
@ -235,15 +251,24 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var priceID string
|
||||
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
|
||||
priceID = tier.StripeMonthlyPriceID
|
||||
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
|
||||
priceID = tier.StripeYearlyPriceID
|
||||
} else {
|
||||
return errNotAPaidTier
|
||||
}
|
||||
logvr(v, r).
|
||||
Tag(tagStripe).
|
||||
Fields(log.Context{
|
||||
"new_tier_id": tier.ID,
|
||||
"new_tier_name": tier.Name,
|
||||
"new_tier_stripe_price_id": tier.StripePriceID,
|
||||
"new_tier_id": tier.ID,
|
||||
"new_tier_code": tier.Code,
|
||||
"new_tier_stripe_price_id": priceID,
|
||||
"new_tier_stripe_subscription_interval": req.Interval,
|
||||
// Other stripe_* fields filled by visitor context
|
||||
}).
|
||||
Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID)
|
||||
Info("Changing Stripe subscription and billing tier to %s/%s (price %s, %s)", tier.ID, tier.Name, priceID, req.Interval)
|
||||
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -252,11 +277,11 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|||
}
|
||||
params := &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(false),
|
||||
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
|
||||
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
|
||||
Items: []*stripe.SubscriptionItemsParams{
|
||||
{
|
||||
ID: stripe.String(sub.Items.Data[0].ID),
|
||||
Price: stripe.String(tier.StripePriceID),
|
||||
Price: stripe.String(priceID),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -345,20 +370,22 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
|
|||
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
|
||||
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" || ev.Items.Data[0].Price.Recurring == nil {
|
||||
logvr(v, r).Tag(tagStripe).Field("stripe_request", fmt.Sprintf("%#v", ev)).Warn("Unexpected request from Stripe")
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
|
||||
subscriptionID, priceID, interval := ev.ID, ev.Items.Data[0].Price.ID, ev.Items.Data[0].Price.Recurring.Interval
|
||||
logvr(v, r).
|
||||
Tag(tagStripe).
|
||||
Fields(log.Context{
|
||||
"stripe_webhook_type": event.Type,
|
||||
"stripe_customer_id": ev.Customer,
|
||||
"stripe_price_id": priceID,
|
||||
"stripe_subscription_id": ev.ID,
|
||||
"stripe_subscription_status": ev.Status,
|
||||
"stripe_subscription_interval": interval,
|
||||
"stripe_subscription_paid_until": ev.CurrentPeriodEnd,
|
||||
"stripe_subscription_cancel_at": ev.CancelAt,
|
||||
"stripe_price_id": priceID,
|
||||
}).
|
||||
Info("Updating subscription to status %s, with price %s", ev.Status, priceID)
|
||||
userFn := func() (*user.User, error) {
|
||||
|
@ -376,7 +403,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, string(interval), ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||
|
@ -399,14 +426,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request,
|
|||
Tag(tagStripe).
|
||||
Field("stripe_webhook_type", event.Type).
|
||||
Info("Subscription deleted, downgrading to unpaid tier")
|
||||
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", "", 0, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status, interval string, paidUntil, cancelAt int64) error {
|
||||
reservationsLimit := visitorDefaultReservationsLimit
|
||||
if tier != nil {
|
||||
reservationsLimit = tier.ReservationLimit
|
||||
|
@ -423,9 +450,8 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
|
|||
logvr(v, r).
|
||||
Tag(tagStripe).
|
||||
Fields(log.Context{
|
||||
"new_tier_id": tier.ID,
|
||||
"new_tier_name": tier.Name,
|
||||
"new_tier_stripe_price_id": tier.StripePriceID,
|
||||
"new_tier_id": tier.ID,
|
||||
"new_tier_code": tier.Code,
|
||||
}).
|
||||
Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name)
|
||||
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
|
||||
|
@ -437,6 +463,7 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
|
|||
StripeCustomerID: customerID,
|
||||
StripeSubscriptionID: subscriptionID,
|
||||
StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
|
||||
StripeSubscriptionInterval: stripe.PriceRecurringInterval(interval),
|
||||
StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
|
||||
StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
|
||||
}
|
||||
|
@ -448,20 +475,16 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
|
|||
|
||||
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
|
||||
// in memory, and ultimately for the web app to display the price table.
|
||||
func (s *Server) fetchStripePrices() (map[string]string, error) {
|
||||
func (s *Server) fetchStripePrices() (map[string]int64, error) {
|
||||
log.Debug("Caching prices from Stripe API")
|
||||
priceMap := make(map[string]string)
|
||||
priceMap := make(map[string]int64)
|
||||
prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
|
||||
if err != nil {
|
||||
log.Warn("Fetching Stripe prices failed: %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
for _, p := range prices {
|
||||
if p.UnitAmount%100 == 0 {
|
||||
priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
|
||||
} else {
|
||||
priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
||||
}
|
||||
priceMap[p.ID] = p.UnitAmount
|
||||
log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
|
||||
}
|
||||
return priceMap, nil
|
||||
|
|
|
@ -37,7 +37,9 @@ func TestPayments_Tiers(t *testing.T) {
|
|||
On("ListPrices", mock.Anything).
|
||||
Return([]*stripe.Price{
|
||||
{ID: "price_123", UnitAmount: 500},
|
||||
{ID: "price_124", UnitAmount: 5000},
|
||||
{ID: "price_456", UnitAmount: 1000},
|
||||
{ID: "price_457", UnitAmount: 10000},
|
||||
{ID: "price_999", UnitAmount: 9999},
|
||||
}, nil)
|
||||
|
||||
|
@ -58,7 +60,8 @@ func TestPayments_Tiers(t *testing.T) {
|
|||
AttachmentFileSizeLimit: 999,
|
||||
AttachmentTotalSizeLimit: 888,
|
||||
AttachmentExpiryDuration: time.Minute,
|
||||
StripePriceID: "price_123",
|
||||
StripeMonthlyPriceID: "price_123",
|
||||
StripeYearlyPriceID: "price_124",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_444",
|
||||
|
@ -71,7 +74,8 @@ func TestPayments_Tiers(t *testing.T) {
|
|||
AttachmentFileSizeLimit: 999111,
|
||||
AttachmentTotalSizeLimit: 888111,
|
||||
AttachmentExpiryDuration: time.Hour,
|
||||
StripePriceID: "price_456",
|
||||
StripeMonthlyPriceID: "price_456",
|
||||
StripeYearlyPriceID: "price_457",
|
||||
}))
|
||||
response := request(t, s, "GET", "/v1/tiers", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
@ -98,6 +102,8 @@ func TestPayments_Tiers(t *testing.T) {
|
|||
require.Equal(t, "pro", tier.Code)
|
||||
require.Equal(t, "Pro", tier.Name)
|
||||
require.Equal(t, "tier", tier.Limits.Basis)
|
||||
require.Equal(t, int64(500), tier.Prices.Month)
|
||||
require.Equal(t, int64(5000), tier.Prices.Year)
|
||||
require.Equal(t, int64(777), tier.Limits.Reservations)
|
||||
require.Equal(t, int64(1000), tier.Limits.Messages)
|
||||
require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
|
||||
|
@ -109,6 +115,8 @@ func TestPayments_Tiers(t *testing.T) {
|
|||
tier = tiers[2]
|
||||
require.Equal(t, "business", tier.Code)
|
||||
require.Equal(t, "Business", tier.Name)
|
||||
require.Equal(t, int64(1000), tier.Prices.Month)
|
||||
require.Equal(t, int64(10000), tier.Prices.Year)
|
||||
require.Equal(t, "tier", tier.Limits.Basis)
|
||||
require.Equal(t, int64(777333), tier.Limits.Reservations)
|
||||
require.Equal(t, int64(2000), tier.Limits.Messages)
|
||||
|
@ -136,14 +144,14 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
|||
|
||||
// Create tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripeMonthlyPriceID: "price_123",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
// Create subscription
|
||||
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
|
||||
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
@ -172,9 +180,9 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
|||
|
||||
// Create tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripeMonthlyPriceID: "price_123",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
|
@ -187,7 +195,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
|||
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
|
||||
|
||||
// Create subscription
|
||||
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
|
||||
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
@ -214,9 +222,9 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
|||
|
||||
// Create tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripeMonthlyPriceID: "price_123",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
|
||||
|
@ -267,7 +275,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
|
|||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "starter",
|
||||
StripePriceID: "price_1234",
|
||||
StripeMonthlyPriceID: "price_1234",
|
||||
ReservationLimit: 1,
|
||||
MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
|
||||
MessageExpiryDuration: time.Hour,
|
||||
|
@ -298,7 +306,12 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
|
|||
Items: &stripe.SubscriptionItemList{
|
||||
Data: []*stripe.SubscriptionItem{
|
||||
{
|
||||
Price: &stripe.Price{ID: "price_1234"},
|
||||
Price: &stripe.Price{
|
||||
ID: "price_1234",
|
||||
Recurring: &stripe.PriceRecurring{
|
||||
Interval: stripe.PriceRecurringIntervalMonth,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -333,6 +346,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
|
|||
require.Equal(t, "", u.Billing.StripeCustomerID)
|
||||
require.Equal(t, "", u.Billing.StripeSubscriptionID)
|
||||
require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
|
||||
require.Equal(t, stripe.PriceRecurringInterval(""), u.Billing.StripeSubscriptionInterval)
|
||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||
require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
|
||||
|
@ -349,6 +363,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
|
|||
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
|
||||
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
|
||||
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
|
||||
require.Equal(t, stripe.PriceRecurringIntervalMonth, u.Billing.StripeSubscriptionInterval)
|
||||
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||
require.Equal(t, int64(0), u.Stats.Messages)
|
||||
|
@ -423,7 +438,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_1",
|
||||
Code: "starter",
|
||||
StripePriceID: "price_1234", // !
|
||||
StripeMonthlyPriceID: "price_1234", // !
|
||||
ReservationLimit: 1, // !
|
||||
MessageLimit: 100,
|
||||
MessageExpiryDuration: time.Hour,
|
||||
|
@ -435,7 +450,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_2",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_1111", // !
|
||||
StripeMonthlyPriceID: "price_1111", // !
|
||||
ReservationLimit: 3, // !
|
||||
MessageLimit: 200,
|
||||
MessageExpiryDuration: time.Hour,
|
||||
|
@ -457,6 +472,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
StripeCustomerID: "acct_5555",
|
||||
StripeSubscriptionID: "sub_1234",
|
||||
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
|
||||
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
|
||||
StripeSubscriptionPaidUntil: time.Unix(123, 0),
|
||||
StripeSubscriptionCancelAt: time.Unix(456, 0),
|
||||
}
|
||||
|
@ -499,9 +515,10 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
|
||||
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
|
||||
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
|
||||
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
|
||||
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
|
||||
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
|
||||
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
|
||||
require.Equal(t, stripe.PriceRecurringIntervalYear, u.Billing.StripeSubscriptionInterval) // Not "month"
|
||||
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
|
||||
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
|
||||
|
||||
// Verify that reservations were deleted
|
||||
r, err := s.userManager.Reservations("phil")
|
||||
|
@ -546,10 +563,10 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
|
|||
|
||||
// Create a user with a Stripe subscription and 3 reservations
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_1",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_1234",
|
||||
ReservationLimit: 1,
|
||||
ID: "ti_1",
|
||||
Code: "pro",
|
||||
StripeMonthlyPriceID: "price_1234",
|
||||
ReservationLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
@ -562,6 +579,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
|
|||
StripeCustomerID: "acct_5555",
|
||||
StripeSubscriptionID: "sub_1234",
|
||||
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
|
||||
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
|
||||
StripeSubscriptionPaidUntil: time.Unix(123, 0),
|
||||
StripeSubscriptionCancelAt: time.Unix(0, 0),
|
||||
}))
|
||||
|
@ -615,11 +633,11 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
|
|||
stripeMock.
|
||||
On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(false),
|
||||
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
|
||||
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
|
||||
Items: []*stripe.SubscriptionItemsParams{
|
||||
{
|
||||
ID: stripe.String("someid_123"),
|
||||
Price: stripe.String("price_456"),
|
||||
Price: stripe.String("price_457"),
|
||||
},
|
||||
},
|
||||
}).
|
||||
|
@ -627,14 +645,16 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
|
|||
|
||||
// Create tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripeMonthlyPriceID: "price_123",
|
||||
StripeYearlyPriceID: "price_124",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
ID: "ti_456",
|
||||
Code: "business",
|
||||
StripePriceID: "price_456",
|
||||
ID: "ti_456",
|
||||
Code: "business",
|
||||
StripeMonthlyPriceID: "price_456",
|
||||
StripeYearlyPriceID: "price_457",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
@ -644,7 +664,7 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
|
|||
}))
|
||||
|
||||
// Call endpoint to change subscription
|
||||
rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business"}`, map[string]string{
|
||||
rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business","interval":"year"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
@ -795,7 +815,10 @@ const subscriptionUpdatedEventJSON = `
|
|||
"data": [
|
||||
{
|
||||
"price": {
|
||||
"id": "price_1234"
|
||||
"id": "price_1234",
|
||||
"recurring": {
|
||||
"interval": "year"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
@ -818,7 +841,10 @@ const subscriptionDeletedEventJSON = `
|
|||
"data": [
|
||||
{
|
||||
"price": {
|
||||
"id": "price_1234"
|
||||
"id": "price_1234",
|
||||
"recurring": {
|
||||
"interval": "month"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
|
@ -149,6 +149,8 @@ func TestServer_PublishAndSubscribe(t *testing.T) {
|
|||
require.Equal(t, "", messages[1].Title)
|
||||
require.Equal(t, 0, messages[1].Priority)
|
||||
require.Nil(t, messages[1].Tags)
|
||||
require.True(t, time.Now().Add(12*time.Hour-5*time.Second).Unix() < messages[1].Expires)
|
||||
require.True(t, time.Now().Add(12*time.Hour+5*time.Second).Unix() > messages[1].Expires)
|
||||
|
||||
require.Equal(t, messageEvent, messages[2].Event)
|
||||
require.Equal(t, "mytopic", messages[2].Topic)
|
||||
|
@ -287,6 +289,7 @@ func TestServer_PublishNoCache(t *testing.T) {
|
|||
msg := toMessage(t, response.Body.String())
|
||||
require.NotEmpty(t, msg.ID)
|
||||
require.Equal(t, "this message is not cached", msg.Message)
|
||||
require.Equal(t, int64(0), msg.Expires)
|
||||
|
||||
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
|
||||
messages := toMessages(t, response.Body.String())
|
||||
|
@ -324,6 +327,18 @@ func TestServer_PublishAt(t *testing.T) {
|
|||
require.Equal(t, "9.9.9.9", messages[0].Sender.String()) // It's stored in the DB though!
|
||||
}
|
||||
|
||||
func TestServer_PublishAt_Expires(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
response := request(t, s, "PUT", "/mytopic", "a message", map[string]string{
|
||||
"In": "2 days",
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.True(t, m.Expires > time.Now().Add(12*time.Hour+48*time.Hour-time.Minute).Unix())
|
||||
require.True(t, m.Expires < time.Now().Add(12*time.Hour+48*time.Hour+time.Minute).Unix())
|
||||
}
|
||||
|
||||
func TestServer_PublishAtWithCacheError(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
|
@ -1486,7 +1501,7 @@ func TestServer_PublishAttachmentTooLargeBodyVisitorAttachmentTotalSizeLimit(t *
|
|||
c.VisitorAttachmentTotalSizeLimit = 10000
|
||||
s := newTestServer(t, c)
|
||||
|
||||
response := request(t, s, "PUT", "/mytopic", util.RandomString(5000), nil)
|
||||
response := request(t, s, "PUT", "/mytopic", "text file!"+util.RandomString(4990), nil)
|
||||
msg := toMessage(t, response.Body.String())
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, "You received a file: attachment.txt", msg.Message)
|
||||
|
|
|
@ -37,18 +37,18 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error {
|
|||
return err
|
||||
}
|
||||
auth := smtp.PlainAuth("", s.config.SMTPSenderUser, s.config.SMTPSenderPass, host)
|
||||
logvm(v, m).
|
||||
ev := logvm(v, m).
|
||||
Tag(tagEmail).
|
||||
Fields(log.Context{
|
||||
"email_via": s.config.SMTPSenderAddr,
|
||||
"email_user": s.config.SMTPSenderUser,
|
||||
"email_to": to,
|
||||
}).
|
||||
Debug("Sending email")
|
||||
logvm(v, m).
|
||||
Tag(tagEmail).
|
||||
Field("email_body", message).
|
||||
Trace("Email body")
|
||||
})
|
||||
if ev.IsTrace() {
|
||||
ev.Field("email_body", message).Trace("Sending email")
|
||||
} else if ev.IsDebug() {
|
||||
ev.Debug("Sending email")
|
||||
}
|
||||
return smtp.SendMail(s.config.SMTPSenderAddr, auth, s.config.SMTPSenderFrom, []string{to}, []byte(message))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/emersion/go-smtp"
|
||||
|
@ -21,9 +22,14 @@ var (
|
|||
errInvalidAddress = errors.New("invalid address")
|
||||
errInvalidTopic = errors.New("invalid topic")
|
||||
errTooManyRecipients = errors.New("too many recipients")
|
||||
errMultipartNestedTooDeep = errors.New("multipart message nested too deep")
|
||||
errUnsupportedContentType = errors.New("unsupported content type")
|
||||
)
|
||||
|
||||
const (
|
||||
maxMultipartDepth = 2
|
||||
)
|
||||
|
||||
// smtpBackend implements SMTP server methods.
|
||||
type smtpBackend struct {
|
||||
config *Config
|
||||
|
@ -33,6 +39,9 @@ type smtpBackend struct {
|
|||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var _ smtp.Backend = (*smtpBackend)(nil)
|
||||
var _ smtp.Session = (*smtpSession)(nil)
|
||||
|
||||
func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
|
||||
return &smtpBackend{
|
||||
config: conf,
|
||||
|
@ -40,14 +49,9 @@ func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Reques
|
|||
}
|
||||
}
|
||||
|
||||
func (b *smtpBackend) Login(state *smtp.ConnectionState, username, _ string) (smtp.Session, error) {
|
||||
logem(state).Debug("Incoming mail, login with user %s", username)
|
||||
return &smtpSession{backend: b, state: state}, nil
|
||||
}
|
||||
|
||||
func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
|
||||
logem(state).Debug("Incoming mail, anonymous login")
|
||||
return &smtpSession{backend: b, state: state}, nil
|
||||
func (b *smtpBackend) NewSession(conn *smtp.Conn) (smtp.Session, error) {
|
||||
logem(conn).Debug("Incoming mail")
|
||||
return &smtpSession{backend: b, conn: conn}, nil
|
||||
}
|
||||
|
||||
func (b *smtpBackend) Counts() (total int64, success int64, failure int64) {
|
||||
|
@ -59,24 +63,26 @@ func (b *smtpBackend) Counts() (total int64, success int64, failure int64) {
|
|||
// smtpSession is returned after EHLO.
|
||||
type smtpSession struct {
|
||||
backend *smtpBackend
|
||||
state *smtp.ConnectionState
|
||||
conn *smtp.Conn
|
||||
topic string
|
||||
token string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *smtpSession) AuthPlain(username, password string) error {
|
||||
logem(s.state).Debug("AUTH PLAIN (with username %s)", username)
|
||||
func (s *smtpSession) AuthPlain(username, _ string) error {
|
||||
logem(s.conn).Field("smtp_username", username).Debug("AUTH PLAIN (with username %s)", username)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
||||
logem(s.state).Debug("MAIL FROM: %s (with options: %#v)", from, opts)
|
||||
func (s *smtpSession) Mail(from string, opts *smtp.MailOptions) error {
|
||||
logem(s.conn).Field("smtp_mail_from", from).Debug("MAIL FROM: %s", from)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *smtpSession) Rcpt(to string) error {
|
||||
logem(s.state).Debug("RCPT TO: %s", to)
|
||||
logem(s.conn).Field("smtp_rcpt_to", to).Debug("RCPT TO: %s", to)
|
||||
return s.withFailCount(func() error {
|
||||
token := ""
|
||||
conf := s.backend.config
|
||||
addressList, err := mail.ParseAddressList(to)
|
||||
if err != nil {
|
||||
|
@ -88,18 +94,27 @@ func (s *smtpSession) Rcpt(to string) error {
|
|||
if !strings.HasSuffix(to, "@"+conf.SMTPServerDomain) {
|
||||
return errInvalidDomain
|
||||
}
|
||||
// Remove @ntfy.sh from end of email
|
||||
to = strings.TrimSuffix(to, "@"+conf.SMTPServerDomain)
|
||||
if conf.SMTPServerAddrPrefix != "" {
|
||||
if !strings.HasPrefix(to, conf.SMTPServerAddrPrefix) {
|
||||
return errInvalidAddress
|
||||
}
|
||||
// remove ntfy- from beginning of email
|
||||
to = strings.TrimPrefix(to, conf.SMTPServerAddrPrefix)
|
||||
}
|
||||
// If email contains token, split topic and token
|
||||
if strings.Contains(to, "+") {
|
||||
parts := strings.Split(to, "+")
|
||||
to = parts[0]
|
||||
token = parts[1]
|
||||
}
|
||||
if !topicRegex.MatchString(to) {
|
||||
return errInvalidTopic
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.topic = to
|
||||
s.token = token
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
@ -112,17 +127,17 @@ func (s *smtpSession) Data(r io.Reader) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ev := logem(s.state).Tag(tagSMTP)
|
||||
ev := logem(s.conn)
|
||||
if ev.IsTrace() {
|
||||
ev.Field("smtp_data", string(b)).Trace("DATA")
|
||||
} else if ev.IsDebug() {
|
||||
ev.Debug("DATA: %d byte(s)", len(b))
|
||||
ev.Field("smtp_data_len", len(b)).Debug("DATA")
|
||||
}
|
||||
msg, err := mail.ReadMessage(bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := readMailBody(msg)
|
||||
body, err := readMailBody(msg.Body, msg.Header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -156,11 +171,10 @@ func (s *smtpSession) Data(r io.Reader) error {
|
|||
|
||||
func (s *smtpSession) publishMessage(m *message) error {
|
||||
// Extract remote address (for rate limiting)
|
||||
remoteAddr, _, err := net.SplitHostPort(s.state.RemoteAddr.String())
|
||||
remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String())
|
||||
if err != nil {
|
||||
remoteAddr = s.state.RemoteAddr.String()
|
||||
remoteAddr = s.conn.Conn().RemoteAddr().String()
|
||||
}
|
||||
|
||||
// Call HTTP handler with fake HTTP request
|
||||
url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
|
||||
req, err := http.NewRequest("POST", url, strings.NewReader(m.Message))
|
||||
|
@ -173,6 +187,9 @@ func (s *smtpSession) publishMessage(m *message) error {
|
|||
if m.Title != "" {
|
||||
req.Header.Set("Title", m.Title)
|
||||
}
|
||||
if s.token != "" {
|
||||
req.Header.Add("Authorization", "Bearer "+s.token)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
s.backend.handler(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
|
@ -198,54 +215,58 @@ func (s *smtpSession) withFailCount(fn func() error) error {
|
|||
if err != nil {
|
||||
// Almost all of these errors are parse errors, and user input errors.
|
||||
// We do not want to spam the log with WARN messages.
|
||||
logem(s.state).Err(err).Debug("Incoming mail error")
|
||||
logem(s.conn).Err(err).Debug("Incoming mail error")
|
||||
s.backend.failure++
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func readMailBody(msg *mail.Message) (string, error) {
|
||||
if msg.Header.Get("Content-Type") == "" {
|
||||
return readPlainTextMailBody(msg)
|
||||
func readMailBody(body io.Reader, header mail.Header) (string, error) {
|
||||
if header.Get("Content-Type") == "" {
|
||||
return readPlainTextMailBody(body, header.Get("Content-Transfer-Encoding"))
|
||||
}
|
||||
contentType, params, err := mime.ParseMediaType(msg.Header.Get("Content-Type"))
|
||||
contentType, params, err := mime.ParseMediaType(header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if contentType == "text/plain" {
|
||||
return readPlainTextMailBody(msg)
|
||||
} else if strings.HasPrefix(contentType, "multipart/") {
|
||||
return readMultipartMailBody(msg, params)
|
||||
if strings.ToLower(contentType) == "text/plain" {
|
||||
return readPlainTextMailBody(body, header.Get("Content-Transfer-Encoding"))
|
||||
} else if strings.HasPrefix(strings.ToLower(contentType), "multipart/") {
|
||||
return readMultipartMailBody(body, params, 0)
|
||||
}
|
||||
return "", errUnsupportedContentType
|
||||
}
|
||||
|
||||
func readPlainTextMailBody(msg *mail.Message) (string, error) {
|
||||
body, err := io.ReadAll(msg.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
func readMultipartMailBody(body io.Reader, params map[string]string, depth int) (string, error) {
|
||||
if depth >= maxMultipartDepth {
|
||||
return "", errMultipartNestedTooDeep
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
func readMultipartMailBody(msg *mail.Message, params map[string]string) (string, error) {
|
||||
mr := multipart.NewReader(msg.Body, params["boundary"])
|
||||
mr := multipart.NewReader(body, params["boundary"])
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if err != nil { // may be io.EOF
|
||||
return "", err
|
||||
}
|
||||
partContentType, _, err := mime.ParseMediaType(part.Header.Get("Content-Type"))
|
||||
partContentType, partParams, err := mime.ParseMediaType(part.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if partContentType != "text/plain" {
|
||||
continue
|
||||
if strings.ToLower(partContentType) == "text/plain" {
|
||||
return readPlainTextMailBody(part, part.Header.Get("Content-Transfer-Encoding"))
|
||||
} else if strings.HasPrefix(strings.ToLower(partContentType), "multipart/") {
|
||||
return readMultipartMailBody(part, partParams, depth+1)
|
||||
}
|
||||
body, err := io.ReadAll(part)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(body), nil
|
||||
// Continue with next part
|
||||
}
|
||||
}
|
||||
|
||||
func readPlainTextMailBody(reader io.Reader, transferEncoding string) (string, error) {
|
||||
if strings.ToLower(transferEncoding) == "base64" {
|
||||
reader = base64.NewDecoder(base64.StdEncoding, reader)
|
||||
}
|
||||
body, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
|
|
|
@ -1,16 +1,23 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSmtpBackend_Multipart(t *testing.T) {
|
||||
email := `MIME-Version: 1.0
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: ntfy-mytopic@ntfy.sh
|
||||
DATA
|
||||
MIME-Version: 1.0
|
||||
Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
|
||||
Subject: and one more
|
||||
|
@ -28,20 +35,25 @@ Content-Type: text/html; charset="UTF-8"
|
|||
|
||||
<div dir="ltr">what's up<br clear="all"><div><br></div></div>
|
||||
|
||||
--000000000000f3320b05d42915c9--`
|
||||
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
--000000000000f3320b05d42915c9--
|
||||
.
|
||||
`
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "and one more", r.Header.Get("Title"))
|
||||
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||
})
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_MultipartNoBody(t *testing.T) {
|
||||
email := `MIME-Version: 1.0
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: ntfy-emailtest@ntfy.sh
|
||||
DATA
|
||||
MIME-Version: 1.0
|
||||
Date: Tue, 28 Dec 2021 01:33:34 +0100
|
||||
Message-ID: <CAAvm7ABCDsi9vsuu0WTRXzZQBC8dXrDOLT8iCWdqrsmg@mail.gmail.com>
|
||||
Subject: This email has a subject but no body
|
||||
|
@ -59,20 +71,25 @@ Content-Type: text/html; charset="UTF-8"
|
|||
|
||||
<div dir="ltr"><br></div>
|
||||
|
||||
--000000000000bcf4a405d429f8d4--`
|
||||
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
--000000000000bcf4a405d429f8d4--
|
||||
.
|
||||
`
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/emailtest", r.URL.Path)
|
||||
require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body
|
||||
require.Equal(t, "This email has a subject but no body", readAll(t, r.Body))
|
||||
})
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_Plaintext(t *testing.T) {
|
||||
email := `Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: mytopic@ntfy.sh
|
||||
DATA
|
||||
Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
|
||||
Subject: and one more
|
||||
From: Phil <phil@example.com>
|
||||
|
@ -80,56 +97,68 @@ To: mytopic@ntfy.sh
|
|||
Content-Type: text/plain; charset="UTF-8"
|
||||
|
||||
what's up
|
||||
.
|
||||
`
|
||||
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "and one more", r.Header.Get("Title"))
|
||||
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||
})
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) {
|
||||
email := `Subject: Very short mail
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: mytopic@ntfy.sh
|
||||
DATA
|
||||
Subject: Very short mail
|
||||
|
||||
what's up
|
||||
.
|
||||
`
|
||||
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "Very short mail", r.Header.Get("Title"))
|
||||
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||
})
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_Plaintext_EncodedSubject(t *testing.T) {
|
||||
email := `Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: ntfy-mytopic@ntfy.sh
|
||||
DATA
|
||||
Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
Subject: =?UTF-8?B?VGhyZWUgc2FudGFzIPCfjoXwn46F8J+OhQ==?=
|
||||
From: Phil <phil@example.com>
|
||||
To: ntfy-mytopic@ntfy.sh
|
||||
Content-Type: text/plain; charset="UTF-8"
|
||||
|
||||
what's up
|
||||
.
|
||||
`
|
||||
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title"))
|
||||
})
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_Plaintext_TooLongTruncate(t *testing.T) {
|
||||
email := `Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: mytopic@ntfy.sh
|
||||
DATA
|
||||
Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
|
||||
Subject: and one more
|
||||
From: Phil <phil@example.com>
|
||||
|
@ -148,60 +177,61 @@ so i'm gonna fill the rest of this with AAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAa
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
||||
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
||||
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
||||
that should do it
|
||||
.
|
||||
`
|
||||
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
expected := `you know this is a string.
|
||||
it's a long string.
|
||||
it's supposed to be longer than the max message length
|
||||
|
@ -214,68 +244,71 @@ so i'm gonna fill the rest of this with AAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAa
|
||||
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
......................................................................
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
pppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp
|
||||
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
||||
BBBBBBBBBBBBBBBBBBBBBBBBB`
|
||||
require.Equal(t, 4096, len(expected)) // Sanity check
|
||||
require.Equal(t, expected, readAll(t, r.Body))
|
||||
})
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_Unsupported(t *testing.T) {
|
||||
email := `Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: ntfy-mytopic@ntfy.sh
|
||||
DATA
|
||||
Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
|
||||
Subject: and one more
|
||||
From: Phil <phil@example.com>
|
||||
|
@ -283,34 +316,254 @@ To: mytopic@ntfy.sh
|
|||
Content-Type: text/SOMETHINGELSE
|
||||
|
||||
what's up
|
||||
.
|
||||
`
|
||||
conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) {
|
||||
// Nothing.
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("This should not be called")
|
||||
})
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass")
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||
require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "554 5.0.0 Error: transaction failed, blame it on the weather: unsupported content type")
|
||||
}
|
||||
|
||||
func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) {
|
||||
conf := newTestConfig(t)
|
||||
func TestSmtpBackend_InvalidAddress(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: unsupported@ntfy.sh
|
||||
DATA
|
||||
Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
Subject: and one more
|
||||
From: Phil <phil@example.com>
|
||||
To: mytopic@ntfy.sh
|
||||
Content-Type: text/plain
|
||||
|
||||
what's up
|
||||
.
|
||||
`
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("This should not be called")
|
||||
})
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "451 4.0.0 invalid address")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_Base64Body(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: test@mydomain.me
|
||||
RCPT TO: ntfy-mytopic@ntfy.sh
|
||||
DATA
|
||||
Content-Type: multipart/mixed; boundary="===============2138658284696597373=="
|
||||
MIME-Version: 1.0
|
||||
Subject: TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local
|
||||
From: =?utf-8?q?Robbie?= <test@mydomain.me>
|
||||
To: test@mydomain.me
|
||||
Date: Thu, 16 Feb 2023 01:04:00 -0000
|
||||
Message-ID: <truenas-20230216.010400.344514.b'8jfL'@truenas.local>
|
||||
|
||||
This is a multi-part message in MIME format.
|
||||
--===============2138658284696597373==
|
||||
Content-Type: text/plain; charset="utf-8"
|
||||
MIME-Version: 1.0
|
||||
Content-Transfer-Encoding: base64
|
||||
|
||||
VGhpcyBpcyBhIHRlc3QgbWVzc2FnZSBmcm9tIFRydWVOQVMgQ09SRS4=
|
||||
|
||||
--===============2138658284696597373==
|
||||
Content-Type: text/html; charset="utf-8"
|
||||
MIME-Version: 1.0
|
||||
Content-Transfer-Encoding: base64
|
||||
|
||||
PCFET0NUWVBFIEhUTUwgUFVCTElDICItLy9XM0MvL0RURCBIVE1MIDQuMCBUcmFuc2l0aW9uYWwv
|
||||
L0VOIj4KClRoaXMgaXMgYSB0ZXN0IG1lc3NhZ2UgZnJvbSBUcnVlTkFTIENPUkUuCg==
|
||||
|
||||
--===============2138658284696597373==--
|
||||
.
|
||||
`
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local", r.Header.Get("Title"))
|
||||
require.Equal(t, "This is a test message from TrueNAS CORE.", readAll(t, r.Body))
|
||||
})
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_NestedMultipartBase64(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: test@mydomain.me
|
||||
RCPT TO: ntfy-mytopic@ntfy.sh
|
||||
DATA
|
||||
Content-Type: multipart/mixed; boundary="===============2138658284696597373=="
|
||||
MIME-Version: 1.0
|
||||
Subject: TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local
|
||||
From: =?utf-8?q?Robbie?= <test@mydomain.me>
|
||||
To: test@mydomain.me
|
||||
Date: Thu, 16 Feb 2023 01:04:00 -0000
|
||||
Message-ID: <truenas-20230216.010400.344514.b'8jfL'@truenas.local>
|
||||
|
||||
This is a multi-part message in MIME format.
|
||||
--===============2138658284696597373==
|
||||
Content-Type: multipart/alternative; boundary="===============2233989480071754745=="
|
||||
MIME-Version: 1.0
|
||||
|
||||
--===============2233989480071754745==
|
||||
Content-Type: text/plain; charset="utf-8"
|
||||
MIME-Version: 1.0
|
||||
Content-Transfer-Encoding: base64
|
||||
|
||||
VGhpcyBpcyBhIHRlc3QgbWVzc2FnZSBmcm9tIFRydWVOQVMgQ09SRS4=
|
||||
|
||||
--===============2233989480071754745==
|
||||
Content-Type: text/html; charset="utf-8"
|
||||
MIME-Version: 1.0
|
||||
Content-Transfer-Encoding: base64
|
||||
|
||||
PCFET0NUWVBFIEhUTUwgUFVCTElDICItLy9XM0MvL0RURCBIVE1MIDQuMCBUcmFuc2l0aW9uYWwv
|
||||
L0VOIj4KClRoaXMgaXMgYSB0ZXN0IG1lc3NhZ2UgZnJvbSBUcnVlTkFTIENPUkUuCg==
|
||||
|
||||
--===============2233989480071754745==--
|
||||
|
||||
--===============2138658284696597373==--
|
||||
.
|
||||
`
|
||||
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local", r.Header.Get("Title"))
|
||||
require.Equal(t, "This is a test message from TrueNAS CORE.", readAll(t, r.Body))
|
||||
})
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_NestedMultipartTooDeep(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: test@mydomain.me
|
||||
RCPT TO: ntfy-mytopic@ntfy.sh
|
||||
DATA
|
||||
Content-Type: multipart/mixed; boundary="===============1=="
|
||||
MIME-Version: 1.0
|
||||
Subject: TrueNAS truenas.local: TrueNAS Test Message hostname: truenas.local
|
||||
From: =?utf-8?q?Robbie?= <test@mydomain.me>
|
||||
To: test@mydomain.me
|
||||
Date: Thu, 16 Feb 2023 01:04:00 -0000
|
||||
Message-ID: <truenas-20230216.010400.344514.b'8jfL'@truenas.local>
|
||||
|
||||
This is a multi-part message in MIME format.
|
||||
--===============1==
|
||||
Content-Type: multipart/alternative; boundary="===============2=="
|
||||
MIME-Version: 1.0
|
||||
|
||||
--===============2==
|
||||
Content-Type: multipart/alternative; boundary="===============3=="
|
||||
MIME-Version: 1.0
|
||||
|
||||
--===============3==
|
||||
Content-Type: text/plain; charset="utf-8"
|
||||
MIME-Version: 1.0
|
||||
Content-Transfer-Encoding: base64
|
||||
|
||||
VGhpcyBpcyBhIHRlc3QgbWVzc2FnZSBmcm9tIFRydWVOQVMgQ09SRS4=
|
||||
|
||||
--===============3==
|
||||
Content-Type: text/html; charset="utf-8"
|
||||
MIME-Version: 1.0
|
||||
Content-Transfer-Encoding: base64
|
||||
|
||||
PCFET0NUWVBFIEhUTUwgUFVCTElDICItLy9XM0MvL0RURCBIVE1MIDQuMCBUcmFuc2l0aW9uYWwv
|
||||
L0VOIj4KClRoaXMgaXMgYSB0ZXN0IG1lc3NhZ2UgZnJvbSBUcnVlTkFTIENPUkUuCg==
|
||||
|
||||
--===============3==--
|
||||
|
||||
--===============2==--
|
||||
|
||||
--===============1==--
|
||||
.
|
||||
`
|
||||
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("This should not be called")
|
||||
})
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "554 5.0.0 Error: transaction failed, blame it on the weather: multipart message nested too deep")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_PlaintextWithToken(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: ntfy-mytopic+tk_KLORUqSqvNRLpY11DfkHVbHu9NGG2@ntfy.sh
|
||||
DATA
|
||||
Subject: Very short mail
|
||||
|
||||
what's up
|
||||
.
|
||||
`
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "Very short mail", r.Header.Get("Title"))
|
||||
require.Equal(t, "Bearer tk_KLORUqSqvNRLpY11DfkHVbHu9NGG2", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||
})
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
type smtpHandlerFunc func(http.ResponseWriter, *http.Request)
|
||||
|
||||
func newTestSMTPServer(t *testing.T, handler smtpHandlerFunc) (s *smtp.Server, c net.Conn, conf *Config, scanner *bufio.Scanner) {
|
||||
conf = newTestConfig(t)
|
||||
conf.SMTPServerListen = ":25"
|
||||
conf.SMTPServerDomain = "ntfy.sh"
|
||||
conf.SMTPServerAddrPrefix = "ntfy-"
|
||||
backend := newMailBackend(conf, handler)
|
||||
return conf, backend
|
||||
}
|
||||
|
||||
func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
|
||||
ip, err := net.ResolveIPAddr("ip", remoteAddr)
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &smtp.ConnectionState{
|
||||
Hostname: "myhostname",
|
||||
LocalAddr: ip,
|
||||
RemoteAddr: ip,
|
||||
s = smtp.NewServer(backend)
|
||||
s.Domain = conf.SMTPServerDomain
|
||||
s.AllowInsecureAuth = true
|
||||
go func() {
|
||||
require.Nil(t, s.Serve(l))
|
||||
}()
|
||||
c, err = net.Dial("tcp", l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
scanner = bufio.NewScanner(c)
|
||||
return
|
||||
}
|
||||
|
||||
func writeAndReadUntilLine(t *testing.T, email string, conn net.Conn, scanner *bufio.Scanner, expectedLine string) {
|
||||
_, err := io.WriteString(conn, email)
|
||||
require.Nil(t, err)
|
||||
readUntilLine(t, conn, scanner, expectedLine)
|
||||
}
|
||||
|
||||
func readUntilLine(t *testing.T, conn net.Conn, scanner *bufio.Scanner, expectedLine string) {
|
||||
cancelChan := make(chan bool)
|
||||
go func() {
|
||||
select {
|
||||
case <-cancelChan:
|
||||
case <-time.After(3 * time.Second):
|
||||
conn.Close()
|
||||
t.Error("Failed waiting for expected output")
|
||||
}
|
||||
}()
|
||||
var output string
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
if strings.TrimSpace(text) == expectedLine {
|
||||
cancelChan <- true
|
||||
return
|
||||
}
|
||||
output += text + "\n"
|
||||
//fmt.Println(text)
|
||||
}
|
||||
t.Fatalf("Expected line '%s' not found in output:\n%s", expectedLine, output)
|
||||
}
|
||||
|
|
|
@ -309,6 +309,7 @@ type apiAccountBilling struct {
|
|||
Customer bool `json:"customer"`
|
||||
Subscription bool `json:"subscription"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Interval string `json:"interval,omitempty"`
|
||||
PaidUntil int64 `json:"paid_until,omitempty"`
|
||||
CancelAt int64 `json:"cancel_at,omitempty"`
|
||||
}
|
||||
|
@ -343,11 +344,16 @@ type apiConfigResponse struct {
|
|||
DisallowedTopics []string `json:"disallowed_topics"`
|
||||
}
|
||||
|
||||
type apiAccountBillingPrices struct {
|
||||
Month int64 `json:"month"`
|
||||
Year int64 `json:"year"`
|
||||
}
|
||||
|
||||
type apiAccountBillingTier struct {
|
||||
Code string `json:"code,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Price string `json:"price,omitempty"`
|
||||
Limits *apiAccountLimits `json:"limits"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Prices *apiAccountBillingPrices `json:"prices,omitempty"`
|
||||
Limits *apiAccountLimits `json:"limits"`
|
||||
}
|
||||
|
||||
type apiAccountBillingSubscriptionCreateResponse struct {
|
||||
|
@ -355,7 +361,8 @@ type apiAccountBillingSubscriptionCreateResponse struct {
|
|||
}
|
||||
|
||||
type apiAccountBillingSubscriptionChangeRequest struct {
|
||||
Tier string `json:"tier"`
|
||||
Tier string `json:"tier"`
|
||||
Interval string `json:"interval"`
|
||||
}
|
||||
|
||||
type apiAccountBillingPortalRedirectResponse struct {
|
||||
|
@ -385,7 +392,10 @@ type apiStripeSubscriptionUpdatedEvent struct {
|
|||
Items *struct {
|
||||
Data []*struct {
|
||||
Price *struct {
|
||||
ID string `json:"id"`
|
||||
ID string `json:"id"`
|
||||
Recurring *struct {
|
||||
Interval string `json:"interval"`
|
||||
} `json:"recurring"`
|
||||
} `json:"price"`
|
||||
} `json:"data"`
|
||||
} `json:"items"`
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
|
||||
|
@ -74,7 +72,7 @@ func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
|
|||
if err != nil {
|
||||
ip = netip.IPv4Unspecified()
|
||||
if remoteAddr != "@" || !behindProxy { // RemoteAddr is @ when unix socket is used
|
||||
log.Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err)
|
||||
logr(r).Err(err).Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created", remoteAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -85,7 +83,7 @@ func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
|
|||
ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
|
||||
realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
|
||||
if err != nil {
|
||||
log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error())
|
||||
logr(r).Err(err).Error("invalid IP address %s received in X-Forwarded-For header", ip)
|
||||
// Fall back to regular remote address if X-Forwarded-For is damaged
|
||||
} else {
|
||||
ip = realIP
|
||||
|
|
|
@ -159,8 +159,9 @@ func (v *visitor) contextNoLock() log.Context {
|
|||
fields["user_id"] = v.user.ID
|
||||
fields["user_name"] = v.user.Name
|
||||
if v.user.Tier != nil {
|
||||
fields["tier_id"] = v.user.Tier.ID
|
||||
fields["tier_name"] = v.user.Tier.Name
|
||||
for field, value := range v.user.Tier.Context() {
|
||||
fields[field] = value
|
||||
}
|
||||
}
|
||||
if v.user.Billing.StripeCustomerID != "" {
|
||||
fields["stripe_customer_id"] = v.user.Billing.StripeCustomerID
|
||||
|
@ -329,9 +330,13 @@ func (v *visitor) SetUser(u *user.User) {
|
|||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
|
||||
v.user = u
|
||||
v.user = u // u may be nil!
|
||||
if shouldResetLimiters {
|
||||
v.resetLimitersNoLock(0, 0, true)
|
||||
var messages, emails int64
|
||||
if u != nil {
|
||||
messages, emails = u.Stats.Messages, u.Stats.Emails
|
||||
}
|
||||
v.resetLimitersNoLock(messages, emails, true)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue