diff --git a/server/message_cache.go b/server/message_cache.go index 376c761..7172e1b 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -40,6 +40,7 @@ const ( attachment_expires INT NOT NULL, attachment_url TEXT NOT NULL, sender TEXT NOT NULL, + user TEXT NOT NULL, encoding TEXT NOT NULL, published INT NOT NULL ); @@ -49,46 +50,47 @@ const ( COMMIT; ` insertMessageQuery = ` - INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding, published) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1` selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics selectMessagesSinceTimeQuery = ` - SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding + SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding FROM messages WHERE topic = ? AND time >= ? AND published = 1 ORDER BY time, id ` selectMessagesSinceTimeIncludeScheduledQuery = ` - SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding + SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding FROM messages WHERE topic = ? AND time >= ? ORDER BY time, id ` selectMessagesSinceIDQuery = ` - SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding + SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding FROM messages WHERE topic = ? AND id > ? AND published = 1 ORDER BY time, id ` selectMessagesSinceIDIncludeScheduledQuery = ` - SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding + SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding FROM messages WHERE topic = ? AND (id > ? OR published = 0) ORDER BY time, id ` selectMessagesDueQuery = ` - SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding + SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding FROM messages WHERE time <= ? AND published = 0 ORDER BY time, id ` - updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?` - selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` - selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` - selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` - selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` + updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?` + selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` + selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` + selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` + selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` + selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` ) // Schema management queries @@ -316,6 +318,7 @@ func (c *messageCache) addMessages(ms []*message) error { attachmentExpires, attachmentURL, sender, + m.User, m.Encoding, published, ) @@ -442,11 +445,23 @@ func (c *messageCache) Prune(olderThan time.Time) error { return nil } -func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) { - rows, err := c.db.Query(selectAttachmentsSizeQuery, sender, time.Now().Unix()) +func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) { + rows, err := c.db.Query(selectAttachmentsSizeBySenderQuery, sender, time.Now().Unix()) if err != nil { return 0, err } + return c.readAttachmentBytesUsed(rows) +} + +func (c *messageCache) AttachmentBytesUsedByUser(user string) (int64, error) { + rows, err := c.db.Query(selectAttachmentsSizeByUserQuery, user, time.Now().Unix()) + if err != nil { + return 0, err + } + return c.readAttachmentBytesUsed(rows) +} + +func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { defer rows.Close() var size int64 if !rows.Next() { @@ -477,7 +492,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { for rows.Next() { var timestamp, attachmentSize, attachmentExpires int64 var priority int - var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, encoding string + var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string err := rows.Scan( &id, ×tamp, @@ -495,6 +510,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { &attachmentExpires, &attachmentURL, &sender, + &user, &encoding, ) if err != nil { @@ -538,6 +554,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) { Actions: actions, Attachment: att, Sender: senderIP, // Must parse assuming database must be correct + User: user, Encoding: encoding, }) } @@ -598,6 +615,7 @@ func setupCacheDB(db *sql.DB, startupQueries string) error { } else if schemaVersion == 8 { return migrateFrom8(db) } + // TODO add user column return fmt.Errorf("unexpected schema version found: %d", schemaVersion) } diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 2dcd7b3..fc1f2e5 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -343,11 +343,11 @@ func testCacheAttachments(t *testing.T, c *messageCache) { require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) require.Equal(t, "1.2.3.4", messages[1].Sender.String()) - size, err := c.AttachmentBytesUsed("1.2.3.4") + size, err := c.AttachmentBytesUsedBySender("1.2.3.4") require.Nil(t, err) require.Equal(t, int64(30000), size) - size, err = c.AttachmentBytesUsed("5.6.7.8") + size, err = c.AttachmentBytesUsedBySender("5.6.7.8") require.Nil(t, err) require.Equal(t, int64(0), size) } diff --git a/server/server.go b/server/server.go index da168f3..29924b2 100644 --- a/server/server.go +++ b/server/server.go @@ -495,6 +495,10 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if m.PollID != "" { m = newPollRequestMessage(t.ID, m.PollID) } + if v.user != nil { + log.Info("user is %s", v.user.Name) + m.User = v.user.Name + } if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { return nil, err } @@ -502,8 +506,8 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes m.Message = emptyMessageBody } delayed := m.Time > time.Now().Unix() - log.Debug("%s Received message: event=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s", - logMessagePrefix(v, m), m.Event, len(m.Message), delayed, firebase, cache, unifiedpush, email) + log.Debug("%s Received message: event=%s, user=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s", + logMessagePrefix(v, m), m.Event, m.User, len(m.Message), delayed, firebase, cache, unifiedpush, email) if log.IsTrace() { log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m)) } diff --git a/server/server_account.go b/server/server_account.go index 65e911b..6529560 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -75,19 +75,18 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis Code: v.user.Plan.Code, Upgradable: v.user.Plan.Upgradable, } + } else if v.user.Role == auth.RoleAdmin { + response.Plan = &apiAccountPlan{ + Code: string(auth.PlanUnlimited), + Upgradable: false, + } } else { - if v.user.Role == auth.RoleAdmin { - response.Plan = &apiAccountPlan{ - Code: string(auth.PlanUnlimited), - Upgradable: false, - } - } else { - response.Plan = &apiAccountPlan{ - Code: string(auth.PlanDefault), - Upgradable: true, - } + response.Plan = &apiAccountPlan{ + Code: string(auth.PlanDefault), + Upgradable: true, } } + } else { response.Username = auth.Everyone response.Role = string(auth.RoleAnonymous) diff --git a/server/server_test.go b/server/server_test.go index 762e1d1..8f776a5 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1151,7 +1151,7 @@ func TestServer_PublishAttachment(t *testing.T) { require.Equal(t, "", response.Body.String()) // Slightly unrelated cross-test: make sure we add an owner for internal attachments - size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request() + size, err := s.messageCache.AttachmentBytesUsedBySender("9.9.9.9") // See request() require.Nil(t, err) require.Equal(t, int64(5000), size) } @@ -1180,7 +1180,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) { require.Equal(t, content, response.Body.String()) // Slightly unrelated cross-test: make sure we add an owner for internal attachments - size, err := s.messageCache.AttachmentBytesUsed("1.2.3.4") + size, err := s.messageCache.AttachmentBytesUsedBySender("1.2.3.4") require.Nil(t, err) require.Equal(t, int64(21), size) } @@ -1200,7 +1200,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) { require.Equal(t, netip.Addr{}, msg.Sender) // Slightly unrelated cross-test: make sure we don't add an owner for external attachments - size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1") + size, err := s.messageCache.AttachmentBytesUsedBySender("127.0.0.1") require.Nil(t, err) require.Equal(t, int64(0), size) } diff --git a/server/types.go b/server/types.go index 8cdfc43..bc66d55 100644 --- a/server/types.go +++ b/server/types.go @@ -36,8 +36,9 @@ type message struct { Actions []*action `json:"actions,omitempty"` Attachment *attachment `json:"attachment,omitempty"` PollID string `json:"poll_id,omitempty"` - Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes + Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting + User string `json:"-"` // Username of the uploader, used to associated attachments } type attachment struct { diff --git a/server/visitor.go b/server/visitor.go index 0e81c0f..c4840b9 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -151,12 +151,10 @@ func (v *visitor) IncrEmails() { } func (v *visitor) Stats() (*visitorStats, error) { - attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String()) - if err != nil { - return nil, err - } v.mu.Lock() - defer v.mu.Unlock() + messages := v.messages + emails := v.emails + v.mu.Unlock() stats := &visitorStats{} if v.user != nil && v.user.Role == auth.RoleAdmin { stats.Basis = "role" @@ -174,12 +172,22 @@ func (v *visitor) Stats() (*visitorStats, error) { stats.Basis = "ip" stats.MessagesLimit = replenishDurationToDailyLimit(v.config.VisitorRequestLimitReplenish) stats.EmailsLimit = replenishDurationToDailyLimit(v.config.VisitorEmailLimitReplenish) - stats.AttachmentTotalSizeLimit = v.config.AttachmentTotalSizeLimit + stats.AttachmentTotalSizeLimit = v.config.VisitorAttachmentTotalSizeLimit stats.AttachmentFileSizeLimit = v.config.AttachmentFileSizeLimit } - stats.Messages = v.messages + var attachmentsBytesUsed int64 + var err error + if v.user != nil { + attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.Name) + } else { + attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String()) + } + if err != nil { + return nil, err + } + stats.Messages = messages stats.MessagesRemaining = zeroIfNegative(stats.MessagesLimit - stats.MessagesLimit) - stats.Emails = v.emails + stats.Emails = emails stats.EmailsRemaining = zeroIfNegative(stats.EmailsLimit - stats.EmailsRemaining) stats.AttachmentTotalSize = attachmentsBytesUsed stats.AttachmentTotalSizeRemaining = zeroIfNegative(stats.AttachmentTotalSizeLimit - stats.AttachmentTotalSize)