diff --git a/server/message_cache.go b/server/message_cache.go index afd4bf1..d6024c8 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -82,7 +82,7 @@ const ( ` updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` - selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` + selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?` @@ -332,22 +332,24 @@ func (c *messageCache) MarkPublished(m *message) error { return err } -func (c *messageCache) MessageCount(topic string) (int, error) { - rows, err := c.db.Query(selectMessageCountForTopicQuery, topic) +func (c *messageCache) MessageCounts() (map[string]int, error) { + rows, err := c.db.Query(selectMessageCountPerTopicQuery) if err != nil { - return 0, err + return nil, err } defer rows.Close() + var topic string var count int - if !rows.Next() { - return 0, errors.New("no rows found") + counts := make(map[string]int) + for rows.Next() { + if err := rows.Scan(&topic, &count); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + counts[topic] = count } - if err := rows.Scan(&count); err != nil { - return 0, err - } else if err := rows.Err(); err != nil { - return 0, err - } - return count, nil + return counts, nil } func (c *messageCache) Topics() (map[string]*topic, error) { diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 398f21e..07929e0 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -34,9 +34,9 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added! // mytopic: count - count, err := c.MessageCount("mytopic") + counts, err := c.MessageCounts() require.Nil(t, err) - require.Equal(t, 2, count) + require.Equal(t, 2, counts["mytopic"]) // mytopic: since all messages, _ := c.Messages("mytopic", sinceAllMessages, false) @@ -66,18 +66,18 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, "my other message", messages[0].Message) // example: count - count, err = c.MessageCount("example") + counts, err = c.MessageCounts() require.Nil(t, err) - require.Equal(t, 1, count) + require.Equal(t, 1, counts["example"]) // example: since all messages, _ = c.Messages("example", sinceAllMessages, false) require.Equal(t, "my example message", messages[0].Message) // non-existing: count - count, err = c.MessageCount("doesnotexist") + counts, err = c.MessageCounts() require.Nil(t, err) - require.Equal(t, 0, count) + require.Equal(t, 0, counts["doesnotexist"]) // non-existing: since all messages, _ = c.Messages("doesnotexist", sinceAllMessages, false) @@ -255,13 +255,13 @@ func testCachePrune(t *testing.T, c *messageCache) { require.Nil(t, c.AddMessage(m3)) require.Nil(t, c.Prune(time.Unix(2, 0))) - count, err := c.MessageCount("mytopic") + counts, err := c.MessageCounts() require.Nil(t, err) - require.Equal(t, 1, count) + require.Equal(t, 1, counts["mytopic"]) - count, err = c.MessageCount("another_topic") + counts, err = c.MessageCounts() require.Nil(t, err) - require.Equal(t, 0, count) + require.Equal(t, 0, counts["another_topic"]) messages, err := c.Messages("mytopic", sinceAllMessages, false) require.Nil(t, err) diff --git a/server/server.go b/server/server.go index 4d028d9..e79118b 100644 --- a/server/server.go +++ b/server/server.go @@ -798,6 +798,13 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * return err } var wlock sync.Mutex + defer func() { + // Hack: This is the fix for a horrible data race that I have not been able to figure out in quite some time. + // It appears to be happening when the Go HTTP code reads from the socket when closing the request (i.e. AFTER + // this function returns), and causes a data race with the ResponseWriter. Locking wlock here silences the + // data race detector. See https://github.com/binwiederhier/ntfy/issues/338#issuecomment-1163425889. + wlock.TryLock() + }() sub := func(v *visitor, msg *message) error { if !filters.Pass(msg) { return nil @@ -1080,18 +1087,23 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { } func (s *Server) updateStatsAndPrune() { - s.mu.Lock() - defer s.mu.Unlock() + log.Debug("Manager: Starting") + defer log.Debug("Manager: Finished") + + // WARNING: Make sure to only selectively lock with the mutex, and be aware that this + // there is no mutex for the entire function. // Expire visitors from rate visitors map + s.mu.Lock() staleVisitors := 0 for ip, v := range s.visitors { if v.Stale() { - log.Debug("Deleting stale visitor %s", v.ip) + log.Trace("Deleting stale visitor %s", v.ip) delete(s.visitors, ip) staleVisitors++ } } + s.mu.Unlock() log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) // Delete expired attachments @@ -1116,22 +1128,31 @@ func (s *Server) updateStatsAndPrune() { log.Warn("Manager: Error pruning cache: %s", err.Error()) } - // Prune old topics, remove subscriptions without subscribers - var subscribers, messages int + // Message count per topic + var messages int + messageCounts, err := s.messageCache.MessageCounts() + if err != nil { + log.Warn("Manager: Cannot get message counts: %s", err.Error()) + messageCounts = make(map[string]int) // Empty, so we can continue + } + for _, count := range messageCounts { + messages += count + } + + // Remove subscriptions without subscribers + s.mu.Lock() + var subscribers int for _, t := range s.topics { - subs := t.Subscribers() - msgs, err := s.messageCache.MessageCount(t.ID) - if err != nil { - log.Warn("Manager: Cannot get stats for topic %s: %s", t.ID, err.Error()) - continue - } - if msgs == 0 && subs == 0 { + subs := t.SubscribersCount() + msgs, exists := messageCounts[t.ID] + if subs == 0 && (!exists || msgs == 0) { + log.Trace("Deleting empty topic %s", t.ID) delete(s.topics, t.ID) continue } subscribers += subs - messages += msgs } + s.mu.Unlock() // Mail stats var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64 @@ -1219,10 +1240,10 @@ func (s *Server) sendDelayedMessages() error { } func (s *Server) sendDelayedMessage(v *visitor, m *message) error { - s.mu.Lock() - defer s.mu.Unlock() log.Debug("%s Sending delayed message", logMessagePrefix(v, m)) + s.mu.Lock() t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published + s.mu.Unlock() if ok { go func() { // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler diff --git a/server/topic.go b/server/topic.go index 889f1eb..3bc7473 100644 --- a/server/topic.go +++ b/server/topic.go @@ -44,14 +44,19 @@ func (t *topic) Unsubscribe(id int) { // Publish asynchronously publishes to all subscribers func (t *topic) Publish(v *visitor, m *message) error { go func() { - t.mu.Lock() - defer t.mu.Unlock() - if len(t.subscribers) > 0 { - log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(t.subscribers)) - for _, s := range t.subscribers { - if err := s(v, m); err != nil { - log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m)) - } + // We want to lock the topic as short as possible, so we make a shallow copy of the + // subscribers map here. Actually sending out the messages then doesn't have to lock. + subscribers := t.subscribersCopy() + if len(subscribers) > 0 { + log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(subscribers)) + for _, s := range subscribers { + // We call the subscriber functions in their own Go routines because they are blocking, and + // we don't want individual slow subscribers to be able to block others. + go func(s subscriber) { + if err := s(v, m); err != nil { + log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m)) + } + }(s) } } else { log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m)) @@ -60,9 +65,20 @@ func (t *topic) Publish(v *visitor, m *message) error { return nil } -// Subscribers returns the number of subscribers to this topic -func (t *topic) Subscribers() int { +// SubscribersCount returns the number of subscribers to this topic +func (t *topic) SubscribersCount() int { t.mu.Lock() defer t.mu.Unlock() return len(t.subscribers) } + +// subscribersCopy returns a shallow copy of the subscribers map +func (t *topic) subscribersCopy() map[int]subscriber { + t.mu.Lock() + defer t.mu.Unlock() + subscribers := make(map[int]subscriber) + for k, v := range t.subscribers { + subscribers[k] = v + } + return subscribers +}