Make Firebase logic testable, test it

This commit is contained in:
Philipp Heckel 2022-05-31 23:16:44 -04:00
parent f9284a098a
commit c80e4e1aa9
4 changed files with 147 additions and 89 deletions

View file

@ -32,22 +32,22 @@ import (
// Server is the main server, providing the UI and API for ntfy // Server is the main server, providing the UI and API for ntfy
type Server struct { type Server struct {
config *Config config *Config
httpServer *http.Server httpServer *http.Server
httpsServer *http.Server httpsServer *http.Server
unixListener net.Listener unixListener net.Listener
smtpServer *smtp.Server smtpServer *smtp.Server
smtpBackend *smtpBackend smtpBackend *smtpBackend
topics map[string]*topic topics map[string]*topic
visitors map[string]*visitor visitors map[string]*visitor
firebase subscriber firebaseClient *firebaseClient
mailer mailer mailer mailer
messages int64 messages int64
auth auth.Auther auth auth.Auther
messageCache *messageCache messageCache *messageCache
fileCache *fileCache fileCache *fileCache
closeChan chan bool closeChan chan bool
mu sync.Mutex mu sync.Mutex
} }
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
@ -134,23 +134,23 @@ func New(conf *Config) (*Server, error) {
return nil, err return nil, err
} }
} }
var firebaseSubscriber subscriber var firebaseClient *firebaseClient
if conf.FirebaseKeyFile != "" { if conf.FirebaseKeyFile != "" {
var err error sender, err := newFirebaseSender(conf.FirebaseKeyFile)
firebaseSubscriber, err = createFirebaseSubscriber(conf.FirebaseKeyFile, auther)
if err != nil { if err != nil {
return nil, err return nil, err
} }
firebaseClient = newFirebaseClient(sender, auther)
} }
return &Server{ return &Server{
config: conf, config: conf,
messageCache: messageCache, messageCache: messageCache,
fileCache: fileCache, fileCache: fileCache,
firebase: firebaseSubscriber, firebaseClient: firebaseClient,
mailer: mailer, mailer: mailer,
topics: topics, topics: topics,
auth: auther, auth: auther,
visitors: make(map[string]*visitor), visitors: make(map[string]*visitor),
}, nil }, nil
} }
@ -437,7 +437,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
return err return err
} }
} }
if s.firebase != nil && firebase && !delayed { if s.firebaseClient != nil && firebase && !delayed {
go s.sendToFirebase(v, m) go s.sendToFirebase(v, m)
} }
if s.mailer != nil && email != "" && !delayed { if s.mailer != nil && email != "" && !delayed {
@ -463,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
} }
func (s *Server) sendToFirebase(v *visitor, m *message) { func (s *Server) sendToFirebase(v *visitor, m *message) {
if err := s.firebase(v, m); err != nil { if err := s.firebaseClient.Send(v, m); err != nil {
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error()) log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
} }
} }
@ -1096,20 +1096,16 @@ func (s *Server) runDelayedSender() {
} }
func (s *Server) runFirebaseKeepaliver() { func (s *Server) runFirebaseKeepaliver() {
if s.firebase == nil { if s.firebaseClient == nil {
return return
} }
v := newVisitor(s.config, s.messageCache, "0.0.0.0") v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
for { for {
select { select {
case <-time.After(s.config.FirebaseKeepaliveInterval): case <-time.After(s.config.FirebaseKeepaliveInterval):
if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil { s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
}
case <-time.After(s.config.FirebasePollInterval): case <-time.After(s.config.FirebasePollInterval):
if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil { s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
}
case <-s.closeChan: case <-s.closeChan:
return return
} }
@ -1142,7 +1138,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
} }
}() }()
} }
if s.firebase != nil { // Firebase subscribers may not show up in topics map if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
go s.sendToFirebase(v, m) go s.sendToFirebase(v, m)
} }
if s.config.UpstreamBaseURL != "" { if s.config.UpstreamBaseURL != "" {

View file

@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"log" "log"
"strings" "strings"
@ -18,33 +19,75 @@ const (
fcmApnsBodyMessageLimit = 100 fcmApnsBodyMessageLimit = 100
) )
func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subscriber, error) { var (
errFirebaseQuotaExceeded = errors.New("Firebase quota exceeded")
)
// firebaseClient is a generic client that formats and sends messages to Firebase.
// The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable.
type firebaseClient struct {
sender firebaseSender
auther auth.Auther
}
func newFirebaseClient(sender firebaseSender, auther auth.Auther) *firebaseClient {
return &firebaseClient{
sender: sender,
auther: auther,
}
}
func (c *firebaseClient) Send(v *visitor, m *message) error {
if err := v.FirebaseAllowed(); err != nil {
return errFirebaseQuotaExceeded
}
fbm, err := toFirebaseMessage(m, c.auther)
if err != nil {
return err
}
err = c.sender.Send(fbm)
if err == errFirebaseQuotaExceeded {
log.Printf("[%s] FB quota exceeded for topic %s, temporarily denying FB access to visitor", v.ip, m.Topic)
v.FirebaseTemporarilyDeny()
}
return err
}
// firebaseSender is an interface that represents a client that can send to Firebase Cloud Messaging.
// In tests, this can be implemented with a mock.
type firebaseSender interface {
// Send sends a message to Firebase, or returns an error. It returns errFirebaseQuotaExceeded
// if a rate limit has reached.
Send(m *messaging.Message) error
}
// firebaseSenderImpl is a firebaseSender that actually talks to Firebase
type firebaseSenderImpl struct {
client *messaging.Client
}
func newFirebaseSender(credentialsFile string) (*firebaseSenderImpl, error) {
fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile)) fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile))
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg, err := fb.Messaging(context.Background()) client, err := fb.Messaging(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return func(v *visitor, m *message) error { return &firebaseSenderImpl{
if err := v.FirebaseAllowed(); err != nil { client: client,
return errHTTPTooManyRequestsFirebaseQuotaReached
}
fbm, err := toFirebaseMessage(m, auther)
if err != nil {
return err
}
_, err = msg.Send(context.Background(), fbm)
if err != nil && messaging.IsQuotaExceeded(err) {
log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
v.FirebaseTemporarilyDeny()
return errHTTPTooManyRequestsFirebaseQuotaReached
}
return err
}, nil }, nil
} }
func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
_, err := c.client.Send(context.Background(), m)
if err != nil && messaging.IsQuotaExceeded(err) {
return errFirebaseQuotaExceeded
}
return err
}
// toFirebaseMessage converts a message to a Firebase message. // toFirebaseMessage converts a message to a Firebase message.
// //
// Normal messages ("message"): // Normal messages ("message"):

