Subscription limit

This commit is contained in:
Philipp Heckel 2021-11-01 15:21:38 -04:00
parent 5f2bb4f876
commit fa7a45902f
3 changed files with 97 additions and 37 deletions

View file

@ -17,8 +17,9 @@ const (
// Defines the max number of requests, here: // Defines the max number of requests, here:
// 50 requests bucket, replenished at a rate of 1 per second // 50 requests bucket, replenished at a rate of 1 per second
var ( var (
defaultLimit = rate.Every(time.Second) defaultRequestLimit = rate.Every(time.Second)
defaultLimitBurst = 50 defaultRequestLimitBurst = 50
defaultSubscriptionLimit = 30 // per visitor
) )
// Config is the main config struct for the application. Use New to instantiate a default config struct. // Config is the main config struct for the application. Use New to instantiate a default config struct.
@ -28,8 +29,9 @@ type Config struct {
MessageBufferDuration time.Duration MessageBufferDuration time.Duration
KeepaliveInterval time.Duration KeepaliveInterval time.Duration
ManagerInterval time.Duration ManagerInterval time.Duration
Limit rate.Limit RequestLimit rate.Limit
LimitBurst int RequestLimitBurst int
SubscriptionLimit int
} }
// New instantiates a default new config // New instantiates a default new config
@ -40,7 +42,8 @@ func New(listenHTTP string) *Config {
MessageBufferDuration: DefaultMessageBufferDuration, MessageBufferDuration: DefaultMessageBufferDuration,
KeepaliveInterval: DefaultKeepaliveInterval, KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval, ManagerInterval: DefaultManagerInterval,
Limit: defaultLimit, RequestLimit: defaultRequestLimit,
LimitBurst: defaultLimitBurst, RequestLimitBurst: defaultRequestLimitBurst,
SubscriptionLimit: defaultSubscriptionLimit,
} }
} }

View file

