Polishing

This commit is contained in:
binwiederhier 2023-02-23 20:46:53 -05:00
parent 8eae44ea61
commit 2329695a47
5 changed files with 88 additions and 44 deletions

View file

@ -4,11 +4,6 @@ import (
"heckel.io/ntfy/log"
"math/rand"
"sync"
"time"
)
const (
topicExpiryDuration = 6 * time.Hour
)
// topic represents a channel to which subscribers can subscribe, and publishers
@ -17,13 +12,12 @@ type topic struct {
ID string
subscribers map[int]*topicSubscriber
rateVisitor *visitor
expires time.Time
mu sync.RWMutex
}
type topicSubscriber struct {
userID string // User ID associated with this subscription, may be empty
subscriber subscriber
visitor *visitor // User ID associated with this subscription, may be empty
cancel func()
}
@ -39,12 +33,12 @@ func newTopic(id string) *topic {
}
// Subscribe subscribes to this topic
func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
t.mu.Lock()
defer t.mu.Unlock()
subscriberID := rand.Int()
t.subscribers[subscriberID] = &topicSubscriber{
visitor: visitor, // May be empty
userID: userID, // May be empty
subscriber: s,
cancel: cancel,
}
@ -54,7 +48,10 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
func (t *topic) Stale() bool {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.subscribers) == 0 && t.expires.Before(time.Now())
if t.rateVisitor != nil && !t.rateVisitor.Stale() {
return false
}
return len(t.subscribers) == 0
}
func (t *topic) SetRateVisitor(v *visitor) {
@ -66,6 +63,9 @@ func (t *topic) SetRateVisitor(v *visitor) {
func (t *topic) RateVisitor() *visitor {
t.mu.Lock()
defer t.mu.Unlock()
if t.rateVisitor != nil && t.rateVisitor.Stale() {
t.rateVisitor = nil
}
return t.rateVisitor
}
@ -74,9 +74,6 @@ func (t *topic) Unsubscribe(id int) {
t.mu.Lock()
defer t.mu.Unlock()
delete(t.subscribers, id)
if len(t.subscribers) == 0 {
t.expires = time.Now().Add(topicExpiryDuration)
}
}
// Publish asynchronously publishes to all subscribers
@ -115,9 +112,14 @@ func (t *topic) CancelSubscribers(exceptUserID string) {
t.mu.Lock()
defer t.mu.Unlock()
for _, s := range t.subscribers {
if s.visitor.MaybeUserID() != exceptUserID {
// TODO: Shouldn't this log the IP for anonymous visitors? It was s.userID before my change.
log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.visitor.MaybeUserID())
if s.userID != exceptUserID {
log.
Tag(tagSubscribe).
Fields(log.Context{
"message_topic": t.ID,
"user_id": s.userID,
}).
Debug("Canceling subscriber %s", s.userID)
s.cancel()
}
}
@ -130,7 +132,7 @@ func (t *topic) subscribersCopy() map[int]*topicSubscriber {
subscribers := make(map[int]*topicSubscriber)
for k, sub := range t.subscribers {
subscribers[k] = &topicSubscriber{
visitor: sub.visitor,
userID: sub.userID,
subscriber: sub.subscriber,
cancel: sub.cancel,
}