WIP: persist message stats

This commit is contained in:
binwiederhier 2023-04-20 22:04:11 -04:00
parent 4783cb1211
commit 6be95f8285
4 changed files with 130 additions and 31 deletions

View file

@ -17,6 +17,7 @@ import (
var ( var (
errUnexpectedMessageType = errors.New("unexpected message type") errUnexpectedMessageType = errors.New("unexpected message type")
errMessageNotFound = errors.New("message not found") errMessageNotFound = errors.New("message not found")
errNoRows = errors.New("no rows found")
) )
// Messages cache // Messages cache
@ -54,6 +55,11 @@ const (
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user); CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
COMMIT; COMMIT;
` `
insertMessageQuery = ` insertMessageQuery = `
@ -108,11 +114,14 @@ const (
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0` selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?` selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
) )
// Schema management queries // Schema management queries
const ( const (
currentSchemaVersion = 10 currentSchemaVersion = 11
createSchemaVersionTableQuery = ` createSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion ( CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY, id INT PRIMARY KEY,
@ -222,20 +231,30 @@ const (
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
` `
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?` migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
// 10 -> 11
migrate10To11AlterMessagesTableQuery = `
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
`
) )
var ( var (
migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{ migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
0: migrateFrom0, 0: migrateFrom0,
1: migrateFrom1, 1: migrateFrom1,
2: migrateFrom2, 2: migrateFrom2,
3: migrateFrom3, 3: migrateFrom3,
4: migrateFrom4, 4: migrateFrom4,
5: migrateFrom5, 5: migrateFrom5,
6: migrateFrom6, 6: migrateFrom6,
7: migrateFrom7, 7: migrateFrom7,
8: migrateFrom8, 8: migrateFrom8,
9: migrateFrom9, 9: migrateFrom9,
10: migrateFrom10,
} }
) )
@ -706,6 +725,26 @@ func readMessage(rows *sql.Rows) (*message, error) {
}, nil }, nil
} }
func (c *messageCache) UpdateStats(messages int64) error {
_, err := c.db.Exec(updateStatsQuery, messages)
return err
}
func (c *messageCache) Stats() (messages int64, err error) {
rows, err := c.db.Query(selectStatsQuery)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
if err := rows.Scan(&messages); err != nil {
return 0, err
}
return messages, nil
}
func (c *messageCache) Close() error { func (c *messageCache) Close() error {
return c.db.Close() return c.db.Close()
} }
@ -889,3 +928,19 @@ func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
} }
return tx.Commit() return tx.Commit()
} }
func migrateFrom10(db *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
return err
}
return tx.Commit()
}

View file

