Tests, client tests WIP

This commit is contained in:
Philipp Heckel 2021-12-22 14:17:50 +01:00
parent 68d881291c
commit 6a7e9071b6
7 changed files with 104 additions and 30 deletions

42
client/client_test.go Normal file
View file

@ -0,0 +1,42 @@
package client_test
import (
"github.com/stretchr/testify/require"
"heckel.io/ntfy/client"
"heckel.io/ntfy/server"
"net/http"
"testing"
"time"
)
func TestClient_Publish(t *testing.T) {
s := startTestServer(t)
defer s.Stop()
c := client.New(newTestConfig())
time.Sleep(time.Second) // FIXME Wait for port up
_, err := c.Publish("mytopic", "some message")
require.Nil(t, err)
}
func newTestConfig() *client.Config {
c := client.NewConfig()
c.DefaultHost = "http://127.0.0.1:12345"
return c
}
func startTestServer(t *testing.T) *server.Server {
conf := server.NewConfig()
conf.ListenHTTP = ":12345"
s, err := server.New(conf)
if err != nil {
t.Fatal(err)
}
go func() {
if err := s.Run(); err != nil && err != http.ErrServerClosed {
panic(err) // 'go vet' complains about 't.Fatal(err)'
}
}()
return s
}

View file

@ -1,7 +1,8 @@
package client
package client_test
import (
"github.com/stretchr/testify/require"
"heckel.io/ntfy/client"
"os"
"path/filepath"
"testing"
@ -21,7 +22,7 @@ subscribe:
priority: high,urgent
`), 0600))
conf, err := LoadConfig(filename)
conf, err := client.LoadConfig(filename)
require.Nil(t, err)
require.Equal(t, "http://localhost", conf.DefaultHost)
require.Equal(t, 3, len(conf.Subscribe))

View file

@ -85,7 +85,8 @@ func execServe(c *cli.Context) error {
}
// Run server
conf := server.NewConfig(listenHTTP)
conf := server.NewConfig()
conf.ListenHTTP = listenHTTP
conf.ListenHTTPS = listenHTTPS
conf.KeyFile = keyFile
conf.CertFile = certFile

View file

@ -52,9 +52,9 @@ type Config struct {
}
// NewConfig instantiates a default new server config
func NewConfig(listenHTTP string) *Config {
func NewConfig() *Config {
return &Config{
ListenHTTP: listenHTTP,
ListenHTTP: DefaultListenHTTP,
ListenHTTPS: "",
KeyFile: "",
CertFile: "",

View file

@ -7,6 +7,7 @@ import (
)
func TestConfig_New(t *testing.T) {
c := server.NewConfig(":1234")
assert.Equal(t, ":1234", c.ListenHTTP)
c := server.NewConfig()
assert.Equal(t, ":80", c.ListenHTTP)
assert.Equal(t, server.DefaultKeepaliveInterval, c.KeepaliveInterval)
}

View file

@ -27,13 +27,16 @@ import (
// Server is the main server, providing the UI and API for ntfy
type Server struct {
config *Config
topics map[string]*topic
visitors map[string]*visitor
firebase subscriber
messages int64
cache cache
mu sync.Mutex
config *Config
httpServer *http.Server
httpsServer *http.Server
topics map[string]*topic
visitors map[string]*visitor
firebase subscriber
messages int64
cache cache
closeChan chan bool
mu sync.Mutex
}
// errHTTP is a generic HTTP error for any non-200 HTTP error
@ -198,17 +201,35 @@ func (s *Server) Run() error {
log.Printf("Listening on %s", listenStr)
http.HandleFunc("/", s.handle)
errChan := make(chan error)
s.mu.Lock()
s.closeChan = make(chan bool)
s.httpServer = &http.Server{Addr: s.config.ListenHTTP}
go func() {
errChan <- http.ListenAndServe(s.config.ListenHTTP, nil)
errChan <- s.httpServer.ListenAndServe()
}()
if s.config.ListenHTTPS != "" {
s.httpsServer = &http.Server{Addr: s.config.ListenHTTP}
go func() {
errChan <- http.ListenAndServeTLS(s.config.ListenHTTPS, s.config.CertFile, s.config.KeyFile, nil)
errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
}()
}
s.mu.Unlock()
return <-errChan
}
// Stop stops HTTP (+HTTPS) server and all managers
func (s *Server) Stop() {
s.mu.Lock()
defer s.mu.Unlock()
if s.httpServer != nil {
s.httpServer.Close()
}
if s.httpsServer != nil {
s.httpsServer.Close()
}
close(s.closeChan)
}
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
if err := s.handleInternal(w, r); err != nil {
if e, ok := err.(*errHTTP); ok {
@ -635,21 +656,25 @@ func (s *Server) updateStatsAndPrune() {
}
func (s *Server) runManager() {
func() {
ticker := time.NewTicker(s.config.ManagerInterval)
for {
<-ticker.C
for {
select {
case <-time.After(s.config.ManagerInterval):
s.updateStatsAndPrune()
case <-s.closeChan:
return
}
}()
}
}
func (s *Server) runAtSender() {
ticker := time.NewTicker(s.config.AtSenderInterval)
for {
<-ticker.C
if err := s.sendDelayedMessages(); err != nil {
log.Printf("error sending scheduled messages: %s", err.Error())
select {
case <-time.After(s.config.AtSenderInterval):
if err := s.sendDelayedMessages(); err != nil {
log.Printf("error sending scheduled messages: %s", err.Error())
}
case <-s.closeChan:
return
}
}
}
@ -658,14 +683,18 @@ func (s *Server) runFirebaseKeepliver() {
if s.firebase == nil {
return
}
ticker := time.NewTicker(s.config.FirebaseKeepaliveInterval)
for {
<-ticker.C
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
log.Printf("error sending Firebase keepalive message: %s", err.Error())
select {
case <-time.After(s.config.FirebaseKeepaliveInterval):
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
log.Printf("error sending Firebase keepalive message: %s", err.Error())
}
case <-s.closeChan:
return
}
}
}
func (s *Server) sendDelayedMessages() error {
s.mu.Lock()
defer s.mu.Unlock()

View file

@ -488,7 +488,7 @@ func TestServer_SubscribeWithQueryFilters(t *testing.T) {
}
func newTestConfig(t *testing.T) *Config {
conf := NewConfig(":80")
conf := NewConfig()
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
return conf
}