@ -9,7 +9,6 @@ import (
firebase "firebase.google.com/go" firebase "firebase.google.com/go"
"firebase.google.com/go/messaging" "firebase.google.com/go/messaging"
"fmt" "fmt"
"golang.org/x/time/rate"
"google.golang.org/api/option" "google.golang.org/api/option"
"heckel.io/ntfy/config" "heckel.io/ntfy/config"
"io" "io"
@ -23,9 +22,8 @@ import (
"time" "time"
) )
// TODO add "max connections open" limit
// TODO add "max messages in a topic" limit // TODO add "max messages in a topic" limit
// TODO add "max topics" limit // TODO implement persistence
// Server is the main server // Server is the main server
type Server struct { type Server struct {
@ -37,12 +35,6 @@ type Server struct {
mu sync.Mutex mu sync.Mutex
} }
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
type visitor struct {
limiter *rate.Limiter
seen time.Time
}
// errHTTP is a generic HTTP error for any non-200 HTTP error // errHTTP is a generic HTTP error for any non-200 HTTP error
type errHTTP struct { type errHTTP struct {
Code int Code int
@ -55,7 +47,6 @@ func (e errHTTP) Error() string {
const ( const (
messageLimit = 1024 messageLimit = 1024
visitorExpungeAfter = 30 * time.Minute
) )
var ( var (
@ -147,8 +138,8 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
v := s.visitor(r.RemoteAddr) v := s.visitor(r.RemoteAddr)
if !v.limiter.Allow() { if err := v.RequestAllowed(); err != nil {
return errHTTPTooManyRequests return err
} }
if r.Method == http.MethodGet && r.URL.Path == "/" { if r.Method == http.MethodGet && r.URL.Path == "/" {
return s.handleHome(w, r) return s.handleHome(w, r)
@ -157,11 +148,11 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
return s.handlePublish(w, r) return s.handlePublish(w, r)
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
return s.handleSubscribeJSON(w, r) return s.handleSubscribeJSON(w, r, v)
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
return s.handleSubscribeSSE(w, r) return s.handleSubscribeSSE(w, r, v)
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
return s.handleSubscribeRaw(w, r) return s.handleSubscribeRaw(w, r, v)
} else if r.Method == http.MethodOptions { } else if r.Method == http.MethodOptions {
return s.handleOptions(w, r) return s.handleOptions(w, r)
} }
@ -195,7 +186,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request) error {
return nil return nil
} }
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request) error { func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) { encoder := func(msg *message) (string, error) {
var buf bytes.Buffer var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(&msg); err != nil { if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
@ -203,10 +194,10 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request) err
} }
return buf.String(), nil return buf.String(), nil
} }
return s.handleSubscribe(w, r, "json", "application/stream+json", encoder) return s.handleSubscribe(w, r, v, "json", "application/stream+json", encoder)
} }
func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request) error { func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) { encoder := func(msg *message) (string, error) {
var buf bytes.Buffer var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(&msg); err != nil { if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
@ -217,20 +208,24 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request) erro
} }
return fmt.Sprintf("data: %s\n", buf.String()), nil return fmt.Sprintf("data: %s\n", buf.String()), nil
} }
return s.handleSubscribe(w, r, "sse", "text/event-stream", encoder) return s.handleSubscribe(w, r, v, "sse", "text/event-stream", encoder)
} }
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request) error { func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) { encoder := func(msg *message) (string, error) {
if msg.Event == "" { // only handle default events if msg.Event == "" { // only handle default events
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
} }
return "\n", nil // "keepalive" and "open" events just send an empty line return "\n", nil // "keepalive" and "open" events just send an empty line
} }
return s.handleSubscribe(w, r, "raw", "text/plain", encoder) return s.handleSubscribe(w, r, v, "raw", "text/plain", encoder)
} }
func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format string, contentType string, encoder messageEncoder) error { func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error {
if err := v.AddSubscription(); err != nil {
return err
}
defer v.RemoveSubscription()
t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/"+format)) // Hack t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/"+format)) // Hack
since, err := parseSince(r) since, err := parseSince(r)
if err != nil { if err != nil {
@ -270,6 +265,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format
case <-r.Context().Done(): case <-r.Context().Done():
return nil return nil
case <-time.After(s.config.KeepaliveInterval): case <-time.After(s.config.KeepaliveInterval):
v.Keepalive()
if err := sub(newKeepaliveMessage(t.id)); err != nil { // Send keepalive message if err := sub(newKeepaliveMessage(t.id)); err != nil { // Send keepalive message
return err return err
} }
@ -326,12 +322,12 @@ func (s *Server) updateStatsAndExpire() {
// Expire visitors from rate visitors map // Expire visitors from rate visitors map
for ip, v := range s.visitors { for ip, v := range s.visitors {
if time.Since(v.seen) > visitorExpungeAfter { if v.Stale() {
delete(s.visitors, ip) delete(s.visitors, ip)
} }
} }
// Prune old messages, remove topics without subscribers // Prune old messages, remove subscriptions without subscribers
for _, t := range s.topics { for _, t := range s.topics {
t.Prune(s.config.MessageBufferDuration) t.Prune(s.config.MessageBufferDuration)
subs, msgs := t.Stats() subs, msgs := t.Stats()
@ -362,12 +358,8 @@ func (s *Server) visitor(remoteAddr string) *visitor {
} }
v, exists := s.visitors[ip] v, exists := s.visitors[ip]
if !exists { if !exists {
v = &visitor{ s.visitors[ip] = newVisitor(s.config)
rate.NewLimiter(s.config.Limit, s.config.LimitBurst), return s.visitors[ip]
time.Now(),
}
s.visitors[ip] = v
return v
} }
v.seen = time.Now() v.seen = time.Now()
return v return v

65
server/visitor.go Normal file
View file

@ -0,0 +1,65 @@
package server
import (
"golang.org/x/time/rate"
"heckel.io/ntfy/config"
"sync"
"time"
)
const (
visitorExpungeAfter = 30 * time.Minute
)
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
type visitor struct {
config *config.Config
limiter *rate.Limiter
subscriptions int
seen time.Time
mu sync.Mutex
}
func newVisitor(conf *config.Config) *visitor {
return &visitor{
config: conf,
limiter: rate.NewLimiter(conf.RequestLimit, conf.RequestLimitBurst),
seen: time.Now(),
}
}
func (v *visitor) RequestAllowed() error {
if !v.limiter.Allow() {
return errHTTPTooManyRequests
}
return nil
}
func (v *visitor) AddSubscription() error {
v.mu.Lock()
defer v.mu.Unlock()
if v.subscriptions >= v.config.SubscriptionLimit {
return errHTTPTooManyRequests
}
v.subscriptions++
return nil
}
func (v *visitor) RemoveSubscription() {
v.mu.Lock()
defer v.mu.Unlock()
v.subscriptions--
}
func (v *visitor) Keepalive() {
v.mu.Lock()
defer v.mu.Unlock()
v.seen = time.Now()
}
func (v *visitor) Stale() bool {
v.mu.Lock()
defer v.mu.Unlock()
v.seen = time.Now()
return time.Since(v.seen) > visitorExpungeAfter
}