Fix previous fix

This commit is contained in:
binwiederhier 2023-06-01 16:01:39 -04:00
parent dc8932cd95
commit f58c1e4c84
5 changed files with 38 additions and 43 deletions

View file

@ -11,23 +11,25 @@ import (
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"io" "io"
"net/http" "net/http"
"regexp"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
// Event type constants
const ( const (
MessageEvent = "message" // MessageEvent identifies a message event
KeepaliveEvent = "keepalive" MessageEvent = "message"
OpenEvent = "open"
PollRequestEvent = "poll_request"
) )
const ( const (
maxResponseBytes = 4096 maxResponseBytes = 4096
) )
var (
topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go
)
// Client is the ntfy client that can be used to publish and subscribe to ntfy topics // Client is the ntfy client that can be used to publish and subscribe to ntfy topics
type Client struct { type Client struct {
Messages chan *Message Messages chan *Message
@ -96,7 +98,10 @@ func (c *Client) Publish(topic, message string, options ...PublishOption) (*Mess
// To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache, // To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache,
// WithNoFirebase, and the generic WithHeader. // WithNoFirebase, and the generic WithHeader.
func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) { func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) {
topicURL := c.expandTopicURL(topic) topicURL, err := c.expandTopicURL(topic)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", topicURL, body) req, err := http.NewRequest("POST", topicURL, body)
if err != nil { if err != nil {
return nil, err return nil, err
@ -136,11 +141,14 @@ func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishO
// By default, all messages will be returned, but you can change this behavior using a SubscribeOption. // By default, all messages will be returned, but you can change this behavior using a SubscribeOption.
// See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam. // See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam.
func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, error) { func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, error) {
topicURL, err := c.expandTopicURL(topic)
if err != nil {
return nil, err
}
ctx := context.Background() ctx := context.Background()
messages := make([]*Message, 0) messages := make([]*Message, 0)
msgChan := make(chan *Message) msgChan := make(chan *Message)
errChan := make(chan error) errChan := make(chan error)
topicURL := c.expandTopicURL(topic)
log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL)) log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL))
options = append(options, WithPoll()) options = append(options, WithPoll())
go func() { go func() {
@ -169,15 +177,18 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err
// Example: // Example:
// //
// c := client.New(client.NewConfig()) // c := client.New(client.NewConfig())
// subscriptionID := c.Subscribe("mytopic") // subscriptionID, _ := c.Subscribe("mytopic")
// for m := range c.Messages { // for m := range c.Messages {
// fmt.Printf("New message: %s", m.Message) // fmt.Printf("New message: %s", m.Message)
// } // }
func (c *Client) Subscribe(topic string, options ...SubscribeOption) string { func (c *Client) Subscribe(topic string, options ...SubscribeOption) (string, error) {
topicURL, err := c.expandTopicURL(topic)
if err != nil {
return "", err
}
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
subscriptionID := util.RandomString(10) subscriptionID := util.RandomString(10)
topicURL := c.expandTopicURL(topic)
log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL)) log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL))
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c.subscriptions[subscriptionID] = &subscription{ c.subscriptions[subscriptionID] = &subscription{
@ -186,7 +197,7 @@ func (c *Client) Subscribe(topic string, options ...SubscribeOption) string {
cancel: cancel, cancel: cancel,
} }
go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...) go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...)
return subscriptionID return subscriptionID, nil
} }
// Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique // Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique
@ -202,31 +213,16 @@ func (c *Client) Unsubscribe(subscriptionID string) {
sub.cancel() sub.cancel()
} }
// UnsubscribeAll unsubscribes from a topic that has been previously subscribed with Subscribe. func (c *Client) expandTopicURL(topic string) (string, error) {
// If there are multiple subscriptions matching the topic, all of them are unsubscribed from.
//
// A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https://
// (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the
// config (e.g. mytopic -> https://ntfy.sh/mytopic).
func (c *Client) UnsubscribeAll(topic string) {
c.mu.Lock()
defer c.mu.Unlock()
topicURL := c.expandTopicURL(topic)
for _, sub := range c.subscriptions {
if sub.topicURL == topicURL {
delete(c.subscriptions, sub.ID)
sub.cancel()
}
}
}
func (c *Client) expandTopicURL(topic string) string {
if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") { if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") {
return topic return topic, nil
} else if strings.Contains(topic, "/") { } else if strings.Contains(topic, "/") {
return fmt.Sprintf("https://%s", topic) return fmt.Sprintf("https://%s", topic), nil
} }
return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic) if !topicRegex.MatchString(topic) {
return "", fmt.Errorf("invalid topic name: %s", topic)
}
return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic), nil
} }
func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) { func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) {

View file

@ -21,7 +21,7 @@ func TestClient_Publish_Subscribe(t *testing.T) {
defer test.StopServer(t, s, port) defer test.StopServer(t, s, port)
c := client.New(newTestConfig(port)) c := client.New(newTestConfig(port))
subscriptionID := c.Subscribe("mytopic") subscriptionID, _ := c.Subscribe("mytopic")
time.Sleep(time.Second) time.Sleep(time.Second)
msg, err := c.Publish("mytopic", "some message") msg, err := c.Publish("mytopic", "some message")

View file

@ -29,7 +29,6 @@ var flagsDefault = []cli.Flag{
var ( var (
logLevelOverrideRegex = regexp.MustCompile(`(?i)^([^=\s]+)(?:\s*=\s*(\S+))?\s*->\s*(TRACE|DEBUG|INFO|WARN|ERROR)$`) logLevelOverrideRegex = regexp.MustCompile(`(?i)^([^=\s]+)(?:\s*=\s*(\S+))?\s*->\s*(TRACE|DEBUG|INFO|WARN|ERROR)$`)
topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go
) )
// New creates a new CLI application // New creates a new CLI application

View file

@ -249,10 +249,6 @@ func parseTopicMessageCommand(c *cli.Context) (topic string, message string, com
if c.String("message") != "" { if c.String("message") != "" {
message = c.String("message") message = c.String("message")
} }
if !topicRegex.MatchString(topic) {
err = fmt.Errorf("topic %s contains invalid characters", topic)
return
}
return return
} }

View file

@ -108,8 +108,6 @@ func execSubscribe(c *cli.Context) error {
// Checks // Checks
if user != "" && token != "" { if user != "" && token != "" {
return errors.New("cannot set both --user and --token") return errors.New("cannot set both --user and --token")
} else if !topicRegex.MatchString(topic) {
return fmt.Errorf("topic %s contains invalid characters", topic)
} }
if !fromConfig { if !fromConfig {
@ -196,7 +194,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic,
topicOptions = append(topicOptions, auth) topicOptions = append(topicOptions, auth)
} }
subscriptionID := cl.Subscribe(s.Topic, topicOptions...) subscriptionID, err := cl.Subscribe(s.Topic, topicOptions...)
if err != nil {
return err
}
if s.Command != "" { if s.Command != "" {
cmds[subscriptionID] = s.Command cmds[subscriptionID] = s.Command
} else if conf.DefaultCommand != "" { } else if conf.DefaultCommand != "" {
@ -206,7 +207,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic,
} }
} }
if topic != "" { if topic != "" {
subscriptionID := cl.Subscribe(topic, options...) subscriptionID, err := cl.Subscribe(topic, options...)
if err != nil {
return err
}
cmds[subscriptionID] = command cmds[subscriptionID] = command
} }
for m := range cl.Messages { for m := range cl.Messages {