Fix previous fix

This commit is contained in:
binwiederhier 2023-06-01 16:01:39 -04:00
parent dc8932cd95
commit f58c1e4c84
5 changed files with 38 additions and 43 deletions

View file

@ -11,23 +11,25 @@ import (
"heckel.io/ntfy/util"
"io"
"net/http"
"regexp"
"strings"
"sync"
"time"
)
// Event type constants
const (
// MessageEvent identifies a message event
MessageEvent = "message"
KeepaliveEvent = "keepalive"
OpenEvent = "open"
PollRequestEvent = "poll_request"
)
const (
maxResponseBytes = 4096
)
var (
topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go
)
// Client is the ntfy client that can be used to publish and subscribe to ntfy topics
type Client struct {
Messages chan *Message
@ -96,7 +98,10 @@ func (c *Client) Publish(topic, message string, options ...PublishOption) (*Mess
// To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache,
// WithNoFirebase, and the generic WithHeader.
func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) {
topicURL := c.expandTopicURL(topic)
topicURL, err := c.expandTopicURL(topic)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", topicURL, body)
if err != nil {
return nil, err
@ -136,11 +141,14 @@ func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishO
// By default, all messages will be returned, but you can change this behavior using a SubscribeOption.
// See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam.
func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, error) {
topicURL, err := c.expandTopicURL(topic)
if err != nil {
return nil, err
}
ctx := context.Background()
messages := make([]*Message, 0)
msgChan := make(chan *Message)
errChan := make(chan error)
topicURL := c.expandTopicURL(topic)
log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL))
options = append(options, WithPoll())
go func() {
@ -169,15 +177,18 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err
// Example:
//
// c := client.New(client.NewConfig())
// subscriptionID := c.Subscribe("mytopic")
// subscriptionID, _ := c.Subscribe("mytopic")
// for m := range c.Messages {
// fmt.Printf("New message: %s", m.Message)
// }
func (c *Client) Subscribe(topic string, options ...SubscribeOption) string {
func (c *Client) Subscribe(topic string, options ...SubscribeOption) (string, error) {
topicURL, err := c.expandTopicURL(topic)
if err != nil {
return "", err
}
c.mu.Lock()
defer c.mu.Unlock()
subscriptionID := util.RandomString(10)
topicURL := c.expandTopicURL(topic)
log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL))
ctx, cancel := context.WithCancel(context.Background())
c.subscriptions[subscriptionID] = &subscription{
@ -186,7 +197,7 @@ func (c *Client) Subscribe(topic string, options ...SubscribeOption) string {
cancel: cancel,
}
go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...)
return subscriptionID
return subscriptionID, nil
}
// Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique
@ -202,31 +213,16 @@ func (c *Client) Unsubscribe(subscriptionID string) {
sub.cancel()
}
// UnsubscribeAll unsubscribes from a topic that has been previously subscribed with Subscribe.
// If there are multiple subscriptions matching the topic, all of them are unsubscribed from.
//
// A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https://
// (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the
// config (e.g. mytopic -> https://ntfy.sh/mytopic).
func (c *Client) UnsubscribeAll(topic string) {
c.mu.Lock()
defer c.mu.Unlock()
topicURL := c.expandTopicURL(topic)
for _, sub := range c.subscriptions {
if sub.topicURL == topicURL {
delete(c.subscriptions, sub.ID)
sub.cancel()
}
}
}
func (c *Client) expandTopicURL(topic string) string {
func (c *Client) expandTopicURL(topic string) (string, error) {
if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") {
return topic
return topic, nil
} else if strings.Contains(topic, "/") {
return fmt.Sprintf("https://%s", topic)
return fmt.Sprintf("https://%s", topic), nil
}
return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic)
if !topicRegex.MatchString(topic) {
return "", fmt.Errorf("invalid topic name: %s", topic)
}
return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic), nil
}
func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) {

View file

@ -21,7 +21,7 @@ func TestClient_Publish_Subscribe(t *testing.T) {
defer test.StopServer(t, s, port)
c := client.New(newTestConfig(port))
subscriptionID := c.Subscribe("mytopic")
subscriptionID, _ := c.Subscribe("mytopic")
time.Sleep(time.Second)
msg, err := c.Publish("mytopic", "some message")

View file

@ -29,7 +29,6 @@ var flagsDefault = []cli.Flag{
var (
logLevelOverrideRegex = regexp.MustCompile(`(?i)^([^=\s]+)(?:\s*=\s*(\S+))?\s*->\s*(TRACE|DEBUG|INFO|WARN|ERROR)$`)
topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go
)
// New creates a new CLI application

View file

@ -249,10 +249,6 @@ func parseTopicMessageCommand(c *cli.Context) (topic string, message string, com
if c.String("message") != "" {
message = c.String("message")
}
if !topicRegex.MatchString(topic) {
err = fmt.Errorf("topic %s contains invalid characters", topic)
return
}
return
}

View file

@ -108,8 +108,6 @@ func execSubscribe(c *cli.Context) error {
// Checks
if user != "" && token != "" {
return errors.New("cannot set both --user and --token")
} else if !topicRegex.MatchString(topic) {
return fmt.Errorf("topic %s contains invalid characters", topic)
}
if !fromConfig {
@ -196,7 +194,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic,
topicOptions = append(topicOptions, auth)
}
subscriptionID := cl.Subscribe(s.Topic, topicOptions...)
subscriptionID, err := cl.Subscribe(s.Topic, topicOptions...)
if err != nil {
return err
}
if s.Command != "" {
cmds[subscriptionID] = s.Command
} else if conf.DefaultCommand != "" {
@ -206,7 +207,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic,
}
}
if topic != "" {
subscriptionID := cl.Subscribe(topic, options...)
subscriptionID, err := cl.Subscribe(topic, options...)
if err != nil {
return err
}
cmds[subscriptionID] = command
}
for m := range cl.Messages {