Expire attachments properly

This commit is contained in:
Philipp Heckel 2022-01-07 15:15:33 +01:00
parent c45a28e6af
commit e7c19a2bad
5 changed files with 101 additions and 14 deletions

View file

@ -21,4 +21,5 @@ type cache interface {
Prune(olderThan time.Time) error Prune(olderThan time.Time) error
MarkPublished(m *message) error MarkPublished(m *message) error
AttachmentsSize(owner string) (int64, error) AttachmentsSize(owner string) (int64, error)
AttachmentsExpired() ([]string, error)
} }

View file

@ -139,6 +139,20 @@ func (c *memCache) AttachmentsSize(owner string) (int64, error) {
return size, nil return size, nil
} }
func (c *memCache) AttachmentsExpired() ([]string, error) {
c.mu.Lock()
defer c.mu.Unlock()
ids := make([]string, 0)
for topic := range c.messages {
for _, m := range c.messages[topic] {
if m.Attachment != nil && m.Attachment.Expires > 0 && m.Attachment.Expires < time.Now().Unix() {
ids = append(ids, m.ID)
}
}
}
return ids, nil
}
func (c *memCache) pruneTopic(topic string, olderThan time.Time) { func (c *memCache) pruneTopic(topic string, olderThan time.Time) {
messages := make([]*message, 0) messages := make([]*message, 0)
for _, m := range c.messages[topic] { for _, m := range c.messages[topic] {

View file

@ -60,7 +60,8 @@ const (
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE attachment_owner = ?` selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE attachment_owner = ? AND attachment_expires >= ?`
selectAttachmentsExpiredQuery = `SELECT id FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?`
) )
// Schema management queries // Schema management queries
@ -234,7 +235,7 @@ func (c *sqliteCache) Prune(olderThan time.Time) error {
} }
func (c *sqliteCache) AttachmentsSize(owner string) (int64, error) { func (c *sqliteCache) AttachmentsSize(owner string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeQuery, owner) rows, err := c.db.Query(selectAttachmentsSizeQuery, owner, time.Now().Unix())
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -251,6 +252,26 @@ func (c *sqliteCache) AttachmentsSize(owner string) (int64, error) {
return size, nil return size, nil
} }
func (c *sqliteCache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
}
func readMessages(rows *sql.Rows) ([]*message, error) { func readMessages(rows *sql.Rows) ([]*message, error) {
defer rows.Close() defer rows.Close()
messages := make([]*message, 0) messages := make([]*message, 0)

View file

@ -28,18 +28,10 @@ func newFileCache(dir string, totalSizeLimit int64, fileSizeLimit int64) (*fileC
if err := os.MkdirAll(dir, 0700); err != nil { if err := os.MkdirAll(dir, 0700); err != nil {
return nil, err return nil, err
} }
entries, err := os.ReadDir(dir) size, err := dirSize(dir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var size int64
for _, e := range entries {
info, err := e.Info()
if err != nil {
return nil, err
}
size += info.Size()
}
return &fileCache{ return &fileCache{
dir: dir, dir: dir,
totalSizeCurrent: size, totalSizeCurrent: size,
@ -58,8 +50,8 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i
return 0, err return 0, err
} }
defer f.Close() defer f.Close()
log.Printf("remaining total: %d", c.remainingTotalSize()) log.Printf("remaining total: %d", c.Remaining())
limiters = append(limiters, util.NewLimiter(c.remainingTotalSize()), util.NewLimiter(c.fileSizeLimit)) limiters = append(limiters, util.NewLimiter(c.Remaining()), util.NewLimiter(c.fileSizeLimit))
limitWriter := util.NewLimitWriter(f, limiters...) limitWriter := util.NewLimitWriter(f, limiters...)
size, err := io.Copy(limitWriter, in) size, err := io.Copy(limitWriter, in)
if err != nil { if err != nil {
@ -77,7 +69,40 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i
} }
func (c *fileCache) remainingTotalSize() int64 { func (c *fileCache) Remove(ids []string) error {
var firstErr error
for _, id := range ids {
if err := c.removeFile(id); err != nil {
if firstErr == nil {
firstErr = err // Continue despite error; we want to delete as many as we can
}
}
}
size, err := dirSize(c.dir)
if err != nil {
return err
}
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
return firstErr
}
func (c *fileCache) removeFile(id string) error {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
file := filepath.Join(c.dir, id)
return os.Remove(file)
}
func (c *fileCache) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *fileCache) Remaining() int64 {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent remaining := c.totalSizeLimit - c.totalSizeCurrent
@ -86,3 +111,19 @@ func (c *fileCache) remainingTotalSize() int64 {
} }
return remaining return remaining
} }
func dirSize(dir string) (int64, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return 0, err
}
var size int64
for _, e := range entries {
info, err := e.Info()
if err != nil {
return 0, err
}
size += info.Size()
}
return size, nil
}

View file

@ -832,6 +832,16 @@ func (s *Server) updateStatsAndPrune() {
} }
} }
// Delete expired attachments
ids, err := s.cache.AttachmentsExpired()
if err == nil {
if err := s.fileCache.Remove(ids); err != nil {
log.Printf("error while deleting attachments: %s", err.Error())
}
} else {
log.Printf("error retrieving expired attachments: %s", err.Error())
}
// Prune message cache // Prune message cache
olderThan := time.Now().Add(-1 * s.config.CacheDuration) olderThan := time.Now().Add(-1 * s.config.CacheDuration)
if err := s.cache.Prune(olderThan); err != nil { if err := s.cache.Prune(olderThan); err != nil {