From 6be95f828561da88b2e83e591da5b3f456d34f8b Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 20 Apr 2023 22:04:11 -0400 Subject: [PATCH] WIP: persist message stats --- server/message_cache.go | 77 ++++++++++++++++++++++++++++++++++------ server/server.go | 62 ++++++++++++++++++++++---------- server/server_manager.go | 17 +++++++-- server/types.go | 5 +++ 4 files changed, 130 insertions(+), 31 deletions(-) diff --git a/server/message_cache.go b/server/message_cache.go index a8b8ff1..1d7302a 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -17,6 +17,7 @@ import ( var ( errUnexpectedMessageType = errors.New("unexpected message type") errMessageNotFound = errors.New("message not found") + errNoRows = errors.New("no rows found") ) // Messages cache @@ -54,6 +55,11 @@ const ( CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); CREATE INDEX IF NOT EXISTS idx_user ON messages (user); CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); + CREATE TABLE IF NOT EXISTS stats ( + key TEXT PRIMARY KEY, + value INT + ); + INSERT INTO stats (key, value) VALUES ('messages', 0); COMMIT; ` insertMessageQuery = ` @@ -108,11 +114,14 @@ const ( selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0` selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?` selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` + + selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'` + updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'` ) // Schema management queries const ( - currentSchemaVersion = 10 + currentSchemaVersion = 11 createSchemaVersionTableQuery = ` CREATE TABLE IF NOT EXISTS schemaVersion ( id INT PRIMARY KEY, @@ -222,20 +231,30 @@ const ( CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); ` migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?` + + // 10 -> 11 + migrate10To11AlterMessagesTableQuery = ` + CREATE TABLE IF NOT EXISTS stats ( + key TEXT PRIMARY KEY, + value INT + ); + INSERT INTO stats (key, value) VALUES ('messages', 0); + ` ) var ( migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{ - 0: migrateFrom0, - 1: migrateFrom1, - 2: migrateFrom2, - 3: migrateFrom3, - 4: migrateFrom4, - 5: migrateFrom5, - 6: migrateFrom6, - 7: migrateFrom7, - 8: migrateFrom8, - 9: migrateFrom9, + 0: migrateFrom0, + 1: migrateFrom1, + 2: migrateFrom2, + 3: migrateFrom3, + 4: migrateFrom4, + 5: migrateFrom5, + 6: migrateFrom6, + 7: migrateFrom7, + 8: migrateFrom8, + 9: migrateFrom9, + 10: migrateFrom10, } ) @@ -706,6 +725,26 @@ func readMessage(rows *sql.Rows) (*message, error) { }, nil } +func (c *messageCache) UpdateStats(messages int64) error { + _, err := c.db.Exec(updateStatsQuery, messages) + return err +} + +func (c *messageCache) Stats() (messages int64, err error) { + rows, err := c.db.Query(selectStatsQuery) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + if err := rows.Scan(&messages); err != nil { + return 0, err + } + return messages, nil +} + func (c *messageCache) Close() error { return c.db.Close() } @@ -889,3 +928,19 @@ func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error { } return tx.Commit() } + +func migrateFrom10(db *sql.DB, cacheDuration time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 11); err != nil { + return err + } + return tx.Commit() +} diff --git a/server/server.go b/server/server.go index c86307d..e545e94 100644 --- a/server/server.go +++ b/server/server.go @@ -48,7 +48,8 @@ type Server struct { topics map[string]*topic visitors map[string]*visitor // ip: or user: firebaseClient *firebaseClient - messages int64 + messages int64 // Total number of messages (persisted if messageCache enabled) + messagesHistory []int64 // Last n values of the messages counter, used to determine rate userManager *user.Manager // Might be nil! messageCache *messageCache // Database that stores the messages fileCache *fileCache // File system based cache that stores attachments @@ -56,7 +57,7 @@ type Server struct { priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set closeChan chan bool - mu sync.Mutex + mu sync.RWMutex } // handleFunc extends the normal http.HandlerFunc to be able to easily return errors @@ -79,7 +80,8 @@ var ( matrixPushPath = "/_matrix/push/v1/notify" metricsPath = "/metrics" apiHealthPath = "/v1/health" - apiTiers = "/v1/tiers" + apiStatsPath = "/v1/stats" + apiTiersPath = "/v1/tiers" apiAccountPath = "/v1/account" apiAccountTokenPath = "/v1/account/token" apiAccountPasswordPath = "/v1/account/password" @@ -116,9 +118,10 @@ const ( newMessageBody = "New message" // Used in poll requests as generic message defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages - jsonBodyBytesLimit = 16384 - unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber - unifiedPushTopicLength = 14 + jsonBodyBytesLimit = 16384 // Max number of bytes for a JSON request body + unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber + unifiedPushTopicLength = 14 // Length of UnifiedPush topics, including the "up" part + messagesHistoryMax = 10 // Number of message count values to keep in memory ) // WebSocket constants @@ -148,6 +151,10 @@ func New(conf *Config) (*Server, error) { if err != nil { return nil, err } + messages, err := messageCache.Stats() + if err != nil { + return nil, err + } var fileCache *fileCache if conf.AttachmentCacheDir != "" { fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit) @@ -177,15 +184,17 @@ func New(conf *Config) (*Server, error) { firebaseClient = newFirebaseClient(sender, auther) } s := &Server{ - config: conf, - messageCache: messageCache, - fileCache: fileCache, - firebaseClient: firebaseClient, - smtpSender: mailer, - topics: topics, - userManager: userManager, - visitors: make(map[string]*visitor), - stripe: stripe, + config: conf, + messageCache: messageCache, + fileCache: fileCache, + firebaseClient: firebaseClient, + smtpSender: mailer, + topics: topics, + userManager: userManager, + messages: messages, + messagesHistory: []int64{messages}, + visitors: make(map[string]*visitor), + stripe: stripe, } s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration) return s, nil @@ -441,7 +450,9 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath { return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe! - } else if r.Method == http.MethodGet && r.URL.Path == apiTiers { + } else if r.Method == http.MethodGet && r.URL.Path == apiStatsPath { + return s.handleStats(w, r, v) + } else if r.Method == http.MethodGet && r.URL.Path == apiTiersPath { return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v) } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath { return s.handleMatrixDiscovery(w) @@ -546,17 +557,32 @@ func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visito return nil } +// handleStatic returns all static resources (excluding the docs), including the web app func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error { r.URL.Path = webSiteDir + r.URL.Path util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r) return nil } +// handleDocs returns static resources related to the docs func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error { util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r) return nil } +// handleStats returns the publicly available server stats +func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor) error { + s.mu.RLock() + n := len(s.messagesHistory) + rate := float64(s.messagesHistory[n-1]-s.messagesHistory[0]) / (float64(n-1) * s.config.ManagerInterval.Seconds()) + response := &apiStatsResponse{ + Messages: s.messages, + MessagesRate: rate, + } + s.mu.RUnlock() + return s.writeJSON(w, response) +} + // handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file. // Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it // can associate the download bandwidth with the uploader. @@ -1580,9 +1606,9 @@ func (s *Server) sendDelayedMessages() error { func (s *Server) sendDelayedMessage(v *visitor, m *message) error { logvm(v, m).Debug("Sending delayed message") - s.mu.Lock() + s.mu.RLock() t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published - s.mu.Unlock() + s.mu.RUnlock() if ok { go func() { // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler diff --git a/server/server_manager.go b/server/server_manager.go index 891366f..445f830 100644 --- a/server/server_manager.go +++ b/server/server_manager.go @@ -73,9 +73,9 @@ func (s *Server) execManager() { } // Print stats - s.mu.Lock() + s.mu.RLock() messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors) - s.mu.Unlock() + s.mu.RUnlock() log. Tag(tagManager). Fields(log.Context{ @@ -98,6 +98,19 @@ func (s *Server) execManager() { mset(metricUsers, usersCount) mset(metricSubscribers, subscribers) mset(metricTopics, topicsCount) + + // Write stats + s.mu.Lock() + s.messagesHistory = append(s.messagesHistory, messagesCount) + if len(s.messagesHistory) > messagesHistoryMax { + s.messagesHistory = s.messagesHistory[1:] + } + s.mu.Unlock() + go func() { + if err := s.messageCache.UpdateStats(messagesCount); err != nil { + log.Tag(tagManager).Err(err).Warn("Cannot write messages stats") + } + }() } func (s *Server) pruneVisitors() { diff --git a/server/types.go b/server/types.go index b11424f..563cafb 100644 --- a/server/types.go +++ b/server/types.go @@ -239,6 +239,11 @@ type apiHealthResponse struct { Healthy bool `json:"healthy"` } +type apiStatsResponse struct { + Messages int64 `json:"messages"` + MessagesRate float64 `json:"messages_rate"` // Average number of messages per second +} + type apiAccountCreateRequest struct { Username string `json:"username"` Password string `json:"password"`