Add "last access" to access tokens
This commit is contained in:
parent
000bf27c87
commit
e596834096
15 changed files with 276 additions and 145 deletions
111
user/manager.go
111
user/manager.go
|
@ -10,6 +10,7 @@ import (
|
|||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -95,6 +96,8 @@ const (
|
|||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
|
@ -127,9 +130,9 @@ const (
|
|||
selectUserByTokenQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
FROM user u
|
||||
JOIN user_token t on u.id = t.user_id
|
||||
JOIN user_token tk on u.id = tk.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE t.token = ? AND (t.expires = 0 OR t.expires >= ?)
|
||||
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
|
||||
`
|
||||
selectUserByStripeCustomerIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
|
@ -218,16 +221,17 @@ const (
|
|||
AND topic = ?
|
||||
`
|
||||
|
||||
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
selectTokensQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ?`
|
||||
selectTokenQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ? AND token = ?`
|
||||
insertTokenQuery = `INSERT INTO user_token (user_id, token, label, expires) VALUES (?, ?, ?, ?)`
|
||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
deleteExcessTokensQuery = `
|
||||
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
|
||||
selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
|
||||
insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
|
||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
deleteExcessTokensQuery = `
|
||||
DELETE FROM user_token
|
||||
WHERE (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
|
@ -297,16 +301,17 @@ const (
|
|||
// in a SQLite database.
|
||||
type Manager struct {
|
||||
db *sql.DB
|
||||
defaultAccess Permission // Default permission if no ACL matches
|
||||
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
|
||||
bcryptCost int // Makes testing easier
|
||||
defaultAccess Permission // Default permission if no ACL matches
|
||||
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
|
||||
tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
|
||||
bcryptCost int // Makes testing easier
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var _ Auther = (*Manager)(nil)
|
||||
|
||||
// NewManager creates a new Manager instance
|
||||
func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) (*Manager, error) {
|
||||
func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, queueWriterInterval time.Duration) (*Manager, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -321,9 +326,10 @@ func NewManager(filename, startupQueries string, defaultAccess Permission, bcryp
|
|||
db: db,
|
||||
defaultAccess: defaultAccess,
|
||||
statsQueue: make(map[string]*Stats),
|
||||
tokenQueue: make(map[string]*TokenUpdate),
|
||||
bcryptCost: bcryptCost,
|
||||
}
|
||||
go manager.userStatsQueueWriter(statsWriterInterval)
|
||||
go manager.asyncQueueWriter(queueWriterInterval)
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
|
@ -367,14 +373,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
|||
// CreateToken generates a random token for the given user and returns it. The token expires
|
||||
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
||||
// given user, if there are too many of them.
|
||||
func (a *Manager) CreateToken(userID, label string, expires time.Time) (*Token, error) {
|
||||
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
|
||||
token := util.RandomStringPrefix(tokenPrefix, tokenLength)
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(insertTokenQuery, userID, token, label, expires.Unix()); err != nil {
|
||||
access := time.Now()
|
||||
if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := tx.Query(selectTokenCountQuery, userID)
|
||||
|
@ -400,9 +407,11 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time) (*Token,
|
|||
return nil, err
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
Expires: expires,
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: access,
|
||||
LastOrigin: origin,
|
||||
Expires: expires,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -437,20 +446,26 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
|
|||
}
|
||||
|
||||
func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
|
||||
var token, label string
|
||||
var expires int64
|
||||
var token, label, lastOrigin string
|
||||
var lastAccess, expires int64
|
||||
if !rows.Next() {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
if err := rows.Scan(&token, &label, &expires); err != nil {
|
||||
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lastOriginIP, err := netip.ParseAddr(lastOrigin)
|
||||
if err != nil {
|
||||
lastOriginIP = netip.IPv4Unspecified()
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
Expires: time.Unix(expires, 0),
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: time.Unix(lastAccess, 0),
|
||||
LastOrigin: lastOriginIP,
|
||||
Expires: time.Unix(expires, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -521,7 +536,7 @@ func (a *Manager) ChangeSettings(user *User) error {
|
|||
|
||||
// ResetStats resets all user stats in the user database. This touches all users.
|
||||
func (a *Manager) ResetStats() error {
|
||||
a.mu.Lock()
|
||||
a.mu.Lock() // Includes database query to avoid races!
|
||||
defer a.mu.Unlock()
|
||||
if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
|
||||
return err
|
||||
|
@ -538,12 +553,23 @@ func (a *Manager) EnqueueStats(userID string, stats *Stats) {
|
|||
a.statsQueue[userID] = stats
|
||||
}
|
||||
|
||||
func (a *Manager) userStatsQueueWriter(interval time.Duration) {
|
||||
// EnqueueTokenUpdate adds the token update to a queue which writes out token access times
|
||||
// in batches at a regular interval
|
||||
func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.tokenQueue[tokenID] = update
|
||||
}
|
||||
|
||||
func (a *Manager) asyncQueueWriter(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
for range ticker.C {
|
||||
if err := a.writeUserStatsQueue(); err != nil {
|
||||
log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
|
||||
}
|
||||
if err := a.writeTokenUpdateQueue(); err != nil {
|
||||
log.Warn("User Manager: Writing token update queue failed: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -572,6 +598,31 @@ func (a *Manager) writeUserStatsQueue() error {
|
|||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (a *Manager) writeTokenUpdateQueue() error {
|
||||
a.mu.Lock()
|
||||
if len(a.tokenQueue) == 0 {
|
||||
a.mu.Unlock()
|
||||
log.Trace("User Manager: No token updates to commit")
|
||||
return nil
|
||||
}
|
||||
tokenQueue := a.tokenQueue
|
||||
a.tokenQueue = make(map[string]*TokenUpdate)
|
||||
a.mu.Unlock()
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
log.Debug("User Manager: Writing token update queue for %d token(s)", len(tokenQueue))
|
||||
for tokenID, update := range tokenQueue {
|
||||
log.Trace("User Manager: Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
|
||||
if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// Authorize returns nil if the given user has access to the given topic using the desired
|
||||
// permission. The user param may be nil to signal an anonymous user.
|
||||
func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/util"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -139,7 +140,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
require.False(t, u.Deleted)
|
||||
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour))
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
|
||||
require.Nil(t, err)
|
||||
|
||||
u, err = a.Authenticate("user", "pass")
|
||||
|
@ -397,7 +398,7 @@ func TestManager_Token_Valid(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
|
||||
// Create token for user
|
||||
token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour))
|
||||
token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token.Value)
|
||||
require.Equal(t, "some label", token.Label)
|
||||
|
@ -441,12 +442,12 @@ func TestManager_Token_Expire(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
|
||||
// Create tokens for user
|
||||
token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour))
|
||||
token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token1.Value)
|
||||
require.True(t, time.Now().Add(71*time.Hour).Unix() < token1.Expires.Unix())
|
||||
|
||||
token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour))
|
||||
token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token2.Value)
|
||||
require.NotEqual(t, token1.Value, token2.Value)
|
||||
|
@ -493,7 +494,7 @@ func TestManager_Token_Extend(t *testing.T) {
|
|||
require.Equal(t, errNoTokenProvided, err)
|
||||
|
||||
// Create token for user
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour))
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token.Value)
|
||||
|
||||
|
@ -520,7 +521,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
|||
baseTime := time.Now().Add(24 * time.Hour)
|
||||
tokens := make([]string, 0)
|
||||
for i := 0; i < 12; i++ {
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour))
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token.Value)
|
||||
tokens = append(tokens, token.Value)
|
||||
|
@ -624,6 +625,61 @@ func TestManager_ChangeSettings(t *testing.T) {
|
|||
require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName)
|
||||
}
|
||||
|
||||
func TestManager_Tier_Create(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
|
||||
// Create tier and user
|
||||
require.Nil(t, a.CreateTier(&Tier{
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
MessageLimit: 123,
|
||||
MessageExpiryDuration: 86400 * time.Second,
|
||||
EmailLimit: 32,
|
||||
ReservationLimit: 2,
|
||||
AttachmentFileSizeLimit: 1231231,
|
||||
AttachmentTotalSizeLimit: 123123,
|
||||
AttachmentExpiryDuration: 10800 * time.Second,
|
||||
AttachmentBandwidthLimit: 21474836480,
|
||||
}))
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleUser))
|
||||
require.Nil(t, a.ChangeTier("phil", "pro"))
|
||||
|
||||
ti, err := a.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
|
||||
u, err := a.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
// These are populated by different SQL queries
|
||||
require.Equal(t, ti, u.Tier)
|
||||
|
||||
// Fields
|
||||
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
|
||||
require.Equal(t, "pro", ti.Code)
|
||||
require.Equal(t, "Pro", ti.Name)
|
||||
require.Equal(t, int64(123), ti.MessageLimit)
|
||||
require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration)
|
||||
require.Equal(t, int64(32), ti.EmailLimit)
|
||||
require.Equal(t, int64(2), ti.ReservationLimit)
|
||||
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
|
||||
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
|
||||
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
|
||||
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
|
||||
}
|
||||
|
||||
func TestAccount_Tier_Create_With_ID(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
|
||||
require.Nil(t, a.CreateTier(&Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
}))
|
||||
|
||||
ti, err := a.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "ti_123", ti.ID)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||
filename := filepath.Join(t.TempDir(), "user.db")
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
|
|
|
@ -4,6 +4,7 @@ package user
|
|||
import (
|
||||
"errors"
|
||||
"github.com/stripe/stripe-go/v74"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
@ -46,9 +47,17 @@ type Auther interface {
|
|||
|
||||
// Token represents a user token, including expiry date
|
||||
type Token struct {
|
||||
Value string
|
||||
Label string
|
||||
Expires time.Time
|
||||
Value string
|
||||
Label string
|
||||
LastAccess time.Time
|
||||
LastOrigin netip.Addr
|
||||
Expires time.Time
|
||||
}
|
||||
|
||||
// TokenUpdate holds information about the last access time and origin IP address of a token
|
||||
type TokenUpdate struct {
|
||||
LastAccess time.Time
|
||||
LastOrigin netip.Addr
|
||||
}
|
||||
|
||||
// Prefs represents a user's configuration settings
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue