Write stats to user table asynchronously

This commit is contained in:
binwiederhier 2022-12-20 21:18:33 -05:00
parent 2f567af80b
commit cc55bec521
6 changed files with 102 additions and 51 deletions

View file

@ -17,6 +17,7 @@ type Manager interface {
CreateToken(user *User) (string, error) CreateToken(user *User) (string, error)
RemoveToken(user *User) error RemoveToken(user *User) error
ChangeSettings(user *User) error ChangeSettings(user *User) error
EnqueueUpdateStats(user *User)
// Authorize returns nil if the given user has access to the given topic using the desired // Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user. // permission. The user param may be nil to signal an anonymous user.
@ -65,6 +66,7 @@ type User struct {
Grants []Grant Grants []Grant
Prefs *UserPrefs Prefs *UserPrefs
Plan *Plan Plan *Plan
Stats *Stats
} }
type UserPrefs struct { type UserPrefs struct {
@ -102,6 +104,11 @@ type UserNotificationPrefs struct {
DeleteAfter int `json:"delete_after,omitempty"` DeleteAfter int `json:"delete_after,omitempty"`
} }
type Stats struct {
Messages int64
Emails int64
}
// Grant is a struct that represents an access control entry to a topic // Grant is a struct that represents an access control entry to a topic
type Grant struct { type Grant struct {
TopicPattern string // May include wildcard (*) TopicPattern string // May include wildcard (*)

View file

@ -7,14 +7,18 @@ import (
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"strings" "strings"
"sync"
"time"
) )
const ( const (
tokenLength = 32 tokenLength = 32
bcryptCost = 10 bcryptCost = 10
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
statsWriterInterval = 10 * time.Second
) )
// Manager-related queries // Manager-related queries
@ -36,6 +40,8 @@ const (
user TEXT NOT NULL, user TEXT NOT NULL,
pass TEXT NOT NULL, pass TEXT NOT NULL,
role TEXT NOT NULL, role TEXT NOT NULL,
messages INT NOT NULL DEFAULT (0),
emails INT NOT NULL DEFAULT (0),
settings JSON, settings JSON,
FOREIGN KEY (plan_id) REFERENCES plan (id) FOREIGN KEY (plan_id) REFERENCES plan (id)
); );
@ -46,13 +52,14 @@ const (
read INT NOT NULL, read INT NOT NULL,
write INT NOT NULL, write INT NOT NULL,
PRIMARY KEY (user_id, topic), PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
); );
CREATE TABLE IF NOT EXISTS user_token ( CREATE TABLE IF NOT EXISTS user_token (
user_id INT NOT NULL, user_id INT NOT NULL,
token TEXT NOT NULL, token TEXT NOT NULL,
expires INT NOT NULL, expires INT NOT NULL,
PRIMARY KEY (user_id, token) PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
); );
CREATE TABLE IF NOT EXISTS schemaVersion ( CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY, id INT PRIMARY KEY,
@ -62,13 +69,13 @@ const (
COMMIT; COMMIT;
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.settings, p.code, p.messages_limit, p.emails_limit, p.attachment_file_size_limit, p.attachment_total_size_limit SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.emails_limit, p.attachment_file_size_limit, p.attachment_total_size_limit
FROM user u FROM user u
LEFT JOIN plan p on p.id = u.plan_id LEFT JOIN plan p on p.id = u.plan_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.settings, p.code, p.messages_limit, p.emails_limit, p.attachment_file_size_limit, p.attachment_total_size_limit SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.emails_limit, p.attachment_file_size_limit, p.attachment_total_size_limit
FROM user u FROM user u
JOIN user_token t on u.id = t.user_id JOIN user_token t on u.id = t.user_id
LEFT JOIN plan p on p.id = u.plan_id LEFT JOIN plan p on p.id = u.plan_id
@ -90,6 +97,7 @@ const (
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?` updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?` updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserSettingsQuery = `UPDATE user SET settings = ? WHERE user = ?` updateUserSettingsQuery = `UPDATE user SET settings = ? WHERE user = ?`
updateUserStatsQuery = `UPDATE user SET messages = ?, emails = ? WHERE user = ?`
deleteUserQuery = `DELETE FROM user WHERE user = ?` deleteUserQuery = `DELETE FROM user WHERE user = ?`
upsertUserAccessQuery = `INSERT INTO user_access (user_id, topic, read, write) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?)` upsertUserAccessQuery = `INSERT INTO user_access (user_id, topic, read, write) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?)`
@ -116,6 +124,8 @@ type SQLiteAuthManager struct {
db *sql.DB db *sql.DB
defaultRead bool defaultRead bool
defaultWrite bool defaultWrite bool
statsQueue map[string]*Stats // Username -> Stats
mu sync.Mutex
} }
var _ Manager = (*SQLiteAuthManager)(nil) var _ Manager = (*SQLiteAuthManager)(nil)
@ -129,11 +139,14 @@ func NewSQLiteAuthManager(filename string, defaultRead, defaultWrite bool) (*SQL
if err := setupAuthDB(db); err != nil { if err := setupAuthDB(db); err != nil {
return nil, err return nil, err
} }
return &SQLiteAuthManager{ manager := &SQLiteAuthManager{
db: db, db: db,
defaultRead: defaultRead, defaultRead: defaultRead,
defaultWrite: defaultWrite, defaultWrite: defaultWrite,
}, nil statsQueue: make(map[string]*Stats),
}
go manager.statsWriter()
return manager, nil
} }
// Authenticate checks username and password and returns a user if correct. The method // Authenticate checks username and password and returns a user if correct. The method
@ -194,6 +207,39 @@ func (a *SQLiteAuthManager) ChangeSettings(user *User) error {
return nil return nil
} }
func (a *SQLiteAuthManager) EnqueueUpdateStats(user *User) {
a.mu.Lock()
defer a.mu.Unlock()
a.statsQueue[user.Name] = user.Stats
}
func (a *SQLiteAuthManager) statsWriter() {
ticker := time.NewTicker(statsWriterInterval)
for range ticker.C {
if err := a.writeStats(); err != nil {
log.Warn("UserManager: Writing user stats failed: %s", err.Error())
}
}
}
func (a *SQLiteAuthManager) writeStats() error {
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
a.mu.Lock()
statsQueue := a.statsQueue
a.statsQueue = make(map[string]*Stats)
a.mu.Unlock()
for username, stats := range statsQueue {
if _, err := tx.Exec(updateUserStatsQuery, stats.Messages, stats.Emails, username); err != nil {
return err
}
}
return tx.Commit()
}
// Authorize returns nil if the given user has access to the given topic using the desired // Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user. // permission. The user param may be nil to signal an anonymous user.
func (a *SQLiteAuthManager) Authorize(user *User, topic string, perm Permission) error { func (a *SQLiteAuthManager) Authorize(user *User, topic string, perm Permission) error {
@ -325,12 +371,13 @@ func (a *SQLiteAuthManager) userByToken(token string) (*User, error) {
func (a *SQLiteAuthManager) readUser(rows *sql.Rows) (*User, error) { func (a *SQLiteAuthManager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var username, hash, role string var username, hash, role string
var prefs, planCode sql.NullString var settings, planCode sql.NullString
var messages, emails int64
var messagesLimit, emailsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit sql.NullInt64 var messagesLimit, emailsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrNotFound return nil, ErrNotFound
} }
if err := rows.Scan(&username, &hash, &role, &prefs, &planCode, &messagesLimit, &emailsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit); err != nil { if err := rows.Scan(&username, &hash, &role, &messages, &emails, &settings, &planCode, &messagesLimit, &emailsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -344,10 +391,14 @@ func (a *SQLiteAuthManager) readUser(rows *sql.Rows) (*User, error) {
Hash: hash, Hash: hash,
Role: Role(role), Role: Role(role),
Grants: grants, Grants: grants,
Stats: &Stats{
Messages: messages,
Emails: emails,
},
} }
if prefs.Valid { if settings.Valid {
user.Prefs = &UserPrefs{} user.Prefs = &UserPrefs{}
if err := json.Unmarshal([]byte(prefs.String), user.Prefs); err != nil { if err := json.Unmarshal([]byte(settings.String), user.Prefs); err != nil {
return nil, err return nil, err
} }
} }

View file

@ -36,7 +36,7 @@ import (
/* /*
TODO TODO
persist user stats in user table publishXHR + poll should pick current user, not from userManager
expire tokens expire tokens
auto-refresh tokens from UI auto-refresh tokens from UI
reserve topics reserve topics
@ -498,7 +498,6 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
m = newPollRequestMessage(t.ID, m.PollID) m = newPollRequestMessage(t.ID, m.PollID)
} }
if v.user != nil { if v.user != nil {
log.Info("user is %s", v.user.Name)
m.User = v.user.Name m.User = v.user.Name
} }
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
@ -537,6 +536,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
} }
} }
v.IncrMessages() v.IncrMessages()
if v.user != nil {
s.auth.EnqueueUpdateStats(v.user)
}
s.mu.Lock() s.mu.Lock()
s.messages++ s.messages++
s.mu.Unlock() s.mu.Unlock()
@ -772,14 +774,14 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
} else if m.Time > time.Now().Add(s.config.AttachmentExpiryDuration).Unix() { } else if m.Time > time.Now().Add(s.config.AttachmentExpiryDuration).Unix() {
return errHTTPBadRequestAttachmentsExpiryBeforeDelivery return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
} }
visitorStats, err := v.Stats() stats, err := v.Stats()
if err != nil { if err != nil {
return err return err
} }
contentLengthStr := r.Header.Get("Content-Length") contentLengthStr := r.Header.Get("Content-Length")
if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
if err == nil && (contentLength > visitorStats.AttachmentTotalSizeRemaining || contentLength > s.config.AttachmentFileSizeLimit) { if err == nil && (contentLength > stats.AttachmentTotalSizeRemaining || contentLength > stats.AttachmentFileSizeLimit) {
return errHTTPEntityTooLargeAttachmentTooLarge return errHTTPEntityTooLargeAttachmentTooLarge
} }
} }
@ -797,7 +799,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
if m.Message == "" { if m.Message == "" {
m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name) m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
} }
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, v.BandwidthLimiter(), util.NewFixedLimiter(visitorStats.AttachmentTotalSizeRemaining)) m.Attachment.Size, err = s.fileCache.Write(m.ID, body, v.BandwidthLimiter(), util.NewFixedLimiter(stats.AttachmentTotalSizeRemaining))
if err == util.ErrLimitReached { if err == util.ErrLimitReached {
return errHTTPEntityTooLargeAttachmentTooLarge return errHTTPEntityTooLargeAttachmentTooLarge
} else if err != nil { } else if err != nil {
@ -1446,33 +1448,11 @@ func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc
} }
} }
// extractUserPass reads the username/password from the basic auth header (Authorization: Basic ...),
// or from the ?auth=... query param. The latter is required only to support the WebSocket JavaScript
// class, which does not support passing headers during the initial request. The auth query param
// is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
func extractUserPass(r *http.Request) (username string, password string, ok bool) {
username, password, ok = r.BasicAuth()
if ok {
return
}
authParam := readQueryParam(r, "authorization", "auth")
if authParam != "" {
a, err := base64.RawURLEncoding.DecodeString(authParam)
if err != nil {
return
}
r.Header.Set("Authorization", string(a))
return r.BasicAuth()
}
return
}
// visitor creates or retrieves a rate.Limiter for the given visitor. // visitor creates or retrieves a rate.Limiter for the given visitor.
// Note that this function will always return a visitor, even if an error occurs. // Note that this function will always return a visitor, even if an error occurs.
func (s *Server) visitor(r *http.Request) (v *visitor, err error) { func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
ip := s.extractIPAddress(r) ip := s.extractIPAddress(r)
visitorID := fmt.Sprintf("ip:%s", ip.String()) visitorID := fmt.Sprintf("ip:%s", ip.String())
var user *auth.User // may stay nil if no auth header! var user *auth.User // may stay nil if no auth header!
if user, err = s.authenticate(r); err != nil { if user, err = s.authenticate(r); err != nil {
log.Debug("authentication failed: %s", err.Error()) log.Debug("authentication failed: %s", err.Error())
@ -1486,6 +1466,10 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
return v, err // Always return visitor, even when error occurs! return v, err // Always return visitor, even when error occurs!
} }
// authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
// The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
// support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) { func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
value := r.Header.Get("Authorization") value := r.Header.Get("Authorization")
queryParam := readQueryParam(r, "authorization", "auth") queryParam := readQueryParam(r, "authorization", "auth")

View file

@ -28,11 +28,11 @@ type visitor struct {
messageCache *messageCache messageCache *messageCache
ip netip.Addr ip netip.Addr
user *auth.User user *auth.User
messages int64 messages int64 // Number of messages sent
emails int64 emails int64 // Number of emails sent
requestLimiter *rate.Limiter requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages)
emailsLimiter *rate.Limiter emailsLimiter *rate.Limiter // Rate limiter for emails
subscriptionLimiter util.Limiter subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections)
bandwidthLimiter util.Limiter bandwidthLimiter util.Limiter
firebase time.Time // Next allowed Firebase message firebase time.Time // Next allowed Firebase message
seen time.Time seen time.Time
@ -55,6 +55,11 @@ type visitorStats struct {
func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr, user *auth.User) *visitor { func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr, user *auth.User) *visitor {
var requestLimiter, emailsLimiter *rate.Limiter var requestLimiter, emailsLimiter *rate.Limiter
var messages, emails int64
if user != nil {
messages = user.Stats.Messages
emails = user.Stats.Emails
}
if user != nil && user.Plan != nil { if user != nil && user.Plan != nil {
requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Plan.MessagesLimit), conf.VisitorRequestLimitBurst) requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Plan.MessagesLimit), conf.VisitorRequestLimitBurst)
emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Plan.EmailsLimit), conf.VisitorEmailLimitBurst) emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Plan.EmailsLimit), conf.VisitorEmailLimitBurst)
@ -67,8 +72,8 @@ func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr, user *a
messageCache: messageCache, messageCache: messageCache,
ip: ip, ip: ip,
user: user, user: user,
messages: 0, // TODO messages: messages,
emails: 0, // TODO emails: emails,
requestLimiter: requestLimiter, requestLimiter: requestLimiter,
emailsLimiter: emailsLimiter, emailsLimiter: emailsLimiter,
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
@ -142,12 +147,18 @@ func (v *visitor) IncrMessages() {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()
v.messages++ v.messages++
if v.user != nil {
v.user.Stats.Messages = v.messages
}
} }
func (v *visitor) IncrEmails() { func (v *visitor) IncrEmails() {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()
v.emails++ v.emails++
if v.user != nil {
v.user.Stats.Emails = v.emails
}
} }
func (v *visitor) Stats() (*visitorStats, error) { func (v *visitor) Stats() (*visitorStats, error) {
@ -186,9 +197,9 @@ func (v *visitor) Stats() (*visitorStats, error) {
return nil, err return nil, err
} }
stats.Messages = messages stats.Messages = messages
stats.MessagesRemaining = zeroIfNegative(stats.MessagesLimit - stats.MessagesLimit) stats.MessagesRemaining = zeroIfNegative(stats.MessagesLimit - stats.Messages)
stats.Emails = emails stats.Emails = emails
stats.EmailsRemaining = zeroIfNegative(stats.EmailsLimit - stats.EmailsRemaining) stats.EmailsRemaining = zeroIfNegative(stats.EmailsLimit - stats.Emails)
stats.AttachmentTotalSize = attachmentsBytesUsed stats.AttachmentTotalSize = attachmentsBytesUsed
stats.AttachmentTotalSizeRemaining = zeroIfNegative(stats.AttachmentTotalSizeLimit - stats.AttachmentTotalSize) stats.AttachmentTotalSizeRemaining = zeroIfNegative(stats.AttachmentTotalSizeLimit - stats.AttachmentTotalSize)
return stats, nil return stats, nil

View file

@ -62,7 +62,7 @@
"publish_dialog_progress_uploading_detail": "Hochladen {{loaded}}/{{total}} ({{percent}} %) …", "publish_dialog_progress_uploading_detail": "Hochladen {{loaded}}/{{total}} ({{percent}} %) …",
"publish_dialog_priority_max": "Max. Priorität", "publish_dialog_priority_max": "Max. Priorität",
"publish_dialog_topic_placeholder": "Thema, z.B. phil_alerts", "publish_dialog_topic_placeholder": "Thema, z.B. phil_alerts",
"publish_dialog_attachment_limits_file_reached": "überschreitet das Dateigrößen-Limit {{filesizeLimit}}", "publish_dialog_attachment_limits_file_reached": "überschreitet das Dateigrößen-Limit {{fileSizeLimit}}",
"publish_dialog_topic_label": "Thema", "publish_dialog_topic_label": "Thema",
"publish_dialog_priority_default": "Standard-Priorität", "publish_dialog_priority_default": "Standard-Priorität",
"publish_dialog_base_url_placeholder": "Service-URL, z.B. https://example.com", "publish_dialog_base_url_placeholder": "Service-URL, z.B. https://example.com",

View file

@ -162,9 +162,7 @@ const PublishDialog = (props) => {
try { try {
const account = await api.getAccount(baseUrl, session.token()); const account = await api.getAccount(baseUrl, session.token());
const fileSizeLimit = account.limits.attachment_file_size ?? 0; const fileSizeLimit = account.limits.attachment_file_size ?? 0;
const totalSizeLimit = account.limits.attachment_total_size ?? 0; const remainingBytes = account.stats.attachment_total_size_remaining;
const usedSize = account.usage.attachments_size ?? 0;
const remainingBytes = (totalSizeLimit > 0) ? totalSizeLimit - usedSize : 0;
const fileSizeLimitReached = fileSizeLimit > 0 && file.size > fileSizeLimit; const fileSizeLimitReached = fileSizeLimit > 0 && file.size > fileSizeLimit;
const quotaReached = remainingBytes > 0 && file.size > remainingBytes; const quotaReached = remainingBytes > 0 && file.size > remainingBytes;
if (fileSizeLimitReached && quotaReached) { if (fileSizeLimitReached && quotaReached) {
@ -179,7 +177,7 @@ const PublishDialog = (props) => {
} }
setAttachFileError(""); setAttachFileError("");
} catch (e) { } catch (e) {
console.log(`[SendDialog] Retrieving attachment limits failed`, e); console.log(`[PublishDialog] Retrieving attachment limits failed`, e);
setAttachFileError(""); // Reset error (rely on server-side checking) setAttachFileError(""); // Reset error (rely on server-side checking)
} }
}; };