Tests, client tests WIP
This commit is contained in:
parent
68d881291c
commit
6a7e9071b6
7 changed files with 104 additions and 30 deletions
42
client/client_test.go
Normal file
42
client/client_test.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package client_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/client"
|
||||||
|
"heckel.io/ntfy/server"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClient_Publish(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
defer s.Stop()
|
||||||
|
c := client.New(newTestConfig())
|
||||||
|
|
||||||
|
time.Sleep(time.Second) // FIXME Wait for port up
|
||||||
|
|
||||||
|
_, err := c.Publish("mytopic", "some message")
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestConfig() *client.Config {
|
||||||
|
c := client.NewConfig()
|
||||||
|
c.DefaultHost = "http://127.0.0.1:12345"
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func startTestServer(t *testing.T) *server.Server {
|
||||||
|
conf := server.NewConfig()
|
||||||
|
conf.ListenHTTP = ":12345"
|
||||||
|
s, err := server.New(conf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
if err := s.Run(); err != nil && err != http.ErrServerClosed {
|
||||||
|
panic(err) // 'go vet' complains about 't.Fatal(err)'
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return s
|
||||||
|
}
|
|
@ -1,7 +1,8 @@
|
||||||
package client
|
package client_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/client"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -21,7 +22,7 @@ subscribe:
|
||||||
priority: high,urgent
|
priority: high,urgent
|
||||||
`), 0600))
|
`), 0600))
|
||||||
|
|
||||||
conf, err := LoadConfig(filename)
|
conf, err := client.LoadConfig(filename)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "http://localhost", conf.DefaultHost)
|
require.Equal(t, "http://localhost", conf.DefaultHost)
|
||||||
require.Equal(t, 3, len(conf.Subscribe))
|
require.Equal(t, 3, len(conf.Subscribe))
|
||||||
|
|
|
@ -85,7 +85,8 @@ func execServe(c *cli.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
conf := server.NewConfig(listenHTTP)
|
conf := server.NewConfig()
|
||||||
|
conf.ListenHTTP = listenHTTP
|
||||||
conf.ListenHTTPS = listenHTTPS
|
conf.ListenHTTPS = listenHTTPS
|
||||||
conf.KeyFile = keyFile
|
conf.KeyFile = keyFile
|
||||||
conf.CertFile = certFile
|
conf.CertFile = certFile
|
||||||
|
|
|
@ -52,9 +52,9 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConfig instantiates a default new server config
|
// NewConfig instantiates a default new server config
|
||||||
func NewConfig(listenHTTP string) *Config {
|
func NewConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
ListenHTTP: listenHTTP,
|
ListenHTTP: DefaultListenHTTP,
|
||||||
ListenHTTPS: "",
|
ListenHTTPS: "",
|
||||||
KeyFile: "",
|
KeyFile: "",
|
||||||
CertFile: "",
|
CertFile: "",
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_New(t *testing.T) {
|
func TestConfig_New(t *testing.T) {
|
||||||
c := server.NewConfig(":1234")
|
c := server.NewConfig()
|
||||||
assert.Equal(t, ":1234", c.ListenHTTP)
|
assert.Equal(t, ":80", c.ListenHTTP)
|
||||||
|
assert.Equal(t, server.DefaultKeepaliveInterval, c.KeepaliveInterval)
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,13 +27,16 @@ import (
|
||||||
|
|
||||||
// Server is the main server, providing the UI and API for ntfy
|
// Server is the main server, providing the UI and API for ntfy
|
||||||
type Server struct {
|
type Server struct {
|
||||||
config *Config
|
config *Config
|
||||||
topics map[string]*topic
|
httpServer *http.Server
|
||||||
visitors map[string]*visitor
|
httpsServer *http.Server
|
||||||
firebase subscriber
|
topics map[string]*topic
|
||||||
messages int64
|
visitors map[string]*visitor
|
||||||
cache cache
|
firebase subscriber
|
||||||
mu sync.Mutex
|
messages int64
|
||||||
|
cache cache
|
||||||
|
closeChan chan bool
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// errHTTP is a generic HTTP error for any non-200 HTTP error
|
// errHTTP is a generic HTTP error for any non-200 HTTP error
|
||||||
|
@ -198,17 +201,35 @@ func (s *Server) Run() error {
|
||||||
log.Printf("Listening on %s", listenStr)
|
log.Printf("Listening on %s", listenStr)
|
||||||
http.HandleFunc("/", s.handle)
|
http.HandleFunc("/", s.handle)
|
||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
|
s.mu.Lock()
|
||||||
|
s.closeChan = make(chan bool)
|
||||||
|
s.httpServer = &http.Server{Addr: s.config.ListenHTTP}
|
||||||
go func() {
|
go func() {
|
||||||
errChan <- http.ListenAndServe(s.config.ListenHTTP, nil)
|
errChan <- s.httpServer.ListenAndServe()
|
||||||
}()
|
}()
|
||||||
if s.config.ListenHTTPS != "" {
|
if s.config.ListenHTTPS != "" {
|
||||||
|
s.httpsServer = &http.Server{Addr: s.config.ListenHTTP}
|
||||||
go func() {
|
go func() {
|
||||||
errChan <- http.ListenAndServeTLS(s.config.ListenHTTPS, s.config.CertFile, s.config.KeyFile, nil)
|
errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
return <-errChan
|
return <-errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop stops HTTP (+HTTPS) server and all managers
|
||||||
|
func (s *Server) Stop() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.httpServer != nil {
|
||||||
|
s.httpServer.Close()
|
||||||
|
}
|
||||||
|
if s.httpsServer != nil {
|
||||||
|
s.httpsServer.Close()
|
||||||
|
}
|
||||||
|
close(s.closeChan)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := s.handleInternal(w, r); err != nil {
|
if err := s.handleInternal(w, r); err != nil {
|
||||||
if e, ok := err.(*errHTTP); ok {
|
if e, ok := err.(*errHTTP); ok {
|
||||||
|
@ -635,21 +656,25 @@ func (s *Server) updateStatsAndPrune() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) runManager() {
|
func (s *Server) runManager() {
|
||||||
func() {
|
for {
|
||||||
ticker := time.NewTicker(s.config.ManagerInterval)
|
select {
|
||||||
for {
|
case <-time.After(s.config.ManagerInterval):
|
||||||
<-ticker.C
|
|
||||||
s.updateStatsAndPrune()
|
s.updateStatsAndPrune()
|
||||||
|
case <-s.closeChan:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) runAtSender() {
|
func (s *Server) runAtSender() {
|
||||||
ticker := time.NewTicker(s.config.AtSenderInterval)
|
|
||||||
for {
|
for {
|
||||||
<-ticker.C
|
select {
|
||||||
if err := s.sendDelayedMessages(); err != nil {
|
case <-time.After(s.config.AtSenderInterval):
|
||||||
log.Printf("error sending scheduled messages: %s", err.Error())
|
if err := s.sendDelayedMessages(); err != nil {
|
||||||
|
log.Printf("error sending scheduled messages: %s", err.Error())
|
||||||
|
}
|
||||||
|
case <-s.closeChan:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -658,14 +683,18 @@ func (s *Server) runFirebaseKeepliver() {
|
||||||
if s.firebase == nil {
|
if s.firebase == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ticker := time.NewTicker(s.config.FirebaseKeepaliveInterval)
|
|
||||||
for {
|
for {
|
||||||
<-ticker.C
|
select {
|
||||||
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||||
log.Printf("error sending Firebase keepalive message: %s", err.Error())
|
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
||||||
|
log.Printf("error sending Firebase keepalive message: %s", err.Error())
|
||||||
|
}
|
||||||
|
case <-s.closeChan:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendDelayedMessages() error {
|
func (s *Server) sendDelayedMessages() error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
|
@ -488,7 +488,7 @@ func TestServer_SubscribeWithQueryFilters(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestConfig(t *testing.T) *Config {
|
func newTestConfig(t *testing.T) *Config {
|
||||||
conf := NewConfig(":80")
|
conf := NewConfig()
|
||||||
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
|
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
|
||||||
return conf
|
return conf
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue