forked from mirrors/ntfy
Auth rate limiter
This commit is contained in:
parent
3ac315a9e7
commit
e1a4a74905
16 changed files with 152 additions and 60 deletions
|
@ -79,7 +79,9 @@ user * (role: anonymous, tier: none)
|
||||||
func runAccessCommand(app *cli.App, conf *server.Config, args ...string) error {
|
func runAccessCommand(app *cli.App, conf *server.Config, args ...string) error {
|
||||||
userArgs := []string{
|
userArgs := []string{
|
||||||
"ntfy",
|
"ntfy",
|
||||||
|
"--log-level=ERROR",
|
||||||
"access",
|
"access",
|
||||||
|
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
|
||||||
"--auth-file=" + conf.AuthFile,
|
"--auth-file=" + conf.AuthFile,
|
||||||
"--auth-default-access=" + conf.AuthDefault.String(),
|
"--auth-default-access=" + conf.AuthDefault.String(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -253,6 +253,7 @@ func execServe(c *cli.Context) error {
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
conf := server.NewConfig()
|
conf := server.NewConfig()
|
||||||
|
conf.File = config
|
||||||
conf.BaseURL = baseURL
|
conf.BaseURL = baseURL
|
||||||
conf.ListenHTTP = listenHTTP
|
conf.ListenHTTP = listenHTTP
|
||||||
conf.ListenHTTPS = listenHTTPS
|
conf.ListenHTTPS = listenHTTPS
|
||||||
|
|
|
@ -38,7 +38,9 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) {
|
||||||
func runTierCommand(app *cli.App, conf *server.Config, args ...string) error {
|
func runTierCommand(app *cli.App, conf *server.Config, args ...string) error {
|
||||||
userArgs := []string{
|
userArgs := []string{
|
||||||
"ntfy",
|
"ntfy",
|
||||||
|
"--log-level=ERROR",
|
||||||
"tier",
|
"tier",
|
||||||
|
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
|
||||||
"--auth-file=" + conf.AuthFile,
|
"--auth-file=" + conf.AuthFile,
|
||||||
"--auth-default-access=" + conf.AuthDefault.String(),
|
"--auth-default-access=" + conf.AuthDefault.String(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,9 @@ func TestCLI_Token_AddListRemove(t *testing.T) {
|
||||||
func runTokenCommand(app *cli.App, conf *server.Config, args ...string) error {
|
func runTokenCommand(app *cli.App, conf *server.Config, args ...string) error {
|
||||||
userArgs := []string{
|
userArgs := []string{
|
||||||
"ntfy",
|
"ntfy",
|
||||||
|
"--log-level=ERROR",
|
||||||
"token",
|
"token",
|
||||||
|
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
|
||||||
"--auth-file=" + conf.AuthFile,
|
"--auth-file=" + conf.AuthFile,
|
||||||
}
|
}
|
||||||
return app.Run(append(userArgs, args...))
|
return app.Run(append(userArgs, args...))
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"heckel.io/ntfy/server"
|
"heckel.io/ntfy/server"
|
||||||
"heckel.io/ntfy/test"
|
"heckel.io/ntfy/test"
|
||||||
"heckel.io/ntfy/user"
|
"heckel.io/ntfy/user"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -113,7 +114,10 @@ func TestCLI_User_Delete(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config, port int) {
|
func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config, port int) {
|
||||||
|
configFile := filepath.Join(t.TempDir(), "server-dummy.yml")
|
||||||
|
require.Nil(t, os.WriteFile(configFile, []byte(""), 0600)) // Dummy config file to avoid lookup of real server.yml
|
||||||
conf = server.NewConfig()
|
conf = server.NewConfig()
|
||||||
|
conf.File = configFile
|
||||||
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
||||||
conf.AuthDefault = user.PermissionDenyAll
|
conf.AuthDefault = user.PermissionDenyAll
|
||||||
s, port = test.StartServerWithConfig(t, conf)
|
s, port = test.StartServerWithConfig(t, conf)
|
||||||
|
@ -123,7 +127,9 @@ func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config,
|
||||||
func runUserCommand(app *cli.App, conf *server.Config, args ...string) error {
|
func runUserCommand(app *cli.App, conf *server.Config, args ...string) error {
|
||||||
userArgs := []string{
|
userArgs := []string{
|
||||||
"ntfy",
|
"ntfy",
|
||||||
|
"--log-level=ERROR",
|
||||||
"user",
|
"user",
|
||||||
|
"--config=" + conf.File, // Dummy config file to avoid lookups of real file
|
||||||
"--auth-file=" + conf.AuthFile,
|
"--auth-file=" + conf.AuthFile,
|
||||||
"--auth-default-access=" + conf.AuthDefault.String(),
|
"--auth-default-access=" + conf.AuthDefault.String(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,8 +82,10 @@ func (e *Event) Time(t time.Time) *Event {
|
||||||
|
|
||||||
// Err adds an "error" field to the log event
|
// Err adds an "error" field to the log event
|
||||||
func (e *Event) Err(err error) *Event {
|
func (e *Event) Err(err error) *Event {
|
||||||
if c, ok := err.(Contexter); ok {
|
if err == nil {
|
||||||
return e.Fields(c.Context())
|
return e
|
||||||
|
} else if c, ok := err.(Contexter); ok {
|
||||||
|
return e.With(c)
|
||||||
}
|
}
|
||||||
return e.Field(errorField, err.Error())
|
return e.Field(errorField, err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,6 +49,8 @@ const (
|
||||||
DefaultVisitorEmailLimitReplenish = time.Hour
|
DefaultVisitorEmailLimitReplenish = time.Hour
|
||||||
DefaultVisitorAccountCreationLimitBurst = 3
|
DefaultVisitorAccountCreationLimitBurst = 3
|
||||||
DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
|
DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
|
||||||
|
DefaultVisitorAuthFailureLimitBurst = 10
|
||||||
|
DefaultVisitorAuthFailureLimitReplenish = time.Minute
|
||||||
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
|
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
|
||||||
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
|
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
|
||||||
)
|
)
|
||||||
|
@ -60,6 +62,7 @@ var (
|
||||||
|
|
||||||
// Config is the main config struct for the application. Use New to instantiate a default config struct.
|
// Config is the main config struct for the application. Use New to instantiate a default config struct.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
File string // Config file, only used for testing
|
||||||
BaseURL string
|
BaseURL string
|
||||||
ListenHTTP string
|
ListenHTTP string
|
||||||
ListenHTTPS string
|
ListenHTTPS string
|
||||||
|
@ -113,6 +116,8 @@ type Config struct {
|
||||||
VisitorEmailLimitReplenish time.Duration
|
VisitorEmailLimitReplenish time.Duration
|
||||||
VisitorAccountCreationLimitBurst int
|
VisitorAccountCreationLimitBurst int
|
||||||
VisitorAccountCreationLimitReplenish time.Duration
|
VisitorAccountCreationLimitReplenish time.Duration
|
||||||
|
VisitorAuthFailureLimitBurst int
|
||||||
|
VisitorAuthFailureLimitReplenish time.Duration
|
||||||
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
|
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
|
||||||
BehindProxy bool
|
BehindProxy bool
|
||||||
StripeSecretKey string
|
StripeSecretKey string
|
||||||
|
@ -129,6 +134,7 @@ type Config struct {
|
||||||
// NewConfig instantiates a default new server config
|
// NewConfig instantiates a default new server config
|
||||||
func NewConfig() *Config {
|
func NewConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
|
File: "", // Only used for testing
|
||||||
BaseURL: "",
|
BaseURL: "",
|
||||||
ListenHTTP: DefaultListenHTTP,
|
ListenHTTP: DefaultListenHTTP,
|
||||||
ListenHTTPS: "",
|
ListenHTTPS: "",
|
||||||
|
@ -182,6 +188,8 @@ func NewConfig() *Config {
|
||||||
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
||||||
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
|
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
|
||||||
VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
|
VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
|
||||||
|
VisitorAuthFailureLimitBurst: DefaultVisitorAuthFailureLimitBurst,
|
||||||
|
VisitorAuthFailureLimitReplenish: DefaultVisitorAuthFailureLimitReplenish,
|
||||||
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
|
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
|
||||||
BehindProxy: false,
|
BehindProxy: false,
|
||||||
StripeSecretKey: "",
|
StripeSecretKey: "",
|
||||||
|
|
|
@ -87,6 +87,7 @@ var (
|
||||||
errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
|
errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
|
||||||
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
|
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
|
||||||
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
|
errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
|
||||||
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
||||||
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
|
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
|
||||||
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
|
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
|
||||||
|
|
|
@ -34,9 +34,9 @@ import (
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
||||||
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
|
||||||
- HIGH Account limit creation triggers when account is taken!
|
|
||||||
- HIGH Docs
|
- HIGH Docs
|
||||||
|
- tiers
|
||||||
|
- api
|
||||||
- HIGH Self-review
|
- HIGH Self-review
|
||||||
- MEDIUM: Test for expiring messages after reservation removal
|
- MEDIUM: Test for expiring messages after reservation removal
|
||||||
- MEDIUM: Test new token endpoints & never-expiring token
|
- MEDIUM: Test new token endpoints & never-expiring token
|
||||||
|
@ -1540,18 +1540,6 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
|
||||||
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
|
||||||
return next(w, r, v)
|
|
||||||
} else if err := v.RequestAllowed(); err != nil {
|
|
||||||
logvr(v, r).Err(err).Trace("Request not allowed by rate limiter")
|
|
||||||
return errHTTPTooManyRequestsLimitRequests
|
|
||||||
}
|
|
||||||
return next(w, r, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
|
// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
|
||||||
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
|
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
|
||||||
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
|
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
|
||||||
|
@ -1648,43 +1636,65 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
|
// maybeAuthenticate reads the "Authorization" header and will try to authenticate the user
|
||||||
// Note that this function will always return a visitor, even if an error occurs.
|
// if it is set.
|
||||||
func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
|
//
|
||||||
|
// - If the header is not set, an IP-based visitor is returned
|
||||||
|
// - If the header is set, authenticate will be called to check the username/password (Basic auth),
|
||||||
|
// or the token (Bearer auth), and read the user from the database
|
||||||
|
//
|
||||||
|
// This function will ALWAYS return a visitor, even if an error occurs (e.g. unauthorized), so
|
||||||
|
// that subsequent logging calls still have a visitor context.
|
||||||
|
func (s *Server) maybeAuthenticate(r *http.Request) (*visitor, error) {
|
||||||
|
// Read "Authorization" header value, and exit out early if it's not set
|
||||||
ip := extractIPAddress(r, s.config.BehindProxy)
|
ip := extractIPAddress(r, s.config.BehindProxy)
|
||||||
var u *user.User // may stay nil if no auth header!
|
vip := s.visitor(ip, nil)
|
||||||
if u, err = s.authenticate(r); err != nil {
|
header, err := readAuthHeader(r)
|
||||||
logr(r).Err(err).Debug("Authentication failed: %s", err.Error())
|
if err != nil {
|
||||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
return vip, err
|
||||||
|
} else if header == "" {
|
||||||
|
return vip, nil
|
||||||
|
} else if s.userManager == nil {
|
||||||
|
return vip, errHTTPUnauthorized
|
||||||
}
|
}
|
||||||
v = s.visitor(ip, u)
|
// If we're trying to auth, check the rate limiter first
|
||||||
v.SetUser(u) // Update visitor user with latest from database!
|
if !vip.AuthAllowed() {
|
||||||
return v, err // Always return visitor, even when error occurs!
|
return vip, errHTTPTooManyRequestsLimitAuthFailure // Always return visitor, even when error occurs!
|
||||||
|
}
|
||||||
|
u, err := s.authenticate(r, header)
|
||||||
|
if err != nil {
|
||||||
|
vip.AuthFailed()
|
||||||
|
logr(r).Err(err).Debug("Authentication failed")
|
||||||
|
return vip, errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||||
|
}
|
||||||
|
// Authentication with user was successful
|
||||||
|
return s.visitor(ip, u), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
|
// authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
|
||||||
// The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
|
// The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
|
||||||
// support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
|
// support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
|
||||||
// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
|
// query param is effectively doubly base64 encoded. Its format is base64(Basic base64(user:pass)).
|
||||||
func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
|
func (s *Server) authenticate(r *http.Request, header string) (user *user.User, err error) {
|
||||||
|
if strings.HasPrefix(header, "Bearer") {
|
||||||
|
return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(header, "Bearer")))
|
||||||
|
}
|
||||||
|
return s.authenticateBasicAuth(r, header)
|
||||||
|
}
|
||||||
|
|
||||||
|
// readAuthHeader reads the raw value of the Authorization header, either from the actual HTTP header,
|
||||||
|
// or from the ?auth... query parameter
|
||||||
|
func readAuthHeader(r *http.Request) (string, error) {
|
||||||
value := strings.TrimSpace(r.Header.Get("Authorization"))
|
value := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||||
queryParam := readQueryParam(r, "authorization", "auth")
|
queryParam := readQueryParam(r, "authorization", "auth")
|
||||||
if queryParam != "" {
|
if queryParam != "" {
|
||||||
a, err := base64.RawURLEncoding.DecodeString(queryParam)
|
a, err := base64.RawURLEncoding.DecodeString(queryParam)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
value = strings.TrimSpace(string(a))
|
value = strings.TrimSpace(string(a))
|
||||||
}
|
}
|
||||||
if value == "" {
|
return value, nil
|
||||||
return nil, nil
|
|
||||||
} else if s.userManager == nil {
|
|
||||||
return nil, errHTTPUnauthorized
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(value, "Bearer") {
|
|
||||||
return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(value, "Bearer")))
|
|
||||||
}
|
|
||||||
return s.authenticateBasicAuth(r, value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
|
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
|
||||||
|
@ -1721,6 +1731,7 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
|
||||||
return s.visitors[id]
|
return s.visitors[id]
|
||||||
}
|
}
|
||||||
v.Keepalive()
|
v.Keepalive()
|
||||||
|
v.SetUser(user) // Always update with the latest user, may be nil!
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
||||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil {
|
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
v.AccountCreated()
|
||||||
return s.writeJSON(w, newSuccessResponse())
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
||||||
if err := v.FirebaseAllowed(); err != nil {
|
if !v.FirebaseAllowed() {
|
||||||
return errFirebaseTemporarilyBanned
|
return errFirebaseTemporarilyBanned
|
||||||
}
|
}
|
||||||
fbm, err := toFirebaseMessage(m, c.auther)
|
fbm, err := toFirebaseMessage(m, c.auther)
|
||||||
|
|
|
@ -1,9 +1,21 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
|
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||||
|
return next(w, r, v)
|
||||||
|
} else if !v.RequestAllowed() {
|
||||||
|
return errHTTPTooManyRequestsLimitRequests
|
||||||
|
}
|
||||||
|
return next(w, r, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
|
func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if !s.config.EnableWeb {
|
if !s.config.EnableWeb {
|
||||||
|
|
|
@ -374,13 +374,13 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < 209; i++ {
|
for i := 0; i < 209; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code, "Failed on %d", i)
|
||||||
wg.Done()
|
}(i)
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
|
|
|
@ -733,6 +733,24 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
|
||||||
require.Equal(t, 403, response.Code) // Anonymous read not allowed
|
require.Equal(t, 403, response.Code) // Anonymous read not allowed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_Auth_Fail_Rate_Limiting(t *testing.T) {
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
response := request(t, s, "PUT", "/announcements", "test", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 401, response.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
response := request(t, s, "PUT", "/announcements", "test", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 429, response.Code)
|
||||||
|
require.Equal(t, 42909, toHTTPError(t, response.Body.String()).Code)
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_Auth_ViaQuery(t *testing.T) {
|
func TestServer_Auth_ViaQuery(t *testing.T) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t)
|
||||||
c.AuthDefault = user.PermissionDenyAll
|
c.AuthDefault = user.PermissionDenyAll
|
||||||
|
|
|
@ -64,6 +64,7 @@ type visitor struct {
|
||||||
subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
|
subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
|
||||||
bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads
|
bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads
|
||||||
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
|
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
|
||||||
|
authLimiter *rate.Limiter // Limiter for incorrect login attempts
|
||||||
firebase time.Time // Next allowed Firebase message
|
firebase time.Time // Next allowed Firebase message
|
||||||
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
|
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
@ -130,6 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
||||||
emailsLimiter: nil, // Set in resetLimiters
|
emailsLimiter: nil, // Set in resetLimiters
|
||||||
bandwidthLimiter: nil, // Set in resetLimiters
|
bandwidthLimiter: nil, // Set in resetLimiters
|
||||||
accountLimiter: nil, // Set in resetLimiters, may be nil
|
accountLimiter: nil, // Set in resetLimiters, may be nil
|
||||||
|
authLimiter: nil, // Set in resetLimiters, may be nil
|
||||||
}
|
}
|
||||||
v.resetLimitersNoLock(messages, emails, false)
|
v.resetLimitersNoLock(messages, emails, false)
|
||||||
return v
|
return v
|
||||||
|
@ -154,6 +156,10 @@ func (v *visitor) contextNoLock() log.Context {
|
||||||
"visitor_request_limiter_limit": v.requestLimiter.Limit(),
|
"visitor_request_limiter_limit": v.requestLimiter.Limit(),
|
||||||
"visitor_request_limiter_tokens": v.requestLimiter.Tokens(),
|
"visitor_request_limiter_tokens": v.requestLimiter.Tokens(),
|
||||||
}
|
}
|
||||||
|
if v.authLimiter != nil {
|
||||||
|
fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit()
|
||||||
|
fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens()
|
||||||
|
}
|
||||||
if v.user != nil {
|
if v.user != nil {
|
||||||
fields["user_id"] = v.user.ID
|
fields["user_id"] = v.user.ID
|
||||||
fields["user_name"] = v.user.Name
|
fields["user_name"] = v.user.Name
|
||||||
|
@ -182,28 +188,16 @@ func visitorExtendedInfoContext(info *visitorInfo) log.Context {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
func (v *visitor) RequestAllowed() error {
|
func (v *visitor) RequestAllowed() bool {
|
||||||
v.mu.Lock() // limiters could be replaced!
|
v.mu.Lock() // limiters could be replaced!
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
if !v.requestLimiter.Allow() {
|
return v.requestLimiter.Allow()
|
||||||
return errVisitorLimitReached
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) RequestLimiter() *rate.Limiter {
|
func (v *visitor) FirebaseAllowed() bool {
|
||||||
v.mu.Lock() // limiters could be replaced!
|
|
||||||
defer v.mu.Unlock()
|
|
||||||
return v.requestLimiter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *visitor) FirebaseAllowed() error {
|
|
||||||
v.mu.Lock()
|
v.mu.Lock()
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
if time.Now().Before(v.firebase) {
|
return !time.Now().Before(v.firebase)
|
||||||
return errVisitorLimitReached
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) FirebaseTemporarilyDeny() {
|
func (v *visitor) FirebaseTemporarilyDeny() {
|
||||||
|
@ -230,15 +224,44 @@ func (v *visitor) SubscriptionAllowed() bool {
|
||||||
return v.subscriptionLimiter.Allow()
|
return v.subscriptionLimiter.Allow()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthAllowed returns true if an auth request can be attempted (> 1 token available)
|
||||||
|
func (v *visitor) AuthAllowed() bool {
|
||||||
|
v.mu.Lock() // limiters could be replaced!
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
if v.authLimiter == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return v.authLimiter.Tokens() > 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthFailed records an auth failure
|
||||||
|
func (v *visitor) AuthFailed() {
|
||||||
|
v.mu.Lock() // limiters could be replaced!
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
if v.authLimiter != nil {
|
||||||
|
v.authLimiter.Allow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountCreationAllowed returns true if a new account can be created
|
||||||
func (v *visitor) AccountCreationAllowed() bool {
|
func (v *visitor) AccountCreationAllowed() bool {
|
||||||
v.mu.Lock() // limiters could be replaced!
|
v.mu.Lock() // limiters could be replaced!
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) {
|
if v.accountLimiter == nil || (v.accountLimiter != nil && v.accountLimiter.Tokens() < 1) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountCreated decreases the account limiter. This is to be called after an account was created.
|
||||||
|
func (v *visitor) AccountCreated() {
|
||||||
|
v.mu.Lock() // limiters could be replaced!
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
if v.accountLimiter != nil {
|
||||||
|
v.accountLimiter.Allow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (v *visitor) BandwidthAllowed(bytes int64) bool {
|
func (v *visitor) BandwidthAllowed(bytes int64) bool {
|
||||||
v.mu.Lock() // limiters could be replaced!
|
v.mu.Lock() // limiters could be replaced!
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
|
@ -336,8 +359,10 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool
|
||||||
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
|
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
|
||||||
if v.user == nil {
|
if v.user == nil {
|
||||||
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
|
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
|
||||||
|
v.authLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAuthFailureLimitReplenish), v.config.VisitorAuthFailureLimitBurst)
|
||||||
} else {
|
} else {
|
||||||
v.accountLimiter = nil // Users cannot create accounts when logged in
|
v.accountLimiter = nil // Users cannot create accounts when logged in
|
||||||
|
v.authLimiter = nil // Users are already logged in, no need to limit requests
|
||||||
}
|
}
|
||||||
if enqueueUpdate && v.user != nil {
|
if enqueueUpdate && v.user != nil {
|
||||||
go v.userManager.EnqueueStats(v.user.ID, &user.Stats{
|
go v.userManager.EnqueueStats(v.user.ID, &user.Stats{
|
||||||
|
|
|
@ -372,6 +372,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
||||||
}
|
}
|
||||||
user, err := a.userByToken(token)
|
user, err := a.userByToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Tag(tagManager).Field("token", token).Err(err).Trace("Authentication of token failed")
|
||||||
return nil, ErrUnauthenticated
|
return nil, ErrUnauthenticated
|
||||||
}
|
}
|
||||||
user.Token = token
|
user.Token = token
|
||||||
|
|
Loading…
Reference in a new issue