diff --git a/cmd/access.go b/cmd/access.go index 9bd65f8..e03abd0 100644 --- a/cmd/access.go +++ b/cmd/access.go @@ -6,7 +6,7 @@ import ( "errors" "fmt" "github.com/urfave/cli/v2" - "heckel.io/ntfy/auth" + "heckel.io/ntfy/user" "heckel.io/ntfy/util" ) @@ -77,7 +77,7 @@ func execUserAccess(c *cli.Context) error { } username := c.Args().Get(0) if username == userEveryone { - username = auth.Everyone + username = user.Everyone } topic := c.Args().Get(1) perms := c.Args().Get(2) @@ -96,16 +96,16 @@ func execUserAccess(c *cli.Context) error { return changeAccess(c, manager, username, topic, perms) } -func changeAccess(c *cli.Context, manager auth.Manager, username string, topic string, perms string) error { +func changeAccess(c *cli.Context, manager user.Manager, username string, topic string, perms string) error { if !util.Contains([]string{"", "read-write", "rw", "read-only", "read", "ro", "write-only", "write", "wo", "none", "deny"}, perms) { return errors.New("permission must be one of: read-write, read-only, write-only, or deny (or the aliases: read, ro, write, wo, none)") } read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms) write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms) - user, err := manager.User(username) - if err == auth.ErrNotFound { + u, err := manager.User(username) + if err == user.ErrNotFound { return fmt.Errorf("user %s does not exist", username) - } else if user.Role == auth.RoleAdmin { + } else if u.Role == user.RoleAdmin { return fmt.Errorf("user %s is an admin user, access control entries have no effect", username) } if err := manager.AllowAccess(username, topic, read, write); err != nil { @@ -123,7 +123,7 @@ func changeAccess(c *cli.Context, manager auth.Manager, username string, topic s return showUserAccess(c, manager, username) } -func resetAccess(c *cli.Context, manager auth.Manager, username, topic string) error { +func resetAccess(c *cli.Context, manager user.Manager, username, topic string) error { if username == "" { return resetAllAccess(c, manager) } else if topic == "" { @@ -132,7 +132,7 @@ func resetAccess(c *cli.Context, manager auth.Manager, username, topic string) e return resetUserTopicAccess(c, manager, username, topic) } -func resetAllAccess(c *cli.Context, manager auth.Manager) error { +func resetAllAccess(c *cli.Context, manager user.Manager) error { if err := manager.ResetAccess("", ""); err != nil { return err } @@ -140,7 +140,7 @@ func resetAllAccess(c *cli.Context, manager auth.Manager) error { return nil } -func resetUserAccess(c *cli.Context, manager auth.Manager, username string) error { +func resetUserAccess(c *cli.Context, manager user.Manager, username string) error { if err := manager.ResetAccess(username, ""); err != nil { return err } @@ -148,7 +148,7 @@ func resetUserAccess(c *cli.Context, manager auth.Manager, username string) erro return showUserAccess(c, manager, username) } -func resetUserTopicAccess(c *cli.Context, manager auth.Manager, username string, topic string) error { +func resetUserTopicAccess(c *cli.Context, manager user.Manager, username string, topic string) error { if err := manager.ResetAccess(username, topic); err != nil { return err } @@ -156,14 +156,14 @@ func resetUserTopicAccess(c *cli.Context, manager auth.Manager, username string, return showUserAccess(c, manager, username) } -func showAccess(c *cli.Context, manager auth.Manager, username string) error { +func showAccess(c *cli.Context, manager user.Manager, username string) error { if username == "" { return showAllAccess(c, manager) } return showUserAccess(c, manager, username) } -func showAllAccess(c *cli.Context, manager auth.Manager) error { +func showAllAccess(c *cli.Context, manager user.Manager) error { users, err := manager.Users() if err != nil { return err @@ -171,23 +171,23 @@ func showAllAccess(c *cli.Context, manager auth.Manager) error { return showUsers(c, manager, users) } -func showUserAccess(c *cli.Context, manager auth.Manager, username string) error { +func showUserAccess(c *cli.Context, manager user.Manager, username string) error { users, err := manager.User(username) - if err == auth.ErrNotFound { + if err == user.ErrNotFound { return fmt.Errorf("user %s does not exist", username) } else if err != nil { return err } - return showUsers(c, manager, []*auth.User{users}) + return showUsers(c, manager, []*user.User{users}) } -func showUsers(c *cli.Context, manager auth.Manager, users []*auth.User) error { - for _, user := range users { - fmt.Fprintf(c.App.ErrWriter, "user %s (%s)\n", user.Name, user.Role) - if user.Role == auth.RoleAdmin { +func showUsers(c *cli.Context, manager user.Manager, users []*user.User) error { + for _, u := range users { + fmt.Fprintf(c.App.ErrWriter, "user %s (%s)\n", u.Name, u.Role) + if u.Role == user.RoleAdmin { fmt.Fprintf(c.App.ErrWriter, "- read-write access to all topics (admin role)\n") - } else if len(user.Grants) > 0 { - for _, grant := range user.Grants { + } else if len(u.Grants) > 0 { + for _, grant := range u.Grants { if grant.AllowRead && grant.AllowWrite { fmt.Fprintf(c.App.ErrWriter, "- read-write access to topic %s\n", grant.TopicPattern) } else if grant.AllowRead { @@ -201,7 +201,7 @@ func showUsers(c *cli.Context, manager auth.Manager, users []*auth.User) error { } else { fmt.Fprintf(c.App.ErrWriter, "- no topic-specific permissions\n") } - if user.Name == auth.Everyone { + if u.Name == user.Everyone { defaultRead, defaultWrite := manager.DefaultAccess() if defaultRead && defaultWrite { fmt.Fprintln(c.App.ErrWriter, "- read-write access to all (other) topics (server config)") diff --git a/cmd/user.go b/cmd/user.go index 094d31c..8fe1b30 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -6,12 +6,12 @@ import ( "crypto/subtle" "errors" "fmt" + "heckel.io/ntfy/user" "os" "strings" "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" - "heckel.io/ntfy/auth" "heckel.io/ntfy/util" ) @@ -41,7 +41,7 @@ var cmdUser = &cli.Command{ UsageText: "ntfy user add [--role=admin|user] USERNAME\nNTFY_PASSWORD=... ntfy user add [--role=admin|user] USERNAME", Action: execUserAdd, Flags: []cli.Flag{ - &cli.StringFlag{Name: "role", Aliases: []string{"r"}, Value: string(auth.RoleUser), Usage: "user role"}, + &cli.StringFlag{Name: "role", Aliases: []string{"r"}, Value: string(user.RoleUser), Usage: "user role"}, }, Description: `Add a new user to the ntfy user database. @@ -152,13 +152,13 @@ variable to pass the new password. This is useful if you are creating/updating u func execUserAdd(c *cli.Context) error { username := c.Args().Get(0) - role := auth.Role(c.String("role")) + role := user.Role(c.String("role")) password := os.Getenv("NTFY_PASSWORD") if username == "" { return errors.New("username expected, type 'ntfy user add --help' for help") } else if username == userEveryone { return errors.New("username not allowed") - } else if !auth.AllowedRole(role) { + } else if !user.AllowedRole(role) { return errors.New("role must be either 'user' or 'admin'") } manager, err := createAuthManager(c) @@ -194,7 +194,7 @@ func execUserDel(c *cli.Context) error { if err != nil { return err } - if _, err := manager.User(username); err == auth.ErrNotFound { + if _, err := manager.User(username); err == user.ErrNotFound { return fmt.Errorf("user %s does not exist", username) } if err := manager.RemoveUser(username); err != nil { @@ -216,7 +216,7 @@ func execUserChangePass(c *cli.Context) error { if err != nil { return err } - if _, err := manager.User(username); err == auth.ErrNotFound { + if _, err := manager.User(username); err == user.ErrNotFound { return fmt.Errorf("user %s does not exist", username) } if password == "" { @@ -234,8 +234,8 @@ func execUserChangePass(c *cli.Context) error { func execUserChangeRole(c *cli.Context) error { username := c.Args().Get(0) - role := auth.Role(c.Args().Get(1)) - if username == "" || !auth.AllowedRole(role) { + role := user.Role(c.Args().Get(1)) + if username == "" || !user.AllowedRole(role) { return errors.New("username and new role expected, type 'ntfy user change-role --help' for help") } else if username == userEveryone { return errors.New("username not allowed") @@ -244,7 +244,7 @@ func execUserChangeRole(c *cli.Context) error { if err != nil { return err } - if _, err := manager.User(username); err == auth.ErrNotFound { + if _, err := manager.User(username); err == user.ErrNotFound { return fmt.Errorf("user %s does not exist", username) } if err := manager.ChangeRole(username, role); err != nil { @@ -266,7 +266,7 @@ func execUserList(c *cli.Context) error { return showUsers(c, manager, users) } -func createAuthManager(c *cli.Context) (auth.Manager, error) { +func createAuthManager(c *cli.Context) (user.Manager, error) { authFile := c.String("auth-file") authDefaultAccess := c.String("auth-default-access") if authFile == "" { @@ -278,7 +278,7 @@ func createAuthManager(c *cli.Context) (auth.Manager, error) { } authDefaultRead := authDefaultAccess == "read-write" || authDefaultAccess == "read-only" authDefaultWrite := authDefaultAccess == "read-write" || authDefaultAccess == "write-only" - return auth.NewSQLiteAuthManager(authFile, authDefaultRead, authDefaultWrite) + return user.NewSQLiteAuthManager(authFile, authDefaultRead, authDefaultWrite) } func readPasswordAndConfirm(c *cli.Context) (string, error) { diff --git a/server/errors.go b/server/errors.go index 5d5accd..d7a1cdc 100644 --- a/server/errors.go +++ b/server/errors.go @@ -54,6 +54,7 @@ var ( errHTTPBadRequestMatrixPushkeyBaseURLMismatch = &errHTTP{40020, http.StatusBadRequest, "invalid request: push key must be prefixed with base URL", "https://ntfy.sh/docs/publish/#matrix-gateway"} errHTTPBadRequestIconURLInvalid = &errHTTP{40021, http.StatusBadRequest, "invalid request: icon URL is invalid", "https://ntfy.sh/docs/publish/#icons"} errHTTPBadRequestSignupNotEnabled = &errHTTP{40022, http.StatusBadRequest, "invalid request: signup not enabled", "https://ntfy.sh/docs/config"} + errHTTPBadRequestNoTokenProvided = &errHTTP{40023, http.StatusBadRequest, "invalid request: no token provided", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"} errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"} diff --git a/server/server.go b/server/server.go index c3105a4..7cbc77a 100644 --- a/server/server.go +++ b/server/server.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "fmt" + "heckel.io/ntfy/user" "io" "net" "net/http" @@ -30,17 +31,17 @@ import ( "github.com/emersion/go-smtp" "github.com/gorilla/websocket" "golang.org/x/sync/errgroup" - "heckel.io/ntfy/auth" "heckel.io/ntfy/util" ) /* TODO + expire tokens + auto-extend tokens from UI use token auth in "SubscribeDialog" upload files based on user limit + database migration publishXHR + poll should pick current user, not from userManager - expire tokens - auto-refresh tokens from UI reserve topics purge accounts that were not logged into in X sync subscription display name @@ -55,7 +56,11 @@ import ( Polishing: aria-label for everything - + Tests: + - APIs + - CRUD tokens + - Expire tokens + - */ // Server is the main server, providing the UI and API for ntfy @@ -71,7 +76,7 @@ type Server struct { visitors map[string]*visitor // ip: or user: firebaseClient *firebaseClient messages int64 - auth auth.Manager + userManager user.Manager messageCache *messageCache fileCache *fileCache closeChan chan bool @@ -159,9 +164,9 @@ func New(conf *Config) (*Server, error) { return nil, err } } - var auther auth.Manager + var auther user.Manager if conf.AuthFile != "" { - auther, err = auth.NewSQLiteAuthManager(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite) + auther, err = user.NewSQLiteAuthManager(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite) if err != nil { return nil, err } @@ -181,7 +186,7 @@ func New(conf *Config) (*Server, error) { firebaseClient: firebaseClient, smtpSender: mailer, topics: topics, - auth: auther, + userManager: auther, visitors: make(map[string]*visitor), }, nil } @@ -342,11 +347,13 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.handleAccountDelete(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath { return s.handleAccountPasswordChange(w, r, v) - } else if r.Method == http.MethodGet && r.URL.Path == accountTokenPath { - return s.handleAccountTokenGet(w, r, v) + } else if r.Method == http.MethodPost && r.URL.Path == accountTokenPath { + return s.handleAccountTokenIssue(w, r, v) + } else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath { + return s.handleAccountTokenExtend(w, r, v) } else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath { return s.handleAccountTokenDelete(w, r, v) - } else if r.Method == http.MethodPost && r.URL.Path == accountSettingsPath { + } else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath { return s.handleAccountSettingsChange(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath { return s.handleAccountSubscriptionAdd(w, r, v) @@ -557,7 +564,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes } v.IncrMessages() if v.user != nil { - s.auth.EnqueueUpdateStats(v.user) + s.userManager.EnqueueStats(v.user) } s.mu.Lock() s.messages++ @@ -1122,7 +1129,7 @@ func parseSince(r *http.Request, poll bool) (sinceMarker, error) { } func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error { - w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, DELETE") + w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE") w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible? return nil @@ -1192,6 +1199,11 @@ func (s *Server) updateStatsAndPrune() { s.mu.Unlock() log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) + // Delete expired user tokens + if err := s.userManager.RemoveExpiredTokens(); err != nil { + log.Warn("Error expiring user tokens: %s", err.Error()) + } + // Delete expired attachments if s.fileCache != nil && s.config.AttachmentExpiryDuration > 0 { olderThan := time.Now().Add(-1 * s.config.AttachmentExpiryDuration) @@ -1323,7 +1335,7 @@ func (s *Server) sendDelayedMessages() error { for _, m := range messages { var v *visitor if m.User != "" { - user, err := s.auth.User(m.User) + user, err := s.userManager.User(m.User) if err != nil { log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error()) continue @@ -1457,16 +1469,16 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc { } func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc { - return s.autorizeTopic(next, auth.PermissionWrite) + return s.autorizeTopic(next, user.PermissionWrite) } func (s *Server) authorizeTopicRead(next handleFunc) handleFunc { - return s.autorizeTopic(next, auth.PermissionRead) + return s.autorizeTopic(next, user.PermissionRead) } -func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc { +func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if s.auth == nil { + if s.userManager == nil { return next(w, r, v) } topics, _, err := s.topicsFromPath(r.URL.Path) @@ -1474,7 +1486,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc return err } for _, t := range topics { - if err := s.auth.Authorize(v.user, t.ID, perm); err != nil { + if err := s.userManager.Authorize(v.user, t.ID, perm); err != nil { log.Info("unauthorized: %s", err.Error()) return errHTTPForbidden } @@ -1487,7 +1499,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc // Note that this function will always return a visitor, even if an error occurs. func (s *Server) visitor(r *http.Request) (v *visitor, err error) { ip := extractIPAddress(r, s.config.BehindProxy) - var user *auth.User // may stay nil if no auth header! + var user *user.User // may stay nil if no auth header! if user, err = s.authenticate(r); err != nil { log.Debug("authentication failed: %s", err.Error()) err = errHTTPUnauthorized // Always return visitor, even when error occurs! @@ -1505,7 +1517,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) { // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth // query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)). -func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) { +func (s *Server) authenticate(r *http.Request) (user *user.User, err error) { value := r.Header.Get("Authorization") queryParam := readQueryParam(r, "authorization", "auth") if queryParam != "" { @@ -1524,21 +1536,21 @@ func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) { return s.authenticateBasicAuth(r, value) } -func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *auth.User, err error) { +func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) { r.Header.Set("Authorization", value) username, password, ok := r.BasicAuth() if !ok { return nil, errors.New("invalid basic auth") } - return s.auth.Authenticate(username, password) + return s.userManager.Authenticate(username, password) } -func (s *Server) authenticateBearerAuth(value string) (user *auth.User, err error) { +func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) { token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer")) - return s.auth.AuthenticateToken(token) + return s.userManager.AuthenticateToken(token) } -func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor { +func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor { s.mu.Lock() defer s.mu.Unlock() v, exists := s.visitors[visitorID] @@ -1554,6 +1566,6 @@ func (s *Server) visitorFromIP(ip netip.Addr) *visitor { return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil) } -func (s *Server) visitorFromUser(user *auth.User, ip netip.Addr) *visitor { +func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor { return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user) } diff --git a/server/server_account.go b/server/server_account.go index 7069ca9..1f9d849 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -3,13 +3,13 @@ package server import ( "encoding/json" "errors" - "heckel.io/ntfy/auth" + "heckel.io/ntfy/user" "heckel.io/ntfy/util" "net/http" ) func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { - admin := v.user != nil && v.user.Role == auth.RoleAdmin + admin := v.user != nil && v.user.Role == user.RoleAdmin if !admin { if !s.config.EnableSignup { return errHTTPBadRequestSignupNotEnabled @@ -26,13 +26,13 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v * if err := json.NewDecoder(body).Decode(&newAccount); err != nil { return err } - if existingUser, _ := s.auth.User(newAccount.Username); existingUser != nil { + if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil { return errHTTPConflictUserExists } if v.accountLimiter != nil && !v.accountLimiter.Allow() { return errHTTPTooManyRequestsAccountCreateLimit } - if err := s.auth.AddUser(newAccount.Username, newAccount.Password, auth.RoleUser); err != nil { // TODO this should return a User + if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User return err } w.Header().Set("Content-Type", "application/json") @@ -84,23 +84,23 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis Code: v.user.Plan.Code, Upgradable: v.user.Plan.Upgradable, } - } else if v.user.Role == auth.RoleAdmin { + } else if v.user.Role == user.RoleAdmin { response.Plan = &apiAccountPlan{ - Code: string(auth.PlanUnlimited), + Code: string(user.PlanUnlimited), Upgradable: false, } } else { response.Plan = &apiAccountPlan{ - Code: string(auth.PlanDefault), + Code: string(user.PlanDefault), Upgradable: true, } } } else { - response.Username = auth.Everyone - response.Role = string(auth.RoleAnonymous) + response.Username = user.Everyone + response.Role = string(user.RoleAnonymous) response.Plan = &apiAccountPlan{ - Code: string(auth.PlanNone), + Code: string(user.PlanNone), Upgradable: true, } } @@ -114,7 +114,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v * if v.user == nil { return errHTTPUnauthorized } - if err := s.auth.RemoveUser(v.user.Name); err != nil { + if err := s.userManager.RemoveUser(v.user.Name); err != nil { return err } w.Header().Set("Content-Type", "application/json") @@ -136,7 +136,7 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ if err := json.NewDecoder(body).Decode(&newPassword); err != nil { return err } - if err := s.auth.ChangePassword(v.user.Name, newPassword.Password); err != nil { + if err := s.userManager.ChangePassword(v.user.Name, newPassword.Password); err != nil { return err } w.Header().Set("Content-Type", "application/json") @@ -145,19 +145,43 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ return nil } -func (s *Server) handleAccountTokenGet(w http.ResponseWriter, r *http.Request, v *visitor) error { +func (s *Server) handleAccountTokenIssue(w http.ResponseWriter, r *http.Request, v *visitor) error { // TODO rate limit if v.user == nil { return errHTTPUnauthorized } - token, err := s.auth.CreateToken(v.user) + token, err := s.userManager.CreateToken(v.user) if err != nil { return err } w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this response := &apiAccountTokenResponse{ - Token: token, + Token: token.Value, + Expires: token.Expires, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil +} + +func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, r *http.Request, v *visitor) error { + // TODO rate limit + if v.user == nil { + return errHTTPUnauthorized + } else if v.user.Token == "" { + return errHTTPBadRequestNoTokenProvided + } + token, err := s.userManager.ExtendToken(v.user) + if err != nil { + return err + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this + response := &apiAccountTokenResponse{ + Token: token.Value, + Expires: token.Expires, } if err := json.NewEncoder(w).Encode(response); err != nil { return err @@ -170,7 +194,7 @@ func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, r *http.Request if v.user == nil || v.user.Token == "" { return errHTTPUnauthorized } - if err := s.auth.RemoveToken(v.user); err != nil { + if err := s.userManager.RemoveToken(v.user); err != nil { return err } w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this @@ -188,12 +212,12 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ return err } defer r.Body.Close() - var newPrefs auth.UserPrefs + var newPrefs user.Prefs if err := json.NewDecoder(body).Decode(&newPrefs); err != nil { return err } if v.user.Prefs == nil { - v.user.Prefs = &auth.UserPrefs{} + v.user.Prefs = &user.Prefs{} } prefs := v.user.Prefs if newPrefs.Language != "" { @@ -201,7 +225,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ } if newPrefs.Notification != nil { if prefs.Notification == nil { - prefs.Notification = &auth.UserNotificationPrefs{} + prefs.Notification = &user.NotificationPrefs{} } if newPrefs.Notification.DeleteAfter > 0 { prefs.Notification.DeleteAfter = newPrefs.Notification.DeleteAfter @@ -213,7 +237,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ prefs.Notification.MinPriority = newPrefs.Notification.MinPriority } } - return s.auth.ChangeSettings(v.user) + return s.userManager.ChangeSettings(v.user) } func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -227,12 +251,12 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req return err } defer r.Body.Close() - var newSubscription auth.UserSubscription + var newSubscription user.Subscription if err := json.NewDecoder(body).Decode(&newSubscription); err != nil { return err } if v.user.Prefs == nil { - v.user.Prefs = &auth.UserPrefs{} + v.user.Prefs = &user.Prefs{} } newSubscription.ID = "" // Client cannot set ID for _, subscription := range v.user.Prefs.Subscriptions { @@ -244,7 +268,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req if newSubscription.ID == "" { newSubscription.ID = util.RandomString(16) v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, &newSubscription) - if err := s.auth.ChangeSettings(v.user); err != nil { + if err := s.userManager.ChangeSettings(v.user); err != nil { return err } } @@ -268,7 +292,7 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http. if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil { return nil } - newSubscriptions := make([]*auth.UserSubscription, 0) + newSubscriptions := make([]*user.Subscription, 0) for _, subscription := range v.user.Prefs.Subscriptions { if subscription.ID != subscriptionID { newSubscriptions = append(newSubscriptions, subscription) @@ -276,7 +300,7 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http. } if len(newSubscriptions) < len(v.user.Prefs.Subscriptions) { v.user.Prefs.Subscriptions = newSubscriptions - if err := s.auth.ChangeSettings(v.user); err != nil { + if err := s.userManager.ChangeSettings(v.user); err != nil { return err } } diff --git a/server/server_firebase.go b/server/server_firebase.go index 5124195..629aebc 100644 --- a/server/server_firebase.go +++ b/server/server_firebase.go @@ -8,8 +8,8 @@ import ( "firebase.google.com/go/v4/messaging" "fmt" "google.golang.org/api/option" - "heckel.io/ntfy/auth" "heckel.io/ntfy/log" + "heckel.io/ntfy/user" "heckel.io/ntfy/util" "strings" ) @@ -28,10 +28,10 @@ var ( // The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable. type firebaseClient struct { sender firebaseSender - auther auth.Manager + auther user.Manager } -func newFirebaseClient(sender firebaseSender, auther auth.Manager) *firebaseClient { +func newFirebaseClient(sender firebaseSender, auther user.Manager) *firebaseClient { return &firebaseClient{ sender: sender, auther: auther, @@ -112,7 +112,7 @@ func (c *firebaseSenderImpl) Send(m *messaging.Message) error { // On Android, this will trigger the app to poll the topic and thereby displaying new messages. // - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded // to Firebase here. This is mainly for iOS to support self-hosted servers. -func toFirebaseMessage(m *message, auther auth.Manager) (*messaging.Message, error) { +func toFirebaseMessage(m *message, auther user.Manager) (*messaging.Message, error) { var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format var apnsConfig *messaging.APNSConfig switch m.Event { @@ -137,7 +137,7 @@ func toFirebaseMessage(m *message, auther auth.Manager) (*messaging.Message, err case messageEvent: allowForward := true if auther != nil { - allowForward = auther.Authorize(nil, m.Topic, auth.PermissionRead) == nil + allowForward = auther.Authorize(nil, m.Topic, user.PermissionRead) == nil } if allowForward { data = map[string]string{ diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go index ba2ab1d..034511f 100644 --- a/server/server_firebase_test.go +++ b/server/server_firebase_test.go @@ -11,18 +11,17 @@ import ( "firebase.google.com/go/v4/messaging" "github.com/stretchr/testify/require" - "heckel.io/ntfy/auth" ) type testAuther struct { Allow bool } -func (t testAuther) AuthenticateUser(_, _ string) (*auth.User, error) { +func (t testAuther) AuthenticateUser(_, _ string) (*user.User, error) { return nil, errors.New("not used") } -func (t testAuther) Authorize(_ *auth.User, _ string, _ auth.Permission) error { +func (t testAuther) Authorize(_ *user.User, _ string, _ user.Permission) error { if t.Allow { return nil } diff --git a/server/server_test.go b/server/server_test.go index 8f776a5..dc047e7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -21,7 +21,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "heckel.io/ntfy/auth" "heckel.io/ntfy/util" ) @@ -626,8 +625,8 @@ func TestServer_Auth_Success_Admin(t *testing.T) { c.AuthFile = filepath.Join(t.TempDir(), "user.db") s := newTestServer(t, c) - manager := s.auth.(auth.Manager) - require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin)) + manager := s.userManager.(user.Manager) + require.Nil(t, manager.AddUser("phil", "phil", user.RoleAdmin)) response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ "Authorization": basicAuth("phil:phil"), @@ -643,8 +642,8 @@ func TestServer_Auth_Success_User(t *testing.T) { c.AuthDefaultWrite = false s := newTestServer(t, c) - manager := s.auth.(auth.Manager) - require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser)) + manager := s.userManager.(user.Manager) + require.Nil(t, manager.AddUser("ben", "ben", user.RoleUser)) require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true)) response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ @@ -660,8 +659,8 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) { c.AuthDefaultWrite = false s := newTestServer(t, c) - manager := s.auth.(auth.Manager) - require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser)) + manager := s.userManager.(user.Manager) + require.Nil(t, manager.AddUser("ben", "ben", user.RoleUser)) require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, manager.AllowAccess("ben", "anothertopic", true, true)) @@ -683,8 +682,8 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) { c.AuthDefaultWrite = false s := newTestServer(t, c) - manager := s.auth.(auth.Manager) - require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin)) + manager := s.userManager.(user.Manager) + require.Nil(t, manager.AddUser("phil", "phil", user.RoleAdmin)) response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ "Authorization": basicAuth("phil:INVALID"), @@ -699,8 +698,8 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) { c.AuthDefaultWrite = false s := newTestServer(t, c) - manager := s.auth.(auth.Manager) - require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser)) + manager := s.userManager.(user.Manager) + require.Nil(t, manager.AddUser("ben", "ben", user.RoleUser)) require.Nil(t, manager.AllowAccess("ben", "sometopic", true, true)) // Not mytopic! response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ @@ -716,10 +715,10 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) { c.AuthDefaultWrite = true // Open by default s := newTestServer(t, c) - manager := s.auth.(auth.Manager) - require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin)) - require.Nil(t, manager.AllowAccess(auth.Everyone, "private", false, false)) - require.Nil(t, manager.AllowAccess(auth.Everyone, "announcements", true, false)) + manager := s.userManager.(user.Manager) + require.Nil(t, manager.AddUser("phil", "phil", user.RoleAdmin)) + require.Nil(t, manager.AllowAccess(user.Everyone, "private", false, false)) + require.Nil(t, manager.AllowAccess(user.Everyone, "announcements", true, false)) response := request(t, s, "PUT", "/mytopic", "test", nil) require.Equal(t, 200, response.Code) @@ -749,8 +748,8 @@ func TestServer_Auth_ViaQuery(t *testing.T) { c.AuthDefaultWrite = false s := newTestServer(t, c) - manager := s.auth.(auth.Manager) - require.Nil(t, manager.AddUser("ben", "some pass", auth.RoleAdmin)) + manager := s.userManager.(user.Manager) + require.Nil(t, manager.AddUser("ben", "some pass", user.RoleAdmin)) u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(basicAuth("ben:some pass")))) response := request(t, s, "GET", u, "", nil) diff --git a/server/types.go b/server/types.go index 2f14485..e6e5186 100644 --- a/server/types.go +++ b/server/types.go @@ -1,7 +1,7 @@ package server import ( - "heckel.io/ntfy/auth" + "heckel.io/ntfy/user" "net/http" "net/netip" "time" @@ -226,7 +226,8 @@ type apiAccountCreateRequest struct { } type apiAccountTokenResponse struct { - Token string `json:"token"` + Token string `json:"token"` + Expires int64 `json:"expires"` } type apiAccountPlan struct { @@ -252,12 +253,12 @@ type apiAccountStats struct { } type apiAccountSettingsResponse struct { - Username string `json:"username"` - Role string `json:"role,omitempty"` - Language string `json:"language,omitempty"` - Notification *auth.UserNotificationPrefs `json:"notification,omitempty"` - Subscriptions []*auth.UserSubscription `json:"subscriptions,omitempty"` - Plan *apiAccountPlan `json:"plan,omitempty"` - Limits *apiAccountLimits `json:"limits,omitempty"` - Stats *apiAccountStats `json:"stats,omitempty"` + Username string `json:"username"` + Role string `json:"role,omitempty"` + Language string `json:"language,omitempty"` + Notification *user.NotificationPrefs `json:"notification,omitempty"` + Subscriptions []*user.Subscription `json:"subscriptions,omitempty"` + Plan *apiAccountPlan `json:"plan,omitempty"` + Limits *apiAccountLimits `json:"limits,omitempty"` + Stats *apiAccountStats `json:"stats,omitempty"` } diff --git a/server/visitor.go b/server/visitor.go index add3273..9b53033 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -2,7 +2,7 @@ package server import ( "errors" - "heckel.io/ntfy/auth" + "heckel.io/ntfy/user" "net/netip" "sync" "time" @@ -27,7 +27,7 @@ type visitor struct { config *Config messageCache *messageCache ip netip.Addr - user *auth.User + user *user.User messages int64 // Number of messages sent emails int64 // Number of emails sent requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages) @@ -54,7 +54,7 @@ type visitorStats struct { AttachmentFileSizeLimit int64 } -func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr, user *auth.User) *visitor { +func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr, user *user.User) *visitor { var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter var messages, emails int64 if user != nil { @@ -171,7 +171,7 @@ func (v *visitor) Stats() (*visitorStats, error) { emails := v.emails v.mu.Unlock() stats := &visitorStats{} - if v.user != nil && v.user.Role == auth.RoleAdmin { + if v.user != nil && v.user.Role == user.RoleAdmin { stats.Basis = "role" stats.MessagesLimit = 0 stats.EmailsLimit = 0 diff --git a/auth/auth.go b/user/manager.go similarity index 91% rename from auth/auth.go rename to user/manager.go index 1c97830..a8de6e5 100644 --- a/auth/auth.go +++ b/user/manager.go @@ -1,5 +1,5 @@ // Package auth deals with authentication and authorization against topics -package auth +package user import ( "errors" @@ -14,10 +14,12 @@ type Manager interface { Authenticate(username, password string) (*User, error) AuthenticateToken(token string) (*User, error) - CreateToken(user *User) (string, error) + CreateToken(user *User) (*Token, error) + ExtendToken(user *User) (*Token, error) RemoveToken(user *User) error + RemoveExpiredTokens() error ChangeSettings(user *User) error - EnqueueUpdateStats(user *User) + EnqueueStats(user *User) // Authorize returns nil if the given user has access to the given topic using the desired // permission. The user param may be nil to signal an anonymous user. @@ -64,15 +66,20 @@ type User struct { Token string // Only set if token was used to log in Role Role Grants []Grant - Prefs *UserPrefs + Prefs *Prefs Plan *Plan Stats *Stats } -type UserPrefs struct { - Language string `json:"language,omitempty"` - Notification *UserNotificationPrefs `json:"notification,omitempty"` - Subscriptions []*UserSubscription `json:"subscriptions,omitempty"` +type Token struct { + Value string + Expires int64 +} + +type Prefs struct { + Language string `json:"language,omitempty"` + Notification *NotificationPrefs `json:"notification,omitempty"` + Subscriptions []*Subscription `json:"subscriptions,omitempty"` } type PlanCode string @@ -92,13 +99,13 @@ type Plan struct { AttachmentTotalSizeLimit int64 `json:"attachment_total_size_limit"` } -type UserSubscription struct { +type Subscription struct { ID string `json:"id"` BaseURL string `json:"base_url"` Topic string `json:"topic"` } -type UserNotificationPrefs struct { +type NotificationPrefs struct { Sound string `json:"sound,omitempty"` MinPriority int `json:"min_priority,omitempty"` DeleteAfter int `json:"delete_after,omitempty"` diff --git a/auth/auth_sqlite.go b/user/manager_sqlite.go similarity index 76% rename from auth/auth_sqlite.go rename to user/manager_sqlite.go index 4c1559f..70b3dc8 100644 --- a/auth/auth_sqlite.go +++ b/user/manager_sqlite.go @@ -1,4 +1,4 @@ -package auth +package user import ( "database/sql" @@ -15,10 +15,11 @@ import ( ) const ( - tokenLength = 32 - bcryptCost = 10 - intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost - statsWriterInterval = 10 * time.Second + tokenLength = 32 + bcryptCost = 10 + intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost + userStatsQueueWriterInterval = 33 * time.Second + userTokenExpiryDuration = 72 * time.Hour ) // Manager-related queries @@ -106,9 +107,11 @@ const ( deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)` deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?` - insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)` - deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` - deleteUserTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?)` + insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)` + updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` + deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` + deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?` + deleteUserTokensQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?)` ) // Schema management queries @@ -118,20 +121,20 @@ const ( selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` ) -// SQLiteAuthManager is an implementation of Manager. It stores users and access control list +// SQLiteManager is an implementation of Manager. It stores users and access control list // in a SQLite database. -type SQLiteAuthManager struct { +type SQLiteManager struct { db *sql.DB defaultRead bool defaultWrite bool - statsQueue map[string]*Stats // Username -> Stats + statsQueue map[string]*User // Username -> User, for "unimportant" user updates mu sync.Mutex } -var _ Manager = (*SQLiteAuthManager)(nil) +var _ Manager = (*SQLiteManager)(nil) -// NewSQLiteAuthManager creates a new SQLiteAuthManager instance -func NewSQLiteAuthManager(filename string, defaultRead, defaultWrite bool) (*SQLiteAuthManager, error) { +// NewSQLiteAuthManager creates a new SQLiteManager instance +func NewSQLiteAuthManager(filename string, defaultRead, defaultWrite bool) (*SQLiteManager, error) { db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err @@ -139,20 +142,20 @@ func NewSQLiteAuthManager(filename string, defaultRead, defaultWrite bool) (*SQL if err := setupAuthDB(db); err != nil { return nil, err } - manager := &SQLiteAuthManager{ + manager := &SQLiteManager{ db: db, defaultRead: defaultRead, defaultWrite: defaultWrite, - statsQueue: make(map[string]*Stats), + statsQueue: make(map[string]*User), } - go manager.statsWriter() + go manager.userStatsQueueWriter() return manager, nil } // Authenticate checks username and password and returns a user if correct. The method // returns in constant-ish time, regardless of whether the user exists or the password is // correct or incorrect. -func (a *SQLiteAuthManager) Authenticate(username, password string) (*User, error) { +func (a *SQLiteManager) Authenticate(username, password string) (*User, error) { if username == Everyone { return nil, ErrUnauthenticated } @@ -168,7 +171,7 @@ func (a *SQLiteAuthManager) Authenticate(username, password string) (*User, erro return user, nil } -func (a *SQLiteAuthManager) AuthenticateToken(token string) (*User, error) { +func (a *SQLiteManager) AuthenticateToken(token string) (*User, error) { user, err := a.userByToken(token) if err != nil { return nil, ErrUnauthenticated @@ -177,16 +180,30 @@ func (a *SQLiteAuthManager) AuthenticateToken(token string) (*User, error) { return user, nil } -func (a *SQLiteAuthManager) CreateToken(user *User) (string, error) { +func (a *SQLiteManager) CreateToken(user *User) (*Token, error) { token := util.RandomString(tokenLength) - expires := 1 // FIXME - if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires); err != nil { - return "", err + expires := time.Now().Add(userTokenExpiryDuration) + if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil { + return nil, err } - return token, nil + return &Token{ + Value: token, + Expires: expires.Unix(), + }, nil } -func (a *SQLiteAuthManager) RemoveToken(user *User) error { +func (a *SQLiteManager) ExtendToken(user *User) (*Token, error) { + newExpires := time.Now().Add(userTokenExpiryDuration) + if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil { + return nil, err + } + return &Token{ + Value: user.Token, + Expires: newExpires.Unix(), + }, nil +} + +func (a *SQLiteManager) RemoveToken(user *User) error { if user.Token == "" { return ErrUnauthorized } @@ -196,7 +213,14 @@ func (a *SQLiteAuthManager) RemoveToken(user *User) error { return nil } -func (a *SQLiteAuthManager) ChangeSettings(user *User) error { +func (a *SQLiteManager) RemoveExpiredTokens() error { + if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil { + return err + } + return nil +} + +func (a *SQLiteManager) ChangeSettings(user *User) error { settings, err := json.Marshal(user.Prefs) if err != nil { return err @@ -207,33 +231,40 @@ func (a *SQLiteAuthManager) ChangeSettings(user *User) error { return nil } -func (a *SQLiteAuthManager) EnqueueUpdateStats(user *User) { +func (a *SQLiteManager) EnqueueStats(user *User) { a.mu.Lock() defer a.mu.Unlock() - a.statsQueue[user.Name] = user.Stats + a.statsQueue[user.Name] = user } -func (a *SQLiteAuthManager) statsWriter() { - ticker := time.NewTicker(statsWriterInterval) +func (a *SQLiteManager) userStatsQueueWriter() { + ticker := time.NewTicker(userStatsQueueWriterInterval) for range ticker.C { - if err := a.writeStats(); err != nil { - log.Warn("UserManager: Writing user stats failed: %s", err.Error()) + if err := a.writeUserStatsQueue(); err != nil { + log.Warn("UserManager: Writing user stats queue failed: %s", err.Error()) } } } -func (a *SQLiteAuthManager) writeStats() error { +func (a *SQLiteManager) writeUserStatsQueue() error { + a.mu.Lock() + if len(a.statsQueue) == 0 { + a.mu.Unlock() + log.Trace("UserManager: No user stats updates to commit") + return nil + } + statsQueue := a.statsQueue + a.statsQueue = make(map[string]*User) + a.mu.Unlock() tx, err := a.db.Begin() if err != nil { return err } defer tx.Rollback() - a.mu.Lock() - statsQueue := a.statsQueue - a.statsQueue = make(map[string]*Stats) - a.mu.Unlock() - for username, stats := range statsQueue { - if _, err := tx.Exec(updateUserStatsQuery, stats.Messages, stats.Emails, username); err != nil { + log.Debug("UserManager: Writing user stats queue for %d user(s)", len(statsQueue)) + for username, u := range statsQueue { + log.Trace("UserManager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails) + if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil { return err } } @@ -242,7 +273,7 @@ func (a *SQLiteAuthManager) writeStats() error { // Authorize returns nil if the given user has access to the given topic using the desired // permission. The user param may be nil to signal an anonymous user. -func (a *SQLiteAuthManager) Authorize(user *User, topic string, perm Permission) error { +func (a *SQLiteManager) Authorize(user *User, topic string, perm Permission) error { if user != nil && user.Role == RoleAdmin { return nil // Admin can do everything } @@ -270,7 +301,7 @@ func (a *SQLiteAuthManager) Authorize(user *User, topic string, perm Permission) return a.resolvePerms(read, write, perm) } -func (a *SQLiteAuthManager) resolvePerms(read, write bool, perm Permission) error { +func (a *SQLiteManager) resolvePerms(read, write bool, perm Permission) error { if perm == PermissionRead && read { return nil } else if perm == PermissionWrite && write { @@ -281,7 +312,7 @@ func (a *SQLiteAuthManager) resolvePerms(read, write bool, perm Permission) erro // AddUser adds a user with the given username, password and role. The password should be hashed // before it is stored in a persistence layer. -func (a *SQLiteAuthManager) AddUser(username, password string, role Role) error { +func (a *SQLiteManager) AddUser(username, password string, role Role) error { if !AllowedUsername(username) || !AllowedRole(role) { return ErrInvalidArgument } @@ -297,14 +328,14 @@ func (a *SQLiteAuthManager) AddUser(username, password string, role Role) error // RemoveUser deletes the user with the given username. The function returns nil on success, even // if the user did not exist in the first place. -func (a *SQLiteAuthManager) RemoveUser(username string) error { +func (a *SQLiteManager) RemoveUser(username string) error { if !AllowedUsername(username) { return ErrInvalidArgument } if _, err := a.db.Exec(deleteUserAccessQuery, username); err != nil { return err } - if _, err := a.db.Exec(deleteUserTokenQuery, username); err != nil { + if _, err := a.db.Exec(deleteUserTokensQuery, username); err != nil { return err } if _, err := a.db.Exec(deleteUserQuery, username); err != nil { @@ -314,7 +345,7 @@ func (a *SQLiteAuthManager) RemoveUser(username string) error { } // Users returns a list of users. It always also returns the Everyone user ("*"). -func (a *SQLiteAuthManager) Users() ([]*User, error) { +func (a *SQLiteManager) Users() ([]*User, error) { rows, err := a.db.Query(selectUsernamesQuery) if err != nil { return nil, err @@ -349,7 +380,7 @@ func (a *SQLiteAuthManager) Users() ([]*User, error) { // User returns the user with the given username if it exists, or ErrNotFound otherwise. // You may also pass Everyone to retrieve the anonymous user and its Grant list. -func (a *SQLiteAuthManager) User(username string) (*User, error) { +func (a *SQLiteManager) User(username string) (*User, error) { if username == Everyone { return a.everyoneUser() } @@ -360,7 +391,7 @@ func (a *SQLiteAuthManager) User(username string) (*User, error) { return a.readUser(rows) } -func (a *SQLiteAuthManager) userByToken(token string) (*User, error) { +func (a *SQLiteManager) userByToken(token string) (*User, error) { rows, err := a.db.Query(selectUserByTokenQuery, token) if err != nil { return nil, err @@ -368,7 +399,7 @@ func (a *SQLiteAuthManager) userByToken(token string) (*User, error) { return a.readUser(rows) } -func (a *SQLiteAuthManager) readUser(rows *sql.Rows) (*User, error) { +func (a *SQLiteManager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var username, hash, role string var settings, planCode sql.NullString @@ -397,7 +428,7 @@ func (a *SQLiteAuthManager) readUser(rows *sql.Rows) (*User, error) { }, } if settings.Valid { - user.Prefs = &UserPrefs{} + user.Prefs = &Prefs{} if err := json.Unmarshal([]byte(settings.String), user.Prefs); err != nil { return nil, err } @@ -415,7 +446,7 @@ func (a *SQLiteAuthManager) readUser(rows *sql.Rows) (*User, error) { return user, nil } -func (a *SQLiteAuthManager) everyoneUser() (*User, error) { +func (a *SQLiteManager) everyoneUser() (*User, error) { grants, err := a.readGrants(Everyone) if err != nil { return nil, err @@ -428,7 +459,7 @@ func (a *SQLiteAuthManager) everyoneUser() (*User, error) { }, nil } -func (a *SQLiteAuthManager) readGrants(username string) ([]Grant, error) { +func (a *SQLiteManager) readGrants(username string) ([]Grant, error) { rows, err := a.db.Query(selectUserAccessQuery, username) if err != nil { return nil, err @@ -453,7 +484,7 @@ func (a *SQLiteAuthManager) readGrants(username string) ([]Grant, error) { } // ChangePassword changes a user's password -func (a *SQLiteAuthManager) ChangePassword(username, password string) error { +func (a *SQLiteManager) ChangePassword(username, password string) error { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost) if err != nil { return err @@ -466,7 +497,7 @@ func (a *SQLiteAuthManager) ChangePassword(username, password string) error { // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin, // all existing access control entries (Grant) are removed, since they are no longer needed. -func (a *SQLiteAuthManager) ChangeRole(username string, role Role) error { +func (a *SQLiteManager) ChangeRole(username string, role Role) error { if !AllowedUsername(username) || !AllowedRole(role) { return ErrInvalidArgument } @@ -483,7 +514,7 @@ func (a *SQLiteAuthManager) ChangeRole(username string, role Role) error { // AllowAccess adds or updates an entry in th access control list for a specific user. It controls // read/write access to a topic. The parameter topicPattern may include wildcards (*). -func (a *SQLiteAuthManager) AllowAccess(username string, topicPattern string, read bool, write bool) error { +func (a *SQLiteManager) AllowAccess(username string, topicPattern string, read bool, write bool) error { if (!AllowedUsername(username) && username != Everyone) || !AllowedTopicPattern(topicPattern) { return ErrInvalidArgument } @@ -495,7 +526,7 @@ func (a *SQLiteAuthManager) AllowAccess(username string, topicPattern string, re // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // empty) for an entire user. The parameter topicPattern may include wildcards (*). -func (a *SQLiteAuthManager) ResetAccess(username string, topicPattern string) error { +func (a *SQLiteManager) ResetAccess(username string, topicPattern string) error { if !AllowedUsername(username) && username != Everyone && username != "" { return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { @@ -513,7 +544,7 @@ func (a *SQLiteAuthManager) ResetAccess(username string, topicPattern string) er } // DefaultAccess returns the default read/write access if no access control entry matches -func (a *SQLiteAuthManager) DefaultAccess() (read bool, write bool) { +func (a *SQLiteManager) DefaultAccess() (read bool, write bool) { return a.defaultRead, a.defaultWrite } diff --git a/auth/auth_sqlite_test.go b/user/manager_sqlite_test.go similarity index 50% rename from auth/auth_sqlite_test.go rename to user/manager_sqlite_test.go index c13917a..40af95f 100644 --- a/auth/auth_sqlite_test.go +++ b/user/manager_sqlite_test.go @@ -1,8 +1,7 @@ -package auth_test +package user_test import ( "github.com/stretchr/testify/require" - "heckel.io/ntfy/auth" "path/filepath" "strings" "testing" @@ -13,29 +12,29 @@ const minBcryptTimingMillis = int64(50) // Ideally should be >100ms, but this sh func TestSQLiteAuth_FullScenario_Default_DenyAll(t *testing.T) { a := newTestAuth(t, false, false) - require.Nil(t, a.AddUser("phil", "phil", auth.RoleAdmin)) - require.Nil(t, a.AddUser("ben", "ben", auth.RoleUser)) + require.Nil(t, a.AddUser("phil", "phil", user.RoleAdmin)) + require.Nil(t, a.AddUser("ben", "ben", user.RoleUser)) require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "writeme", false, true)) require.Nil(t, a.AllowAccess("ben", "everyonewrite", false, false)) // How unfair! - require.Nil(t, a.AllowAccess(auth.Everyone, "announcements", true, false)) - require.Nil(t, a.AllowAccess(auth.Everyone, "everyonewrite", true, true)) - require.Nil(t, a.AllowAccess(auth.Everyone, "up*", false, true)) // Everyone can write to /up* + require.Nil(t, a.AllowAccess(user.Everyone, "announcements", true, false)) + require.Nil(t, a.AllowAccess(user.Everyone, "everyonewrite", true, true)) + require.Nil(t, a.AllowAccess(user.Everyone, "up*", false, true)) // Everyone can write to /up* phil, err := a.Authenticate("phil", "phil") require.Nil(t, err) require.Equal(t, "phil", phil.Name) require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$")) - require.Equal(t, auth.RoleAdmin, phil.Role) - require.Equal(t, []auth.Grant{}, phil.Grants) + require.Equal(t, user.RoleAdmin, phil.Role) + require.Equal(t, []user.Grant{}, phil.Grants) ben, err := a.Authenticate("ben", "ben") require.Nil(t, err) require.Equal(t, "ben", ben.Name) require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$")) - require.Equal(t, auth.RoleUser, ben.Role) - require.Equal(t, []auth.Grant{ + require.Equal(t, user.RoleUser, ben.Role) + require.Equal(t, []user.Grant{ {"mytopic", true, true}, {"readme", true, false}, {"writeme", false, true}, @@ -44,62 +43,62 @@ func TestSQLiteAuth_FullScenario_Default_DenyAll(t *testing.T) { notben, err := a.Authenticate("ben", "this is wrong") require.Nil(t, notben) - require.Equal(t, auth.ErrUnauthenticated, err) + require.Equal(t, user.ErrUnauthenticated, err) // Admin can do everything - require.Nil(t, a.Authorize(phil, "sometopic", auth.PermissionWrite)) - require.Nil(t, a.Authorize(phil, "mytopic", auth.PermissionRead)) - require.Nil(t, a.Authorize(phil, "readme", auth.PermissionWrite)) - require.Nil(t, a.Authorize(phil, "writeme", auth.PermissionWrite)) - require.Nil(t, a.Authorize(phil, "announcements", auth.PermissionWrite)) - require.Nil(t, a.Authorize(phil, "everyonewrite", auth.PermissionWrite)) + require.Nil(t, a.Authorize(phil, "sometopic", user.PermissionWrite)) + require.Nil(t, a.Authorize(phil, "mytopic", user.PermissionRead)) + require.Nil(t, a.Authorize(phil, "readme", user.PermissionWrite)) + require.Nil(t, a.Authorize(phil, "writeme", user.PermissionWrite)) + require.Nil(t, a.Authorize(phil, "announcements", user.PermissionWrite)) + require.Nil(t, a.Authorize(phil, "everyonewrite", user.PermissionWrite)) // User cannot do everything - require.Nil(t, a.Authorize(ben, "mytopic", auth.PermissionWrite)) - require.Nil(t, a.Authorize(ben, "mytopic", auth.PermissionRead)) - require.Nil(t, a.Authorize(ben, "readme", auth.PermissionRead)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "readme", auth.PermissionWrite)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "writeme", auth.PermissionRead)) - require.Nil(t, a.Authorize(ben, "writeme", auth.PermissionWrite)) - require.Nil(t, a.Authorize(ben, "writeme", auth.PermissionWrite)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "everyonewrite", auth.PermissionRead)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "everyonewrite", auth.PermissionWrite)) - require.Nil(t, a.Authorize(ben, "announcements", auth.PermissionRead)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "announcements", auth.PermissionWrite)) + require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionWrite)) + require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionRead)) + require.Nil(t, a.Authorize(ben, "readme", user.PermissionRead)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "readme", user.PermissionWrite)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "writeme", user.PermissionRead)) + require.Nil(t, a.Authorize(ben, "writeme", user.PermissionWrite)) + require.Nil(t, a.Authorize(ben, "writeme", user.PermissionWrite)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "everyonewrite", user.PermissionRead)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "everyonewrite", user.PermissionWrite)) + require.Nil(t, a.Authorize(ben, "announcements", user.PermissionRead)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "announcements", user.PermissionWrite)) // Everyone else can do barely anything - require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "sometopicnotinthelist", auth.PermissionRead)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "sometopicnotinthelist", auth.PermissionWrite)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "mytopic", auth.PermissionRead)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "mytopic", auth.PermissionWrite)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "readme", auth.PermissionRead)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "readme", auth.PermissionWrite)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "writeme", auth.PermissionRead)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "writeme", auth.PermissionWrite)) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(nil, "announcements", auth.PermissionWrite)) - require.Nil(t, a.Authorize(nil, "announcements", auth.PermissionRead)) - require.Nil(t, a.Authorize(nil, "everyonewrite", auth.PermissionRead)) - require.Nil(t, a.Authorize(nil, "everyonewrite", auth.PermissionWrite)) - require.Nil(t, a.Authorize(nil, "up1234", auth.PermissionWrite)) // Wildcard permission - require.Nil(t, a.Authorize(nil, "up5678", auth.PermissionWrite)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "sometopicnotinthelist", user.PermissionRead)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "sometopicnotinthelist", user.PermissionWrite)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "mytopic", user.PermissionRead)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "mytopic", user.PermissionWrite)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "readme", user.PermissionRead)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "readme", user.PermissionWrite)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "writeme", user.PermissionRead)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "writeme", user.PermissionWrite)) + require.Equal(t, user.ErrUnauthorized, a.Authorize(nil, "announcements", user.PermissionWrite)) + require.Nil(t, a.Authorize(nil, "announcements", user.PermissionRead)) + require.Nil(t, a.Authorize(nil, "everyonewrite", user.PermissionRead)) + require.Nil(t, a.Authorize(nil, "everyonewrite", user.PermissionWrite)) + require.Nil(t, a.Authorize(nil, "up1234", user.PermissionWrite)) // Wildcard permission + require.Nil(t, a.Authorize(nil, "up5678", user.PermissionWrite)) } func TestSQLiteAuth_AddUser_Invalid(t *testing.T) { a := newTestAuth(t, false, false) - require.Equal(t, auth.ErrInvalidArgument, a.AddUser(" invalid ", "pass", auth.RoleAdmin)) - require.Equal(t, auth.ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role")) + require.Equal(t, user.ErrInvalidArgument, a.AddUser(" invalid ", "pass", user.RoleAdmin)) + require.Equal(t, user.ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role")) } func TestSQLiteAuth_AddUser_Timing(t *testing.T) { a := newTestAuth(t, false, false) start := time.Now().UnixMilli() - require.Nil(t, a.AddUser("user", "pass", auth.RoleAdmin)) + require.Nil(t, a.AddUser("user", "pass", user.RoleAdmin)) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) } func TestSQLiteAuth_Authenticate_Timing(t *testing.T) { a := newTestAuth(t, false, false) - require.Nil(t, a.AddUser("user", "pass", auth.RoleAdmin)) + require.Nil(t, a.AddUser("user", "pass", user.RoleAdmin)) // Timing a correct attempt start := time.Now().UnixMilli() @@ -110,53 +109,53 @@ func TestSQLiteAuth_Authenticate_Timing(t *testing.T) { // Timing an incorrect attempt start = time.Now().UnixMilli() _, err = a.Authenticate("user", "INCORRECT") - require.Equal(t, auth.ErrUnauthenticated, err) + require.Equal(t, user.ErrUnauthenticated, err) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) // Timing a non-existing user attempt start = time.Now().UnixMilli() _, err = a.Authenticate("DOES-NOT-EXIST", "hithere") - require.Equal(t, auth.ErrUnauthenticated, err) + require.Equal(t, user.ErrUnauthenticated, err) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) } func TestSQLiteAuth_UserManagement(t *testing.T) { a := newTestAuth(t, false, false) - require.Nil(t, a.AddUser("phil", "phil", auth.RoleAdmin)) - require.Nil(t, a.AddUser("ben", "ben", auth.RoleUser)) + require.Nil(t, a.AddUser("phil", "phil", user.RoleAdmin)) + require.Nil(t, a.AddUser("ben", "ben", user.RoleUser)) require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "writeme", false, true)) require.Nil(t, a.AllowAccess("ben", "everyonewrite", false, false)) // How unfair! - require.Nil(t, a.AllowAccess(auth.Everyone, "announcements", true, false)) - require.Nil(t, a.AllowAccess(auth.Everyone, "everyonewrite", true, true)) + require.Nil(t, a.AllowAccess(user.Everyone, "announcements", true, false)) + require.Nil(t, a.AllowAccess(user.Everyone, "everyonewrite", true, true)) // Query user details phil, err := a.User("phil") require.Nil(t, err) require.Equal(t, "phil", phil.Name) require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$")) - require.Equal(t, auth.RoleAdmin, phil.Role) - require.Equal(t, []auth.Grant{}, phil.Grants) + require.Equal(t, user.RoleAdmin, phil.Role) + require.Equal(t, []user.Grant{}, phil.Grants) ben, err := a.User("ben") require.Nil(t, err) require.Equal(t, "ben", ben.Name) require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$")) - require.Equal(t, auth.RoleUser, ben.Role) - require.Equal(t, []auth.Grant{ + require.Equal(t, user.RoleUser, ben.Role) + require.Equal(t, []user.Grant{ {"mytopic", true, true}, {"readme", true, false}, {"writeme", false, true}, {"everyonewrite", false, false}, }, ben.Grants) - everyone, err := a.User(auth.Everyone) + everyone, err := a.User(user.Everyone) require.Nil(t, err) require.Equal(t, "*", everyone.Name) require.Equal(t, "", everyone.Hash) - require.Equal(t, auth.RoleAnonymous, everyone.Role) - require.Equal(t, []auth.Grant{ + require.Equal(t, user.RoleAnonymous, everyone.Role) + require.Equal(t, []user.Grant{ {"announcements", true, false}, {"everyonewrite", true, true}, }, everyone.Grants) @@ -165,22 +164,22 @@ func TestSQLiteAuth_UserManagement(t *testing.T) { require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "writeme", false, true)) - require.Nil(t, a.Authorize(ben, "mytopic", auth.PermissionRead)) - require.Nil(t, a.Authorize(ben, "mytopic", auth.PermissionWrite)) - require.Nil(t, a.Authorize(ben, "readme", auth.PermissionRead)) - require.Nil(t, a.Authorize(ben, "writeme", auth.PermissionWrite)) + require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionRead)) + require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionWrite)) + require.Nil(t, a.Authorize(ben, "readme", user.PermissionRead)) + require.Nil(t, a.Authorize(ben, "writeme", user.PermissionWrite)) // Revoke access for "ben" to "mytopic", then check again require.Nil(t, a.ResetAccess("ben", "mytopic")) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "mytopic", auth.PermissionWrite)) // Revoked - require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "mytopic", auth.PermissionRead)) // Revoked - require.Nil(t, a.Authorize(ben, "readme", auth.PermissionRead)) // Unchanged - require.Nil(t, a.Authorize(ben, "writeme", auth.PermissionWrite)) // Unchanged + require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "mytopic", user.PermissionWrite)) // Revoked + require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "mytopic", user.PermissionRead)) // Revoked + require.Nil(t, a.Authorize(ben, "readme", user.PermissionRead)) // Unchanged + require.Nil(t, a.Authorize(ben, "writeme", user.PermissionWrite)) // Unchanged // Revoke rest of the access require.Nil(t, a.ResetAccess("ben", "")) - require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "readme", auth.PermissionRead)) // Revoked - require.Equal(t, auth.ErrUnauthorized, a.Authorize(ben, "wrtiteme", auth.PermissionWrite)) // Revoked + require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "readme", user.PermissionRead)) // Revoked + require.Equal(t, user.ErrUnauthorized, a.Authorize(ben, "wrtiteme", user.PermissionWrite)) // Revoked // User list users, err := a.Users() @@ -193,7 +192,7 @@ func TestSQLiteAuth_UserManagement(t *testing.T) { // Remove user require.Nil(t, a.RemoveUser("ben")) _, err = a.User("ben") - require.Equal(t, auth.ErrNotFound, err) + require.Equal(t, user.ErrNotFound, err) users, err = a.Users() require.Nil(t, err) @@ -204,40 +203,40 @@ func TestSQLiteAuth_UserManagement(t *testing.T) { func TestSQLiteAuth_ChangePassword(t *testing.T) { a := newTestAuth(t, false, false) - require.Nil(t, a.AddUser("phil", "phil", auth.RoleAdmin)) + require.Nil(t, a.AddUser("phil", "phil", user.RoleAdmin)) _, err := a.Authenticate("phil", "phil") require.Nil(t, err) require.Nil(t, a.ChangePassword("phil", "newpass")) _, err = a.Authenticate("phil", "phil") - require.Equal(t, auth.ErrUnauthenticated, err) + require.Equal(t, user.ErrUnauthenticated, err) _, err = a.Authenticate("phil", "newpass") require.Nil(t, err) } func TestSQLiteAuth_ChangeRole(t *testing.T) { a := newTestAuth(t, false, false) - require.Nil(t, a.AddUser("ben", "ben", auth.RoleUser)) + require.Nil(t, a.AddUser("ben", "ben", user.RoleUser)) require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "readme", true, false)) ben, err := a.User("ben") require.Nil(t, err) - require.Equal(t, auth.RoleUser, ben.Role) + require.Equal(t, user.RoleUser, ben.Role) require.Equal(t, 2, len(ben.Grants)) - require.Nil(t, a.ChangeRole("ben", auth.RoleAdmin)) + require.Nil(t, a.ChangeRole("ben", user.RoleAdmin)) ben, err = a.User("ben") require.Nil(t, err) - require.Equal(t, auth.RoleAdmin, ben.Role) + require.Equal(t, user.RoleAdmin, ben.Role) require.Equal(t, 0, len(ben.Grants)) } -func newTestAuth(t *testing.T, defaultRead, defaultWrite bool) *auth.SQLiteAuthManager { +func newTestAuth(t *testing.T, defaultRead, defaultWrite bool) *user.SQLiteAuthManager { filename := filepath.Join(t.TempDir(), "user.db") - a, err := auth.NewSQLiteAuthManager(filename, defaultRead, defaultWrite) + a, err := user.NewSQLiteAuthManager(filename, defaultRead, defaultWrite) require.Nil(t, err) return a } diff --git a/web/src/app/Api.js b/web/src/app/Api.js index 3d79bac..219861f 100644 --- a/web/src/app/Api.js +++ b/web/src/app/Api.js @@ -1,14 +1,18 @@ import { + accountPasswordUrl, + accountSettingsUrl, + accountSubscriptionSingleUrl, + accountSubscriptionUrl, + accountTokenUrl, + accountUrl, fetchLinesIterator, - maybeWithBasicAuth, maybeWithBearerAuth, + maybeWithBasicAuth, + maybeWithBearerAuth, topicShortUrl, topicUrl, topicUrlAuth, topicUrlJsonPoll, - topicUrlJsonPollWithSince, - accountSettingsUrl, - accountTokenUrl, - userStatsUrl, accountSubscriptionUrl, accountSubscriptionSingleUrl, accountUrl, accountPasswordUrl + topicUrlJsonPollWithSince } from "./utils"; import userManager from "./UserManager"; @@ -74,7 +78,7 @@ class Api { xhr.setRequestHeader(key, value); } xhr.upload.addEventListener("progress", onProgress); - xhr.addEventListener('readystatechange', (ev) => { + xhr.addEventListener('readystatechange', () => { if (xhr.readyState === 4 && xhr.status >= 200 && xhr.status <= 299) { console.log(`[Api] Publish successful (HTTP ${xhr.status})`, xhr.response); resolve(xhr.response); @@ -123,6 +127,7 @@ class Api { const url = accountTokenUrl(baseUrl); console.log(`[Api] Checking auth for ${url}`); const response = await fetch(url, { + method: "POST", headers: maybeWithBasicAuth({}, user) }); if (response.status === 401 || response.status === 403) { @@ -218,12 +223,26 @@ class Api { } } + async extendToken(baseUrl, token) { + const url = accountTokenUrl(baseUrl); + console.log(`[Api] Extending user access token ${url}`); + const response = await fetch(url, { + method: "PATCH", + headers: maybeWithBearerAuth({}, token) + }); + if (response.status === 401 || response.status === 403) { + throw new UnauthorizedError(); + } else if (response.status !== 200) { + throw new Error(`Unexpected server response ${response.status}`); + } + } + async updateAccountSettings(baseUrl, token, payload) { const url = accountSettingsUrl(baseUrl); const body = JSON.stringify(payload); console.log(`[Api] Updating user account ${url}: ${body}`); const response = await fetch(url, { - method: "POST", + method: "PATCH", headers: maybeWithBearerAuth({}, token), body: body });