View file

@ -26,6 +26,25 @@ func (t testAuther) Authorize(_ *auth.User, _ string, _ auth.Permission) error {
return errors.New("unauthorized") return errors.New("unauthorized")
} }
type testFirebaseSender struct {
allowed int
messages []*messaging.Message
}
func newTestFirebaseSender(allowed int) *testFirebaseSender {
return &testFirebaseSender{
allowed: allowed,
messages: make([]*messaging.Message, 0),
}
}
func (s *testFirebaseSender) Send(m *messaging.Message) error {
if len(s.messages)+1 > s.allowed {
return errFirebaseQuotaExceeded
}
s.messages = append(s.messages, m)
return nil
}
func TestToFirebaseMessage_Keepalive(t *testing.T) { func TestToFirebaseMessage_Keepalive(t *testing.T) {
m := newKeepaliveMessage("mytopic") m := newKeepaliveMessage("mytopic")
fbm, err := toFirebaseMessage(m, nil) fbm, err := toFirebaseMessage(m, nil)
@ -285,3 +304,22 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage)) require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"]) require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
} }
func TestToFirebaseSender_Abuse(t *testing.T) {
sender := &testFirebaseSender{allowed: 2}
client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.messages))
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.messages))
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.messages))
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 0, len(sender.messages))
}

View file

@ -9,7 +9,6 @@ import (
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
@ -55,6 +54,21 @@ func TestServer_PublishAndPoll(t *testing.T) {
require.Equal(t, "my second message", lines[1]) // \n -> " " require.Equal(t, "my second message", lines[1]) // \n -> " "
} }
func TestServer_PublishWithFirebase(t *testing.T) {
sender := newTestFirebaseSender(10)
s := newTestServer(t, newTestConfig(t))
s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
response := request(t, s, "PUT", "/mytopic", "my first message", nil)
msg1 := toMessage(t, response.Body.String())
require.NotEmpty(t, msg1.ID)
require.Equal(t, "my first message", msg1.Message)
require.Equal(t, 1, len(sender.messages))
require.Equal(t, "my first message", sender.messages[0].Data["message"])
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.Aps.Alert.Body)
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.CustomData["message"])
}
func TestServer_SubscribeOpenAndKeepalive(t *testing.T) { func TestServer_SubscribeOpenAndKeepalive(t *testing.T) {
c := newTestConfig(t) c := newTestConfig(t)
c.KeepaliveInterval = time.Second c.KeepaliveInterval = time.Second
@ -461,27 +475,6 @@ func TestServer_PublishMessageInHeaderWithNewlines(t *testing.T) {
require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n ! require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n !
} }
func TestServer_PublishFirebase(t *testing.T) {
// This is unfortunately not much of a test, since it merely fires the messages towards Firebase,
// but cannot re-read them. There is no way from Go to read the messages back, or even get an error back.
// I tried everything. I already had written the test, and it increases the code coverage, so I'll leave it ... :shrug: ...
c := newTestConfig(t)
c.FirebaseKeyFile = firebaseServiceAccountFile(t) // May skip the test!
s := newTestServer(t, c)
// Normal message
response := request(t, s, "PUT", "/mytopic", "This is a message for firebase", nil)
msg := toMessage(t, response.Body.String())
require.NotEmpty(t, msg.ID)
// Keepalive message
v := newVisitor(s.config, s.messageCache, "1.2.3.4")
require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
time.Sleep(500 * time.Millisecond) // Time for sends
}
func TestServer_PublishInvalidTopic(t *testing.T) { func TestServer_PublishInvalidTopic(t *testing.T) {
s := newTestServer(t, newTestConfig(t)) s := newTestServer(t, newTestConfig(t))
s.mailer = &testMailer{} s.mailer = &testMailer{}
@ -1341,18 +1334,6 @@ func toHTTPError(t *testing.T, s string) *errHTTP {
return &e return &e
} }
func firebaseServiceAccountFile(t *testing.T) string {
if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" {
return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE")
} else if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT") != "" {
filename := filepath.Join(t.TempDir(), "firebase.json")
require.NotNil(t, os.WriteFile(filename, []byte(os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT")), 0o600))
return filename
}
t.SkipNow()
return ""
}
func basicAuth(s string) string { func basicAuth(s string) string {
return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s))) return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s)))
} }