From 8283b6be975e07a9a38c58ca51728ae68647d4b0 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Tue, 31 May 2022 20:38:56 -0400 Subject: [PATCH] Firebase quota limit --- server/config.go | 21 +++++---- server/errors.go | 1 + server/server.go | 97 ++++++++++++++++++-------------------- server/server_firebase.go | 11 ++++- server/server_test.go | 3 +- server/smtp_server.go | 42 +++++++++++++---- server/smtp_server_test.go | 97 ++++++++++++++++++++++---------------- server/topic.go | 6 +-- server/visitor.go | 21 +++++++-- 9 files changed, 180 insertions(+), 119 deletions(-) diff --git a/server/config.go b/server/config.go index 4db52a3..3de4bd7 100644 --- a/server/config.go +++ b/server/config.go @@ -6,15 +6,16 @@ import ( // Defines default config settings (excluding limits, see below) const ( - DefaultListenHTTP = ":80" - DefaultCacheDuration = 12 * time.Hour - DefaultKeepaliveInterval = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!) - DefaultManagerInterval = time.Minute - DefaultAtSenderInterval = 10 * time.Second - DefaultMinDelay = 10 * time.Second - DefaultMaxDelay = 3 * 24 * time.Hour - DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery - DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs) + DefaultListenHTTP = ":80" + DefaultCacheDuration = 12 * time.Hour + DefaultKeepaliveInterval = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!) + DefaultManagerInterval = time.Minute + DefaultAtSenderInterval = 10 * time.Second + DefaultMinDelay = 10 * time.Second + DefaultMaxDelay = 3 * 24 * time.Hour + DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery + DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs) + DefaultFirebaseQuotaLimitPenaltyDuration = 10 * time.Minute ) // Defines all global and per-visitor limits @@ -69,6 +70,7 @@ type Config struct { AtSenderInterval time.Duration FirebaseKeepaliveInterval time.Duration FirebasePollInterval time.Duration + FirebaseQuotaLimitPenaltyDuration time.Duration UpstreamBaseURL string SMTPSenderAddr string SMTPSenderUser string @@ -121,6 +123,7 @@ func NewConfig() *Config { AtSenderInterval: DefaultAtSenderInterval, FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval, FirebasePollInterval: DefaultFirebasePollInterval, + FirebaseQuotaLimitPenaltyDuration: DefaultFirebaseQuotaLimitPenaltyDuration, TotalTopicLimit: DefaultTotalTopicLimit, VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit, VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit, diff --git a/server/errors.go b/server/errors.go index 32c1b3b..2fa883f 100644 --- a/server/errors.go +++ b/server/errors.go @@ -59,6 +59,7 @@ var ( errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsFirebaseQuotaReached = &errHTTP{42906, http.StatusTooManyRequests, "too many requests: Firebase quota for topic reached", "https://ntfy.sh/docs/publish/#limitations"} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""} ) diff --git a/server/server.go b/server/server.go index 86ed753..2baa366 100644 --- a/server/server.go +++ b/server/server.go @@ -7,13 +7,11 @@ import ( "embed" "encoding/base64" "encoding/json" - "errors" "fmt" "io" "log" "net" "net/http" - "net/http/httptest" "net/url" "os" "path" @@ -221,7 +219,7 @@ func (s *Server) Run() error { } s.mu.Unlock() go s.runManager() - go s.runAtSender() + go s.runDelayedSender() go s.runFirebaseKeepaliver() return <-errChan @@ -435,7 +433,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito } delayed := m.Time > time.Now().Unix() if !delayed { - if err := t.Publish(m); err != nil { + if err := t.Publish(v, m); err != nil { return err } } @@ -465,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito } func (s *Server) sendToFirebase(v *visitor, m *message) { - if err := s.firebase(m); err != nil { + if err := s.firebase(v, m); err != nil { log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error()) } } @@ -731,7 +729,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * return err } var wlock sync.Mutex - sub := func(msg *message) error { + sub := func(v *visitor, msg *message) error { if !filters.Pass(msg) { return nil } @@ -752,7 +750,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! if poll { - return s.sendOldMessages(topics, since, scheduled, sub) + return s.sendOldMessages(topics, since, scheduled, v, sub) } subscriberIDs := make([]int, 0) for _, t := range topics { @@ -763,10 +761,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * topics[i].Unsubscribe(subscriberID) // Order! } }() - if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message + if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message return err } - if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil { + if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil { return err } for { @@ -775,7 +773,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * return nil case <-time.After(s.config.KeepaliveInterval): v.Keepalive() - if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message + if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message return err } } @@ -849,7 +847,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } } }) - sub := func(msg *message) error { + sub := func(v *visitor, msg *message) error { if !filters.Pass(msg) { return nil } @@ -862,7 +860,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests if poll { - return s.sendOldMessages(topics, since, scheduled, sub) + return s.sendOldMessages(topics, since, scheduled, v, sub) } subscriberIDs := make([]int, 0) for _, t := range topics { @@ -873,10 +871,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi topics[i].Unsubscribe(subscriberID) // Order! } }() - if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message + if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message return err } - if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil { + if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil { return err } err = g.Wait() @@ -900,7 +898,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu return } -func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error { +func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error { if since.IsNone() { return nil } @@ -910,7 +908,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b return err } for _, m := range messages { - if err := sub(m); err != nil { + if err := sub(v, m); err != nil { return err } } @@ -1057,23 +1055,7 @@ func (s *Server) updateStatsAndPrune() { } func (s *Server) runSMTPServer() error { - sub := func(m *message) error { - url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic) - req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message)) - if err != nil { - return err - } - if m.Title != "" { - req.Header.Set("Title", m.Title) - } - rr := httptest.NewRecorder() - s.handle(rr, req) - if rr.Code != http.StatusOK { - return errors.New("error: " + rr.Body.String()) - } - return nil - } - s.smtpBackend = newMailBackend(s.config, sub) + s.smtpBackend = newMailBackend(s.config, s.handle) s.smtpServer = smtp.NewServer(s.smtpBackend) s.smtpServer.Addr = s.config.SMTPServerListen s.smtpServer.Domain = s.config.SMTPServerDomain @@ -1096,7 +1078,7 @@ func (s *Server) runManager() { } } -func (s *Server) runAtSender() { +func (s *Server) runDelayedSender() { for { select { case <-time.After(s.config.AtSenderInterval): @@ -1113,14 +1095,15 @@ func (s *Server) runFirebaseKeepaliver() { if s.firebase == nil { return } + v := newVisitor(s.config, s.messageCache, "0.0.0.0") for { select { case <-time.After(s.config.FirebaseKeepaliveInterval): - if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil { + if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil { log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error()) } case <-time.After(s.config.FirebasePollInterval): - if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil { + if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil { log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error()) } case <-s.closeChan: @@ -1130,28 +1113,36 @@ func (s *Server) runFirebaseKeepaliver() { } func (s *Server) sendDelayedMessages() error { - s.mu.Lock() - defer s.mu.Unlock() messages, err := s.messageCache.MessagesDue() if err != nil { return err } for _, m := range messages { - t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published - if ok { - if err := t.Publish(m); err != nil { - log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error()) - } + v := s.visitorFromIP("0.0.0.0") // FIXME: get message owner!! + if err := s.sendDelayedMessage(v, m); err != nil { + log.Printf("error sending delayed message: %s", err.Error()) } - if s.firebase != nil { // Firebase subscribers may not show up in topics map - if err := s.firebase(m); err != nil { - log.Printf("unable to publish to Firebase: %v", err.Error()) - } + } + return nil +} + +func (s *Server) sendDelayedMessage(v *visitor, m *message) error { + s.mu.Lock() + defer s.mu.Unlock() + t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published + if ok { + if err := t.Publish(v, m); err != nil { + return fmt.Errorf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error()) } - if err := s.messageCache.MarkPublished(m); err != nil { - return err + } + if s.firebase != nil { // Firebase subscribers may not show up in topics map + if err := s.firebase(v, m); err != nil { + return fmt.Errorf("unable to publish to Firebase: %v", err.Error()) } } + if err := s.messageCache.MarkPublished(m); err != nil { + return err + } return nil } @@ -1290,8 +1281,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool // visitor creates or retrieves a rate.Limiter for the given visitor. // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). func (s *Server) visitor(r *http.Request) *visitor { - s.mu.Lock() - defer s.mu.Unlock() remoteAddr := r.RemoteAddr ip, _, err := net.SplitHostPort(remoteAddr) if err != nil { @@ -1300,6 +1289,12 @@ func (s *Server) visitor(r *http.Request) *visitor { if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" { ip = r.Header.Get("X-Forwarded-For") } + return s.visitorFromIP(ip) +} + +func (s *Server) visitorFromIP(ip string) *visitor { + s.mu.Lock() + defer s.mu.Unlock() v, exists := s.visitors[ip] if !exists { s.visitors[ip] = newVisitor(s.config, s.messageCache, ip) diff --git a/server/server_firebase.go b/server/server_firebase.go index 1facd5d..8368337 100644 --- a/server/server_firebase.go +++ b/server/server_firebase.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log" "strings" firebase "firebase.google.com/go/v4" @@ -26,12 +27,20 @@ func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subsc if err != nil { return nil, err } - return func(m *message) error { + return func(v *visitor, m *message) error { + if err := v.FirebaseAllowed(); err != nil { + return errHTTPTooManyRequestsFirebaseQuotaReached + } fbm, err := toFirebaseMessage(m, auther) if err != nil { return err } _, err = msg.Send(context.Background(), fbm) + if err != nil && messaging.IsQuotaExceeded(err) { + log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic) + v.FirebaseTemporarilyDeny() + return errHTTPTooManyRequestsFirebaseQuotaReached + } return err }, nil } diff --git a/server/server_test.go b/server/server_test.go index 06f3cd2..5e23e47 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -469,7 +469,8 @@ func TestServer_PublishFirebase(t *testing.T) { require.NotEmpty(t, msg.ID) // Keepalive message - require.Nil(t, s.firebase(newKeepaliveMessage(firebaseControlTopic))) + v := newVisitor(s.config, s.messageCache, "1.2.3.4") + require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic))) time.Sleep(500 * time.Millisecond) // Time for sends } diff --git a/server/smtp_server.go b/server/smtp_server.go index c437d23..a5b6f85 100644 --- a/server/smtp_server.go +++ b/server/smtp_server.go @@ -3,10 +3,13 @@ package server import ( "bytes" "errors" + "fmt" "github.com/emersion/go-smtp" "io" "mime" "mime/multipart" + "net/http" + "net/http/httptest" "net/mail" "strings" "sync" @@ -23,25 +26,25 @@ var ( // smtpBackend implements SMTP server methods. type smtpBackend struct { config *Config - sub subscriber + handler func(http.ResponseWriter, *http.Request) success int64 failure int64 mu sync.Mutex } -func newMailBackend(conf *Config, sub subscriber) *smtpBackend { +func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend { return &smtpBackend{ - config: conf, - sub: sub, + config: conf, + handler: handler, } } func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { - return &smtpSession{backend: b}, nil + return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil } func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { - return &smtpSession{backend: b}, nil + return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil } func (b *smtpBackend) Counts() (success int64, failure int64) { @@ -52,9 +55,10 @@ func (b *smtpBackend) Counts() (success int64, failure int64) { // smtpSession is returned after EHLO. type smtpSession struct { - backend *smtpBackend - topic string - mu sync.Mutex + backend *smtpBackend + remoteAddr string + topic string + mu sync.Mutex } func (s *smtpSession) AuthPlain(username, password string) error { @@ -128,7 +132,7 @@ func (s *smtpSession) Data(r io.Reader) error { m.Message = m.Title // Flip them, this makes more sense m.Title = "" } - if err := s.backend.sub(m); err != nil { + if err := s.publishMessage(m); err != nil { return err } s.backend.mu.Lock() @@ -138,6 +142,24 @@ func (s *smtpSession) Data(r io.Reader) error { }) } +func (s *smtpSession) publishMessage(m *message) error { + url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic) + req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message)) + req.RemoteAddr = s.remoteAddr // rate limiting!! + if err != nil { + return err + } + if m.Title != "" { + req.Header.Set("Title", m.Title) + } + rr := httptest.NewRecorder() + s.backend.handler(rr, req) + if rr.Code != http.StatusOK { + return errors.New("error: " + rr.Body.String()) + } + return nil +} + func (s *smtpSession) Reset() { s.mu.Lock() s.topic = "" diff --git a/server/smtp_server_test.go b/server/smtp_server_test.go index d0e8bfd..8e9d589 100644 --- a/server/smtp_server_test.go +++ b/server/smtp_server_test.go @@ -3,6 +3,9 @@ package server import ( "github.com/emersion/go-smtp" "github.com/stretchr/testify/require" + "io" + "net" + "net/http" "strings" "testing" ) @@ -27,13 +30,12 @@ Content-Type: text/html; charset="UTF-8"
what's up

--000000000000f3320b05d42915c9--` - _, backend := newTestBackend(t, func(m *message) error { - require.Equal(t, "mytopic", m.Topic) - require.Equal(t, "and one more", m.Title) - require.Equal(t, "what's up", m.Message) - return nil + _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/mytopic", r.URL.Path) + require.Equal(t, "and one more", r.Header.Get("Title")) + require.Equal(t, "what's up", readAll(t, r.Body)) }) - session, _ := backend.AnonymousLogin(nil) + session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4")) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh")) require.Nil(t, session.Data(strings.NewReader(email))) @@ -59,13 +61,12 @@ Content-Type: text/html; charset="UTF-8"

--000000000000bcf4a405d429f8d4--` - _, backend := newTestBackend(t, func(m *message) error { - require.Equal(t, "emailtest", m.Topic) - require.Equal(t, "", m.Title) // We flipped message and body - require.Equal(t, "This email has a subject but no body", m.Message) - return nil + _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/emailtest", r.URL.Path) + require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body + require.Equal(t, "This email has a subject but no body", readAll(t, r.Body)) }) - session, _ := backend.AnonymousLogin(nil) + session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4")) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh")) require.Nil(t, session.Data(strings.NewReader(email))) @@ -81,14 +82,13 @@ Content-Type: text/plain; charset="UTF-8" what's up ` - conf, backend := newTestBackend(t, func(m *message) error { - require.Equal(t, "mytopic", m.Topic) - require.Equal(t, "and one more", m.Title) - require.Equal(t, "what's up", m.Message) - return nil + conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/mytopic", r.URL.Path) + require.Equal(t, "and one more", r.Header.Get("Title")) + require.Equal(t, "what's up", readAll(t, r.Body)) }) conf.SMTPServerAddrPrefix = "" - session, _ := backend.AnonymousLogin(nil) + session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4")) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Data(strings.NewReader(email))) @@ -99,14 +99,13 @@ func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) { what's up ` - conf, backend := newTestBackend(t, func(m *message) error { - require.Equal(t, "mytopic", m.Topic) - require.Equal(t, "Very short mail", m.Title) - require.Equal(t, "what's up", m.Message) - return nil + conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/mytopic", r.URL.Path) + require.Equal(t, "Very short mail", r.Header.Get("Title")) + require.Equal(t, "what's up", readAll(t, r.Body)) }) conf.SMTPServerAddrPrefix = "" - session, _ := backend.AnonymousLogin(nil) + session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4")) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Data(strings.NewReader(email))) @@ -121,11 +120,10 @@ Content-Type: text/plain; charset="UTF-8" what's up ` - _, backend := newTestBackend(t, func(m *message) error { - require.Equal(t, "Three santas 🎅🎅🎅", m.Title) - return nil + _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title")) }) - session, _ := backend.AnonymousLogin(nil) + session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4")) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh")) require.Nil(t, session.Data(strings.NewReader(email))) @@ -140,7 +138,7 @@ To: mytopic@ntfy.sh Content-Type: text/plain; charset="UTF-8" you know this is a string. -it's a long string. +it's a long string. it's supposed to be longer than the max message length which is 4096 bytes, it used to be 512 bytes, but I increased that for the UnifiedPush support @@ -204,9 +202,9 @@ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB that should do it ` - conf, backend := newTestBackend(t, func(m *message) error { + conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) { expected := `you know this is a string. -it's a long string. +it's a long string. it's supposed to be longer than the max message length which is 4096 bytes, it used to be 512 bytes, but I increased that for the UnifiedPush support @@ -266,13 +264,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA ...................................................................... ...................................................................... and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB -BBBBBBBBBBBBBBBBBBBBBBBB` +BBBBBBBBBBBBBBBBBBBBBBBBB` require.Equal(t, 4096, len(expected)) // Sanity check - require.Equal(t, expected, m.Message) - return nil + require.Equal(t, expected, readAll(t, r.Body)) }) conf.SMTPServerAddrPrefix = "" - session, _ := backend.AnonymousLogin(nil) + session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4")) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Data(strings.NewReader(email))) @@ -288,21 +285,41 @@ Content-Type: text/SOMETHINGELSE what's up ` - conf, backend := newTestBackend(t, func(m *message) error { - return nil + conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) { + // Nothing. }) conf.SMTPServerAddrPrefix = "" - session, _ := backend.Login(nil, "user", "pass") + session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass") require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email))) } -func newTestBackend(t *testing.T, sub subscriber) (*Config, *smtpBackend) { +func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) { conf := newTestConfig(t) conf.SMTPServerListen = ":25" conf.SMTPServerDomain = "ntfy.sh" conf.SMTPServerAddrPrefix = "ntfy-" - backend := newMailBackend(conf, sub) + backend := newMailBackend(conf, handler) return conf, backend } + +func readAll(t *testing.T, rc io.ReadCloser) string { + b, err := io.ReadAll(rc) + if err != nil { + t.Fatal(err) + } + return string(b) +} + +func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState { + ip, err := net.ResolveIPAddr("ip", remoteAddr) + if err != nil { + t.Fatal(err) + } + return &smtp.ConnectionState{ + Hostname: "myhostname", + LocalAddr: ip, + RemoteAddr: ip, + } +} diff --git a/server/topic.go b/server/topic.go index 9badd7b..eb53225 100644 --- a/server/topic.go +++ b/server/topic.go @@ -15,7 +15,7 @@ type topic struct { } // subscriber is a function that is called for every new message on a topic -type subscriber func(msg *message) error +type subscriber func(v *visitor, msg *message) error // newTopic creates a new topic func newTopic(id string) *topic { @@ -42,12 +42,12 @@ func (t *topic) Unsubscribe(id int) { } // Publish asynchronously publishes to all subscribers -func (t *topic) Publish(m *message) error { +func (t *topic) Publish(v *visitor, m *message) error { go func() { t.mu.Lock() defer t.mu.Unlock() for _, s := range t.subscribers { - if err := s(m); err != nil { + if err := s(v, m); err != nil { log.Printf("error publishing message to subscriber") } } diff --git a/server/visitor.go b/server/visitor.go index 58cc28a..1bbc4e0 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -28,6 +28,7 @@ type visitor struct { emails *rate.Limiter subscriptions util.Limiter bandwidth util.Limiter + firebase time.Time // Next allowed Firebase message seen time.Time mu sync.Mutex } @@ -48,14 +49,11 @@ func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor { emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst), subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour), + firebase: time.Unix(0, 0), seen: time.Now(), } } -func (v *visitor) IP() string { - return v.ip -} - func (v *visitor) RequestAllowed() error { if !v.requests.Allow() { return errVisitorLimitReached @@ -63,6 +61,21 @@ func (v *visitor) RequestAllowed() error { return nil } +func (v *visitor) FirebaseAllowed() error { + v.mu.Lock() + defer v.mu.Unlock() + if time.Now().Before(v.firebase) { + return errVisitorLimitReached + } + return nil +} + +func (v *visitor) FirebaseTemporarilyDeny() { + v.mu.Lock() + defer v.mu.Unlock() + v.firebase = time.Now().Add(v.config.FirebaseQuotaLimitPenaltyDuration) +} + func (v *visitor) EmailAllowed() error { if !v.emails.Allow() { return errVisitorLimitReached