Query filters

This commit is contained in:
Philipp Heckel 2021-12-21 21:22:27 +01:00
parent 85b4abde6c
commit 9315829bc4
7 changed files with 167 additions and 37 deletions

View file

@ -6,6 +6,7 @@ import (
"context"
"encoding/json"
"fmt"
"heckel.io/ntfy/util"
"io"
"log"
"net/http"
@ -39,16 +40,21 @@ type Message struct {
Event string
Time int64
Topic string
TopicURL string
Message string
Title string
Priority int
Tags []string
Raw string
// Additional fields
TopicURL string
SubscriptionID string
Raw string
}
type subscription struct {
cancel context.CancelFunc
ID string
topicURL string
cancel context.CancelFunc
}
// New creates a new Client using a given Config
@ -88,7 +94,7 @@ func (c *Client) Publish(topic, message string, options ...PublishOption) (*Mess
if err != nil {
return nil, err
}
m, err := toMessage(string(b), topicURL)
m, err := toMessage(string(b), topicURL, "")
if err != nil {
return nil, err
}
@ -111,7 +117,7 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err
errChan := make(chan error)
topicURL := c.expandTopicURL(topic)
go func() {
err := performSubscribeRequest(ctx, msgChan, topicURL, options...)
err := performSubscribeRequest(ctx, msgChan, topicURL, "", options...)
close(msgChan)
errChan <- err
}()
@ -131,39 +137,58 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err
// By default, only new messages will be returned, but you can change this behavior using a SubscribeOption.
// See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam.
//
// The method returns a unique subscriptionID that can be used in Unsubscribe.
//
// Example:
// c := client.New(client.NewConfig())
// c.Subscribe("mytopic")
// subscriptionID := c.Subscribe("mytopic")
// for m := range c.Messages {
// fmt.Printf("New message: %s", m.Message)
// }
func (c *Client) Subscribe(topic string, options ...SubscribeOption) string {
c.mu.Lock()
defer c.mu.Unlock()
subscriptionID := util.RandomString(10)
topicURL := c.expandTopicURL(topic)
if _, ok := c.subscriptions[topicURL]; ok {
return topicURL
}
ctx, cancel := context.WithCancel(context.Background())
c.subscriptions[topicURL] = &subscription{cancel}
go handleSubscribeConnLoop(ctx, c.Messages, topicURL, options...)
return topicURL
c.subscriptions[subscriptionID] = &subscription{
ID: subscriptionID,
topicURL: topicURL,
cancel: cancel,
}
go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...)
return subscriptionID
}
// Unsubscribe unsubscribes from a topic that has been previously subscribed with Subscribe.
// Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique
// subscriptionID returned in Subscribe.
func (c *Client) Unsubscribe(subscriptionID string) {
c.mu.Lock()
defer c.mu.Unlock()
sub, ok := c.subscriptions[subscriptionID]
if !ok {
return
}
delete(c.subscriptions, subscriptionID)
sub.cancel()
}
// UnsubscribeAll unsubscribes from a topic that has been previously subscribed with Subscribe.
// If there are multiple subscriptions matching the topic, all of them are unsubscribed from.
//
// A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https://
// (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the
// config (e.g. mytopic -> https://ntfy.sh/mytopic).
func (c *Client) Unsubscribe(topic string) {
func (c *Client) UnsubscribeAll(topic string) {
c.mu.Lock()
defer c.mu.Unlock()
topicURL := c.expandTopicURL(topic)
sub, ok := c.subscriptions[topicURL]
if !ok {
return
for _, sub := range c.subscriptions {
if sub.topicURL == topicURL {
delete(c.subscriptions, sub.ID)
sub.cancel()
}
}
sub.cancel()
}
func (c *Client) expandTopicURL(topic string) string {
@ -175,9 +200,11 @@ func (c *Client) expandTopicURL(topic string) string {
return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic)
}
func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL string, options ...SubscribeOption) {
func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) {
for {
if err := performSubscribeRequest(ctx, msgChan, topicURL, options...); err != nil {
// TODO The retry logic is crude and may lose messages. It should record the last message like the
// Android client, use since=, and do incremental backoff too
if err := performSubscribeRequest(ctx, msgChan, topicURL, subcriptionID, options...); err != nil {
log.Printf("Connection to %s failed: %s", topicURL, err.Error())
}
select {
@ -189,7 +216,7 @@ func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicUR
}
}
func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicURL string, options ...SubscribeOption) error {
func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicURL string, subscriptionID string, options ...SubscribeOption) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/json", topicURL), nil)
if err != nil {
return err
@ -206,7 +233,7 @@ func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicUR
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
m, err := toMessage(scanner.Text(), topicURL)
m, err := toMessage(scanner.Text(), topicURL, subscriptionID)
if err != nil {
return err
}
@ -215,12 +242,13 @@ func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicUR
return nil
}
func toMessage(s, topicURL string) (*Message, error) {
func toMessage(s, topicURL, subscriptionID string) (*Message, error) {
var m *Message
if err := json.NewDecoder(strings.NewReader(s)).Decode(&m); err != nil {
return nil, err
}
m.TopicURL = topicURL
m.SubscriptionID = subscriptionID
m.Raw = s
return m, nil
}

View file

@ -9,9 +9,9 @@ const (
type Config struct {
DefaultHost string `yaml:"default-host"`
Subscribe []struct {
Topic string `yaml:"topic"`
Command string `yaml:"command"`
// If []map[string]string TODO This would be cool
Topic string `yaml:"topic"`
Command string `yaml:"command"`
If map[string]string `yaml:"if"`
} `yaml:"subscribe"`
}

View file

@ -88,6 +88,32 @@ func WithScheduled() SubscribeOption {
return WithQueryParam("scheduled", "1")
}
// WithFilter is a generic subscribe option meant to be used to filter for certain messages only
func WithFilter(param, value string) SubscribeOption {
return WithQueryParam(param, value)
}
// WithMessageFilter instructs the server to only return messages that match the exact message
func WithMessageFilter(message string) SubscribeOption {
return WithQueryParam("message", message)
}
// WithTitleFilter instructs the server to only return messages with a title that match the exact string
func WithTitleFilter(title string) SubscribeOption {
return WithQueryParam("title", title)
}
// WithPriorityFilter instructs the server to only return messages with the matching priority. Not that messages
// without priority also implicitly match priority 3.
func WithPriorityFilter(priority int) SubscribeOption {
return WithQueryParam("priority", fmt.Sprintf("%d", priority))
}
// WithTagsFilter instructs the server to only return messages that contain all of the given tags
func WithTagsFilter(tags []string) SubscribeOption {
return WithQueryParam("tags", strings.Join(tags, ","))
}
// WithHeader is a generic option to add headers to a request
func WithHeader(header, value string) RequestOption {
return func(r *http.Request) error {

View file

@ -135,17 +135,21 @@ func doPollSingle(c *cli.Context, cl *client.Client, topic, command string, opti
}
func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic, command string, options ...client.SubscribeOption) error {
commands := make(map[string]string)
for _, s := range conf.Subscribe { // May be nil
topicURL := cl.Subscribe(s.Topic, options...)
commands[topicURL] = s.Command
commands := make(map[string]string) // Subscription ID -> command
for _, s := range conf.Subscribe { // May be nil
topicOptions := append(make([]client.SubscribeOption, 0), options...)
for filter, value := range s.If {
topicOptions = append(topicOptions, client.WithFilter(filter, value))
}
subscriptionID := cl.Subscribe(s.Topic, topicOptions...)
commands[subscriptionID] = s.Command
}
if topic != "" {
topicURL := cl.Subscribe(topic, options...)
commands[topicURL] = command
subscriptionID := cl.Subscribe(topic, options...)
commands[subscriptionID] = command
}
for m := range cl.Messages {
command, ok := commands[m.TopicURL]
command, ok := commands[m.SubscriptionID]
if !ok {
continue
}

View file

@ -334,7 +334,7 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase
tagsStr := readParam(r, "x-tags", "tag", "tags", "ta")
if tagsStr != "" {
m.Tags = make([]string, 0)
for _, s := range strings.Split(tagsStr, ",") {
for _, s := range util.SplitNoEmpty(tagsStr, ",") {
m.Tags = append(m.Tags, strings.TrimSpace(s))
}
}
@ -413,7 +413,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
}
defer v.RemoveSubscription()
topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack
topicIDs := strings.Split(topicsStr, ",")
topicIDs := util.SplitNoEmpty(topicsStr, ",")
topics, err := s.topicsFromIDs(topicIDs...)
if err != nil {
return err
@ -425,13 +425,20 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
var wlock sync.Mutex
poll := r.URL.Query().Has("poll")
scheduled := r.URL.Query().Has("scheduled") || r.URL.Query().Has("sched")
messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r)
if err != nil {
return err
}
sub := func(msg *message) error {
wlock.Lock()
defer wlock.Unlock()
if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) {
return nil
}
m, err := encoder(msg)
if err != nil {
return err
}
wlock.Lock()
defer wlock.Unlock()
if _, err := w.Write([]byte(m)); err != nil {
return err
}
@ -473,6 +480,34 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
}
}
func parseQueryFilters(r *http.Request) (messageFilter string, titleFilter string, priorityFilter int, tagsFilter []string, err error) {
messageFilter = r.URL.Query().Get("message")
titleFilter = r.URL.Query().Get("title")
tagsFilter = util.SplitNoEmpty(r.URL.Query().Get("tags"), ",")
priorityFilter, err = util.ParsePriority(r.URL.Query().Get("priority"))
return
}
func passesQueryFilter(msg *message, messageFilter string, titleFilter string, priorityFilter int, tagsFilter []string) bool {
if messageFilter != "" && msg.Message != messageFilter {
log.Printf("1")
return false
}
if titleFilter != "" && msg.Title != titleFilter {
log.Printf("2")
return false
}
if priorityFilter > 0 && (msg.Priority != priorityFilter || (msg.Priority == 0 && priorityFilter != 3)) {
log.Printf("3")
return false
}
if len(tagsFilter) > 0 && !util.InStringListAll(msg.Tags, tagsFilter) {
log.Printf("4")
return false
}
return true
}
func (s *Server) sendOldMessages(topics []*topic, since sinceTime, scheduled bool, sub subscriber) error {
if since.IsNone() {
return nil

View file

@ -37,6 +37,30 @@ func InStringList(haystack []string, needle string) bool {
return false
}
// InStringListAll returns true if all needles are contained in haystack
func InStringListAll(haystack []string, needles []string) bool {
matches := 0
for _, s := range haystack {
for _, needle := range needles {
if s == needle {
matches++
}
}
}
return matches == len(needles)
}
// SplitNoEmpty splits a string using strings.Split, but filters out empty strings
func SplitNoEmpty(s string, sep string) []string {
res := make([]string, 0)
for _, r := range strings.Split(s, sep) {
if r != "" {
res = append(res, r)
}
}
return res
}
// RandomString returns a random string with a given length
func RandomString(length int) string {
randomMutex.Lock() // Who would have thought that random.Intn() is not thread-safe?!

View file

@ -56,6 +56,19 @@ func TestInStringList(t *testing.T) {
require.False(t, InStringList(s, "three"))
}
func TestInStringListAll(t *testing.T) {
s := []string{"one", "two", "three", "four"}
require.True(t, InStringListAll(s, []string{"two", "four"}))
require.False(t, InStringListAll(s, []string{"three", "five"}))
}
func TestSplitNoEmpty(t *testing.T) {
require.Equal(t, []string{}, SplitNoEmpty("", ","))
require.Equal(t, []string{}, SplitNoEmpty(",,,", ","))
require.Equal(t, []string{"tag1", "tag2"}, SplitNoEmpty("tag1,tag2", ","))
require.Equal(t, []string{"tag1", "tag2"}, SplitNoEmpty("tag1,tag2,", ","))
}
func TestExpandHome_WithTilde(t *testing.T) {
require.Equal(t, os.Getenv("HOME")+"/this/is/a/path", ExpandHome("~/this/is/a/path"))
}