From c80e4e1aa9318b85a8abb67fe623451418079fe6 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Tue, 31 May 2022 23:16:44 -0400 Subject: [PATCH] Make Firebase logic testable, test it --- server/server.go | 72 +++++++++++++++---------------- server/server_firebase.go | 77 ++++++++++++++++++++++++++-------- server/server_firebase_test.go | 38 +++++++++++++++++ server/server_test.go | 49 +++++++--------------- 4 files changed, 147 insertions(+), 89 deletions(-) diff --git a/server/server.go b/server/server.go index 7384ab4..253a422 100644 --- a/server/server.go +++ b/server/server.go @@ -32,22 +32,22 @@ import ( // Server is the main server, providing the UI and API for ntfy type Server struct { - config *Config - httpServer *http.Server - httpsServer *http.Server - unixListener net.Listener - smtpServer *smtp.Server - smtpBackend *smtpBackend - topics map[string]*topic - visitors map[string]*visitor - firebase subscriber - mailer mailer - messages int64 - auth auth.Auther - messageCache *messageCache - fileCache *fileCache - closeChan chan bool - mu sync.Mutex + config *Config + httpServer *http.Server + httpsServer *http.Server + unixListener net.Listener + smtpServer *smtp.Server + smtpBackend *smtpBackend + topics map[string]*topic + visitors map[string]*visitor + firebaseClient *firebaseClient + mailer mailer + messages int64 + auth auth.Auther + messageCache *messageCache + fileCache *fileCache + closeChan chan bool + mu sync.Mutex } // handleFunc extends the normal http.HandlerFunc to be able to easily return errors @@ -134,23 +134,23 @@ func New(conf *Config) (*Server, error) { return nil, err } } - var firebaseSubscriber subscriber + var firebaseClient *firebaseClient if conf.FirebaseKeyFile != "" { - var err error - firebaseSubscriber, err = createFirebaseSubscriber(conf.FirebaseKeyFile, auther) + sender, err := newFirebaseSender(conf.FirebaseKeyFile) if err != nil { return nil, err } + firebaseClient = newFirebaseClient(sender, auther) } return &Server{ - config: conf, - messageCache: messageCache, - fileCache: fileCache, - firebase: firebaseSubscriber, - mailer: mailer, - topics: topics, - auth: auther, - visitors: make(map[string]*visitor), + config: conf, + messageCache: messageCache, + fileCache: fileCache, + firebaseClient: firebaseClient, + mailer: mailer, + topics: topics, + auth: auther, + visitors: make(map[string]*visitor), }, nil } @@ -437,7 +437,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito return err } } - if s.firebase != nil && firebase && !delayed { + if s.firebaseClient != nil && firebase && !delayed { go s.sendToFirebase(v, m) } if s.mailer != nil && email != "" && !delayed { @@ -463,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito } func (s *Server) sendToFirebase(v *visitor, m *message) { - if err := s.firebase(v, m); err != nil { + if err := s.firebaseClient.Send(v, m); err != nil { log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error()) } } @@ -1096,20 +1096,16 @@ func (s *Server) runDelayedSender() { } func (s *Server) runFirebaseKeepaliver() { - if s.firebase == nil { + if s.firebaseClient == nil { return } - v := newVisitor(s.config, s.messageCache, "0.0.0.0") + v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor for { select { case <-time.After(s.config.FirebaseKeepaliveInterval): - if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil { - log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error()) - } + s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic)) case <-time.After(s.config.FirebasePollInterval): - if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil { - log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error()) - } + s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic)) case <-s.closeChan: return } @@ -1142,7 +1138,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error { } }() } - if s.firebase != nil { // Firebase subscribers may not show up in topics map + if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map go s.sendToFirebase(v, m) } if s.config.UpstreamBaseURL != "" { diff --git a/server/server_firebase.go b/server/server_firebase.go index 8368337..47d2755 100644 --- a/server/server_firebase.go +++ b/server/server_firebase.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "errors" "fmt" "log" "strings" @@ -18,33 +19,75 @@ const ( fcmApnsBodyMessageLimit = 100 ) -func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subscriber, error) { +var ( + errFirebaseQuotaExceeded = errors.New("Firebase quota exceeded") +) + +// firebaseClient is a generic client that formats and sends messages to Firebase. +// The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable. +type firebaseClient struct { + sender firebaseSender + auther auth.Auther +} + +func newFirebaseClient(sender firebaseSender, auther auth.Auther) *firebaseClient { + return &firebaseClient{ + sender: sender, + auther: auther, + } +} + +func (c *firebaseClient) Send(v *visitor, m *message) error { + if err := v.FirebaseAllowed(); err != nil { + return errFirebaseQuotaExceeded + } + fbm, err := toFirebaseMessage(m, c.auther) + if err != nil { + return err + } + err = c.sender.Send(fbm) + if err == errFirebaseQuotaExceeded { + log.Printf("[%s] FB quota exceeded for topic %s, temporarily denying FB access to visitor", v.ip, m.Topic) + v.FirebaseTemporarilyDeny() + } + return err +} + +// firebaseSender is an interface that represents a client that can send to Firebase Cloud Messaging. +// In tests, this can be implemented with a mock. +type firebaseSender interface { + // Send sends a message to Firebase, or returns an error. It returns errFirebaseQuotaExceeded + // if a rate limit has reached. + Send(m *messaging.Message) error +} + +// firebaseSenderImpl is a firebaseSender that actually talks to Firebase +type firebaseSenderImpl struct { + client *messaging.Client +} + +func newFirebaseSender(credentialsFile string) (*firebaseSenderImpl, error) { fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile)) if err != nil { return nil, err } - msg, err := fb.Messaging(context.Background()) + client, err := fb.Messaging(context.Background()) if err != nil { return nil, err } - return func(v *visitor, m *message) error { - if err := v.FirebaseAllowed(); err != nil { - return errHTTPTooManyRequestsFirebaseQuotaReached - } - fbm, err := toFirebaseMessage(m, auther) - if err != nil { - return err - } - _, err = msg.Send(context.Background(), fbm) - if err != nil && messaging.IsQuotaExceeded(err) { - log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic) - v.FirebaseTemporarilyDeny() - return errHTTPTooManyRequestsFirebaseQuotaReached - } - return err + return &firebaseSenderImpl{ + client: client, }, nil } +func (c *firebaseSenderImpl) Send(m *messaging.Message) error { + _, err := c.client.Send(context.Background(), m) + if err != nil && messaging.IsQuotaExceeded(err) { + return errFirebaseQuotaExceeded + } + return err +} + // toFirebaseMessage converts a message to a Firebase message. // // Normal messages ("message"): diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go index 6ad6fde..8e08b0d 100644 --- a/server/server_firebase_test.go +++ b/server/server_firebase_test.go @@ -26,6 +26,25 @@ func (t testAuther) Authorize(_ *auth.User, _ string, _ auth.Permission) error { return errors.New("unauthorized") } +type testFirebaseSender struct { + allowed int + messages []*messaging.Message +} + +func newTestFirebaseSender(allowed int) *testFirebaseSender { + return &testFirebaseSender{ + allowed: allowed, + messages: make([]*messaging.Message, 0), + } +} +func (s *testFirebaseSender) Send(m *messaging.Message) error { + if len(s.messages)+1 > s.allowed { + return errFirebaseQuotaExceeded + } + s.messages = append(s.messages, m) + return nil +} + func TestToFirebaseMessage_Keepalive(t *testing.T) { m := newKeepaliveMessage("mytopic") fbm, err := toFirebaseMessage(m, nil) @@ -285,3 +304,22 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) { require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage)) require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"]) } + +func TestToFirebaseSender_Abuse(t *testing.T) { + sender := &testFirebaseSender{allowed: 2} + client := newFirebaseClient(sender, &testAuther{}) + visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4") + + require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) + require.Equal(t, 1, len(sender.messages)) + + require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) + require.Equal(t, 2, len(sender.messages)) + + require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"})) + require.Equal(t, 2, len(sender.messages)) + + sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working + require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"})) + require.Equal(t, 0, len(sender.messages)) +} diff --git a/server/server_test.go b/server/server_test.go index 1fec1f5..d05075f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -9,7 +9,6 @@ import ( "math/rand" "net/http" "net/http/httptest" - "os" "path/filepath" "strings" "sync" @@ -55,6 +54,21 @@ func TestServer_PublishAndPoll(t *testing.T) { require.Equal(t, "my second message", lines[1]) // \n -> " " } +func TestServer_PublishWithFirebase(t *testing.T) { + sender := newTestFirebaseSender(10) + s := newTestServer(t, newTestConfig(t)) + s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true}) + + response := request(t, s, "PUT", "/mytopic", "my first message", nil) + msg1 := toMessage(t, response.Body.String()) + require.NotEmpty(t, msg1.ID) + require.Equal(t, "my first message", msg1.Message) + require.Equal(t, 1, len(sender.messages)) + require.Equal(t, "my first message", sender.messages[0].Data["message"]) + require.Equal(t, "my first message", sender.messages[0].APNS.Payload.Aps.Alert.Body) + require.Equal(t, "my first message", sender.messages[0].APNS.Payload.CustomData["message"]) +} + func TestServer_SubscribeOpenAndKeepalive(t *testing.T) { c := newTestConfig(t) c.KeepaliveInterval = time.Second @@ -461,27 +475,6 @@ func TestServer_PublishMessageInHeaderWithNewlines(t *testing.T) { require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n ! } -func TestServer_PublishFirebase(t *testing.T) { - // This is unfortunately not much of a test, since it merely fires the messages towards Firebase, - // but cannot re-read them. There is no way from Go to read the messages back, or even get an error back. - // I tried everything. I already had written the test, and it increases the code coverage, so I'll leave it ... :shrug: ... - - c := newTestConfig(t) - c.FirebaseKeyFile = firebaseServiceAccountFile(t) // May skip the test! - s := newTestServer(t, c) - - // Normal message - response := request(t, s, "PUT", "/mytopic", "This is a message for firebase", nil) - msg := toMessage(t, response.Body.String()) - require.NotEmpty(t, msg.ID) - - // Keepalive message - v := newVisitor(s.config, s.messageCache, "1.2.3.4") - require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic))) - - time.Sleep(500 * time.Millisecond) // Time for sends -} - func TestServer_PublishInvalidTopic(t *testing.T) { s := newTestServer(t, newTestConfig(t)) s.mailer = &testMailer{} @@ -1341,18 +1334,6 @@ func toHTTPError(t *testing.T, s string) *errHTTP { return &e } -func firebaseServiceAccountFile(t *testing.T) string { - if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" { - return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") - } else if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT") != "" { - filename := filepath.Join(t.TempDir(), "firebase.json") - require.NotNil(t, os.WriteFile(filename, []byte(os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT")), 0o600)) - return filename - } - t.SkipNow() - return "" -} - func basicAuth(s string) string { return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s))) }