Firebase quota limit
This commit is contained in:
parent
8a81c8e95b
commit
8283b6be97
9 changed files with 180 additions and 119 deletions
|
@ -15,6 +15,7 @@ const (
|
||||||
DefaultMaxDelay = 3 * 24 * time.Hour
|
DefaultMaxDelay = 3 * 24 * time.Hour
|
||||||
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
|
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
|
||||||
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
|
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
|
||||||
|
DefaultFirebaseQuotaLimitPenaltyDuration = 10 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines all global and per-visitor limits
|
// Defines all global and per-visitor limits
|
||||||
|
@ -69,6 +70,7 @@ type Config struct {
|
||||||
AtSenderInterval time.Duration
|
AtSenderInterval time.Duration
|
||||||
FirebaseKeepaliveInterval time.Duration
|
FirebaseKeepaliveInterval time.Duration
|
||||||
FirebasePollInterval time.Duration
|
FirebasePollInterval time.Duration
|
||||||
|
FirebaseQuotaLimitPenaltyDuration time.Duration
|
||||||
UpstreamBaseURL string
|
UpstreamBaseURL string
|
||||||
SMTPSenderAddr string
|
SMTPSenderAddr string
|
||||||
SMTPSenderUser string
|
SMTPSenderUser string
|
||||||
|
@ -121,6 +123,7 @@ func NewConfig() *Config {
|
||||||
AtSenderInterval: DefaultAtSenderInterval,
|
AtSenderInterval: DefaultAtSenderInterval,
|
||||||
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
|
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
|
||||||
FirebasePollInterval: DefaultFirebasePollInterval,
|
FirebasePollInterval: DefaultFirebasePollInterval,
|
||||||
|
FirebaseQuotaLimitPenaltyDuration: DefaultFirebaseQuotaLimitPenaltyDuration,
|
||||||
TotalTopicLimit: DefaultTotalTopicLimit,
|
TotalTopicLimit: DefaultTotalTopicLimit,
|
||||||
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
|
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
|
||||||
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,
|
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,
|
||||||
|
|
|
@ -59,6 +59,7 @@ var (
|
||||||
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
|
errHTTPTooManyRequestsFirebaseQuotaReached = &errHTTP{42906, http.StatusTooManyRequests, "too many requests: Firebase quota for topic reached", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
||||||
errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
|
errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,13 +7,11 @@ import (
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
@ -221,7 +219,7 @@ func (s *Server) Run() error {
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
go s.runManager()
|
go s.runManager()
|
||||||
go s.runAtSender()
|
go s.runDelayedSender()
|
||||||
go s.runFirebaseKeepaliver()
|
go s.runFirebaseKeepaliver()
|
||||||
|
|
||||||
return <-errChan
|
return <-errChan
|
||||||
|
@ -435,7 +433,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||||
}
|
}
|
||||||
delayed := m.Time > time.Now().Unix()
|
delayed := m.Time > time.Now().Unix()
|
||||||
if !delayed {
|
if !delayed {
|
||||||
if err := t.Publish(m); err != nil {
|
if err := t.Publish(v, m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -465,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
||||||
if err := s.firebase(m); err != nil {
|
if err := s.firebase(v, m); err != nil {
|
||||||
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
|
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -731,7 +729,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var wlock sync.Mutex
|
var wlock sync.Mutex
|
||||||
sub := func(msg *message) error {
|
sub := func(v *visitor, msg *message) error {
|
||||||
if !filters.Pass(msg) {
|
if !filters.Pass(msg) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -752,7 +750,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||||
if poll {
|
if poll {
|
||||||
return s.sendOldMessages(topics, since, scheduled, sub)
|
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||||
}
|
}
|
||||||
subscriberIDs := make([]int, 0)
|
subscriberIDs := make([]int, 0)
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
|
@ -763,10 +761,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
topics[i].Unsubscribe(subscriberID) // Order!
|
topics[i].Unsubscribe(subscriberID) // Order!
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
|
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
|
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
|
@ -775,7 +773,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
return nil
|
return nil
|
||||||
case <-time.After(s.config.KeepaliveInterval):
|
case <-time.After(s.config.KeepaliveInterval):
|
||||||
v.Keepalive()
|
v.Keepalive()
|
||||||
if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -849,7 +847,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
sub := func(msg *message) error {
|
sub := func(v *visitor, msg *message) error {
|
||||||
if !filters.Pass(msg) {
|
if !filters.Pass(msg) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -862,7 +860,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
}
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||||
if poll {
|
if poll {
|
||||||
return s.sendOldMessages(topics, since, scheduled, sub)
|
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||||
}
|
}
|
||||||
subscriberIDs := make([]int, 0)
|
subscriberIDs := make([]int, 0)
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
|
@ -873,10 +871,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
topics[i].Unsubscribe(subscriberID) // Order!
|
topics[i].Unsubscribe(subscriberID) // Order!
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
|
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
|
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = g.Wait()
|
err = g.Wait()
|
||||||
|
@ -900,7 +898,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error {
|
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
||||||
if since.IsNone() {
|
if since.IsNone() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -910,7 +908,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
if err := sub(m); err != nil {
|
if err := sub(v, m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1057,23 +1055,7 @@ func (s *Server) updateStatsAndPrune() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) runSMTPServer() error {
|
func (s *Server) runSMTPServer() error {
|
||||||
sub := func(m *message) error {
|
s.smtpBackend = newMailBackend(s.config, s.handle)
|
||||||
url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
|
|
||||||
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if m.Title != "" {
|
|
||||||
req.Header.Set("Title", m.Title)
|
|
||||||
}
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
s.handle(rr, req)
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
return errors.New("error: " + rr.Body.String())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
s.smtpBackend = newMailBackend(s.config, sub)
|
|
||||||
s.smtpServer = smtp.NewServer(s.smtpBackend)
|
s.smtpServer = smtp.NewServer(s.smtpBackend)
|
||||||
s.smtpServer.Addr = s.config.SMTPServerListen
|
s.smtpServer.Addr = s.config.SMTPServerListen
|
||||||
s.smtpServer.Domain = s.config.SMTPServerDomain
|
s.smtpServer.Domain = s.config.SMTPServerDomain
|
||||||
|
@ -1096,7 +1078,7 @@ func (s *Server) runManager() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) runAtSender() {
|
func (s *Server) runDelayedSender() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-time.After(s.config.AtSenderInterval):
|
case <-time.After(s.config.AtSenderInterval):
|
||||||
|
@ -1113,14 +1095,15 @@ func (s *Server) runFirebaseKeepaliver() {
|
||||||
if s.firebase == nil {
|
if s.firebase == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
v := newVisitor(s.config, s.messageCache, "0.0.0.0")
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||||
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
||||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
|
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
|
||||||
}
|
}
|
||||||
case <-time.After(s.config.FirebasePollInterval):
|
case <-time.After(s.config.FirebasePollInterval):
|
||||||
if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil {
|
if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
|
||||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
|
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
|
||||||
}
|
}
|
||||||
case <-s.closeChan:
|
case <-s.closeChan:
|
||||||
|
@ -1130,28 +1113,36 @@ func (s *Server) runFirebaseKeepaliver() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendDelayedMessages() error {
|
func (s *Server) sendDelayedMessages() error {
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
messages, err := s.messageCache.MessagesDue()
|
messages, err := s.messageCache.MessagesDue()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
|
v := s.visitorFromIP("0.0.0.0") // FIXME: get message owner!!
|
||||||
|
if err := s.sendDelayedMessage(v, m); err != nil {
|
||||||
|
log.Printf("error sending delayed message: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
||||||
if ok {
|
if ok {
|
||||||
if err := t.Publish(m); err != nil {
|
if err := t.Publish(v, m); err != nil {
|
||||||
log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
|
return fmt.Errorf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if s.firebase != nil { // Firebase subscribers may not show up in topics map
|
if s.firebase != nil { // Firebase subscribers may not show up in topics map
|
||||||
if err := s.firebase(m); err != nil {
|
if err := s.firebase(v, m); err != nil {
|
||||||
log.Printf("unable to publish to Firebase: %v", err.Error())
|
return fmt.Errorf("unable to publish to Firebase: %v", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.messageCache.MarkPublished(m); err != nil {
|
if err := s.messageCache.MarkPublished(m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1290,8 +1281,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
|
||||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||||
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
||||||
func (s *Server) visitor(r *http.Request) *visitor {
|
func (s *Server) visitor(r *http.Request) *visitor {
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
remoteAddr := r.RemoteAddr
|
remoteAddr := r.RemoteAddr
|
||||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1300,6 +1289,12 @@ func (s *Server) visitor(r *http.Request) *visitor {
|
||||||
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
|
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
|
||||||
ip = r.Header.Get("X-Forwarded-For")
|
ip = r.Header.Get("X-Forwarded-For")
|
||||||
}
|
}
|
||||||
|
return s.visitorFromIP(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) visitorFromIP(ip string) *visitor {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
v, exists := s.visitors[ip]
|
v, exists := s.visitors[ip]
|
||||||
if !exists {
|
if !exists {
|
||||||
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)
|
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
firebase "firebase.google.com/go/v4"
|
firebase "firebase.google.com/go/v4"
|
||||||
|
@ -26,12 +27,20 @@ func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subsc
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return func(m *message) error {
|
return func(v *visitor, m *message) error {
|
||||||
|
if err := v.FirebaseAllowed(); err != nil {
|
||||||
|
return errHTTPTooManyRequestsFirebaseQuotaReached
|
||||||
|
}
|
||||||
fbm, err := toFirebaseMessage(m, auther)
|
fbm, err := toFirebaseMessage(m, auther)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = msg.Send(context.Background(), fbm)
|
_, err = msg.Send(context.Background(), fbm)
|
||||||
|
if err != nil && messaging.IsQuotaExceeded(err) {
|
||||||
|
log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
|
||||||
|
v.FirebaseTemporarilyDeny()
|
||||||
|
return errHTTPTooManyRequestsFirebaseQuotaReached
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -469,7 +469,8 @@ func TestServer_PublishFirebase(t *testing.T) {
|
||||||
require.NotEmpty(t, msg.ID)
|
require.NotEmpty(t, msg.ID)
|
||||||
|
|
||||||
// Keepalive message
|
// Keepalive message
|
||||||
require.Nil(t, s.firebase(newKeepaliveMessage(firebaseControlTopic)))
|
v := newVisitor(s.config, s.messageCache, "1.2.3.4")
|
||||||
|
require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond) // Time for sends
|
time.Sleep(500 * time.Millisecond) // Time for sends
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,10 +3,13 @@ package server
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
"io"
|
"io"
|
||||||
"mime"
|
"mime"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -23,25 +26,25 @@ var (
|
||||||
// smtpBackend implements SMTP server methods.
|
// smtpBackend implements SMTP server methods.
|
||||||
type smtpBackend struct {
|
type smtpBackend struct {
|
||||||
config *Config
|
config *Config
|
||||||
sub subscriber
|
handler func(http.ResponseWriter, *http.Request)
|
||||||
success int64
|
success int64
|
||||||
failure int64
|
failure int64
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMailBackend(conf *Config, sub subscriber) *smtpBackend {
|
func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
|
||||||
return &smtpBackend{
|
return &smtpBackend{
|
||||||
config: conf,
|
config: conf,
|
||||||
sub: sub,
|
handler: handler,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
|
func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
|
||||||
return &smtpSession{backend: b}, nil
|
return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
|
func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
|
||||||
return &smtpSession{backend: b}, nil
|
return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *smtpBackend) Counts() (success int64, failure int64) {
|
func (b *smtpBackend) Counts() (success int64, failure int64) {
|
||||||
|
@ -53,6 +56,7 @@ func (b *smtpBackend) Counts() (success int64, failure int64) {
|
||||||
// smtpSession is returned after EHLO.
|
// smtpSession is returned after EHLO.
|
||||||
type smtpSession struct {
|
type smtpSession struct {
|
||||||
backend *smtpBackend
|
backend *smtpBackend
|
||||||
|
remoteAddr string
|
||||||
topic string
|
topic string
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
@ -128,7 +132,7 @@ func (s *smtpSession) Data(r io.Reader) error {
|
||||||
m.Message = m.Title // Flip them, this makes more sense
|
m.Message = m.Title // Flip them, this makes more sense
|
||||||
m.Title = ""
|
m.Title = ""
|
||||||
}
|
}
|
||||||
if err := s.backend.sub(m); err != nil {
|
if err := s.publishMessage(m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.backend.mu.Lock()
|
s.backend.mu.Lock()
|
||||||
|
@ -138,6 +142,24 @@ func (s *smtpSession) Data(r io.Reader) error {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *smtpSession) publishMessage(m *message) error {
|
||||||
|
url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
|
||||||
|
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
|
||||||
|
req.RemoteAddr = s.remoteAddr // rate limiting!!
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if m.Title != "" {
|
||||||
|
req.Header.Set("Title", m.Title)
|
||||||
|
}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
s.backend.handler(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
return errors.New("error: " + rr.Body.String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *smtpSession) Reset() {
|
func (s *smtpSession) Reset() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.topic = ""
|
s.topic = ""
|
||||||
|
|
|
@ -3,6 +3,9 @@ package server
|
||||||
import (
|
import (
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -27,13 +30,12 @@ Content-Type: text/html; charset="UTF-8"
|
||||||
<div dir="ltr">what's up<br clear="all"><div><br></div></div>
|
<div dir="ltr">what's up<br clear="all"><div><br></div></div>
|
||||||
|
|
||||||
--000000000000f3320b05d42915c9--`
|
--000000000000f3320b05d42915c9--`
|
||||||
_, backend := newTestBackend(t, func(m *message) error {
|
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
require.Equal(t, "mytopic", m.Topic)
|
require.Equal(t, "/mytopic", r.URL.Path)
|
||||||
require.Equal(t, "and one more", m.Title)
|
require.Equal(t, "and one more", r.Header.Get("Title"))
|
||||||
require.Equal(t, "what's up", m.Message)
|
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
session, _ := backend.AnonymousLogin(nil)
|
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||||
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
|
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
|
||||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||||
|
@ -59,13 +61,12 @@ Content-Type: text/html; charset="UTF-8"
|
||||||
<div dir="ltr"><br></div>
|
<div dir="ltr"><br></div>
|
||||||
|
|
||||||
--000000000000bcf4a405d429f8d4--`
|
--000000000000bcf4a405d429f8d4--`
|
||||||
_, backend := newTestBackend(t, func(m *message) error {
|
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
require.Equal(t, "emailtest", m.Topic)
|
require.Equal(t, "/emailtest", r.URL.Path)
|
||||||
require.Equal(t, "", m.Title) // We flipped message and body
|
require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body
|
||||||
require.Equal(t, "This email has a subject but no body", m.Message)
|
require.Equal(t, "This email has a subject but no body", readAll(t, r.Body))
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
session, _ := backend.AnonymousLogin(nil)
|
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||||
require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
|
require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
|
||||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||||
|
@ -81,14 +82,13 @@ Content-Type: text/plain; charset="UTF-8"
|
||||||
|
|
||||||
what's up
|
what's up
|
||||||
`
|
`
|
||||||
conf, backend := newTestBackend(t, func(m *message) error {
|
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
require.Equal(t, "mytopic", m.Topic)
|
require.Equal(t, "/mytopic", r.URL.Path)
|
||||||
require.Equal(t, "and one more", m.Title)
|
require.Equal(t, "and one more", r.Header.Get("Title"))
|
||||||
require.Equal(t, "what's up", m.Message)
|
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
conf.SMTPServerAddrPrefix = ""
|
conf.SMTPServerAddrPrefix = ""
|
||||||
session, _ := backend.AnonymousLogin(nil)
|
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||||
|
@ -99,14 +99,13 @@ func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) {
|
||||||
|
|
||||||
what's up
|
what's up
|
||||||
`
|
`
|
||||||
conf, backend := newTestBackend(t, func(m *message) error {
|
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
require.Equal(t, "mytopic", m.Topic)
|
require.Equal(t, "/mytopic", r.URL.Path)
|
||||||
require.Equal(t, "Very short mail", m.Title)
|
require.Equal(t, "Very short mail", r.Header.Get("Title"))
|
||||||
require.Equal(t, "what's up", m.Message)
|
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
conf.SMTPServerAddrPrefix = ""
|
conf.SMTPServerAddrPrefix = ""
|
||||||
session, _ := backend.AnonymousLogin(nil)
|
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||||
|
@ -121,11 +120,10 @@ Content-Type: text/plain; charset="UTF-8"
|
||||||
|
|
||||||
what's up
|
what's up
|
||||||
`
|
`
|
||||||
_, backend := newTestBackend(t, func(m *message) error {
|
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
require.Equal(t, "Three santas 🎅🎅🎅", m.Title)
|
require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title"))
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
session, _ := backend.AnonymousLogin(nil)
|
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||||
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
|
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
|
||||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||||
|
@ -204,7 +202,7 @@ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
||||||
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
||||||
that should do it
|
that should do it
|
||||||
`
|
`
|
||||||
conf, backend := newTestBackend(t, func(m *message) error {
|
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
expected := `you know this is a string.
|
expected := `you know this is a string.
|
||||||
it's a long string.
|
it's a long string.
|
||||||
it's supposed to be longer than the max message length
|
it's supposed to be longer than the max message length
|
||||||
|
@ -266,13 +264,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||||
......................................................................
|
......................................................................
|
||||||
......................................................................
|
......................................................................
|
||||||
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
||||||
BBBBBBBBBBBBBBBBBBBBBBBB`
|
BBBBBBBBBBBBBBBBBBBBBBBBB`
|
||||||
require.Equal(t, 4096, len(expected)) // Sanity check
|
require.Equal(t, 4096, len(expected)) // Sanity check
|
||||||
require.Equal(t, expected, m.Message)
|
require.Equal(t, expected, readAll(t, r.Body))
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
conf.SMTPServerAddrPrefix = ""
|
conf.SMTPServerAddrPrefix = ""
|
||||||
session, _ := backend.AnonymousLogin(nil)
|
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||||
|
@ -288,21 +285,41 @@ Content-Type: text/SOMETHINGELSE
|
||||||
|
|
||||||
what's up
|
what's up
|
||||||
`
|
`
|
||||||
conf, backend := newTestBackend(t, func(m *message) error {
|
conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) {
|
||||||
return nil
|
// Nothing.
|
||||||
})
|
})
|
||||||
conf.SMTPServerAddrPrefix = ""
|
conf.SMTPServerAddrPrefix = ""
|
||||||
session, _ := backend.Login(nil, "user", "pass")
|
session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass")
|
||||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||||
require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
|
require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestBackend(t *testing.T, sub subscriber) (*Config, *smtpBackend) {
|
func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) {
|
||||||
conf := newTestConfig(t)
|
conf := newTestConfig(t)
|
||||||
conf.SMTPServerListen = ":25"
|
conf.SMTPServerListen = ":25"
|
||||||
conf.SMTPServerDomain = "ntfy.sh"
|
conf.SMTPServerDomain = "ntfy.sh"
|
||||||
conf.SMTPServerAddrPrefix = "ntfy-"
|
conf.SMTPServerAddrPrefix = "ntfy-"
|
||||||
backend := newMailBackend(conf, sub)
|
backend := newMailBackend(conf, handler)
|
||||||
return conf, backend
|
return conf, backend
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readAll(t *testing.T, rc io.ReadCloser) string {
|
||||||
|
b, err := io.ReadAll(rc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
|
||||||
|
ip, err := net.ResolveIPAddr("ip", remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return &smtp.ConnectionState{
|
||||||
|
Hostname: "myhostname",
|
||||||
|
LocalAddr: ip,
|
||||||
|
RemoteAddr: ip,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ type topic struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// subscriber is a function that is called for every new message on a topic
|
// subscriber is a function that is called for every new message on a topic
|
||||||
type subscriber func(msg *message) error
|
type subscriber func(v *visitor, msg *message) error
|
||||||
|
|
||||||
// newTopic creates a new topic
|
// newTopic creates a new topic
|
||||||
func newTopic(id string) *topic {
|
func newTopic(id string) *topic {
|
||||||
|
@ -42,12 +42,12 @@ func (t *topic) Unsubscribe(id int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish asynchronously publishes to all subscribers
|
// Publish asynchronously publishes to all subscribers
|
||||||
func (t *topic) Publish(m *message) error {
|
func (t *topic) Publish(v *visitor, m *message) error {
|
||||||
go func() {
|
go func() {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
for _, s := range t.subscribers {
|
for _, s := range t.subscribers {
|
||||||
if err := s(m); err != nil {
|
if err := s(v, m); err != nil {
|
||||||
log.Printf("error publishing message to subscriber")
|
log.Printf("error publishing message to subscriber")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ type visitor struct {
|
||||||
emails *rate.Limiter
|
emails *rate.Limiter
|
||||||
subscriptions util.Limiter
|
subscriptions util.Limiter
|
||||||
bandwidth util.Limiter
|
bandwidth util.Limiter
|
||||||
|
firebase time.Time // Next allowed Firebase message
|
||||||
seen time.Time
|
seen time.Time
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
@ -48,14 +49,11 @@ func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
|
||||||
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
|
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
|
||||||
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
|
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||||
bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
|
bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
|
||||||
|
firebase: time.Unix(0, 0),
|
||||||
seen: time.Now(),
|
seen: time.Now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) IP() string {
|
|
||||||
return v.ip
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *visitor) RequestAllowed() error {
|
func (v *visitor) RequestAllowed() error {
|
||||||
if !v.requests.Allow() {
|
if !v.requests.Allow() {
|
||||||
return errVisitorLimitReached
|
return errVisitorLimitReached
|
||||||
|
@ -63,6 +61,21 @@ func (v *visitor) RequestAllowed() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *visitor) FirebaseAllowed() error {
|
||||||
|
v.mu.Lock()
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
if time.Now().Before(v.firebase) {
|
||||||
|
return errVisitorLimitReached
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *visitor) FirebaseTemporarilyDeny() {
|
||||||
|
v.mu.Lock()
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
v.firebase = time.Now().Add(v.config.FirebaseQuotaLimitPenaltyDuration)
|
||||||
|
}
|
||||||
|
|
||||||
func (v *visitor) EmailAllowed() error {
|
func (v *visitor) EmailAllowed() error {
|
||||||
if !v.emails.Allow() {
|
if !v.emails.Allow() {
|
||||||
return errVisitorLimitReached
|
return errVisitorLimitReached
|
||||||
|
|
Loading…
Reference in a new issue