Begin unit tests, relates to #35
This commit is contained in:
parent
da8f90d388
commit
be50af0a7a
11 changed files with 198 additions and 61 deletions
|
@ -1,10 +1,17 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnexpectedMessageType = errors.New("unexpected message type")
|
||||
)
|
||||
|
||||
// cache implements a cache for messages of type "message" events,
|
||||
// i.e. message structs with the Event messageEvent.
|
||||
type cache interface {
|
||||
AddMessage(m *message) error
|
||||
Messages(topic string, since sinceTime) ([]*message, error)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
@ -22,6 +21,9 @@ func newMemCache() *memCache {
|
|||
func (s *memCache) AddMessage(m *message) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if m.Event != messageEvent {
|
||||
return errUnexpectedMessageType
|
||||
}
|
||||
if _, ok := s.messages[m.Topic]; !ok {
|
||||
s.messages[m.Topic] = make([]*message, 0)
|
||||
}
|
||||
|
@ -32,7 +34,7 @@ func (s *memCache) AddMessage(m *message) error {
|
|||
func (s *memCache) Messages(topic string, since sinceTime) ([]*message, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.messages[topic]; !ok {
|
||||
if _, ok := s.messages[topic]; !ok || since.IsNone() {
|
||||
return make([]*message, 0), nil
|
||||
}
|
||||
messages := make([]*message, 0) // copy!
|
||||
|
@ -62,7 +64,7 @@ func (s *memCache) Topics() (map[string]*topic, error) {
|
|||
func (s *memCache) Prune(keep time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for topic, _ := range s.messages {
|
||||
for topic := range s.messages {
|
||||
s.pruneTopic(topic, keep)
|
||||
}
|
||||
return nil
|
||||
|
|
12
server/cache_mem_test.go
Normal file
12
server/cache_mem_test.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMemCache_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newMemCache())
|
||||
}
|
||||
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newMemCache())
|
||||
}
|
|
@ -81,11 +81,17 @@ func newSqliteCache(filename string) (*sqliteCache, error) {
|
|||
}
|
||||
|
||||
func (c *sqliteCache) AddMessage(m *message) error {
|
||||
if m.Event != messageEvent {
|
||||
return errUnexpectedMessageType
|
||||
}
|
||||
_, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message, m.Title, m.Priority, strings.Join(m.Tags, ","))
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *sqliteCache) Messages(topic string, since sinceTime) ([]*message, error) {
|
||||
if since.IsNone() {
|
||||
return make([]*message, 0), nil
|
||||
}
|
||||
rows, err := c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -99,9 +105,6 @@ func (c *sqliteCache) Messages(topic string, since sinceTime) ([]*message, error
|
|||
if err := rows.Scan(&id, ×tamp, &msg, &title, &priority, &tagsStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg == "" {
|
||||
msg = " " // Hack: never return empty messages; this should not happen
|
||||
}
|
||||
var tags []string
|
||||
if tagsStr != "" {
|
||||
tags = strings.Split(tagsStr, ",")
|
||||
|
@ -141,13 +144,13 @@ func (c *sqliteCache) MessageCount(topic string) (int, error) {
|
|||
return count, nil
|
||||
}
|
||||
|
||||
func (s *sqliteCache) Topics() (map[string]*topic, error) {
|
||||
rows, err := s.db.Query(selectTopicsQuery)
|
||||
func (c *sqliteCache) Topics() (map[string]*topic, error) {
|
||||
rows, err := c.db.Query(selectTopicsQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
topics := make(map[string]*topic, 0)
|
||||
topics := make(map[string]*topic)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
var last int64
|
||||
|
@ -162,8 +165,8 @@ func (s *sqliteCache) Topics() (map[string]*topic, error) {
|
|||
return topics, nil
|
||||
}
|
||||
|
||||
func (s *sqliteCache) Prune(keep time.Duration) error {
|
||||
_, err := s.db.Exec(pruneMessagesQuery, time.Now().Add(-1*keep).Unix())
|
||||
func (c *sqliteCache) Prune(keep time.Duration) error {
|
||||
_, err := c.db.Exec(pruneMessagesQuery, time.Now().Add(-1*keep).Unix())
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
23
server/cache_sqlite_test.go
Normal file
23
server/cache_sqlite_test.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSqliteCache_AddMessage(t *testing.T) {
|
||||
testCacheMessages(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func newSqliteTestCache(t *testing.T) cache {
|
||||
filename := filepath.Join(t.TempDir(), "cache.db")
|
||||
c, err := newSqliteCache(filename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return c
|
||||
}
|
79
server/cache_test.go
Normal file
79
server/cache_test.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func testCacheMessages(t *testing.T, c cache) {
|
||||
m1 := newDefaultMessage("mytopic", "my message")
|
||||
m1.Time = 1
|
||||
|
||||
m2 := newDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = 2
|
||||
|
||||
assert.Nil(t, c.AddMessage(m1))
|
||||
assert.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message")))
|
||||
assert.Nil(t, c.AddMessage(m2))
|
||||
|
||||
// Adding invalid
|
||||
assert.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added!
|
||||
assert.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
|
||||
|
||||
// mytopic: count
|
||||
count, err := c.MessageCount("mytopic")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
// mytopic: since all
|
||||
messages, _ := c.Messages("mytopic", sinceAllMessages)
|
||||
assert.Equal(t, 2, len(messages))
|
||||
assert.Equal(t, "my message", messages[0].Message)
|
||||
assert.Equal(t, "mytopic", messages[0].Topic)
|
||||
assert.Equal(t, messageEvent, messages[0].Event)
|
||||
assert.Equal(t, "", messages[0].Title)
|
||||
assert.Equal(t, 0, messages[0].Priority)
|
||||
assert.Nil(t, messages[0].Tags)
|
||||
assert.Equal(t, "my other message", messages[1].Message)
|
||||
|
||||
// mytopic: since none
|
||||
messages, _ = c.Messages("mytopic", sinceNoMessages)
|
||||
assert.Empty(t, messages)
|
||||
|
||||
// mytopic: since 2
|
||||
messages, _ = c.Messages("mytopic", sinceTime(time.Unix(2, 0)))
|
||||
assert.Equal(t, 1, len(messages))
|
||||
assert.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// example: count
|
||||
count, err = c.MessageCount("example")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
|
||||
// example: since all
|
||||
messages, _ = c.Messages("example", sinceAllMessages)
|
||||
assert.Equal(t, "my example message", messages[0].Message)
|
||||
|
||||
// non-existing: count
|
||||
count, err = c.MessageCount("doesnotexist")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
|
||||
// non-existing: since all
|
||||
messages, _ = c.Messages("doesnotexist", sinceAllMessages)
|
||||
assert.Empty(t, messages)
|
||||
}
|
||||
|
||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) {
|
||||
m := newDefaultMessage("mytopic", "some message")
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
m.Priority = 5
|
||||
m.Title = "some title"
|
||||
assert.Nil(t, c.AddMessage(m))
|
||||
|
||||
messages, _ := c.Messages("mytopic", sinceAllMessages)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
||||
assert.Equal(t, 5, messages[0].Priority)
|
||||
assert.Equal(t, "some title", messages[0].Title)
|
||||
}
|
|
@ -3,8 +3,7 @@ package server
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"embed"
|
||||
_ "embed" // required for go:embed
|
||||
"embed" // required for go:embed
|
||||
"encoding/json"
|
||||
firebase "firebase.google.com/go"
|
||||
"firebase.google.com/go/messaging"
|
||||
|
@ -27,7 +26,7 @@ import (
|
|||
// TODO add "max messages in a topic" limit
|
||||
// TODO implement "since=<ID>"
|
||||
|
||||
// Server is the main server
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
topics map[string]*topic
|
||||
|
@ -105,6 +104,8 @@ var (
|
|||
errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)}
|
||||
)
|
||||
|
||||
// New instantiates a new Server. It creates the cache and adds a Firebase
|
||||
// subscriber (if configured).
|
||||
func New(conf *config.Config) (*Server, error) {
|
||||
var firebaseSubscriber subscriber
|
||||
if conf.FirebaseKeyFile != "" {
|
||||
|
@ -170,6 +171,8 @@ func createFirebaseSubscriber(conf *config.Config) (subscriber, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
|
||||
// a manager go routine to print stats and prune messages.
|
||||
func (s *Server) Run() error {
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.config.ManagerInterval)
|
||||
|
@ -241,11 +244,11 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleEmpty(w http.ResponseWriter, r *http.Request) error {
|
||||
func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleExample(w http.ResponseWriter, r *http.Request) error {
|
||||
func (s *Server) handleExample(w http.ResponseWriter, _ *http.Request) error {
|
||||
_, err := io.WriteString(w, exampleSource)
|
||||
return err
|
||||
}
|
||||
|
@ -260,7 +263,7 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
t, err := s.topicFromID(r.URL.Path[1:])
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -466,7 +469,7 @@ func parseSince(r *http.Request) (sinceTime, error) {
|
|||
return sinceNoMessages, errHTTPBadRequest
|
||||
}
|
||||
|
||||
func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request) error {
|
||||
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request) error {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST")
|
||||
return nil
|
||||
|
@ -570,5 +573,5 @@ func (s *Server) visitor(r *http.Request) *visitor {
|
|||
func (s *Server) fail(w http.ResponseWriter, r *http.Request, code int, err error) {
|
||||
log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, code, err.Error())
|
||||
w.WriteHeader(code)
|
||||
io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code)))
|
||||
_, _ = io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code)))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue