From 0170f673bd84b8a94f53f9ee3009e9787f05790a Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Fri, 5 Nov 2021 13:46:27 -0400 Subject: [PATCH] Fix rate limiting behind proxy, make configurable --- cmd/app.go | 19 ++++++++++++++-- config/config.go | 57 ++++++++++++++++++++++++----------------------- config/config.yml | 26 ++++++++++++++++++++- server/server.go | 40 ++++++++++++++++++++++----------- server/visitor.go | 2 +- 5 files changed, 99 insertions(+), 45 deletions(-) diff --git a/cmd/app.go b/cmd/app.go index b7d3722..6ef9f94 100644 --- a/cmd/app.go +++ b/cmd/app.go @@ -22,8 +22,13 @@ func New() *cli.App { altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: config.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}), - altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "default interval of keepalive messages"}), - altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "default interval of for message pruning and stats printing"}), + altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}), + altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: config.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"V"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: config.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"B"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: config.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), + altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"R"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: config.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), + altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), } return &cli.App{ Name: "ntfy", @@ -50,6 +55,11 @@ func execRun(c *cli.Context) error { cacheDuration := c.Duration("cache-duration") keepaliveInterval := c.Duration("keepalive-interval") managerInterval := c.Duration("manager-interval") + globalTopicLimit := c.Int("global-topic-limit") + visitorSubscriptionLimit := c.Int("visitor-subscription-limit") + visitorRequestLimitBurst := c.Int("visitor-request-limit-burst") + visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish") + behindProxy := c.Bool("behind-proxy") // Check values if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) { @@ -69,6 +79,11 @@ func execRun(c *cli.Context) error { conf.CacheDuration = cacheDuration conf.KeepaliveInterval = keepaliveInterval conf.ManagerInterval = managerInterval + conf.GlobalTopicLimit = globalTopicLimit + conf.VisitorSubscriptionLimit = visitorSubscriptionLimit + conf.VisitorRequestLimitBurst = visitorRequestLimitBurst + conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish + conf.BehindProxy = behindProxy s, err := server.New(conf) if err != nil { log.Fatalln(err) diff --git a/config/config.go b/config/config.go index f80cc9a..902f23b 100644 --- a/config/config.go +++ b/config/config.go @@ -2,7 +2,6 @@ package config import ( - "golang.org/x/time/rate" "time" ) @@ -15,42 +14,44 @@ const ( ) // Defines all the limits -// - request limit: max number of PUT/GET/.. requests (here: 50 requests bucket, replenished at a rate of one per 10 seconds) // - global topic limit: max number of topics overall -// - subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP -var ( - defaultGlobalTopicLimit = 5000 - defaultVisitorRequestLimit = rate.Every(10 * time.Second) - defaultVisitorRequestLimitBurst = 60 - defaultVisitorSubscriptionLimit = 30 +// - per visistor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds) +// - per visistor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP +const ( + DefaultGlobalTopicLimit = 5000 + DefaultVisitorRequestLimitBurst = 60 + DefaultVisitorRequestLimitReplenish = 10 * time.Second + DefaultVisitorSubscriptionLimit = 30 ) // Config is the main config struct for the application. Use New to instantiate a default config struct. type Config struct { - ListenHTTP string - FirebaseKeyFile string - CacheFile string - CacheDuration time.Duration - KeepaliveInterval time.Duration - ManagerInterval time.Duration - GlobalTopicLimit int - VisitorRequestLimit rate.Limit - VisitorRequestLimitBurst int - VisitorSubscriptionLimit int + ListenHTTP string + FirebaseKeyFile string + CacheFile string + CacheDuration time.Duration + KeepaliveInterval time.Duration + ManagerInterval time.Duration + GlobalTopicLimit int + VisitorRequestLimitBurst int + VisitorRequestLimitReplenish time.Duration + VisitorSubscriptionLimit int + BehindProxy bool } // New instantiates a default new config func New(listenHTTP string) *Config { return &Config{ - ListenHTTP: listenHTTP, - FirebaseKeyFile: "", - CacheFile: "", - CacheDuration: DefaultCacheDuration, - KeepaliveInterval: DefaultKeepaliveInterval, - ManagerInterval: DefaultManagerInterval, - GlobalTopicLimit: defaultGlobalTopicLimit, - VisitorRequestLimit: defaultVisitorRequestLimit, - VisitorRequestLimitBurst: defaultVisitorRequestLimitBurst, - VisitorSubscriptionLimit: defaultVisitorSubscriptionLimit, + ListenHTTP: listenHTTP, + FirebaseKeyFile: "", + CacheFile: "", + CacheDuration: DefaultCacheDuration, + KeepaliveInterval: DefaultKeepaliveInterval, + ManagerInterval: DefaultManagerInterval, + GlobalTopicLimit: DefaultGlobalTopicLimit, + VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, + VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, + VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit, + BehindProxy: false, } } diff --git a/config/config.yml b/config/config.yml index e4a6fc0..210df07 100644 --- a/config/config.yml +++ b/config/config.yml @@ -25,6 +25,30 @@ # # keepalive-interval: 30s -# Interval in which the manager prunes old messages, deletes topics and prints the stats. +# Interval in which the manager prunes old messages, deletes topics +# and prints the stats. # # manager-interval: 1m + +# Rate limiting: Total number of topics before the server rejects new topics. +# +# global-topic-limit: 5000 + +# Rate limiting: Number of subscriptions per visitor (IP address) +# +# visitor-subscription-limit: 30 + +# Rate limiting: Allowed GET/PUT/POST requests per second, per visitor: +# - visitor-request-limit-burst is the initial bucket of requests each visitor has +# - visitor-request-limit-replenish is the rate at which the bucket is refilled +# +# visitor-request-limit-burst: 60 +# visitor-request-limit-replenish: 10s + +# If set, the X-Forwarded-For header is used to determine the visitor IP address +# instead of the remote address of the connection. +# +# WARNING: If you are behind a proxy, you must set this, otherwise all visitors are rate limited +# as if they are one. +# +# behind-proxy: false diff --git a/server/server.go b/server/server.go index 64d26a0..bbde734 100644 --- a/server/server.go +++ b/server/server.go @@ -159,24 +159,22 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { - v := s.visitor(r.RemoteAddr) - if err := v.RequestAllowed(); err != nil { - return err - } if r.Method == http.MethodGet && r.URL.Path == "/" { return s.handleHome(w, r) + } else if r.Method == http.MethodHead && r.URL.Path == "/" { + return s.handleEmpty(w, r) } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { return s.handleStatic(w, r) - } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { - return s.handlePublish(w, r, v) - } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) { - return s.handleSubscribeJSON(w, r, v) - } else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) { - return s.handleSubscribeSSE(w, r, v) - } else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) { - return s.handleSubscribeRaw(w, r, v) } else if r.Method == http.MethodOptions { return s.handleOptions(w, r) + } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { + return s.withRateLimit(w, r, s.handlePublish) + } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) { + return s.withRateLimit(w, r, s.handleSubscribeJSON) + } else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) { + return s.withRateLimit(w, r, s.handleSubscribeSSE) + } else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) { + return s.withRateLimit(w, r, s.handleSubscribeRaw) } return errHTTPNotFound } @@ -186,6 +184,10 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error { return err } +func (s *Server) handleEmpty(w http.ResponseWriter, r *http.Request) error { + return nil +} + func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error { http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r) return nil @@ -394,15 +396,27 @@ func (s *Server) updateStatsAndExpire() { s.messages, len(s.topics), subscribers, messages, len(s.visitors)) } +func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error { + v := s.visitor(r) + if err := v.RequestAllowed(); err != nil { + return err + } + return handler(w, r, v) +} + // visitor creates or retrieves a rate.Limiter for the given visitor. // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). -func (s *Server) visitor(remoteAddr string) *visitor { +func (s *Server) visitor(r *http.Request) *visitor { s.mu.Lock() defer s.mu.Unlock() + remoteAddr := r.RemoteAddr ip, _, err := net.SplitHostPort(remoteAddr) if err != nil { ip = remoteAddr // This should not happen in real life; only in tests. } + if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" { + ip = r.Header.Get("X-Forwarded-For") + } v, exists := s.visitors[ip] if !exists { s.visitors[ip] = newVisitor(s.config) diff --git a/server/visitor.go b/server/visitor.go index 3e028ac..7c23f89 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -24,7 +24,7 @@ type visitor struct { func newVisitor(conf *config.Config) *visitor { return &visitor{ config: conf, - limiter: rate.NewLimiter(conf.VisitorRequestLimit, conf.VisitorRequestLimitBurst), + limiter: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst), subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)), seen: time.Now(), }