@ -48,7 +48,8 @@ type Server struct {
topics map[string]*topic topics map[string]*topic
visitors map[string]*visitor // ip:<ip> or user:<user> visitors map[string]*visitor // ip:<ip> or user:<user>
firebaseClient *firebaseClient firebaseClient *firebaseClient
messages int64 messages int64 // Total number of messages (persisted if messageCache enabled)
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
userManager *user.Manager // Might be nil! userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages messageCache *messageCache // Database that stores the messages
fileCache *fileCache // File system based cache that stores attachments fileCache *fileCache // File system based cache that stores attachments
@ -56,7 +57,7 @@ type Server struct {
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
closeChan chan bool closeChan chan bool
mu sync.Mutex mu sync.RWMutex
} }
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
@ -79,7 +80,8 @@ var (
matrixPushPath = "/_matrix/push/v1/notify" matrixPushPath = "/_matrix/push/v1/notify"
metricsPath = "/metrics" metricsPath = "/metrics"
apiHealthPath = "/v1/health" apiHealthPath = "/v1/health"
apiTiers = "/v1/tiers" apiStatsPath = "/v1/stats"
apiTiersPath = "/v1/tiers"
apiAccountPath = "/v1/account" apiAccountPath = "/v1/account"
apiAccountTokenPath = "/v1/account/token" apiAccountTokenPath = "/v1/account/token"
apiAccountPasswordPath = "/v1/account/password" apiAccountPasswordPath = "/v1/account/password"
@ -116,9 +118,10 @@ const (
newMessageBody = "New message" // Used in poll requests as generic message newMessageBody = "New message" // Used in poll requests as generic message
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
jsonBodyBytesLimit = 16384 jsonBodyBytesLimit = 16384 // Max number of bytes for a JSON request body
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
unifiedPushTopicLength = 14 unifiedPushTopicLength = 14 // Length of UnifiedPush topics, including the "up" part
messagesHistoryMax = 10 // Number of message count values to keep in memory
) )
// WebSocket constants // WebSocket constants
@ -148,6 +151,10 @@ func New(conf *Config) (*Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
messages, err := messageCache.Stats()
if err != nil {
return nil, err
}
var fileCache *fileCache var fileCache *fileCache
if conf.AttachmentCacheDir != "" { if conf.AttachmentCacheDir != "" {
fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit) fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
@ -177,15 +184,17 @@ func New(conf *Config) (*Server, error) {
firebaseClient = newFirebaseClient(sender, auther) firebaseClient = newFirebaseClient(sender, auther)
} }
s := &Server{ s := &Server{
config: conf, config: conf,
messageCache: messageCache, messageCache: messageCache,
fileCache: fileCache, fileCache: fileCache,
firebaseClient: firebaseClient, firebaseClient: firebaseClient,
smtpSender: mailer, smtpSender: mailer,
topics: topics, topics: topics,
userManager: userManager, userManager: userManager,
visitors: make(map[string]*visitor), messages: messages,
stripe: stripe, messagesHistory: []int64{messages},
visitors: make(map[string]*visitor),
stripe: stripe,
} }
s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration) s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
return s, nil return s, nil
@ -441,7 +450,9 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v) return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath { } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe! return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
} else if r.Method == http.MethodGet && r.URL.Path == apiTiers { } else if r.Method == http.MethodGet && r.URL.Path == apiStatsPath {
return s.handleStats(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiTiersPath {
return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v) return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath { } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w) return s.handleMatrixDiscovery(w)
@ -546,17 +557,32 @@ func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visito
return nil return nil
} }
// handleStatic returns all static resources (excluding the docs), including the web app
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error { func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
r.URL.Path = webSiteDir + r.URL.Path r.URL.Path = webSiteDir + r.URL.Path
util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r) util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
return nil return nil
} }
// handleDocs returns static resources related to the docs
func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error { func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r) util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
return nil return nil
} }
// handleStats returns the publicly available server stats
func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
s.mu.RLock()
n := len(s.messagesHistory)
rate := float64(s.messagesHistory[n-1]-s.messagesHistory[0]) / (float64(n-1) * s.config.ManagerInterval.Seconds())
response := &apiStatsResponse{
Messages: s.messages,
MessagesRate: rate,
}
s.mu.RUnlock()
return s.writeJSON(w, response)
}
// handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file. // handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file.
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it // Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
// can associate the download bandwidth with the uploader. // can associate the download bandwidth with the uploader.
@ -1580,9 +1606,9 @@ func (s *Server) sendDelayedMessages() error {
func (s *Server) sendDelayedMessage(v *visitor, m *message) error { func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
logvm(v, m).Debug("Sending delayed message") logvm(v, m).Debug("Sending delayed message")
s.mu.Lock() s.mu.RLock()
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
s.mu.Unlock() s.mu.RUnlock()
if ok { if ok {
go func() { go func() {
// We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler

View file

@ -73,9 +73,9 @@ func (s *Server) execManager() {
} }
// Print stats // Print stats
s.mu.Lock() s.mu.RLock()
messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors) messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
s.mu.Unlock() s.mu.RUnlock()
log. log.
Tag(tagManager). Tag(tagManager).
Fields(log.Context{ Fields(log.Context{
@ -98,6 +98,19 @@ func (s *Server) execManager() {
mset(metricUsers, usersCount) mset(metricUsers, usersCount)
mset(metricSubscribers, subscribers) mset(metricSubscribers, subscribers)
mset(metricTopics, topicsCount) mset(metricTopics, topicsCount)
// Write stats
s.mu.Lock()
s.messagesHistory = append(s.messagesHistory, messagesCount)
if len(s.messagesHistory) > messagesHistoryMax {
s.messagesHistory = s.messagesHistory[1:]
}
s.mu.Unlock()
go func() {
if err := s.messageCache.UpdateStats(messagesCount); err != nil {
log.Tag(tagManager).Err(err).Warn("Cannot write messages stats")
}
}()
} }
func (s *Server) pruneVisitors() { func (s *Server) pruneVisitors() {

View file

@ -239,6 +239,11 @@ type apiHealthResponse struct {
Healthy bool `json:"healthy"` Healthy bool `json:"healthy"`
} }
type apiStatsResponse struct {
Messages int64 `json:"messages"`
MessagesRate float64 `json:"messages_rate"` // Average number of messages per second
}
type apiAccountCreateRequest struct { type apiAccountCreateRequest struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`