Self-review, round 2

This commit is contained in:
binwiederhier 2023-02-09 15:24:12 -05:00
parent bcb22d8d4c
commit e6bb5f484c
24 changed files with 288 additions and 183 deletions

View file

@ -1,3 +1,4 @@
// Package user deals with authentication and authorization against topics
package user
import (
@ -28,7 +29,7 @@ const (
tokenPrefix = "tk_"
tokenLength = 32
tokenMaxCount = 20 // Only keep this many tokens in the table per user
tagManager = "user_manager"
tag = "user_manager"
)
// Default constants that may be overridden by configs
@ -47,7 +48,7 @@ var (
const (
createTablesQueriesNoTx = `
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
@ -89,7 +90,7 @@ const (
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
@ -109,7 +110,7 @@ const (
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
@ -121,7 +122,7 @@ const (
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
@ -151,12 +152,12 @@ const (
`
insertUserQuery = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES (?, ?, ?, ?, ?, ?)
`
selectUsernamesQuery = `
SELECT user
FROM user
SELECT user
FROM user
ORDER BY
CASE role
WHEN 'admin' THEN 1
@ -166,7 +167,7 @@ const (
`
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
@ -174,15 +175,15 @@ const (
deleteUserQuery = `DELETE FROM user WHERE user = ?`
upsertUserAccessQuery = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
ON CONFLICT (user_id, topic)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
`
selectUserAccessQuery = `
SELECT topic, read, write
FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY write DESC, read DESC, topic
`
selectUserReservationsQuery = `
@ -201,9 +202,9 @@ const (
selectUserHasReservationQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
AND topic = ?
AND topic = ?
`
selectOtherAccessCountQuery = `
SELECT COUNT(*)
@ -213,13 +214,13 @@ const (
`
deleteAllAccessQuery = `DELETE FROM user_access`
deleteUserAccessQuery = `
DELETE FROM user_access
DELETE FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
`
deleteTopicAccessQuery = `
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
AND topic = ?
`
@ -239,7 +240,7 @@ const (
SELECT user_id, token
FROM user_token
WHERE user_id = ?
ORDER BY expires DESC
ORDER BY expires DESC
LIMIT ?
)
`
@ -249,7 +250,7 @@ const (
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
updateTierQuery = `
UPDATE tier
UPDATE tier
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_price_id = ?
WHERE code = ?
`
@ -272,7 +273,7 @@ const (
deleteTierQuery = `DELETE FROM tier WHERE code = ?`
updateBillingQuery = `
UPDATE user
UPDATE user
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
WHERE user = ?
`
@ -291,7 +292,7 @@ const (
`
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
`
migrate1To2InsertFromOldTablesAndDropNoTx = `
@ -305,6 +306,12 @@ const (
`
)
var (
migrations = map[int]func(db *sql.DB) error{
1: migrateFrom1,
}
)
// Manager is an implementation of Manager. It stores users and access control list
// in a SQLite database.
type Manager struct {
@ -350,15 +357,15 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
}
user, err := a.User(username)
if err != nil {
log.Tag(tagManager).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
} else if user.Deleted {
log.Tag(tagManager).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
log.Tag(tag).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
} else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
log.Tag(tagManager).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
return nil, ErrUnauthenticated
}
return user, nil
@ -372,7 +379,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
}
user, err := a.userByToken(token)
if err != nil {
log.Tag(tagManager).Field("token", token).Err(err).Trace("Authentication of token failed")
log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed")
return nil, ErrUnauthenticated
}
user.Token = token
@ -532,12 +539,12 @@ func (a *Manager) RemoveDeletedUsers() error {
}
// ChangeSettings persists the user settings
func (a *Manager) ChangeSettings(user *User) error {
prefs, err := json.Marshal(user.Prefs)
func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error {
b, err := json.Marshal(prefs)
if err != nil {
return err
}
if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
if _, err := a.db.Exec(updateUserPrefsQuery, string(b), userID); err != nil {
return err
}
return nil
@ -554,9 +561,9 @@ func (a *Manager) ResetStats() error {
return nil
}
// EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
// EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
// batches at a regular interval
func (a *Manager) EnqueueStats(userID string, stats *Stats) {
func (a *Manager) EnqueueUserStats(userID string, stats *Stats) {
a.mu.Lock()
defer a.mu.Unlock()
a.statsQueue[userID] = stats
@ -574,10 +581,10 @@ func (a *Manager) asyncQueueWriter(interval time.Duration) {
ticker := time.NewTicker(interval)
for range ticker.C {
if err := a.writeUserStatsQueue(); err != nil {
log.Tag(tagManager).Err(err).Warn("Writing user stats queue failed")
log.Tag(tag).Err(err).Warn("Writing user stats queue failed")
}
if err := a.writeTokenUpdateQueue(); err != nil {
log.Tag(tagManager).Err(err).Warn("Writing token update queue failed")
log.Tag(tag).Err(err).Warn("Writing token update queue failed")
}
}
}
@ -586,7 +593,7 @@ func (a *Manager) writeUserStatsQueue() error {
a.mu.Lock()
if len(a.statsQueue) == 0 {
a.mu.Unlock()
log.Tag(tagManager).Trace("No user stats updates to commit")
log.Tag(tag).Trace("No user stats updates to commit")
return nil
}
statsQueue := a.statsQueue
@ -597,10 +604,10 @@ func (a *Manager) writeUserStatsQueue() error {
return err
}
defer tx.Rollback()
log.Tag(tagManager).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
for userID, update := range statsQueue {
log.
Tag(tagManager).
Tag(tag).
Fields(log.Context{
"user_id": userID,
"messages_count": update.Messages,
@ -618,7 +625,7 @@ func (a *Manager) writeTokenUpdateQueue() error {
a.mu.Lock()
if len(a.tokenQueue) == 0 {
a.mu.Unlock()
log.Tag(tagManager).Trace("No token updates to commit")
log.Tag(tag).Trace("No token updates to commit")
return nil
}
tokenQueue := a.tokenQueue
@ -629,9 +636,9 @@ func (a *Manager) writeTokenUpdateQueue() error {
return err
}
defer tx.Rollback()
log.Tag(tagManager).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
for tokenID, update := range tokenQueue {
log.Tag(tagManager).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
return err
}
@ -718,7 +725,7 @@ func (a *Manager) MarkUserRemoved(user *User) error {
return err
}
defer tx.Rollback()
if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
if _, err := tx.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
return err
}
if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
@ -1012,7 +1019,6 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
// CheckAllowAccess tests if a user may create an access control entry for the given topic.
// If there are any ACL entries that are not owned by the user, an error is returned.
// FIXME is this the same as HasReservation?
func (a *Manager) CheckAllowAccess(username string, topic string) error {
if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
return ErrInvalidArgument
@ -1275,10 +1281,18 @@ func setupDB(db *sql.DB) error {
// Do migrations
if schemaVersion == currentSchemaVersion {
return nil
} else if schemaVersion == 1 {
return migrateFrom1(db)
} else if schemaVersion > currentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
}
return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
for i := schemaVersion; i < currentSchemaVersion; i++ {
fn, ok := migrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db); err != nil {
return err
}
}
return nil
}
func setupNewDB(db *sql.DB) error {
@ -1292,7 +1306,7 @@ func setupNewDB(db *sql.DB) error {
}
func migrateFrom1(db *sql.DB) error {
log.Tag(tagManager).Info("Migrating user database schema: from 1 to 2")
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
tx, err := db.Begin()
if err != nil {
return err
@ -1339,7 +1353,7 @@ func migrateFrom1(db *sql.DB) error {
if err := tx.Commit(); err != nil {
return err
}
return nil // Update this when a new version is added
return nil
}
func nullString(s string) sql.NullString {