diff --git a/server/server.go b/server/server.go index 0477c4b..fb795bb 100644 --- a/server/server.go +++ b/server/server.go @@ -45,10 +45,6 @@ import ( "account topic" sync mechanism purge accounts that were not logged into in X reset daily limits for users - max token issue limit - user db startup queries -> foreign keys - UI - - Feature flag for "reserve topic" feature Sync: - "mute" setting - figure out what settings are "web" or "phone" diff --git a/user/manager.go b/user/manager.go index 0aae865..d7e069c 100644 --- a/user/manager.go +++ b/user/manager.go @@ -15,16 +15,18 @@ import ( ) const ( - tokenLength = 32 bcryptCost = 10 intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost userStatsQueueWriterInterval = 33 * time.Second - userTokenExpiryDuration = 72 * time.Hour + tokenLength = 32 + tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much + tokenMaxCount = 10 // Only keep this many tokens in the table per user ) var ( errNoTokenProvided = errors.New("no token provided") errTopicOwnedByOthers = errors.New("topic owned by others") + errNoRows = errors.New("no rows found") ) // Manager-related queries @@ -139,7 +141,7 @@ const ( ORDER BY a_user.topic ` selectOtherAccessCountQuery = ` - SELECT count(*) + SELECT COUNT(*) FROM user_access WHERE (topic = ? OR ? LIKE topic) AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?)) @@ -148,10 +150,22 @@ const ( deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)` deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?` + selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)` insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)` updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?` + deleteExcessTokensQuery = ` + DELETE FROM user_token + WHERE (user_id, token) NOT IN ( + SELECT user_id, token + FROM user_token + WHERE user_id = (SELECT id FROM user WHERE user = ?) + ORDER BY expires DESC + LIMIT ? + ) +; + ` ) // Schema management queries @@ -182,22 +196,21 @@ const ( // Manager is an implementation of Manager. It stores users and access control list // in a SQLite database. type Manager struct { - db *sql.DB - defaultAccess Permission // Default permission if no ACL matches - statsQueue map[string]*User // Username -> User, for "unimportant" user updates - tokenExpiryInterval time.Duration // Duration after which tokens expire, and by which tokens are extended - mu sync.Mutex + db *sql.DB + defaultAccess Permission // Default permission if no ACL matches + statsQueue map[string]*User // Username -> User, for "unimportant" user updates + mu sync.Mutex } var _ Auther = (*Manager)(nil) // NewManager creates a new Manager instance func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) { - return newManager(filename, startupQueries, defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval) + return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval) } // NewManager creates a new Manager instance -func newManager(filename, startupQueries string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) { +func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) { db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err @@ -209,10 +222,9 @@ func newManager(filename, startupQueries string, defaultAccess Permission, token return nil, err } manager := &Manager{ - db: db, - defaultAccess: defaultAccess, - statsQueue: make(map[string]*User), - tokenExpiryInterval: tokenExpiryDuration, + db: db, + defaultAccess: defaultAccess, + statsQueue: make(map[string]*User), } go manager.userStatsQueueWriter(statsWriterInterval) return manager, nil @@ -253,10 +265,38 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { } // CreateToken generates a random token for the given user and returns it. The token expires -// after a fixed duration unless ExtendToken is called. +// after a fixed duration unless ExtendToken is called. This function also prunes tokens for the +// given user, if there are too many of them. func (a *Manager) CreateToken(user *User) (*Token, error) { - token, expires := util.RandomString(tokenLength), time.Now().Add(userTokenExpiryDuration) - if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil { + token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration) + tx, err := a.db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil { + return nil, err + } + rows, err := tx.Query(selectTokenCountQuery, user.Name) + if err != nil { + return nil, err + } + defer rows.Close() + if !rows.Next() { + return nil, errNoRows + } + var tokenCount int + if err := rows.Scan(&tokenCount); err != nil { + return nil, err + } + if tokenCount >= tokenMaxCount { + // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup + // on two indices, whereas the query below is a full table scan. + if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil { + return nil, err + } + } + if err := tx.Commit(); err != nil { return nil, err } return &Token{ @@ -270,7 +310,7 @@ func (a *Manager) ExtendToken(user *User) (*Token, error) { if user.Token == "" { return nil, errNoTokenProvided } - newExpires := time.Now().Add(userTokenExpiryDuration) + newExpires := time.Now().Add(tokenExpiryDuration) if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil { return nil, err } @@ -600,7 +640,7 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error { } defer rows.Close() if !rows.Next() { - return errors.New("no rows found") + return errNoRows } var otherCount int if err := rows.Scan(&otherCount); err != nil { diff --git a/user/manager_test.go b/user/manager_test.go index 8845372..611c8df 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -369,8 +369,51 @@ func TestManager_Token_Extend(t *testing.T) { require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix()) } +func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { + a := newTestManager(t, PermissionDenyAll) + require.Nil(t, a.AddUser("ben", "ben", RoleUser)) + + // Try to extend token for user without token + u, err := a.User("ben") + require.Nil(t, err) + + // Tokens + baseTime := time.Now().Add(24 * time.Hour) + tokens := make([]string, 0) + for i := 0; i < 12; i++ { + token, err := a.CreateToken(u) + require.Nil(t, err) + require.NotEmpty(t, token.Value) + tokens = append(tokens, token.Value) + + // Manually modify expiry date to avoid sorting issues (this is a hack) + _, err = a.db.Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value) + require.Nil(t, err) + } + + _, err = a.AuthenticateToken(tokens[0]) + require.Equal(t, ErrUnauthenticated, err) + + _, err = a.AuthenticateToken(tokens[1]) + require.Equal(t, ErrUnauthenticated, err) + + for i := 2; i < 12; i++ { + userWithToken, err := a.AuthenticateToken(tokens[i]) + require.Nil(t, err, "token[%d]=%s failed", i, tokens[i]) + require.Equal(t, "ben", userWithToken.Name) + require.Equal(t, tokens[i], userWithToken.Token) + } + + var count int + rows, err := a.db.Query(`SELECT COUNT(*) FROM user_token`) + require.Nil(t, err) + require.True(t, rows.Next()) + require.Nil(t, rows.Scan(&count)) + require.Equal(t, 10, count) +} + func TestManager_EnqueueStats(t *testing.T) { - a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond) + a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) require.Nil(t, err) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) @@ -400,7 +443,7 @@ func TestManager_EnqueueStats(t *testing.T) { } func TestManager_ChangeSettings(t *testing.T) { - a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond) + a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) require.Nil(t, err) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) @@ -482,7 +525,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) { require.Nil(t, err) // Create manager to trigger migration - a := newTestManagerFromFile(t, filename, PermissionDenyAll, userTokenExpiryDuration, userStatsQueueWriterInterval) + a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, userStatsQueueWriterInterval) checkSchemaVersion(t, a.db) users, err := a.Users() @@ -530,11 +573,11 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) { } func newTestManager(t *testing.T, defaultAccess Permission) *Manager { - return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval) + return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, userStatsQueueWriterInterval) } -func newTestManagerFromFile(t *testing.T, filename string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager { - a, err := newManager(filename, defaultAccess, tokenExpiryDuration, statsWriterInterval) +func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) *Manager { + a, err := newManager(filename, startupQueries, defaultAccess, statsWriterInterval) require.Nil(t, err) return a }