Fix all the tests

This commit is contained in:
binwiederhier 2022-12-28 13:28:28 -05:00
parent d9722a9825
commit a2e474c375
4 changed files with 44 additions and 35 deletions

View file

@ -40,8 +40,6 @@ import (
reserve topics reserve topics
purge accounts that were not logged into in X purge accounts that were not logged into in X
reset daily limits for users reset daily limits for users
"user list" shows * twice
"ntfy access everyone user4topic <bla>" twice -> UNIQUE constraint error
Account usage not updated "in real time" Account usage not updated "in real time"
Attachment expiration based on plan Attachment expiration based on plan
Plan: Keep 10000 messages or keep X days? Plan: Keep 10000 messages or keep X days?
@ -66,6 +64,7 @@ import (
- Expire tokens - Expire tokens
- userManager can be nil - userManager can be nil
- visitor with/without user - visitor with/without user
- userManager.<NEWSTUFF>
*/ */
// Server is the main server, providing the UI and API for ntfy // Server is the main server, providing the UI and API for ntfy

View file

@ -93,16 +93,30 @@ const (
// Manager-related queries // Manager-related queries
const ( const (
insertUserQuery = `INSERT INTO user (user, pass, role) VALUES (?, ?, ?)` insertUserQuery = `INSERT INTO user (user, pass, role) VALUES (?, ?, ?)`
selectUsernamesQuery = `SELECT user FROM user ORDER BY role, user` selectUsernamesQuery = `
SELECT user
FROM user
ORDER BY
CASE role
WHEN 'admin' THEN 1
WHEN 'anonymous' THEN 3
ELSE 2
END, user
`
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?` updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?` updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserSettingsQuery = `UPDATE user SET settings = ? WHERE user = ?` updateUserSettingsQuery = `UPDATE user SET settings = ? WHERE user = ?`
updateUserStatsQuery = `UPDATE user SET messages = ?, emails = ? WHERE user = ?` updateUserStatsQuery = `UPDATE user SET messages = ?, emails = ? WHERE user = ?`
deleteUserQuery = `DELETE FROM user WHERE user = ?` deleteUserQuery = `DELETE FROM user WHERE user = ?`
upsertUserAccessQuery = `INSERT INTO user_access (user_id, topic, read, write) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?)` upsertUserAccessQuery = `
selectUserAccessQuery = `SELECT topic, read, write FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)` INSERT INTO user_access (user_id, topic, read, write)
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write
`
selectUserAccessQuery = `SELECT topic, read, write FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) ORDER BY write DESC, read DESC, topic`
deleteAllAccessQuery = `DELETE FROM user_access` deleteAllAccessQuery = `DELETE FROM user_access`
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)` deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?` deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
@ -152,7 +166,7 @@ func NewManager(filename string, defaultRead, defaultWrite bool) (*Manager, erro
return manager, nil return manager, nil
} }
// Authenticate checks username and password and returns a user if correct. The method // Authenticate checks username and password and returns a User if correct. The method
// returns in constant-ish time, regardless of whether the user exists or the password is // returns in constant-ish time, regardless of whether the user exists or the password is
// correct or incorrect. // correct or incorrect.
func (a *Manager) Authenticate(username, password string) (*User, error) { func (a *Manager) Authenticate(username, password string) (*User, error) {
@ -171,6 +185,8 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
return user, nil return user, nil
} }
// AuthenticateToken checks if the token exists and returns the associated User if it does.
// The method sets the User.Token value to the token that was used for authentication.
func (a *Manager) AuthenticateToken(token string) (*User, error) { func (a *Manager) AuthenticateToken(token string) (*User, error) {
user, err := a.userByToken(token) user, err := a.userByToken(token)
if err != nil { if err != nil {
@ -180,9 +196,10 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
return user, nil return user, nil
} }
// CreateToken generates a random token for the given user and returns it. The token expires
// after a fixed duration unless ExtendToken is called.
func (a *Manager) CreateToken(user *User) (*Token, error) { func (a *Manager) CreateToken(user *User) (*Token, error) {
token := util.RandomString(tokenLength) token, expires := util.RandomString(tokenLength), time.Now().Add(userTokenExpiryDuration)
expires := time.Now().Add(userTokenExpiryDuration)
if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil { if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
return nil, err return nil, err
} }
@ -192,6 +209,7 @@ func (a *Manager) CreateToken(user *User) (*Token, error) {
}, nil }, nil
} }
// ExtendToken sets the new expiry date for a token, thereby extending its use further into the future.
func (a *Manager) ExtendToken(user *User) (*Token, error) { func (a *Manager) ExtendToken(user *User) (*Token, error) {
newExpires := time.Now().Add(userTokenExpiryDuration) newExpires := time.Now().Add(userTokenExpiryDuration)
if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil { if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
@ -203,6 +221,7 @@ func (a *Manager) ExtendToken(user *User) (*Token, error) {
}, nil }, nil
} }
// RemoveToken deletes the token defined in User.Token
func (a *Manager) RemoveToken(user *User) error { func (a *Manager) RemoveToken(user *User) error {
if user.Token == "" { if user.Token == "" {
return ErrUnauthorized return ErrUnauthorized
@ -213,6 +232,7 @@ func (a *Manager) RemoveToken(user *User) error {
return nil return nil
} }
// RemoveExpiredTokens deletes all expired tokens from the database
func (a *Manager) RemoveExpiredTokens() error { func (a *Manager) RemoveExpiredTokens() error {
if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil { if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
return err return err
@ -370,20 +390,23 @@ func (a *Manager) Users() ([]*User, error) {
} }
users = append(users, user) users = append(users, user)
} }
everyone, err := a.everyoneUser() /*sort.Slice(users, func(i, j int) bool {
if err != nil { if users[i].Role != users[j].Role {
return nil, err return true
} }
users = append(users, everyone) if users[i].Name == Everyone || users[j].Name == Everyone {
return users[i].Name != Everyone
} else if string(users[i].Role) < string(users[j].Role) {
return true
}
return users[i].Name < users[j].Name
})*/
return users, nil return users, nil
} }
// User returns the user with the given username if it exists, or ErrNotFound otherwise. // User returns the user with the given username if it exists, or ErrNotFound otherwise.
// You may also pass Everyone to retrieve the anonymous user and its Grant list. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
func (a *Manager) User(username string) (*User, error) { func (a *Manager) User(username string) (*User, error) {
if username == Everyone {
return a.everyoneUser()
}
rows, err := a.db.Query(selectUserByNameQuery, username) rows, err := a.db.Query(selectUserByNameQuery, username)
if err != nil { if err != nil {
return nil, err return nil, err
@ -446,19 +469,6 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
return user, nil return user, nil
} }
func (a *Manager) everyoneUser() (*User, error) {
grants, err := a.readGrants(Everyone)
if err != nil {
return nil, err
}
return &User{
Name: Everyone,
Hash: "",
Role: RoleAnonymous,
Grants: grants,
}, nil
}
func (a *Manager) readGrants(username string) ([]Grant, error) { func (a *Manager) readGrants(username string) ([]Grant, error) {
rows, err := a.db.Query(selectUserAccessQuery, username) rows, err := a.db.Query(selectUserAccessQuery, username)
if err != nil { if err != nil {

View file

@ -37,8 +37,8 @@ func TestSQLiteAuth_FullScenario_Default_DenyAll(t *testing.T) {
require.Equal(t, user.RoleUser, ben.Role) require.Equal(t, user.RoleUser, ben.Role)
require.Equal(t, []user.Grant{ require.Equal(t, []user.Grant{
{"mytopic", true, true}, {"mytopic", true, true},
{"readme", true, false},
{"writeme", false, true}, {"writeme", false, true},
{"readme", true, false},
{"everyonewrite", false, false}, {"everyonewrite", false, false},
}, ben.Grants) }, ben.Grants)
@ -146,8 +146,8 @@ func TestSQLiteAuth_UserManagement(t *testing.T) {
require.Equal(t, user.RoleUser, ben.Role) require.Equal(t, user.RoleUser, ben.Role)
require.Equal(t, []user.Grant{ require.Equal(t, []user.Grant{
{"mytopic", true, true}, {"mytopic", true, true},
{"readme", true, false},
{"writeme", false, true}, {"writeme", false, true},
{"readme", true, false},
{"everyonewrite", false, false}, {"everyonewrite", false, false},
}, ben.Grants) }, ben.Grants)
@ -157,12 +157,12 @@ func TestSQLiteAuth_UserManagement(t *testing.T) {
require.Equal(t, "", everyone.Hash) require.Equal(t, "", everyone.Hash)
require.Equal(t, user.RoleAnonymous, everyone.Role) require.Equal(t, user.RoleAnonymous, everyone.Role)
require.Equal(t, []user.Grant{ require.Equal(t, []user.Grant{
{"announcements", true, false},
{"everyonewrite", true, true}, {"everyonewrite", true, true},
{"announcements", true, false},
}, everyone.Grants) }, everyone.Grants)
// Ben: Before revoking // Ben: Before revoking
require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) // Overwrite!
require.Nil(t, a.AllowAccess("ben", "readme", true, false)) require.Nil(t, a.AllowAccess("ben", "readme", true, false))
require.Nil(t, a.AllowAccess("ben", "writeme", false, true)) require.Nil(t, a.AllowAccess("ben", "writeme", false, true))
require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionRead)) require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionRead))

View file

@ -96,7 +96,7 @@ type Role string
// User roles // User roles
const ( const (
RoleAdmin = Role("admin") RoleAdmin = Role("admin") // Some queries have these values hardcoded!
RoleUser = Role("user") RoleUser = Role("user")
RoleAnonymous = Role("anonymous") RoleAnonymous = Role("anonymous")
) )