diff --git a/server/cache.go b/server/cache.go index be4f704..df4c2c9 100644 --- a/server/cache.go +++ b/server/cache.go @@ -1,61 +1,14 @@ package server import ( - "database/sql" - "time" _ "github.com/mattn/go-sqlite3" // SQLite driver + "time" ) -const ( - createTableQuery = `CREATE TABLE IF NOT EXISTS messages ( - id VARCHAR(20) PRIMARY KEY, - time INT NOT NULL, - topic VARCHAR(64) NOT NULL, - message VARCHAR(1024) NOT NULL - )` - insertQuery = `INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)` - pruneOlderThanQuery = `DELETE FROM messages WHERE time < ?` -) - -type cache struct { - db *sql.DB - insert *sql.Stmt - prune *sql.Stmt -} - -func newCache(filename string) (*cache, error) { - db, err := sql.Open("sqlite3", filename) - if err != nil { - return nil, err - } - if _, err := db.Exec(createTableQuery); err != nil { - return nil, err - } - insert, err := db.Prepare(insertQuery) - if err != nil { - return nil, err - } - prune, err := db.Prepare(pruneOlderThanQuery) - if err != nil { - return nil, err - } - return &cache{ - db: db, - insert: insert, - prune: prune, - }, nil -} - -func (c *cache) Load() (map[string]*topic, error) { - -} - -func (c *cache) Add(m *message) error { - _, err := c.insert.Exec(m.ID, m.Time, m.Topic, m.Message) - return err -} - -func (c *cache) Prune(olderThan time.Duration) error { - _, err := c.prune.Exec(time.Now().Add(-1 * olderThan).Unix()) - return err +type cache interface { + AddMessage(m *message) error + Messages(topic string, since time.Time) ([]*message, error) + MessageCount(topic string) (int, error) + Topics() (map[string]*topic, error) + Prune(keep time.Duration) error } diff --git a/server/cache_mem.go b/server/cache_mem.go new file mode 100644 index 0000000..1e7e08d --- /dev/null +++ b/server/cache_mem.go @@ -0,0 +1,80 @@ +package server + +import ( + _ "github.com/mattn/go-sqlite3" // SQLite driver + "sync" + "time" +) + +type memCache struct { + messages map[string][]*message + mu sync.Mutex +} + +var _ cache = (*memCache)(nil) + +func newMemCache() *memCache { + return &memCache{ + messages: make(map[string][]*message), + } +} + +func (s *memCache) AddMessage(m *message) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.messages[m.Topic]; !ok { + s.messages[m.Topic] = make([]*message, 0) + } + s.messages[m.Topic] = append(s.messages[m.Topic], m) + return nil +} + +func (s *memCache) Messages(topic string, since time.Time) ([]*message, error) { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.messages[topic]; !ok { + return make([]*message, 0), nil + } + messages := make([]*message, 0) // copy! + for _, m := range s.messages[topic] { + msgTime := time.Unix(m.Time, 0) + if msgTime == since || msgTime.After(since) { + messages = append(messages, m) + } + } + return messages, nil +} + +func (s *memCache) MessageCount(topic string) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.messages[topic]; !ok { + return 0, nil + } + return len(s.messages[topic]), nil +} + +func (s *memCache) Topics() (map[string]*topic, error) { + // Hack since we know when this is called there are no messages! + return make(map[string]*topic), nil +} + +func (s *memCache) Prune(keep time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + for topic, _ := range s.messages { + s.pruneTopic(topic, keep) + } + return nil +} + +func (s *memCache) pruneTopic(topic string, keep time.Duration) { + for i, m := range s.messages[topic] { + msgTime := time.Unix(m.Time, 0) + if time.Since(msgTime) < keep { + s.messages[topic] = s.messages[topic][i:] + return + } + } + s.messages[topic] = make([]*message, 0) // all messages expired +} diff --git a/server/cache_sqlite.go b/server/cache_sqlite.go new file mode 100644 index 0000000..6f041f1 --- /dev/null +++ b/server/cache_sqlite.go @@ -0,0 +1,127 @@ +package server + +import ( + "database/sql" + "errors" + _ "github.com/mattn/go-sqlite3" // SQLite driver + "time" +) + +const ( + createTableQuery = ` + BEGIN; + CREATE TABLE IF NOT EXISTS messages ( + id VARCHAR(20) PRIMARY KEY, + time INT NOT NULL, + topic VARCHAR(64) NOT NULL, + message VARCHAR(1024) NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + COMMIT; + ` + insertMessageQuery = `INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)` + pruneMessagesQuery = `DELETE FROM messages WHERE time < ?` + selectMessagesSinceTimeQuery = ` + SELECT id, time, message + FROM messages + WHERE topic = ? AND time >= ? + ORDER BY time ASC + ` + selectMessageCountQuery = `SELECT count(*) FROM messages WHERE topic = ?` + selectTopicsQuery = `SELECT topic, MAX(time) FROM messages GROUP BY TOPIC` +) + +type sqliteCache struct { + db *sql.DB +} + +var _ cache = (*sqliteCache)(nil) + +func newSqliteCache(filename string) (*sqliteCache, error) { + db, err := sql.Open("sqlite3", filename) + if err != nil { + return nil, err + } + if _, err := db.Exec(createTableQuery); err != nil { + return nil, err + } + return &sqliteCache{ + db: db, + }, nil +} + +func (c *sqliteCache) AddMessage(m *message) error { + _, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message) + return err +} + +func (c *sqliteCache) Messages(topic string, since time.Time) ([]*message, error) { + rows, err := c.db.Query(selectMessagesSinceTimeQuery, topic, since.Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + messages := make([]*message, 0) + for rows.Next() { + var timestamp int64 + var id, msg string + if err := rows.Scan(&id, ×tamp, &msg); err != nil { + return nil, err + } + messages = append(messages, &message{ + ID: id, + Time: timestamp, + Event: messageEvent, + Topic: topic, + Message: msg, + }) + } + if err := rows.Err(); err != nil { + return nil, err + } + return messages, nil +} + +func (c *sqliteCache) MessageCount(topic string) (int, error) { + rows, err := c.db.Query(selectMessageCountQuery, topic) + if err != nil { + return 0, err + } + defer rows.Close() + var count int + if !rows.Next() { + return 0, errors.New("no rows found") + } + if err := rows.Scan(&count); err != nil { + return 0, err + } else if err := rows.Err(); err != nil { + return 0, err + } + return count, nil +} + +func (s *sqliteCache) Topics() (map[string]*topic, error) { + rows, err := s.db.Query(selectTopicsQuery) + if err != nil { + return nil, err + } + defer rows.Close() + topics := make(map[string]*topic, 0) + for rows.Next() { + var id string + var last int64 + if err := rows.Scan(&id, &last); err != nil { + return nil, err + } + topics[id] = newTopic(id, time.Unix(last, 0)) + } + if err := rows.Err(); err != nil { + return nil, err + } + return topics, nil +} + +func (c *sqliteCache) Prune(keep time.Duration) error { + _, err := c.db.Exec(pruneMessagesQuery, time.Now().Add(-1 * keep).Unix()) + return err +} diff --git a/server/server.go b/server/server.go index 65adddf..208cc45 100644 --- a/server/server.go +++ b/server/server.go @@ -32,7 +32,7 @@ type Server struct { visitors map[string]*visitor firebase subscriber messages int64 - cache *cache + cache cache mu sync.Mutex } @@ -78,30 +78,28 @@ func New(conf *config.Config) (*Server, error) { return nil, err } } - cache, err := maybeCreateCache(conf) + cache, err := createCache(conf) if err != nil { return nil, err } - topics := make(map[string]*topic) - if cache != nil { - if topics, err = cache.Load(); err != nil { - return nil, err - } + topics, err := cache.Topics() + if err != nil { + return nil, err } return &Server{ config: conf, - cache: cache, + cache: cache, firebase: firebaseSubscriber, topics: topics, visitors: make(map[string]*visitor), }, nil } -func maybeCreateCache(conf *config.Config) (*cache, error) { - if conf.CacheFile == "" { - return nil, nil +func createCache(conf *config.Config) (cache, error) { + if conf.CacheFile != "" { + return newSqliteCache(conf.CacheFile) } - return newCache(conf.CacheFile) + return newMemCache(), nil } func createFirebaseSubscriber(conf *config.Config) (subscriber, error) { @@ -202,8 +200,8 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if err := t.Publish(m); err != nil { return err } - if s.cache != nil { - s.cache.Add(m) + if err := s.cache.AddMessage(m); err != nil { + return err } w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests s.mu.Lock() @@ -277,20 +275,18 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType) if poll { - return sendOldMessages(t, since, sub) + return s.sendOldMessages(t, since, sub) } subscriberID := t.Subscribe(sub) defer t.Unsubscribe(subscriberID) if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message return err } - if err := sendOldMessages(t, since, sub); err != nil { + if err := s.sendOldMessages(t, since, sub); err != nil { return err } for { select { - case <-t.ctx.Done(): - return nil case <-r.Context().Done(): return nil case <-time.After(s.config.KeepaliveInterval): @@ -302,11 +298,15 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi } } -func sendOldMessages(t *topic, since time.Time, sub subscriber) error { +func (s *Server) sendOldMessages(t *topic, since time.Time, sub subscriber) error { if since.IsZero() { return nil } - for _, m := range t.Messages(since) { + messages, err := s.cache.Messages(t.id, since) + if err != nil { + return err + } + for _, m := range messages { if err := sub(m); err != nil { return err } @@ -340,7 +340,7 @@ func (s *Server) topic(id string) (*topic, error) { if len(s.topics) >= s.config.GlobalTopicLimit { return nil, errHTTPTooManyRequests } - s.topics[id] = newTopic(id) + s.topics[id] = newTopic(id, time.Now()) if s.firebase != nil { s.topics[id].Subscribe(s.firebase) } @@ -360,28 +360,28 @@ func (s *Server) updateStatsAndExpire() { } // Prune cache - if s.cache != nil { - if err := s.cache.Prune(s.config.MessageBufferDuration); err != nil { - log.Printf("error pruning cache: %s", err.Error()) - } + if err := s.cache.Prune(s.config.MessageBufferDuration); err != nil { + log.Printf("error pruning cache: %s", err.Error()) } // Prune old messages, remove subscriptions without subscribers - for _, t := range s.topics { - t.Prune(s.config.MessageBufferDuration) - subs, msgs := t.Stats() - if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber! - delete(s.topics, t.id) - } - } - - // Print stats var subscribers, messages int for _, t := range s.topics { - subs, msgs := t.Stats() + subs := t.Subscribers() + msgs, err := s.cache.MessageCount(t.id) + if err != nil { + log.Printf("cannot get stats for topic %s: %s", t.id, err.Error()) + continue + } + if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber! + delete(s.topics, t.id) + continue + } subscribers += subs messages += msgs } + + // Print stats log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)", s.messages, len(s.topics), subscribers, messages, len(s.visitors)) } diff --git a/server/topic.go b/server/topic.go index 3ad98fc..881a28b 100644 --- a/server/topic.go +++ b/server/topic.go @@ -1,7 +1,6 @@ package server import ( - "context" "log" "math/rand" "sync" @@ -12,11 +11,8 @@ import ( // can publish a message type topic struct { id string - subscribers map[int]subscriber - messages []*message last time.Time - ctx context.Context - cancel context.CancelFunc + subscribers map[int]subscriber mu sync.Mutex } @@ -24,15 +20,11 @@ type topic struct { type subscriber func(msg *message) error // newTopic creates a new topic -func newTopic(id string) *topic { - ctx, cancel := context.WithCancel(context.Background()) +func newTopic(id string, last time.Time) *topic { return &topic{ id: id, + last: last, subscribers: make(map[int]subscriber), - messages: make([]*message, 0), - last: time.Now(), - ctx: ctx, - cancel: cancel, } } @@ -55,7 +47,6 @@ func (t *topic) Publish(m *message) error { t.mu.Lock() defer t.mu.Unlock() t.last = time.Now() - t.messages = append(t.messages, m) for _, s := range t.subscribers { if err := s(m); err != nil { log.Printf("error publishing message to subscriber") @@ -64,38 +55,8 @@ func (t *topic) Publish(m *message) error { return nil } -func (t *topic) Messages(since time.Time) []*message { +func (t *topic) Subscribers() int { t.mu.Lock() defer t.mu.Unlock() - messages := make([]*message, 0) // copy! - for _, m := range t.messages { - msgTime := time.Unix(m.Time, 0) - if msgTime == since || msgTime.After(since) { - messages = append(messages, m) - } - } - return messages -} - -func (t *topic) Prune(keep time.Duration) { - t.mu.Lock() - defer t.mu.Unlock() - for i, m := range t.messages { - msgTime := time.Unix(m.Time, 0) - if time.Since(msgTime) < keep { - t.messages = t.messages[i:] - return - } - } - t.messages = make([]*message, 0) -} - -func (t *topic) Stats() (subscribers int, messages int) { - t.mu.Lock() - defer t.mu.Unlock() - return len(t.subscribers), len(t.messages) -} - -func (t *topic) Close() { - t.cancel() + return len(t.subscribers) }