diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 0000000..b010fd9 --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,42 @@ +package client_test + +import ( + "github.com/stretchr/testify/require" + "heckel.io/ntfy/client" + "heckel.io/ntfy/server" + "net/http" + "testing" + "time" +) + +func TestClient_Publish(t *testing.T) { + s := startTestServer(t) + defer s.Stop() + c := client.New(newTestConfig()) + + time.Sleep(time.Second) // FIXME Wait for port up + + _, err := c.Publish("mytopic", "some message") + require.Nil(t, err) +} + +func newTestConfig() *client.Config { + c := client.NewConfig() + c.DefaultHost = "http://127.0.0.1:12345" + return c +} + +func startTestServer(t *testing.T) *server.Server { + conf := server.NewConfig() + conf.ListenHTTP = ":12345" + s, err := server.New(conf) + if err != nil { + t.Fatal(err) + } + go func() { + if err := s.Run(); err != nil && err != http.ErrServerClosed { + panic(err) // 'go vet' complains about 't.Fatal(err)' + } + }() + return s +} diff --git a/client/config_test.go b/client/config_test.go index 67b69be..8d32211 100644 --- a/client/config_test.go +++ b/client/config_test.go @@ -1,7 +1,8 @@ -package client +package client_test import ( "github.com/stretchr/testify/require" + "heckel.io/ntfy/client" "os" "path/filepath" "testing" @@ -21,7 +22,7 @@ subscribe: priority: high,urgent `), 0600)) - conf, err := LoadConfig(filename) + conf, err := client.LoadConfig(filename) require.Nil(t, err) require.Equal(t, "http://localhost", conf.DefaultHost) require.Equal(t, 3, len(conf.Subscribe)) diff --git a/cmd/serve.go b/cmd/serve.go index 874cf0e..312319f 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -85,7 +85,8 @@ func execServe(c *cli.Context) error { } // Run server - conf := server.NewConfig(listenHTTP) + conf := server.NewConfig() + conf.ListenHTTP = listenHTTP conf.ListenHTTPS = listenHTTPS conf.KeyFile = keyFile conf.CertFile = certFile diff --git a/server/config.go b/server/config.go index f44344f..e77d71f 100644 --- a/server/config.go +++ b/server/config.go @@ -52,9 +52,9 @@ type Config struct { } // NewConfig instantiates a default new server config -func NewConfig(listenHTTP string) *Config { +func NewConfig() *Config { return &Config{ - ListenHTTP: listenHTTP, + ListenHTTP: DefaultListenHTTP, ListenHTTPS: "", KeyFile: "", CertFile: "", diff --git a/server/config_test.go b/server/config_test.go index 902549d..14f028f 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -7,6 +7,7 @@ import ( ) func TestConfig_New(t *testing.T) { - c := server.NewConfig(":1234") - assert.Equal(t, ":1234", c.ListenHTTP) + c := server.NewConfig() + assert.Equal(t, ":80", c.ListenHTTP) + assert.Equal(t, server.DefaultKeepaliveInterval, c.KeepaliveInterval) } diff --git a/server/server.go b/server/server.go index a2bff6c..4d31215 100644 --- a/server/server.go +++ b/server/server.go @@ -27,13 +27,16 @@ import ( // Server is the main server, providing the UI and API for ntfy type Server struct { - config *Config - topics map[string]*topic - visitors map[string]*visitor - firebase subscriber - messages int64 - cache cache - mu sync.Mutex + config *Config + httpServer *http.Server + httpsServer *http.Server + topics map[string]*topic + visitors map[string]*visitor + firebase subscriber + messages int64 + cache cache + closeChan chan bool + mu sync.Mutex } // errHTTP is a generic HTTP error for any non-200 HTTP error @@ -198,17 +201,35 @@ func (s *Server) Run() error { log.Printf("Listening on %s", listenStr) http.HandleFunc("/", s.handle) errChan := make(chan error) + s.mu.Lock() + s.closeChan = make(chan bool) + s.httpServer = &http.Server{Addr: s.config.ListenHTTP} go func() { - errChan <- http.ListenAndServe(s.config.ListenHTTP, nil) + errChan <- s.httpServer.ListenAndServe() }() if s.config.ListenHTTPS != "" { + s.httpsServer = &http.Server{Addr: s.config.ListenHTTP} go func() { - errChan <- http.ListenAndServeTLS(s.config.ListenHTTPS, s.config.CertFile, s.config.KeyFile, nil) + errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile) }() } + s.mu.Unlock() return <-errChan } +// Stop stops HTTP (+HTTPS) server and all managers +func (s *Server) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + if s.httpServer != nil { + s.httpServer.Close() + } + if s.httpsServer != nil { + s.httpsServer.Close() + } + close(s.closeChan) +} + func (s *Server) handle(w http.ResponseWriter, r *http.Request) { if err := s.handleInternal(w, r); err != nil { if e, ok := err.(*errHTTP); ok { @@ -635,21 +656,25 @@ func (s *Server) updateStatsAndPrune() { } func (s *Server) runManager() { - func() { - ticker := time.NewTicker(s.config.ManagerInterval) - for { - <-ticker.C + for { + select { + case <-time.After(s.config.ManagerInterval): s.updateStatsAndPrune() + case <-s.closeChan: + return } - }() + } } func (s *Server) runAtSender() { - ticker := time.NewTicker(s.config.AtSenderInterval) for { - <-ticker.C - if err := s.sendDelayedMessages(); err != nil { - log.Printf("error sending scheduled messages: %s", err.Error()) + select { + case <-time.After(s.config.AtSenderInterval): + if err := s.sendDelayedMessages(); err != nil { + log.Printf("error sending scheduled messages: %s", err.Error()) + } + case <-s.closeChan: + return } } } @@ -658,14 +683,18 @@ func (s *Server) runFirebaseKeepliver() { if s.firebase == nil { return } - ticker := time.NewTicker(s.config.FirebaseKeepaliveInterval) for { - <-ticker.C - if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil { - log.Printf("error sending Firebase keepalive message: %s", err.Error()) + select { + case <-time.After(s.config.FirebaseKeepaliveInterval): + if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil { + log.Printf("error sending Firebase keepalive message: %s", err.Error()) + } + case <-s.closeChan: + return } } } + func (s *Server) sendDelayedMessages() error { s.mu.Lock() defer s.mu.Unlock() diff --git a/server/server_test.go b/server/server_test.go index 715b416..6619588 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -488,7 +488,7 @@ func TestServer_SubscribeWithQueryFilters(t *testing.T) { } func newTestConfig(t *testing.T) *Config { - conf := NewConfig(":80") + conf := NewConfig() conf.CacheFile = filepath.Join(t.TempDir(), "cache.db") return conf }