1
0
Fork 0
forked from mirrors/ntfy

Auth rate limiter

This commit is contained in:
binwiederhier 2023-02-08 15:20:44 -05:00
parent 3ac315a9e7
commit e1a4a74905
16 changed files with 152 additions and 60 deletions

View file

@ -79,7 +79,9 @@ user * (role: anonymous, tier: none)
func runAccessCommand(app *cli.App, conf *server.Config, args ...string) error {
userArgs := []string{
"ntfy",
"--log-level=ERROR",
"access",
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
"--auth-file=" + conf.AuthFile,
"--auth-default-access=" + conf.AuthDefault.String(),
}

View file

@ -253,6 +253,7 @@ func execServe(c *cli.Context) error {
// Run server
conf := server.NewConfig()
conf.File = config
conf.BaseURL = baseURL
conf.ListenHTTP = listenHTTP
conf.ListenHTTPS = listenHTTPS

View file

@ -38,7 +38,9 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) {
func runTierCommand(app *cli.App, conf *server.Config, args ...string) error {
userArgs := []string{
"ntfy",
"--log-level=ERROR",
"tier",
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
"--auth-file=" + conf.AuthFile,
"--auth-default-access=" + conf.AuthDefault.String(),
}

View file

@ -41,7 +41,9 @@ func TestCLI_Token_AddListRemove(t *testing.T) {
func runTokenCommand(app *cli.App, conf *server.Config, args ...string) error {
userArgs := []string{
"ntfy",
"--log-level=ERROR",
"token",
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
"--auth-file=" + conf.AuthFile,
}
return app.Run(append(userArgs, args...))

View file

@ -6,6 +6,7 @@ import (
"heckel.io/ntfy/server"
"heckel.io/ntfy/test"
"heckel.io/ntfy/user"
"os"
"path/filepath"
"testing"
)
@ -113,7 +114,10 @@ func TestCLI_User_Delete(t *testing.T) {
}
func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config, port int) {
configFile := filepath.Join(t.TempDir(), "server-dummy.yml")
require.Nil(t, os.WriteFile(configFile, []byte(""), 0600)) // Dummy config file to avoid lookup of real server.yml
conf = server.NewConfig()
conf.File = configFile
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
conf.AuthDefault = user.PermissionDenyAll
s, port = test.StartServerWithConfig(t, conf)
@ -123,7 +127,9 @@ func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config,
func runUserCommand(app *cli.App, conf *server.Config, args ...string) error {
userArgs := []string{
"ntfy",
"--log-level=ERROR",
"user",
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
"--auth-file=" + conf.AuthFile,
"--auth-default-access=" + conf.AuthDefault.String(),
}

View file

@ -82,8 +82,10 @@ func (e *Event) Time(t time.Time) *Event {
// Err adds an "error" field to the log event
func (e *Event) Err(err error) *Event {
if c, ok := err.(Contexter); ok {
return e.Fields(c.Context())
if err == nil {
return e
} else if c, ok := err.(Contexter); ok {
return e.With(c)
}
return e.Field(errorField, err.Error())
}

View file

@ -49,6 +49,8 @@ const (
DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorAccountCreationLimitBurst = 3
DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
DefaultVisitorAuthFailureLimitBurst = 10
DefaultVisitorAuthFailureLimitReplenish = time.Minute
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
)
@ -60,6 +62,7 @@ var (
// Config is the main config struct for the application. Use New to instantiate a default config struct.
type Config struct {
File string // Config file, only used for testing
BaseURL string
ListenHTTP string
ListenHTTPS string
@ -113,6 +116,8 @@ type Config struct {
VisitorEmailLimitReplenish time.Duration
VisitorAccountCreationLimitBurst int
VisitorAccountCreationLimitReplenish time.Duration
VisitorAuthFailureLimitBurst int
VisitorAuthFailureLimitReplenish time.Duration
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
BehindProxy bool
StripeSecretKey string
@ -129,6 +134,7 @@ type Config struct {
// NewConfig instantiates a default new server config
func NewConfig() *Config {
return &Config{
File: "", // Only used for testing
BaseURL: "",
ListenHTTP: DefaultListenHTTP,
ListenHTTPS: "",
@ -182,6 +188,8 @@ func NewConfig() *Config {
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
VisitorAuthFailureLimitBurst: DefaultVisitorAuthFailureLimitBurst,
VisitorAuthFailureLimitReplenish: DefaultVisitorAuthFailureLimitReplenish,
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
BehindProxy: false,
StripeSecretKey: "",

View file

@ -87,6 +87,7 @@ var (
errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}

View file

@ -34,9 +34,9 @@ import (
/*
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Account limit creation triggers when account is taken!
- HIGH Docs
- tiers
- api
- HIGH Self-review
- MEDIUM: Test for expiring messages after reservation removal
- MEDIUM: Test new token endpoints & never-expiring token
@ -1540,18 +1540,6 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
return nil
}
func (s *Server) limitRequests(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v)
} else if err := v.RequestAllowed(); err != nil {
logvr(v, r).Err(err).Trace("Request not allowed by rate limiter")
return errHTTPTooManyRequestsLimitRequests
}
return next(w, r, v)
}
}
// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
@ -1648,43 +1636,65 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
}
}
// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
// Note that this function will always return a visitor, even if an error occurs.
func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
// maybeAuthenticate reads the "Authorization" header and will try to authenticate the user
// if it is set.
//
// - If the header is not set, an IP-based visitor is returned
// - If the header is set, authenticate will be called to check the username/password (Basic auth),
// or the token (Bearer auth), and read the user from the database
//
// This function will ALWAYS return a visitor, even if an error occurs (e.g. unauthorized), so
// that subsequent logging calls still have a visitor context.
func (s *Server) maybeAuthenticate(r *http.Request) (*visitor, error) {
// Read "Authorization" header value, and exit out early if it's not set
ip := extractIPAddress(r, s.config.BehindProxy)
var u *user.User // may stay nil if no auth header!
if u, err = s.authenticate(r); err != nil {
logr(r).Err(err).Debug("Authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
vip := s.visitor(ip, nil)
header, err := readAuthHeader(r)
if err != nil {
return vip, err
} else if header == "" {
return vip, nil
} else if s.userManager == nil {
return vip, errHTTPUnauthorized
}
v = s.visitor(ip, u)
v.SetUser(u) // Update visitor user with latest from database!
return v, err // Always return visitor, even when error occurs!
// If we're trying to auth, check the rate limiter first
if !vip.AuthAllowed() {
return vip, errHTTPTooManyRequestsLimitAuthFailure // Always return visitor, even when error occurs!
}
u, err := s.authenticate(r, header)
if err != nil {
vip.AuthFailed()
logr(r).Err(err).Debug("Authentication failed")
return vip, errHTTPUnauthorized // Always return visitor, even when error occurs!
}
// Authentication with user was successful
return s.visitor(ip, u), nil
}
// authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
// The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
// support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
// query param is effectively doubly base64 encoded. Its format is base64(Basic base64(user:pass)).
func (s *Server) authenticate(r *http.Request, header string) (user *user.User, err error) {
if strings.HasPrefix(header, "Bearer") {
return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(header, "Bearer")))
}
return s.authenticateBasicAuth(r, header)
}
// readAuthHeader reads the raw value of the Authorization header, either from the actual HTTP header,
// or from the ?auth... query parameter
func readAuthHeader(r *http.Request) (string, error) {
value := strings.TrimSpace(r.Header.Get("Authorization"))
queryParam := readQueryParam(r, "authorization", "auth")
if queryParam != "" {
a, err := base64.RawURLEncoding.DecodeString(queryParam)
if err != nil {
return nil, err
return "", err
}
value = strings.TrimSpace(string(a))
}
if value == "" {
return nil, nil
} else if s.userManager == nil {
return nil, errHTTPUnauthorized
}
if strings.HasPrefix(value, "Bearer") {
return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(value, "Bearer")))
}
return s.authenticateBasicAuth(r, value)
return value, nil
}
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
@ -1721,6 +1731,7 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
return s.visitors[id]
}
v.Keepalive()
v.SetUser(user) // Always update with the latest user, may be nil!
return v
}

View file

@ -41,6 +41,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil {
return err
}
v.AccountCreated()
return s.writeJSON(w, newSuccessResponse())
}

View file

@ -39,7 +39,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien
}
func (c *firebaseClient) Send(v *visitor, m *message) error {
if err := v.FirebaseAllowed(); err != nil {
if !v.FirebaseAllowed() {
return errFirebaseTemporarilyBanned
}
fbm, err := toFirebaseMessage(m, c.auther)

View file

@ -1,9 +1,21 @@
package server
import (
"heckel.io/ntfy/util"
"net/http"
)
func (s *Server) limitRequests(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
return next(w, r, v)
} else if !v.RequestAllowed() {
return errHTTPTooManyRequestsLimitRequests
}
return next(w, r, v)
}
}
func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !s.config.EnableWeb {

View file

@ -374,13 +374,13 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
var wg sync.WaitGroup
for i := 0; i < 209; i++ {
wg.Add(1)
go func() {
go func(i int) {
defer wg.Done()
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
wg.Done()
}()
require.Equal(t, 200, rr.Code, "Failed on %d", i)
}(i)
}
wg.Wait()
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{

View file

@ -733,6 +733,24 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
require.Equal(t, 403, response.Code) // Anonymous read not allowed
}
func TestServer_Auth_Fail_Rate_Limiting(t *testing.T) {
c := newTestConfigWithAuthFile(t)
s := newTestServer(t, c)
for i := 0; i < 10; i++ {
response := request(t, s, "PUT", "/announcements", "test", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 401, response.Code)
}
response := request(t, s, "PUT", "/announcements", "test", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, response.Code)
require.Equal(t, 42909, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Auth_ViaQuery(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll

View file

@ -64,6 +64,7 @@ type visitor struct {
subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
authLimiter *rate.Limiter // Limiter for incorrect login attempts
firebase time.Time // Next allowed Firebase message
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
mu sync.Mutex
@ -130,6 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
emailsLimiter: nil, // Set in resetLimiters
bandwidthLimiter: nil, // Set in resetLimiters
accountLimiter: nil, // Set in resetLimiters, may be nil
authLimiter: nil, // Set in resetLimiters, may be nil
}
v.resetLimitersNoLock(messages, emails, false)
return v
@ -154,6 +156,10 @@ func (v *visitor) contextNoLock() log.Context {
"visitor_request_limiter_limit": v.requestLimiter.Limit(),
"visitor_request_limiter_tokens": v.requestLimiter.Tokens(),
}
if v.authLimiter != nil {
fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit()
fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens()
}
if v.user != nil {
fields["user_id"] = v.user.ID
fields["user_name"] = v.user.Name
@ -182,28 +188,16 @@ func visitorExtendedInfoContext(info *visitorInfo) log.Context {
}
}
func (v *visitor) RequestAllowed() error {
func (v *visitor) RequestAllowed() bool {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
if !v.requestLimiter.Allow() {
return errVisitorLimitReached
}
return nil
return v.requestLimiter.Allow()
}
func (v *visitor) RequestLimiter() *rate.Limiter {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
return v.requestLimiter
}
func (v *visitor) FirebaseAllowed() error {
func (v *visitor) FirebaseAllowed() bool {
v.mu.Lock()
defer v.mu.Unlock()
if time.Now().Before(v.firebase) {
return errVisitorLimitReached
}
return nil
return !time.Now().Before(v.firebase)
}
func (v *visitor) FirebaseTemporarilyDeny() {
@ -230,15 +224,44 @@ func (v *visitor) SubscriptionAllowed() bool {
return v.subscriptionLimiter.Allow()
}
// AuthAllowed returns true if an auth request can be attempted (> 1 token available)
func (v *visitor) AuthAllowed() bool {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
if v.authLimiter == nil {
return true
}
return v.authLimiter.Tokens() > 1
}
// AuthFailed records an auth failure
func (v *visitor) AuthFailed() {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
if v.authLimiter != nil {
v.authLimiter.Allow()
}
}
// AccountCreationAllowed returns true if a new account can be created
func (v *visitor) AccountCreationAllowed() bool {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) {
if v.accountLimiter == nil || (v.accountLimiter != nil && v.accountLimiter.Tokens() < 1) {
return false
}
return true
}
// AccountCreated decreases the account limiter. This is to be called after an account was created.
func (v *visitor) AccountCreated() {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
if v.accountLimiter != nil {
v.accountLimiter.Allow()
}
}
func (v *visitor) BandwidthAllowed(bytes int64) bool {
v.mu.Lock() // limiters could be replaced!
defer v.mu.Unlock()
@ -336,8 +359,10 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
if v.user == nil {
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
v.authLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAuthFailureLimitReplenish), v.config.VisitorAuthFailureLimitBurst)
} else {
v.accountLimiter = nil // Users cannot create accounts when logged in
v.authLimiter = nil // Users are already logged in, no need to limit requests
}
if enqueueUpdate && v.user != nil {
go v.userManager.EnqueueStats(v.user.ID, &user.Stats{

View file

@ -372,6 +372,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
}
user, err := a.userByToken(token)
if err != nil {
log.Tag(tagManager).Field("token", token).Err(err).Trace("Authentication of token failed")
return nil, ErrUnauthenticated
}
user.Token = token