From 85f2252a77a2629434460a70e45b281b5dbebd2a Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Tue, 21 Jun 2022 19:07:27 -0400 Subject: [PATCH 1/5] WIP: Shorter lock, for #338 --- server/message_cache.go | 26 ++++++++++++++------------ server/message_cache_test.go | 10 +++++----- server/server.go | 32 ++++++++++++++++++++++---------- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/server/message_cache.go b/server/message_cache.go index afd4bf1..d6024c8 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -82,7 +82,7 @@ const ( ` updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` - selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` + selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?` @@ -332,22 +332,24 @@ func (c *messageCache) MarkPublished(m *message) error { return err } -func (c *messageCache) MessageCount(topic string) (int, error) { - rows, err := c.db.Query(selectMessageCountForTopicQuery, topic) +func (c *messageCache) MessageCounts() (map[string]int, error) { + rows, err := c.db.Query(selectMessageCountPerTopicQuery) if err != nil { - return 0, err + return nil, err } defer rows.Close() + var topic string var count int - if !rows.Next() { - return 0, errors.New("no rows found") + counts := make(map[string]int) + for rows.Next() { + if err := rows.Scan(&topic, &count); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + counts[topic] = count } - if err := rows.Scan(&count); err != nil { - return 0, err - } else if err := rows.Err(); err != nil { - return 0, err - } - return count, nil + return counts, nil } func (c *messageCache) Topics() (map[string]*topic, error) { diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 398f21e..9132088 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -34,7 +34,7 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added! // mytopic: count - count, err := c.MessageCount("mytopic") + count, err := c.MessageCounts("mytopic") require.Nil(t, err) require.Equal(t, 2, count) @@ -66,7 +66,7 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, "my other message", messages[0].Message) // example: count - count, err = c.MessageCount("example") + count, err = c.MessageCounts("example") require.Nil(t, err) require.Equal(t, 1, count) @@ -75,7 +75,7 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, "my example message", messages[0].Message) // non-existing: count - count, err = c.MessageCount("doesnotexist") + count, err = c.MessageCounts("doesnotexist") require.Nil(t, err) require.Equal(t, 0, count) @@ -255,11 +255,11 @@ func testCachePrune(t *testing.T, c *messageCache) { require.Nil(t, c.AddMessage(m3)) require.Nil(t, c.Prune(time.Unix(2, 0))) - count, err := c.MessageCount("mytopic") + count, err := c.MessageCounts("mytopic") require.Nil(t, err) require.Equal(t, 1, count) - count, err = c.MessageCount("another_topic") + count, err = c.MessageCounts("another_topic") require.Nil(t, err) require.Equal(t, 0, count) diff --git a/server/server.go b/server/server.go index 4d028d9..2c887cc 100644 --- a/server/server.go +++ b/server/server.go @@ -1080,10 +1080,13 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { } func (s *Server) updateStatsAndPrune() { - s.mu.Lock() - defer s.mu.Unlock() + log.Debug("Manager: Running cleanup") + + // WARNING: Make sure to only selectively lock with the mutex, and be aware that this + // there is no mutex for the entire function. // Expire visitors from rate visitors map + s.mu.Lock() staleVisitors := 0 for ip, v := range s.visitors { if v.Stale() { @@ -1092,6 +1095,7 @@ func (s *Server) updateStatsAndPrune() { staleVisitors++ } } + s.mu.Unlock() log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) // Delete expired attachments @@ -1116,22 +1120,30 @@ func (s *Server) updateStatsAndPrune() { log.Warn("Manager: Error pruning cache: %s", err.Error()) } + // Message count per topic + var messages int + messageCounts, err := s.messageCache.MessageCounts() + if err != nil { + log.Warn("Manager: Cannot get message counts: %s", err.Error()) + messageCounts = make(map[string]int) // Empty, so we can continue + } + for _, count := range messageCounts { + messages += count + } + // Prune old topics, remove subscriptions without subscribers - var subscribers, messages int + s.mu.Lock() + var subscribers int for _, t := range s.topics { subs := t.Subscribers() - msgs, err := s.messageCache.MessageCount(t.ID) - if err != nil { - log.Warn("Manager: Cannot get stats for topic %s: %s", t.ID, err.Error()) - continue - } - if msgs == 0 && subs == 0 { + msgs, exists := messageCounts[t.ID] + if subs == 0 && (!exists || msgs == 0) { delete(s.topics, t.ID) continue } subscribers += subs - messages += msgs } + s.mu.Unlock() // Mail stats var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64 From c1f7bed8d10d42a317f377295b05524b0b77d11a Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Tue, 21 Jun 2022 19:45:23 -0400 Subject: [PATCH 2/5] Fix tests, lock topic as short as possible --- server/message_cache_test.go | 20 ++++++++++---------- server/server.go | 7 ++++--- server/topic.go | 26 +++++++++++++++++++------- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 9132088..07929e0 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -34,9 +34,9 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added! // mytopic: count - count, err := c.MessageCounts("mytopic") + counts, err := c.MessageCounts() require.Nil(t, err) - require.Equal(t, 2, count) + require.Equal(t, 2, counts["mytopic"]) // mytopic: since all messages, _ := c.Messages("mytopic", sinceAllMessages, false) @@ -66,18 +66,18 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Equal(t, "my other message", messages[0].Message) // example: count - count, err = c.MessageCounts("example") + counts, err = c.MessageCounts() require.Nil(t, err) - require.Equal(t, 1, count) + require.Equal(t, 1, counts["example"]) // example: since all messages, _ = c.Messages("example", sinceAllMessages, false) require.Equal(t, "my example message", messages[0].Message) // non-existing: count - count, err = c.MessageCounts("doesnotexist") + counts, err = c.MessageCounts() require.Nil(t, err) - require.Equal(t, 0, count) + require.Equal(t, 0, counts["doesnotexist"]) // non-existing: since all messages, _ = c.Messages("doesnotexist", sinceAllMessages, false) @@ -255,13 +255,13 @@ func testCachePrune(t *testing.T, c *messageCache) { require.Nil(t, c.AddMessage(m3)) require.Nil(t, c.Prune(time.Unix(2, 0))) - count, err := c.MessageCounts("mytopic") + counts, err := c.MessageCounts() require.Nil(t, err) - require.Equal(t, 1, count) + require.Equal(t, 1, counts["mytopic"]) - count, err = c.MessageCounts("another_topic") + counts, err = c.MessageCounts() require.Nil(t, err) - require.Equal(t, 0, count) + require.Equal(t, 0, counts["another_topic"]) messages, err := c.Messages("mytopic", sinceAllMessages, false) require.Nil(t, err) diff --git a/server/server.go b/server/server.go index 2c887cc..0801ed9 100644 --- a/server/server.go +++ b/server/server.go @@ -1090,7 +1090,7 @@ func (s *Server) updateStatsAndPrune() { staleVisitors := 0 for ip, v := range s.visitors { if v.Stale() { - log.Debug("Deleting stale visitor %s", v.ip) + log.Trace("Deleting stale visitor %s", v.ip) delete(s.visitors, ip) staleVisitors++ } @@ -1131,13 +1131,14 @@ func (s *Server) updateStatsAndPrune() { messages += count } - // Prune old topics, remove subscriptions without subscribers + // Remove subscriptions without subscribers s.mu.Lock() var subscribers int for _, t := range s.topics { - subs := t.Subscribers() + subs := t.SubscribersCount() msgs, exists := messageCounts[t.ID] if subs == 0 && (!exists || msgs == 0) { + log.Trace("Deleting empty topic %s", t.ID) delete(s.topics, t.ID) continue } diff --git a/server/topic.go b/server/topic.go index 889f1eb..8ce7953 100644 --- a/server/topic.go +++ b/server/topic.go @@ -44,11 +44,12 @@ func (t *topic) Unsubscribe(id int) { // Publish asynchronously publishes to all subscribers func (t *topic) Publish(v *visitor, m *message) error { go func() { - t.mu.Lock() - defer t.mu.Unlock() - if len(t.subscribers) > 0 { - log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(t.subscribers)) - for _, s := range t.subscribers { + // We want to lock the topic as short as possible, so we make a shallow copy of the + // subscribers map here. Actually sending out the messages then doesn't have to lock. + subscribers := t.subscribersCopy() + if len(subscribers) > 0 { + log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(subscribers)) + for _, s := range subscribers { if err := s(v, m); err != nil { log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m)) } @@ -60,9 +61,20 @@ func (t *topic) Publish(v *visitor, m *message) error { return nil } -// Subscribers returns the number of subscribers to this topic -func (t *topic) Subscribers() int { +// SubscribersCount returns the number of subscribers to this topic +func (t *topic) SubscribersCount() int { t.mu.Lock() defer t.mu.Unlock() return len(t.subscribers) } + +// subscribersCopy returns a shallow copy of the subscribers map +func (t *topic) subscribersCopy() map[int]subscriber { + t.mu.Lock() + defer t.mu.Unlock() + subscribers := make(map[int]subscriber) + for k, v := range t.subscribers { + subscribers[k] = v + } + return subscribers +} From edfc1b78a1952d67a1ae94e9401b2110a0bdf2e2 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Tue, 21 Jun 2022 20:07:08 -0400 Subject: [PATCH 3/5] Delayed message lock shorter --- server/server.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/server.go b/server/server.go index 0801ed9..680296c 100644 --- a/server/server.go +++ b/server/server.go @@ -1080,7 +1080,8 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { } func (s *Server) updateStatsAndPrune() { - log.Debug("Manager: Running cleanup") + log.Debug("Manager: Starting") + defer log.Debug("Manager: Finished") // WARNING: Make sure to only selectively lock with the mutex, and be aware that this // there is no mutex for the entire function. @@ -1232,10 +1233,10 @@ func (s *Server) sendDelayedMessages() error { } func (s *Server) sendDelayedMessage(v *visitor, m *message) error { - s.mu.Lock() - defer s.mu.Unlock() log.Debug("%s Sending delayed message", logMessagePrefix(v, m)) + s.mu.Lock() t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published + s.mu.Unlock() if ok { go func() { // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler From ed9d99fd57d225d41e2c86fc551077f841e2f6cd Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Wed, 22 Jun 2022 13:47:54 -0400 Subject: [PATCH 4/5] "Fix" data race --- server/server.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/server/server.go b/server/server.go index 680296c..e79118b 100644 --- a/server/server.go +++ b/server/server.go @@ -798,6 +798,13 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * return err } var wlock sync.Mutex + defer func() { + // Hack: This is the fix for a horrible data race that I have not been able to figure out in quite some time. + // It appears to be happening when the Go HTTP code reads from the socket when closing the request (i.e. AFTER + // this function returns), and causes a data race with the ResponseWriter. Locking wlock here silences the + // data race detector. See https://github.com/binwiederhier/ntfy/issues/338#issuecomment-1163425889. + wlock.TryLock() + }() sub := func(v *visitor, msg *message) error { if !filters.Pass(msg) { return nil From 9cee8ab8887dd24c2f1e89a5e4125c5d2c24bf80 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Wed, 22 Jun 2022 13:52:49 -0400 Subject: [PATCH 5/5] Call subscriber funtions in individual goroutines --- server/topic.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/server/topic.go b/server/topic.go index 8ce7953..3bc7473 100644 --- a/server/topic.go +++ b/server/topic.go @@ -50,9 +50,13 @@ func (t *topic) Publish(v *visitor, m *message) error { if len(subscribers) > 0 { log.Debug("%s Forwarding to %d subscriber(s)", logMessagePrefix(v, m), len(subscribers)) for _, s := range subscribers { - if err := s(v, m); err != nil { - log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m)) - } + // We call the subscriber functions in their own Go routines because they are blocking, and + // we don't want individual slow subscribers to be able to block others. + go func(s subscriber) { + if err := s(v, m); err != nil { + log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m)) + } + }(s) } } else { log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m))