User database migration
This commit is contained in:
parent
bd2ec7b2af
commit
f4ffcebb14
3 changed files with 151 additions and 15 deletions
|
@ -40,7 +40,6 @@ import (
|
||||||
message cache duration
|
message cache duration
|
||||||
Keep 10000 messages or keep X days?
|
Keep 10000 messages or keep X days?
|
||||||
Attachment expiration based on plan
|
Attachment expiration based on plan
|
||||||
database migration
|
|
||||||
reserve topics
|
reserve topics
|
||||||
purge accounts that were not logged into in X
|
purge accounts that were not logged into in X
|
||||||
reset daily limits for users
|
reset daily limits for users
|
||||||
|
|
|
@ -24,8 +24,7 @@ const (
|
||||||
|
|
||||||
// Manager-related queries
|
// Manager-related queries
|
||||||
const (
|
const (
|
||||||
createAuthTablesQueries = `
|
createTablesQueriesNoTx = `
|
||||||
BEGIN;
|
|
||||||
CREATE TABLE IF NOT EXISTS plan (
|
CREATE TABLE IF NOT EXISTS plan (
|
||||||
id INT NOT NULL,
|
id INT NOT NULL,
|
||||||
code TEXT NOT NULL,
|
code TEXT NOT NULL,
|
||||||
|
@ -67,8 +66,8 @@ const (
|
||||||
version INT NOT NULL
|
version INT NOT NULL
|
||||||
);
|
);
|
||||||
INSERT INTO user (id, user, pass, role) VALUES (1, '*', '', 'anonymous') ON CONFLICT (id) DO NOTHING;
|
INSERT INTO user (id, user, pass, role) VALUES (1, '*', '', 'anonymous') ON CONFLICT (id) DO NOTHING;
|
||||||
COMMIT;
|
|
||||||
`
|
`
|
||||||
|
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
|
||||||
selectUserByNameQuery = `
|
selectUserByNameQuery = `
|
||||||
SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.emails_limit, p.attachment_file_size_limit, p.attachment_total_size_limit
|
SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.emails_limit, p.attachment_file_size_limit, p.attachment_total_size_limit
|
||||||
FROM user u
|
FROM user u
|
||||||
|
@ -130,9 +129,27 @@ const (
|
||||||
|
|
||||||
// Schema management queries
|
// Schema management queries
|
||||||
const (
|
const (
|
||||||
currentSchemaVersion = 1
|
currentSchemaVersion = 2
|
||||||
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||||
|
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||||
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||||
|
|
||||||
|
// 1 -> 2 (complex migration!)
|
||||||
|
migrate1To2RenameUserTableQueryNoTx = `
|
||||||
|
ALTER TABLE user RENAME TO user_old;
|
||||||
|
`
|
||||||
|
migrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||||
|
INSERT INTO user (user, pass, role)
|
||||||
|
SELECT user, pass, role FROM user_old;
|
||||||
|
|
||||||
|
INSERT INTO user_access (user_id, topic, read, write)
|
||||||
|
SELECT u.id, a.topic, a.read, a.write
|
||||||
|
FROM user u
|
||||||
|
JOIN access a ON u.user = a.user;
|
||||||
|
|
||||||
|
DROP TABLE access;
|
||||||
|
DROP TABLE user_old;
|
||||||
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager is an implementation of Manager. It stores users and access control list
|
// Manager is an implementation of Manager. It stores users and access control list
|
||||||
|
@ -159,7 +176,7 @@ func newManager(filename string, defaultRead, defaultWrite bool, tokenExpiryDura
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := setupAuthDB(db); err != nil {
|
if err := setupDB(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
|
@ -364,16 +381,21 @@ func (a *Manager) RemoveUser(username string) error {
|
||||||
if !AllowedUsername(username) {
|
if !AllowedUsername(username) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
if _, err := a.db.Exec(deleteUserAccessQuery, username); err != nil {
|
tx, err := a.db.Begin()
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := a.db.Exec(deleteUserTokensQuery, username); err != nil {
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(deleteUserAccessQuery, username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
|
if _, err := tx.Exec(deleteUserTokensQuery, username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
if _, err := tx.Exec(deleteUserQuery, username); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Users returns a list of users. It always also returns the Everyone user ("*").
|
// Users returns a list of users. It always also returns the Everyone user ("*").
|
||||||
|
@ -567,11 +589,11 @@ func fromSQLWildcard(s string) string {
|
||||||
return strings.ReplaceAll(s, "%", "*")
|
return strings.ReplaceAll(s, "%", "*")
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupAuthDB(db *sql.DB) error {
|
func setupDB(db *sql.DB) error {
|
||||||
// If 'schemaVersion' table does not exist, this must be a new database
|
// If 'schemaVersion' table does not exist, this must be a new database
|
||||||
rowsSV, err := db.Query(selectSchemaVersionQuery)
|
rowsSV, err := db.Query(selectSchemaVersionQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return setupNewAuthDB(db)
|
return setupNewDB(db)
|
||||||
}
|
}
|
||||||
defer rowsSV.Close()
|
defer rowsSV.Close()
|
||||||
|
|
||||||
|
@ -588,12 +610,14 @@ func setupAuthDB(db *sql.DB) error {
|
||||||
// Do migrations
|
// Do migrations
|
||||||
if schemaVersion == currentSchemaVersion {
|
if schemaVersion == currentSchemaVersion {
|
||||||
return nil
|
return nil
|
||||||
|
} else if schemaVersion == 1 {
|
||||||
|
return migrateFrom1(db)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
|
return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupNewAuthDB(db *sql.DB) error {
|
func setupNewDB(db *sql.DB) error {
|
||||||
if _, err := db.Exec(createAuthTablesQueries); err != nil {
|
if _, err := db.Exec(createTablesQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
|
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
|
||||||
|
@ -601,3 +625,28 @@ func setupNewAuthDB(db *sql.DB) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func migrateFrom1(db *sql.DB) error {
|
||||||
|
log.Info("Migrating user database schema: from 1 to 2")
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil // Update this when a new version is added
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package user
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -350,8 +351,95 @@ func TestManager_EnqueueStats(t *testing.T) {
|
||||||
require.Equal(t, int64(2), u.Stats.Emails)
|
require.Equal(t, int64(2), u.Stats.Emails)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||||
|
filename := filepath.Join(t.TempDir(), "user.db")
|
||||||
|
db, err := sql.Open("sqlite3", filename)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Create "version 1" schema
|
||||||
|
_, err = db.Exec(`
|
||||||
|
BEGIN;
|
||||||
|
CREATE TABLE IF NOT EXISTS user (
|
||||||
|
user TEXT NOT NULL PRIMARY KEY,
|
||||||
|
pass TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS access (
|
||||||
|
user TEXT NOT NULL,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
read INT NOT NULL,
|
||||||
|
write INT NOT NULL,
|
||||||
|
PRIMARY KEY (topic, user)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||||
|
id INT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
|
||||||
|
COMMIT;
|
||||||
|
`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Insert a bunch of users and ACL entries
|
||||||
|
_, err = db.Exec(`
|
||||||
|
BEGIN;
|
||||||
|
INSERT INTO user (user, pass, role) VALUES ('ben', '$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy', 'user');
|
||||||
|
INSERT INTO user (user, pass, role) VALUES ('phil', '$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C', 'admin');
|
||||||
|
INSERT INTO access (user, topic, read, write) VALUES ('ben', 'stats', 1, 1);
|
||||||
|
INSERT INTO access (user, topic, read, write) VALUES ('ben', 'secret', 1, 0);
|
||||||
|
INSERT INTO access (user, topic, read, write) VALUES ('*', 'stats', 1, 0);
|
||||||
|
COMMIT;
|
||||||
|
`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Create manager to trigger migration
|
||||||
|
a := newTestManagerFromFile(t, filename, false, false, userTokenExpiryDuration, userStatsQueueWriterInterval)
|
||||||
|
checkSchemaVersion(t, a.db)
|
||||||
|
|
||||||
|
users, err := a.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 3, len(users))
|
||||||
|
phil, ben, everyone := users[0], users[1], users[2]
|
||||||
|
|
||||||
|
require.Equal(t, "phil", phil.Name)
|
||||||
|
require.Equal(t, RoleAdmin, phil.Role)
|
||||||
|
require.Equal(t, 0, len(phil.Grants))
|
||||||
|
|
||||||
|
require.Equal(t, "ben", ben.Name)
|
||||||
|
require.Equal(t, RoleUser, ben.Role)
|
||||||
|
require.Equal(t, 2, len(ben.Grants))
|
||||||
|
require.Equal(t, "stats", ben.Grants[0].TopicPattern)
|
||||||
|
require.Equal(t, true, ben.Grants[0].AllowRead)
|
||||||
|
require.Equal(t, true, ben.Grants[0].AllowWrite)
|
||||||
|
require.Equal(t, "secret", ben.Grants[1].TopicPattern)
|
||||||
|
require.Equal(t, true, ben.Grants[1].AllowRead)
|
||||||
|
require.Equal(t, false, ben.Grants[1].AllowWrite)
|
||||||
|
|
||||||
|
require.Equal(t, Everyone, everyone.Name)
|
||||||
|
require.Equal(t, RoleAnonymous, everyone.Role)
|
||||||
|
require.Equal(t, 1, len(everyone.Grants))
|
||||||
|
require.Equal(t, "stats", everyone.Grants[0].TopicPattern)
|
||||||
|
require.Equal(t, true, everyone.Grants[0].AllowRead)
|
||||||
|
require.Equal(t, false, everyone.Grants[0].AllowWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
||||||
|
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, rows.Next())
|
||||||
|
|
||||||
|
var schemaVersion int
|
||||||
|
require.Nil(t, rows.Scan(&schemaVersion))
|
||||||
|
require.Equal(t, currentSchemaVersion, schemaVersion)
|
||||||
|
require.Nil(t, rows.Close())
|
||||||
|
}
|
||||||
|
|
||||||
func newTestManager(t *testing.T, defaultRead, defaultWrite bool) *Manager {
|
func newTestManager(t *testing.T, defaultRead, defaultWrite bool) *Manager {
|
||||||
a, err := NewManager(filepath.Join(t.TempDir(), "db"), defaultRead, defaultWrite)
|
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultRead, defaultWrite, userTokenExpiryDuration, userStatsQueueWriterInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestManagerFromFile(t *testing.T, filename string, defaultRead, defaultWrite bool, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager {
|
||||||
|
a, err := newManager(filename, defaultRead, defaultWrite, tokenExpiryDuration, statsWriterInterval)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue