Begin unit tests, relates to #35

This commit is contained in:
Philipp Heckel 2021-12-07 11:45:15 -05:00
parent da8f90d388
commit be50af0a7a
11 changed files with 198 additions and 61 deletions

View file

@ -1,10 +1,17 @@
package server
import (
"errors"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"time"
)
var (
errUnexpectedMessageType = errors.New("unexpected message type")
)
// cache implements a cache for messages of type "message" events,
// i.e. message structs with the Event messageEvent.
type cache interface {
AddMessage(m *message) error
Messages(topic string, since sinceTime) ([]*message, error)

View file

@ -1,7 +1,6 @@
package server
import (
_ "github.com/mattn/go-sqlite3" // SQLite driver
"sync"
"time"
)
@ -22,6 +21,9 @@ func newMemCache() *memCache {
func (s *memCache) AddMessage(m *message) error {
s.mu.Lock()
defer s.mu.Unlock()
if m.Event != messageEvent {
return errUnexpectedMessageType
}
if _, ok := s.messages[m.Topic]; !ok {
s.messages[m.Topic] = make([]*message, 0)
}
@ -32,7 +34,7 @@ func (s *memCache) AddMessage(m *message) error {
func (s *memCache) Messages(topic string, since sinceTime) ([]*message, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.messages[topic]; !ok {
if _, ok := s.messages[topic]; !ok || since.IsNone() {
return make([]*message, 0), nil
}
messages := make([]*message, 0) // copy!
@ -62,7 +64,7 @@ func (s *memCache) Topics() (map[string]*topic, error) {
func (s *memCache) Prune(keep time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
for topic, _ := range s.messages {
for topic := range s.messages {
s.pruneTopic(topic, keep)
}
return nil

12
server/cache_mem_test.go Normal file
View file

@ -0,0 +1,12 @@
package server
import (
"testing"
)
func TestMemCache_Messages(t *testing.T) {
testCacheMessages(t, newMemCache())
}
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newMemCache())
}

View file

@ -81,11 +81,17 @@ func newSqliteCache(filename string) (*sqliteCache, error) {
}
func (c *sqliteCache) AddMessage(m *message) error {
if m.Event != messageEvent {
return errUnexpectedMessageType
}
_, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message, m.Title, m.Priority, strings.Join(m.Tags, ","))
return err
}
func (c *sqliteCache) Messages(topic string, since sinceTime) ([]*message, error) {
if since.IsNone() {
return make([]*message, 0), nil
}
rows, err := c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix())
if err != nil {
return nil, err
@ -99,9 +105,6 @@ func (c *sqliteCache) Messages(topic string, since sinceTime) ([]*message, error
if err := rows.Scan(&id, &timestamp, &msg, &title, &priority, &tagsStr); err != nil {
return nil, err
}
if msg == "" {
msg = " " // Hack: never return empty messages; this should not happen
}
var tags []string
if tagsStr != "" {
tags = strings.Split(tagsStr, ",")
@ -141,13 +144,13 @@ func (c *sqliteCache) MessageCount(topic string) (int, error) {
return count, nil
}
func (s *sqliteCache) Topics() (map[string]*topic, error) {
rows, err := s.db.Query(selectTopicsQuery)
func (c *sqliteCache) Topics() (map[string]*topic, error) {
rows, err := c.db.Query(selectTopicsQuery)
if err != nil {
return nil, err
}
defer rows.Close()
topics := make(map[string]*topic, 0)
topics := make(map[string]*topic)
for rows.Next() {
var id string
var last int64
@ -162,8 +165,8 @@ func (s *sqliteCache) Topics() (map[string]*topic, error) {
return topics, nil
}
func (s *sqliteCache) Prune(keep time.Duration) error {
_, err := s.db.Exec(pruneMessagesQuery, time.Now().Add(-1*keep).Unix())
func (c *sqliteCache) Prune(keep time.Duration) error {
_, err := c.db.Exec(pruneMessagesQuery, time.Now().Add(-1*keep).Unix())
return err
}

View file

@ -0,0 +1,23 @@
package server
import (
"path/filepath"
"testing"
)
func TestSqliteCache_AddMessage(t *testing.T) {
testCacheMessages(t, newSqliteTestCache(t))
}
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
}
func newSqliteTestCache(t *testing.T) cache {
filename := filepath.Join(t.TempDir(), "cache.db")
c, err := newSqliteCache(filename)
if err != nil {
t.Fatal(err)
}
return c
}

79
server/cache_test.go Normal file
View file

@ -0,0 +1,79 @@
package server
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func testCacheMessages(t *testing.T, c cache) {
m1 := newDefaultMessage("mytopic", "my message")
m1.Time = 1
m2 := newDefaultMessage("mytopic", "my other message")
m2.Time = 2
assert.Nil(t, c.AddMessage(m1))
assert.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message")))
assert.Nil(t, c.AddMessage(m2))
// Adding invalid
assert.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added!
assert.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
// mytopic: count
count, err := c.MessageCount("mytopic")
assert.Nil(t, err)
assert.Equal(t, 2, count)
// mytopic: since all
messages, _ := c.Messages("mytopic", sinceAllMessages)
assert.Equal(t, 2, len(messages))
assert.Equal(t, "my message", messages[0].Message)
assert.Equal(t, "mytopic", messages[0].Topic)
assert.Equal(t, messageEvent, messages[0].Event)
assert.Equal(t, "", messages[0].Title)
assert.Equal(t, 0, messages[0].Priority)
assert.Nil(t, messages[0].Tags)
assert.Equal(t, "my other message", messages[1].Message)
// mytopic: since none
messages, _ = c.Messages("mytopic", sinceNoMessages)
assert.Empty(t, messages)
// mytopic: since 2
messages, _ = c.Messages("mytopic", sinceTime(time.Unix(2, 0)))
assert.Equal(t, 1, len(messages))
assert.Equal(t, "my other message", messages[0].Message)
// example: count
count, err = c.MessageCount("example")
assert.Nil(t, err)
assert.Equal(t, 1, count)
// example: since all
messages, _ = c.Messages("example", sinceAllMessages)
assert.Equal(t, "my example message", messages[0].Message)
// non-existing: count
count, err = c.MessageCount("doesnotexist")
assert.Nil(t, err)
assert.Equal(t, 0, count)
// non-existing: since all
messages, _ = c.Messages("doesnotexist", sinceAllMessages)
assert.Empty(t, messages)
}
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) {
m := newDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
m.Title = "some title"
assert.Nil(t, c.AddMessage(m))
messages, _ := c.Messages("mytopic", sinceAllMessages)
assert.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
assert.Equal(t, 5, messages[0].Priority)
assert.Equal(t, "some title", messages[0].Title)
}

View file

@ -3,8 +3,7 @@ package server
import (
"bytes"
"context"
"embed"
_ "embed" // required for go:embed
"embed" // required for go:embed
"encoding/json"
firebase "firebase.google.com/go"
"firebase.google.com/go/messaging"
@ -27,7 +26,7 @@ import (
// TODO add "max messages in a topic" limit
// TODO implement "since=<ID>"
// Server is the main server
// Server is the main server, providing the UI and API for ntfy
type Server struct {
config *config.Config
topics map[string]*topic
@ -105,6 +104,8 @@ var (
errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)}
)
// New instantiates a new Server. It creates the cache and adds a Firebase
// subscriber (if configured).
func New(conf *config.Config) (*Server, error) {
var firebaseSubscriber subscriber
if conf.FirebaseKeyFile != "" {
@ -170,6 +171,8 @@ func createFirebaseSubscriber(conf *config.Config) (subscriber, error) {
}, nil
}
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
// a manager go routine to print stats and prune messages.
func (s *Server) Run() error {
go func() {
ticker := time.NewTicker(s.config.ManagerInterval)
@ -241,11 +244,11 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
})
}
func (s *Server) handleEmpty(w http.ResponseWriter, r *http.Request) error {
func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request) error {
return nil
}
func (s *Server) handleExample(w http.ResponseWriter, r *http.Request) error {
func (s *Server) handleExample(w http.ResponseWriter, _ *http.Request) error {
_, err := io.WriteString(w, exampleSource)
return err
}
@ -260,7 +263,7 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error {
return nil
}
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visitor) error {
t, err := s.topicFromID(r.URL.Path[1:])
if err != nil {
return err
@ -466,7 +469,7 @@ func parseSince(r *http.Request) (sinceTime, error) {
return sinceNoMessages, errHTTPBadRequest
}
func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request) error {
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request) error {
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST")
return nil
@ -570,5 +573,5 @@ func (s *Server) visitor(r *http.Request) *visitor {
func (s *Server) fail(w http.ResponseWriter, r *http.Request, code int, err error) {
log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, code, err.Error())
w.WriteHeader(code)
io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code)))
_, _ = io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code)))
}