diff --git a/config/config.go b/config/config.go index 2dbed00..7fcea45 100644 --- a/config/config.go +++ b/config/config.go @@ -11,6 +11,7 @@ const ( DefaultCacheDuration = 12 * time.Hour DefaultKeepaliveInterval = 30 * time.Second DefaultManagerInterval = time.Minute + DefaultAtSenderInterval = 10 * time.Second ) // Defines all the limits @@ -35,6 +36,7 @@ type Config struct { CacheDuration time.Duration KeepaliveInterval time.Duration ManagerInterval time.Duration + AtSenderInterval time.Duration GlobalTopicLimit int VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration @@ -54,6 +56,7 @@ func New(listenHTTP string) *Config { CacheDuration: DefaultCacheDuration, KeepaliveInterval: DefaultKeepaliveInterval, ManagerInterval: DefaultManagerInterval, + AtSenderInterval: DefaultAtSenderInterval, GlobalTopicLimit: DefaultGlobalTopicLimit, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, diff --git a/server/cache.go b/server/cache.go index b355791..64d517d 100644 --- a/server/cache.go +++ b/server/cache.go @@ -14,8 +14,10 @@ var ( // i.e. message structs with the Event messageEvent. type cache interface { AddMessage(m *message) error - Messages(topic string, since sinceTime) ([]*message, error) + Messages(topic string, since sinceTime, scheduled bool) ([]*message, error) + MessagesDue() ([]*message, error) MessageCount(topic string) (int, error) Topics() (map[string]*topic, error) Prune(olderThan time.Time) error + MarkPublished(m *message) error } diff --git a/server/cache_mem.go b/server/cache_mem.go index 9272ebd..31c7bb9 100644 --- a/server/cache_mem.go +++ b/server/cache_mem.go @@ -1,14 +1,16 @@ package server import ( + "sort" "sync" "time" ) type memCache struct { - messages map[string][]*message - nop bool - mu sync.Mutex + messages map[string][]*message + scheduled map[string]*message // Message ID -> message + nop bool + mu sync.Mutex } var _ cache = (*memCache)(nil) @@ -16,8 +18,9 @@ var _ cache = (*memCache)(nil) // newMemCache creates an in-memory cache func newMemCache() *memCache { return &memCache{ - messages: make(map[string][]*message), - nop: false, + messages: make(map[string][]*message), + scheduled: make(map[string]*message), + nop: false, } } @@ -25,77 +28,109 @@ func newMemCache() *memCache { // it is always empty and can be used if caching is entirely disabled func newNopCache() *memCache { return &memCache{ - messages: make(map[string][]*message), - nop: true, + messages: make(map[string][]*message), + scheduled: make(map[string]*message), + nop: true, } } -func (s *memCache) AddMessage(m *message) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.nop { +func (c *memCache) AddMessage(m *message) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.nop { return nil } if m.Event != messageEvent { return errUnexpectedMessageType } - if _, ok := s.messages[m.Topic]; !ok { - s.messages[m.Topic] = make([]*message, 0) + if _, ok := c.messages[m.Topic]; !ok { + c.messages[m.Topic] = make([]*message, 0) } - s.messages[m.Topic] = append(s.messages[m.Topic], m) + delayed := m.Time > time.Now().Unix() + if delayed { + c.scheduled[m.ID] = m + } + c.messages[m.Topic] = append(c.messages[m.Topic], m) return nil } -func (s *memCache) Messages(topic string, since sinceTime) ([]*message, error) { - s.mu.Lock() - defer s.mu.Unlock() - if _, ok := s.messages[topic]; !ok || since.IsNone() { +func (c *memCache) Messages(topic string, since sinceTime, scheduled bool) ([]*message, error) { + c.mu.Lock() + defer c.mu.Unlock() + if _, ok := c.messages[topic]; !ok || since.IsNone() { return make([]*message, 0), nil } - messages := make([]*message, 0) // copy! - for _, m := range s.messages[topic] { - msgTime := time.Unix(m.Time, 0) - if msgTime == since.Time() || msgTime.After(since.Time()) { + messages := make([]*message, 0) + for _, m := range c.messages[topic] { + _, messageScheduled := c.scheduled[m.ID] + include := m.Time >= since.Time().Unix() && (!messageScheduled || scheduled) + if include { messages = append(messages, m) } } + sort.Slice(messages, func(i, j int) bool { + return messages[i].Time < messages[j].Time + }) return messages, nil } -func (s *memCache) MessageCount(topic string) (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - if _, ok := s.messages[topic]; !ok { - return 0, nil +func (c *memCache) MessagesDue() ([]*message, error) { + c.mu.Lock() + defer c.mu.Unlock() + messages := make([]*message, 0) + for _, m := range c.scheduled { + due := time.Now().Unix() >= m.Time + if due { + messages = append(messages, m) + } } - return len(s.messages[topic]), nil + sort.Slice(messages, func(i, j int) bool { + return messages[i].Time < messages[j].Time + }) + return messages, nil } -func (s *memCache) Topics() (map[string]*topic, error) { - s.mu.Lock() - defer s.mu.Unlock() +func (c *memCache) MarkPublished(m *message) error { + c.mu.Lock() + delete(c.scheduled, m.ID) + c.mu.Unlock() + return nil +} + +func (c *memCache) MessageCount(topic string) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + if _, ok := c.messages[topic]; !ok { + return 0, nil + } + return len(c.messages[topic]), nil +} + +func (c *memCache) Topics() (map[string]*topic, error) { + c.mu.Lock() + defer c.mu.Unlock() topics := make(map[string]*topic) - for topic := range s.messages { + for topic := range c.messages { topics[topic] = newTopic(topic) } return topics, nil } -func (s *memCache) Prune(olderThan time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - for topic := range s.messages { - s.pruneTopic(topic, olderThan) +func (c *memCache) Prune(olderThan time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() + for topic := range c.messages { + c.pruneTopic(topic, olderThan) } return nil } -func (s *memCache) pruneTopic(topic string, olderThan time.Time) { +func (c *memCache) pruneTopic(topic string, olderThan time.Time) { messages := make([]*message, 0) - for _, m := range s.messages[topic] { + for _, m := range c.messages[topic] { if m.Time >= olderThan.Unix() { messages = append(messages, m) } } - s.messages[topic] = messages + c.messages[topic] = messages } diff --git a/server/cache_mem_test.go b/server/cache_mem_test.go index a1c854d..831703a 100644 --- a/server/cache_mem_test.go +++ b/server/cache_mem_test.go @@ -9,6 +9,10 @@ func TestMemCache_Messages(t *testing.T) { testCacheMessages(t, newMemCache()) } +func TestMemCache_MessagesScheduled(t *testing.T) { + testCacheMessagesScheduled(t, newMemCache()) +} + func TestMemCache_Topics(t *testing.T) { testCacheTopics(t, newMemCache()) } @@ -25,7 +29,7 @@ func TestMemCache_NopCache(t *testing.T) { c := newNopCache() assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message"))) - messages, err := c.Messages("mytopic", sinceAllMessages) + messages, err := c.Messages("mytopic", sinceAllMessages, false) assert.Nil(t, err) assert.Empty(t, messages) diff --git a/server/cache_sqlite.go b/server/cache_sqlite.go index 3c3564d..19eddee 100644 --- a/server/cache_sqlite.go +++ b/server/cache_sqlite.go @@ -21,19 +21,32 @@ const ( message VARCHAR(512) NOT NULL, title VARCHAR(256) NOT NULL, priority INT NOT NULL, - tags VARCHAR(256) NOT NULL + tags VARCHAR(256) NOT NULL, + published INT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); COMMIT; ` - insertMessageQuery = `INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)` + insertMessageQuery = `INSERT INTO messages (id, time, topic, message, title, priority, tags, published) VALUES (?, ?, ?, ?, ?, ?, ?, ?)` pruneMessagesQuery = `DELETE FROM messages WHERE time < ?` selectMessagesSinceTimeQuery = ` - SELECT id, time, message, title, priority, tags + SELECT id, time, topic, message, title, priority, tags + FROM messages + WHERE topic = ? AND time >= ? AND published = 1 + ORDER BY time ASC + ` + selectMessagesSinceTimeIncludeScheduledQuery = ` + SELECT id, time, topic, message, title, priority, tags FROM messages WHERE topic = ? AND time >= ? ORDER BY time ASC ` + selectMessagesDueQuery = ` + SELECT id, time, topic, message, title, priority, tags + FROM messages + WHERE time <= ? AND published = 0 + ` + updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE id = ?` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` @@ -41,7 +54,7 @@ const ( // Schema management queries const ( - currentSchemaVersion = 1 + currentSchemaVersion = 2 createSchemaVersionTableQuery = ` CREATE TABLE IF NOT EXISTS schemaVersion ( id INT PRIMARY KEY, @@ -49,6 +62,7 @@ const ( ); ` insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` + updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` // 0 -> 1 @@ -59,6 +73,13 @@ const ( ALTER TABLE messages ADD COLUMN tags VARCHAR(256) NOT NULL DEFAULT(''); COMMIT; ` + + // 1 -> 2 + migrate1To2AlterMessagesTableQuery = ` + BEGIN; + ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1); + COMMIT; + ` ) type sqliteCache struct { @@ -84,46 +105,39 @@ func (c *sqliteCache) AddMessage(m *message) error { if m.Event != messageEvent { return errUnexpectedMessageType } - _, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message, m.Title, m.Priority, strings.Join(m.Tags, ",")) + published := m.Time <= time.Now().Unix() + _, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message, m.Title, m.Priority, strings.Join(m.Tags, ","), published) return err } -func (c *sqliteCache) Messages(topic string, since sinceTime) ([]*message, error) { +func (c *sqliteCache) Messages(topic string, since sinceTime, scheduled bool) ([]*message, error) { if since.IsNone() { return make([]*message, 0), nil } - rows, err := c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix()) + var rows *sql.Rows + var err error + if scheduled { + rows, err = c.db.Query(selectMessagesSinceTimeIncludeScheduledQuery, topic, since.Time().Unix()) + } else { + rows, err = c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix()) + } if err != nil { return nil, err } - defer rows.Close() - messages := make([]*message, 0) - for rows.Next() { - var timestamp int64 - var priority int - var id, msg, title, tagsStr string - if err := rows.Scan(&id, ×tamp, &msg, &title, &priority, &tagsStr); err != nil { - return nil, err - } - var tags []string - if tagsStr != "" { - tags = strings.Split(tagsStr, ",") - } - messages = append(messages, &message{ - ID: id, - Time: timestamp, - Event: messageEvent, - Topic: topic, - Message: msg, - Title: title, - Priority: priority, - Tags: tags, - }) - } - if err := rows.Err(); err != nil { + return readMessages(rows) +} + +func (c *sqliteCache) MessagesDue() ([]*message, error) { + rows, err := c.db.Query(selectMessagesDueQuery, time.Now().Unix()) + if err != nil { return nil, err } - return messages, nil + return readMessages(rows) +} + +func (c *sqliteCache) MarkPublished(m *message) error { + _, err := c.db.Exec(updateMessagePublishedQuery, m.ID) + return err } func (c *sqliteCache) MessageCount(topic string) (int, error) { @@ -169,6 +183,37 @@ func (c *sqliteCache) Prune(olderThan time.Time) error { return err } +func readMessages(rows *sql.Rows) ([]*message, error) { + defer rows.Close() + messages := make([]*message, 0) + for rows.Next() { + var timestamp int64 + var priority int + var id, topic, msg, title, tagsStr string + if err := rows.Scan(&id, ×tamp, &topic, &msg, &title, &priority, &tagsStr); err != nil { + return nil, err + } + var tags []string + if tagsStr != "" { + tags = strings.Split(tagsStr, ",") + } + messages = append(messages, &message{ + ID: id, + Time: timestamp, + Event: messageEvent, + Topic: topic, + Message: msg, + Title: title, + Priority: priority, + Tags: tags, + }) + } + if err := rows.Err(); err != nil { + return nil, err + } + return messages, nil +} + func setupDB(db *sql.DB) error { // If 'messages' table does not exist, this must be a new database rowsMC, err := db.Query(selectMessagesCountQuery) @@ -194,7 +239,9 @@ func setupDB(db *sql.DB) error { if schemaVersion == currentSchemaVersion { return nil } else if schemaVersion == 0 { - return migrateFrom0To1(db) + return migrateFrom0(db) + } else if schemaVersion == 1 { + return migrateFrom1(db) } return fmt.Errorf("unexpected schema version found: %d", schemaVersion) } @@ -212,7 +259,7 @@ func setupNewDB(db *sql.DB) error { return nil } -func migrateFrom0To1(db *sql.DB) error { +func migrateFrom0(db *sql.DB) error { log.Print("Migrating cache database schema: from 0 to 1") if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil { return err @@ -223,5 +270,16 @@ func migrateFrom0To1(db *sql.DB) error { if _, err := db.Exec(insertSchemaVersion, 1); err != nil { return err } - return nil + return migrateFrom1(db) +} + +func migrateFrom1(db *sql.DB) error { + log.Print("Migrating cache database schema: from 1 to 2") + if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 2); err != nil { + return err + } + return nil // Update this when a new version is added } diff --git a/server/cache_sqlite_test.go b/server/cache_sqlite_test.go index 0f6c430..eb7b64a 100644 --- a/server/cache_sqlite_test.go +++ b/server/cache_sqlite_test.go @@ -9,10 +9,14 @@ import ( "time" ) -func TestSqliteCache_AddMessage(t *testing.T) { +func TestSqliteCache_Messages(t *testing.T) { testCacheMessages(t, newSqliteTestCache(t)) } +func TestSqliteCache_MessagesScheduled(t *testing.T) { + testCacheMessagesScheduled(t, newSqliteTestCache(t)) +} + func TestSqliteCache_Topics(t *testing.T) { testCacheTopics(t, newSqliteTestCache(t)) } @@ -25,7 +29,7 @@ func TestSqliteCache_Prune(t *testing.T) { testCachePrune(t, newSqliteTestCache(t)) } -func TestSqliteCache_Migration_0to1(t *testing.T) { +func TestSqliteCache_Migration_From0(t *testing.T) { filename := newSqliteTestCacheFile(t) db, err := sql.Open("sqlite3", filename) assert.Nil(t, err) @@ -53,7 +57,7 @@ func TestSqliteCache_Migration_0to1(t *testing.T) { // Create cache to trigger migration c := newSqliteTestCacheFromFile(t, filename) - messages, err := c.Messages("mytopic", sinceAllMessages) + messages, err := c.Messages("mytopic", sinceAllMessages, false) assert.Nil(t, err) assert.Equal(t, 10, len(messages)) assert.Equal(t, "some message 5", messages[5].Message) @@ -67,7 +71,7 @@ func TestSqliteCache_Migration_0to1(t *testing.T) { var schemaVersion int assert.Nil(t, rows.Scan(&schemaVersion)) - assert.Equal(t, 1, schemaVersion) + assert.Equal(t, 2, schemaVersion) } func newSqliteTestCache(t *testing.T) *sqliteCache { diff --git a/server/cache_test.go b/server/cache_test.go index ab65b06..1eae091 100644 --- a/server/cache_test.go +++ b/server/cache_test.go @@ -27,7 +27,7 @@ func testCacheMessages(t *testing.T, c cache) { assert.Equal(t, 2, count) // mytopic: since all - messages, _ := c.Messages("mytopic", sinceAllMessages) + messages, _ := c.Messages("mytopic", sinceAllMessages, false) assert.Equal(t, 2, len(messages)) assert.Equal(t, "my message", messages[0].Message) assert.Equal(t, "mytopic", messages[0].Topic) @@ -38,11 +38,11 @@ func testCacheMessages(t *testing.T, c cache) { assert.Equal(t, "my other message", messages[1].Message) // mytopic: since none - messages, _ = c.Messages("mytopic", sinceNoMessages) + messages, _ = c.Messages("mytopic", sinceNoMessages, false) assert.Empty(t, messages) // mytopic: since 2 - messages, _ = c.Messages("mytopic", sinceTime(time.Unix(2, 0))) + messages, _ = c.Messages("mytopic", sinceTime(time.Unix(2, 0)), false) assert.Equal(t, 1, len(messages)) assert.Equal(t, "my other message", messages[0].Message) @@ -52,7 +52,7 @@ func testCacheMessages(t *testing.T, c cache) { assert.Equal(t, 1, count) // example: since all - messages, _ = c.Messages("example", sinceAllMessages) + messages, _ = c.Messages("example", sinceAllMessages, false) assert.Equal(t, "my example message", messages[0].Message) // non-existing: count @@ -61,7 +61,7 @@ func testCacheMessages(t *testing.T, c cache) { assert.Equal(t, 0, count) // non-existing: since all - messages, _ = c.Messages("doesnotexist", sinceAllMessages) + messages, _ = c.Messages("doesnotexist", sinceAllMessages, false) assert.Empty(t, messages) } @@ -103,7 +103,7 @@ func testCachePrune(t *testing.T, c cache) { assert.Nil(t, err) assert.Equal(t, 0, count) - messages, err := c.Messages("mytopic", sinceAllMessages) + messages, err := c.Messages("mytopic", sinceAllMessages, false) assert.Nil(t, err) assert.Equal(t, 1, len(messages)) assert.Equal(t, "my other message", messages[0].Message) @@ -116,8 +116,34 @@ func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) { m.Title = "some title" assert.Nil(t, c.AddMessage(m)) - messages, _ := c.Messages("mytopic", sinceAllMessages) + messages, _ := c.Messages("mytopic", sinceAllMessages, false) assert.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags) assert.Equal(t, 5, messages[0].Priority) assert.Equal(t, "some title", messages[0].Title) } + +func testCacheMessagesScheduled(t *testing.T, c cache) { + m1 := newDefaultMessage("mytopic", "message 1") + m2 := newDefaultMessage("mytopic", "message 2") + m2.Time = time.Now().Add(time.Hour).Unix() + m3 := newDefaultMessage("mytopic", "message 3") + m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2! + m4 := newDefaultMessage("mytopic2", "message 4") + m4.Time = time.Now().Add(time.Minute).Unix() + assert.Nil(t, c.AddMessage(m1)) + assert.Nil(t, c.AddMessage(m2)) + assert.Nil(t, c.AddMessage(m3)) + + messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled + assert.Equal(t, 1, len(messages)) + assert.Equal(t, "message 1", messages[0].Message) + + messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled + assert.Equal(t, 3, len(messages)) + assert.Equal(t, "message 1", messages[0].Message) + assert.Equal(t, "message 3", messages[1].Message) // Order! + assert.Equal(t, "message 2", messages[2].Message) + + messages, _ = c.MessagesDue() + assert.Empty(t, messages) +} diff --git a/server/server.go b/server/server.go index 7ee039c..73727f2 100644 --- a/server/server.go +++ b/server/server.go @@ -73,6 +73,7 @@ var ( const ( messageLimit = 512 + minDelay = 10 * time.Second ) var ( @@ -183,6 +184,15 @@ func (s *Server) Run() error { s.updateStatsAndExpire() } }() + go func() { + ticker := time.NewTicker(s.config.AtSenderInterval) + for { + <-ticker.C + if err := s.sendDelayedMessages(); err != nil { + log.Printf("error sending scheduled messages: %s", err.Error()) + } + } + }() listenStr := fmt.Sprintf("%s/http", s.config.ListenHTTP) if s.config.ListenHTTPS != "" { listenStr += fmt.Sprintf(" %s/https", s.config.ListenHTTPS) @@ -279,14 +289,17 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito if m.Message == "" { return errHTTPBadRequest } - title, priority, tags, cache, firebase := parseHeaders(r.Header) - m.Title = title - m.Priority = priority - m.Tags = tags - if err := t.Publish(m); err != nil { + cache, firebase, err := parseHeaders(r.Header, m) + if err != nil { return err } - if s.firebase != nil && firebase { + delayed := m.Time > time.Now().Unix() + if !delayed { + if err := t.Publish(m); err != nil { + return err + } + } + if s.firebase != nil && firebase && !delayed { go func() { if err := s.firebase(m); err != nil { log.Printf("Unable to publish to Firebase: %v", err.Error()) @@ -308,35 +321,62 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito return nil } -func parseHeaders(header http.Header) (title string, priority int, tags []string, cache bool, firebase bool) { - title = readHeader(header, "x-title", "title", "ti", "t") +func parseHeaders(header http.Header, m *message) (cache bool, firebase bool, err error) { + cache = readHeader(header, "x-cache", "cache") != "no" + firebase = readHeader(header, "x-firebase", "firebase") != "no" + m.Title = readHeader(header, "x-title", "title", "ti", "t") priorityStr := readHeader(header, "x-priority", "priority", "prio", "p") if priorityStr != "" { switch strings.ToLower(priorityStr) { case "1", "min": - priority = 1 + m.Priority = 1 case "2", "low": - priority = 2 + m.Priority = 2 case "3", "default": - priority = 3 + m.Priority = 3 case "4", "high": - priority = 4 + m.Priority = 4 case "5", "max", "urgent": - priority = 5 + m.Priority = 5 default: - priority = 0 + return false, false, errHTTPBadRequest } } tagsStr := readHeader(header, "x-tags", "tag", "tags", "ta") if tagsStr != "" { - tags = make([]string, 0) + m.Tags = make([]string, 0) for _, s := range strings.Split(tagsStr, ",") { - tags = append(tags, strings.TrimSpace(s)) + m.Tags = append(m.Tags, strings.TrimSpace(s)) } } - cache = readHeader(header, "x-cache", "cache") != "no" - firebase = readHeader(header, "x-firebase", "firebase") != "no" - return title, priority, tags, cache, firebase + atStr := readHeader(header, "x-at", "at", "x-schedule", "schedule", "sched") + if atStr != "" { + if !cache { + return false, false, errHTTPBadRequest + } + at, err := strconv.Atoi(atStr) + if err != nil { + return false, false, errHTTPBadRequest + } else if int64(at) < time.Now().Add(minDelay).Unix() { + return false, false, errHTTPBadRequest + } + m.Time = int64(at) + } else { + delayStr := readHeader(header, "x-delay", "delay", "x-in", "in") + if delayStr != "" { + if !cache { + return false, false, errHTTPBadRequest + } + delay, err := time.ParseDuration(delayStr) + if err != nil { + return false, false, errHTTPBadRequest + } else if delay < minDelay { + return false, false, errHTTPBadRequest + } + m.Time = time.Now().Add(delay).Unix() + } + } + return cache, firebase, nil } func readHeader(header http.Header, names ...string) string { @@ -401,6 +441,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi } var wlock sync.Mutex poll := r.URL.Query().Has("poll") + scheduled := r.URL.Query().Has("scheduled") || r.URL.Query().Has("sched") sub := func(msg *message) error { wlock.Lock() defer wlock.Unlock() @@ -419,7 +460,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! if poll { - return s.sendOldMessages(topics, since, sub) + return s.sendOldMessages(topics, since, scheduled, sub) } subscriberIDs := make([]int, 0) for _, t := range topics { @@ -433,7 +474,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message return err } - if err := s.sendOldMessages(topics, since, sub); err != nil { + if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil { return err } for { @@ -449,12 +490,12 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi } } -func (s *Server) sendOldMessages(topics []*topic, since sinceTime, sub subscriber) error { +func (s *Server) sendOldMessages(topics []*topic, since sinceTime, scheduled bool, sub subscriber) error { if since.IsNone() { return nil } for _, t := range topics { - messages, err := s.cache.Messages(t.ID, since) + messages, err := s.cache.Messages(t.ID, since, scheduled) if err != nil { return err } @@ -560,6 +601,32 @@ func (s *Server) updateStatsAndExpire() { s.messages, len(s.topics), subscribers, messages, len(s.visitors)) } +func (s *Server) sendDelayedMessages() error { + s.mu.Lock() + defer s.mu.Unlock() + messages, err := s.cache.MessagesDue() + if err != nil { + return err + } + for _, m := range messages { + t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published + if ok { + if err := t.Publish(m); err != nil { + log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error()) + } + if s.firebase != nil { + if err := s.firebase(m); err != nil { + log.Printf("unable to publish to Firebase: %v", err.Error()) + } + } + } + if err := s.cache.MarkPublished(m); err != nil { + return err + } + } + return nil +} + func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error { v := s.visitor(r) if err := v.RequestAllowed(); err != nil {