Merge branch 'main' into unix-socket

This commit is contained in:
Philipp Heckel 2022-01-14 20:16:12 -05:00
commit b079cb99a4
36 changed files with 2109 additions and 354 deletions

View file

@ -20,4 +20,6 @@ type cache interface {
Topics() (map[string]*topic, error)
Prune(olderThan time.Time) error
MarkPublished(m *message) error
AttachmentsSize(owner string) (int64, error)
AttachmentsExpired() ([]string, error)
}

View file

@ -125,6 +125,35 @@ func (c *memCache) Prune(olderThan time.Time) error {
return nil
}
func (c *memCache) AttachmentsSize(owner string) (int64, error) {
c.mu.Lock()
defer c.mu.Unlock()
var size int64
for topic := range c.messages {
for _, m := range c.messages[topic] {
counted := m.Attachment != nil && m.Attachment.Owner == owner && m.Attachment.Expires > time.Now().Unix()
if counted {
size += m.Attachment.Size
}
}
}
return size, nil
}
func (c *memCache) AttachmentsExpired() ([]string, error) {
c.mu.Lock()
defer c.mu.Unlock()
ids := make([]string, 0)
for topic := range c.messages {
for _, m := range c.messages[topic] {
if m.Attachment != nil && m.Attachment.Expires > 0 && m.Attachment.Expires < time.Now().Unix() {
ids = append(ids, m.ID)
}
}
}
return ids, nil
}
func (c *memCache) pruneTopic(topic string, olderThan time.Time) {
messages := make([]*message, 0)
for _, m := range c.messages[topic] {

View file

@ -25,6 +25,10 @@ func TestMemCache_Prune(t *testing.T) {
testCachePrune(t, newMemCache())
}
func TestMemCache_Attachments(t *testing.T) {
testCacheAttachments(t, newMemCache())
}
func TestMemCache_NopCache(t *testing.T) {
c := newNopCache()
assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message")))

View file

@ -23,27 +23,36 @@ const (
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_owner TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
COMMIT;
`
insertMessageQuery = `INSERT INTO messages (id, time, topic, message, title, priority, tags, click, published) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`
insertMessageQuery = `
INSERT INTO messages (id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1`
selectMessagesSinceTimeQuery = `
SELECT id, time, topic, message, title, priority, tags, click
SELECT id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner
FROM messages
WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time ASC
`
selectMessagesSinceTimeIncludeScheduledQuery = `
SELECT id, time, topic, message, title, priority, tags, click
SELECT id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time ASC
`
selectMessagesDueQuery = `
SELECT id, time, topic, message, title, priority, tags, click
SELECT id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner
FROM messages
WHERE time <= ? AND published = 0
`
@ -51,6 +60,8 @@ const (
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE attachment_owner = ? AND attachment_expires >= ?`
selectAttachmentsExpiredQuery = `SELECT id FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?`
)
// Schema management queries
@ -82,7 +93,15 @@ const (
// 2 -> 3
migrate2To3AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
COMMIT;
`
)
@ -110,7 +129,35 @@ func (c *sqliteCache) AddMessage(m *message) error {
return errUnexpectedMessageType
}
published := m.Time <= time.Now().Unix()
_, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message, m.Title, m.Priority, strings.Join(m.Tags, ","), m.Click, published)
tags := strings.Join(m.Tags, ",")
var attachmentName, attachmentType, attachmentURL, attachmentOwner string
var attachmentSize, attachmentExpires int64
if m.Attachment != nil {
attachmentName = m.Attachment.Name
attachmentType = m.Attachment.Type
attachmentSize = m.Attachment.Size
attachmentExpires = m.Attachment.Expires
attachmentURL = m.Attachment.URL
attachmentOwner = m.Attachment.Owner
}
_, err := c.db.Exec(
insertMessageQuery,
m.ID,
m.Time,
m.Topic,
m.Message,
m.Title,
m.Priority,
tags,
m.Click,
attachmentName,
attachmentType,
attachmentSize,
attachmentExpires,
attachmentURL,
attachmentOwner,
published,
)
return err
}
@ -187,30 +234,80 @@ func (c *sqliteCache) Prune(olderThan time.Time) error {
return err
}
func (c *sqliteCache) AttachmentsSize(owner string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeQuery, owner, time.Now().Unix())
if err != nil {
return 0, err
}
defer rows.Close()
var size int64
if !rows.Next() {
return 0, errors.New("no rows found")
}
if err := rows.Scan(&size); err != nil {
return 0, err
} else if err := rows.Err(); err != nil {
return 0, err
}
return size, nil
}
func (c *sqliteCache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
}
func readMessages(rows *sql.Rows) ([]*message, error) {
defer rows.Close()
messages := make([]*message, 0)
for rows.Next() {
var timestamp int64
var timestamp, attachmentSize, attachmentExpires int64
var priority int
var id, topic, msg, title, tagsStr, click string
if err := rows.Scan(&id, &timestamp, &topic, &msg, &title, &priority, &tagsStr, &click); err != nil {
var id, topic, msg, title, tagsStr, click, attachmentName, attachmentType, attachmentURL, attachmentOwner string
if err := rows.Scan(&id, &timestamp, &topic, &msg, &title, &priority, &tagsStr, &click, &attachmentName, &attachmentType, &attachmentSize, &attachmentExpires, &attachmentURL, &attachmentOwner); err != nil {
return nil, err
}
var tags []string
if tagsStr != "" {
tags = strings.Split(tagsStr, ",")
}
var att *attachment
if attachmentName != "" && attachmentURL != "" {
att = &attachment{
Name: attachmentName,
Type: attachmentType,
Size: attachmentSize,
Expires: attachmentExpires,
URL: attachmentURL,
Owner: attachmentOwner,
}
}
messages = append(messages, &message{
ID: id,
Time: timestamp,
Event: messageEvent,
Topic: topic,
Message: msg,
Title: title,
Priority: priority,
Tags: tags,
Click: click,
ID: id,
Time: timestamp,
Event: messageEvent,
Topic: topic,
Message: msg,
Title: title,
Priority: priority,
Tags: tags,
Click: click,
Attachment: att,
})
}
if err := rows.Err(); err != nil {

View file

@ -29,6 +29,10 @@ func TestSqliteCache_Prune(t *testing.T) {
testCachePrune(t, newSqliteTestCache(t))
}
func TestSqliteCache_Attachments(t *testing.T) {
testCacheAttachments(t, newSqliteTestCache(t))
}
func TestSqliteCache_Migration_From0(t *testing.T) {
filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename)

View file

@ -1,7 +1,7 @@
package server
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
@ -13,71 +13,71 @@ func testCacheMessages(t *testing.T, c cache) {
m2 := newDefaultMessage("mytopic", "my other message")
m2.Time = 2
assert.Nil(t, c.AddMessage(m1))
assert.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message")))
assert.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message")))
require.Nil(t, c.AddMessage(m2))
// Adding invalid
assert.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added!
assert.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added!
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
// mytopic: count
count, err := c.MessageCount("mytopic")
assert.Nil(t, err)
assert.Equal(t, 2, count)
require.Nil(t, err)
require.Equal(t, 2, count)
// mytopic: since all
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
assert.Equal(t, 2, len(messages))
assert.Equal(t, "my message", messages[0].Message)
assert.Equal(t, "mytopic", messages[0].Topic)
assert.Equal(t, messageEvent, messages[0].Event)
assert.Equal(t, "", messages[0].Title)
assert.Equal(t, 0, messages[0].Priority)
assert.Nil(t, messages[0].Tags)
assert.Equal(t, "my other message", messages[1].Message)
require.Equal(t, 2, len(messages))
require.Equal(t, "my message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
require.Equal(t, messageEvent, messages[0].Event)
require.Equal(t, "", messages[0].Title)
require.Equal(t, 0, messages[0].Priority)
require.Nil(t, messages[0].Tags)
require.Equal(t, "my other message", messages[1].Message)
// mytopic: since none
messages, _ = c.Messages("mytopic", sinceNoMessages, false)
assert.Empty(t, messages)
require.Empty(t, messages)
// mytopic: since 2
messages, _ = c.Messages("mytopic", sinceTime(time.Unix(2, 0)), false)
assert.Equal(t, 1, len(messages))
assert.Equal(t, "my other message", messages[0].Message)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// example: count
count, err = c.MessageCount("example")
assert.Nil(t, err)
assert.Equal(t, 1, count)
require.Nil(t, err)
require.Equal(t, 1, count)
// example: since all
messages, _ = c.Messages("example", sinceAllMessages, false)
assert.Equal(t, "my example message", messages[0].Message)
require.Equal(t, "my example message", messages[0].Message)
// non-existing: count
count, err = c.MessageCount("doesnotexist")
assert.Nil(t, err)
assert.Equal(t, 0, count)
require.Nil(t, err)
require.Equal(t, 0, count)
// non-existing: since all
messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
assert.Empty(t, messages)
require.Empty(t, messages)
}
func testCacheTopics(t *testing.T, c cache) {
assert.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3")))
topics, err := c.Topics()
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 2, len(topics))
assert.Equal(t, "topic1", topics["topic1"].ID)
assert.Equal(t, "topic2", topics["topic2"].ID)
require.Equal(t, 2, len(topics))
require.Equal(t, "topic1", topics["topic1"].ID)
require.Equal(t, "topic2", topics["topic2"].ID)
}
func testCachePrune(t *testing.T, c cache) {
@ -90,23 +90,23 @@ func testCachePrune(t *testing.T, c cache) {
m3 := newDefaultMessage("another_topic", "and another one")
m3.Time = 1
assert.Nil(t, c.AddMessage(m1))
assert.Nil(t, c.AddMessage(m2))
assert.Nil(t, c.AddMessage(m3))
assert.Nil(t, c.Prune(time.Unix(2, 0)))
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m3))
require.Nil(t, c.Prune(time.Unix(2, 0)))
count, err := c.MessageCount("mytopic")
assert.Nil(t, err)
assert.Equal(t, 1, count)
require.Nil(t, err)
require.Equal(t, 1, count)
count, err = c.MessageCount("another_topic")
assert.Nil(t, err)
assert.Equal(t, 0, count)
require.Nil(t, err)
require.Equal(t, 0, count)
messages, err := c.Messages("mytopic", sinceAllMessages, false)
assert.Nil(t, err)
assert.Equal(t, 1, len(messages))
assert.Equal(t, "my other message", messages[0].Message)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
}
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) {
@ -114,12 +114,12 @@ func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) {
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
m.Title = "some title"
assert.Nil(t, c.AddMessage(m))
require.Nil(t, c.AddMessage(m))
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
assert.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
assert.Equal(t, 5, messages[0].Priority)
assert.Equal(t, "some title", messages[0].Title)
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
require.Equal(t, 5, messages[0].Priority)
require.Equal(t, "some title", messages[0].Title)
}
func testCacheMessagesScheduled(t *testing.T, c cache) {
@ -130,20 +130,93 @@ func testCacheMessagesScheduled(t *testing.T, c cache) {
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
m4 := newDefaultMessage("mytopic2", "message 4")
m4.Time = time.Now().Add(time.Minute).Unix()
assert.Nil(t, c.AddMessage(m1))
assert.Nil(t, c.AddMessage(m2))
assert.Nil(t, c.AddMessage(m3))
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m3))
messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled
assert.Equal(t, 1, len(messages))
assert.Equal(t, "message 1", messages[0].Message)
require.Equal(t, 1, len(messages))
require.Equal(t, "message 1", messages[0].Message)
messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled
assert.Equal(t, 3, len(messages))
assert.Equal(t, "message 1", messages[0].Message)
assert.Equal(t, "message 3", messages[1].Message) // Order!
assert.Equal(t, "message 2", messages[2].Message)
require.Equal(t, 3, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message) // Order!
require.Equal(t, "message 2", messages[2].Message)
messages, _ = c.MessagesDue()
assert.Empty(t, messages)
require.Empty(t, messages)
}
func testCacheAttachments(t *testing.T, c cache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix()
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.Attachment = &attachment{
Name: "flower.jpg",
Type: "image/jpeg",
Size: 5000,
Expires: expires1,
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
Owner: "1.2.3.4",
}
require.Nil(t, c.AddMessage(m))
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = newDefaultMessage("mytopic", "sending you a car")
m.ID = "m2"
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: expires2,
URL: "https://ntfy.sh/file/aCaRURL.jpg",
Owner: "1.2.3.4",
}
require.Nil(t, c.AddMessage(m))
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = newDefaultMessage("another-topic", "sending you another car")
m.ID = "m3"
m.Attachment = &attachment{
Name: "another-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: expires3,
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
Owner: "1.2.3.4",
}
require.Nil(t, c.AddMessage(m))
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, "flower for you", messages[0].Message)
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Attachment.Owner)
require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Attachment.Owner)
size, err := c.AttachmentsSize("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(30000), size)
size, err = c.AttachmentsSize("5.6.7.8")
require.Nil(t, err)
require.Equal(t, int64(0), size)
ids, err := c.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, []string{"m1"}, ids)
}

View file

@ -4,7 +4,7 @@ import (
"time"
)
// Defines default config settings
// Defines default config settings (excluding limits, see below)
const (
DefaultListenHTTP = ":80"
DefaultCacheDuration = 12 * time.Hour
@ -13,82 +13,109 @@ const (
DefaultAtSenderInterval = 10 * time.Second
DefaultMinDelay = 10 * time.Second
DefaultMaxDelay = 3 * 24 * time.Hour
DefaultMessageLimit = 4096
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // Not too frequently to save battery
)
// Defines all the limits
// - global topic limit: max number of topics overall
// Defines all global and per-visitor limits
// - message size limit: the max number of bytes for a message
// - total topic limit: max number of topics overall
// - various attachment limits
const (
DefaultMessageLengthLimit = 4096 // Bytes
DefaultTotalTopicLimit = 15000
DefaultAttachmentTotalSizeLimit = int64(5 * 1024 * 1024 * 1024) // 5 GB
DefaultAttachmentFileSizeLimit = int64(15 * 1024 * 1024) // 15 MB
DefaultAttachmentExpiryDuration = 3 * time.Hour
)
// Defines all per-visitor limits
// - per visitor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
// - per visitor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds)
// - per visitor email limit: max number of emails (here: 16 email bucket, replenished at a rate of one per hour)
// - per visitor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
// - per visitor attachment size limit: total per-visitor attachment size in bytes to be stored on the server
// - per visitor attachment daily bandwidth limit: number of bytes that can be transferred to/from the server
const (
DefaultGlobalTopicLimit = 5000
DefaultVisitorRequestLimitBurst = 60
DefaultVisitorRequestLimitReplenish = 10 * time.Second
DefaultVisitorEmailLimitBurst = 16
DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorSubscriptionLimit = 30
DefaultVisitorSubscriptionLimit = 30
DefaultVisitorRequestLimitBurst = 60
DefaultVisitorRequestLimitReplenish = 10 * time.Second
DefaultVisitorEmailLimitBurst = 16
DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
)
// Config is the main config struct for the application. Use New to instantiate a default config struct.
type Config struct {
BaseURL string
ListenHTTP string
ListenHTTPS string
ListenUnix string
KeyFile string
CertFile string
FirebaseKeyFile string
CacheFile string
CacheDuration time.Duration
KeepaliveInterval time.Duration
ManagerInterval time.Duration
AtSenderInterval time.Duration
FirebaseKeepaliveInterval time.Duration
SMTPSenderAddr string
SMTPSenderUser string
SMTPSenderPass string
SMTPSenderFrom string
SMTPServerListen string
SMTPServerDomain string
SMTPServerAddrPrefix string
MessageLimit int
MinDelay time.Duration
MaxDelay time.Duration
GlobalTopicLimit int
VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration
VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration
VisitorSubscriptionLimit int
BehindProxy bool
BaseURL string
ListenHTTP string
ListenHTTPS string
ListenUnix string
KeyFile string
CertFile string
FirebaseKeyFile string
CacheFile string
CacheDuration time.Duration
AttachmentCacheDir string
AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64
AttachmentExpiryDuration time.Duration
KeepaliveInterval time.Duration
ManagerInterval time.Duration
AtSenderInterval time.Duration
FirebaseKeepaliveInterval time.Duration
SMTPSenderAddr string
SMTPSenderUser string
SMTPSenderPass string
SMTPSenderFrom string
SMTPServerListen string
SMTPServerDomain string
SMTPServerAddrPrefix string
MessageLimit int
MinDelay time.Duration
MaxDelay time.Duration
TotalTopicLimit int
TotalAttachmentSizeLimit int64
VisitorSubscriptionLimit int
VisitorAttachmentTotalSizeLimit int64
VisitorAttachmentDailyBandwidthLimit int
VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration
VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration
BehindProxy bool
}
// NewConfig instantiates a default new server config
func NewConfig() *Config {
return &Config{
BaseURL: "",
ListenHTTP: DefaultListenHTTP,
ListenHTTPS: "",
KeyFile: "",
CertFile: "",
FirebaseKeyFile: "",
CacheFile: "",
CacheDuration: DefaultCacheDuration,
KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval,
MessageLimit: DefaultMessageLimit,
MinDelay: DefaultMinDelay,
MaxDelay: DefaultMaxDelay,
AtSenderInterval: DefaultAtSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
GlobalTopicLimit: DefaultGlobalTopicLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
BehindProxy: false,
BaseURL: "",
ListenHTTP: DefaultListenHTTP,
ListenHTTPS: "",
ListenUnix: "",
KeyFile: "",
CertFile: "",
FirebaseKeyFile: "",
CacheFile: "",
CacheDuration: DefaultCacheDuration,
AttachmentCacheDir: "",
AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: DefaultAttachmentFileSizeLimit,
AttachmentExpiryDuration: DefaultAttachmentExpiryDuration,
KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval,
MessageLimit: DefaultMessageLengthLimit,
MinDelay: DefaultMinDelay,
MaxDelay: DefaultMaxDelay,
AtSenderInterval: DefaultAtSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
TotalTopicLimit: DefaultTotalTopicLimit,
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,
VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
BehindProxy: false,
}
}

121
server/file_cache.go Normal file
View file

@ -0,0 +1,121 @@
package server
import (
"errors"
"heckel.io/ntfy/util"
"io"
"os"
"path/filepath"
"regexp"
"sync"
)
var (
fileIDRegex = regexp.MustCompile(`^[-_A-Za-z0-9]+$`)
errInvalidFileID = errors.New("invalid file ID")
errFileExists = errors.New("file exists")
)
type fileCache struct {
dir string
totalSizeCurrent int64
totalSizeLimit int64
fileSizeLimit int64
mu sync.Mutex
}
func newFileCache(dir string, totalSizeLimit int64, fileSizeLimit int64) (*fileCache, error) {
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, err
}
size, err := dirSize(dir)
if err != nil {
return nil, err
}
return &fileCache{
dir: dir,
totalSizeCurrent: size,
totalSizeLimit: totalSizeLimit,
fileSizeLimit: fileSizeLimit,
}, nil
}
func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
file := filepath.Join(c.dir, id)
if _, err := os.Stat(file); err == nil {
return 0, errFileExists
}
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return 0, err
}
defer f.Close()
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()), util.NewFixedLimiter(c.fileSizeLimit))
limitWriter := util.NewLimitWriter(f, limiters...)
size, err := io.Copy(limitWriter, in)
if err != nil {
os.Remove(file)
return 0, err
}
if err := f.Close(); err != nil {
os.Remove(file)
return 0, err
}
c.mu.Lock()
c.totalSizeCurrent += size
c.mu.Unlock()
return size, nil
}
func (c *fileCache) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
file := filepath.Join(c.dir, id)
_ = os.Remove(file) // Best effort delete
}
size, err := dirSize(c.dir)
if err != nil {
return err
}
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
return nil
}
func (c *fileCache) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *fileCache) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent
if remaining < 0 {
return 0
}
return remaining
}
func dirSize(dir string) (int64, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return 0, err
}
var size int64
for _, e := range entries {
info, err := e.Info()
if err != nil {
return 0, err
}
size += info.Size()
}
return size, nil
}

83
server/file_cache_test.go Normal file
View file

@ -0,0 +1,83 @@
package server
import (
"bytes"
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/util"
"os"
"strings"
"testing"
)
var (
oneKilobyteArray = make([]byte, 1024)
)
func TestFileCache_Write_Success(t *testing.T) {
dir, c := newTestFileCache(t)
size, err := c.Write("abc", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abc"))
require.Equal(t, int64(11), c.Size())
require.Equal(t, int64(10229), c.Remaining())
}
func TestFileCache_Write_Remove_Success(t *testing.T) {
dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(make([]byte, 999)))
require.Nil(t, err)
require.Equal(t, int64(999), size)
}
require.Equal(t, int64(9990), c.Size())
require.Equal(t, int64(250), c.Remaining())
require.FileExists(t, dir+"/abc1")
require.FileExists(t, dir+"/abc5")
require.Nil(t, c.Remove("abc1", "abc5"))
require.NoFileExists(t, dir+"/abc1")
require.NoFileExists(t, dir+"/abc5")
require.Equal(t, int64(7992), c.Size())
require.Equal(t, int64(2248), c.Remaining())
}
func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
dir, c := newTestFileCache(t)
for i := 0; i < 10; i++ {
size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(oneKilobyteArray))
require.Nil(t, err)
require.Equal(t, int64(1024), size)
}
_, err := c.Write("abc11", bytes.NewReader(oneKilobyteArray))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abc11")
}
func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abc", bytes.NewReader(make([]byte, 1025)))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abc")
}
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abc")
}
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
dir = t.TempDir()
cache, err := newFileCache(dir, 10*1024, 1*1024)
require.Nil(t, err)
return dir, cache
}
func readFile(t *testing.T, f string) string {
b, err := os.ReadFile(f)
require.Nil(t, err)
return string(b)
}

View file

@ -18,15 +18,25 @@ const (
// message represents a message published to a topic
type message struct {
ID string `json:"id"` // Random message ID
Time int64 `json:"time"` // Unix time in seconds
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Click string `json:"click,omitempty"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
ID string `json:"id"` // Random message ID
Time int64 `json:"time"` // Unix time in seconds
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Click string `json:"click,omitempty"`
Attachment *attachment `json:"attachment,omitempty"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
}
type attachment struct {
Name string `json:"name"`
Type string `json:"type,omitempty"`
Size int64 `json:"size,omitempty"`
Expires int64 `json:"expires,omitempty"`
URL string `json:"url"`
Owner string `json:"-"` // IP address of uploader, used for rate limiting
}
// messageEncoder is a function that knows how to encode a message

View file

@ -18,12 +18,16 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
// TODO add "max messages in a topic" limit
@ -43,6 +47,7 @@ type Server struct {
mailer mailer
messages int64
cache cache
fileCache *fileCache
closeChan chan bool
mu sync.Mutex
}
@ -98,7 +103,9 @@ var (
staticRegex = regexp.MustCompile(`^/static/.+`)
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
disallowedTopics = []string{"docs", "static"}
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
disallowedTopics = []string{"docs", "static", "file"}
attachURLRegex = regexp.MustCompile(`^https?://`)
templateFnMap = template.FuncMap{
"durationToHuman": util.DurationToHuman,
@ -119,28 +126,36 @@ var (
docsStaticFs embed.FS
docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitGlobalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPBadRequestEmailDisabled = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"}
errHTTPBadRequestDelayNoCache = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""}
errHTTPBadRequestDelayNoEmail = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""}
errHTTPBadRequestDelayCannotParse = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
errHTTPBadRequestDelayTooSmall = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
errHTTPBadRequestDelayTooLarge = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
errHTTPBadRequestPriorityInvalid = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"}
errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPBadRequestEmailDisabled = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"}
errHTTPBadRequestDelayNoCache = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""}
errHTTPBadRequestDelayNoEmail = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""}
errHTTPBadRequestDelayCannotParse = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
errHTTPBadRequestDelayTooSmall = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
errHTTPBadRequestDelayTooLarge = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
errHTTPBadRequestPriorityInvalid = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"}
errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
errHTTPBadRequestMessageNotUTF8 = &errHTTP{40011, http.StatusBadRequest, "invalid message: message must be UTF-8 encoded", ""}
errHTTPBadRequestAttachmentTooLarge = &errHTTP{40012, http.StatusBadRequest, "invalid request: attachment too large, or bandwidth limit reached", ""}
errHTTPBadRequestAttachmentURLInvalid = &errHTTP{40013, http.StatusBadRequest, "invalid request: attachment URL is invalid", ""}
errHTTPBadRequestAttachmentsDisallowed = &errHTTP{40014, http.StatusBadRequest, "invalid request: attachments not allowed", ""}
errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
)
const (
firebaseControlTopic = "~control" // See Android if changed
emptyMessageBody = "triggered"
fcmMessageLimit = 4000 // see maybeTruncateFCMMessage for details
firebaseControlTopic = "~control" // See Android if changed
emptyMessageBody = "triggered" // Used if message body is empty
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
fcmMessageLimit = 4000 // see maybeTruncateFCMMessage for details
)
// New instantiates a new Server. It creates the cache and adds a Firebase
@ -166,13 +181,21 @@ func New(conf *Config) (*Server, error) {
if err != nil {
return nil, err
}
var fileCache *fileCache
if conf.AttachmentCacheDir != "" {
fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, conf.AttachmentFileSizeLimit)
if err != nil {
return nil, err
}
}
return &Server{
config: conf,
cache: cache,
firebase: firebaseSubscriber,
mailer: mailer,
topics: topics,
visitors: make(map[string]*visitor),
config: conf,
cache: cache,
fileCache: fileCache,
firebase: firebaseSubscriber,
mailer: mailer,
topics: topics,
visitors: make(map[string]*visitor),
}, nil
}
@ -216,6 +239,13 @@ func createFirebaseSubscriber(conf *Config) (subscriber, error) {
"title": m.Title,
"message": m.Message,
}
if m.Attachment != nil {
data["attachment_name"] = m.Attachment.Name
data["attachment_type"] = m.Attachment.Type
data["attachment_size"] = fmt.Sprintf("%d", m.Attachment.Size)
data["attachment_expires"] = fmt.Sprintf("%d", m.Attachment.Expires)
data["attachment_url"] = m.Attachment.URL
}
}
var androidConfig *messaging.AndroidConfig
if m.Priority >= 4 {
@ -280,7 +310,7 @@ func (s *Server) Run() error {
}()
}
if s.config.ListenHTTPS != "" {
s.httpsServer = &http.Server{Addr: s.config.ListenHTTP, Handler: mux}
s.httpsServer = &http.Server{Addr: s.config.ListenHTTPS, Handler: mux}
go func() {
errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
}()
@ -339,7 +369,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
if e, ok = err.(*errHTTP); !ok {
e = errHTTPInternalError
}
log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, e.HTTPCode, err.Error())
log.Printf("[%s] %s - %d - %d - %s", r.RemoteAddr, r.Method, e.HTTPCode, e.Code, err.Error())
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.WriteHeader(e.HTTPCode)
@ -358,6 +388,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
return s.handleStatic(w, r)
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
return s.handleDocs(w, r)
} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
return s.withRateLimit(w, r, s.handleFile)
} else if r.Method == http.MethodOptions {
return s.handleOptions(w, r)
} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
@ -413,28 +445,49 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error {
return nil
}
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.AttachmentCacheDir == "" {
return errHTTPInternalError
}
matches := fileRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidFilePath
}
messageID := matches[1]
file := filepath.Join(s.config.AttachmentCacheDir, messageID)
stat, err := os.Stat(file)
if err != nil {
return errHTTPNotFound
}
if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
return errHTTPTooManyRequestsAttachmentBandwidthLimit
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
f, err := os.Open(file)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
return err
}
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
t, err := s.topicFromPath(r.URL.Path)
if err != nil {
return err
}
reader := io.LimitReader(r.Body, int64(s.config.MessageLimit))
b, err := io.ReadAll(reader)
body, err := util.Peak(r.Body, s.config.MessageLimit)
if err != nil {
return err
}
m := newDefaultMessage(t.ID, strings.TrimSpace(string(b)))
cache, firebase, email, err := s.parsePublishParams(r, m)
m := newDefaultMessage(t.ID, "")
cache, firebase, email, err := s.parsePublishParams(r, v, m)
if err != nil {
return err
}
if email != "" {
if err := v.EmailAllowed(); err != nil {
return errHTTPTooManyRequestsLimitEmails
}
}
if s.mailer == nil && email != "" {
return errHTTPBadRequestEmailDisabled
if err := s.handlePublishBody(r, v, m, body); err != nil {
return err
}
if m.Message == "" {
m.Message = emptyMessageBody
@ -473,12 +526,46 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
return nil
}
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) {
func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, err error) {
cache = readParam(r, "x-cache", "cache") != "no"
firebase = readParam(r, "x-firebase", "firebase") != "no"
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
m.Title = readParam(r, "x-title", "title", "t")
m.Click = readParam(r, "x-click", "click")
filename := readParam(r, "x-filename", "filename", "file", "f")
attach := readParam(r, "x-attach", "attach", "a")
if attach != "" || filename != "" {
m.Attachment = &attachment{}
}
if filename != "" {
m.Attachment.Name = filename
}
if attach != "" {
if !attachURLRegex.MatchString(attach) {
return false, false, "", errHTTPBadRequestAttachmentURLInvalid
}
m.Attachment.URL = attach
if m.Attachment.Name == "" {
u, err := url.Parse(m.Attachment.URL)
if err == nil {
m.Attachment.Name = path.Base(u.Path)
if m.Attachment.Name == "." || m.Attachment.Name == "/" {
m.Attachment.Name = ""
}
}
}
if m.Attachment.Name == "" {
m.Attachment.Name = "attachment"
}
}
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
if email != "" {
if err := v.EmailAllowed(); err != nil {
return false, false, "", errHTTPTooManyRequestsLimitEmails
}
}
if s.mailer == nil && email != "" {
return false, false, "", errHTTPBadRequestEmailDisabled
}
messageStr := readParam(r, "x-message", "message", "m")
if messageStr != "" {
m.Message = messageStr
@ -535,6 +622,81 @@ func readParam(r *http.Request, names ...string) string {
return ""
}
// handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
//
// 1. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
// Body must be a message, because we attached an external URL
// 2. curl -T short.txt -H "Filename: short.txt" ntfy.sh/mytopic
// Body must be attachment, because we passed a filename
// 3. curl -T file.txt ntfy.sh/mytopic
// If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
// 4. curl -T file.txt ntfy.sh/mytopic
// If file.txt is > message limit, treat it as an attachment
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser) error {
if m.Attachment != nil && m.Attachment.URL != "" {
return s.handleBodyAsMessage(m, body) // Case 1
} else if m.Attachment != nil && m.Attachment.Name != "" {
return s.handleBodyAsAttachment(r, v, m, body) // Case 2
} else if !body.LimitReached && utf8.Valid(body.PeakedBytes) {
return s.handleBodyAsMessage(m, body) // Case 3
}
return s.handleBodyAsAttachment(r, v, m, body) // Case 4
}
func (s *Server) handleBodyAsMessage(m *message, body *util.PeakedReadCloser) error {
if !utf8.Valid(body.PeakedBytes) {
return errHTTPBadRequestMessageNotUTF8
}
if len(body.PeakedBytes) > 0 { // Empty body should not override message (publish via GET!)
m.Message = strings.TrimSpace(string(body.PeakedBytes)) // Truncates the message to the peak limit if required
}
if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
}
return nil
}
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser) error {
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
return errHTTPBadRequestAttachmentsDisallowed
} else if m.Time > time.Now().Add(s.config.AttachmentExpiryDuration).Unix() {
return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
}
visitorAttachmentsSize, err := s.cache.AttachmentsSize(v.ip)
if err != nil {
return err
}
remainingVisitorAttachmentSize := s.config.VisitorAttachmentTotalSizeLimit - visitorAttachmentsSize
contentLengthStr := r.Header.Get("Content-Length")
if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
if err == nil && (contentLength > remainingVisitorAttachmentSize || contentLength > s.config.AttachmentFileSizeLimit) {
return errHTTPBadRequestAttachmentTooLarge
}
}
if m.Attachment == nil {
m.Attachment = &attachment{}
}
var ext string
m.Attachment.Owner = v.ip // Important for attachment rate limiting
m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix()
m.Attachment.Type, ext = util.DetectContentType(body.PeakedBytes, m.Attachment.Name)
m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
if m.Attachment.Name == "" {
m.Attachment.Name = fmt.Sprintf("attachment%s", ext)
}
if m.Message == "" {
m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
}
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, v.BandwidthLimiter(), util.NewFixedLimiter(remainingVisitorAttachmentSize))
if err == util.ErrLimitReached {
return errHTTPBadRequestAttachmentTooLarge
} else if err != nil {
return err
}
return nil
}
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) {
var buf bytes.Buffer
@ -748,8 +910,8 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
return nil, errHTTPBadRequestTopicDisallowed
}
if _, ok := s.topics[id]; !ok {
if len(s.topics) >= s.config.GlobalTopicLimit {
return nil, errHTTPTooManyRequestsLimitGlobalTopics
if len(s.topics) >= s.config.TotalTopicLimit {
return nil, errHTTPTooManyRequestsLimitTotalTopics
}
s.topics[id] = newTopic(id)
}
@ -769,6 +931,18 @@ func (s *Server) updateStatsAndPrune() {
}
}
// Delete expired attachments
if s.fileCache != nil {
ids, err := s.cache.AttachmentsExpired()
if err == nil {
if err := s.fileCache.Remove(ids...); err != nil {
log.Printf("error while deleting attachments: %s", err.Error())
}
} else {
log.Printf("error retrieving expired attachments: %s", err.Error())
}
}
// Prune message cache
olderThan := time.Now().Add(-1 * s.config.CacheDuration)
if err := s.cache.Prune(olderThan); err != nil {
@ -885,12 +1059,11 @@ func (s *Server) sendDelayedMessages() error {
if err := t.Publish(m); err != nil {
log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
}
if s.firebase != nil {
if err := s.firebase(m); err != nil {
log.Printf("unable to publish to Firebase: %v", err.Error())
}
}
if s.firebase != nil { // Firebase subscribers may not show up in topics map
if err := s.firebase(m); err != nil {
log.Printf("unable to publish to Firebase: %v", err.Error())
}
// TODO delayed email sending
}
if err := s.cache.MarkPublished(m); err != nil {
return err

View file

@ -8,9 +8,16 @@
# Listen address for the HTTP & HTTPS web server. If "listen-https" is set, you must also
# set "key-file" and "cert-file". Format: <hostname>:<port>
#
# To disable HTTP, set "listen-http" to "-".
#
# listen-http: ":80"
# listen-https:
# Listen on a Unix socket, e.g. /var/lib/ntfy/ntfy.sock
# This can be useful to avoid port issues on local systems, and to simplify permissions.
#
# listen-unix: <socket-path>
# Path to the private key & cert file for the HTTPS web server. Not used if "listen-https" is not set.
#
# key-file:
@ -36,7 +43,7 @@
#
# You can disable the cache entirely by setting this to 0.
#
# cache-duration: 12h
# cache-duration: "12h"
# If set, the X-Forwarded-For header is used to determine the visitor IP address
# instead of the remote address of the connection.
@ -46,6 +53,19 @@
#
# behind-proxy: false
# If enabled, clients can attach files to notifications as attachments. Minimum settings to enable attachments
# are "attachment-cache-dir" and "base-url".
#
# - attachment-cache-dir is the cache directory for attached files
# - attachment-total-size-limit is the limit of the on-disk attachment cache directory (total size)
# - attachment-file-size-limit is the per-file attachment size limit (e.g. 300k, 2M, 100M)
# - attachment-expiry-duration is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h)
#
# attachment-cache-dir:
# attachment-total-size-limit: "5G"
# attachment-file-size-limit: "15M"
# attachment-expiry-duration: "3h"
# If enabled, allow outgoing e-mail notifications via the 'X-Email' header. If this header is set,
# messages will additionally be sent out as e-mail using an external SMTP server. As of today, only
# SMTP servers with plain text auth and STARTLS are supported. Please also refer to the rate limiting settings
@ -78,16 +98,16 @@
#
# Note that the Android app has a hardcoded timeout at 77s, so it should be less than that.
#
# keepalive-interval: 30s
# keepalive-interval: "30s"
# Interval in which the manager prunes old messages, deletes topics
# and prints the stats.
#
# manager-interval: 1m
# manager-interval: "1m"
# Rate limiting: Total number of topics before the server rejects new topics.
#
# global-topic-limit: 5000
# global-topic-limit: 15000
# Rate limiting: Number of subscriptions per visitor (IP address)
#
@ -98,11 +118,18 @@
# - visitor-request-limit-replenish is the rate at which the bucket is refilled
#
# visitor-request-limit-burst: 60
# visitor-request-limit-replenish: 10s
# visitor-request-limit-replenish: "10s"
# Rate limiting: Allowed emails per visitor:
# - visitor-email-limit-burst is the initial bucket of emails each visitor has
# - visitor-email-limit-replenish is the rate at which the bucket is refilled
#
# visitor-email-limit-burst: 16
# visitor-email-limit-replenish: 1h
# visitor-email-limit-replenish: "1h"
# Rate limiting: Attachment size and bandwidth limits per visitor:
# - visitor-attachment-total-size-limit is the total storage limit used for attachments per visitor
# - visitor-attachment-daily-bandwidth-limit is the total daily attachment download/upload traffic limit per visitor
#
# visitor-attachment-total-size-limit: "100M"
# visitor-attachment-daily-bandwidth-limit: "500M"

View file

@ -7,6 +7,7 @@ import (
"firebase.google.com/go/messaging"
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/util"
"net/http"
"net/http/httptest"
"os"
@ -163,20 +164,13 @@ func TestServer_StaticSites(t *testing.T) {
}
func TestServer_PublishLargeMessage(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
c := newTestConfig(t)
c.AttachmentCacheDir = "" // Disable attachments
s := newTestServer(t, c)
body := strings.Repeat("this is a large message", 5000)
truncated := body[0:4096]
response := request(t, s, "PUT", "/mytopic", body, nil)
msg := toMessage(t, response.Body.String())
require.NotEmpty(t, msg.ID)
require.Equal(t, truncated, msg.Message)
require.Equal(t, 4096, len(msg.Message))
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
messages := toMessages(t, response.Body.String())
require.Equal(t, 1, len(messages))
require.Equal(t, truncated, messages[0].Message)
require.Equal(t, 400, response.Code)
}
func TestServer_PublishPriority(t *testing.T) {
@ -205,6 +199,9 @@ func TestServer_PublishPriority(t *testing.T) {
response = request(t, s, "GET", "/mytopic/trigger?priority=urgent", "test", nil)
require.Equal(t, 5, toMessage(t, response.Body.String()).Priority)
response = request(t, s, "GET", "/mytopic/trigger?priority=INVALID", "test", nil)
require.Equal(t, 40007, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_PublishNoCache(t *testing.T) {
@ -268,13 +265,28 @@ func TestServer_PublishAtTooShortDelay(t *testing.T) {
func TestServer_PublishAtTooLongDelay(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic", "a message", map[string]string{
"In": "99999999h",
})
require.Equal(t, 400, response.Code)
}
func TestServer_PublishAtInvalidDelay(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic?delay=INVALID", "a message", nil)
err := toHTTPError(t, response.Body.String())
require.Equal(t, 400, response.Code)
require.Equal(t, 40004, err.Code)
}
func TestServer_PublishAtTooLarge(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic?x-in=99999h", "a message", nil)
err := toHTTPError(t, response.Body.String())
require.Equal(t, 400, response.Code)
require.Equal(t, 40006, err.Code)
}
func TestServer_PublishAtAndPrune(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
@ -356,6 +368,19 @@ func TestServer_PublishAndPollSince(t *testing.T) {
messages := toMessages(t, response.Body.String())
require.Equal(t, 1, len(messages))
require.Equal(t, "test 2", messages[0].Message)
response = request(t, s, "GET", "/mytopic/json?poll=1&since=10s", "", nil)
messages = toMessages(t, response.Body.String())
require.Equal(t, 2, len(messages))
require.Equal(t, "test 1", messages[0].Message)
response = request(t, s, "GET", "/mytopic/json?poll=1&since=100ms", "", nil)
messages = toMessages(t, response.Body.String())
require.Equal(t, 1, len(messages))
require.Equal(t, "test 2", messages[0].Message)
response = request(t, s, "GET", "/mytopic/json?poll=1&since=INVALID", "", nil)
require.Equal(t, 40008, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_PublishViaGET(t *testing.T) {
@ -396,6 +421,13 @@ func TestServer_PublishFirebase(t *testing.T) {
time.Sleep(500 * time.Millisecond) // Time for sends
}
func TestServer_PublishInvalidTopic(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
s.mailer = &testMailer{}
response := request(t, s, "PUT", "/docs", "fail", nil)
require.Equal(t, 40010, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_PollWithQueryFilters(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
@ -649,9 +681,241 @@ func TestServer_MaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
}
func TestServer_PublishAttachment(t *testing.T) {
content := util.RandomString(5000) // > 4096
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic", content, nil)
msg := toMessage(t, response.Body.String())
require.Equal(t, "attachment.txt", msg.Attachment.Name)
require.Equal(t, "text/plain; charset=utf-8", msg.Attachment.Type)
require.Equal(t, int64(5000), msg.Attachment.Size)
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix())
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
require.Equal(t, "", msg.Attachment.Owner) // Should never be returned
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
response = request(t, s, "GET", path, "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, "5000", response.Header().Get("Content-Length"))
require.Equal(t, content, response.Body.String())
// Slightly unrelated cross-test: make sure we add an owner for internal attachments
size, err := s.cache.AttachmentsSize("9.9.9.9") // See request()
require.Nil(t, err)
require.Equal(t, int64(5000), size)
}
func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
c := newTestConfig(t)
c.BehindProxy = true
s := newTestServer(t, c)
content := "this is an ATTACHMENT"
response := request(t, s, "PUT", "/mytopic?f=myfile.txt", content, map[string]string{
"X-Forwarded-For": "1.2.3.4",
})
msg := toMessage(t, response.Body.String())
require.Equal(t, "myfile.txt", msg.Attachment.Name)
require.Equal(t, "text/plain; charset=utf-8", msg.Attachment.Type)
require.Equal(t, int64(21), msg.Attachment.Size)
require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix())
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
require.Equal(t, "", msg.Attachment.Owner) // Should never be returned
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
response = request(t, s, "GET", path, "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, "21", response.Header().Get("Content-Length"))
require.Equal(t, content, response.Body.String())
// Slightly unrelated cross-test: make sure we add an owner for internal attachments
size, err := s.cache.AttachmentsSize("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(21), size)
}
func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic", "", map[string]string{
"Attach": "https://upload.wikimedia.org/wikipedia/commons/f/fd/Pink_flower.jpg",
})
msg := toMessage(t, response.Body.String())
require.Equal(t, "You received a file: Pink_flower.jpg", msg.Message)
require.Equal(t, "Pink_flower.jpg", msg.Attachment.Name)
require.Equal(t, "https://upload.wikimedia.org/wikipedia/commons/f/fd/Pink_flower.jpg", msg.Attachment.URL)
require.Equal(t, "", msg.Attachment.Type)
require.Equal(t, int64(0), msg.Attachment.Size)
require.Equal(t, int64(0), msg.Attachment.Expires)
require.Equal(t, "", msg.Attachment.Owner)
// Slightly unrelated cross-test: make sure we don't add an owner for external attachments
size, err := s.cache.AttachmentsSize("127.0.0.1")
require.Nil(t, err)
require.Equal(t, int64(0), size)
}
func TestServer_PublishAttachmentExternalWithFilename(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic", "This is a custom message", map[string]string{
"X-Attach": "https://upload.wikimedia.org/wikipedia/commons/f/fd/Pink_flower.jpg",
"File": "some file.jpg",
})
msg := toMessage(t, response.Body.String())
require.Equal(t, "This is a custom message", msg.Message)
require.Equal(t, "some file.jpg", msg.Attachment.Name)
require.Equal(t, "https://upload.wikimedia.org/wikipedia/commons/f/fd/Pink_flower.jpg", msg.Attachment.URL)
require.Equal(t, "", msg.Attachment.Type)
require.Equal(t, int64(0), msg.Attachment.Size)
require.Equal(t, int64(0), msg.Attachment.Expires)
require.Equal(t, "", msg.Attachment.Owner)
}
func TestServer_PublishAttachmentBadURL(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic?a=not+a+URL", "", nil)
err := toHTTPError(t, response.Body.String())
require.Equal(t, 400, response.Code)
require.Equal(t, 400, err.HTTPCode)
require.Equal(t, 40013, err.Code)
}
func TestServer_PublishAttachmentTooLargeContentLength(t *testing.T) {
content := util.RandomString(5000) // > 4096
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic", content, map[string]string{
"Content-Length": "20000000",
})
err := toHTTPError(t, response.Body.String())
require.Equal(t, 400, response.Code)
require.Equal(t, 400, err.HTTPCode)
require.Equal(t, 40012, err.Code)
}
func TestServer_PublishAttachmentTooLargeBodyAttachmentFileSizeLimit(t *testing.T) {
content := util.RandomString(5001) // > 5000, see below
c := newTestConfig(t)
c.AttachmentFileSizeLimit = 5000
s := newTestServer(t, c)
response := request(t, s, "PUT", "/mytopic", content, nil)
err := toHTTPError(t, response.Body.String())
require.Equal(t, 400, response.Code)
require.Equal(t, 400, err.HTTPCode)
require.Equal(t, 40012, err.Code)
}
func TestServer_PublishAttachmentExpiryBeforeDelivery(t *testing.T) {
c := newTestConfig(t)
c.AttachmentExpiryDuration = 10 * time.Minute
s := newTestServer(t, c)
response := request(t, s, "PUT", "/mytopic", util.RandomString(5000), map[string]string{
"Delay": "11 min", // > AttachmentExpiryDuration
})
err := toHTTPError(t, response.Body.String())
require.Equal(t, 400, response.Code)
require.Equal(t, 400, err.HTTPCode)
require.Equal(t, 40015, err.Code)
}
func TestServer_PublishAttachmentTooLargeBodyVisitorAttachmentTotalSizeLimit(t *testing.T) {
c := newTestConfig(t)
c.VisitorAttachmentTotalSizeLimit = 10000
s := newTestServer(t, c)
response := request(t, s, "PUT", "/mytopic", util.RandomString(5000), nil)
msg := toMessage(t, response.Body.String())
require.Equal(t, 200, response.Code)
require.Equal(t, "You received a file: attachment.txt", msg.Message)
require.Equal(t, int64(5000), msg.Attachment.Size)
content := util.RandomString(5001) // 5000+5001 > , see below
response = request(t, s, "PUT", "/mytopic", content, nil)
err := toHTTPError(t, response.Body.String())
require.Equal(t, 400, response.Code)
require.Equal(t, 400, err.HTTPCode)
require.Equal(t, 40012, err.Code)
}
func TestServer_PublishAttachmentAndPrune(t *testing.T) {
content := util.RandomString(5000) // > 4096
c := newTestConfig(t)
c.AttachmentExpiryDuration = time.Millisecond // Hack
s := newTestServer(t, c)
// Publish and make sure we can retrieve it
response := request(t, s, "PUT", "/mytopic", content, nil)
msg := toMessage(t, response.Body.String())
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
file := filepath.Join(s.config.AttachmentCacheDir, msg.ID)
require.FileExists(t, file)
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
response = request(t, s, "GET", path, "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, content, response.Body.String())
// Prune and makes sure it's gone
time.Sleep(time.Second) // Sigh ...
s.updateStatsAndPrune()
require.NoFileExists(t, file)
response = request(t, s, "GET", path, "", nil)
require.Equal(t, 404, response.Code)
}
func TestServer_PublishAttachmentBandwidthLimit(t *testing.T) {
content := util.RandomString(5000) // > 4096
c := newTestConfig(t)
c.VisitorAttachmentDailyBandwidthLimit = 5*5000 + 123 // A little more than 1 upload and 3 downloads
s := newTestServer(t, c)
// Publish attachment
response := request(t, s, "PUT", "/mytopic", content, nil)
msg := toMessage(t, response.Body.String())
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
// Get it 4 times successfully
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
for i := 1; i <= 4; i++ { // 4 successful downloads
response = request(t, s, "GET", path, "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, content, response.Body.String())
}
// And then fail with a 429
response = request(t, s, "GET", path, "", nil)
err := toHTTPError(t, response.Body.String())
require.Equal(t, 429, response.Code)
require.Equal(t, 42905, err.Code)
}
func TestServer_PublishAttachmentBandwidthLimitUploadOnly(t *testing.T) {
content := util.RandomString(5000) // > 4096
c := newTestConfig(t)
c.VisitorAttachmentDailyBandwidthLimit = 5*5000 + 500 // 5 successful uploads
s := newTestServer(t, c)
// 5 successful uploads
for i := 1; i <= 5; i++ {
response := request(t, s, "PUT", "/mytopic", content, nil)
msg := toMessage(t, response.Body.String())
require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
}
// And a failed one
response := request(t, s, "PUT", "/mytopic", content, nil)
err := toHTTPError(t, response.Body.String())
require.Equal(t, 400, response.Code)
require.Equal(t, 40012, err.Code)
}
func newTestConfig(t *testing.T) *Config {
conf := NewConfig()
conf.BaseURL = "http://127.0.0.1:12345"
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
conf.AttachmentCacheDir = t.TempDir()
return conf
}
@ -669,6 +933,7 @@ func request(t *testing.T, s *Server, method, url, body string, headers map[stri
if err != nil {
t.Fatal(err)
}
req.RemoteAddr = "9.9.9.9" // Used for tests
for k, v := range headers {
req.Header.Set(k, v)
}

View file

@ -25,7 +25,8 @@ type visitor struct {
ip string
requests *rate.Limiter
emails *rate.Limiter
subscriptions *util.Limiter
subscriptions util.Limiter
bandwidth util.Limiter
seen time.Time
mu sync.Mutex
}
@ -36,7 +37,8 @@ func newVisitor(conf *Config, ip string) *visitor {
ip: ip,
requests: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
seen: time.Now(),
}
}
@ -62,7 +64,7 @@ func (v *visitor) EmailAllowed() error {
func (v *visitor) SubscriptionAllowed() error {
v.mu.Lock()
defer v.mu.Unlock()
if err := v.subscriptions.Add(1); err != nil {
if err := v.subscriptions.Allow(1); err != nil {
return errVisitorLimitReached
}
return nil
@ -71,7 +73,7 @@ func (v *visitor) SubscriptionAllowed() error {
func (v *visitor) RemoveSubscription() {
v.mu.Lock()
defer v.mu.Unlock()
v.subscriptions.Sub(1)
v.subscriptions.Allow(-1)
}
func (v *visitor) Keepalive() {
@ -80,6 +82,10 @@ func (v *visitor) Keepalive() {
v.seen = time.Now()
}
func (v *visitor) BandwidthLimiter() util.Limiter {
return v.bandwidth
}
func (v *visitor) Stale() bool {
v.mu.Lock()
defer v.mu.Unlock()